[example] support hey_snips_kws_4.0 dataset (#38)
* [example] support hey_snips_kws_4.0 dataset * format * format
This commit is contained in:
parent
1eda27647b
commit
4a875776e5
6
examples/hey_snips/s0/README.md
Normal file
6
examples/hey_snips/s0/README.md
Normal file
@ -0,0 +1,6 @@
|
||||
FRRs with FAR fixed at once per hour:
|
||||
|
||||
| model | params(K) | epoch | hey_snips |
|
||||
|------------------|-----------|------------|------------|
|
||||
| DS_TCN | 21 | 80(avg30) | 0.019771 |
|
||||
| MDTC_Small | 31 | 100(avg10) | 0.008699 |
|
||||
45
examples/hey_snips/s0/conf/ds_tcn.yaml
Normal file
45
examples/hey_snips/s0/conf/ds_tcn.yaml
Normal file
@ -0,0 +1,45 @@
|
||||
dataset_conf:
|
||||
filter_conf:
|
||||
max_length: 2048
|
||||
min_length: 0
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
feature_extraction_conf:
|
||||
feature_type: 'fbank'
|
||||
num_mel_bins: 40
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 1.0
|
||||
spec_aug: true
|
||||
spec_aug_conf:
|
||||
num_t_mask: 1
|
||||
num_f_mask: 1
|
||||
max_t: 50
|
||||
max_f: 30
|
||||
shuffle: true
|
||||
shuffle_conf:
|
||||
shuffle_size: 1500
|
||||
batch_conf:
|
||||
batch_size: 256
|
||||
|
||||
model:
|
||||
hidden_dim: 64
|
||||
preprocessing:
|
||||
type: linear
|
||||
backbone:
|
||||
type: tcn
|
||||
ds: true
|
||||
num_layers: 4
|
||||
kernel_size: 8
|
||||
dropout: 0.1
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 80
|
||||
log_interval: 10
|
||||
|
||||
50
examples/hey_snips/s0/conf/mdtc_small.yaml
Normal file
50
examples/hey_snips/s0/conf/mdtc_small.yaml
Normal file
@ -0,0 +1,50 @@
|
||||
dataset_conf:
|
||||
filter_conf:
|
||||
max_length: 2048
|
||||
min_length: 0
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
feature_extraction_conf:
|
||||
feature_type: 'mfcc'
|
||||
num_ceps: 80
|
||||
num_mel_bins: 80
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 1.0
|
||||
feature_dither: 0.0
|
||||
spec_aug: true
|
||||
spec_aug_conf:
|
||||
num_t_mask: 1
|
||||
num_f_mask: 1
|
||||
max_t: 20
|
||||
max_f: 40
|
||||
shuffle: true
|
||||
shuffle_conf:
|
||||
shuffle_size: 1500
|
||||
batch_conf:
|
||||
batch_size: 100
|
||||
|
||||
model:
|
||||
hidden_dim: 32
|
||||
preprocessing:
|
||||
type: none
|
||||
backbone:
|
||||
type: mdtc
|
||||
num_stack: 3
|
||||
stack_size: 4
|
||||
kernel_size: 5
|
||||
hidden_dim: 32
|
||||
causal: True
|
||||
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 0.00005
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 100
|
||||
log_interval: 10
|
||||
criterion: max_pooling
|
||||
1
examples/hey_snips/s0/kws
Symbolic link
1
examples/hey_snips/s0/kws
Symbolic link
@ -0,0 +1 @@
|
||||
../../../kws
|
||||
46
examples/hey_snips/s0/local/prepare_data.py
Executable file
46
examples/hey_snips/s0/local/prepare_data.py
Executable file
@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2018-2020 Yiming Wang
|
||||
# 2018-2020 Daniel Povey
|
||||
# 2021 Binbin Zhang
|
||||
# Menglong Xu
|
||||
# Apache 2.0
|
||||
""" This script prepares the snips data into kaldi format.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="""Prepare data.""")
|
||||
parser.add_argument('wav_dir',
|
||||
type=str,
|
||||
help='dir containing all the wav files')
|
||||
parser.add_argument('path', type=str, help='path to the json file')
|
||||
parser.add_argument('out_dir', type=str, help='out dir')
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
utt_id, label = [], []
|
||||
for entry in data:
|
||||
if entry['duration'] > 0:
|
||||
utt_id.append(entry['id'])
|
||||
keyword_id = 0 if entry['is_hotword'] == 1 else -1
|
||||
label.append(keyword_id)
|
||||
|
||||
abs_dir = os.path.abspath(args.wav_dir)
|
||||
wav_path = os.path.join(args.out_dir, 'wav.scp')
|
||||
text_path = os.path.join(args.out_dir, 'text')
|
||||
with open(wav_path, 'w', encoding='utf-8') as f_wav, \
|
||||
open(text_path, 'w', encoding='utf-8') as f_text:
|
||||
for utt, l in zip(utt_id, label):
|
||||
f_wav.write('{} {}\n'.format(utt,
|
||||
os.path.join(abs_dir, utt + ".wav")))
|
||||
f_text.write('{} {}\n'.format(utt, l))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
30
examples/hey_snips/s0/local/snips_data_extract.sh
Executable file
30
examples/hey_snips/s0/local/snips_data_extract.sh
Executable file
@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2018-2020 Yiming Wang
|
||||
# 2018-2020 Daniel Povey
|
||||
# 2021 Binbin Zhang
|
||||
# Menglong Xu
|
||||
|
||||
[ -f ./path.sh ] && . ./path.sh
|
||||
|
||||
dl_dir=data/download
|
||||
|
||||
. tools/parse_options.sh || exit 1;
|
||||
|
||||
mkdir -p $dl_dir
|
||||
|
||||
# Fill the following form:
|
||||
# https://forms.gle/JtmFYM7xK1SaMfZYA
|
||||
# to download the dataset
|
||||
dataset=hey_snips_kws_4.0.tar.gz
|
||||
src_path=$dl_dir
|
||||
|
||||
if [ -d $dl_dir/$(basename "$dataset" .tar.gz) ]; then
|
||||
echo "Not extracting $(basename "$dataset" .tar.gz) as it is already there."
|
||||
else
|
||||
echo "Extracting $dataset..."
|
||||
tar -xvzf $src_path/$dataset -C $dl_dir || exit 1;
|
||||
echo "Done extracting $dataset."
|
||||
fi
|
||||
|
||||
exit 0
|
||||
5
examples/hey_snips/s0/path.sh
Executable file
5
examples/hey_snips/s0/path.sh
Executable file
@ -0,0 +1,5 @@
|
||||
export PATH=$PWD:$PATH
|
||||
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=../../../:$PYTHONPATH
|
||||
125
examples/hey_snips/s0/run.sh
Executable file
125
examples/hey_snips/s0/run.sh
Executable file
@ -0,0 +1,125 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Binbin Zhang
|
||||
# Menglong Xu
|
||||
|
||||
. ./path.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
stage=0
|
||||
stop_stage=4
|
||||
num_keywords=1
|
||||
|
||||
config=conf/ds_tcn.yaml
|
||||
norm_mean=true
|
||||
norm_var=true
|
||||
gpu_id=0
|
||||
|
||||
checkpoint=
|
||||
dir=exp/ds_tcn
|
||||
|
||||
num_average=30
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
|
||||
download_dir=./data/local # your data dir
|
||||
|
||||
. tools/parse_options.sh || exit 1;
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
echo "Extracte all datasets"
|
||||
local/snips_data_extract.sh --dl_dir $download_dir
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
echo "Preparing datasets..."
|
||||
mkdir -p dict
|
||||
echo "<filler> -1" > dict/words.txt
|
||||
echo "Hey_Snips 0" >> dict/words.txt
|
||||
|
||||
for folder in train dev test; do
|
||||
mkdir -p data/$folder
|
||||
json_path=$download_dir/hey_snips_research_6k_en_train_eval_clean_ter/$folder.json
|
||||
local/prepare_data.py $download_dir/hey_snips_research_6k_en_train_eval_clean_ter/audio_files $json_path \
|
||||
data/$folder
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
echo "Compute CMVN and Format datasets"
|
||||
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
|
||||
--in_scp data/train/wav.scp \
|
||||
--out_cmvn data/train/global_cmvn
|
||||
|
||||
for x in train dev test; do
|
||||
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
||||
tools/make_list.py data/$x/wav.scp data/$x/text \
|
||||
data/$x/wav.dur data/$x/data.list
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
echo "Start training ..."
|
||||
mkdir -p $dir
|
||||
cmvn_opts=
|
||||
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||
python kws/bin/train.py --gpu $gpu_id \
|
||||
--config $config \
|
||||
--train_data data/train/data.list \
|
||||
--cv_data data/dev/data.list \
|
||||
--model_dir $dir \
|
||||
--num_workers 8 \
|
||||
--num_keywords $num_keywords \
|
||||
--min_duration 50 \
|
||||
--seed 777 \
|
||||
$cmvn_opts \
|
||||
${checkpoint:+--checkpoint $checkpoint}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# Do model average
|
||||
python kws/bin/average_model.py \
|
||||
--dst_model $score_checkpoint \
|
||||
--src_path $dir \
|
||||
--num ${num_average} \
|
||||
--val_best
|
||||
|
||||
# Compute posterior score
|
||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||
mkdir -p $result_dir
|
||||
python kws/bin/score.py --gpu $gpu_id \
|
||||
--config $dir/config.yaml \
|
||||
--test_data data/test/data.list \
|
||||
--batch_size 256 \
|
||||
--checkpoint $score_checkpoint \
|
||||
--score_file $result_dir/score.txt \
|
||||
--num_workers 8
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# Compute detection error tradeoff
|
||||
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||
first_keyword=0
|
||||
last_keyword=$(($num_keywords+$first_keyword-1))
|
||||
for keyword in $(seq $first_keyword $last_keyword); do
|
||||
python kws/bin/compute_det.py \
|
||||
--keyword $keyword \
|
||||
--test_data data/test/data.list \
|
||||
--score_file $result_dir/score.txt \
|
||||
--stats_file $result_dir/stats.${keyword}.txt
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
python kws/bin/export_jit.py --config $dir/config.yaml \
|
||||
--checkpoint $score_checkpoint \
|
||||
--output_file $dir/final.zip \
|
||||
--output_quant_file $dir/final.quant.zip
|
||||
fi
|
||||
|
||||
1
examples/hey_snips/s0/tools
Symbolic link
1
examples/hey_snips/s0/tools
Symbolic link
@ -0,0 +1 @@
|
||||
../../../tools
|
||||
Loading…
x
Reference in New Issue
Block a user