modifications to get the mdtc model torch-scriptable (#14)
* modifying some implmentations of mdtc to get the model torch-scripting through * modifications to get the mdtc model torch-scriptable Co-authored-by: lxiao336 <shawl336@163.com>
This commit is contained in:
parent
bf3029c0fa
commit
ba6919baaf
@ -102,7 +102,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
--batch_size 256 \
|
--batch_size 256 \
|
||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
--score_file $result_dir/score.txt
|
--score_file $result_dir/score.txt \
|
||||||
|
--num_workers 8
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from kws.model.cmvn import GlobalCMVN
|
from kws.model.cmvn import GlobalCMVN
|
||||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
|
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
|
||||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||||
from kws.model.mdtc import MDTC
|
from kws.model.mdtc import MDTC
|
||||||
from kws.utils.cmvn import load_cmvn
|
from kws.utils.cmvn import load_cmvn
|
||||||
@ -52,8 +52,7 @@ class KWSModel(torch.nn.Module):
|
|||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
x = self.global_cmvn(x)
|
x = self.global_cmvn(x)
|
||||||
if self.preprocessing:
|
x = self.preprocessing(x)
|
||||||
x = self.preprocessing(x)
|
|
||||||
x, _ = self.backbone(x)
|
x, _ = self.backbone(x)
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
x = torch.sigmoid(x)
|
x = torch.sigmoid(x)
|
||||||
@ -82,7 +81,7 @@ def init_model(configs):
|
|||||||
elif prep_type == 'cnn1d_s1':
|
elif prep_type == 'cnn1d_s1':
|
||||||
preprocessing = Conv1dSubsampling1(input_dim, hidden_dim)
|
preprocessing = Conv1dSubsampling1(input_dim, hidden_dim)
|
||||||
elif prep_type == 'none':
|
elif prep_type == 'none':
|
||||||
preprocessing = None
|
preprocessing = NoSubsampling()
|
||||||
else:
|
else:
|
||||||
print('Unknown preprocessing type {}'.format(prep_type))
|
print('Unknown preprocessing type {}'.format(prep_type))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|||||||
@ -241,16 +241,27 @@ class MDTC(nn.Module):
|
|||||||
outputs = outputs.transpose(1, 2)
|
outputs = outputs.transpose(1, 2)
|
||||||
outputs_list = []
|
outputs_list = []
|
||||||
outputs = self.relu(self.preprocessor(outputs))
|
outputs = self.relu(self.preprocessor(outputs))
|
||||||
for i in range(len(self.blocks)):
|
for block in self.blocks:
|
||||||
outputs = self.blocks[i](outputs)
|
outputs = block(outputs)
|
||||||
outputs_list.append(outputs)
|
outputs_list.append(outputs)
|
||||||
|
|
||||||
if self.causal:
|
normalized_outputs = []
|
||||||
outputs_list = self.normalize_length_causal(outputs_list)
|
output_size = outputs_list[-1].shape[-1]
|
||||||
else:
|
for x in outputs_list:
|
||||||
outputs_list = self.normalize_length(outputs_list)
|
remove_length = x.shape[-1] - output_size
|
||||||
|
if remove_length != 0:
|
||||||
|
if self.causal:
|
||||||
|
normalized_outputs.append(x[:, :, remove_length:])
|
||||||
|
else:
|
||||||
|
remove_length = remove_length // 2
|
||||||
|
normalized_outputs.append(x[:, :,
|
||||||
|
remove_length:-remove_length])
|
||||||
|
else:
|
||||||
|
normalized_outputs.append(x)
|
||||||
|
|
||||||
outputs = sum(outputs_list)
|
outputs = torch.zeros_like(outputs_list[-1], dtype=outputs_list[-1].dtype)
|
||||||
|
for x in normalized_outputs:
|
||||||
|
outputs += x
|
||||||
outputs = outputs.transpose(1, 2)
|
outputs = outputs.transpose(1, 2)
|
||||||
return outputs, None
|
return outputs, None
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,14 @@ class SubsamplingBase(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.subsampling_rate = 1
|
self.subsampling_rate = 1
|
||||||
|
|
||||||
|
class NoSubsampling(SubsamplingBase):
|
||||||
|
"""No subsampling in accordance to the 'none' preprocessing
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x
|
||||||
|
|
||||||
class LinearSubsampling1(SubsamplingBase):
|
class LinearSubsampling1(SubsamplingBase):
|
||||||
"""Linear transform the input without subsampling
|
"""Linear transform the input without subsampling
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user