QA run.sh, maxpooling training scripts is compatible. Ready to PR.
This commit is contained in:
parent
0ac7a417b1
commit
b1b64a4c6a
@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
. ./path.sh
|
. ./path.sh
|
||||||
|
|
||||||
stage=0
|
stage=$1
|
||||||
stop_stage=4
|
stop_stage=$2
|
||||||
num_keywords=2
|
num_keywords=2
|
||||||
|
|
||||||
config=conf/ds_tcn.yaml
|
config=conf/ds_tcn.yaml
|
||||||
@ -98,6 +98,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
python wekws/bin/score.py \
|
python wekws/bin/score.py \
|
||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--test_data data/test/data.list \
|
--test_data data/test/data.list \
|
||||||
|
--gpu 0 \
|
||||||
--batch_size 256 \
|
--batch_size 256 \
|
||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
--score_file $result_dir/score.txt \
|
--score_file $result_dir/score.txt \
|
||||||
|
|||||||
@ -41,7 +41,7 @@ def main():
|
|||||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||||
feature_dim = configs['model']['input_dim']
|
feature_dim = configs['model']['input_dim']
|
||||||
model = init_model(configs['model'])
|
model = init_model(configs['model'])
|
||||||
if configs['training_config']['criterion'] == 'ctc':
|
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
|
||||||
# if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search
|
# if we use ctc_loss, the logits need to be convert into probs before ctc_prefix_beam_search
|
||||||
model.forward = model.forward_softmax
|
model.forward = model.forward_softmax
|
||||||
print(model)
|
print(model)
|
||||||
|
|||||||
@ -106,7 +106,7 @@ def main():
|
|||||||
score_abs_path = os.path.abspath(args.score_file)
|
score_abs_path = os.path.abspath(args.score_file)
|
||||||
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
|
||||||
for batch_idx, batch in enumerate(test_data_loader):
|
for batch_idx, batch in enumerate(test_data_loader):
|
||||||
keys, feats, target, lengths = batch
|
keys, feats, target, lengths, target_lengths = batch
|
||||||
feats = feats.to(device)
|
feats = feats.to(device)
|
||||||
lengths = lengths.to(device)
|
lengths = lengths.to(device)
|
||||||
logits, _ = model(feats)
|
logits, _ = model(feats)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user