modifications to get the mdtc model torch-scriptable (#14)
* modifying some implmentations of mdtc to get the model torch-scripting through * modifications to get the mdtc model torch-scriptable Co-authored-by: lxiao336 <shawl336@163.com>
This commit is contained in:
parent
bf3029c0fa
commit
ba6919baaf
@ -102,7 +102,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
--test_data data/test/data.list \
|
||||
--batch_size 256 \
|
||||
--checkpoint $score_checkpoint \
|
||||
--score_file $result_dir/score.txt
|
||||
--score_file $result_dir/score.txt \
|
||||
--num_workers 8
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
|
||||
@ -18,7 +18,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from kws.model.cmvn import GlobalCMVN
|
||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
|
||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
|
||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||
from kws.model.mdtc import MDTC
|
||||
from kws.utils.cmvn import load_cmvn
|
||||
@ -52,8 +52,7 @@ class KWSModel(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.global_cmvn is not None:
|
||||
x = self.global_cmvn(x)
|
||||
if self.preprocessing:
|
||||
x = self.preprocessing(x)
|
||||
x = self.preprocessing(x)
|
||||
x, _ = self.backbone(x)
|
||||
x = self.classifier(x)
|
||||
x = torch.sigmoid(x)
|
||||
@ -82,7 +81,7 @@ def init_model(configs):
|
||||
elif prep_type == 'cnn1d_s1':
|
||||
preprocessing = Conv1dSubsampling1(input_dim, hidden_dim)
|
||||
elif prep_type == 'none':
|
||||
preprocessing = None
|
||||
preprocessing = NoSubsampling()
|
||||
else:
|
||||
print('Unknown preprocessing type {}'.format(prep_type))
|
||||
sys.exit(1)
|
||||
|
||||
@ -241,16 +241,27 @@ class MDTC(nn.Module):
|
||||
outputs = outputs.transpose(1, 2)
|
||||
outputs_list = []
|
||||
outputs = self.relu(self.preprocessor(outputs))
|
||||
for i in range(len(self.blocks)):
|
||||
outputs = self.blocks[i](outputs)
|
||||
for block in self.blocks:
|
||||
outputs = block(outputs)
|
||||
outputs_list.append(outputs)
|
||||
|
||||
if self.causal:
|
||||
outputs_list = self.normalize_length_causal(outputs_list)
|
||||
else:
|
||||
outputs_list = self.normalize_length(outputs_list)
|
||||
normalized_outputs = []
|
||||
output_size = outputs_list[-1].shape[-1]
|
||||
for x in outputs_list:
|
||||
remove_length = x.shape[-1] - output_size
|
||||
if remove_length != 0:
|
||||
if self.causal:
|
||||
normalized_outputs.append(x[:, :, remove_length:])
|
||||
else:
|
||||
remove_length = remove_length // 2
|
||||
normalized_outputs.append(x[:, :,
|
||||
remove_length:-remove_length])
|
||||
else:
|
||||
normalized_outputs.append(x)
|
||||
|
||||
outputs = sum(outputs_list)
|
||||
outputs = torch.zeros_like(outputs_list[-1], dtype=outputs_list[-1].dtype)
|
||||
for x in normalized_outputs:
|
||||
outputs += x
|
||||
outputs = outputs.transpose(1, 2)
|
||||
return outputs, None
|
||||
|
||||
|
||||
@ -23,6 +23,14 @@ class SubsamplingBase(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.subsampling_rate = 1
|
||||
|
||||
class NoSubsampling(SubsamplingBase):
|
||||
"""No subsampling in accordance to the 'none' preprocessing
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
class LinearSubsampling1(SubsamplingBase):
|
||||
"""Linear transform the input without subsampling
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user