diff --git a/kws/bin/average_model.py b/kws/bin/average_model.py new file mode 100644 index 0000000..8a9a477 --- /dev/null +++ b/kws/bin/average_model.py @@ -0,0 +1,89 @@ +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# Author: di.wu@mobvoi.com (DI WU) +import os +import argparse +import glob + +import yaml +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument('--src_path', + required=True, + help='src model path for average') + parser.add_argument('--val_best', + action="store_true", + help='averaged model') + parser.add_argument('--num', + default=5, + type=int, + help='nums for averaged model') + parser.add_argument('--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + checkpoints = [] + val_scores = [] + if args.val_best: + yamls = glob.glob('{}/[!config]*.yaml'.format(args.src_path)) + for y in yamls: + with open(y, 'r') as f: + dic_yaml = yaml.load(f, Loader=yaml.FullLoader) + print(y, dic_yaml) + loss = dic_yaml['cv_loss'] + epoch = dic_yaml['epoch'] + if epoch >= args.min_epoch and epoch <= args.max_epoch: + val_scores += [[epoch, loss]] + val_scores = np.array(val_scores) + sort_idx = np.argsort(val_scores[:, -1]) + sorted_val_scores = val_scores[sort_idx][::1] + print("best val scores = " + str(sorted_val_scores[:args.num, 1])) + print("selected epochs = " + + str(sorted_val_scores[:args.num, 0].astype(np.int64))) + path_list = [ + args.src_path + '/{}.pt'.format(int(epoch)) + for epoch in sorted_val_scores[:args.num, 0] + ] + else: + path_list = glob.glob('{}/[!avg][!final]*.pt'.format(args.src_path)) + path_list = sorted(path_list, key=os.path.getmtime) + path_list = path_list[-args.num:] + print(path_list) + avg = None + num = args.num + assert num == len(path_list) + for path in path_list: + print('Processing {}'.format(path)) + states = torch.load(path, map_location=torch.device('cpu')) + if avg is None: + avg = states + else: + for k in avg.keys(): + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + # pytorch 1.6 use true_divide instead of /= + avg[k] = torch.true_divide(avg[k], num) + print('Saving to {}'.format(args.dst_model)) + torch.save(avg, args.dst_model) + + +if __name__ == '__main__': + main() diff --git a/kws/bin/compute_det.py b/kws/bin/compute_det.py new file mode 100644 index 0000000..53e3079 --- /dev/null +++ b/kws/bin/compute_det.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + + +def load_label_and_score(keyword, label_file, score_file): + score_table = {} + with open(score_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + key = arr[0] + score = float(arr[keyword + 1]) + score_table[key] = score + keyword_table = {} + filler_table = {} + filler_duration = 0.0 + with open(label_file, 'r', encoding='utf8') as fin: + for line in fin: + obj = json.loads(line.strip()) + assert 'key' in obj + assert 'txt' in obj + assert 'duration' in obj + key = obj['key'] + index = obj['txt'] + duration = obj['duration'] + assert key in score_table + if index == keyword: + keyword_table[key] = score_table[key] + else: + filler_table[key] = score_table[key] + filler_duration += duration + return keyword_table, filler_table, filler_duration + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='compute det curve') + parser.add_argument('--test_data', required=True, help='label file') + parser.add_argument('--keyword', type=int, default=0, help='score file') + parser.add_argument('--score_file', required=True, help='score file') + parser.add_argument('--step', type=float, default=0.01, help='score file') + parser.add_argument('--stats_file', + required=True, + help='false reject/alarm stats file') + args = parser.parse_args() + + keyword_table, filler_table, filler_duration = load_label_and_score( + args.keyword, args.test_data, args.score_file) + print('Filler total duration Hours: {}'.format(filler_duration / 3600.0)) + + with open(args.stats_file, 'w', encoding='utf8') as fout: + threshold = 0.0 + while threshold <= 1.0: + num_false_reject = 0 + for key, score in keyword_table.items(): + if score < threshold: + num_false_reject += 1 + num_false_alarm = 0 + for key, score in filler_table.items(): + if score >= threshold: + num_false_alarm += 1 + false_reject_rate = num_false_reject / len(keyword_table) + num_false_alarm = max(num_false_alarm, 1e-6) + false_alarm_per_hour = num_false_alarm / (filler_duration / 3600.0) + fout.write('{:.6f} {:.6f} {:.6f}\n'.format(threshold, + false_alarm_per_hour, + false_reject_rate)) + threshold += args.step diff --git a/kws/bin/export_jit.py b/kws/bin/export_jit.py new file mode 100644 index 0000000..7c1846e --- /dev/null +++ b/kws/bin/export_jit.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import os + +import torch +import yaml + +from kws.model.kws_model import init_model +from kws.utils.checkpoint import load_checkpoint + + +def get_args(): + parser = argparse.ArgumentParser(description='export your script model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--output_file', required=True, help='output file') + parser.add_argument('--output_quant_file', + default=None, + help='output quantized model file') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + # No need gpu for model export + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + model = init_model(configs['model']) + print(model) + + load_checkpoint(model, args.checkpoint) + # Export jit torch script model + + script_model = torch.jit.script(model) + script_model.save(args.output_file) + print('Export model successfully, see {}'.format(args.output_file)) + + # Export quantized jit torch script model + if args.output_quant_file: + quantized_model = torch.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8 + ) + print(quantized_model) + script_quant_model = torch.jit.script(quantized_model) + script_quant_model.save(args.output_quant_file) + print('Export quantized model successfully, ' + 'see {}'.format(args.output_quant_file)) + + +if __name__ == '__main__': + main() diff --git a/kws/bin/score.py b/kws/bin/score.py new file mode 100644 index 0000000..0e8c18f --- /dev/null +++ b/kws/bin/score.py @@ -0,0 +1,104 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader + +from kws.dataset.dataset import Dataset +from kws.model.kws_model import init_model +from kws.utils.checkpoint import load_checkpoint +from kws.utils.mask import padding_mask + + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--batch_size', + default=16, + type=int, + help='batch size for inference') + parser.add_argument('--score_file', + required=True, + help='output score file') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + test_conf = copy.deepcopy(configs['dataset_conf']) + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['shuffle'] = False + test_conf['fbank_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_size'] = args.batch_size + + test_dataset = Dataset(args.test_data, test_conf) + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + # Init asr model from configs + model = init_model(configs['model']) + + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + + model.eval() + with torch.no_grad(), open(args.score_file, 'w', encoding='utf8') as fout: + for batch_idx, batch in enumerate(test_data_loader): + keys, feats, target, lengths = batch + feats = feats.to(device) + lengths = lengths.to(device) + mask = padding_mask(lengths).unsqueeze(2) + logits = model(feats) + logits = logits.masked_fill(mask, 0.0) + max_logits, _ = logits.max(dim=1) + max_logits = max_logits.cpu() + for i in range(len(keys)): + key = keys[i] + score = max_logits[i] + score = ' '.join([str(x) for x in score.tolist()]) + fout.write('{} {}\n'.format(key, score)) + if batch_idx % 10 == 0: + print('Progress batch {}'.format(batch_idx)) + sys.stdout.flush() + + +if __name__ == '__main__': + main() diff --git a/kws/bin/train.py b/kws/bin/train.py new file mode 100644 index 0000000..150d333 --- /dev/null +++ b/kws/bin/train.py @@ -0,0 +1,242 @@ +# Copyright (c) 2020 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os + +import torch +import torch.distributed as dist +import torch.optim as optim +import yaml +from tensorboardX import SummaryWriter +from torch.utils.data import DataLoader + +from kws.dataset.dataset import Dataset +from kws.utils.checkpoint import load_checkpoint, save_checkpoint +from kws.model.kws_model import init_model +from kws.utils.executor import Executor + + +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this local rank, -1 for cpu') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.rank', + dest='rank', + default=0, + type=int, + help='global rank for distributed training') + parser.add_argument('--ddp.world_size', + dest='world_size', + default=-1, + type=int, + help='''number of total processes/gpus for + distributed training''') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--ddp.init_method', + dest='init_method', + default=None, + help='ddp init method') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--cmvn_file', default=None, help='global cmvn file') + parser.add_argument('--norm_var', + action='store_true', + default=False, + help='norm var option') + parser.add_argument('--num_keywords', + default=1, + type=int, + help='number of keywords') + parser.add_argument('--min_duration', + default=50, + type=int, + help='min duration frames of the keyword') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + + args = parser.parse_args() + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + # Set random seed + torch.manual_seed(777) + print(args) + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + distributed = args.world_size > 1 + if distributed: + logging.info('training on multiple gpus, this gpu {}'.format(args.gpu)) + dist.init_process_group(args.dist_backend, + init_method=args.init_method, + world_size=args.world_size, + rank=args.rank) + + train_conf = configs['dataset_conf'] + cv_conf = copy.deepcopy(train_conf) + cv_conf['speed_perturb'] = False + cv_conf['spec_aug'] = False + cv_conf['shuffle'] = False + + train_dataset = Dataset(args.train_data, train_conf) + cv_dataset = Dataset(args.cv_data, cv_conf) + + train_data_loader = DataLoader(train_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + cv_data_loader = DataLoader(cv_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + + input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] + output_dim = args.num_keywords + + # Write model_dir/config.yaml for inference and export + configs['model']['input_dim'] = input_dim + configs['model']['output_dim'] = output_dim + if args.cmvn_file is not None: + configs['model']['cmvn'] = {} + configs['model']['cmvn']['norm_var'] = args.norm_var + configs['model']['cmvn']['cmvn_file'] = args.cmvn_file + if args.rank == 0: + saved_config_path = os.path.join(args.model_dir, 'config.yaml') + with open(saved_config_path, 'w') as fout: + data = yaml.dump(configs) + fout.write(data) + + # Init asr model from configs + model = init_model(configs['model']) + print(model) + num_params = sum(p.numel() for p in model.parameters()) + print('the number of model params: {}'.format(num_params)) + + # !!!IMPORTANT!!! + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements + if args.rank == 0: + script_model = torch.jit.script(model) + script_model.save(os.path.join(args.model_dir, 'init.zip')) + executor = Executor() + # If specify checkpoint, load some info from checkpoint + if args.checkpoint is not None: + infos = load_checkpoint(model, args.checkpoint) + else: + infos = {} + start_epoch = infos.get('epoch', -1) + 1 + cv_loss = infos.get('cv_loss', 0.0) + + model_dir = args.model_dir + writer = None + if args.rank == 0: + os.makedirs(model_dir, exist_ok=True) + exp_id = os.path.basename(model_dir) + writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) + + if distributed: + assert (torch.cuda.is_available()) + # cuda model is required for nn.parallel.DistributedDataParallel + model.cuda() + model = torch.nn.parallel.DistributedDataParallel( + model, find_unused_parameters=True) + device = torch.device("cuda") + else: + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + + optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode='min', + factor=0.5, + patience=3, + min_lr=1e-6, + threshold=0.01, + ) + + training_config = configs['training_config'] + training_config['min_duration'] = args.min_duration + num_epochs = training_config.get('max_epoch', 100) + final_epoch = None + if start_epoch == 0 and args.rank == 0: + save_model_path = os.path.join(model_dir, 'init.pt') + save_checkpoint(model, save_model_path) + + # Start training loop + for epoch in range(start_epoch, num_epochs): + train_dataset.set_epoch(epoch) + training_config['epoch'] = epoch + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) + executor.train(model, optimizer, train_data_loader, device, writer, + training_config) + cv_loss = executor.cv(model, cv_data_loader, device, training_config) + logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) + + if args.rank == 0: + save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) + save_checkpoint(model, save_model_path, { + 'epoch': epoch, + 'lr': lr, + 'cv_loss': cv_loss, + }) + writer.add_scalar('epoch/cv_loss', cv_loss, epoch) + writer.add_scalar('epoch/lr', lr, epoch) + final_epoch = epoch + scheduler.step(cv_loss) + + if final_epoch is not None and args.rank == 0: + final_model_path = os.path.join(model_dir, 'final.pt') + os.symlink('{}.pt'.format(final_epoch), final_model_path) + writer.close() + + +if __name__ == '__main__': + main() diff --git a/kws/dataset/dataset.py b/kws/dataset/dataset.py new file mode 100644 index 0000000..ffff668 --- /dev/null +++ b/kws/dataset/dataset.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset + +import kws.dataset.processor as processor +from kws.utils.file_utils import read_lists + + +class Processor(IterableDataset): + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = data.copy() + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + data = data[self.rank::self.world_size] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + lists = self.sampler.sample(self.lists) + for src in lists: + # yield dict(src=src) + data = dict(src=src) + data.update(sampler_info) + yield data + + +def Dataset(data_list_file, conf, partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + partition(bool): whether to do data partition in terms of rank + """ + lists = read_lists(data_list_file) + shuffle = conf.get('shuffle', True) + dataset = DataList(lists, shuffle=shuffle, partition=partition) + dataset = Processor(dataset, processor.parse_raw) + filter_conf = conf.get('filter_conf', {}) + dataset = Processor(dataset, processor.filter, **filter_conf) + + resample_conf = conf.get('resample_conf', {}) + dataset = Processor(dataset, processor.resample, **resample_conf) + + speed_perturb = conf.get('speed_perturb', False) + if speed_perturb: + dataset = Processor(dataset, processor.speed_perturb) + + fbank_conf = conf.get('fbank_conf', {}) + dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) + + spec_aug = conf.get('spec_aug', True) + if spec_aug: + spec_aug_conf = conf.get('spec_aug_conf', {}) + dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = Processor(dataset, processor.shuffle, **shuffle_conf) + + batch_conf = conf.get('batch_conf', {}) + dataset = Processor(dataset, processor.batch, **batch_conf) + dataset = Processor(dataset, processor.padding) + return dataset + + +if __name__ == '__main__': + import sys + dataset = Dataset(sys.argv[1], {}) + for data in dataset: + print(data) diff --git a/kws/dataset/processor.py b/kws/dataset/processor.py new file mode 100644 index 0000000..65aba99 --- /dev/null +++ b/kws/dataset/processor.py @@ -0,0 +1,266 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import json +import random + +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +from torch.nn.utils.rnn import pad_sequence + + +def parse_raw(data): + """ Parse key/wav/txt from json line + + Args: + data: Iterable[str], str is a json line has key/wav/txt + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'src' in sample + json_line = sample['src'] + obj = json.loads(json_line) + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + key = obj['key'] + wav_file = obj['wav'] + txt = obj['txt'] + try: + waveform, sample_rate = torchaudio.load(wav_file) + example = dict(key=key, + label=txt, + wav=waveform, + sample_rate=sample_rate) + yield example + except Exception as ex: + logging.warning('Failed to read {}'.format(wav_file)) + + +def filter(data, max_length=10240, min_length=10): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + yield sample + + +def resample(data, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + yield sample + + +def speed_perturb(data, speeds=None): + """ Apply speed perturb to the data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + speeds(List[float]): optional speed + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + if speeds is None: + speeds = [0.9, 1.0, 1.1] + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + speed = random.choice(speeds) + if speed != 1.0: + wav, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, sample_rate, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + sample['wav'] = wav + + yield sample + + +def compute_fbank(data, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sample_frequency=sample_rate) + yield dict(key=sample['key'], label=sample['label'], feat=mat) + + +def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10): + """ Do spec augmentation + Inplace operation + + Args: + data: Iterable[{key, feat, label}] + num_t_mask: number of time mask to apply + num_f_mask: number of freq mask to apply + max_t: max width of time mask + max_f: max width of freq mask + + Returns + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + max_freq = y.size(1) + # time mask + for i in range(num_t_mask): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + y[start:end, :] = 0 + # freq mask + for i in range(num_f_mask): + start = random.randint(0, max_freq - 1) + length = random.randint(1, max_f) + end = min(max_freq, start + length) + y[:, start:end] = 0 + sample['feat'] = y + yield sample + + +def shuffle(data, shuffle_size=1000): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def padding(data): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + feats_length = torch.tensor([x['feat'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(feats_length, descending=True) + feats_lengths = torch.tensor( + [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['key'] for i in order] + sorted_labels = torch.tensor([sample[i]['label'] for i in order], + dtype=torch.int64) + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + yield (sorted_keys, padded_feats, sorted_labels, feats_lengths) diff --git a/kws/model/cmvn.py b/kws/model/cmvn.py new file mode 100644 index 0000000..2e211aa --- /dev/null +++ b/kws/model/cmvn.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2020 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class GlobalCMVN(torch.nn.Module): + def __init__(self, + mean: torch.Tensor, + istd: torch.Tensor, + norm_var: bool = True): + """ + Args: + mean (torch.Tensor): mean stats + istd (torch.Tensor): inverse std, std which is 1.0 / std + """ + super().__init__() + assert mean.shape == istd.shape + self.norm_var = norm_var + # The buffer can be accessed from this module using self.mean + self.register_buffer("mean", mean) + self.register_buffer("istd", istd) + + def forward(self, x: torch.Tensor): + """ + Args: + x (torch.Tensor): (batch, max_len, feat_dim) + + Returns: + (torch.Tensor): normalized feature + """ + x = x - self.mean + if self.norm_var: + x = x * self.istd + return x diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py new file mode 100644 index 0000000..314822c --- /dev/null +++ b/kws/model/kws_model.py @@ -0,0 +1,101 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Optional + +import torch + +from kws.model.cmvn import GlobalCMVN +from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1 +from kws.model.tcn import TCN, CnnBlock, DsCnnBlock +from kws.utils.cmvn import load_cmvn + + +class KwsModel(torch.nn.Module): + """ Our model consists of four parts: + 1. global_cmvn: Optional, (idim, idim) + 2. subsampling: subsampling the input, (idim, hdim) + 3. body: body of the whole network, (hdim, hdim) + 4. linear: a linear layer, (hdim, odim) + """ + def __init__(self, idim: int, odim: int, hdim: int, + global_cmvn: Optional[torch.nn.Module], + subsampling: torch.nn.Module, body: torch.nn.Module): + super().__init__() + self.idim = idim + self.odim = odim + self.hdim = hdim + self.global_cmvn = global_cmvn + self.subsampling = subsampling + self.body = body + self.linear = torch.nn.Linear(hdim, odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.global_cmvn is not None: + x = self.global_cmvn(x) + x = self.subsampling(x) + x, _ = self.body(x) + x = self.linear(x) + x = torch.sigmoid(x) + return x + + +def init_model(configs): + cmvn = configs.get('cmvn', {}) + if cmvn['cmvn_file'] is not None: + mean, istd = load_cmvn(cmvn['cmvn_file']) + global_cmvn = GlobalCMVN( + torch.from_numpy(mean).float(), + torch.from_numpy(istd).float(), cmvn['norm_var']) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + output_dim = configs['output_dim'] + hidden_dim = configs['hidden_dim'] + + subsampling_type = configs['subsampling']['type'] + if subsampling_type == 'linear': + subsampling = LinearSubsampling1(input_dim, hidden_dim) + elif subsampling_type == 'cnn1d_s1': + subsampling = Conv1dSubsampling1(input_dim, hidden_dim) + else: + print('Unknown subsampling type {}'.format(subsampling_type)) + sys.exit(1) + + body_type = configs['body']['type'] + num_layers = configs['body']['num_layers'] + if body_type == 'gru': + body = torch.nn.GRU(hidden_dim, + hidden_dim, + num_layers=num_layers, + batch_first=True) + elif body_type == 'tcn': + # Depthwise Separable + ds = configs['body'].get('ds', False) + if ds: + block_class = DsCnnBlock + else: + block_class = CnnBlock + kernel_size = configs['body'].get('kernel_size', 8) + dropout = configs['body'].get('drouput', 0.1) + body = TCN(num_layers, hidden_dim, kernel_size, dropout, block_class) + else: + print('Unknown body type {}'.format(body_type)) + sys.exit(1) + + kws_model = KwsModel(input_dim, output_dim, hidden_dim, global_cmvn, + subsampling, body) + return kws_model diff --git a/kws/model/loss.py b/kws/model/loss.py new file mode 100644 index 0000000..e871ef1 --- /dev/null +++ b/kws/model/loss.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + +from kws.utils.mask import padding_mask + + +def max_polling_loss(logits: torch.Tensor, + target: torch.Tensor, + lengths: torch.Tensor, + min_duration: int = 0): + """ Max-pooling loss + For keyword, select the frame with the highest posterior. + The keyword is triggered when any of the frames is triggered. + For none keyword, select the hardest frame, namely the frame + with lowest filler posterior(highest keyword posterior). + the keyword is not triggered when all frames are not triggered. + + Attributes: + logits: (B, T, D), D is the number of keywords + target: (B) + lengths: (B) + min_duration: min duration of the keyword + Returns: + (float): loss of current batch + (float): accuracy of current batch + """ + mask = padding_mask(lengths) + num_utts = logits.size(0) + num_keywords = logits.size(2) + + target = target.cpu() + loss = 0.0 + for i in range(num_utts): + for j in range(num_keywords): + # Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p)) + if target[i] == j: + # For the keyword, do max-polling + prob = logits[i, :, j] + m = mask[i].clone().detach() + m[:min_duration] = True + prob = prob.masked_fill(m, 0.0) + prob = torch.clamp(prob, 1e-8, 1.0) + max_prob = prob.max() + loss += -torch.log(max_prob) + else: + # For other keywords or filler, do min-polling + prob = 1 - logits[i, :, j] + prob = prob.masked_fill(mask[i], 1.0) + prob = torch.clamp(prob, 1e-8, 1.0) + min_prob = prob.min() + loss += -torch.log(min_prob) + loss = loss / num_utts + + # Compute accuracy of current batch + mask = mask.unsqueeze(-1) + logits = logits.masked_fill(mask, 0.0) + max_logits, index = logits.max(1) + num_correct = 0 + for i in range(num_utts): + max_p, idx = max_logits[i].max(0) + # Predict correct as the i'th keyword + if max_p > 0.5 and idx == target[i]: + num_correct += 1 + # Predict correct as the filler, filler id < 0 + if max_p < 0.5 and target[i] < 0: + num_correct += 1 + acc = num_correct / num_utts + # acc = 0.0 + return loss, acc diff --git a/kws/model/subsampling.py b/kws/model/subsampling.py new file mode 100644 index 0000000..044ef24 --- /dev/null +++ b/kws/model/subsampling.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +# There is no right context or lookahead in our Subsampling design, so +# If there is CNN in Subsampling, it's a causal CNN. + + +class SubsamplingBase(torch.nn.Module): + def __init__(self): + super().__init__() + self.subsampling_rate = 1 + + +class LinearSubsampling1(SubsamplingBase): + """Linear transform the input without subsampling + """ + def __init__(self, idim: int, odim: int): + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + # torch.nn.BatchNorm1d(odim), + torch.nn.ReLU(), + ) + self.subsampling_rate = 1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.out(x) + return x + + +class Conv1dSubsampling1(SubsamplingBase): + """Conv1d transform without subsampling + """ + def __init__(self, idim: int, odim: int): + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Conv1d(idim, odim, 3), + torch.nn.BatchNorm1d(odim), + torch.nn.ReLU(), + ) + self.subsampling_rate = 1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.out(x) + return x diff --git a/kws/model/tcn.py b/kws/model/tcn.py new file mode 100644 index 0000000..7dffa5b --- /dev/null +++ b/kws/model/tcn.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CnnBlock(nn.Module): + def __init__(self, + channel: int, + kernel_size: int, + dilation: int, + dropout: float = 0.1): + super().__init__() + # The CNN used here is causal convolution + self.padding = (kernel_size - 1) * dilation + self.cnn = nn.Conv1d(channel, + channel, + kernel_size, + stride=1, + dilation=dilation) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + """ + Args: + x(torch.Tensor): Input tensor (B, D, T) + Returns: + torch.Tensor(B, D, T) + """ + if cache is None: + y = F.pad(x, (self.padding, 0), value=0.0) + else: + y = torch.cat((cache, x), dim=2) + assert y.size(2) > self.padding + new_cache = y[:, :, -self.padding:] + + y = self.cnn(y) + y = F.relu(y) + y = self.dropout(y) + y = y + x # residual connection + return y, new_cache + + +class DsCnnBlock(nn.Module): + """ Depthwise Separable Convolution + """ + def __init__(self, + channel: int, + kernel_size: int, + dilation: int, + dropout: float = 0.1): + super().__init__() + # The CNN used here is causal convolution + self.padding = (kernel_size - 1) * dilation + self.depthwise_cnn = nn.Conv1d(channel, + channel, + kernel_size, + stride=1, + dilation=dilation, + groups=channel) + self.pointwise_cnn = nn.Conv1d(channel, + channel, + kernel_size=1, + stride=1) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + """ + Args: + x(torch.Tensor): Input tensor (B, D, T) + Returns: + torch.Tensor(B, D, T) + """ + if cache is None: + y = F.pad(x, (self.padding, 0), value=0.0) + else: + y = torch.cat((cache, x), dim=2) + assert y.size(2) > self.padding + new_cache = y[:, :, -self.padding:] + + y = self.depthwise_cnn(y) + y = self.pointwise_cnn(y) + y = F.relu(y) + y = self.dropout(y) + y = y + x # residual connection + return y, new_cache + + +class TCN(nn.Module): + def __init__(self, + num_layers: int, + channel: int, + kernel_size: int, + dropout: float = 0.1, + block_class=CnnBlock): + super().__init__() + layers = [] + self.padding = 0 + self.network = nn.ModuleList() + for i in range(num_layers): + dilation = 2**i + self.padding += (kernel_size - 1) * dilation + self.network.append( + block_class(channel, kernel_size, dilation, dropout)) + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + """ + Args: + x (torch.Tensor): Input tensor (B, T, D) + + Returns: + torch.Tensor(B, T, D) + torch.Tensor(B, D, C): C is the accumulated cache size + """ + x = x.transpose(1, 2) # (B, D, T) + out_caches = [] + for block in self.network: + x, c = block(x) + out_caches.append(c) + x = x.transpose(1, 2) # (B, T, D) + new_cache = torch.cat(out_caches, dim=2) + return x, new_cache + + +if __name__ == '__main__': + tcn = TCN(4, 64, 8, block_class=CnnBlock) + print(tcn) + print(tcn.padding) + num_params = sum(p.numel() for p in tcn.parameters()) + print('the number of model params: {}'.format(num_params)) + x = torch.zeros(3, 15, 64) + y = tcn(x) diff --git a/kws/utils/checkpoint.py b/kws/utils/checkpoint.py new file mode 100644 index 0000000..4bbd3e2 --- /dev/null +++ b/kws/utils/checkpoint.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +import re + +import yaml +import torch + + +def load_checkpoint(model: torch.nn.Module, path: str) -> dict: + if torch.cuda.is_available(): + logging.info('Checkpoint: loading from checkpoint %s for GPU' % path) + checkpoint = torch.load(path) + else: + logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) + checkpoint = torch.load(path, map_location='cpu') + model.load_state_dict(checkpoint) + info_path = re.sub('.pt$', '.yaml', path) + configs = {} + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + return configs + + +def save_checkpoint(model: torch.nn.Module, path: str, infos=None): + ''' + Args: + infos (dict or None): any info you want to save. + ''' + logging.info('Checkpoint: save to checkpoint %s' % path) + if isinstance(model, torch.nn.DataParallel): + state_dict = model.module.state_dict() + elif isinstance(model, torch.nn.parallel.DistributedDataParallel): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save(state_dict, path) + info_path = re.sub('.pt$', '.yaml', path) + if infos is None: + infos = {} + with open(info_path, 'w') as fout: + data = yaml.dump(infos) + fout.write(data) diff --git a/kws/utils/cmvn.py b/kws/utils/cmvn.py new file mode 100644 index 0000000..1d2ecbd --- /dev/null +++ b/kws/utils/cmvn.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# Copyright (c) 2020 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math + +import numpy as np + + +def load_cmvn(json_cmvn_file): + """ Load the json format cmvn stats file and calculate cmvn + + Args: + json_cmvn_file: cmvn stats file in json format + + Returns: + a numpy array of [means, vars] + """ + with open(json_cmvn_file) as f: + cmvn_stats = json.load(f) + + means = cmvn_stats['mean_stat'] + variance = cmvn_stats['var_stat'] + count = cmvn_stats['frame_num'] + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn diff --git a/kws/utils/executor.py b/kws/utils/executor.py new file mode 100644 index 0000000..2f8fe47 --- /dev/null +++ b/kws/utils/executor.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from torch.nn.utils import clip_grad_norm_ + +from kws.model.loss import max_polling_loss + + +class Executor: + def __init__(self): + self.step = 0 + + def train(self, model, optimizer, data_loader, device, writer, args): + ''' Train one epoch + ''' + model.train() + clip = args.get('grad_clip', 50.0) + log_interval = args.get('log_interval', 10) + epoch = args.get('epoch', 0) + min_duration = args.get('min_duration', 0) + + num_total_batch = 0 + total_loss = 0.0 + for batch_idx, batch in enumerate(data_loader): + key, feats, target, feats_lengths = batch + feats = feats.to(device) + target = target.to(device) + feats_lengths = feats_lengths.to(device) + num_utts = feats_lengths.size(0) + if num_utts == 0: + continue + logits = model(feats) + loss, acc = max_polling_loss(logits, target, feats_lengths, + min_duration) + loss.backward() + grad_norm = clip_grad_norm_(model.parameters(), clip) + if torch.isfinite(grad_norm): + optimizer.step() + if batch_idx % log_interval == 0: + logging.debug( + 'TRAIN Batch {}/{} loss {:.8f} acc {:.8f}'.format( + epoch, batch_idx, loss.item(), acc)) + + def cv(self, model, data_loader, device, args): + ''' Cross validation on + ''' + model.eval() + log_interval = args.get('log_interval', 10) + epoch = args.get('epoch', 0) + # in order to avoid division by 0 + num_seen_utts = 1 + total_loss = 0.0 + with torch.no_grad(): + for batch_idx, batch in enumerate(data_loader): + key, feats, target, feats_lengths = batch + feats = feats.to(device) + target = target.to(device) + feats_lengths = feats_lengths.to(device) + num_utts = feats_lengths.size(0) + if num_utts == 0: + continue + num_seen_utts += num_utts + logits = model(feats) + loss, acc = max_polling_loss(logits, target, feats_lengths) + if torch.isfinite(loss): + num_seen_utts += num_utts + total_loss += loss.item() * num_utts + if batch_idx % log_interval == 0: + logging.debug( + 'CV Batch {}/{} loss {:.8f} acc {:.8f} history loss {:.8f}' + .format(epoch, batch_idx, loss.item(), acc, + total_loss / num_seen_utts)) + return total_loss / num_seen_utts diff --git a/kws/utils/file_utils.py b/kws/utils/file_utils.py new file mode 100644 index 0000000..06f9d55 --- /dev/null +++ b/kws/utils/file_utils.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def read_lists(list_file): + lists = [] + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + lists.append(line.strip()) + return lists + + +def read_symbol_table(symbol_table_file): + symbol_table = {} + with open(symbol_table_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + symbol_table[arr[0]] = int(arr[1]) + return symbol_table diff --git a/kws/utils/mask.py b/kws/utils/mask.py new file mode 100644 index 0000000..78b2df6 --- /dev/null +++ b/kws/utils/mask.py @@ -0,0 +1,32 @@ +# Copyright (c) 2021 Binbin Zhang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def padding_mask(lengths: torch.Tensor) -> torch.Tensor: + """ + Examples: + >>> lengths = torch.tensor([2, 2, 3], dtype=torch.int32) + >>> mask = padding_mask(lengths) + >>> print(mask) + tensor([[False, False, True], + [False, False, True], + [False, False, False]]) + """ + batch_size = lengths.size(0) + max_len = int(lengths.max().item()) + seq = torch.arange(max_len, dtype=torch.int64, device=lengths.device) + seq = seq.expand(batch_size, max_len) + return seq >= lengths.unsqueeze(1)