[example] support hey_snips_kws_4.0 dataset (#38)

* [example] support hey_snips_kws_4.0 dataset

* format

* format
This commit is contained in:
Menglong Xu 2021-12-08 23:46:05 +08:00 committed by GitHub
parent 1eda27647b
commit 4a875776e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 309 additions and 0 deletions

View File

@ -0,0 +1,6 @@
FRRs with FAR fixed at once per hour:
| model | params(K) | epoch | hey_snips |
|------------------|-----------|------------|------------|
| DS_TCN | 21 | 80(avg30) | 0.019771 |
| MDTC_Small | 31 | 100(avg10) | 0.008699 |

View File

@ -0,0 +1,45 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 50
max_f: 30
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256
model:
hidden_dim: 64
preprocessing:
type: linear
backbone:
type: tcn
ds: true
num_layers: 4
kernel_size: 8
dropout: 0.1
optim: adam
optim_conf:
lr: 0.001
training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10

View File

@ -0,0 +1,50 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'mfcc'
num_ceps: 80
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
feature_dither: 0.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 40
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 100
model:
hidden_dim: 32
preprocessing:
type: none
backbone:
type: mdtc
num_stack: 3
stack_size: 4
kernel_size: 5
hidden_dim: 32
causal: True
optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.00005
training_config:
grad_clip: 5
max_epoch: 100
log_interval: 10
criterion: max_pooling

1
examples/hey_snips/s0/kws Symbolic link
View File

@ -0,0 +1 @@
../../../kws

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python3
# Copyright 2018-2020 Yiming Wang
# 2018-2020 Daniel Povey
# 2021 Binbin Zhang
# Menglong Xu
# Apache 2.0
""" This script prepares the snips data into kaldi format.
"""
import argparse
import os
import json
def main():
parser = argparse.ArgumentParser(description="""Prepare data.""")
parser.add_argument('wav_dir',
type=str,
help='dir containing all the wav files')
parser.add_argument('path', type=str, help='path to the json file')
parser.add_argument('out_dir', type=str, help='out dir')
args = parser.parse_args()
with open(args.path, 'r', encoding='utf-8') as f:
data = json.load(f)
utt_id, label = [], []
for entry in data:
if entry['duration'] > 0:
utt_id.append(entry['id'])
keyword_id = 0 if entry['is_hotword'] == 1 else -1
label.append(keyword_id)
abs_dir = os.path.abspath(args.wav_dir)
wav_path = os.path.join(args.out_dir, 'wav.scp')
text_path = os.path.join(args.out_dir, 'text')
with open(wav_path, 'w', encoding='utf-8') as f_wav, \
open(text_path, 'w', encoding='utf-8') as f_text:
for utt, l in zip(utt_id, label):
f_wav.write('{} {}\n'.format(utt,
os.path.join(abs_dir, utt + ".wav")))
f_text.write('{} {}\n'.format(utt, l))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2018-2020 Yiming Wang
# 2018-2020 Daniel Povey
# 2021 Binbin Zhang
# Menglong Xu
[ -f ./path.sh ] && . ./path.sh
dl_dir=data/download
. tools/parse_options.sh || exit 1;
mkdir -p $dl_dir
# Fill the following form:
# https://forms.gle/JtmFYM7xK1SaMfZYA
# to download the dataset
dataset=hey_snips_kws_4.0.tar.gz
src_path=$dl_dir
if [ -d $dl_dir/$(basename "$dataset" .tar.gz) ]; then
echo "Not extracting $(basename "$dataset" .tar.gz) as it is already there."
else
echo "Extracting $dataset..."
tar -xvzf $src_path/$dataset -C $dl_dir || exit 1;
echo "Done extracting $dataset."
fi
exit 0

5
examples/hey_snips/s0/path.sh Executable file
View File

@ -0,0 +1,5 @@
export PATH=$PWD:$PATH
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=../../../:$PYTHONPATH

125
examples/hey_snips/s0/run.sh Executable file
View File

@ -0,0 +1,125 @@
#!/bin/bash
# Copyright 2021 Binbin Zhang
# Menglong Xu
. ./path.sh
export CUDA_VISIBLE_DEVICES="0"
stage=0
stop_stage=4
num_keywords=1
config=conf/ds_tcn.yaml
norm_mean=true
norm_var=true
gpu_id=0
checkpoint=
dir=exp/ds_tcn
num_average=30
score_checkpoint=$dir/avg_${num_average}.pt
download_dir=./data/local # your data dir
. tools/parse_options.sh || exit 1;
set -euo pipefail
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Extracte all datasets"
local/snips_data_extract.sh --dl_dir $download_dir
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Preparing datasets..."
mkdir -p dict
echo "<filler> -1" > dict/words.txt
echo "Hey_Snips 0" >> dict/words.txt
for folder in train dev test; do
mkdir -p data/$folder
json_path=$download_dir/hey_snips_research_6k_en_train_eval_clean_ter/$folder.json
local/prepare_data.py $download_dir/hey_snips_research_6k_en_train_eval_clean_ter/audio_files $json_path \
data/$folder
done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn
for x in train dev test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Start training ..."
mkdir -p $dir
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
python kws/bin/train.py --gpu $gpu_id \
--config $config \
--train_data data/train/data.list \
--cv_data data/dev/data.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 50 \
--seed 777 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# Do model average
python kws/bin/average_model.py \
--dst_model $score_checkpoint \
--src_path $dir \
--num ${num_average} \
--val_best
# Compute posterior score
result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir
python kws/bin/score.py --gpu $gpu_id \
--config $dir/config.yaml \
--test_data data/test/data.list \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
--num_workers 8
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# Compute detection error tradeoff
result_dir=$dir/test_$(basename $score_checkpoint)
first_keyword=0
last_keyword=$(($num_keywords+$first_keyword-1))
for keyword in $(seq $first_keyword $last_keyword); do
python kws/bin/compute_det.py \
--keyword $keyword \
--test_data data/test/data.list \
--score_file $result_dir/score.txt \
--stats_file $result_dir/stats.${keyword}.txt
done
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
python kws/bin/export_jit.py --config $dir/config.yaml \
--checkpoint $score_checkpoint \
--output_file $dir/final.zip \
--output_quant_file $dir/final.quant.zip
fi

1
examples/hey_snips/s0/tools Symbolic link
View File

@ -0,0 +1 @@
../../../tools