@@ -51,7 +51,7 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None,
51
51
self .win_length = win_length if win_length is not None else n_fft
52
52
self .hop_length = hop_length if hop_length is not None else self .win_length // 2
53
53
window = window_fn (self .win_length ) if wkwargs is None else window_fn (self .win_length , ** wkwargs )
54
- self .window = window
54
+ self .register_buffer ( ' window' , window )
55
55
self .pad = pad
56
56
self .power = power
57
57
self .normalized = normalized
@@ -136,7 +136,7 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N
136
136
137
137
fb = torch .empty (0 ) if n_stft is None else F .create_fb_matrix (
138
138
n_stft , self .f_min , self .f_max , self .n_mels , self .sample_rate )
139
- self .fb = fb
139
+ self .register_buffer ( 'fb' , fb )
140
140
141
141
def forward (self , specgram ):
142
142
r"""
@@ -260,7 +260,7 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m
260
260
if self .n_mfcc > self .MelSpectrogram .n_mels :
261
261
raise ValueError ('Cannot select more MFCC coefficients than # mel bins' )
262
262
dct_mat = F .create_dct (self .n_mfcc , self .MelSpectrogram .n_mels , self .norm )
263
- self .dct_mat = dct_mat
263
+ self .register_buffer ( ' dct_mat' , dct_mat )
264
264
self .log_mels = log_mels
265
265
266
266
def forward (self , waveform ):
0 commit comments