[fix] fix mdtc training cache (#82)
This commit is contained in:
parent
490a474d4e
commit
1ad3102c8c
@ -13,6 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -205,7 +207,11 @@ class MDTC(nn.Module):
|
|||||||
self.half_receptive_fields = self.receptive_fields // 2
|
self.half_receptive_fields = self.receptive_fields // 2
|
||||||
print('Receptive Fields: %d' % self.receptive_fields)
|
print('Receptive Fields: %d' % self.receptive_fields)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self.causal:
|
if self.causal:
|
||||||
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
|
outputs = F.pad(x, (0, 0, self.receptive_fields, 0, 0, 0),
|
||||||
'constant')
|
'constant')
|
||||||
@ -241,7 +247,8 @@ class MDTC(nn.Module):
|
|||||||
for x in normalized_outputs:
|
for x in normalized_outputs:
|
||||||
outputs += x
|
outputs += x
|
||||||
outputs = outputs.transpose(1, 2)
|
outputs = outputs.transpose(1, 2)
|
||||||
return outputs, None
|
# TODO(Binbin Zhang): Fix cache
|
||||||
|
return outputs, in_cache
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user