Merge pull request #9 from wenet-e2e/dev-jingyonghou
add results of mdtc
This commit is contained in:
commit
f7fd62db7d
@ -8,3 +8,5 @@ FRRs with FAR fixed at once per hour:
|
||||
| DS_TCN | 21 | 80 | 0.010807 | 0.014754 |
|
||||
| DS_TCN | 21 | 80(avg30) | 0.009867 | 0.014472 |
|
||||
| DS_TCN(spec_aug) | 21 | 80(avg30) | 0.029039 | 0.022648 |
|
||||
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
||||
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
||||
|
||||
51
examples/hi_xiaowen/s0/conf/mdtc_small.yaml
Normal file
51
examples/hi_xiaowen/s0/conf/mdtc_small.yaml
Normal file
@ -0,0 +1,51 @@
|
||||
dataset_conf:
|
||||
filter_conf:
|
||||
max_length: 2048
|
||||
min_length: 0
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
feature_extraction_conf:
|
||||
feature_type: 'mfcc'
|
||||
num_ceps: 80
|
||||
num_mel_bins: 80
|
||||
frame_shift: 10
|
||||
frame_length: 25
|
||||
dither: 1.0
|
||||
feature_dither: 0.0
|
||||
spec_aug: true
|
||||
spec_aug_conf:
|
||||
num_t_mask: 1
|
||||
num_f_mask: 1
|
||||
max_t: 20
|
||||
max_f: 40
|
||||
shuffle: true
|
||||
shuffle_conf:
|
||||
shuffle_size: 1500
|
||||
batch_conf:
|
||||
batch_size: 100
|
||||
|
||||
model:
|
||||
hidden_dim: 32
|
||||
preprocessing:
|
||||
type: none
|
||||
backbone:
|
||||
type: mdtc
|
||||
num_stack: 3
|
||||
stack_size: 4
|
||||
kernel_size: 5
|
||||
hidden_dim: 32
|
||||
classifier:
|
||||
type: linear
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001
|
||||
weight_decay: 5e-5
|
||||
warm_up_step: 2500
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 100
|
||||
log_interval: 10
|
||||
criterion: max_pooling
|
||||
@ -9,13 +9,13 @@ stage=0
|
||||
stop_stage=4
|
||||
num_keywords=2
|
||||
|
||||
config=conf/mdtc.yaml
|
||||
config=conf/mdtc_small.yaml
|
||||
norm_mean=false
|
||||
norm_var=false
|
||||
gpu_id=0
|
||||
|
||||
checkpoint=
|
||||
dir=exp/mdtc
|
||||
dir=exp/mdtc_small
|
||||
|
||||
num_average=10
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
|
||||
@ -21,7 +21,7 @@ def max_polling_loss(logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
min_duration: int = 0):
|
||||
""" Max-pooling loss
|
||||
''' Max-pooling loss
|
||||
For keyword, select the frame with the highest posterior.
|
||||
The keyword is triggered when any of the frames is triggered.
|
||||
For none keyword, select the hardest frame, namely the frame
|
||||
@ -36,7 +36,7 @@ def max_polling_loss(logits: torch.Tensor,
|
||||
Returns:
|
||||
(float): loss of current batch
|
||||
(float): accuracy of current batch
|
||||
"""
|
||||
'''
|
||||
mask = padding_mask(lengths)
|
||||
num_utts = logits.size(0)
|
||||
num_keywords = logits.size(2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user