diff --git a/examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml b/examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml new file mode 100644 index 0000000..ba7c716 --- /dev/null +++ b/examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml @@ -0,0 +1,50 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'fbank' + num_mel_bins: 40 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 1 + num_f_mask: 1 + max_t: 20 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 256 + +model: + hidden_dim: 256 + preprocessing: + type: linear + backbone: + type: tcn + ds: true + num_layers: 4 + kernel_size: 8 + dropout: 0.1 + activation: + type: identity + + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 0.0001 + +training_config: + grad_clip: 5 + max_epoch: 80 + log_interval: 10 + criterion: ctc + diff --git a/examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml b/examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml new file mode 100644 index 0000000..6eb753f --- /dev/null +++ b/examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml @@ -0,0 +1,50 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'fbank' + num_mel_bins: 40 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 1 + num_f_mask: 1 + max_t: 20 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 200 + +model: + hidden_dim: 256 + preprocessing: + type: linear + backbone: + type: tcn + ds: true + num_layers: 4 + kernel_size: 8 + dropout: 0.1 + activation: + type: identity + + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 0.0001 + +training_config: + grad_clip: 5 + max_epoch: 30 + log_interval: 100 + criterion: ctc + diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 9573302..49b96d6 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -111,6 +111,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --score_file $result_dir/score.txt \ --stats_file $result_dir/stats.${keyword}.txt done + + # plot det curve + python wekws/bin/plot_det_curve.py \ + --keywords_dict dict/words.txt \ + --stats_dir $result_dir \ + --figure_file $result_dir/det.png fi diff --git a/examples/hi_xiaowen/s0/run_ctc.sh b/examples/hi_xiaowen/s0/run_ctc.sh new file mode 100644 index 0000000..af848bb --- /dev/null +++ b/examples/hi_xiaowen/s0/run_ctc.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# Copyright 2021 Binbin Zhang(binbzha@qq.com) +# 2023 dujing(thuduj12@163.com) + +. ./path.sh + +stage=$1 +stop_stage=$2 +num_keywords=2599 + +config=conf/ds_tcn_ctc.yaml +norm_mean=true +norm_var=true +gpus="4,5,6,7" + +checkpoint= +dir=exp/ds_tcn_ctc_ft + +num_average=30 +score_checkpoint=$dir/avg_${num_average}.pt + +download_dir=/data/megastore/Datasets/ASR/KWS/nihaoxiaowen # your data dir + +. tools/parse_options.sh || exit 1; +window_shift=50 + +#Whether to train base model. If set true, must put train+dev data in trainbase_dir +trainbase=true +trainbase_dir=data/base +trainbase_config=conf/ds_tcn_ctc_base.yaml +trainbase_exp=exp/base +if $trainbase; then + checkpoint=$trainbase_exp/final.pt +fi + +if [ ${stage} -le -3 ] && [ ${stop_stage} -ge -3 ]; then + echo "Download and extracte all datasets" + local/mobvoi_data_download.sh --dl_dir $download_dir +fi + + +if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then + echo "Preparing datasets..." + mkdir -p dict + echo " -1" > dict/words.txt + echo "Hi_Xiaowen 0" >> dict/words.txt + echo "Nihao_Wenwen 1" >> dict/words.txt + + for folder in train dev test; do + mkdir -p data/$folder + for prefix in p n; do + mkdir -p data/${prefix}_$folder + json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json + local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \ + data/${prefix}_$folder + done + cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp + cat data/p_$folder/text data/n_$folder/text > data/$folder/text + rm -rf data/p_$folder data/n_$folder + done +fi + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then +# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) +# to transcribe the negative wavs, and upload the transcription to modelscope. + git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git + for folder in train dev test; do + if [ -f data/$folder/text ];then + mv data/$folder/text data/$folder/text.label + fi + cp mobvoi_kws_transcription/$folder.text data/$folder/text + done + + # and we also copy the tokens and lexicon that used in + # https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary + cp mobvoi_kws_transcription/tokens.txt data/tokens.txt + cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt + +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "Compute CMVN and Format datasets" + tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \ + --in_scp data/train/wav.scp \ + --out_cmvn data/train/global_cmvn + + for x in train dev test; do + tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur + + # Here we use tokens.txt and lexicon.txt to convert txt into index + tools/make_list.py data/$x/wav.scp data/$x/text \ + data/$x/wav.dur data/$x/data.list \ + --token_file data/tokens.txt \ + --lexicon_file data/lexicon.txt + done +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then + for x in train dev ; do + if [ ! -f $trainbase_dir/$x/wav.scp ] || [ ! -f $trainbase_dir/$x/text ]; then + echo "If You Want to Train Base KWS-CTC Model, You Should Prepare ASR Data by Yourself." + echo "The wav.scp and text in KALDI-format is Needed, You Should Put Them in $trainbase_dir/$x" + exit + fi + if [ ! -f $trainbase_dir/$x/wav.dur ]; then + tools/wav_to_duration.sh --nj 8 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur + fi + + # Here we use tokens.txt and lexicon.txt to convert txt into index + if [ ! -f $trainbase_dir/$x/data.list ]; then + tools/make_list.py $trainbase_dir/$x/wav.scp $trainbase_dir/$x/text \ + $trainbase_dir/$x/wav.dur $trainbase_dir/$x/data.list \ + --token_file data/tokens.txt \ + --lexicon_file data/lexicon.txt + fi + done + + echo "Start base training ..." + mkdir -p $trainbase_exp + cmvn_opts= + $norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn" + $norm_var && cmvn_opts="$cmvn_opts --norm_var" + num_gpus=$(echo $gpus | awk -F ',' '{print NF}') + torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ + wekws/bin/train.py --gpus $gpus \ + --config $trainbase_config \ + --train_data $trainbase_dir/train/data.list \ + --cv_data $trainbase_dir/dev/data.list \ + --model_dir $trainbase_exp \ + --num_workers 8 \ + --num_keywords $num_keywords \ + --min_duration 50 \ + --seed 666 \ + $cmvn_opts +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Start training ..." + mkdir -p $dir + cmvn_opts= + $norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn" + $norm_var && cmvn_opts="$cmvn_opts --norm_var" + num_gpus=$(echo $gpus | awk -F ',' '{print NF}') + torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ + wekws/bin/train.py --gpus $gpus \ + --config $config \ + --train_data data/train/data.list \ + --cv_data data/dev/data.list \ + --model_dir $dir \ + --num_workers 8 \ + --num_keywords $num_keywords \ + --min_duration 50 \ + --seed 666 \ + $cmvn_opts \ + ${checkpoint:+--checkpoint $checkpoint} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Do model average, Compute FRR/FAR ..." + python wekws/bin/average_model.py \ + --dst_model $score_checkpoint \ + --src_path $dir \ + --num ${num_average} \ + --val_best + result_dir=$dir/test_$(basename $score_checkpoint) + mkdir -p $result_dir + python wekws/bin/score_ctc.py \ + --config $dir/config.yaml \ + --test_data data/test/data.list \ + --batch_size 256 \ + --checkpoint $score_checkpoint \ + --score_file $result_dir/score.txt \ + --num_workers 8 \ + --keywords 嗨小问,你好问问 \ + --token_file data/tokens.txt \ + --lexicon_file data/lexicon.txt + + python wekws/bin/compute_det_ctc.py \ + --keyword 嗨小问,你好问问 \ + --test_data data/test/data.list \ + --window_shift $window_shift \ + --step 0.001 \ + --score_file $result_dir/score.txt +fi + + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') + onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g') + python wekws/bin/export_jit.py \ + --config $dir/config.yaml \ + --checkpoint $score_checkpoint \ + --jit_model $dir/$jit_model + python wekws/bin/export_onnx.py \ + --config $dir/config.yaml \ + --checkpoint $score_checkpoint \ + --onnx_model $dir/$onnx_model +fi diff --git a/tools/make_list.py b/tools/make_list.py index 5825761..60928c2 100755 --- a/tools/make_list.py +++ b/tools/make_list.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# 2023 Jing Du(thuduj12@163.com) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +15,142 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse -import json +import argparse, logging +import json, re + +symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+' + +def split_mixed_label(input_str): + tokens = [] + s = input_str.lower() + while len(s) > 0: + match = re.match(r'[A-Za-z!?,<>()\']+', s) + if match is not None: + word = match.group(0) + else: + word = s[0:1] + tokens.append(word) + s = s.replace(word, '', 1).strip(' ') + return tokens + +def query_token_set(txt, symbol_table, lexicon_table): + tokens_str = tuple() + tokens_idx = tuple() + + parts = split_mixed_label(txt) + for part in parts: + if part == '!sil' or part == '(sil)' or part == '': + tokens_str = tokens_str + ('!sil', ) + elif part == '' or part == '': + tokens_str = tokens_str + ('', ) + elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '': + tokens_str = tokens_str + ('', ) + elif part in symbol_table: + tokens_str = tokens_str + (part, ) + elif part in lexicon_table: + for ch in lexicon_table[part]: + tokens_str = tokens_str + (ch, ) + else: + # case with symbols or meaningless english letter combination + part = re.sub(symbol_str, '', part) + for ch in part: + tokens_str = tokens_str + (ch, ) + + for ch in tokens_str: + if ch in symbol_table: + tokens_idx = tokens_idx + (symbol_table[ch], ) + elif ch == '!sil': + if 'sil' in symbol_table: + tokens_idx = tokens_idx + (symbol_table['sil'], ) + else: + tokens_idx = tokens_idx + (symbol_table[''], ) + elif ch == '': + if '' in symbol_table: + tokens_idx = tokens_idx + (symbol_table[''], ) + else: + tokens_idx = tokens_idx + (symbol_table[''], ) + else: + if '' in symbol_table: + tokens_idx = tokens_idx + (symbol_table[''], ) + logging.info( + f'\'{ch}\' is not in token set, replace with ') + else: + tokens_idx = tokens_idx + (symbol_table[''], ) + logging.info( + f'\'{ch}\' is not in token set, replace with ') + + return tokens_str, tokens_idx + + +def query_token_list(txt, symbol_table, lexicon_table): + tokens_str = [] + tokens_idx = [] + + parts = split_mixed_label(txt) + for part in parts: + if part == '!sil' or part == '(sil)' or part == '': + tokens_str.append('!sil') + elif part == '' or part == '': + tokens_str.append('') + elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '': + tokens_str.append('') + elif part in symbol_table: + tokens_str.append(part) + elif part in lexicon_table: + for ch in lexicon_table[part]: + tokens_str.append(ch) + else: + # case with symbols or meaningless english letter combination + part = re.sub(symbol_str, '', part) + for ch in part: + tokens_str.append(ch) + + for ch in tokens_str: + if ch in symbol_table: + tokens_idx.append(symbol_table[ch]) + elif ch == '!sil': + if 'sil' in symbol_table: + tokens_idx.append(symbol_table['sil']) + else: + tokens_idx.append(symbol_table['']) + elif ch == '': + if '' in symbol_table: + tokens_idx.append(symbol_table['']) + else: + tokens_idx.append(symbol_table['']) + else: + if '' in symbol_table: + tokens_idx.append(symbol_table['']) + logging.info( + f'\'{ch}\' is not in token set, replace with ') + else: + tokens_idx.append(symbol_table['']) + logging.info( + f'\'{ch}\' is not in token set, replace with ') + + return tokens_str, tokens_idx + +def read_token(token_file): + tokens_table = {} + with open(token_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + tokens_table[arr[0]] = int(arr[1]) - 1 + fin.close() + return tokens_table + + +def read_lexicon(lexicon_file): + lexicon_table = {} + with open(lexicon_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().replace('\t', ' ').split() + assert len(arr) >= 2 + lexicon_table[arr[0]] = arr[1:] + fin.close() + return lexicon_table + if __name__ == '__main__': parser = argparse.ArgumentParser(description='') @@ -23,6 +158,8 @@ if __name__ == '__main__': parser.add_argument('text_file', help='text file') parser.add_argument('duration_file', help='duration file') parser.add_argument('output_file', help='output list file') + parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') + parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') args = parser.parse_args() wav_table = {} @@ -39,16 +176,35 @@ if __name__ == '__main__': assert len(arr) == 2 duration_table[arr[0]] = float(arr[1]) + token_table = None + if args.token_file: + token_table = read_token(args.token_file) + lexicon_table = None + if args.lexicon_file: + lexicon_table = read_lexicon(args.lexicon_file) + with open(args.text_file, 'r', encoding='utf8') as fin, \ open(args.output_file, 'w', encoding='utf8') as fout: for line in fin: arr = line.strip().split(maxsplit=1) key = arr[0] - txt = int(arr[1]) + tokens = None + if token_table!=None and lexicon_table!=None : + if len(arr) < 2: # for some utterence, no text + txt = [1] # the /sil is indexed by 1 + tokens = ["sil"] + else: + tokens, txt = query_token_list(arr[1], token_table, lexicon_table) + else: + txt = int(arr[1]) assert key in wav_table wav = wav_table[key] assert key in duration_table duration = duration_table[key] - line = dict(key=key, txt=txt, duration=duration, wav=wav) + if tokens == None: + line = dict(key=key, txt=txt, duration=duration, wav=wav) + else: + line = dict(key=key, tok=tokens, txt=txt, duration=duration, wav=wav) + json_line = json.dumps(line, ensure_ascii=False) fout.write(json_line + '\n') diff --git a/wekws/bin/compute_det_ctc.py b/wekws/bin/compute_det_ctc.py new file mode 100644 index 0000000..94db816 --- /dev/null +++ b/wekws/bin/compute_det_ctc.py @@ -0,0 +1,241 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# 2022 Shaoqing Yu(954793264@qq.com) +# 2023 Jing Du(thuduj12@163.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse, logging, glob +import json, re, os, numpy as np +import matplotlib.font_manager as fm +import matplotlib.pyplot as plt + +font = fm.FontProperties(size=15) + +def split_mixed_label(input_str): + tokens = [] + s = input_str.lower() + while len(s) > 0: + match = re.match(r'[A-Za-z!?,<>()\']+', s) + if match is not None: + word = match.group(0) + else: + word = s[0:1] + tokens.append(word) + s = s.replace(word, '', 1).strip(' ') + return tokens + + +def space_mixed_label(input_str): + splits = split_mixed_label(input_str) + space_str = ''.join(f'{sub} ' for sub in splits) + return space_str.strip() + +def load_label_and_score(keywords_list, label_file, score_file): + # score_table: {uttid: [keywordlist]} + score_table = {} + with open(score_file, 'r', encoding='utf8') as fin: + # read score file and store in table + for line in fin: + arr = line.strip().split() + key = arr[0] + is_detected = arr[1] + if is_detected == 'detected': + if key not in score_table: + score_table.update({ + key: { + 'kw': space_mixed_label(arr[2]), + 'confi': float(arr[3]) + } + }) + else: + if key not in score_table: + score_table.update({key: {'kw': 'unknown', 'confi': -1.0}}) + + label_lists = [] + with open(label_file, 'r', encoding='utf8') as fin: + for line in fin: + obj = json.loads(line.strip()) + label_lists.append(obj) + + # build empty structure for keyword-filler infos + keyword_filler_table = {} + for keyword in keywords_list: + keyword = space_mixed_label(keyword) + keyword_filler_table[keyword] = {} + keyword_filler_table[keyword]['keyword_table'] = {} + keyword_filler_table[keyword]['keyword_duration'] = 0.0 + keyword_filler_table[keyword]['filler_table'] = {} + keyword_filler_table[keyword]['filler_duration'] = 0.0 + + for obj in label_lists: + assert 'key' in obj + assert 'wav' in obj + assert 'tok' in obj # here we use the tokens + assert 'duration' in obj + + key = obj['key'] + # wav_file = obj['wav'] + txt = "".join(obj['tok']) + txt = space_mixed_label(txt) + txt_regstr_lrblk = ' ' + txt + ' ' + duration = obj['duration'] + assert key in score_table + + for keyword in keywords_list: + keyword = space_mixed_label(keyword) + keyword_regstr_lrblk = ' ' + keyword + ' ' + if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1: + if keyword == score_table[key]['kw']: + keyword_filler_table[keyword]['keyword_table'].update( + {key: score_table[key]['confi']}) + else: + # uttrance detected but not match this keyword + keyword_filler_table[keyword]['keyword_table'].update( + {key: -1.0}) + keyword_filler_table[keyword]['keyword_duration'] += duration + else: + if keyword == score_table[key]['kw']: + keyword_filler_table[keyword]['filler_table'].update( + {key: score_table[key]['confi']}) + else: + # uttrance if detected, which is not FA for this keyword + keyword_filler_table[keyword]['filler_table'].update( + {key: -1.0}) + keyword_filler_table[keyword]['filler_duration'] += duration + + return keyword_filler_table + +def load_stats_file(stats_file): + values = [] + with open(stats_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + threshold, fa_per_hour, frr = arr + values.append([float(fa_per_hour), float(frr) * 100]) + values.reverse() + return np.array(values) + +def plot_det(dets_dir, figure_file, det_title="DetCurve"): + xlim = '[0,2]' + # xstep = kwargs.get('xstep', '1') + ylim = '[15,30]' + # ystep = kwargs.get('ystep', '5') + + plt.figure(dpi=200) + plt.rcParams['xtick.direction'] = 'in' + plt.rcParams['ytick.direction'] = 'in' + plt.rcParams['font.size'] = 12 + + for file in glob.glob(f'{dets_dir}/*stats*.txt'): + logging.info(f'reading det data from {file}') + label = os.path.basename(file).split('.')[0] + values = load_stats_file(file) + plt.plot(values[:, 0], values[:, 1], label=label) + + xlim_splits = xlim.strip().replace('[', '').replace(']', '').split(',') + assert len(xlim_splits) == 2 + ylim_splits = ylim.strip().replace('[', '').replace(']', '').split(',') + assert len(ylim_splits) == 2 + + plt.xlim(float(xlim_splits[0]), float(xlim_splits[1])) + plt.ylim(float(ylim_splits[0]), float(ylim_splits[1])) + + # plt.xticks(range(0, xlim + x_step, x_step)) + # plt.yticks(range(0, ylim + y_step, y_step)) + plt.xlabel('False Alarm Per Hour') + plt.ylabel('False Rejection Rate (\\%)') + plt.title(det_title, fontproperties=font) + plt.grid(linestyle='--') + # plt.legend(loc='best', fontsize=6) + plt.legend(loc='upper right', fontsize=5) + # plt.show() + plt.savefig(figure_file) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='compute det curve') + parser.add_argument('--test_data', required=True, help='label file') + parser.add_argument('--keyword', type=str, default=None, help='keyword label') + parser.add_argument('--score_file', required=True, help='score file') + parser.add_argument('--step', type=float, default=0.01, + help='threshold step') + parser.add_argument('--window_shift', type=int, default=50, + help='window_shift is used to skip the frames after triggered') + parser.add_argument('--stats_dir', + required=False, + default=None, + help='false reject/alarm stats dir, default in score_file') + parser.add_argument('--det_curve_path', + required=False, + default=None, + help='det curve path, default is stats_dir/det.png') + + args = parser.parse_args() + window_shift = args.window_shift + keywords_list = args.keyword.strip().split(',') + keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file) + + for keyword in keywords_list: + keyword = space_mixed_label(keyword) + keyword_dur = keyword_filler_table[keyword]['keyword_duration'] + keyword_num = len(keyword_filler_table[keyword]['keyword_table']) + filler_dur = keyword_filler_table[keyword]['filler_duration'] + filler_num = len(keyword_filler_table[keyword]['filler_table']) + assert keyword_num > 0, 'Can\'t compute det for {} without positive sample' + assert filler_num > 0, 'Can\'t compute det for {} without negative sample' + + logging.info('Computing det for {}'.format(keyword)) + logging.info(' Keyword duration: {} Hours, wave number: {}'.format( + keyword_dur / 3600.0, keyword_num)) + logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0)) + + if args.stats_dir : + stats_dir = args.stats_dir + else: + stats_dir = os.path.dirname(args.score_file) + stats_file = os.path.join(stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt') + with open(stats_file, 'w', encoding='utf8') as fout: + threshold = 0.0 + while threshold <= 1.0: + num_false_reject = 0 + num_true_detect = 0 + # transverse the all keyword_table + for key, confi in keyword_filler_table[keyword][ + 'keyword_table'].items(): + if confi < threshold: + num_false_reject += 1 + else: + num_true_detect += 1 + + num_false_alarm = 0 + # transverse the all filler_table + for key, confi in keyword_filler_table[keyword][ + 'filler_table'].items(): + if confi >= threshold: + num_false_alarm += 1 + # print(f'false alarm: {keyword}, {key}, {confi}') + + false_reject_rate = num_false_reject / keyword_num + true_detect_rate = num_true_detect / keyword_num + + num_false_alarm = max(num_false_alarm, 1e-6) + false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0) + false_alarm_rate = num_false_alarm / filler_num + + fout.write('{:.3f} {:.6f} {:.6f}\n'.format( + threshold, false_alarm_per_hour, threshold)) + threshold += args.step + if args.det_curve_path : + det_curve_path = args.det_curve_path + else: + det_curve_path = os.path.join(stats_dir, 'det.png') + plot_det(stats_dir, det_curve_path) \ No newline at end of file diff --git a/wekws/bin/score_ctc.py b/wekws/bin/score_ctc.py new file mode 100644 index 0000000..a2221cd --- /dev/null +++ b/wekws/bin/score_ctc.py @@ -0,0 +1,206 @@ +# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com) +# 2022 Shaoqing Yu(954793264@qq.com) +# 2023 Jing Du(thuduj12@163.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os, sys, math + +import torch +import yaml +from torch.utils.data import DataLoader + +from wekws.dataset.dataset import Dataset +from wekws.model.kws_model import init_model +from wekws.utils.checkpoint import load_checkpoint +from wekws.model.loss import ctc_prefix_beam_search +from tools.make_list import query_token_set, read_lexicon, read_token + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--batch_size', + default=16, + type=int, + help='batch size for inference') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--score_file', + required=True, + help='output score file') + parser.add_argument('--jit_model', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--keywords', type=str, default=None, help='the keywords, split with comma(,)') + parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') + parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') + + args = parser.parse_args() + return args + +def is_sublist(main_list, check_list): + if len(main_list) < len(check_list): + return -1 + + if len(main_list) == len(check_list): + return 0 if main_list == check_list else -1 + + for i in range(len(main_list) - len(check_list)): + if main_list[i] == check_list[0]: + for j in range(len(check_list)): + if main_list[i + j] != check_list[j]: + break + else: + return i + else: + return -1 + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + test_conf = copy.deepcopy(configs['dataset_conf']) + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['shuffle'] = False + test_conf['feature_extraction_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_size'] = args.batch_size + + test_dataset = Dataset(args.test_data, test_conf) + test_data_loader = DataLoader(test_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + + if args.jit_model: + model = torch.jit.load(args.checkpoint) + # For script model, only cpu is supported. + device = torch.device('cpu') + else: + # Init asr model from configs + model = init_model(configs['model']) + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + model.eval() + score_abs_path = os.path.abspath(args.score_file) + + token_table = read_token(args.token_file) + lexicon_table = read_lexicon(args.lexicon_file) + # 4. parse keywords tokens + assert args.keywords is not None, 'at least one keyword is needed' + keywords_str = args.keywords + keywords_list = keywords_str.strip().replace(' ', '').split(',') + keywords_token = {} + keywords_idxset = {0} + keywords_strset = {''} + keywords_tokenmap = {'': 0} + for keyword in keywords_list: + strs, indexes = query_token_set(keyword, token_table,lexicon_table) + keywords_token[keyword] = {} + keywords_token[keyword]['token_id'] = indexes + keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) + for i in indexes) + [keywords_strset.add(i) for i in strs] + [keywords_idxset.add(i) for i in indexes] + for txt, idx in zip(strs, indexes): + if keywords_tokenmap.get(txt, None) is None: + keywords_tokenmap[txt] = idx + + token_print = '' + for txt, idx in keywords_tokenmap.items(): + token_print += f'{txt}({idx}) ' + logging.info(f'Token set is: {token_print}') + + with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: + for batch_idx, batch in enumerate(test_data_loader): + keys, feats, target, lengths, target_lengths = batch + feats = feats.to(device) + lengths = lengths.to(device) + logits, _ = model(feats) + logits = logits.softmax(2) # (batch_size, maxlen, vocab_size) + logits = logits.cpu() + for i in range(len(keys)): + key = keys[i] + score = logits[i][:lengths[i]] + hyps = ctc_prefix_beam_search(score, lengths[i], + keywords_idxset) + hit_keyword = None + hit_score = 1.0 + # start = 0; end = 0 + for one_hyp in hyps: + prefix_ids = one_hyp[0] + # path_score = one_hyp[1] + prefix_nodes = one_hyp[2] + assert len(prefix_ids) == len(prefix_nodes) + for word in keywords_token.keys(): + lab = keywords_token[word]['token_id'] + offset = is_sublist(prefix_ids, lab) + if offset != -1: + hit_keyword = word + # start = prefix_nodes[offset]['frame'] + # end = prefix_nodes[offset+len(lab)-1]['frame'] + for idx in range(offset, offset + len(lab)): + hit_score *= prefix_nodes[idx]['prob'] + break + if hit_keyword is not None: + hit_score = math.sqrt(hit_score) + break + + if hit_keyword is not None: + # fout.write('{} detected [{:.2f} {:.2f}] {} {:.3f}\n'\ + # .format(key, start*0.03, end*0.03, hit_keyword, hit_score)) + fout.write('{} {} {:.3f}\n'.format( + key, hit_keyword, hit_score)) + else: + fout.write('{} -1 -1\n'.format(key)) + + if batch_idx % 10 == 0: + print('Progress batch {}'.format(batch_idx)) + sys.stdout.flush() + + +if __name__ == '__main__': + main() diff --git a/wekws/dataset/processor.py b/wekws/dataset/processor.py index 988b3a2..c21de63 100644 --- a/wekws/dataset/processor.py +++ b/wekws/dataset/processor.py @@ -302,12 +302,24 @@ def padding(data): [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) sorted_feats = [sample[i]['feat'] for i in order] sorted_keys = [sample[i]['key'] for i in order] - sorted_labels = torch.tensor([sample[i]['label'] for i in order], - dtype=torch.int64) padded_feats = pad_sequence(sorted_feats, batch_first=True, padding_value=0) - yield (sorted_keys, padded_feats, sorted_labels, feats_lengths) + + if isinstance(sample[0]['label'], int): + padded_labels = torch.tensor([sample[i]['label'] for i in order], + dtype=torch.int32) + label_lengths = torch.tensor([1 for i in order], + dtype=torch.int32) + else: + sorted_labels = [ + torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order + ] + label_lengths = torch.tensor([len(sample[i]['label']) for i in order], + dtype=torch.int32) + padded_labels = pad_sequence( + sorted_labels, batch_first=True, padding_value=-1) + yield (sorted_keys, padded_feats, padded_labels, feats_lengths, label_lengths) def add_reverb(data, reverb_source, aug_prob): diff --git a/wekws/model/kws_model.py b/wekws/model/kws_model.py index 993cbd4..3df0849 100644 --- a/wekws/model/kws_model.py +++ b/wekws/model/kws_model.py @@ -162,6 +162,16 @@ def init_model(configs): classifier = LinearClassifier(hidden_dim, output_dim) activation = nn.Sigmoid() + # Here we add a possible "activation_type", one can choose to use other activation function. + # We use nn.Identity just for CTC loss + if "activation" in configs: + activation_type = configs["activation"]["type"] + if activation_type == 'identity': + activation = nn.Identity() + else: + print('Unknown activation type {}'.format(activation_type)) + sys.exit(1) + kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, preprocessing, backbone, classifier, activation) return kws_model diff --git a/wekws/model/loss.py b/wekws/model/loss.py index f8315ef..ce80da0 100644 --- a/wekws/model/loss.py +++ b/wekws/model/loss.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +import torch, math, sys import torch.nn.functional as F +from collections import defaultdict +from typing import List, Optional, Tuple from wekws.utils.mask import padding_mask @@ -93,6 +95,65 @@ def acc_frame( correct = pred.eq(target.long().view_as(pred)).sum().item() return correct * 100.0 / logits.size(0) +def acc_utterance(logits: torch.Tensor, target: torch.Tensor, + logits_length: torch.Tensor, target_length: torch.Tensor): + if logits is None: + return 0 + + logits = logits.softmax(2) # (1, maxlen, vocab_size) + logits = logits.cpu() + target = target.cpu() + + total_word = 0 + total_ins = 0 + total_sub = 0 + total_del = 0 + calculator = Calculator() + for i in range(logits.size(0)): + score = logits[i][:logits_length[i]] + hyps = ctc_prefix_beam_search(score, logits_length[i], None, 3, 5) + lab = [str(item) for item in target[i][:target_length[i]].tolist()] + rec = [] + if len(hyps) > 0: + rec = [str(item) for item in hyps[0][0]] + result = calculator.calculate(lab, rec) + # print(f'result:{result}') + if result['all'] != 0: + total_word += result['all'] + total_ins += result['ins'] + total_sub += result['sub'] + total_del += result['del'] + + return float(total_word - total_ins - total_sub + - total_del) * 100.0 / total_word + +def ctc_loss(logits: torch.Tensor, + target: torch.Tensor, + logits_lengths: torch.Tensor, + target_lengths: torch.Tensor, + need_acc: bool = False): + """ CTC Loss + Args: + logits: (B, D), D is the number of keywords plus 1 (non-keyword) + target: (B) + logits_lengths: (B) + target_lengths: (B) + Returns: + (float): loss of current batch + """ + + acc = 0.0 + if need_acc: + acc = acc_utterance(logits, target, logits_lengths, target_lengths) + + # logits: (B, L, D) -> (L, B, D) + logits = logits.transpose(0, 1) + logits = logits.log_softmax(2) + loss = F.ctc_loss( + logits, target, logits_lengths, target_lengths, reduction='sum') + loss = loss / logits.size(1) # batch mean + + return loss, acc def cross_entropy(logits: torch.Tensor, target: torch.Tensor): """ Cross Entropy Loss @@ -114,12 +175,279 @@ def criterion(type: str, logits: torch.Tensor, target: torch.Tensor, lengths: torch.Tensor, - min_duration: int = 0): + target_lengths: torch.Tensor = None, + min_duration: int = 0, + validation: bool = False, ): if type == 'ce': loss, acc = cross_entropy(logits, target) return loss, acc elif type == 'max_pooling': loss, acc = max_pooling_loss(logits, target, lengths, min_duration) return loss, acc + elif type == 'ctc': + loss, acc = ctc_loss(logits, target, lengths, target_lengths, validation) + return loss, acc else: exit(1) + +def ctc_prefix_beam_search( + logits: torch.Tensor, + logits_lengths: torch.Tensor, + keywords_tokenset: set = None, + score_beam_size: int = 3, + path_beam_size: int = 20, +) -> Tuple[List[List[int]], torch.Tensor]: + """ CTC prefix beam search inner implementation + + Args: + logits (torch.Tensor): (1, max_len, vocab_size) + logits_lengths (torch.Tensor): (1, ) + keywords_tokenset (set): token set for filtering score + score_beam_size (int): beam size for score + path_beam_size (int): beam size for path + + Returns: + List[List[int]]: nbest results + """ + maxlen = logits.size(0) + # ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size) + ctc_probs = logits + + cur_hyps = [(tuple(), (1.0, 0.0, []))] + + # 2. CTC beam search step by step + for t in range(0, maxlen): + probs = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (0.0, 0.0, [])) + + # 2.1 First beam prune: select topk best + top_k_probs, top_k_index = probs.topk( + score_beam_size) # (score_beam_size,) + + # filter prob score that is too small + filter_probs = [] + filter_index = [] + for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()): + if keywords_tokenset is not None: + if prob > 0.05 and idx in keywords_tokenset: + filter_probs.append(prob) + filter_index.append(idx) + else: + if prob > 0.05: + filter_probs.append(prob) + filter_index.append(idx) + + if len(filter_index) == 0: + continue + + for s in filter_index: + ps = probs[s].item() + + for prefix, (pb, pnb, cur_nodes) in cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == 0: # blank + n_pb, n_pnb, nodes = next_hyps[prefix] + n_pb = n_pb + pb * ps + pnb * ps + nodes = cur_nodes.copy() + next_hyps[prefix] = (n_pb, n_pnb, nodes) + elif s == last: + if not math.isclose(pnb, 0.0, abs_tol=0.000001): + # Update *ss -> *s; + n_pb, n_pnb, nodes = next_hyps[prefix] + n_pnb = n_pnb + pnb * ps + nodes = cur_nodes.copy() + if ps > nodes[-1]['prob']: # update frame and prob + nodes[-1]['prob'] = ps + nodes[-1]['frame'] = t + next_hyps[prefix] = (n_pb, n_pnb, nodes) + + if not math.isclose(pb, 0.0, abs_tol=0.000001): + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb, nodes = next_hyps[n_prefix] + n_pnb = n_pnb + pb * ps + nodes = cur_nodes.copy() + nodes.append(dict(token=s, frame=t, + prob=ps)) # to record token prob + next_hyps[n_prefix] = (n_pb, n_pnb, nodes) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb, nodes = next_hyps[n_prefix] + if nodes: + if ps > nodes[-1]['prob']: # update frame and prob + nodes[-1]['prob'] = ps + nodes[-1]['frame'] = t + else: + nodes = cur_nodes.copy() + nodes.append(dict(token=s, frame=t, + prob=ps)) # to record token prob + n_pnb = n_pnb + pb * ps + pnb * ps + next_hyps[n_prefix] = (n_pb, n_pnb, nodes) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True) + + cur_hyps = next_hyps[:path_beam_size] + + hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps] + return hyps + + +class Calculator: + + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)): + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, '') + i = i - 1 + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, '') + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non': # starting point + break + else: # shouldn't reach here + print( + 'this should not happen , i = {i} , j = {j} , error = {error}' + .format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data: + if token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self): + return list(self.data.keys()) \ No newline at end of file diff --git a/wekws/utils/executor.py b/wekws/utils/executor.py index eb1e26a..2283e71 100644 --- a/wekws/utils/executor.py +++ b/wekws/utils/executor.py @@ -34,17 +34,20 @@ class Executor: min_duration = args.get('min_duration', 0) for batch_idx, batch in enumerate(data_loader): - key, feats, target, feats_lengths = batch + key, feats, target, feats_lengths, label_lengths = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) + label_lengths = label_lengths.to(device) num_utts = feats_lengths.size(0) if num_utts == 0: continue logits, _ = model(feats) loss_type = args.get('criterion', 'max_pooling') loss, acc = criterion(loss_type, logits, target, feats_lengths, - min_duration) + target_lengths=label_lengths, + min_duration=min_duration, + validation=False) optimizer.zero_grad() loss.backward() grad_norm = clip_grad_norm_(model.parameters(), clip) @@ -67,16 +70,20 @@ class Executor: total_acc = 0.0 with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): - key, feats, target, feats_lengths = batch + key, feats, target, feats_lengths, label_lengths = batch feats = feats.to(device) target = target.to(device) feats_lengths = feats_lengths.to(device) + label_lengths = label_lengths.to(device) num_utts = feats_lengths.size(0) if num_utts == 0: continue logits, _ = model(feats) loss, acc = criterion(args.get('criterion', 'max_pooling'), - logits, target, feats_lengths) + logits, target, feats_lengths, + target_lengths=label_lengths, + min_duration=0, + validation=True) if torch.isfinite(loss): num_seen_utts += num_utts total_loss += loss.item() * num_utts