diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index fe4b782..7c97999 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -102,7 +102,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --test_data data/test/data.list \ --batch_size 256 \ --checkpoint $score_checkpoint \ - --score_file $result_dir/score.txt + --score_file $result_dir/score.txt \ + --num_workers 8 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index ab543dd..dd6142a 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -18,7 +18,7 @@ from typing import Optional import torch from kws.model.cmvn import GlobalCMVN -from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1 +from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling from kws.model.tcn import TCN, CnnBlock, DsCnnBlock from kws.model.mdtc import MDTC from kws.utils.cmvn import load_cmvn @@ -52,8 +52,7 @@ class KWSModel(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: if self.global_cmvn is not None: x = self.global_cmvn(x) - if self.preprocessing: - x = self.preprocessing(x) + x = self.preprocessing(x) x, _ = self.backbone(x) x = self.classifier(x) x = torch.sigmoid(x) @@ -82,7 +81,7 @@ def init_model(configs): elif prep_type == 'cnn1d_s1': preprocessing = Conv1dSubsampling1(input_dim, hidden_dim) elif prep_type == 'none': - preprocessing = None + preprocessing = NoSubsampling() else: print('Unknown preprocessing type {}'.format(prep_type)) sys.exit(1) diff --git a/kws/model/mdtc.py b/kws/model/mdtc.py index a5ab929..a50d401 100644 --- a/kws/model/mdtc.py +++ b/kws/model/mdtc.py @@ -241,16 +241,27 @@ class MDTC(nn.Module): outputs = outputs.transpose(1, 2) outputs_list = [] outputs = self.relu(self.preprocessor(outputs)) - for i in range(len(self.blocks)): - outputs = self.blocks[i](outputs) + for block in self.blocks: + outputs = block(outputs) outputs_list.append(outputs) - if self.causal: - outputs_list = self.normalize_length_causal(outputs_list) - else: - outputs_list = self.normalize_length(outputs_list) + normalized_outputs = [] + output_size = outputs_list[-1].shape[-1] + for x in outputs_list: + remove_length = x.shape[-1] - output_size + if remove_length != 0: + if self.causal: + normalized_outputs.append(x[:, :, remove_length:]) + else: + remove_length = remove_length // 2 + normalized_outputs.append(x[:, :, + remove_length:-remove_length]) + else: + normalized_outputs.append(x) - outputs = sum(outputs_list) + outputs = torch.zeros_like(outputs_list[-1], dtype=outputs_list[-1].dtype) + for x in normalized_outputs: + outputs += x outputs = outputs.transpose(1, 2) return outputs, None diff --git a/kws/model/subsampling.py b/kws/model/subsampling.py index 044ef24..1aba989 100644 --- a/kws/model/subsampling.py +++ b/kws/model/subsampling.py @@ -23,6 +23,14 @@ class SubsamplingBase(torch.nn.Module): super().__init__() self.subsampling_rate = 1 +class NoSubsampling(SubsamplingBase): + """No subsampling in accordance to the 'none' preprocessing + """ + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x class LinearSubsampling1(SubsamplingBase): """Linear transform the input without subsampling