Skip to content

Commit 4875007

Browse files
authored
Extract JIT tests from filter test module and put in JIT test module. (#507)
1 parent 2126924 commit 4875007

File tree

2 files changed

+152
-25
lines changed

2 files changed

+152
-25
lines changed

test/test_functional_filtering.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,6 @@
99
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
1010

1111

12-
def _test_torchscript_functional(py_method, *args, **kwargs):
13-
jit_method = torch.jit.script(py_method)
14-
15-
jit_out = jit_method(*args, **kwargs)
16-
py_out = py_method(*args, **kwargs)
17-
18-
assert torch.allclose(jit_out, py_out)
19-
20-
2112
class TestFunctionalFiltering(unittest.TestCase):
2213
test_dirpath, test_dir = create_temp_assets_dir()
2314

@@ -88,7 +79,6 @@ def _test_lfilter(self, waveform, device):
8879
assert len(output_waveform.size()) == 2
8980
assert output_waveform.size(0) == waveform.size(0)
9081
assert output_waveform.size(1) == waveform.size(1)
91-
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
9282

9383
def test_lfilter(self):
9484

@@ -189,7 +179,6 @@ def test_lowpass(self):
189179
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
190180

191181
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
192-
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
193182

194183
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
195184
@AudioBackendScope("sox")
@@ -211,7 +200,6 @@ def test_highpass(self):
211200

212201
# TBD - this fails at the 1e-4 level, debug why
213202
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
214-
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ)
215203

216204
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
217205
@AudioBackendScope("sox")
@@ -233,7 +221,6 @@ def test_allpass(self):
233221
output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
234222

235223
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
236-
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)
237224

238225
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
239226
@AudioBackendScope("sox")
@@ -256,7 +243,6 @@ def test_bandpass_with_csg(self):
256243
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
257244

258245
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
259-
_test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
260246

261247
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
262248
@AudioBackendScope("sox")
@@ -279,7 +265,6 @@ def test_bandpass_without_csg(self):
279265
output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
280266

281267
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
282-
_test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)
283268

284269
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
285270
@AudioBackendScope("sox")
@@ -301,7 +286,6 @@ def test_bandreject(self):
301286
output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)
302287

303288
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
304-
_test_torchscript_functional(F.bandreject_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)
305289

306290
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
307291
@AudioBackendScope("sox")
@@ -324,7 +308,6 @@ def test_band_with_noise(self):
324308
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
325309

326310
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
327-
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
328311

329312
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
330313
@AudioBackendScope("sox")
@@ -347,7 +330,6 @@ def test_band_without_noise(self):
347330
output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
348331

349332
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
350-
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)
351333

352334
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
353335
@AudioBackendScope("sox")
@@ -370,7 +352,6 @@ def test_treble(self):
370352
output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
371353

372354
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
373-
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)
374355

375356
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
376357
@AudioBackendScope("sox")
@@ -389,7 +370,6 @@ def test_deemph(self):
389370
output_waveform = F.deemph_biquad(waveform, sample_rate)
390371

391372
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
392-
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)
393373

394374
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
395375
@AudioBackendScope("sox")
@@ -408,7 +388,6 @@ def test_riaa(self):
408388
output_waveform = F.riaa_biquad(waveform, sample_rate)
409389

410390
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
411-
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)
412391

413392
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
414393
@AudioBackendScope("sox")
@@ -431,7 +410,6 @@ def test_equalizer(self):
431410
output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)
432411

433412
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
434-
_test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q)
435413

436414
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
437415
@AudioBackendScope("sox")
@@ -458,9 +436,6 @@ def test_perf_biquad_filtering(self):
458436
)
459437

460438
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
461-
_test_torchscript_functional(
462-
F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
463-
)
464439

465440

466441
if __name__ == "__main__":

test/test_torchscript_consistency.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,42 @@ def _test_torchscript_functional(py_method, *args, **kwargs):
2525
assert torch.allclose(jit_out, py_out)
2626

2727

