QA run.sh, maxpooling training scripts is compatible. Ready to PR.

This commit is contained in:
dujing 2023-06-03 14:50:08 +08:00
parent 0ac7a417b1
commit b1b64a4c6a
3 changed files with 5 additions and 4 deletions

View File

@ -3,8 +3,8 @@
. ./path.sh
stage=0
stop_stage=4
stage=$1
stop_stage=$2
num_keywords=2
config=conf/ds_tcn.yaml
@ -98,6 +98,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
python wekws/bin/score.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \

View File

@ -41,7 +41,7 @@ def main():
configs = yaml.load(fin, Loader=yaml.FullLoader)
feature_dim = configs['model']['input_dim']
model = init_model(configs['model'])
if configs['training_config']['criterion'] == 'ctc':
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
# if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search
model.forward = model.forward_softmax
print(model)

View File

@ -106,7 +106,7 @@ def main():
score_abs_path = os.path.abspath(args.score_file)
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, lengths = batch
keys, feats, target, lengths, target_lengths = batch
feats = feats.to(device)
lengths = lengths.to(device)
logits, _ = model(feats)