[kws] add kws base code
This commit is contained in:
parent
f629c0fa54
commit
aa0b0c11a8
89
kws/bin/average_model.py
Normal file
89
kws/bin/average_model.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
|
||||
# Author: di.wu@mobvoi.com (DI WU)
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import yaml
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='average model')
|
||||
parser.add_argument('--dst_model', required=True, help='averaged model')
|
||||
parser.add_argument('--src_path',
|
||||
required=True,
|
||||
help='src model path for average')
|
||||
parser.add_argument('--val_best',
|
||||
action="store_true",
|
||||
help='averaged model')
|
||||
parser.add_argument('--num',
|
||||
default=5,
|
||||
type=int,
|
||||
help='nums for averaged model')
|
||||
parser.add_argument('--min_epoch',
|
||||
default=0,
|
||||
type=int,
|
||||
help='min epoch used for averaging model')
|
||||
parser.add_argument(
|
||||
'--max_epoch',
|
||||
default=65536, # Big enough
|
||||
type=int,
|
||||
help='max epoch used for averaging model')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
checkpoints = []
|
||||
val_scores = []
|
||||
if args.val_best:
|
||||
yamls = glob.glob('{}/[!config]*.yaml'.format(args.src_path))
|
||||
for y in yamls:
|
||||
with open(y, 'r') as f:
|
||||
dic_yaml = yaml.load(f, Loader=yaml.FullLoader)
|
||||
print(y, dic_yaml)
|
||||
loss = dic_yaml['cv_loss']
|
||||
epoch = dic_yaml['epoch']
|
||||
if epoch >= args.min_epoch and epoch <= args.max_epoch:
|
||||
val_scores += [[epoch, loss]]
|
||||
val_scores = np.array(val_scores)
|
||||
sort_idx = np.argsort(val_scores[:, -1])
|
||||
sorted_val_scores = val_scores[sort_idx][::1]
|
||||
print("best val scores = " + str(sorted_val_scores[:args.num, 1]))
|
||||
print("selected epochs = " +
|
||||
str(sorted_val_scores[:args.num, 0].astype(np.int64)))
|
||||
path_list = [
|
||||
args.src_path + '/{}.pt'.format(int(epoch))
|
||||
for epoch in sorted_val_scores[:args.num, 0]
|
||||
]
|
||||
else:
|
||||
path_list = glob.glob('{}/[!avg][!final]*.pt'.format(args.src_path))
|
||||
path_list = sorted(path_list, key=os.path.getmtime)
|
||||
path_list = path_list[-args.num:]
|
||||
print(path_list)
|
||||
avg = None
|
||||
num = args.num
|
||||
assert num == len(path_list)
|
||||
for path in path_list:
|
||||
print('Processing {}'.format(path))
|
||||
states = torch.load(path, map_location=torch.device('cpu'))
|
||||
if avg is None:
|
||||
avg = states
|
||||
else:
|
||||
for k in avg.keys():
|
||||
avg[k] += states[k]
|
||||
# average
|
||||
for k in avg.keys():
|
||||
if avg[k] is not None:
|
||||
# pytorch 1.6 use true_divide instead of /=
|
||||
avg[k] = torch.true_divide(avg[k], num)
|
||||
print('Saving to {}'.format(args.dst_model))
|
||||
torch.save(avg, args.dst_model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
80
kws/bin/compute_det.py
Normal file
80
kws/bin/compute_det.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def load_label_and_score(keyword, label_file, score_file):
|
||||
score_table = {}
|
||||
with open(score_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
key = arr[0]
|
||||
score = float(arr[keyword + 1])
|
||||
score_table[key] = score
|
||||
keyword_table = {}
|
||||
filler_table = {}
|
||||
filler_duration = 0.0
|
||||
with open(label_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
obj = json.loads(line.strip())
|
||||
assert 'key' in obj
|
||||
assert 'txt' in obj
|
||||
assert 'duration' in obj
|
||||
key = obj['key']
|
||||
index = obj['txt']
|
||||
duration = obj['duration']
|
||||
assert key in score_table
|
||||
if index == keyword:
|
||||
keyword_table[key] = score_table[key]
|
||||
else:
|
||||
filler_table[key] = score_table[key]
|
||||
filler_duration += duration
|
||||
return keyword_table, filler_table, filler_duration
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='compute det curve')
|
||||
parser.add_argument('--test_data', required=True, help='label file')
|
||||
parser.add_argument('--keyword', type=int, default=0, help='score file')
|
||||
parser.add_argument('--score_file', required=True, help='score file')
|
||||
parser.add_argument('--step', type=float, default=0.01, help='score file')
|
||||
parser.add_argument('--stats_file',
|
||||
required=True,
|
||||
help='false reject/alarm stats file')
|
||||
args = parser.parse_args()
|
||||
|
||||
keyword_table, filler_table, filler_duration = load_label_and_score(
|
||||
args.keyword, args.test_data, args.score_file)
|
||||
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
|
||||
|
||||
with open(args.stats_file, 'w', encoding='utf8') as fout:
|
||||
threshold = 0.0
|
||||
while threshold <= 1.0:
|
||||
num_false_reject = 0
|
||||
for key, score in keyword_table.items():
|
||||
if score < threshold:
|
||||
num_false_reject += 1
|
||||
num_false_alarm = 0
|
||||
for key, score in filler_table.items():
|
||||
if score >= threshold:
|
||||
num_false_alarm += 1
|
||||
false_reject_rate = num_false_reject / len(keyword_table)
|
||||
num_false_alarm = max(num_false_alarm, 1e-6)
|
||||
false_alarm_per_hour = num_false_alarm / (filler_duration / 3600.0)
|
||||
fout.write('{:.6f} {:.6f} {:.6f}\n'.format(threshold,
|
||||
false_alarm_per_hour,
|
||||
false_reject_rate))
|
||||
threshold += args.step
|
||||
69
kws/bin/export_jit.py
Normal file
69
kws/bin/export_jit.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from kws.model.kws_model import init_model
|
||||
from kws.utils.checkpoint import load_checkpoint
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='export your script model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||
parser.add_argument('--output_file', required=True, help='output file')
|
||||
parser.add_argument('--output_quant_file',
|
||||
default=None,
|
||||
help='output quantized model file')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
# No need gpu for model export
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
|
||||
with open(args.config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
model = init_model(configs['model'])
|
||||
print(model)
|
||||
|
||||
load_checkpoint(model, args.checkpoint)
|
||||
# Export jit torch script model
|
||||
|
||||
script_model = torch.jit.script(model)
|
||||
script_model.save(args.output_file)
|
||||
print('Export model successfully, see {}'.format(args.output_file))
|
||||
|
||||
# Export quantized jit torch script model
|
||||
if args.output_quant_file:
|
||||
quantized_model = torch.quantization.quantize_dynamic(
|
||||
model, {torch.nn.Linear}, dtype=torch.qint8
|
||||
)
|
||||
print(quantized_model)
|
||||
script_quant_model = torch.jit.script(quantized_model)
|
||||
script_quant_model.save(args.output_quant_file)
|
||||
print('Export quantized model successfully, '
|
||||
'see {}'.format(args.output_quant_file))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
104
kws/bin/score.py
Normal file
104
kws/bin/score.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from kws.dataset.dataset import Dataset
|
||||
from kws.model.kws_model import init_model
|
||||
from kws.utils.checkpoint import load_checkpoint
|
||||
from kws.utils.mask import padding_mask
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='recognize with your model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--test_data', required=True, help='test data file')
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='gpu id for this rank, -1 for cpu')
|
||||
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
|
||||
parser.add_argument('--batch_size',
|
||||
default=16,
|
||||
type=int,
|
||||
help='batch size for inference')
|
||||
parser.add_argument('--score_file',
|
||||
required=True,
|
||||
help='output score file')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||
|
||||
with open(args.config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
|
||||
test_conf = copy.deepcopy(configs['dataset_conf'])
|
||||
test_conf['filter_conf']['max_length'] = 102400
|
||||
test_conf['filter_conf']['min_length'] = 0
|
||||
test_conf['speed_perturb'] = False
|
||||
test_conf['spec_aug'] = False
|
||||
test_conf['shuffle'] = False
|
||||
test_conf['fbank_conf']['dither'] = 0.0
|
||||
test_conf['batch_conf']['batch_size'] = args.batch_size
|
||||
|
||||
test_dataset = Dataset(args.test_data, test_conf)
|
||||
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
||||
|
||||
# Init asr model from configs
|
||||
model = init_model(configs['model'])
|
||||
|
||||
load_checkpoint(model, args.checkpoint)
|
||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
model = model.to(device)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad(), open(args.score_file, 'w', encoding='utf8') as fout:
|
||||
for batch_idx, batch in enumerate(test_data_loader):
|
||||
keys, feats, target, lengths = batch
|
||||
feats = feats.to(device)
|
||||
lengths = lengths.to(device)
|
||||
mask = padding_mask(lengths).unsqueeze(2)
|
||||
logits = model(feats)
|
||||
logits = logits.masked_fill(mask, 0.0)
|
||||
max_logits, _ = logits.max(dim=1)
|
||||
max_logits = max_logits.cpu()
|
||||
for i in range(len(keys)):
|
||||
key = keys[i]
|
||||
score = max_logits[i]
|
||||
score = ' '.join([str(x) for x in score.tolist()])
|
||||
fout.write('{} {}\n'.format(key, score))
|
||||
if batch_idx % 10 == 0:
|
||||
print('Progress batch {}'.format(batch_idx))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
242
kws/bin/train.py
Normal file
242
kws/bin/train.py
Normal file
@ -0,0 +1,242 @@
|
||||
# Copyright (c) 2020 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.optim as optim
|
||||
import yaml
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from kws.dataset.dataset import Dataset
|
||||
from kws.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||
from kws.model.kws_model import init_model
|
||||
from kws.utils.executor import Executor
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='training your network')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--train_data', required=True, help='train data file')
|
||||
parser.add_argument('--cv_data', required=True, help='cv data file')
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='gpu id for this local rank, -1 for cpu')
|
||||
parser.add_argument('--model_dir', required=True, help='save model dir')
|
||||
parser.add_argument('--checkpoint', help='checkpoint model')
|
||||
parser.add_argument('--tensorboard_dir',
|
||||
default='tensorboard',
|
||||
help='tensorboard log dir')
|
||||
parser.add_argument('--ddp.rank',
|
||||
dest='rank',
|
||||
default=0,
|
||||
type=int,
|
||||
help='global rank for distributed training')
|
||||
parser.add_argument('--ddp.world_size',
|
||||
dest='world_size',
|
||||
default=-1,
|
||||
type=int,
|
||||
help='''number of total processes/gpus for
|
||||
distributed training''')
|
||||
parser.add_argument('--ddp.dist_backend',
|
||||
dest='dist_backend',
|
||||
default='nccl',
|
||||
choices=['nccl', 'gloo'],
|
||||
help='distributed backend')
|
||||
parser.add_argument('--ddp.init_method',
|
||||
dest='init_method',
|
||||
default=None,
|
||||
help='ddp init method')
|
||||
parser.add_argument('--num_workers',
|
||||
default=0,
|
||||
type=int,
|
||||
help='num of subprocess workers for reading')
|
||||
parser.add_argument('--pin_memory',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use pinned memory buffers used for reading')
|
||||
parser.add_argument('--cmvn_file', default=None, help='global cmvn file')
|
||||
parser.add_argument('--norm_var',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='norm var option')
|
||||
parser.add_argument('--num_keywords',
|
||||
default=1,
|
||||
type=int,
|
||||
help='number of keywords')
|
||||
parser.add_argument('--min_duration',
|
||||
default=50,
|
||||
type=int,
|
||||
help='min duration frames of the keyword')
|
||||
parser.add_argument('--prefetch',
|
||||
default=100,
|
||||
type=int,
|
||||
help='prefetch number')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||
# Set random seed
|
||||
torch.manual_seed(777)
|
||||
print(args)
|
||||
with open(args.config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
|
||||
distributed = args.world_size > 1
|
||||
if distributed:
|
||||
logging.info('training on multiple gpus, this gpu {}'.format(args.gpu))
|
||||
dist.init_process_group(args.dist_backend,
|
||||
init_method=args.init_method,
|
||||
world_size=args.world_size,
|
||||
rank=args.rank)
|
||||
|
||||
train_conf = configs['dataset_conf']
|
||||
cv_conf = copy.deepcopy(train_conf)
|
||||
cv_conf['speed_perturb'] = False
|
||||
cv_conf['spec_aug'] = False
|
||||
cv_conf['shuffle'] = False
|
||||
|
||||
train_dataset = Dataset(args.train_data, train_conf)
|
||||
cv_dataset = Dataset(args.cv_data, cv_conf)
|
||||
|
||||
train_data_loader = DataLoader(train_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
cv_data_loader = DataLoader(cv_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
|
||||
input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']
|
||||
output_dim = args.num_keywords
|
||||
|
||||
# Write model_dir/config.yaml for inference and export
|
||||
configs['model']['input_dim'] = input_dim
|
||||
configs['model']['output_dim'] = output_dim
|
||||
if args.cmvn_file is not None:
|
||||
configs['model']['cmvn'] = {}
|
||||
configs['model']['cmvn']['norm_var'] = args.norm_var
|
||||
configs['model']['cmvn']['cmvn_file'] = args.cmvn_file
|
||||
if args.rank == 0:
|
||||
saved_config_path = os.path.join(args.model_dir, 'config.yaml')
|
||||
with open(saved_config_path, 'w') as fout:
|
||||
data = yaml.dump(configs)
|
||||
fout.write(data)
|
||||
|
||||
# Init asr model from configs
|
||||
model = init_model(configs['model'])
|
||||
print(model)
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
|
||||
# !!!IMPORTANT!!!
|
||||
# Try to export the model by script, if fails, we should refine
|
||||
# the code to satisfy the script export requirements
|
||||
if args.rank == 0:
|
||||
script_model = torch.jit.script(model)
|
||||
script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
||||
executor = Executor()
|
||||
# If specify checkpoint, load some info from checkpoint
|
||||
if args.checkpoint is not None:
|
||||
infos = load_checkpoint(model, args.checkpoint)
|
||||
else:
|
||||
infos = {}
|
||||
start_epoch = infos.get('epoch', -1) + 1
|
||||
cv_loss = infos.get('cv_loss', 0.0)
|
||||
|
||||
model_dir = args.model_dir
|
||||
writer = None
|
||||
if args.rank == 0:
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
exp_id = os.path.basename(model_dir)
|
||||
writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
|
||||
|
||||
if distributed:
|
||||
assert (torch.cuda.is_available())
|
||||
# cuda model is required for nn.parallel.DistributedDataParallel
|
||||
model.cuda()
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, find_unused_parameters=True)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
model = model.to(device)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode='min',
|
||||
factor=0.5,
|
||||
patience=3,
|
||||
min_lr=1e-6,
|
||||
threshold=0.01,
|
||||
)
|
||||
|
||||
training_config = configs['training_config']
|
||||
training_config['min_duration'] = args.min_duration
|
||||
num_epochs = training_config.get('max_epoch', 100)
|
||||
final_epoch = None
|
||||
if start_epoch == 0 and args.rank == 0:
|
||||
save_model_path = os.path.join(model_dir, 'init.pt')
|
||||
save_checkpoint(model, save_model_path)
|
||||
|
||||
# Start training loop
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
train_dataset.set_epoch(epoch)
|
||||
training_config['epoch'] = epoch
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
||||
executor.train(model, optimizer, train_data_loader, device, writer,
|
||||
training_config)
|
||||
cv_loss = executor.cv(model, cv_data_loader, device, training_config)
|
||||
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
|
||||
|
||||
if args.rank == 0:
|
||||
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
|
||||
save_checkpoint(model, save_model_path, {
|
||||
'epoch': epoch,
|
||||
'lr': lr,
|
||||
'cv_loss': cv_loss,
|
||||
})
|
||||
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
|
||||
writer.add_scalar('epoch/lr', lr, epoch)
|
||||
final_epoch = epoch
|
||||
scheduler.step(cv_loss)
|
||||
|
||||
if final_epoch is not None and args.rank == 0:
|
||||
final_model_path = os.path.join(model_dir, 'final.pt')
|
||||
os.symlink('{}.pt'.format(final_epoch), final_model_path)
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
162
kws/dataset/dataset.py
Normal file
162
kws/dataset/dataset.py
Normal file
@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
import kws.dataset.processor as processor
|
||||
from kws.utils.file_utils import read_lists
|
||||
|
||||
|
||||
class Processor(IterableDataset):
|
||||
def __init__(self, source, f, *args, **kw):
|
||||
assert callable(f)
|
||||
self.source = source
|
||||
self.f = f
|
||||
self.args = args
|
||||
self.kw = kw
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.source.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
""" Return an iterator over the source dataset processed by the
|
||||
given processor.
|
||||
"""
|
||||
assert self.source is not None
|
||||
assert callable(self.f)
|
||||
return self.f(iter(self.source), *self.args, **self.kw)
|
||||
|
||||
def apply(self, f):
|
||||
assert callable(f)
|
||||
return Processor(self, f, *self.args, **self.kw)
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
def __init__(self, shuffle=True, partition=True):
|
||||
self.epoch = -1
|
||||
self.update()
|
||||
self.shuffle = shuffle
|
||||
self.partition = partition
|
||||
|
||||
def update(self):
|
||||
assert dist.is_available()
|
||||
if dist.is_initialized():
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
else:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
self.worker_id = 0
|
||||
self.num_workers = 1
|
||||
else:
|
||||
self.worker_id = worker_info.id
|
||||
self.num_workers = worker_info.num_workers
|
||||
return dict(rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
num_workers=self.num_workers)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def sample(self, data):
|
||||
""" Sample data according to rank/world_size/num_workers
|
||||
|
||||
Args:
|
||||
data(List): input data list
|
||||
|
||||
Returns:
|
||||
List: data list after sample
|
||||
"""
|
||||
data = data.copy()
|
||||
if self.partition:
|
||||
if self.shuffle:
|
||||
random.Random(self.epoch).shuffle(data)
|
||||
data = data[self.rank::self.world_size]
|
||||
data = data[self.worker_id::self.num_workers]
|
||||
return data
|
||||
|
||||
|
||||
class DataList(IterableDataset):
|
||||
def __init__(self, lists, shuffle=True, partition=True):
|
||||
self.lists = lists
|
||||
self.sampler = DistributedSampler(shuffle, partition)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.sampler.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
sampler_info = self.sampler.update()
|
||||
lists = self.sampler.sample(self.lists)
|
||||
for src in lists:
|
||||
# yield dict(src=src)
|
||||
data = dict(src=src)
|
||||
data.update(sampler_info)
|
||||
yield data
|
||||
|
||||
|
||||
def Dataset(data_list_file, conf, partition=True):
|
||||
""" Construct dataset from arguments
|
||||
|
||||
We have two shuffle stage in the Dataset. The first is global
|
||||
shuffle at shards tar/raw file level. The second is global shuffle
|
||||
at training samples level.
|
||||
|
||||
Args:
|
||||
data_type(str): raw/shard
|
||||
partition(bool): whether to do data partition in terms of rank
|
||||
"""
|
||||
lists = read_lists(data_list_file)
|
||||
shuffle = conf.get('shuffle', True)
|
||||
dataset = DataList(lists, shuffle=shuffle, partition=partition)
|
||||
dataset = Processor(dataset, processor.parse_raw)
|
||||
filter_conf = conf.get('filter_conf', {})
|
||||
dataset = Processor(dataset, processor.filter, **filter_conf)
|
||||
|
||||
resample_conf = conf.get('resample_conf', {})
|
||||
dataset = Processor(dataset, processor.resample, **resample_conf)
|
||||
|
||||
speed_perturb = conf.get('speed_perturb', False)
|
||||
if speed_perturb:
|
||||
dataset = Processor(dataset, processor.speed_perturb)
|
||||
|
||||
fbank_conf = conf.get('fbank_conf', {})
|
||||
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
|
||||
|
||||
spec_aug = conf.get('spec_aug', True)
|
||||
if spec_aug:
|
||||
spec_aug_conf = conf.get('spec_aug_conf', {})
|
||||
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
|
||||
|
||||
if shuffle:
|
||||
shuffle_conf = conf.get('shuffle_conf', {})
|
||||
dataset = Processor(dataset, processor.shuffle, **shuffle_conf)
|
||||
|
||||
batch_conf = conf.get('batch_conf', {})
|
||||
dataset = Processor(dataset, processor.batch, **batch_conf)
|
||||
dataset = Processor(dataset, processor.padding)
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
dataset = Dataset(sys.argv[1], {})
|
||||
for data in dataset:
|
||||
print(data)
|
||||
266
kws/dataset/processor.py
Normal file
266
kws/dataset/processor.py
Normal file
@ -0,0 +1,266 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import json
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def parse_raw(data):
|
||||
""" Parse key/wav/txt from json line
|
||||
|
||||
Args:
|
||||
data: Iterable[str], str is a json line has key/wav/txt
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'src' in sample
|
||||
json_line = sample['src']
|
||||
obj = json.loads(json_line)
|
||||
assert 'key' in obj
|
||||
assert 'wav' in obj
|
||||
assert 'txt' in obj
|
||||
key = obj['key']
|
||||
wav_file = obj['wav']
|
||||
txt = obj['txt']
|
||||
try:
|
||||
waveform, sample_rate = torchaudio.load(wav_file)
|
||||
example = dict(key=key,
|
||||
label=txt,
|
||||
wav=waveform,
|
||||
sample_rate=sample_rate)
|
||||
yield example
|
||||
except Exception as ex:
|
||||
logging.warning('Failed to read {}'.format(wav_file))
|
||||
|
||||
|
||||
def filter(data, max_length=10240, min_length=10):
|
||||
""" Filter sample according to feature and label length
|
||||
Inplace operation.
|
||||
|
||||
Args::
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
max_length: drop utterance which is greater than max_length(10ms)
|
||||
min_length: drop utterance which is less than min_length(10ms)
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
||||
num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
|
||||
if num_frames < min_length:
|
||||
continue
|
||||
if num_frames > max_length:
|
||||
continue
|
||||
yield sample
|
||||
|
||||
|
||||
def resample(data, resample_rate=16000):
|
||||
""" Resample data.
|
||||
Inplace operation.
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
resample_rate: target resample rate
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
if sample_rate != resample_rate:
|
||||
sample['sample_rate'] = resample_rate
|
||||
sample['wav'] = torchaudio.transforms.Resample(
|
||||
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
||||
yield sample
|
||||
|
||||
|
||||
def speed_perturb(data, speeds=None):
|
||||
""" Apply speed perturb to the data.
|
||||
Inplace operation.
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
speeds(List[float]): optional speed
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
if speeds is None:
|
||||
speeds = [0.9, 1.0, 1.1]
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
speed = random.choice(speeds)
|
||||
if speed != 1.0:
|
||||
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
|
||||
waveform, sample_rate,
|
||||
[['speed', str(speed)], ['rate', str(sample_rate)]])
|
||||
sample['wav'] = wav
|
||||
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_fbank(data,
|
||||
num_mel_bins=23,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0):
|
||||
""" Extract fbank
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'wav' in sample
|
||||
assert 'key' in sample
|
||||
assert 'label' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['wav']
|
||||
waveform = waveform * (1 << 15)
|
||||
# Only keep key, feat, label
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=num_mel_bins,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
dither=dither,
|
||||
energy_floor=0.0,
|
||||
sample_frequency=sample_rate)
|
||||
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
||||
|
||||
|
||||
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10):
|
||||
""" Do spec augmentation
|
||||
Inplace operation
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
num_t_mask: number of time mask to apply
|
||||
num_f_mask: number of freq mask to apply
|
||||
max_t: max width of time mask
|
||||
max_f: max width of freq mask
|
||||
|
||||
Returns
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'feat' in sample
|
||||
x = sample['feat']
|
||||
assert isinstance(x, torch.Tensor)
|
||||
y = x.clone().detach()
|
||||
max_frames = y.size(0)
|
||||
max_freq = y.size(1)
|
||||
# time mask
|
||||
for i in range(num_t_mask):
|
||||
start = random.randint(0, max_frames - 1)
|
||||
length = random.randint(1, max_t)
|
||||
end = min(max_frames, start + length)
|
||||
y[start:end, :] = 0
|
||||
# freq mask
|
||||
for i in range(num_f_mask):
|
||||
start = random.randint(0, max_freq - 1)
|
||||
length = random.randint(1, max_f)
|
||||
end = min(max_freq, start + length)
|
||||
y[:, start:end] = 0
|
||||
sample['feat'] = y
|
||||
yield sample
|
||||
|
||||
|
||||
def shuffle(data, shuffle_size=1000):
|
||||
""" Local shuffle the data
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
shuffle_size: buffer size for shuffle
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
buf = []
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= shuffle_size:
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
yield x
|
||||
buf = []
|
||||
# The sample left over
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
yield x
|
||||
|
||||
|
||||
def batch(data, batch_size=16):
|
||||
""" Static batch the data by `batch_size`
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
batch_size: batch size
|
||||
|
||||
Returns:
|
||||
Iterable[List[{key, feat, label}]]
|
||||
"""
|
||||
buf = []
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= batch_size:
|
||||
yield buf
|
||||
buf = []
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
||||
|
||||
def padding(data):
|
||||
""" Padding the data into training data
|
||||
|
||||
Args:
|
||||
data: Iterable[List[{key, feat, label}]]
|
||||
|
||||
Returns:
|
||||
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
||||
"""
|
||||
for sample in data:
|
||||
assert isinstance(sample, list)
|
||||
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
|
||||
dtype=torch.int32)
|
||||
order = torch.argsort(feats_length, descending=True)
|
||||
feats_lengths = torch.tensor(
|
||||
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
|
||||
sorted_feats = [sample[i]['feat'] for i in order]
|
||||
sorted_keys = [sample[i]['key'] for i in order]
|
||||
sorted_labels = torch.tensor([sample[i]['label'] for i in order],
|
||||
dtype=torch.int64)
|
||||
padded_feats = pad_sequence(sorted_feats,
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)
|
||||
47
kws/model/cmvn.py
Normal file
47
kws/model/cmvn.py
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2020 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class GlobalCMVN(torch.nn.Module):
|
||||
def __init__(self,
|
||||
mean: torch.Tensor,
|
||||
istd: torch.Tensor,
|
||||
norm_var: bool = True):
|
||||
"""
|
||||
Args:
|
||||
mean (torch.Tensor): mean stats
|
||||
istd (torch.Tensor): inverse std, std which is 1.0 / std
|
||||
"""
|
||||
super().__init__()
|
||||
assert mean.shape == istd.shape
|
||||
self.norm_var = norm_var
|
||||
# The buffer can be accessed from this module using self.mean
|
||||
self.register_buffer("mean", mean)
|
||||
self.register_buffer("istd", istd)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (batch, max_len, feat_dim)
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): normalized feature
|
||||
"""
|
||||
x = x - self.mean
|
||||
if self.norm_var:
|
||||
x = x * self.istd
|
||||
return x
|
||||
101
kws/model/kws_model.py
Normal file
101
kws/model/kws_model.py
Normal file
@ -0,0 +1,101 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from kws.model.cmvn import GlobalCMVN
|
||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
|
||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||
from kws.utils.cmvn import load_cmvn
|
||||
|
||||
|
||||
class KwsModel(torch.nn.Module):
|
||||
""" Our model consists of four parts:
|
||||
1. global_cmvn: Optional, (idim, idim)
|
||||
2. subsampling: subsampling the input, (idim, hdim)
|
||||
3. body: body of the whole network, (hdim, hdim)
|
||||
4. linear: a linear layer, (hdim, odim)
|
||||
"""
|
||||
def __init__(self, idim: int, odim: int, hdim: int,
|
||||
global_cmvn: Optional[torch.nn.Module],
|
||||
subsampling: torch.nn.Module, body: torch.nn.Module):
|
||||
super().__init__()
|
||||
self.idim = idim
|
||||
self.odim = odim
|
||||
self.hdim = hdim
|
||||
self.global_cmvn = global_cmvn
|
||||
self.subsampling = subsampling
|
||||
self.body = body
|
||||
self.linear = torch.nn.Linear(hdim, odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.global_cmvn is not None:
|
||||
x = self.global_cmvn(x)
|
||||
x = self.subsampling(x)
|
||||
x, _ = self.body(x)
|
||||
x = self.linear(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_model(configs):
|
||||
cmvn = configs.get('cmvn', {})
|
||||
if cmvn['cmvn_file'] is not None:
|
||||
mean, istd = load_cmvn(cmvn['cmvn_file'])
|
||||
global_cmvn = GlobalCMVN(
|
||||
torch.from_numpy(mean).float(),
|
||||
torch.from_numpy(istd).float(), cmvn['norm_var'])
|
||||
else:
|
||||
global_cmvn = None
|
||||
|
||||
input_dim = configs['input_dim']
|
||||
output_dim = configs['output_dim']
|
||||
hidden_dim = configs['hidden_dim']
|
||||
|
||||
subsampling_type = configs['subsampling']['type']
|
||||
if subsampling_type == 'linear':
|
||||
subsampling = LinearSubsampling1(input_dim, hidden_dim)
|
||||
elif subsampling_type == 'cnn1d_s1':
|
||||
subsampling = Conv1dSubsampling1(input_dim, hidden_dim)
|
||||
else:
|
||||
print('Unknown subsampling type {}'.format(subsampling_type))
|
||||
sys.exit(1)
|
||||
|
||||
body_type = configs['body']['type']
|
||||
num_layers = configs['body']['num_layers']
|
||||
if body_type == 'gru':
|
||||
body = torch.nn.GRU(hidden_dim,
|
||||
hidden_dim,
|
||||
num_layers=num_layers,
|
||||
batch_first=True)
|
||||
elif body_type == 'tcn':
|
||||
# Depthwise Separable
|
||||
ds = configs['body'].get('ds', False)
|
||||
if ds:
|
||||
block_class = DsCnnBlock
|
||||
else:
|
||||
block_class = CnnBlock
|
||||
kernel_size = configs['body'].get('kernel_size', 8)
|
||||
dropout = configs['body'].get('drouput', 0.1)
|
||||
body = TCN(num_layers, hidden_dim, kernel_size, dropout, block_class)
|
||||
else:
|
||||
print('Unknown body type {}'.format(body_type))
|
||||
sys.exit(1)
|
||||
|
||||
kws_model = KwsModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
subsampling, body)
|
||||
return kws_model
|
||||
83
kws/model/loss.py
Normal file
83
kws/model/loss.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from kws.utils.mask import padding_mask
|
||||
|
||||
|
||||
def max_polling_loss(logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
min_duration: int = 0):
|
||||
""" Max-pooling loss
|
||||
For keyword, select the frame with the highest posterior.
|
||||
The keyword is triggered when any of the frames is triggered.
|
||||
For none keyword, select the hardest frame, namely the frame
|
||||
with lowest filler posterior(highest keyword posterior).
|
||||
the keyword is not triggered when all frames are not triggered.
|
||||
|
||||
Attributes:
|
||||
logits: (B, T, D), D is the number of keywords
|
||||
target: (B)
|
||||
lengths: (B)
|
||||
min_duration: min duration of the keyword
|
||||
Returns:
|
||||
(float): loss of current batch
|
||||
(float): accuracy of current batch
|
||||
"""
|
||||
mask = padding_mask(lengths)
|
||||
num_utts = logits.size(0)
|
||||
num_keywords = logits.size(2)
|
||||
|
||||
target = target.cpu()
|
||||
loss = 0.0
|
||||
for i in range(num_utts):
|
||||
for j in range(num_keywords):
|
||||
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
|
||||
if target[i] == j:
|
||||
# For the keyword, do max-polling
|
||||
prob = logits[i, :, j]
|
||||
m = mask[i].clone().detach()
|
||||
m[:min_duration] = True
|
||||
prob = prob.masked_fill(m, 0.0)
|
||||
prob = torch.clamp(prob, 1e-8, 1.0)
|
||||
max_prob = prob.max()
|
||||
loss += -torch.log(max_prob)
|
||||
else:
|
||||
# For other keywords or filler, do min-polling
|
||||
prob = 1 - logits[i, :, j]
|
||||
prob = prob.masked_fill(mask[i], 1.0)
|
||||
prob = torch.clamp(prob, 1e-8, 1.0)
|
||||
min_prob = prob.min()
|
||||
loss += -torch.log(min_prob)
|
||||
loss = loss / num_utts
|
||||
|
||||
# Compute accuracy of current batch
|
||||
mask = mask.unsqueeze(-1)
|
||||
logits = logits.masked_fill(mask, 0.0)
|
||||
max_logits, index = logits.max(1)
|
||||
num_correct = 0
|
||||
for i in range(num_utts):
|
||||
max_p, idx = max_logits[i].max(0)
|
||||
# Predict correct as the i'th keyword
|
||||
if max_p > 0.5 and idx == target[i]:
|
||||
num_correct += 1
|
||||
# Predict correct as the filler, filler id < 0
|
||||
if max_p < 0.5 and target[i] < 0:
|
||||
num_correct += 1
|
||||
acc = num_correct / num_utts
|
||||
# acc = 0.0
|
||||
return loss, acc
|
||||
58
kws/model/subsampling.py
Normal file
58
kws/model/subsampling.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
# There is no right context or lookahead in our Subsampling design, so
|
||||
# If there is CNN in Subsampling, it's a causal CNN.
|
||||
|
||||
|
||||
class SubsamplingBase(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.subsampling_rate = 1
|
||||
|
||||
|
||||
class LinearSubsampling1(SubsamplingBase):
|
||||
"""Linear transform the input without subsampling
|
||||
"""
|
||||
def __init__(self, idim: int, odim: int):
|
||||
super().__init__()
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, odim),
|
||||
# torch.nn.BatchNorm1d(odim),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.subsampling_rate = 1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class Conv1dSubsampling1(SubsamplingBase):
|
||||
"""Conv1d transform without subsampling
|
||||
"""
|
||||
def __init__(self, idim: int, odim: int):
|
||||
super().__init__()
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(idim, odim, 3),
|
||||
torch.nn.BatchNorm1d(odim),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.subsampling_rate = 1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.out(x)
|
||||
return x
|
||||
148
kws/model/tcn.py
Normal file
148
kws/model/tcn.py
Normal file
@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CnnBlock(nn.Module):
|
||||
def __init__(self,
|
||||
channel: int,
|
||||
kernel_size: int,
|
||||
dilation: int,
|
||||
dropout: float = 0.1):
|
||||
super().__init__()
|
||||
# The CNN used here is causal convolution
|
||||
self.padding = (kernel_size - 1) * dilation
|
||||
self.cnn = nn.Conv1d(channel,
|
||||
channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=dilation)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Args:
|
||||
x(torch.Tensor): Input tensor (B, D, T)
|
||||
Returns:
|
||||
torch.Tensor(B, D, T)
|
||||
"""
|
||||
if cache is None:
|
||||
y = F.pad(x, (self.padding, 0), value=0.0)
|
||||
else:
|
||||
y = torch.cat((cache, x), dim=2)
|
||||
assert y.size(2) > self.padding
|
||||
new_cache = y[:, :, -self.padding:]
|
||||
|
||||
y = self.cnn(y)
|
||||
y = F.relu(y)
|
||||
y = self.dropout(y)
|
||||
y = y + x # residual connection
|
||||
return y, new_cache
|
||||
|
||||
|
||||
class DsCnnBlock(nn.Module):
|
||||
""" Depthwise Separable Convolution
|
||||
"""
|
||||
def __init__(self,
|
||||
channel: int,
|
||||
kernel_size: int,
|
||||
dilation: int,
|
||||
dropout: float = 0.1):
|
||||
super().__init__()
|
||||
# The CNN used here is causal convolution
|
||||
self.padding = (kernel_size - 1) * dilation
|
||||
self.depthwise_cnn = nn.Conv1d(channel,
|
||||
channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=dilation,
|
||||
groups=channel)
|
||||
self.pointwise_cnn = nn.Conv1d(channel,
|
||||
channel,
|
||||
kernel_size=1,
|
||||
stride=1)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Args:
|
||||
x(torch.Tensor): Input tensor (B, D, T)
|
||||
Returns:
|
||||
torch.Tensor(B, D, T)
|
||||
"""
|
||||
if cache is None:
|
||||
y = F.pad(x, (self.padding, 0), value=0.0)
|
||||
else:
|
||||
y = torch.cat((cache, x), dim=2)
|
||||
assert y.size(2) > self.padding
|
||||
new_cache = y[:, :, -self.padding:]
|
||||
|
||||
y = self.depthwise_cnn(y)
|
||||
y = self.pointwise_cnn(y)
|
||||
y = F.relu(y)
|
||||
y = self.dropout(y)
|
||||
y = y + x # residual connection
|
||||
return y, new_cache
|
||||
|
||||
|
||||
class TCN(nn.Module):
|
||||
def __init__(self,
|
||||
num_layers: int,
|
||||
channel: int,
|
||||
kernel_size: int,
|
||||
dropout: float = 0.1,
|
||||
block_class=CnnBlock):
|
||||
super().__init__()
|
||||
layers = []
|
||||
self.padding = 0
|
||||
self.network = nn.ModuleList()
|
||||
for i in range(num_layers):
|
||||
dilation = 2**i
|
||||
self.padding += (kernel_size - 1) * dilation
|
||||
self.network.append(
|
||||
block_class(channel, kernel_size, dilation, dropout))
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (B, T, D)
|
||||
|
||||
Returns:
|
||||
torch.Tensor(B, T, D)
|
||||
torch.Tensor(B, D, C): C is the accumulated cache size
|
||||
"""
|
||||
x = x.transpose(1, 2) # (B, D, T)
|
||||
out_caches = []
|
||||
for block in self.network:
|
||||
x, c = block(x)
|
||||
out_caches.append(c)
|
||||
x = x.transpose(1, 2) # (B, T, D)
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
return x, new_cache
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tcn = TCN(4, 64, 8, block_class=CnnBlock)
|
||||
print(tcn)
|
||||
print(tcn.padding)
|
||||
num_params = sum(p.numel() for p in tcn.parameters())
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
x = torch.zeros(3, 15, 64)
|
||||
y = tcn(x)
|
||||
58
kws/utils/checkpoint.py
Normal file
58
kws/utils/checkpoint.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
|
||||
|
||||
def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
|
||||
if torch.cuda.is_available():
|
||||
logging.info('Checkpoint: loading from checkpoint %s for GPU' % path)
|
||||
checkpoint = torch.load(path)
|
||||
else:
|
||||
logging.info('Checkpoint: loading from checkpoint %s for CPU' % path)
|
||||
checkpoint = torch.load(path, map_location='cpu')
|
||||
model.load_state_dict(checkpoint)
|
||||
info_path = re.sub('.pt$', '.yaml', path)
|
||||
configs = {}
|
||||
if os.path.exists(info_path):
|
||||
with open(info_path, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
return configs
|
||||
|
||||
|
||||
def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
|
||||
'''
|
||||
Args:
|
||||
infos (dict or None): any info you want to save.
|
||||
'''
|
||||
logging.info('Checkpoint: save to checkpoint %s' % path)
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
info_path = re.sub('.pt$', '.yaml', path)
|
||||
if infos is None:
|
||||
infos = {}
|
||||
with open(info_path, 'w') as fout:
|
||||
data = yaml.dump(infos)
|
||||
fout.write(data)
|
||||
44
kws/utils/cmvn.py
Normal file
44
kws/utils/cmvn.py
Normal file
@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2020 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_cmvn(json_cmvn_file):
|
||||
""" Load the json format cmvn stats file and calculate cmvn
|
||||
|
||||
Args:
|
||||
json_cmvn_file: cmvn stats file in json format
|
||||
|
||||
Returns:
|
||||
a numpy array of [means, vars]
|
||||
"""
|
||||
with open(json_cmvn_file) as f:
|
||||
cmvn_stats = json.load(f)
|
||||
|
||||
means = cmvn_stats['mean_stat']
|
||||
variance = cmvn_stats['var_stat']
|
||||
count = cmvn_stats['frame_num']
|
||||
for i in range(len(means)):
|
||||
means[i] /= count
|
||||
variance[i] = variance[i] / count - means[i] * means[i]
|
||||
if variance[i] < 1.0e-20:
|
||||
variance[i] = 1.0e-20
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
87
kws/utils/executor.py
Normal file
87
kws/utils/executor.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from kws.model.loss import max_polling_loss
|
||||
|
||||
|
||||
class Executor:
|
||||
def __init__(self):
|
||||
self.step = 0
|
||||
|
||||
def train(self, model, optimizer, data_loader, device, writer, args):
|
||||
''' Train one epoch
|
||||
'''
|
||||
model.train()
|
||||
clip = args.get('grad_clip', 50.0)
|
||||
log_interval = args.get('log_interval', 10)
|
||||
epoch = args.get('epoch', 0)
|
||||
min_duration = args.get('min_duration', 0)
|
||||
|
||||
num_total_batch = 0
|
||||
total_loss = 0.0
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
key, feats, target, feats_lengths = batch
|
||||
feats = feats.to(device)
|
||||
target = target.to(device)
|
||||
feats_lengths = feats_lengths.to(device)
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits = model(feats)
|
||||
loss, acc = max_polling_loss(logits, target, feats_lengths,
|
||||
min_duration)
|
||||
loss.backward()
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
if torch.isfinite(grad_norm):
|
||||
optimizer.step()
|
||||
if batch_idx % log_interval == 0:
|
||||
logging.debug(
|
||||
'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format(
|
||||
epoch, batch_idx, loss.item(), acc))
|
||||
|
||||
def cv(self, model, data_loader, device, args):
|
||||
''' Cross validation on
|
||||
'''
|
||||
model.eval()
|
||||
log_interval = args.get('log_interval', 10)
|
||||
epoch = args.get('epoch', 0)
|
||||
# in order to avoid division by 0
|
||||
num_seen_utts = 1
|
||||
total_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
key, feats, target, feats_lengths = batch
|
||||
feats = feats.to(device)
|
||||
target = target.to(device)
|
||||
feats_lengths = feats_lengths.to(device)
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
num_seen_utts += num_utts
|
||||
logits = model(feats)
|
||||
loss, acc = max_polling_loss(logits, target, feats_lengths)
|
||||
if torch.isfinite(loss):
|
||||
num_seen_utts += num_utts
|
||||
total_loss += loss.item() * num_utts
|
||||
if batch_idx % log_interval == 0:
|
||||
logging.debug(
|
||||
'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}'
|
||||
.format(epoch, batch_idx, loss.item(), acc,
|
||||
total_loss / num_seen_utts))
|
||||
return total_loss / num_seen_utts
|
||||
31
kws/utils/file_utils.py
Normal file
31
kws/utils/file_utils.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
def read_lists(list_file):
|
||||
lists = []
|
||||
with open(list_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
lists.append(line.strip())
|
||||
return lists
|
||||
|
||||
|
||||
def read_symbol_table(symbol_table_file):
|
||||
symbol_table = {}
|
||||
with open(symbol_table_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
arr = line.strip().split()
|
||||
assert len(arr) == 2
|
||||
symbol_table[arr[0]] = int(arr[1])
|
||||
return symbol_table
|
||||
32
kws/utils/mask.py
Normal file
32
kws/utils/mask.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) 2021 Binbin Zhang
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def padding_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Examples:
|
||||
>>> lengths = torch.tensor([2, 2, 3], dtype=torch.int32)
|
||||
>>> mask = padding_mask(lengths)
|
||||
>>> print(mask)
|
||||
tensor([[False, False, True],
|
||||
[False, False, True],
|
||||
[False, False, False]])
|
||||
"""
|
||||
batch_size = lengths.size(0)
|
||||
max_len = int(lengths.max().item())
|
||||
seq = torch.arange(max_len, dtype=torch.int64, device=lengths.device)
|
||||
seq = seq.expand(batch_size, max_len)
|
||||
return seq >= lengths.unsqueeze(1)
|
||||
Loading…
x
Reference in New Issue
Block a user