QA run.sh, maxpooling training scripts is compatible. Ready to PR.
This commit is contained in:
parent
0ac7a417b1
commit
b1b64a4c6a
@ -3,8 +3,8 @@
|
||||
|
||||
. ./path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=4
|
||||
stage=$1
|
||||
stop_stage=$2
|
||||
num_keywords=2
|
||||
|
||||
config=conf/ds_tcn.yaml
|
||||
@ -98,6 +98,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
python wekws/bin/score.py \
|
||||
--config $dir/config.yaml \
|
||||
--test_data data/test/data.list \
|
||||
--gpu 0 \
|
||||
--batch_size 256 \
|
||||
--checkpoint $score_checkpoint \
|
||||
--score_file $result_dir/score.txt \
|
||||
|
||||
@ -41,7 +41,7 @@ def main():
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
feature_dim = configs['model']['input_dim']
|
||||
model = init_model(configs['model'])
|
||||
if configs['training_config']['criterion'] == 'ctc':
|
||||
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
|
||||
# if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search
|
||||
model.forward = model.forward_softmax
|
||||
print(model)
|
||||
|
||||
@ -106,7 +106,7 @@ def main():
|
||||
score_abs_path = os.path.abspath(args.score_file)
|
||||
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
||||
for batch_idx, batch in enumerate(test_data_loader):
|
||||
keys, feats, target, lengths = batch
|
||||
keys, feats, target, lengths, target_lengths = batch
|
||||
feats = feats.to(device)
|
||||
lengths = lengths.to(device)
|
||||
logits, _ = model(feats)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user