Skip to content

Commit 57c8b97

Browse files
committed
Fix vocoder interface (#1895)
1 parent c4fc8f9 commit 57c8b97

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

torchaudio/pipelines/_tts/impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
def sample_rate(self):
8484
return self._sample_rate
8585

86-
def forward(self, mel_spec, lengths):
86+
def forward(self, mel_spec, lengths=None):
8787
mel_spec = torch.exp(mel_spec)
8888
mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5))
8989
if self._min_level_db is not None:
@@ -120,7 +120,7 @@ def __init__(self):
120120
def sample_rate(self):
121121
return self._sample_rate
122122

123-
def forward(self, mel_spec, lengths):
123+
def forward(self, mel_spec, lengths=None):
124124
mel_spec = torch.exp(mel_spec)
125125
mel_spec = mel_spec.clone().detach().requires_grad_(True)
126126
spec = self._inv_mel(mel_spec)

torchaudio/pipelines/_tts/interface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def sample_rate(self):
4747
"""
4848

4949
@abstractmethod
50-
def __call__(self, specgrams: Tensor, lengths: Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]:
50+
def __call__(self, specgrams: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
5151
"""Generate waveform from the given input, such as spectrogram
5252
5353
See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage.
@@ -58,6 +58,7 @@ def __call__(self, specgrams: Tensor, lengths: Optional[Tensor]) -> Tuple[Tensor
5858
The expected shape depends on the implementation.
5959
lengths (Tensor, or None, optional):
6060
The valid length of each sample in the batch. Shape: `(batch, )`.
61+
(Default: `None`)
6162
6263
Returns:
6364
(Tensor, Optional[Tensor]):

0 commit comments

Comments
 (0)