Skip to content

Commit 774ebc7

Browse files
authored
Backend switch (#355)
* move sox inside function calls. * add backend switch mechanism. * import sox at runtime, not import. * add backend list. * backend tests. * creating hidden modules for backend. * naming backend same as file: soundfile. * remove docstring in backend file. * test soundfile info. * soundfile doesn't support int64. * adding test for wav file. * error with incorrect parameter instead of silent ignore. * adding test across backend. using float32 as done in sox. * backend guard decorator.
1 parent 4887ff4 commit 774ebc7

File tree

7 files changed

+455
-65
lines changed

7 files changed

+455
-65
lines changed
1.06 MB
Binary file not shown.

test/test.py

Lines changed: 111 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,41 @@
77
import os
88

99

10+
class AudioBackendScope:
11+
def __init__(self, backend):
12+
self.new_backend = backend
13+
self.previous_backend = torchaudio.get_audio_backend()
14+
15+
def __enter__(self):
16+
torchaudio.set_audio_backend(self.new_backend)
17+
return self.new_backend
18+
19+
def __exit__(self, type, value, traceback):
20+
backend = self.previous_backend
21+
torchaudio.set_audio_backend(backend)
22+
23+
1024
class Test_LoadSave(unittest.TestCase):
1125
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
1226
test_filepath = os.path.join(test_dirpath, "assets",
1327
"steam-train-whistle-daniel_simon.mp3")
28+
test_filepath_wav = os.path.join(test_dirpath, "assets",
29+
"steam-train-whistle-daniel_simon.wav")
1430

1531
def test_1_save(self):
32+
for backend in ["sox"]:
33+
with self.subTest():
34+
with AudioBackendScope(backend):
35+
self._test_1_save(self.test_filepath, False)
36+
37+
for backend in ["sox", "soundfile"]:
38+
with self.subTest():
39+
with AudioBackendScope(backend):
40+
self._test_1_save(self.test_filepath_wav, True)
41+
42+
def _test_1_save(self, test_filepath, normalization):
1643
# load signal
17-
x, sr = torchaudio.load(self.test_filepath, normalization=False)
44+
x, sr = torchaudio.load(test_filepath, normalization=normalization)
1845

1946
# check save
2047
new_filepath = os.path.join(self.test_dirpath, "test.wav")
@@ -52,6 +79,14 @@ def test_1_save(self):
5279
"test.wav")
5380
torchaudio.save(new_filepath, x, sr)
5481

82+
def test_1_save_sine(self):
83+
for backend in ["sox", "soundfile"]:
84+
with self.subTest():
85+
with AudioBackendScope(backend):
86+
self._test_1_save_sine()
87+
88+
def _test_1_save_sine(self):
89+
5590
# save created file
5691
sinewave_filepath = os.path.join(self.test_dirpath, "assets",
5792
"sinewave.wav")
@@ -78,34 +113,36 @@ def test_1_save(self):
78113
os.unlink(new_filepath)
79114

80115
def test_2_load(self):
116+
for backend in ["sox"]:
117+
with self.subTest():
118+
with AudioBackendScope(backend):
119+
self._test_2_load(self.test_filepath, 278756)
120+
121+
for backend in ["sox", "soundfile"]:
122+
with self.subTest():
123+
with AudioBackendScope(backend):
124+
self._test_2_load(self.test_filepath_wav, 276858)
125+
126+
def _test_2_load(self, test_filepath, length):
81127
# check normal loading
82-
x, sr = torchaudio.load(self.test_filepath)
128+
x, sr = torchaudio.load(test_filepath)
83129
self.assertEqual(sr, 44100)
84-
self.assertEqual(x.size(), (2, 278756))
85-
86-
# check no normalizing
87-
x, _ = torchaudio.load(self.test_filepath, normalization=False)
88-
self.assertTrue(x.min() <= -1.0)
89-
self.assertTrue(x.max() >= 1.0)
130+
self.assertEqual(x.size(), (2, length))
90131

91132
# check offset
92133
offset = 15
93-
x, _ = torchaudio.load(self.test_filepath)
94-
x_offset, _ = torchaudio.load(self.test_filepath, offset=offset)
134+
x, _ = torchaudio.load(test_filepath)
135+
x_offset, _ = torchaudio.load(test_filepath, offset=offset)
95136
self.assertTrue(x[:, offset:].allclose(x_offset))
96137

