From 4a875776e5386227daffb0fd0f359305f1728f87 Mon Sep 17 00:00:00 2001 From: Menglong Xu <32296227+mlxu995@users.noreply.github.com> Date: Wed, 8 Dec 2021 23:46:05 +0800 Subject: [PATCH] [example] support hey_snips_kws_4.0 dataset (#38) * [example] support hey_snips_kws_4.0 dataset * format * format --- examples/hey_snips/s0/README.md | 6 + examples/hey_snips/s0/conf/ds_tcn.yaml | 45 +++++++ examples/hey_snips/s0/conf/mdtc_small.yaml | 50 +++++++ examples/hey_snips/s0/kws | 1 + examples/hey_snips/s0/local/prepare_data.py | 46 +++++++ .../hey_snips/s0/local/snips_data_extract.sh | 30 +++++ examples/hey_snips/s0/path.sh | 5 + examples/hey_snips/s0/run.sh | 125 ++++++++++++++++++ examples/hey_snips/s0/tools | 1 + 9 files changed, 309 insertions(+) create mode 100644 examples/hey_snips/s0/README.md create mode 100644 examples/hey_snips/s0/conf/ds_tcn.yaml create mode 100644 examples/hey_snips/s0/conf/mdtc_small.yaml create mode 120000 examples/hey_snips/s0/kws create mode 100755 examples/hey_snips/s0/local/prepare_data.py create mode 100755 examples/hey_snips/s0/local/snips_data_extract.sh create mode 100755 examples/hey_snips/s0/path.sh create mode 100755 examples/hey_snips/s0/run.sh create mode 120000 examples/hey_snips/s0/tools diff --git a/examples/hey_snips/s0/README.md b/examples/hey_snips/s0/README.md new file mode 100644 index 0000000..26744ff --- /dev/null +++ b/examples/hey_snips/s0/README.md @@ -0,0 +1,6 @@ +FRRs with FAR fixed at once per hour: + +| model | params(K) | epoch | hey_snips | +|------------------|-----------|------------|------------| +| DS_TCN | 21 | 80(avg30) | 0.019771 | +| MDTC_Small | 31 | 100(avg10) | 0.008699 | \ No newline at end of file diff --git a/examples/hey_snips/s0/conf/ds_tcn.yaml b/examples/hey_snips/s0/conf/ds_tcn.yaml new file mode 100644 index 0000000..88b0c78 --- /dev/null +++ b/examples/hey_snips/s0/conf/ds_tcn.yaml @@ -0,0 +1,45 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'fbank' + num_mel_bins: 40 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 1 + num_f_mask: 1 + max_t: 50 + max_f: 30 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 256 + +model: + hidden_dim: 64 + preprocessing: + type: linear + backbone: + type: tcn + ds: true + num_layers: 4 + kernel_size: 8 + dropout: 0.1 + +optim: adam +optim_conf: + lr: 0.001 + +training_config: + grad_clip: 5 + max_epoch: 80 + log_interval: 10 + diff --git a/examples/hey_snips/s0/conf/mdtc_small.yaml b/examples/hey_snips/s0/conf/mdtc_small.yaml new file mode 100644 index 0000000..06932b9 --- /dev/null +++ b/examples/hey_snips/s0/conf/mdtc_small.yaml @@ -0,0 +1,50 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'mfcc' + num_ceps: 80 + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + feature_dither: 0.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 1 + num_f_mask: 1 + max_t: 20 + max_f: 40 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 100 + +model: + hidden_dim: 32 + preprocessing: + type: none + backbone: + type: mdtc + num_stack: 3 + stack_size: 4 + kernel_size: 5 + hidden_dim: 32 + causal: True + + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 0.00005 + +training_config: + grad_clip: 5 + max_epoch: 100 + log_interval: 10 + criterion: max_pooling diff --git a/examples/hey_snips/s0/kws b/examples/hey_snips/s0/kws new file mode 120000 index 0000000..7a3e8e1 --- /dev/null +++ b/examples/hey_snips/s0/kws @@ -0,0 +1 @@ +../../../kws \ No newline at end of file diff --git a/examples/hey_snips/s0/local/prepare_data.py b/examples/hey_snips/s0/local/prepare_data.py new file mode 100755 index 0000000..bdce1c1 --- /dev/null +++ b/examples/hey_snips/s0/local/prepare_data.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +# Copyright 2018-2020 Yiming Wang +# 2018-2020 Daniel Povey +# 2021 Binbin Zhang +# Menglong Xu +# Apache 2.0 +""" This script prepares the snips data into kaldi format. +""" + +import argparse +import os +import json + + +def main(): + parser = argparse.ArgumentParser(description="""Prepare data.""") + parser.add_argument('wav_dir', + type=str, + help='dir containing all the wav files') + parser.add_argument('path', type=str, help='path to the json file') + parser.add_argument('out_dir', type=str, help='out dir') + args = parser.parse_args() + + with open(args.path, 'r', encoding='utf-8') as f: + data = json.load(f) + utt_id, label = [], [] + for entry in data: + if entry['duration'] > 0: + utt_id.append(entry['id']) + keyword_id = 0 if entry['is_hotword'] == 1 else -1 + label.append(keyword_id) + + abs_dir = os.path.abspath(args.wav_dir) + wav_path = os.path.join(args.out_dir, 'wav.scp') + text_path = os.path.join(args.out_dir, 'text') + with open(wav_path, 'w', encoding='utf-8') as f_wav, \ + open(text_path, 'w', encoding='utf-8') as f_text: + for utt, l in zip(utt_id, label): + f_wav.write('{} {}\n'.format(utt, + os.path.join(abs_dir, utt + ".wav"))) + f_text.write('{} {}\n'.format(utt, l)) + + +if __name__ == "__main__": + main() diff --git a/examples/hey_snips/s0/local/snips_data_extract.sh b/examples/hey_snips/s0/local/snips_data_extract.sh new file mode 100755 index 0000000..3e31093 --- /dev/null +++ b/examples/hey_snips/s0/local/snips_data_extract.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +# Copyright 2018-2020 Yiming Wang +# 2018-2020 Daniel Povey +# 2021 Binbin Zhang +# Menglong Xu + +[ -f ./path.sh ] && . ./path.sh + +dl_dir=data/download + +. tools/parse_options.sh || exit 1; + +mkdir -p $dl_dir + +# Fill the following form: +# https://forms.gle/JtmFYM7xK1SaMfZYA +# to download the dataset +dataset=hey_snips_kws_4.0.tar.gz +src_path=$dl_dir + +if [ -d $dl_dir/$(basename "$dataset" .tar.gz) ]; then + echo "Not extracting $(basename "$dataset" .tar.gz) as it is already there." +else + echo "Extracting $dataset..." + tar -xvzf $src_path/$dataset -C $dl_dir || exit 1; + echo "Done extracting $dataset." +fi + +exit 0 diff --git a/examples/hey_snips/s0/path.sh b/examples/hey_snips/s0/path.sh new file mode 100755 index 0000000..b90a515 --- /dev/null +++ b/examples/hey_snips/s0/path.sh @@ -0,0 +1,5 @@ +export PATH=$PWD:$PATH + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:$PYTHONPATH diff --git a/examples/hey_snips/s0/run.sh b/examples/hey_snips/s0/run.sh new file mode 100755 index 0000000..d2ac9c7 --- /dev/null +++ b/examples/hey_snips/s0/run.sh @@ -0,0 +1,125 @@ +#!/bin/bash +# Copyright 2021 Binbin Zhang +# Menglong Xu + +. ./path.sh + +export CUDA_VISIBLE_DEVICES="0" + +stage=0 +stop_stage=4 +num_keywords=1 + +config=conf/ds_tcn.yaml +norm_mean=true +norm_var=true +gpu_id=0 + +checkpoint= +dir=exp/ds_tcn + +num_average=30 +score_checkpoint=$dir/avg_${num_average}.pt + +download_dir=./data/local # your data dir + +. tools/parse_options.sh || exit 1; + +set -euo pipefail + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "Extracte all datasets" + local/snips_data_extract.sh --dl_dir $download_dir +fi + + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "Preparing datasets..." + mkdir -p dict + echo " -1" > dict/words.txt + echo "Hey_Snips 0" >> dict/words.txt + + for folder in train dev test; do + mkdir -p data/$folder + json_path=$download_dir/hey_snips_research_6k_en_train_eval_clean_ter/$folder.json + local/prepare_data.py $download_dir/hey_snips_research_6k_en_train_eval_clean_ter/audio_files $json_path \ + data/$folder + done +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Compute CMVN and Format datasets" + tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \ + --in_scp data/train/wav.scp \ + --out_cmvn data/train/global_cmvn + + for x in train dev test; do + tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur + tools/make_list.py data/$x/wav.scp data/$x/text \ + data/$x/wav.dur data/$x/data.list + done +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Start training ..." + mkdir -p $dir + cmvn_opts= + $norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn" + $norm_var && cmvn_opts="$cmvn_opts --norm_var" + python kws/bin/train.py --gpu $gpu_id \ + --config $config \ + --train_data data/train/data.list \ + --cv_data data/dev/data.list \ + --model_dir $dir \ + --num_workers 8 \ + --num_keywords $num_keywords \ + --min_duration 50 \ + --seed 777 \ + $cmvn_opts \ + ${checkpoint:+--checkpoint $checkpoint} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # Do model average + python kws/bin/average_model.py \ + --dst_model $score_checkpoint \ + --src_path $dir \ + --num ${num_average} \ + --val_best + + # Compute posterior score + result_dir=$dir/test_$(basename $score_checkpoint) + mkdir -p $result_dir + python kws/bin/score.py --gpu $gpu_id \ + --config $dir/config.yaml \ + --test_data data/test/data.list \ + --batch_size 256 \ + --checkpoint $score_checkpoint \ + --score_file $result_dir/score.txt \ + --num_workers 8 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # Compute detection error tradeoff + result_dir=$dir/test_$(basename $score_checkpoint) + first_keyword=0 + last_keyword=$(($num_keywords+$first_keyword-1)) + for keyword in $(seq $first_keyword $last_keyword); do + python kws/bin/compute_det.py \ + --keyword $keyword \ + --test_data data/test/data.list \ + --score_file $result_dir/score.txt \ + --stats_file $result_dir/stats.${keyword}.txt + done +fi + + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + python kws/bin/export_jit.py --config $dir/config.yaml \ + --checkpoint $score_checkpoint \ + --output_file $dir/final.zip \ + --output_quant_file $dir/final.quant.zip +fi + diff --git a/examples/hey_snips/s0/tools b/examples/hey_snips/s0/tools new file mode 120000 index 0000000..c92f417 --- /dev/null +++ b/examples/hey_snips/s0/tools @@ -0,0 +1 @@ +../../../tools \ No newline at end of file