[examples] update to use torchrun launch (#60)
This commit is contained in:
parent
db2685d1a4
commit
d805c55560
@ -4,8 +4,6 @@
|
||||
|
||||
. ./path.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
stage=-1
|
||||
stop_stage=4
|
||||
num_keywords=11
|
||||
@ -13,7 +11,7 @@ num_keywords=11
|
||||
config=conf/mdtc.yaml
|
||||
norm_mean=false
|
||||
norm_var=false
|
||||
gpu_id=0
|
||||
gpus="0"
|
||||
|
||||
checkpoint=
|
||||
dir=exp/mdtc
|
||||
@ -22,7 +20,7 @@ num_average=10
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
|
||||
# your data dir
|
||||
download_dir=/mnt/mnt-data-3/jingyong.hou/data
|
||||
download_dir=./data/local
|
||||
speech_command_dir=$download_dir/speech_commands_v1
|
||||
. tools/parse_options.sh || exit 1;
|
||||
|
||||
@ -68,7 +66,9 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
cmvn_opts=
|
||||
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||
python kws/bin/train.py --gpu $gpu_id \
|
||||
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
|
||||
kws/bin/train.py --gpus $gpus \
|
||||
--config $config \
|
||||
--train_data data/train/data.list \
|
||||
--cv_data data/valid/data.list \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user