[fix] fix mdtc training cache (#82)

This commit is contained in:
Binbin Zhang 2022-09-01 18:25:43 +08:00 committed by GitHub
parent 490a474d4e
commit 1ad3102c8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -205,7 +207,11 @@ class MDTC(nn.Module):
self.half_receptive_fields = self.receptive_fields // 2 self.half_receptive_fields = self.receptive_fields // 2
print('Receptive Fields: %d' % self.receptive_fields) print('Receptive Fields: %d' % self.receptive_fields)
def forward(self, x: torch.Tensor): def forward(
self,
x: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.causal: if self.causal:
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0), outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
'constant') 'constant')
@ -241,7 +247,8 @@ class MDTC(nn.Module):
for x in normalized_outputs: for x in normalized_outputs:
outputs += x outputs += x
outputs = outputs.transpose(1, 2) outputs = outputs.transpose(1, 2)
return outputs, None # TODO(Binbin Zhang): Fix cache
return outputs, in_cache
if __name__ == '__main__': if __name__ == '__main__':