[example] support hey_snips_kws_4.0 dataset (#38)
* [example] support hey_snips_kws_4.0 dataset * format * format
This commit is contained in:
parent
1eda27647b
commit
4a875776e5
6
examples/hey_snips/s0/README.md
Normal file
6
examples/hey_snips/s0/README.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
FRRs with FAR fixed at once per hour:
|
||||||
|
|
||||||
|
| model | params(K) | epoch | hey_snips |
|
||||||
|
|------------------|-----------|------------|------------|
|
||||||
|
| DS_TCN | 21 | 80(avg30) | 0.019771 |
|
||||||
|
| MDTC_Small | 31 | 100(avg10) | 0.008699 |
|
||||||
45
examples/hey_snips/s0/conf/ds_tcn.yaml
Normal file
45
examples/hey_snips/s0/conf/ds_tcn.yaml
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
dataset_conf:
|
||||||
|
filter_conf:
|
||||||
|
max_length: 2048
|
||||||
|
min_length: 0
|
||||||
|
resample_conf:
|
||||||
|
resample_rate: 16000
|
||||||
|
speed_perturb: false
|
||||||
|
feature_extraction_conf:
|
||||||
|
feature_type: 'fbank'
|
||||||
|
num_mel_bins: 40
|
||||||
|
frame_shift: 10
|
||||||
|
frame_length: 25
|
||||||
|
dither: 1.0
|
||||||
|
spec_aug: true
|
||||||
|
spec_aug_conf:
|
||||||
|
num_t_mask: 1
|
||||||
|
num_f_mask: 1
|
||||||
|
max_t: 50
|
||||||
|
max_f: 30
|
||||||
|
shuffle: true
|
||||||
|
shuffle_conf:
|
||||||
|
shuffle_size: 1500
|
||||||
|
batch_conf:
|
||||||
|
batch_size: 256
|
||||||
|
|
||||||
|
model:
|
||||||
|
hidden_dim: 64
|
||||||
|
preprocessing:
|
||||||
|
type: linear
|
||||||
|
backbone:
|
||||||
|
type: tcn
|
||||||
|
ds: true
|
||||||
|
num_layers: 4
|
||||||
|
kernel_size: 8
|
||||||
|
dropout: 0.1
|
||||||
|
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.001
|
||||||
|
|
||||||
|
training_config:
|
||||||
|
grad_clip: 5
|
||||||
|
max_epoch: 80
|
||||||
|
log_interval: 10
|
||||||
|
|
||||||
50
examples/hey_snips/s0/conf/mdtc_small.yaml
Normal file
50
examples/hey_snips/s0/conf/mdtc_small.yaml
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
dataset_conf:
|
||||||
|
filter_conf:
|
||||||
|
max_length: 2048
|
||||||
|
min_length: 0
|
||||||
|
resample_conf:
|
||||||
|
resample_rate: 16000
|
||||||
|
speed_perturb: false
|
||||||
|
feature_extraction_conf:
|
||||||
|
feature_type: 'mfcc'
|
||||||
|
num_ceps: 80
|
||||||
|
num_mel_bins: 80
|
||||||
|
frame_shift: 10
|
||||||
|
frame_length: 25
|
||||||
|
dither: 1.0
|
||||||
|
feature_dither: 0.0
|
||||||
|
spec_aug: true
|
||||||
|
spec_aug_conf:
|
||||||
|
num_t_mask: 1
|
||||||
|
num_f_mask: 1
|
||||||
|
max_t: 20
|
||||||
|
max_f: 40
|
||||||
|
shuffle: true
|
||||||
|
shuffle_conf:
|
||||||
|
shuffle_size: 1500
|
||||||
|
batch_conf:
|
||||||
|
batch_size: 100
|
||||||
|
|
||||||
|
model:
|
||||||
|
hidden_dim: 32
|
||||||
|
preprocessing:
|
||||||
|
type: none
|
||||||
|
backbone:
|
||||||
|
type: mdtc
|
||||||
|
num_stack: 3
|
||||||
|
stack_size: 4
|
||||||
|
kernel_size: 5
|
||||||
|
hidden_dim: 32
|
||||||
|
causal: True
|
||||||
|
|
||||||
|
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.001
|
||||||
|
weight_decay: 0.00005
|
||||||
|
|
||||||
|
training_config:
|
||||||
|
grad_clip: 5
|
||||||
|
max_epoch: 100
|
||||||
|
log_interval: 10
|
||||||
|
criterion: max_pooling
|
||||||
1
examples/hey_snips/s0/kws
Symbolic link
1
examples/hey_snips/s0/kws
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../kws
|
||||||
46
examples/hey_snips/s0/local/prepare_data.py
Executable file
46
examples/hey_snips/s0/local/prepare_data.py
Executable file
@ -0,0 +1,46 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright 2018-2020 Yiming Wang
|
||||||
|
# 2018-2020 Daniel Povey
|
||||||
|
# 2021 Binbin Zhang
|
||||||
|
# Menglong Xu
|
||||||
|
# Apache 2.0
|
||||||
|
""" This script prepares the snips data into kaldi format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="""Prepare data.""")
|
||||||
|
parser.add_argument('wav_dir',
|
||||||
|
type=str,
|
||||||
|
help='dir containing all the wav files')
|
||||||
|
parser.add_argument('path', type=str, help='path to the json file')
|
||||||
|
parser.add_argument('out_dir', type=str, help='out dir')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
utt_id, label = [], []
|
||||||
|
for entry in data:
|
||||||
|
if entry['duration'] > 0:
|
||||||
|
utt_id.append(entry['id'])
|
||||||
|
keyword_id = 0 if entry['is_hotword'] == 1 else -1
|
||||||
|
label.append(keyword_id)
|
||||||
|
|
||||||
|
abs_dir = os.path.abspath(args.wav_dir)
|
||||||
|
wav_path = os.path.join(args.out_dir, 'wav.scp')
|
||||||
|
text_path = os.path.join(args.out_dir, 'text')
|
||||||
|
with open(wav_path, 'w', encoding='utf-8') as f_wav, \
|
||||||
|
open(text_path, 'w', encoding='utf-8') as f_text:
|
||||||
|
for utt, l in zip(utt_id, label):
|
||||||
|
f_wav.write('{} {}\n'.format(utt,
|
||||||
|
os.path.join(abs_dir, utt + ".wav")))
|
||||||
|
f_text.write('{} {}\n'.format(utt, l))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
30
examples/hey_snips/s0/local/snips_data_extract.sh
Executable file
30
examples/hey_snips/s0/local/snips_data_extract.sh
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright 2018-2020 Yiming Wang
|
||||||
|
# 2018-2020 Daniel Povey
|
||||||
|
# 2021 Binbin Zhang
|
||||||
|
# Menglong Xu
|
||||||
|
|
||||||
|
[ -f ./path.sh ] && . ./path.sh
|
||||||
|
|
||||||
|
dl_dir=data/download
|
||||||
|
|
||||||
|
. tools/parse_options.sh || exit 1;
|
||||||
|
|
||||||
|
mkdir -p $dl_dir
|
||||||
|
|
||||||
|
# Fill the following form:
|
||||||
|
# https://forms.gle/JtmFYM7xK1SaMfZYA
|
||||||
|
# to download the dataset
|
||||||
|
dataset=hey_snips_kws_4.0.tar.gz
|
||||||
|
src_path=$dl_dir
|
||||||
|
|
||||||
|
if [ -d $dl_dir/$(basename "$dataset" .tar.gz) ]; then
|
||||||
|
echo "Not extracting $(basename "$dataset" .tar.gz) as it is already there."
|
||||||
|
else
|
||||||
|
echo "Extracting $dataset..."
|
||||||
|
tar -xvzf $src_path/$dataset -C $dl_dir || exit 1;
|
||||||
|
echo "Done extracting $dataset."
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
||||||
5
examples/hey_snips/s0/path.sh
Executable file
5
examples/hey_snips/s0/path.sh
Executable file
@ -0,0 +1,5 @@
|
|||||||
|
export PATH=$PWD:$PATH
|
||||||
|
|
||||||
|
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
export PYTHONPATH=../../../:$PYTHONPATH
|
||||||
125
examples/hey_snips/s0/run.sh
Executable file
125
examples/hey_snips/s0/run.sh
Executable file
@ -0,0 +1,125 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Binbin Zhang
|
||||||
|
# Menglong Xu
|
||||||
|
|
||||||
|
. ./path.sh
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=4
|
||||||
|
num_keywords=1
|
||||||
|
|
||||||
|
config=conf/ds_tcn.yaml
|
||||||
|
norm_mean=true
|
||||||
|
norm_var=true
|
||||||
|
gpu_id=0
|
||||||
|
|
||||||
|
checkpoint=
|
||||||
|
dir=exp/ds_tcn
|
||||||
|
|
||||||
|
num_average=30
|
||||||
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
|
||||||
|
download_dir=./data/local # your data dir
|
||||||
|
|
||||||
|
. tools/parse_options.sh || exit 1;
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||||
|
echo "Extracte all datasets"
|
||||||
|
local/snips_data_extract.sh --dl_dir $download_dir
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
echo "Preparing datasets..."
|
||||||
|
mkdir -p dict
|
||||||
|
echo "<filler> -1" > dict/words.txt
|
||||||
|
echo "Hey_Snips 0" >> dict/words.txt
|
||||||
|
|
||||||
|
for folder in train dev test; do
|
||||||
|
mkdir -p data/$folder
|
||||||
|
json_path=$download_dir/hey_snips_research_6k_en_train_eval_clean_ter/$folder.json
|
||||||
|
local/prepare_data.py $download_dir/hey_snips_research_6k_en_train_eval_clean_ter/audio_files $json_path \
|
||||||
|
data/$folder
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
echo "Compute CMVN and Format datasets"
|
||||||
|
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
|
||||||
|
--in_scp data/train/wav.scp \
|
||||||
|
--out_cmvn data/train/global_cmvn
|
||||||
|
|
||||||
|
for x in train dev test; do
|
||||||
|
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
||||||
|
tools/make_list.py data/$x/wav.scp data/$x/text \
|
||||||
|
data/$x/wav.dur data/$x/data.list
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
echo "Start training ..."
|
||||||
|
mkdir -p $dir
|
||||||
|
cmvn_opts=
|
||||||
|
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||||
|
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||||
|
python kws/bin/train.py --gpu $gpu_id \
|
||||||
|
--config $config \
|
||||||
|
--train_data data/train/data.list \
|
||||||
|
--cv_data data/dev/data.list \
|
||||||
|
--model_dir $dir \
|
||||||
|
--num_workers 8 \
|
||||||
|
--num_keywords $num_keywords \
|
||||||
|
--min_duration 50 \
|
||||||
|
--seed 777 \
|
||||||
|
$cmvn_opts \
|
||||||
|
${checkpoint:+--checkpoint $checkpoint}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# Do model average
|
||||||
|
python kws/bin/average_model.py \
|
||||||
|
--dst_model $score_checkpoint \
|
||||||
|
--src_path $dir \
|
||||||
|
--num ${num_average} \
|
||||||
|
--val_best
|
||||||
|
|
||||||
|
# Compute posterior score
|
||||||
|
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||||
|
mkdir -p $result_dir
|
||||||
|
python kws/bin/score.py --gpu $gpu_id \
|
||||||
|
--config $dir/config.yaml \
|
||||||
|
--test_data data/test/data.list \
|
||||||
|
--batch_size 256 \
|
||||||
|
--checkpoint $score_checkpoint \
|
||||||
|
--score_file $result_dir/score.txt \
|
||||||
|
--num_workers 8
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
|
# Compute detection error tradeoff
|
||||||
|
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||||
|
first_keyword=0
|
||||||
|
last_keyword=$(($num_keywords+$first_keyword-1))
|
||||||
|
for keyword in $(seq $first_keyword $last_keyword); do
|
||||||
|
python kws/bin/compute_det.py \
|
||||||
|
--keyword $keyword \
|
||||||
|
--test_data data/test/data.list \
|
||||||
|
--score_file $result_dir/score.txt \
|
||||||
|
--stats_file $result_dir/stats.${keyword}.txt
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
|
python kws/bin/export_jit.py --config $dir/config.yaml \
|
||||||
|
--checkpoint $score_checkpoint \
|
||||||
|
--output_file $dir/final.zip \
|
||||||
|
--output_quant_file $dir/final.quant.zip
|
||||||
|
fi
|
||||||
|
|
||||||
1
examples/hey_snips/s0/tools
Symbolic link
1
examples/hey_snips/s0/tools
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../tools
|
||||||
Loading…
x
Reference in New Issue
Block a user