diff --git a/examples/hi_xiaowen/s0/README.md b/examples/hi_xiaowen/s0/README.md index 15084d5..b88ad49 100644 --- a/examples/hi_xiaowen/s0/README.md +++ b/examples/hi_xiaowen/s0/README.md @@ -8,3 +8,5 @@ FRRs with FAR fixed at once per hour: | DS_TCN | 21 | 80 | 0.010807 | 0.014754 | | DS_TCN | 21 | 80(avg30) | 0.009867 | 0.014472 | | DS_TCN(spec_aug) | 21 | 80(avg30) | 0.029039 | 0.022648 | +| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 | +| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 | diff --git a/examples/hi_xiaowen/s0/conf/mdtc_small.yaml b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml new file mode 100644 index 0000000..937f1ce --- /dev/null +++ b/examples/hi_xiaowen/s0/conf/mdtc_small.yaml @@ -0,0 +1,51 @@ +dataset_conf: + filter_conf: + max_length: 2048 + min_length: 0 + resample_conf: + resample_rate: 16000 + speed_perturb: false + feature_extraction_conf: + feature_type: 'mfcc' + num_ceps: 80 + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + feature_dither: 0.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 1 + num_f_mask: 1 + max_t: 20 + max_f: 40 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + batch_conf: + batch_size: 100 + +model: + hidden_dim: 32 + preprocessing: + type: none + backbone: + type: mdtc + num_stack: 3 + stack_size: 4 + kernel_size: 5 + hidden_dim: 32 + classifier: + type: linear + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 5e-5 + warm_up_step: 2500 + +training_config: + grad_clip: 5 + max_epoch: 100 + log_interval: 10 + criterion: max_pooling diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index 27987f8..fe4b782 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -9,13 +9,13 @@ stage=0 stop_stage=4 num_keywords=2 -config=conf/mdtc.yaml +config=conf/mdtc_small.yaml norm_mean=false norm_var=false gpu_id=0 checkpoint= -dir=exp/mdtc +dir=exp/mdtc_small num_average=10 score_checkpoint=$dir/avg_${num_average}.pt diff --git a/kws/model/loss.py b/kws/model/loss.py index bd34951..d73722c 100644 --- a/kws/model/loss.py +++ b/kws/model/loss.py @@ -21,7 +21,7 @@ def max_polling_loss(logits: torch.Tensor, target: torch.Tensor, lengths: torch.Tensor, min_duration: int = 0): - """ Max-pooling loss + ''' Max-pooling loss For keyword, select the frame with the highest posterior. The keyword is triggered when any of the frames is triggered. For none keyword, select the hardest frame, namely the frame @@ -36,7 +36,7 @@ def max_polling_loss(logits: torch.Tensor, Returns: (float): loss of current batch (float): accuracy of current batch - """ + ''' mask = padding_mask(lengths) num_utts = logits.size(0) num_keywords = logits.size(2)