Skip to content

Commit a11a5a6

Browse files
jaeyeun97vincentqb
authored andcommitted
Module GPU test fixes (#369)
* Fixed GPU tests
1 parent 774ebc7 commit a11a5a6

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchaudio/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None,
5151
self.win_length = win_length if win_length is not None else n_fft
5252
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
5353
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
54-
self.window = window
54+
self.register_buffer('window', window)
5555
self.pad = pad
5656
self.power = power
5757
self.normalized = normalized
@@ -136,7 +136,7 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N
136136

137137
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
138138
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
139-
self.fb = fb
139+
self.register_buffer('fb', fb)
140140

141141
def forward(self, specgram):
142142
r"""
@@ -260,7 +260,7 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m
260260
if self.n_mfcc > self.MelSpectrogram.n_mels:
261261
raise ValueError('Cannot select more MFCC coefficients than # mel bins')
262262
dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
263-
self.dct_mat = dct_mat
263+
self.register_buffer('dct_mat', dct_mat)
264264
self.log_mels = log_mels
265265

266266
def forward(self, waveform):

0 commit comments

Comments
 (0)