diff --git a/wekws/model/mdtc.py b/wekws/model/mdtc.py index 736eca4..385cd63 100644 --- a/wekws/model/mdtc.py +++ b/wekws/model/mdtc.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import torch import torch.nn as nn import torch.nn.functional as F @@ -205,7 +207,11 @@ class MDTC(nn.Module): self.half_receptive_fields = self.receptive_fields // 2 print('Receptive Fields: %d' % self.receptive_fields) - def forward(self, x: torch.Tensor): + def forward( + self, + x: torch.Tensor, + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + ) -> Tuple[torch.Tensor, torch.Tensor]: if self.causal: outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0), 'constant') @@ -241,7 +247,8 @@ class MDTC(nn.Module): for x in normalized_outputs: outputs += x outputs = outputs.transpose(1, 2) - return outputs, None + # TODO(Binbin Zhang): Fix cache + return outputs, in_cache if __name__ == '__main__':