97138
# check number of frames
98139
n = 201
99-
x, _ = torchaudio.load(self.test_filepath, num_frames=n)
140+
x, _ = torchaudio.load(test_filepath, num_frames=n)
100141
self.assertTrue(x.size(), (2, n))
101142

102143
# check channels first
103-
x, _ = torchaudio.load(self.test_filepath, channels_first=False)
104-
self.assertEqual(x.size(), (278756, 2))
105-
106-
# check different input tensor type
107-
x, _ = torchaudio.load(self.test_filepath, torch.LongTensor(), normalization=False)
108-
self.assertTrue(isinstance(x, torch.LongTensor))
144+
x, _ = torchaudio.load(test_filepath, channels_first=False)
145+
self.assertEqual(x.size(), (length, 2))
109146

110147
# check raising errors
111148
with self.assertRaises(OSError):
@@ -116,7 +153,30 @@ def test_2_load(self):
116153
os.path.dirname(self.test_dirpath), "torchaudio")
117154
torchaudio.load(tdir)
118155

156+
def test_2_load_nonormalization(self):
157+
for backend in ["sox"]:
158+
with self.subTest():
159+
with AudioBackendScope(backend):
160+
self._test_2_load_nonormalization(self.test_filepath, 278756)
161+
162+
def _test_2_load_nonormalization(self, test_filepath, length):
163+
164+
# check no normalizing
165+
x, _ = torchaudio.load(test_filepath, normalization=False)
166+
self.assertTrue(x.min() <= -1.0)
167+
self.assertTrue(x.max() >= 1.0)
168+
169+
# check different input tensor type
170+
x, _ = torchaudio.load(test_filepath, torch.LongTensor(), normalization=False)
171+
self.assertTrue(isinstance(x, torch.LongTensor))
172+
119173
def test_3_load_and_save_is_identity(self):
174+
for backend in ["sox", "soundfile"]:
175+
with self.subTest():
176+
with AudioBackendScope(backend):
177+
self._test_3_load_and_save_is_identity()
178+
179+
def _test_3_load_and_save_is_identity(self):
120180
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
121181
tensor, sample_rate = torchaudio.load(input_path)
122182
output_path = os.path.join(self.test_dirpath, 'test.wav')
@@ -126,7 +186,35 @@ def test_3_load_and_save_is_identity(self):
126186
self.assertEqual(sample_rate, sample_rate2)
127187
os.unlink(output_path)
128188

189+
def test_3_load_and_save_is_identity_across_backend(self):
190+
with self.subTest():
191+
self._test_3_load_and_save_is_identity_across_backend("sox", "soundfile")
192+
with self.subTest():
193+
self._test_3_load_and_save_is_identity_across_backend("soundfile", "sox")
194+
195+
def _test_3_load_and_save_is_identity_across_backend(self, backend1, backend2):
196+
with AudioBackendScope(backend1):
197+
198+
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
199+
tensor1, sample_rate1 = torchaudio.load(input_path)
200+
201+
output_path = os.path.join(self.test_dirpath, 'test.wav')
202+
torchaudio.save(output_path, tensor1, sample_rate1)
203+
204+
with AudioBackendScope(backend2):
205+
tensor2, sample_rate2 = torchaudio.load(output_path)
206+
207+
self.assertTrue(tensor1.allclose(tensor2))
208+
self.assertEqual(sample_rate1, sample_rate2)
209+
os.unlink(output_path)
210+
129211
def test_4_load_partial(self):
212+
for backend in ["sox"]:
213+
with self.subTest():
214+
with AudioBackendScope(backend):
215+
self._test_4_load_partial()
216+
217+
def _test_4_load_partial(self):
130218
num_frames = 101
131219
offset = 201
132220
# load entire mono sinewave wav file, load a partial copy and then compare
@@ -163,6 +251,12 @@ def test_4_load_partial(self):
163251
torchaudio.load(input_sine_path, offset=100000)
164252

165253
def test_5_get_info(self):
254+
for backend in ["sox", "soundfile"]:
255+
with self.subTest():
256+
with AudioBackendScope(backend):
257+
self._test_5_get_info()
258+
259+
def _test_5_get_info(self):
166260
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
167261
channels, samples, rate, precision = (1, 64000, 16000, 16)
168262
si, ei = torchaudio.info(input_path)

0 commit comments

Comments
 (0)