Skip to content

Commit 04e6847

Browse files
authored
Fix Fade device compatibility (#508)
* Fix Fade device compatibility
1 parent adef7b9 commit 04e6847

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchaudio/transforms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,9 @@ def forward(self, waveform: Tensor) -> Tensor:
723723
Tensor: Tensor of audio of dimension (..., time).
724724
"""
725725
waveform_length = waveform.size()[-1]
726-
727-
return self._fade_in(waveform_length) * self._fade_out(waveform_length) * waveform
726+
device = waveform.device
727+
return self._fade_in(waveform_length).to(device) * \
728+
self._fade_out(waveform_length).to(device) * waveform
728729

729730
def _fade_in(self, waveform_length: int) -> Tensor:
730731
fade = torch.linspace(0, 1, self.fade_in_len)

0 commit comments

Comments
 (0)