7
7
import os
8
8
9
9
10
+ class AudioBackendScope :
11
+ def __init__ (self , backend ):
12
+ self .new_backend = backend
13
+ self .previous_backend = torchaudio .get_audio_backend ()
14
+
15
+ def __enter__ (self ):
16
+ torchaudio .set_audio_backend (self .new_backend )
17
+ return self .new_backend
18
+
19
+ def __exit__ (self , type , value , traceback ):
20
+ backend = self .previous_backend
21
+ torchaudio .set_audio_backend (backend )
22
+
23
+
10
24
class Test_LoadSave (unittest .TestCase ):
11
25
test_dirpath , test_dir = common_utils .create_temp_assets_dir ()
12
26
test_filepath = os .path .join (test_dirpath , "assets" ,
13
27
"steam-train-whistle-daniel_simon.mp3" )
28
+ test_filepath_wav = os .path .join (test_dirpath , "assets" ,
29
+ "steam-train-whistle-daniel_simon.wav" )
14
30
15
31
def test_1_save (self ):
32
+ for backend in ["sox" ]:
33
+ with self .subTest ():
34
+ with AudioBackendScope (backend ):
35
+ self ._test_1_save (self .test_filepath , False )
36
+
37
+ for backend in ["sox" , "soundfile" ]:
38
+ with self .subTest ():
39
+ with AudioBackendScope (backend ):
40
+ self ._test_1_save (self .test_filepath_wav , True )
41
+
42
+ def _test_1_save (self , test_filepath , normalization ):
16
43
# load signal
17
- x , sr = torchaudio .load (self . test_filepath , normalization = False )
44
+ x , sr = torchaudio .load (test_filepath , normalization = normalization )
18
45
19
46
# check save
20
47
new_filepath = os .path .join (self .test_dirpath , "test.wav" )
@@ -52,6 +79,14 @@ def test_1_save(self):
52
79
"test.wav" )
53
80
torchaudio .save (new_filepath , x , sr )
54
81
82
+ def test_1_save_sine (self ):
83
+ for backend in ["sox" , "soundfile" ]:
84
+ with self .subTest ():
85
+ with AudioBackendScope (backend ):
86
+ self ._test_1_save_sine ()
87
+
88
+ def _test_1_save_sine (self ):
89
+
55
90
# save created file
56
91
sinewave_filepath = os .path .join (self .test_dirpath , "assets" ,
57
92
"sinewave.wav" )
@@ -78,34 +113,36 @@ def test_1_save(self):
78
113
os .unlink (new_filepath )
79
114
80
115
def test_2_load (self ):
116
+ for backend in ["sox" ]:
117
+ with self .subTest ():
118
+ with AudioBackendScope (backend ):
119
+ self ._test_2_load (self .test_filepath , 278756 )
120
+
121
+ for backend in ["sox" , "soundfile" ]:
122
+ with self .subTest ():
123
+ with AudioBackendScope (backend ):
124
+ self ._test_2_load (self .test_filepath_wav , 276858 )
125
+
126
+ def _test_2_load (self , test_filepath , length ):
81
127
# check normal loading
82
- x , sr = torchaudio .load (self . test_filepath )
128
+ x , sr = torchaudio .load (test_filepath )
83
129
self .assertEqual (sr , 44100 )
84
- self .assertEqual (x .size (), (2 , 278756 ))
85
-
86
- # check no normalizing
87
- x , _ = torchaudio .load (self .test_filepath , normalization = False )
88
- self .assertTrue (x .min () <= - 1.0 )
89
- self .assertTrue (x .max () >= 1.0 )
130
+ self .assertEqual (x .size (), (2 , length ))
90
131
91
132
# check offset
92
133
offset = 15
93
- x , _ = torchaudio .load (self . test_filepath )
94
- x_offset , _ = torchaudio .load (self . test_filepath , offset = offset )
134
+ x , _ = torchaudio .load (test_filepath )
135
+ x_offset , _ = torchaudio .load (test_filepath , offset = offset )
95
136
self .assertTrue (x [:, offset :].allclose (x_offset ))
96
137
97
138
# check number of frames
98
139
n = 201
99
- x , _ = torchaudio .load (self . test_filepath , num_frames = n )
140
+ x , _ = torchaudio .load (test_filepath , num_frames = n )
100
141
self .assertTrue (x .size (), (2 , n ))
101
142
102
143
# check channels first
103
- x , _ = torchaudio .load (self .test_filepath , channels_first = False )
104
- self .assertEqual (x .size (), (278756 , 2 ))
105
-
106
- # check different input tensor type
107
- x , _ = torchaudio .load (self .test_filepath , torch .LongTensor (), normalization = False )
108
- self .assertTrue (isinstance (x , torch .LongTensor ))
144
+ x , _ = torchaudio .load (test_filepath , channels_first = False )
145
+ self .assertEqual (x .size (), (length , 2 ))
109
146
110
147
# check raising errors
111
148
with self .assertRaises (OSError ):
@@ -116,7 +153,30 @@ def test_2_load(self):
116
153
os .path .dirname (self .test_dirpath ), "torchaudio" )
117
154
torchaudio .load (tdir )
118
155
156
+ def test_2_load_nonormalization (self ):
157
+ for backend in ["sox" ]:
158
+ with self .subTest ():
159
+ with AudioBackendScope (backend ):
160
+ self ._test_2_load_nonormalization (self .test_filepath , 278756 )
161
+
162
+ def _test_2_load_nonormalization (self , test_filepath , length ):
163
+
164
+ # check no normalizing
165
+ x , _ = torchaudio .load (test_filepath , normalization = False )
166
+ self .assertTrue (x .min () <= - 1.0 )
167
+ self .assertTrue (x .max () >= 1.0 )
168
+
169
+ # check different input tensor type
170
+ x , _ = torchaudio .load (test_filepath , torch .LongTensor (), normalization = False )
171
+ self .assertTrue (isinstance (x , torch .LongTensor ))
172
+
119
173
def test_3_load_and_save_is_identity (self ):
174
+ for backend in ["sox" , "soundfile" ]:
175
+ with self .subTest ():
176
+ with AudioBackendScope (backend ):
177
+ self ._test_3_load_and_save_is_identity ()
178
+
179
+ def _test_3_load_and_save_is_identity (self ):
120
180
input_path = os .path .join (self .test_dirpath , 'assets' , 'sinewave.wav' )
121
181
tensor , sample_rate = torchaudio .load (input_path )
122
182
output_path = os .path .join (self .test_dirpath , 'test.wav' )
@@ -126,7 +186,35 @@ def test_3_load_and_save_is_identity(self):
126
186
self .assertEqual (sample_rate , sample_rate2 )
127
187
os .unlink (output_path )
128
188
189
+ def test_3_load_and_save_is_identity_across_backend (self ):
190
+ with self .subTest ():
191
+ self ._test_3_load_and_save_is_identity_across_backend ("sox" , "soundfile" )
192
+ with self .subTest ():
193
+ self ._test_3_load_and_save_is_identity_across_backend ("soundfile" , "sox" )
194
+
195
+ def _test_3_load_and_save_is_identity_across_backend (self , backend1 , backend2 ):
196
+ with AudioBackendScope (backend1 ):
197
+
198
+ input_path = os .path .join (self .test_dirpath , 'assets' , 'sinewave.wav' )
199
+ tensor1 , sample_rate1 = torchaudio .load (input_path )
200
+
201
+ output_path = os .path .join (self .test_dirpath , 'test.wav' )
202
+ torchaudio .save (output_path , tensor1 , sample_rate1 )
203
+
204
+ with AudioBackendScope (backend2 ):
205
+ tensor2 , sample_rate2 = torchaudio .load (output_path )
206
+
207
+ self .assertTrue (tensor1 .allclose (tensor2 ))
208
+ self .assertEqual (sample_rate1 , sample_rate2 )
209
+ os .unlink (output_path )
210
+
129
211
def test_4_load_partial (self ):
212
+ for backend in ["sox" ]:
213
+ with self .subTest ():
214
+ with AudioBackendScope (backend ):
215
+ self ._test_4_load_partial ()
216
+
217
+ def _test_4_load_partial (self ):
130
218
num_frames = 101
131
219
offset = 201
132
220
# load entire mono sinewave wav file, load a partial copy and then compare
@@ -163,6 +251,12 @@ def test_4_load_partial(self):
163
251
torchaudio .load (input_sine_path , offset = 100000 )
164
252
165
253
def test_5_get_info (self ):
254
+ for backend in ["sox" , "soundfile" ]:
255
+ with self .subTest ():
256
+ with AudioBackendScope (backend ):
257
+ self ._test_5_get_info ()
258
+
259
+ def _test_5_get_info (self ):
166
260
input_path = os .path .join (self .test_dirpath , 'assets' , 'sinewave.wav' )
167
261
channels , samples , rate , precision = (1 , 64000 , 16000 , 16 )
168
262
si , ei = torchaudio .info (input_path )
0 commit comments