modifications to get the mdtc model torch-scriptable (#14)

* modifying some implmentations of mdtc to get the model torch-scripting through

* modifications to get the mdtc model torch-scriptable

Co-authored-by: lxiao336 <shawl336@163.com>
This commit is contained in:
lxiao336 2021-11-29 11:15:30 +08:00 committed by GitHub
parent bf3029c0fa
commit ba6919baaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 12 deletions

View File

@ -102,7 +102,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--test_data data/test/data.list \ --test_data data/test/data.list \
--batch_size 256 \ --batch_size 256 \
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \
--score_file $result_dir/score.txt --score_file $result_dir/score.txt \
--num_workers 8
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then

View File

@ -18,7 +18,7 @@ from typing import Optional
import torch import torch
from kws.model.cmvn import GlobalCMVN from kws.model.cmvn import GlobalCMVN
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1 from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
from kws.model.mdtc import MDTC from kws.model.mdtc import MDTC
from kws.utils.cmvn import load_cmvn from kws.utils.cmvn import load_cmvn
@ -52,8 +52,7 @@ class KWSModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None: if self.global_cmvn is not None:
x = self.global_cmvn(x) x = self.global_cmvn(x)
if self.preprocessing: x = self.preprocessing(x)
x = self.preprocessing(x)
x, _ = self.backbone(x) x, _ = self.backbone(x)
x = self.classifier(x) x = self.classifier(x)
x = torch.sigmoid(x) x = torch.sigmoid(x)
@ -82,7 +81,7 @@ def init_model(configs):
elif prep_type == 'cnn1d_s1': elif prep_type == 'cnn1d_s1':
preprocessing = Conv1dSubsampling1(input_dim, hidden_dim) preprocessing = Conv1dSubsampling1(input_dim, hidden_dim)
elif prep_type == 'none': elif prep_type == 'none':
preprocessing = None preprocessing = NoSubsampling()
else: else:
print('Unknown preprocessing type {}'.format(prep_type)) print('Unknown preprocessing type {}'.format(prep_type))
sys.exit(1) sys.exit(1)

View File

@ -241,16 +241,27 @@ class MDTC(nn.Module):
outputs = outputs.transpose(1, 2) outputs = outputs.transpose(1, 2)
outputs_list = [] outputs_list = []
outputs = self.relu(self.preprocessor(outputs)) outputs = self.relu(self.preprocessor(outputs))
for i in range(len(self.blocks)): for block in self.blocks:
outputs = self.blocks[i](outputs) outputs = block(outputs)
outputs_list.append(outputs) outputs_list.append(outputs)
if self.causal: normalized_outputs = []
outputs_list = self.normalize_length_causal(outputs_list) output_size = outputs_list[-1].shape[-1]
else: for x in outputs_list:
outputs_list = self.normalize_length(outputs_list) remove_length = x.shape[-1] - output_size
if remove_length != 0:
if self.causal:
normalized_outputs.append(x[:, :, remove_length:])
else:
remove_length = remove_length // 2
normalized_outputs.append(x[:, :,
remove_length:-remove_length])
else:
normalized_outputs.append(x)
outputs = sum(outputs_list) outputs = torch.zeros_like(outputs_list[-1], dtype=outputs_list[-1].dtype)
for x in normalized_outputs:
outputs += x
outputs = outputs.transpose(1, 2) outputs = outputs.transpose(1, 2)
return outputs, None return outputs, None

View File

@ -23,6 +23,14 @@ class SubsamplingBase(torch.nn.Module):
super().__init__() super().__init__()
self.subsampling_rate = 1 self.subsampling_rate = 1
class NoSubsampling(SubsamplingBase):
"""No subsampling in accordance to the 'none' preprocessing
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class LinearSubsampling1(SubsamplingBase): class LinearSubsampling1(SubsamplingBase):
"""Linear transform the input without subsampling """Linear transform the input without subsampling