From b1b64a4c6a9cd5dd462abcdf8eb2d89de8eb2849 Mon Sep 17 00:00:00 2001 From: dujing Date: Sat, 3 Jun 2023 14:50:08 +0800 Subject: [PATCH] QA run.sh, maxpooling training scripts is compatible. Ready to PR. --- examples/hi_xiaowen/s0/run.sh | 5 +++-- wekws/bin/export_onnx.py | 2 +- wekws/bin/score.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 49b96d6..f1ab06c 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -3,8 +3,8 @@ . ./path.sh -stage=0 -stop_stage=4 +stage=$1 +stop_stage=$2 num_keywords=2 config=conf/ds_tcn.yaml @@ -98,6 +98,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then python wekws/bin/score.py \ --config $dir/config.yaml \ --test_data data/test/data.list \ + --gpu 0 \ --batch_size 256 \ --checkpoint $score_checkpoint \ --score_file $result_dir/score.txt \ diff --git a/wekws/bin/export_onnx.py b/wekws/bin/export_onnx.py index 4850d19..e45e4b2 100644 --- a/wekws/bin/export_onnx.py +++ b/wekws/bin/export_onnx.py @@ -41,7 +41,7 @@ def main(): configs = yaml.load(fin, Loader=yaml.FullLoader) feature_dim = configs['model']['input_dim'] model = init_model(configs['model']) - if configs['training_config']['criterion'] == 'ctc': + if configs['training_config'].get('criterion', 'max_pooling') == 'ctc': # if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search model.forward = model.forward_softmax print(model) diff --git a/wekws/bin/score.py b/wekws/bin/score.py index b2c0483..97e91e2 100644 --- a/wekws/bin/score.py +++ b/wekws/bin/score.py @@ -106,7 +106,7 @@ def main(): score_abs_path = os.path.abspath(args.score_file) with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout: for batch_idx, batch in enumerate(test_data_loader): - keys, feats, target, lengths = batch + keys, feats, target, lengths, target_lengths = batch feats = feats.to(device) lengths = lengths.to(device) logits, _ = model(feats)