[examples] update to use torchrun launch (#60)

This commit is contained in:
Menglong Xu 2022-02-11 14:51:00 +08:00 committed by GitHub
parent db2685d1a4
commit d805c55560
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,8 +4,6 @@
. ./path.sh . ./path.sh
export CUDA_VISIBLE_DEVICES="0"
stage=-1 stage=-1
stop_stage=4 stop_stage=4
num_keywords=11 num_keywords=11
@ -13,7 +11,7 @@ num_keywords=11
config=conf/mdtc.yaml config=conf/mdtc.yaml
norm_mean=false norm_mean=false
norm_var=false norm_var=false
gpu_id=0 gpus="0"
checkpoint= checkpoint=
dir=exp/mdtc dir=exp/mdtc
@ -22,7 +20,7 @@ num_average=10
score_checkpoint=$dir/avg_${num_average}.pt score_checkpoint=$dir/avg_${num_average}.pt
# your data dir # your data dir
download_dir=/mnt/mnt-data-3/jingyong.hou/data download_dir=./data/local
speech_command_dir=$download_dir/speech_commands_v1 speech_command_dir=$download_dir/speech_commands_v1
. tools/parse_options.sh || exit 1; . tools/parse_options.sh || exit 1;
@ -68,7 +66,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
cmvn_opts= cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn" $norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var" $norm_var && cmvn_opts="$cmvn_opts --norm_var"
python kws/bin/train.py --gpu $gpu_id \ num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
kws/bin/train.py --gpus $gpus \
--config $config \ --config $config \
--train_data data/train/data.list \ --train_data data/train/data.list \
--cv_data data/valid/data.list \ --cv_data data/valid/data.list \