[kws] add kws base code

This commit is contained in:
Binbin Zhang 2021-11-10 18:48:57 +08:00
parent f629c0fa54
commit aa0b0c11a8
17 changed files with 1701 additions and 0 deletions

89
kws/bin/average_model.py Normal file
View File

@ -0,0 +1,89 @@
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
# Author: di.wu@mobvoi.com (DI WU)
import os
import argparse
import glob
import yaml
import numpy as np
import torch
def get_args():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument('--src_path',
required=True,
help='src model path for average')
parser.add_argument('--val_best',
action="store_true",
help='averaged model')
parser.add_argument('--num',
default=5,
type=int,
help='nums for averaged model')
parser.add_argument('--min_epoch',
default=0,
type=int,
help='min epoch used for averaging model')
parser.add_argument(
'--max_epoch',
default=65536, # Big enough
type=int,
help='max epoch used for averaging model')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
checkpoints = []
val_scores = []
if args.val_best:
yamls = glob.glob('{}/[!config]*.yaml'.format(args.src_path))
for y in yamls:
with open(y, 'r') as f:
dic_yaml = yaml.load(f, Loader=yaml.FullLoader)
print(y, dic_yaml)
loss = dic_yaml['cv_loss']
epoch = dic_yaml['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores += [[epoch, loss]]
val_scores = np.array(val_scores)
sort_idx = np.argsort(val_scores[:, -1])
sorted_val_scores = val_scores[sort_idx][::1]
print("best val scores = " + str(sorted_val_scores[:args.num, 1]))
print("selected epochs = " +
str(sorted_val_scores[:args.num, 0].astype(np.int64)))
path_list = [
args.src_path + '/{}.pt'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0]
]
else:
path_list = glob.glob('{}/[!avg][!final]*.pt'.format(args.src_path))
path_list = sorted(path_list, key=os.path.getmtime)
path_list = path_list[-args.num:]
print(path_list)
avg = None
num = args.num
assert num == len(path_list)
for path in path_list:
print('Processing {}'.format(path))
states = torch.load(path, map_location=torch.device('cpu'))
if avg is None:
avg = states
else:
for k in avg.keys():
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
# pytorch 1.6 use true_divide instead of /=
avg[k] = torch.true_divide(avg[k], num)
print('Saving to {}'.format(args.dst_model))
torch.save(avg, args.dst_model)
if __name__ == '__main__':
main()

80
kws/bin/compute_det.py Normal file
View File

@ -0,0 +1,80 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
def load_label_and_score(keyword, label_file, score_file):
score_table = {}
with open(score_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
key = arr[0]
score = float(arr[keyword + 1])
score_table[key] = score
keyword_table = {}
filler_table = {}
filler_duration = 0.0
with open(label_file, 'r', encoding='utf8') as fin:
for line in fin:
obj = json.loads(line.strip())
assert 'key' in obj
assert 'txt' in obj
assert 'duration' in obj
key = obj['key']
index = obj['txt']
duration = obj['duration']
assert key in score_table
if index == keyword:
keyword_table[key] = score_table[key]
else:
filler_table[key] = score_table[key]
filler_duration += duration
return keyword_table, filler_table, filler_duration
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='compute det curve')
parser.add_argument('--test_data', required=True, help='label file')
parser.add_argument('--keyword', type=int, default=0, help='score file')
parser.add_argument('--score_file', required=True, help='score file')
parser.add_argument('--step', type=float, default=0.01, help='score file')
parser.add_argument('--stats_file',
required=True,
help='false reject/alarm stats file')
args = parser.parse_args()
keyword_table, filler_table, filler_duration = load_label_and_score(
args.keyword, args.test_data, args.score_file)
print('Filler total duration Hours: {}'.format(filler_duration / 3600.0))
with open(args.stats_file, 'w', encoding='utf8') as fout:
threshold = 0.0
while threshold <= 1.0:
num_false_reject = 0
for key, score in keyword_table.items():
if score < threshold:
num_false_reject += 1
num_false_alarm = 0
for key, score in filler_table.items():
if score >= threshold:
num_false_alarm += 1
false_reject_rate = num_false_reject / len(keyword_table)
num_false_alarm = max(num_false_alarm, 1e-6)
false_alarm_per_hour = num_false_alarm / (filler_duration / 3600.0)
fout.write('{:.6f} {:.6f} {:.6f}\n'.format(threshold,
false_alarm_per_hour,
false_reject_rate))
threshold += args.step

