Merge pull request #9 from wenet-e2e/dev-jingyonghou

add results of mdtc
This commit is contained in:
Binbin Zhang 2021-11-19 17:01:23 +08:00 committed by GitHub
commit f7fd62db7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 4 deletions

View File

@ -8,3 +8,5 @@ FRRs with FAR fixed at once per hour:
| DS_TCN | 21 | 80 | 0.010807 | 0.014754 |
| DS_TCN | 21 | 80(avg30) | 0.009867 | 0.014472 |
| DS_TCN(spec_aug) | 21 | 80(avg30) | 0.029039 | 0.022648 |
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |

View File

@ -0,0 +1,51 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'mfcc'
num_ceps: 80
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
feature_dither: 0.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 40
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 100
model:
hidden_dim: 32
preprocessing:
type: none
backbone:
type: mdtc
num_stack: 3
stack_size: 4
kernel_size: 5
hidden_dim: 32
classifier:
type: linear
optim: adam
optim_conf:
lr: 0.001
weight_decay: 5e-5
warm_up_step: 2500
training_config:
grad_clip: 5
max_epoch: 100
log_interval: 10
criterion: max_pooling

View File

@ -9,13 +9,13 @@ stage=0
stop_stage=4
num_keywords=2
config=conf/mdtc.yaml
config=conf/mdtc_small.yaml
norm_mean=false
norm_var=false
gpu_id=0
checkpoint=
dir=exp/mdtc
dir=exp/mdtc_small
num_average=10
score_checkpoint=$dir/avg_${num_average}.pt

View File

@ -21,7 +21,7 @@ def max_polling_loss(logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
""" Max-pooling loss
''' Max-pooling loss
For keyword, select the frame with the highest posterior.
The keyword is triggered when any of the frames is triggered.
For none keyword, select the hardest frame, namely the frame
@ -36,7 +36,7 @@ def max_polling_loss(logits: torch.Tensor,
Returns:
(float): loss of current batch
(float): accuracy of current batch
"""
'''
mask = padding_mask(lengths)
num_utts = logits.size(0)
num_keywords = logits.size(2)