This reverts commit c48c959807e7e80cdd514be9bd019b16e3b816eb.
This commit is contained in:
parent
c48c959807
commit
dfe8b2536b
@ -1,52 +0,0 @@
|
|||||||
dataset_conf:
|
|
||||||
filter_conf:
|
|
||||||
max_length: 2048
|
|
||||||
min_length: 0
|
|
||||||
resample_conf:
|
|
||||||
resample_rate: 16000
|
|
||||||
speed_perturb: false
|
|
||||||
feature_extraction_conf:
|
|
||||||
feature_type: 'mfcc'
|
|
||||||
num_ceps: 80
|
|
||||||
num_mel_bins: 80
|
|
||||||
frame_shift: 10
|
|
||||||
frame_length: 25
|
|
||||||
dither: 1.0
|
|
||||||
feature_dither: 0.0
|
|
||||||
spec_aug: true
|
|
||||||
spec_aug_conf:
|
|
||||||
num_t_mask: 2
|
|
||||||
num_f_mask: 2
|
|
||||||
max_t: 20
|
|
||||||
max_f: 40
|
|
||||||
shuffle: true
|
|
||||||
shuffle_conf:
|
|
||||||
shuffle_size: 1500
|
|
||||||
batch_conf:
|
|
||||||
batch_size: 100
|
|
||||||
|
|
||||||
model:
|
|
||||||
hidden_dim: 64
|
|
||||||
preprocessing:
|
|
||||||
type: none
|
|
||||||
backbone:
|
|
||||||
type: mdtc
|
|
||||||
num_stack: 4
|
|
||||||
stack_size: 4
|
|
||||||
kernel_size: 5
|
|
||||||
hidden_dim: 64
|
|
||||||
causal: False
|
|
||||||
classifier:
|
|
||||||
type: global
|
|
||||||
dropout: 0.5
|
|
||||||
|
|
||||||
optim: adam
|
|
||||||
optim_conf:
|
|
||||||
lr: 0.0002
|
|
||||||
weight_decay: 0.00005
|
|
||||||
|
|
||||||
training_config:
|
|
||||||
grad_clip: 50
|
|
||||||
max_epoch: 100
|
|
||||||
log_interval: 10
|
|
||||||
criterion: ce
|
|
||||||
@ -1 +0,0 @@
|
|||||||
../../../kws
|
|
||||||
@ -1,30 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Copyright 2021 Jingyong Hou (houjingyong@gmail.com)
|
|
||||||
[ -f ./path.sh ] && . ./path.sh
|
|
||||||
|
|
||||||
dl_dir=./data/local
|
|
||||||
|
|
||||||
. tools/parse_options.sh || exit 1;
|
|
||||||
data_dir=$dl_dir
|
|
||||||
file_name=speech_commands_v0.01.tar.gz
|
|
||||||
speech_command_dir=$data_dir/speech_commands_v1
|
|
||||||
audio_dir=$data_dir/speech_commands_v1/audio
|
|
||||||
url=http://download.tensorflow.org/data/$file_name
|
|
||||||
mkdir -p $data_dir
|
|
||||||
if [ ! -f $data_dir/$file_name ]; then
|
|
||||||
echo "downloading $url..."
|
|
||||||
wget -O $data_dir/$file_name $url
|
|
||||||
else
|
|
||||||
echo "$file_name exist in $data_dir, skip download it"
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f $speech_command_dir/.extracted ]; then
|
|
||||||
mkdir -p $audio_dir
|
|
||||||
tar -xzvf $data_dir/$file_name -C $audio_dir
|
|
||||||
touch $speech_command_dir/.extracted
|
|
||||||
else
|
|
||||||
echo "$speech_command_dir/.exatracted exist in $speech_command_dir, skip exatraction"
|
|
||||||
fi
|
|
||||||
|
|
||||||
exit 0
|
|
||||||
@ -1,35 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
import os
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
|
|
||||||
', ')
|
|
||||||
CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='prepare kaldi format file for google speech command')
|
|
||||||
parser.add_argument(
|
|
||||||
'--wav_list',
|
|
||||||
required=True,
|
|
||||||
help='full path of a wav file in google speech command dataset')
|
|
||||||
parser.add_argument('--data_dir',
|
|
||||||
required=True,
|
|
||||||
help='folder to write kaldi format files')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
data_dir = args.data_dir
|
|
||||||
f_wav_scp = open(os.path.join(data_dir, 'wav.scp'), 'w')
|
|
||||||
f_text = open(os.path.join(data_dir, 'text'), 'w')
|
|
||||||
with open(args.wav_list) as f:
|
|
||||||
for line in f.readlines():
|
|
||||||
keyword, file_name = line.strip().split('/')[-2:]
|
|
||||||
file_name_new = file_name.split('.')[0]
|
|
||||||
wav_id = '_'.join([keyword, file_name_new])
|
|
||||||
file_dir = line.strip()
|
|
||||||
f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n')
|
|
||||||
label = CLASS_TO_IDX[
|
|
||||||
keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"]
|
|
||||||
f_text.writelines(wav_id + ' ' + str(label) + '\n')
|
|
||||||
f_wav_scp.close()
|
|
||||||
f_text.close()
|
|
||||||
@ -1,42 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
'''Splits the google speech commands into train, validation and test set'''
|
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
|
|
||||||
def move_files(src_folder, to_folder, list_file):
|
|
||||||
with open(list_file) as f:
|
|
||||||
for line in f.readlines():
|
|
||||||
line = line.rstrip()
|
|
||||||
dirname = os.path.dirname(line)
|
|
||||||
dest = os.path.join(to_folder, dirname)
|
|
||||||
if not os.path.exists(dest):
|
|
||||||
os.mkdir(dest)
|
|
||||||
shutil.move(os.path.join(src_folder, line), dest)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='Split google command dataset.')
|
|
||||||
parser.add_argument(
|
|
||||||
'root',
|
|
||||||
type=str,
|
|
||||||
help='the path to the root folder of the google commands dataset')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
audio_folder = os.path.join(args.root, 'audio')
|
|
||||||
validation_path = os.path.join(audio_folder, 'validation_list.txt')
|
|
||||||
test_path = os.path.join(audio_folder, 'testing_list.txt')
|
|
||||||
|
|
||||||
valid_folder = os.path.join(args.root, 'valid')
|
|
||||||
test_folder = os.path.join(args.root, 'test')
|
|
||||||
train_folder = os.path.join(args.root, 'train')
|
|
||||||
|
|
||||||
os.mkdir(valid_folder)
|
|
||||||
os.mkdir(test_folder)
|
|
||||||
|
|
||||||
move_files(audio_folder, test_folder, test_path)
|
|
||||||
move_files(audio_folder, valid_folder, validation_path)
|
|
||||||
os.rename(audio_folder, train_folder)
|
|
||||||
@ -1 +0,0 @@
|
|||||||
../../hi_xiaowen/s0/path.sh
|
|
||||||
@ -1,108 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# Copyright 2021 Binbin Zhang
|
|
||||||
# Jingyong Hou
|
|
||||||
|
|
||||||
. ./path.sh
|
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0"
|
|
||||||
|
|
||||||
stage=-1
|
|
||||||
stop_stage=4
|
|
||||||
num_keywords=11
|
|
||||||
|
|
||||||
config=conf/mdtc.yaml
|
|
||||||
norm_mean=false
|
|
||||||
norm_var=false
|
|
||||||
gpu_id=4
|
|
||||||
|
|
||||||
checkpoint=
|
|
||||||
dir=exp/mdtc_debug
|
|
||||||
|
|
||||||
num_average=10
|
|
||||||
score_checkpoint=$dir/avg_${num_average}.pt
|
|
||||||
|
|
||||||
# your data dir
|
|
||||||
download_dir=/mnt/mnt-data-3/jingyong.hou/data
|
|
||||||
speech_command_dir=$download_dir/speech_commands_v1
|
|
||||||
. tools/parse_options.sh || exit 1;
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
|
||||||
echo "Download and extract all datasets"
|
|
||||||
local/data_download.sh --dl_dir $download_dir
|
|
||||||
python local/split_dataset.py $download_dir/speech_commands_v1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
|
||||||
echo "Start preparing Kaldi format files"
|
|
||||||
for x in train test valid;
|
|
||||||
do
|
|
||||||
data=data/$x
|
|
||||||
mkdir -p $data
|
|
||||||
# make wav.scp utt2spk text file
|
|
||||||
find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list
|
|
||||||
python local/prepare_speech_command.py --wav_list=$data/wav.list --data_dir=$data
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
|
||||||
echo "Compute CMVN and Format datasets"
|
|
||||||
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
|
|
||||||
--in_scp data/train/wav.scp \
|
|
||||||
--out_cmvn data/train/global_cmvn
|
|
||||||
|
|
||||||
for x in train valid test; do
|
|
||||||
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
|
||||||
tools/make_list.py data/$x/wav.scp data/$x/text \
|
|
||||||
data/$x/wav.dur data/$x/data.list
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
|
||||||
echo "Start training ..."
|
|
||||||
mkdir -p $dir
|
|
||||||
cmvn_opts=
|
|
||||||
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
|
||||||
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
|
||||||
python kws/bin/train.py --gpu $gpu_id \
|
|
||||||
--config $config \
|
|
||||||
--train_data data/train/data.list \
|
|
||||||
--cv_data data/valid/data.list \
|
|
||||||
--model_dir $dir \
|
|
||||||
--num_workers 8 \
|
|
||||||
--num_keywords $num_keywords \
|
|
||||||
--min_duration 50 \
|
|
||||||
$cmvn_opts \
|
|
||||||
${checkpoint:+--checkpoint $checkpoint}
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|
||||||
# Do model average
|
|
||||||
python kws/bin/average_model.py \
|
|
||||||
--dst_model $score_checkpoint \
|
|
||||||
--src_path $dir \
|
|
||||||
--num ${num_average} \
|
|
||||||
--val_best
|
|
||||||
|
|
||||||
# Testing
|
|
||||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
|
||||||
mkdir -p $result_dir
|
|
||||||
python kws/bin/test.py --gpu 3 \
|
|
||||||
--config $dir/config.yaml \
|
|
||||||
--test_data data/test/data.list \
|
|
||||||
--batch_size 256 \
|
|
||||||
--num_workers 8 \
|
|
||||||
--checkpoint $score_checkpoint
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
||||||
python kws/bin/export_jit.py --config $dir/config.yaml \
|
|
||||||
--checkpoint $score_checkpoint \
|
|
||||||
--output_file $dir/final.zip \
|
|
||||||
--output_quant_file $dir/final.quant.zip
|
|
||||||
fi
|
|
||||||
@ -1 +0,0 @@
|
|||||||
../../../tools
|
|
||||||
102
kws/bin/test.py
102
kws/bin/test.py
@ -1,102 +0,0 @@
|
|||||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import copy
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from kws.dataset.dataset import Dataset
|
|
||||||
from kws.model.kws_model import init_model
|
|
||||||
from kws.utils.checkpoint import load_checkpoint
|
|
||||||
from kws.utils.executor import Executor
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser(description='recognize with your model')
|
|
||||||
parser.add_argument('--config', required=True, help='config file')
|
|
||||||
parser.add_argument('--test_data', required=True, help='test data file')
|
|
||||||
parser.add_argument('--gpu',
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help='gpu id for this rank, -1 for cpu')
|
|
||||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
|
||||||
parser.add_argument('--batch_size',
|
|
||||||
default=16,
|
|
||||||
type=int,
|
|
||||||
help='batch size for inference')
|
|
||||||
parser.add_argument('--num_workers',
|
|
||||||
default=0,
|
|
||||||
type=int,
|
|
||||||
help='num of subprocess workers for reading')
|
|
||||||
parser.add_argument('--pin_memory',
|
|
||||||
action='store_true',
|
|
||||||
default=False,
|
|
||||||
help='Use pinned memory buffers used for reading')
|
|
||||||
parser.add_argument('--prefetch',
|
|
||||||
default=100,
|
|
||||||
type=int,
|
|
||||||
help='prefetch number')
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_args()
|
|
||||||
logging.basicConfig(level=logging.DEBUG,
|
|
||||||
format='%(asctime)s %(levelname)s %(message)s')
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
|
||||||
|
|
||||||
with open(args.config, 'r') as fin:
|
|
||||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
|
||||||
|
|
||||||
test_conf = copy.deepcopy(configs['dataset_conf'])
|
|
||||||
test_conf['filter_conf']['max_length'] = 102400
|
|
||||||
test_conf['filter_conf']['min_length'] = 0
|
|
||||||
test_conf['speed_perturb'] = False
|
|
||||||
test_conf['spec_aug'] = False
|
|
||||||
test_conf['shuffle'] = False
|
|
||||||
test_conf['feature_extraction_conf']['dither'] = 0.0
|
|
||||||
test_conf['batch_conf']['batch_size'] = args.batch_size
|
|
||||||
|
|
||||||
test_dataset = Dataset(args.test_data, test_conf)
|
|
||||||
test_data_loader = DataLoader(test_dataset,
|
|
||||||
batch_size=None,
|
|
||||||
pin_memory=args.pin_memory,
|
|
||||||
num_workers=args.num_workers)
|
|
||||||
|
|
||||||
# Init asr model from configs
|
|
||||||
model = init_model(configs['model'])
|
|
||||||
|
|
||||||
load_checkpoint(model, args.checkpoint)
|
|
||||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
|
||||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
|
||||||
model = model.to(device)
|
|
||||||
executor = Executor()
|
|
||||||
model.eval()
|
|
||||||
training_config = configs['training_config']
|
|
||||||
with torch.no_grad():
|
|
||||||
test_loss, test_acc = executor.test(model, test_data_loader, device,
|
|
||||||
training_config)
|
|
||||||
logging.info('Test Loss {} Acc {}'.format(test_loss, test_acc))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@ -164,9 +164,9 @@ def main():
|
|||||||
# !!!IMPORTANT!!!
|
# !!!IMPORTANT!!!
|
||||||
# Try to export the model by script, if fails, we should refine
|
# Try to export the model by script, if fails, we should refine
|
||||||
# the code to satisfy the script export requirements
|
# the code to satisfy the script export requirements
|
||||||
if args.rank == 0:
|
# if args.rank == 0:
|
||||||
script_model = torch.jit.script(model)
|
# script_model = torch.jit.script(model)
|
||||||
script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||||
executor = Executor()
|
executor = Executor()
|
||||||
# If specify checkpoint, load some info from checkpoint
|
# If specify checkpoint, load some info from checkpoint
|
||||||
if args.checkpoint is not None:
|
if args.checkpoint is not None:
|
||||||
@ -196,7 +196,8 @@ def main():
|
|||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(),
|
optimizer = optim.Adam(model.parameters(),
|
||||||
lr=configs['optim_conf']['lr'])
|
lr=configs['optim_conf']['lr'],
|
||||||
|
weight_decay=configs['optim_conf']['weight_decay'])
|
||||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
optimizer,
|
optimizer,
|
||||||
mode='min',
|
mode='min',
|
||||||
@ -222,9 +223,8 @@ def main():
|
|||||||
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
||||||
executor.train(model, optimizer, train_data_loader, device, writer,
|
executor.train(model, optimizer, train_data_loader, device, writer,
|
||||||
training_config)
|
training_config)
|
||||||
cv_loss, cv_acc = executor.cv(model, cv_data_loader, device, training_config)
|
cv_loss = executor.cv(model, cv_data_loader, device, training_config)
|
||||||
logging.info('Epoch {} CV info cv_loss {} cv_acc {}'
|
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
|
||||||
.format(epoch, cv_loss, cv_acc))
|
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
|
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
|
||||||
@ -234,7 +234,6 @@ def main():
|
|||||||
'cv_loss': cv_loss,
|
'cv_loss': cv_loss,
|
||||||
})
|
})
|
||||||
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
|
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
|
||||||
writer.add_scalar('epoch/cv_acc', cv_acc, epoch)
|
|
||||||
writer.add_scalar('epoch/lr', lr, epoch)
|
writer.add_scalar('epoch/lr', lr, epoch)
|
||||||
final_epoch = epoch
|
final_epoch = epoch
|
||||||
scheduler.step(cv_loss)
|
scheduler.step(cv_loss)
|
||||||
|
|||||||
@ -1,33 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalClassifier(nn.Module):
|
|
||||||
"""Add a global average pooling before the classifier"""
|
|
||||||
def __init__(self, classifier: nn.Module):
|
|
||||||
super(GlobalClassifier, self).__init__()
|
|
||||||
self.classifier = classifier
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
x = torch.mean(x, dim=1)
|
|
||||||
return self.classifier(x)
|
|
||||||
|
|
||||||
|
|
||||||
class LastClassifier(nn.Module):
|
|
||||||
"""Select last frame to do the classification"""
|
|
||||||
def __init__(self, classifier: nn.Module):
|
|
||||||
super(LastClassifier, self).__init__()
|
|
||||||
self.classifier = classifier
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
x = x[:, -1, :]
|
|
||||||
return self.classifier(x)
|
|
||||||
|
|
||||||
class ElementClassifier(nn.Module):
|
|
||||||
"""Classify all the frames in an utterance"""
|
|
||||||
def __init__(self, classifier: nn.Module):
|
|
||||||
super(ElementClassifier, self).__init__()
|
|
||||||
self.classifier = classifier
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
return self.classifier(x)
|
|
||||||
@ -16,10 +16,8 @@ import sys
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from kws.model.cmvn import GlobalCMVN
|
from kws.model.cmvn import GlobalCMVN
|
||||||
from kws.model.classifier import GlobalClassifier, LastClassifier
|
|
||||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
|
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1, NoSubsampling
|
||||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||||
from kws.model.mdtc import MDTC
|
from kws.model.mdtc import MDTC
|
||||||
@ -41,7 +39,6 @@ class KWSModel(torch.nn.Module):
|
|||||||
global_cmvn: Optional[torch.nn.Module],
|
global_cmvn: Optional[torch.nn.Module],
|
||||||
preprocessing: Optional[torch.nn.Module],
|
preprocessing: Optional[torch.nn.Module],
|
||||||
backbone: torch.nn.Module,
|
backbone: torch.nn.Module,
|
||||||
classifier: torch.nn.Module
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.idim = idim
|
self.idim = idim
|
||||||
@ -50,7 +47,7 @@ class KWSModel(torch.nn.Module):
|
|||||||
self.global_cmvn = global_cmvn
|
self.global_cmvn = global_cmvn
|
||||||
self.preprocessing = preprocessing
|
self.preprocessing = preprocessing
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.classifier = classifier
|
self.classifier = torch.nn.Linear(hdim, odim)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
@ -58,6 +55,7 @@ class KWSModel(torch.nn.Module):
|
|||||||
x = self.preprocessing(x)
|
x = self.preprocessing(x)
|
||||||
x, _ = self.backbone(x)
|
x, _ = self.backbone(x)
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
|
x = torch.sigmoid(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -112,34 +110,17 @@ def init_model(configs):
|
|||||||
num_stack = configs['backbone']['num_stack']
|
num_stack = configs['backbone']['num_stack']
|
||||||
kernel_size = configs['backbone']['kernel_size']
|
kernel_size = configs['backbone']['kernel_size']
|
||||||
hidden_dim = configs['backbone']['hidden_dim']
|
hidden_dim = configs['backbone']['hidden_dim']
|
||||||
causal = configs['backbone']['causal']
|
|
||||||
backbone = MDTC(num_stack,
|
backbone = MDTC(num_stack,
|
||||||
stack_size,
|
stack_size,
|
||||||
input_dim,
|
input_dim,
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
causal=causal)
|
causal=True)
|
||||||
else:
|
else:
|
||||||
print('Unknown body type {}'.format(backbone_type))
|
print('Unknown body type {}'.format(backbone_type))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
classifier_type = configs['classifier']['type']
|
|
||||||
dropout = configs['classifier']['dropout']
|
|
||||||
classifier_base = nn.Sequential(
|
|
||||||
nn.Linear(hidden_dim, 64),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(dropout),
|
|
||||||
nn.Linear(64, output_dim),
|
|
||||||
)
|
|
||||||
if classifier_type == 'linear':
|
|
||||||
classifier = classifier_base
|
|
||||||
elif classifier_type == 'global':
|
|
||||||
classifier = GlobalClassifier(classifier_base)
|
|
||||||
elif classifier_type == 'last':
|
|
||||||
classifier = LastClassifier(classifier_base)
|
|
||||||
else:
|
|
||||||
print('Unknown classifier type {}'.format(classifier_type))
|
|
||||||
sys.exit(1)
|
|
||||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||||
preprocessing, backbone, classifier)
|
preprocessing, backbone)
|
||||||
return kws_model
|
return kws_model
|
||||||
|
|||||||
@ -13,12 +13,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from kws.utils.mask import padding_mask
|
from kws.utils.mask import padding_mask
|
||||||
|
|
||||||
|
|
||||||
def max_pooling_loss(logits: torch.Tensor,
|
def max_polling_loss(logits: torch.Tensor,
|
||||||
target: torch.Tensor,
|
target: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
lengths: torch.Tensor,
|
||||||
min_duration: int = 0):
|
min_duration: int = 0):
|
||||||
@ -38,7 +37,6 @@ def max_pooling_loss(logits: torch.Tensor,
|
|||||||
(float): loss of current batch
|
(float): loss of current batch
|
||||||
(float): accuracy of current batch
|
(float): accuracy of current batch
|
||||||
'''
|
'''
|
||||||
logits = torch.sigmoid(logits)
|
|
||||||
mask = padding_mask(lengths)
|
mask = padding_mask(lengths)
|
||||||
num_utts = logits.size(0)
|
num_utts = logits.size(0)
|
||||||
num_keywords = logits.size(2)
|
num_keywords = logits.size(2)
|
||||||
@ -82,46 +80,3 @@ def max_pooling_loss(logits: torch.Tensor,
|
|||||||
acc = num_correct / num_utts
|
acc = num_correct / num_utts
|
||||||
# acc = 0.0
|
# acc = 0.0
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|
||||||
|
|
||||||
def acc_frame(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
target: torch.Tensor,
|
|
||||||
):
|
|
||||||
if logits is None:
|
|
||||||
return 0
|
|
||||||
pred = logits.max(1, keepdim=True)[1]
|
|
||||||
correct = pred.eq(target.long().view_as(pred)).sum().item()
|
|
||||||
return correct * 100.0 / logits.size(0)
|
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
|
|
||||||
""" Cross Entropy Loss
|
|
||||||
Attributes:
|
|
||||||
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
|
||||||
target: (B)
|
|
||||||
lengths: (B)
|
|
||||||
min_duration: min duration of the keyword
|
|
||||||
Returns:
|
|
||||||
(float): loss of current batch
|
|
||||||
(float): accuracy of current batch
|
|
||||||
"""
|
|
||||||
cross_entropy = nn.CrossEntropyLoss()
|
|
||||||
loss = cross_entropy(logits, target)
|
|
||||||
acc = acc_frame(logits, target)
|
|
||||||
return loss, acc
|
|
||||||
|
|
||||||
|
|
||||||
def criterion(type: str,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
target: torch.Tensor,
|
|
||||||
lengths: torch.Tensor,
|
|
||||||
min_duration: int = 0):
|
|
||||||
if type == 'ce':
|
|
||||||
loss, acc = cross_entropy(logits, target)
|
|
||||||
return loss, acc
|
|
||||||
elif type == 'max_pooling':
|
|
||||||
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
|
|
||||||
return loss, acc
|
|
||||||
else:
|
|
||||||
exit(1)
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
|
||||||
from kws.model.loss import criterion
|
from kws.model.loss import max_polling_loss
|
||||||
|
|
||||||
|
|
||||||
class Executor:
|
class Executor:
|
||||||
@ -44,8 +44,8 @@ class Executor:
|
|||||||
if num_utts == 0:
|
if num_utts == 0:
|
||||||
continue
|
continue
|
||||||
logits = model(feats)
|
logits = model(feats)
|
||||||
loss_type = args.get('criterion', 'max_pooling')
|
loss, acc = max_polling_loss(logits, target, feats_lengths,
|
||||||
loss, acc = criterion(loss_type, logits, target, feats_lengths)
|
min_duration)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||||
if torch.isfinite(grad_norm):
|
if torch.isfinite(grad_norm):
|
||||||
@ -64,7 +64,6 @@ class Executor:
|
|||||||
# in order to avoid division by 0
|
# in order to avoid division by 0
|
||||||
num_seen_utts = 1
|
num_seen_utts = 1
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
total_acc = 0.0
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_idx, batch in enumerate(data_loader):
|
for batch_idx, batch in enumerate(data_loader):
|
||||||
key, feats, target, feats_lengths = batch
|
key, feats, target, feats_lengths = batch
|
||||||
@ -74,19 +73,15 @@ class Executor:
|
|||||||
num_utts = feats_lengths.size(0)
|
num_utts = feats_lengths.size(0)
|
||||||
if num_utts == 0:
|
if num_utts == 0:
|
||||||
continue
|
continue
|
||||||
|
num_seen_utts += num_utts
|
||||||
logits = model(feats)
|
logits = model(feats)
|
||||||
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
loss, acc = max_polling_loss(logits, target, feats_lengths)
|
||||||
logits, target, feats_lengths)
|
|
||||||
if torch.isfinite(loss):
|
if torch.isfinite(loss):
|
||||||
num_seen_utts += num_utts
|
num_seen_utts += num_utts
|
||||||
total_loss += loss.item() * num_utts
|
total_loss += loss.item() * num_utts
|
||||||
total_acc += acc * num_utts
|
|
||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}'
|
'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}'
|
||||||
.format(epoch, batch_idx, loss.item(), acc,
|
.format(epoch, batch_idx, loss.item(), acc,
|
||||||
total_loss / num_seen_utts))
|
total_loss / num_seen_utts))
|
||||||
return total_loss / num_seen_utts, total_acc / num_seen_utts
|
return total_loss / num_seen_utts
|
||||||
|
|
||||||
def test(self, model, data_loader, device, args):
|
|
||||||
return self.cv(model, data_loader, device, args)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user