69
kws/bin/export_jit.py Normal file
View File

@ -0,0 +1,69 @@
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import os
import torch
import yaml
from kws.model.kws_model import init_model
from kws.utils.checkpoint import load_checkpoint
def get_args():
parser = argparse.ArgumentParser(description='export your script model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--output_file', required=True, help='output file')
parser.add_argument('--output_quant_file',
default=None,
help='output quantized model file')
args = parser.parse_args()
return args
def main():
args = get_args()
# No need gpu for model export
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
model = init_model(configs['model'])
print(model)
load_checkpoint(model, args.checkpoint)
# Export jit torch script model
script_model = torch.jit.script(model)
script_model.save(args.output_file)
print('Export model successfully, see {}'.format(args.output_file))
# Export quantized jit torch script model
if args.output_quant_file:
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
print(quantized_model)
script_quant_model = torch.jit.script(quantized_model)
script_quant_model.save(args.output_quant_file)
print('Export quantized model successfully, '
'see {}'.format(args.output_quant_file))
if __name__ == '__main__':
main()

104
kws/bin/score.py Normal file
View File

@ -0,0 +1,104 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import torch
import yaml
from torch.utils.data import DataLoader
from kws.dataset.dataset import Dataset
from kws.model.kws_model import init_model
from kws.utils.checkpoint import load_checkpoint
from kws.utils.mask import padding_mask
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--batch_size',
default=16,
type=int,
help='batch size for inference')
parser.add_argument('--score_file',
required=True,
help='output score file')
args = parser.parse_args()
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['shuffle'] = False
test_conf['fbank_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_size'] = args.batch_size
test_dataset = Dataset(args.test_data, test_conf)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
# Init asr model from configs
model = init_model(configs['model'])
load_checkpoint(model, args.checkpoint)
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
model.eval()
with torch.no_grad(), open(args.score_file, 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, lengths = batch
feats = feats.to(device)
lengths = lengths.to(device)
mask = padding_mask(lengths).unsqueeze(2)
logits = model(feats)
logits = logits.masked_fill(mask, 0.0)
max_logits, _ = logits.max(dim=1)
max_logits = max_logits.cpu()
for i in range(len(keys)):
key = keys[i]
score = max_logits[i]
score = ' '.join([str(x) for x in score.tolist()])
fout.write('{} {}\n'.format(key, score))
if batch_idx % 10 == 0:
print('Progress batch {}'.format(batch_idx))
sys.stdout.flush()
if __name__ == '__main__':
main()

242
kws/bin/train.py Normal file
View File

@ -0,0 +1,242 @@
# Copyright (c) 2020 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import torch
import torch.distributed as dist
import torch.optim as optim
import yaml
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from kws.dataset.dataset import Dataset
from kws.utils.checkpoint import load_checkpoint, save_checkpoint
from kws.model.kws_model import init_model
from kws.utils.executor import Executor
def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this local rank, -1 for cpu')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.rank',
dest='rank',
default=0,
type=int,
help='global rank for distributed training')
parser.add_argument('--ddp.world_size',
dest='world_size',
default=-1,
type=int,
help='''number of total processes/gpus for
distributed training''')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--ddp.init_method',
dest='init_method',
default=None,
help='ddp init method')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--cmvn_file', default=None, help='global cmvn file')
parser.add_argument('--norm_var',
action='store_true',
default=False,
help='norm var option')
parser.add_argument('--num_keywords',
default=1,
type=int,
help='number of keywords')
parser.add_argument('--min_duration',
default=50,
type=int,
help='min duration frames of the keyword')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
args = parser.parse_args()
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Set random seed
torch.manual_seed(777)
print(args)
with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
distributed = args.world_size > 1
if distributed:
logging.info('training on multiple gpus, this gpu {}'.format(args.gpu))
dist.init_process_group(args.dist_backend,
init_method=args.init_method,
world_size=args.world_size,
rank=args.rank)
train_conf = configs['dataset_conf']
cv_conf = copy.deepcopy(train_conf)
cv_conf['speed_perturb'] = False
cv_conf['spec_aug'] = False
cv_conf['shuffle'] = False
train_dataset = Dataset(args.train_data, train_conf)
cv_dataset = Dataset(args.cv_data, cv_conf)
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']
output_dim = args.num_keywords
# Write model_dir/config.yaml for inference and export
configs['model']['input_dim'] = input_dim
configs['model']['output_dim'] = output_dim
if args.cmvn_file is not None:
configs['model']['cmvn'] = {}
configs['model']['cmvn']['norm_var'] = args.norm_var
configs['model']['cmvn']['cmvn_file'] = args.cmvn_file
if args.rank == 0:
saved_config_path = os.path.join(args.model_dir, 'config.yaml')
with open(saved_config_path, 'w') as fout:
data = yaml.dump(configs)
fout.write(data)
# Init asr model from configs
model = init_model(configs['model'])
print(model)
num_params = sum(p.numel() for p in model.parameters())
print('the number of model params: {}'.format(num_params))
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
if args.rank == 0:
script_model = torch.jit.script(model)
script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor()
# If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None:
infos = load_checkpoint(model, args.checkpoint)
else:
infos = {}
start_epoch = infos.get('epoch', -1) + 1
cv_loss = infos.get('cv_loss', 0.0)
model_dir = args.model_dir
writer = None
if args.rank == 0:
os.makedirs(model_dir, exist_ok=True)
exp_id = os.path.basename(model_dir)
writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
if distributed:
assert (torch.cuda.is_available())
# cuda model is required for nn.parallel.DistributedDataParallel
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(
model, find_unused_parameters=True)
device = torch.device("cuda")
else:
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=3,
min_lr=1e-6,
threshold=0.01,
)
training_config = configs['training_config']
training_config['min_duration'] = args.min_duration
num_epochs = training_config.get('max_epoch', 100)
final_epoch = None
if start_epoch == 0 and args.rank == 0:
save_model_path = os.path.join(model_dir, 'init.pt')
save_checkpoint(model, save_model_path)
# Start training loop
for epoch in range(start_epoch, num_epochs):
train_dataset.set_epoch(epoch)
training_config['epoch'] = epoch
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
executor.train(model, optimizer, train_data_loader, device, writer,
training_config)
cv_loss = executor.cv(model, cv_data_loader, device, training_config)
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
if args.rank == 0:
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
save_checkpoint(model, save_model_path, {
'epoch': epoch,
'lr': lr,
'cv_loss': cv_loss,
})
writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
writer.add_scalar('epoch/lr', lr, epoch)
final_epoch = epoch
scheduler.step(cv_loss)
if final_epoch is not None and args.rank == 0:
final_model_path = os.path.join(model_dir, 'final.pt')
os.symlink('{}.pt'.format(final_epoch), final_model_path)
writer.close()
if __name__ == '__main__':
main()

