1
+ import math
1
2
import random
2
3
from collections import Counter , OrderedDict
3
4
from typing import Any , Callable , Dict , List , Optional , Tuple
18
19
ToExecutorch ,
19
20
)
20
21
from executorch .exir .dim_order_utils import get_memory_format
22
+ from torch .ao .ns .fx .utils import compute_sqnr
21
23
22
24
from torch .export import ExportedProgram
23
25
from torch .testing import FileCheck
@@ -304,13 +306,14 @@ def run_method_and_compare_outputs(
304
306
rtol = 1e-03 ,
305
307
qtol = 0 ,
306
308
statistics_callback : Callable [[ErrorStatistics ], None ] | None = None ,
309
+ snr : float | None = None ,
307
310
):
308
311
number_of_runs = 1 if inputs is not None else num_runs
309
312
reference_stage = self .stages [StageType .EXPORT ]
310
313
311
314
stage = stage or self .cur
312
315
313
- for _ in range (number_of_runs ):
316
+ for run_iteration in range (number_of_runs ):
314
317
inputs_to_run = inputs if inputs else next (self .generate_random_inputs ())
315
318
< << << << HEAD
316
319
input_shapes = [
@@ -338,6 +341,7 @@ def run_method_and_compare_outputs(
338
341
atol ,
339
342
rtol ,
340
343
qtol ,
344
+ snr ,
341
345
statistics_callback ,
342
346
)
343
347
@@ -349,6 +353,7 @@ def _assert_outputs_equal(
349
353
ref_output ,
350
354
atol = 1e-03 ,
351
355
rtol = 1e-03 ,
356
+ snr : float | None = None ,
352
357
statistics_callback : Callable [[ErrorStatistics ], None ] | None = None ,
353
358
):
354
359
"""
@@ -380,15 +385,22 @@ def _assert_outputs_equal(
380
385
f"\t Mismatched count: { (model != ref ).sum ().item ()} / { model .numel ()} \n "
381
386
)
382
387
else :
383
- assert torch .allclose (
384
- model ,
385
- ref ,
386
- atol = atol ,
387
- rtol = rtol ,
388
- equal_nan = True ,
388
+ computed_snr = compute_sqnr (model .to (torch .float ), ref .to (torch .float ))
389
+ snr = snr or float ("-inf" )
390
+
391
+ assert (
392
+ torch .allclose (
393
+ model ,
394
+ ref ,
395
+ atol = atol ,
396
+ rtol = rtol ,
397
+ equal_nan = True ,
398
+ )
399
+ and computed_snr >= snr
400
+ or math .isnan (computed_snr )
389
401
), (
390
402
f"Output { i } does not match reference output.\n "
391
- f"\t Given atol: { atol } , rtol: { rtol } .\n "
403
+ f"\t Given atol: { atol } , rtol: { rtol } , snr: { snr } .\n "
392
404
f"\t Output tensor shape: { model .shape } , dtype: { model .dtype } \n "
393
405
f"\t Difference: max: { torch .max (model - ref )} , abs: { torch .max (torch .abs (model - ref ))} , mean abs error: { torch .mean (torch .abs (model - ref ).to (torch .double ))} .\n "
394
406
f"\t -- Model vs. Reference --\n "
@@ -397,6 +409,7 @@ def _assert_outputs_equal(
397
409
f"\t Mean: { model .to (torch .double ).mean ()} , { ref .to (torch .double ).mean ()} \n "
398
410
f"\t Max: { model .max ()} , { ref .max ()} \n "
399
411
f"\t Min: { model .min ()} , { ref .min ()} \n "
412
+ f"\t SNR: { computed_snr } \n "
400
413
)
401
414
402
415
@staticmethod
@@ -407,6 +420,7 @@ def _compare_outputs(
407
420
atol = 1e-03 ,
408
421
rtol = 1e-03 ,
409
422
qtol = 0 ,
423
+ snr : float | None = None ,
410
424
statistics_callback : Callable [[ErrorStatistics ], None ] | None = None ,
411
425
):
412
426
"""
@@ -430,6 +444,7 @@ def _compare_outputs(
430
444
reference_output ,
431
445
atol = atol ,
432
446
rtol = rtol ,
447
+ snr = snr ,
433
448
statistics_callback = statistics_callback ,
434
449
)
435
450
0 commit comments