wekws/kws/model/kws_model.py
xiaohou c48c959807
[recipe] suport speech command dataset (#21)
* [recipe] suport speech command dataset

* format

* format

* format

* update run.sh
2021-12-03 21:07:42 +08:00

146 lines
5.1 KiB
Python

# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Optional
import torch
import torch.nn as nn
from kws.model.cmvn import GlobalCMVN
from kws.model.classifier import GlobalClassifier, LastClassifier
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
from kws.model.mdtc import MDTC
from kws.utils.cmvn import load_cmvn
class KWSModel(torch.nn.Module):
"""Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim)
2. preprocessing: feature dimention projection, (idim, hdim)
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim)
"""
def __init__(
self,
idim: int,
odim: int,
hdim: int,
global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module,
classifier: torch.nn.Module
):
super().__init__()
self.idim = idim
self.odim = odim
self.hdim = hdim
self.global_cmvn = global_cmvn
self.preprocessing = preprocessing
self.backbone = backbone
self.classifier = classifier
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None:
x = self.global_cmvn(x)
x = self.preprocessing(x)
x, _ = self.backbone(x)
x = self.classifier(x)
return x
def init_model(configs):
cmvn = configs.get('cmvn', {})
if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None:
mean, istd = load_cmvn(cmvn['cmvn_file'])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float(),
cmvn['norm_var'],
)
else:
global_cmvn = None
input_dim = configs['input_dim']
output_dim = configs['output_dim']
hidden_dim = configs['hidden_dim']
prep_type = configs['preprocessing']['type']
if prep_type == 'linear':
preprocessing = LinearSubsampling1(input_dim, hidden_dim)
elif prep_type == 'cnn1d_s1':
preprocessing = Conv1dSubsampling1(input_dim, hidden_dim)
elif prep_type == 'none':
preprocessing = NoSubsampling()
else:
print('Unknown preprocessing type {}'.format(prep_type))
sys.exit(1)
backbone_type = configs['backbone']['type']
if backbone_type == 'gru':
num_layers = configs['backbone']['num_layers']
backbone = torch.nn.GRU(hidden_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True)
elif backbone_type == 'tcn':
# Depthwise Separable
num_layers = configs['backbone']['num_layers']
ds = configs['backbone'].get('ds', False)
if ds:
block_class = DsCnnBlock
else:
block_class = CnnBlock
kernel_size = configs['backbone'].get('kernel_size', 8)
dropout = configs['backbone'].get('drouput', 0.1)
backbone = TCN(num_layers, hidden_dim, kernel_size, dropout,
block_class)
elif backbone_type == 'mdtc':
stack_size = configs['backbone']['stack_size']
num_stack = configs['backbone']['num_stack']
kernel_size = configs['backbone']['kernel_size']
hidden_dim = configs['backbone']['hidden_dim']
causal = configs['backbone']['causal']
backbone = MDTC(num_stack,
stack_size,
input_dim,
hidden_dim,
kernel_size,
causal=causal)
else:
print('Unknown body type {}'.format(backbone_type))
sys.exit(1)
classifier_type = configs['classifier']['type']
dropout = configs['classifier']['dropout']
classifier_base = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, output_dim),
)
if classifier_type == 'linear':
classifier = classifier_base
elif classifier_type == 'global':
classifier = GlobalClassifier(classifier_base)
elif classifier_type == 'last':
classifier = LastClassifier(classifier_base)
else:
print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1)
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier)
return kws_model