[examples] update to use torchrun launch (#50)

This commit is contained in:
Menglong Xu 2021-12-15 21:03:59 +08:00 committed by GitHub
parent 566baca343
commit 6a58993390
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,8 +4,6 @@
. ./path.sh . ./path.sh
export CUDA_VISIBLE_DEVICES="0"
stage=0 stage=0
stop_stage=4 stop_stage=4
num_keywords=1 num_keywords=1
@ -13,7 +11,7 @@ num_keywords=1
config=conf/ds_tcn.yaml config=conf/ds_tcn.yaml
norm_mean=true norm_mean=true
norm_var=true norm_var=true
gpu_id=0 gpus="0"
checkpoint= checkpoint=
dir=exp/ds_tcn dir=exp/ds_tcn
@ -68,7 +66,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmvn_opts= cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn" $norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var" $norm_var && cmvn_opts="$cmvn_opts --norm_var"
python kws/bin/train.py --gpu $gpu_id \ num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
kws/bin/train.py --gpus $gpus \
--config $config \ --config $config \
--train_data data/train/data.list \ --train_data data/train/data.list \
--cv_data data/dev/data.list \ --cv_data data/dev/data.list \
@ -82,28 +82,21 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# Do model average echo "Do model average, Compute FRR/FAR ..."
python kws/bin/average_model.py \ python kws/bin/average_model.py \
--dst_model $score_checkpoint \ --dst_model $score_checkpoint \
--src_path $dir \ --src_path $dir \
--num ${num_average} \ --num ${num_average} \
--val_best --val_best
# Compute posterior score
result_dir=$dir/test_$(basename $score_checkpoint) result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir mkdir -p $result_dir
python kws/bin/score.py --gpu $gpu_id \ python kws/bin/score.py \
--config $dir/config.yaml \ --config $dir/config.yaml \
--test_data data/test/data.list \ --test_data data/test/data.list \
--batch_size 256 \ --batch_size 256 \
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \ --score_file $result_dir/score.txt \
--num_workers 8 --num_workers 8
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# Compute detection error tradeoff
result_dir=$dir/test_$(basename $score_checkpoint)
first_keyword=0 first_keyword=0
last_keyword=$(($num_keywords+$first_keyword-1)) last_keyword=$(($num_keywords+$first_keyword-1))
for keyword in $(seq $first_keyword $last_keyword); do for keyword in $(seq $first_keyword $last_keyword); do