Skip to content

Commit 2601749

Browse files
committed
[Backend Tester] Add SNR validation (pytorch#12924)
Add SNR validation for model outputs. This is a generally more robust than tight element-wise tolerances, as it takes the entire tensor into account. I'm relaxing atol to handle outliers and reduce noise.
1 parent dd62a07 commit 2601749

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

backends/test/harness/tester.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import random
23
from collections import Counter, OrderedDict
34
from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -18,6 +19,7 @@
1819
ToExecutorch,
1920
)
2021
from executorch.exir.dim_order_utils import get_memory_format
22+
from torch.ao.ns.fx.utils import compute_sqnr
2123

2224
from torch.export import ExportedProgram
2325
from torch.testing import FileCheck
@@ -304,13 +306,14 @@ def run_method_and_compare_outputs(
304306
rtol=1e-03,
305307
qtol=0,
306308
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
309+
snr: float | None = None,
307310
):
308311
number_of_runs = 1 if inputs is not None else num_runs
309312
reference_stage = self.stages[StageType.EXPORT]
310313

311314
stage = stage or self.cur
312315

313-
for _ in range(number_of_runs):
316+
for run_iteration in range(number_of_runs):
314317
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
315318
<<<<<<< HEAD
316319
input_shapes = [
@@ -338,6 +341,7 @@ def run_method_and_compare_outputs(
338341
atol,
339342
rtol,
340343
qtol,
344+
snr,
341345
statistics_callback,
342346
)
343347

@@ -349,6 +353,7 @@ def _assert_outputs_equal(
349353
ref_output,
350354
atol=1e-03,
351355
rtol=1e-03,
356+
snr: float | None = None,
352357
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
353358
):
354359
"""
@@ -380,15 +385,22 @@ def _assert_outputs_equal(
380385
f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n"
381386
)
382387
else:
383-
assert torch.allclose(
384-
model,
385-
ref,
386-
atol=atol,
387-
rtol=rtol,
388-
equal_nan=True,
388+
computed_snr = compute_sqnr(model.to(torch.float), ref.to(torch.float))
389+
snr = snr or float("-inf")
390+
391+
assert (
392+
torch.allclose(
393+
model,
394+
ref,
395+
atol=atol,
396+
rtol=rtol,
397+
equal_nan=True,
398+
)
399+
and computed_snr >= snr
400+
or math.isnan(computed_snr)
389401
), (
390402
f"Output {i} does not match reference output.\n"
391-
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
403+
f"\tGiven atol: {atol}, rtol: {rtol}, snr: {snr}.\n"
392404
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
393405
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref).to(torch.double))}.\n"
394406
f"\t-- Model vs. Reference --\n"
@@ -397,6 +409,7 @@ def _assert_outputs_equal(
397409
f"\t Mean: {model.to(torch.double).mean()}, {ref.to(torch.double).mean()}\n"
398410
f"\t Max: {model.max()}, {ref.max()}\n"
399411
f"\t Min: {model.min()}, {ref.min()}\n"
412+
f"\t SNR: {computed_snr}\n"
400413
)
401414

402415
@staticmethod
@@ -407,6 +420,7 @@ def _compare_outputs(
407420
atol=1e-03,
408421
rtol=1e-03,
409422
qtol=0,
423+
snr: float | None = None,
410424
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
411425
):
412426
"""
@@ -430,6 +444,7 @@ def _compare_outputs(
430444
reference_output,
431445
atol=atol,
432446
rtol=rtol,
447+
snr=snr,
433448
statistics_callback=statistics_callback,
434449
)
435450

backends/test/suite/runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ def build_result(
136136
# AssertionErrors to catch output mismatches, but this might catch more than that.
137137
try:
138138
tester.run_method_and_compare_outputs(
139-
inputs=None if generate_random_test_inputs else inputs
139+
inputs=None if generate_random_test_inputs else inputs,
140+
atol=5e-2,
141+
rtol=5e-2,
142+
snr=40,
140143
statistics_callback=lambda stats: error_statistics.append(stats)
141144
)
142145
except AssertionError as e:

0 commit comments

Comments
 (0)