modifications to get the mdtc model torch-scriptable

This commit is contained in:
lxiao336 2021-11-27 00:09:05 +08:00
parent b5dde01df3
commit 0f50633ee2
3 changed files with 29 additions and 11 deletions

View File

@ -18,7 +18,7 @@ from typing import Optional
import torch
from kws.model.cmvn import GlobalCMVN
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
from kws.model.mdtc import MDTC
from kws.utils.cmvn import load_cmvn
@ -52,8 +52,7 @@ class KWSModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None:
x = self.global_cmvn(x)
if self.preprocessing:
x = self.preprocessing(x)
x = self.preprocessing(x)
x, _ = self.backbone(x)
x = self.classifier(x)
x = torch.sigmoid(x)
@ -82,7 +81,7 @@ def init_model(configs):
elif prep_type == 'cnn1d_s1':
preprocessing = Conv1dSubsampling1(input_dim, hidden_dim)
elif prep_type == 'none':
preprocessing = None
preprocessing = NoSubsampling()
else:
print('Unknown preprocessing type {}'.format(prep_type))
sys.exit(1)

View File

@ -241,16 +241,27 @@ class MDTC(nn.Module):
outputs = outputs.transpose(1, 2)
outputs_list = []
outputs = self.relu(self.preprocessor(outputs))
for i in range(len(self.blocks)):
outputs = self.blocks[i](outputs)
for block in self.blocks:
outputs = block(outputs)
outputs_list.append(outputs)
if self.causal:
outputs_list = self.normalize_length_causal(outputs_list)
else:
outputs_list = self.normalize_length(outputs_list)
normalized_outputs = []
output_size = outputs_list[-1].shape[-1]
for x in outputs_list:
remove_length = x.shape[-1] - output_size
if remove_length != 0:
if self.causal:
normalized_outputs.append(x[:, :, remove_length:])
else:
remove_length = remove_length // 2
normalized_outputs.append(x[:, :,
remove_length:-remove_length])
else:
normalized_outputs.append(x)
outputs = sum(outputs_list)
outputs = torch.zeros_like(outputs_list[-1], dtype=outputs_list[-1].dtype)
for x in normalized_outputs:
outputs += x
outputs = outputs.transpose(1, 2)
return outputs, None

View File

@ -23,6 +23,14 @@ class SubsamplingBase(torch.nn.Module):
super().__init__()
self.subsampling_rate = 1
class NoSubsampling(SubsamplingBase):
"""No subsampling in accordance to the 'none' preprocessing
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class LinearSubsampling1(SubsamplingBase):
"""Linear transform the input without subsampling