28+
def _test_lfilter(waveform):
29+
"""
30+
Design an IIR lowpass filter using scipy.signal filter design
31+
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
32+
33+
Example
34+
>>> from scipy.signal import iirdesign
35+
>>> b, a = iirdesign(0.2, 0.3, 1, 60)
36+
"""
37+
b_coeffs = torch.tensor(
38+
[
39+
0.00299893,
40+
-0.0051152,
41+
0.00841964,
42+
-0.00747802,
43+
0.00841964,
44+
-0.0051152,
45+
0.00299893,
46+
],
47+
device=waveform.device,
48+
)
49+
a_coeffs = torch.tensor(
50+
[
51+
1.0,
52+
-4.8155751,
53+
10.2217618,
54+
-12.14481273,
55+
8.49018171,
56+
-3.3066882,
57+
0.56088705,
58+
],
59+
device=waveform.device,
60+
)
61+
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)
62+
63+
2864
class TestFunctional(unittest.TestCase):
2965
"""Test functions in `functional` module."""
3066
def test_spectrogram(self):
@@ -151,6 +187,122 @@ def test_dither(self):
151187
_test_torchscript_functional_shape(F.dither, tensor, "RPDF")
152188
_test_torchscript_functional_shape(F.dither, tensor, "GPDF")
153189

190+
def test_lfilter(self):
191+
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
192+
waveform, _ = torchaudio.load(filepath, normalization=True)
193+
_test_lfilter(waveform)
194+
195+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
196+
def test_lfilter_cuda(self):
197+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
198+
waveform, _ = torchaudio.load(filepath, normalization=True)
199+
_test_lfilter(waveform.cuda(device=torch.device("cuda:0")))
200+
201+
def test_lowpass(self):
202+
cutoff_freq = 3000
203+
204+
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
205+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
206+
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, cutoff_freq)
207+
208+
def test_highpass(self):
209+
cutoff_freq = 2000
210+
211+
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
212+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
213+
_test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, cutoff_freq)
214+
215+
def test_allpass(self):
216+
central_freq = 1000
217+
q = 0.707
218+
219+
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
220+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
221+
_test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, central_freq, q)
222+
223+
def test_bandpass_with_csg(self):
224+
central_freq = 1000
225+
q = 0.707
226+
const_skirt_gain = True
227+
228+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
229+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
230+
_test_torchscript_functional(
231+
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
232+
233+
def test_bandpass_withou_csg(self):
234+
central_freq = 1000
235+
q = 0.707
236+
const_skirt_gain = False
237+
238+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
239+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
240+
_test_torchscript_functional(
241+
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
242+
243+
def test_bandreject(self):
244+
central_freq = 1000
245+
q = 0.707
246+
247+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
248+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
249+
_test_torchscript_functional(
250+
F.bandreject_biquad, waveform, sample_rate, central_freq, q)
251+
252+
def test_band_with_noise(self):
253+
central_freq = 1000
254+
q = 0.707
255+
noise = True
256+
257+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
258+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
259+
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
260+
261+
def test_band_without_noise(self):
262+
central_freq = 1000
263+
q = 0.707
264+
noise = False
265+
266+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
267+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
268+
_test_torchscript_functional(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
269+
270+
def test_treble(self):
271+
gain = 40
272+
central_freq = 1000
273+
q = 0.707
274+
275+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
276+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
277+
_test_torchscript_functional(F.treble_biquad, waveform, sample_rate, gain, central_freq, q)
278+
279+
def test_deemph(self):
280+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
281+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
282+
_test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)
283+
284+
def test_riaa(self):
285+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
286+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
287+
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)
288+
289+
def test_equalizer(self):
290+
center_freq = 300
291+
gain = 1
292+
q = 0.707
293+
294+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
295+
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
296+
_test_torchscript_functional(
297+
F.equalizer_biquad, waveform, sample_rate, center_freq, gain, q)
298+
299+
def test_perf_biquad_filtering(self):
300+
a = torch.tensor([0.7, 0.2, 0.6])
301+
b = torch.tensor([0.4, 0.2, 0.9])
302+
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
303+
waveform, _ = torchaudio.load(filepath, normalization=True)
304+
_test_torchscript_functional(F.lfilter, waveform, a, b)
305+
154306

155307
RUN_CUDA = torch.cuda.is_available()
156308
print("Run test with cuda:", RUN_CUDA)

0 commit comments

Comments
 (0)