162
kws/dataset/dataset.py Normal file
View File

@ -0,0 +1,162 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
import kws.dataset.processor as processor
from kws.utils.file_utils import read_lists
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = data.copy()
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
data = data[self.rank::self.world_size]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self, lists, shuffle=True, partition=True):
self.lists = lists
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
lists = self.sampler.sample(self.lists)
for src in lists:
# yield dict(src=src)
data = dict(src=src)
data.update(sampler_info)
yield data
def Dataset(data_list_file, conf, partition=True):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
partition(bool): whether to do data partition in terms of rank
"""
lists = read_lists(data_list_file)
shuffle = conf.get('shuffle', True)
dataset = DataList(lists, shuffle=shuffle, partition=partition)
dataset = Processor(dataset, processor.parse_raw)
filter_conf = conf.get('filter_conf', {})
dataset = Processor(dataset, processor.filter, **filter_conf)
resample_conf = conf.get('resample_conf', {})
dataset = Processor(dataset, processor.resample, **resample_conf)
speed_perturb = conf.get('speed_perturb', False)
if speed_perturb:
dataset = Processor(dataset, processor.speed_perturb)
fbank_conf = conf.get('fbank_conf', {})
dataset = Processor(dataset, processor.compute_fbank, **fbank_conf)
spec_aug = conf.get('spec_aug', True)
if spec_aug:
spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)
if shuffle:
shuffle_conf = conf.get('shuffle_conf', {})
dataset = Processor(dataset, processor.shuffle, **shuffle_conf)
batch_conf = conf.get('batch_conf', {})
dataset = Processor(dataset, processor.batch, **batch_conf)
dataset = Processor(dataset, processor.padding)
return dataset
if __name__ == '__main__':
import sys
dataset = Dataset(sys.argv[1], {})
for data in dataset:
print(data)

266
kws/dataset/processor.py Normal file
View File

@ -0,0 +1,266 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import json
import random
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.nn.utils.rnn import pad_sequence
def parse_raw(data):
""" Parse key/wav/txt from json line
Args:
data: Iterable[str], str is a json line has key/wav/txt
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'src' in sample
json_line = sample['src']
obj = json.loads(json_line)
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
key = obj['key']
wav_file = obj['wav']
txt = obj['txt']
try:
waveform, sample_rate = torchaudio.load(wav_file)
example = dict(key=key,
label=txt,
wav=waveform,
sample_rate=sample_rate)
yield example
except Exception as ex:
logging.warning('Failed to read {}'.format(wav_file))
def filter(data, max_length=10240, min_length=10):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
yield sample
def resample(data, resample_rate=16000):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
if sample_rate != resample_rate:
sample['sample_rate'] = resample_rate
sample['wav'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
yield sample
def speed_perturb(data, speeds=None):
""" Apply speed perturb to the data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
speeds(List[float]): optional speed
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
if speeds is None:
speeds = [0.9, 1.0, 1.1]
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
speed = random.choice(speeds)
if speed != 1.0:
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, sample_rate,
[['speed', str(speed)], ['rate', str(sample_rate)]])
sample['wav'] = wav
yield sample
def compute_fbank(data,
num_mel_bins=23,
frame_length=25,
frame_shift=10,
dither=0.0):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
assert 'key' in sample
assert 'label' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
waveform = waveform * (1 << 15)
# Only keep key, feat, label
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sample_frequency=sample_rate)
yield dict(key=sample['key'], label=sample['label'], feat=mat)
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10):
""" Do spec augmentation
Inplace operation
Args:
data: Iterable[{key, feat, label}]
num_t_mask: number of time mask to apply
num_f_mask: number of freq mask to apply
max_t: max width of time mask
max_f: max width of freq mask
Returns
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'feat' in sample
x = sample['feat']
assert isinstance(x, torch.Tensor)
y = x.clone().detach()
max_frames = y.size(0)
max_freq = y.size(1)
# time mask
for i in range(num_t_mask):
start = random.randint(0, max_frames - 1)
length = random.randint(1, max_t)
end = min(max_frames, start + length)
y[start:end, :] = 0
# freq mask
for i in range(num_f_mask):
start = random.randint(0, max_freq - 1)
length = random.randint(1, max_f)
end = min(max_freq, start + length)
y[:, start:end] = 0
sample['feat'] = y
yield sample
def shuffle(data, shuffle_size=1000):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def padding(data):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor(
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
sorted_labels = torch.tensor([sample[i]['label'] for i in order],
dtype=torch.int64)
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)

47
kws/model/cmvn.py Normal file
View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright (c) 2020 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class GlobalCMVN(torch.nn.Module):
def __init__(self,
mean: torch.Tensor,
istd: torch.Tensor,
norm_var: bool = True):
"""
Args:
mean (torch.Tensor): mean stats
istd (torch.Tensor): inverse std, std which is 1.0 / std
"""
super().__init__()
assert mean.shape == istd.shape
self.norm_var = norm_var
# The buffer can be accessed from this module using self.mean
self.register_buffer("mean", mean)
self.register_buffer("istd", istd)
def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): (batch, max_len, feat_dim)
Returns:
(torch.Tensor): normalized feature
"""
x = x - self.mean
if self.norm_var:
x = x * self.istd
return x

101
kws/model/kws_model.py Normal file
View File

@ -0,0 +1,101 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Optional
import torch
from kws.model.cmvn import GlobalCMVN
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
from kws.utils.cmvn import load_cmvn
class KwsModel(torch.nn.Module):
""" Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim)
2. subsampling: subsampling the input, (idim, hdim)
3. body: body of the whole network, (hdim, hdim)
4. linear: a linear layer, (hdim, odim)
"""
def __init__(self, idim: int, odim: int, hdim: int,
global_cmvn: Optional[torch.nn.Module],
subsampling: torch.nn.Module, body: torch.nn.Module):
super().__init__()
self.idim = idim
self.odim = odim
self.hdim = hdim
self.global_cmvn = global_cmvn
self.subsampling = subsampling
self.body = body
self.linear = torch.nn.Linear(hdim, odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None:
x = self.global_cmvn(x)
x = self.subsampling(x)
x, _ = self.body(x)
x = self.linear(x)
x = torch.sigmoid(x)
return x
def init_model(configs):
cmvn = configs.get('cmvn', {})
if cmvn['cmvn_file'] is not None:
mean, istd = load_cmvn(cmvn['cmvn_file'])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float(), cmvn['norm_var'])
else:
global_cmvn = None
input_dim = configs['input_dim']
output_dim = configs['output_dim']
hidden_dim = configs['hidden_dim']
subsampling_type = configs['subsampling']['type']
if subsampling_type == 'linear':
subsampling = LinearSubsampling1(input_dim, hidden_dim)
elif subsampling_type == 'cnn1d_s1':
subsampling = Conv1dSubsampling1(input_dim, hidden_dim)
else:
print('Unknown subsampling type {}'.format(subsampling_type))
sys.exit(1)
body_type = configs['body']['type']
num_layers = configs['body']['num_layers']
if body_type == 'gru':
body = torch.nn.GRU(hidden_dim,
hidden_dim,
num_layers=num_layers,
batch_first=True)
elif body_type == 'tcn':
# Depthwise Separable
ds = configs['body'].get('ds', False)
if ds:
block_class = DsCnnBlock
else:
block_class = CnnBlock
kernel_size = configs['body'].get('kernel_size', 8)
dropout = configs['body'].get('drouput', 0.1)
body = TCN(num_layers, hidden_dim, kernel_size, dropout, block_class)
else:
print('Unknown body type {}'.format(body_type))
sys.exit(1)
kws_model = KwsModel(input_dim, output_dim, hidden_dim, global_cmvn,
subsampling, body)
return kws_model

83
kws/model/loss.py Normal file
View File

@ -0,0 +1,83 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from kws.utils.mask import padding_mask
def max_polling_loss(logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
""" Max-pooling loss
For keyword, select the frame with the highest posterior.
The keyword is triggered when any of the frames is triggered.
For none keyword, select the hardest frame, namely the frame
with lowest filler posterior(highest keyword posterior).
the keyword is not triggered when all frames are not triggered.
Attributes:
logits: (B, T, D), D is the number of keywords
target: (B)
lengths: (B)
min_duration: min duration of the keyword
Returns:
(float): loss of current batch
(float): accuracy of current batch
"""
mask = padding_mask(lengths)
num_utts = logits.size(0)
num_keywords = logits.size(2)
target = target.cpu()
loss = 0.0
for i in range(num_utts):
for j in range(num_keywords):
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
if target[i] == j:
# For the keyword, do max-polling
prob = logits[i, :, j]
m = mask[i].clone().detach()
m[:min_duration] = True
prob = prob.masked_fill(m, 0.0)
prob = torch.clamp(prob, 1e-8, 1.0)
max_prob = prob.max()
loss += -torch.log(max_prob)
else:
# For other keywords or filler, do min-polling
prob = 1 - logits[i, :, j]
prob = prob.masked_fill(mask[i], 1.0)
prob = torch.clamp(prob, 1e-8, 1.0)
min_prob = prob.min()
loss += -torch.log(min_prob)
loss = loss / num_utts
# Compute accuracy of current batch
mask = mask.unsqueeze(-1)
logits = logits.masked_fill(mask, 0.0)
max_logits, index = logits.max(1)
num_correct = 0
for i in range(num_utts):
max_p, idx = max_logits[i].max(0)
# Predict correct as the i'th keyword
if max_p > 0.5 and idx == target[i]:
num_correct += 1
# Predict correct as the filler, filler id < 0
if max_p < 0.5 and target[i] < 0:
num_correct += 1
acc = num_correct / num_utts
# acc = 0.0
return loss, acc

58
kws/model/subsampling.py Normal file
View File

@ -0,0 +1,58 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
# There is no right context or lookahead in our Subsampling design, so
# If there is CNN in Subsampling, it's a causal CNN.
class SubsamplingBase(torch.nn.Module):
def __init__(self):
super().__init__()
self.subsampling_rate = 1
class LinearSubsampling1(SubsamplingBase):
"""Linear transform the input without subsampling
"""
def __init__(self, idim: int, odim: int):
super().__init__()
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
# torch.nn.BatchNorm1d(odim),
torch.nn.ReLU(),
)
self.subsampling_rate = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.out(x)
return x
class Conv1dSubsampling1(SubsamplingBase):
"""Conv1d transform without subsampling
"""
def __init__(self, idim: int, odim: int):
super().__init__()
self.out = torch.nn.Sequential(
torch.nn.Conv1d(idim, odim, 3),
torch.nn.BatchNorm1d(odim),
torch.nn.ReLU(),
)
self.subsampling_rate = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.out(x)
return x

148
kws/model/tcn.py Normal file
View File

@ -0,0 +1,148 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class CnnBlock(nn.Module):
def __init__(self,
channel: int,
kernel_size: int,
dilation: int,
dropout: float = 0.1):
super().__init__()
# The CNN used here is causal convolution
self.padding = (kernel_size - 1) * dilation
self.cnn = nn.Conv1d(channel,
channel,
kernel_size,
stride=1,
dilation=dilation)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
"""
Args:
x(torch.Tensor): Input tensor (B, D, T)
Returns:
torch.Tensor(B, D, T)
"""
if cache is None:
y = F.pad(x, (self.padding, 0), value=0.0)
else:
y = torch.cat((cache, x), dim=2)
assert y.size(2) > self.padding
new_cache = y[:, :, -self.padding:]
y = self.cnn(y)
y = F.relu(y)
y = self.dropout(y)
y = y + x # residual connection
return y, new_cache
class DsCnnBlock(nn.Module):
""" Depthwise Separable Convolution
"""
def __init__(self,
channel: int,
kernel_size: int,
dilation: int,
dropout: float = 0.1):
super().__init__()
# The CNN used here is causal convolution
self.padding = (kernel_size - 1) * dilation
self.depthwise_cnn = nn.Conv1d(channel,
channel,
kernel_size,
stride=1,
dilation=dilation,
groups=channel)
self.pointwise_cnn = nn.Conv1d(channel,
channel,
kernel_size=1,
stride=1)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
"""
Args:
x(torch.Tensor): Input tensor (B, D, T)
Returns:
torch.Tensor(B, D, T)
"""
if cache is None:
y = F.pad(x, (self.padding, 0), value=0.0)
else:
y = torch.cat((cache, x), dim=2)
assert y.size(2) > self.padding
new_cache = y[:, :, -self.padding:]
y = self.depthwise_cnn(y)
y = self.pointwise_cnn(y)
y = F.relu(y)
y = self.dropout(y)
y = y + x # residual connection
return y, new_cache
class TCN(nn.Module):
def __init__(self,
num_layers: int,
channel: int,
kernel_size: int,
dropout: float = 0.1,
block_class=CnnBlock):
super().__init__()
layers = []
self.padding = 0
self.network = nn.ModuleList()
for i in range(num_layers):
dilation = 2**i
self.padding += (kernel_size - 1) * dilation
self.network.append(
block_class(channel, kernel_size, dilation, dropout))
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
"""
Args:
x (torch.Tensor): Input tensor (B, T, D)
Returns:
torch.Tensor(B, T, D)
torch.Tensor(B, D, C): C is the accumulated cache size
"""
x = x.transpose(1, 2) # (B, D, T)
out_caches = []
for block in self.network:
x, c = block(x)
out_caches.append(c)
x = x.transpose(1, 2) # (B, T, D)
new_cache = torch.cat(out_caches, dim=2)
return x, new_cache
if __name__ == '__main__':
tcn = TCN(4, 64, 8, block_class=CnnBlock)
print(tcn)
print(tcn.padding)
num_params = sum(p.numel() for p in tcn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(3, 15, 64)
y = tcn(x)

58
kws/utils/checkpoint.py Normal file
View File

@ -0,0 +1,58 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import re
import yaml
import torch
def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
if torch.cuda.is_available():
logging.info('Checkpoint: loading from checkpoint %s for GPU' % path)
checkpoint = torch.load(path)
else:
logging.info('Checkpoint: loading from checkpoint %s for CPU' % path)
checkpoint = torch.load(path, map_location='cpu')
model.load_state_dict(checkpoint)
info_path = re.sub('.pt$', '.yaml', path)
configs = {}
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
return configs
def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
'''
Args:
infos (dict or None): any info you want to save.
'''
logging.info('Checkpoint: save to checkpoint %s' % path)
if isinstance(model, torch.nn.DataParallel):
state_dict = model.module.state_dict()
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)

44
kws/utils/cmvn.py Normal file
View File

@ -0,0 +1,44 @@
#!/usr/bin/env python3
# Copyright (c) 2020 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import numpy as np
def load_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn

87
kws/utils/executor.py Normal file
View File

@ -0,0 +1,87 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
from torch.nn.utils import clip_grad_norm_
from kws.model.loss import max_polling_loss
class Executor:
def __init__(self):
self.step = 0
def train(self, model, optimizer, data_loader, device, writer, args):
''' Train one epoch
'''
model.train()
clip = args.get('grad_clip', 50.0)
log_interval = args.get('log_interval', 10)
epoch = args.get('epoch', 0)
min_duration = args.get('min_duration', 0)
num_total_batch = 0
total_loss = 0.0
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits = model(feats)
loss, acc = max_polling_loss(logits, target, feats_lengths,
min_duration)
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm):
optimizer.step()
if batch_idx % log_interval == 0:
logging.debug(
'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format(
epoch, batch_idx, loss.item(), acc))
def cv(self, model, data_loader, device, args):
''' Cross validation on
'''
model.eval()
log_interval = args.get('log_interval', 10)
epoch = args.get('epoch', 0)
# in order to avoid division by 0
num_seen_utts = 1
total_loss = 0.0
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
num_seen_utts += num_utts
logits = model(feats)
loss, acc = max_polling_loss(logits, target, feats_lengths)
if torch.isfinite(loss):
num_seen_utts += num_utts
total_loss += loss.item() * num_utts
if batch_idx % log_interval == 0:
logging.debug(
'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}'
.format(epoch, batch_idx, loss.item(), acc,
total_loss / num_seen_utts))
return total_loss / num_seen_utts

31
kws/utils/file_utils.py Normal file
View File

@ -0,0 +1,31 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def read_lists(list_file):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
lists.append(line.strip())
return lists
def read_symbol_table(symbol_table_file):
symbol_table = {}
with open(symbol_table_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
symbol_table[arr[0]] = int(arr[1])
return symbol_table

32
kws/utils/mask.py Normal file
View File

@ -0,0 +1,32 @@
# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def padding_mask(lengths: torch.Tensor) -> torch.Tensor:
"""
Examples:
>>> lengths = torch.tensor([2, 2, 3], dtype=torch.int32)
>>> mask = padding_mask(lengths)
>>> print(mask)
tensor([[False, False, True],
[False, False, True],
[False, False, False]])
"""
batch_size = lengths.size(0)
max_len = int(lengths.max().item())
seq = torch.arange(max_len, dtype=torch.int64, device=lengths.device)
seq = seq.expand(batch_size, max_len)
return seq >= lengths.unsqueeze(1)