From edfc6de7435a1f7b6621ae0044ccf33cd23da395 Mon Sep 17 00:00:00 2001 From: jingyong hou Date: Fri, 19 Nov 2021 15:31:11 +0800 Subject: [PATCH] add results of mdtc --- examples/hi_xiaowen/s0/README.md | 2 + examples/hi_xiaowen/s0/conf/mdtc_small.yaml | 51 +++++++++++++++++++++ examples/hi_xiaowen/s0/run.sh | 4 +- kws/model/loss.py | 4 +- 4 files changed, 57 insertions(+), 4 deletions(-) create mode 100644 examples/hi_xiaowen/s0/conf/mdtc_small.yaml diff --git a/examples/hi_xiaowen/s0/README.md b/examples/hi_xiaowen/s0/README.md index 15084d5..b88ad49 100644 --- a/examples/hi_xiaowen/s0/README.md +++ b/examples/hi_xiaowen/s0/README.md @@ -8,3 +8,5 @@ FRRs with FAR fixed at once per hour: | DS_TCN | 21 | 80 | 0.010807 | 0.014754 | | DS_TCN | 21 | 80(avg30) | 0.009867 | 0.014472 | | DS_TCN(spec_aug) | 21 | 80(avg30) | 0.029039 | 0.022648 | +| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 | +| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 | diff --git a/examples/hi_xiaowen/s0/conf/mdtc_small.yaml b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml new file mode 100644 index 0000000..937f1ce --- /dev/null +++ b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml @@ -0,0 +1,51 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'mfcc' + num_ceps: 80 + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + feature_dither: 0.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 1 + num_f_mask: 1 + max_t: 20 + max_f: 40 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 100 + +model: + hidden_dim: 32 + preprocessing: + type: none + backbone: + type: mdtc + num_stack: 3 + stack_size: 4 + kernel_size: 5 + hidden_dim: 32 + classifier: + type: linear + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 5e-5 + warm_up_step: 2500 + +training_config: + grad_clip: 5 + max_epoch: 100 + log_interval: 10 + criterion: max_pooling diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 27987f8..fe4b782 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -9,13 +9,13 @@ stage=0 stop_stage=4 num_keywords=2 -config=conf/mdtc.yaml +config=conf/mdtc_small.yaml norm_mean=false norm_var=false gpu_id=0 checkpoint= -dir=exp/mdtc +dir=exp/mdtc_small num_average=10 score_checkpoint=$dir/avg_${num_average}.pt diff --git a/kws/model/loss.py b/kws/model/loss.py index bd34951..d73722c 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -21,7 +21,7 @@ def max_polling_loss(logits: torch.Tensor, target: torch.Tensor, lengths: torch.Tensor, min_duration: int = 0): - """ Max-pooling loss + ''' Max-pooling loss For keyword, select the frame with the highest posterior. The keyword is triggered when any of the frames is triggered. For none keyword, select the hardest frame, namely the frame @@ -36,7 +36,7 @@ def max_polling_loss(logits: torch.Tensor, Returns: (float): loss of current batch (float): accuracy of current batch - """ + ''' mask = padding_mask(lengths) num_utts = logits.size(0) num_keywords = logits.size(2)