Merge pull request #9 from wenet-e2e/dev-jingyonghou
add results of mdtc
This commit is contained in:
commit
f7fd62db7d
@ -8,3 +8,5 @@ FRRs with FAR fixed at once per hour:
|
|||||||
| DS_TCN | 21 | 80 | 0.010807 | 0.014754 |
|
| DS_TCN | 21 | 80 | 0.010807 | 0.014754 |
|
||||||
| DS_TCN | 21 | 80(avg30) | 0.009867 | 0.014472 |
|
| DS_TCN | 21 | 80(avg30) | 0.009867 | 0.014472 |
|
||||||
| DS_TCN(spec_aug) | 21 | 80(avg30) | 0.029039 | 0.022648 |
|
| DS_TCN(spec_aug) | 21 | 80(avg30) | 0.029039 | 0.022648 |
|
||||||
|
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
|
||||||
|
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
|
||||||
|
|||||||
51
examples/hi_xiaowen/s0/conf/mdtc_small.yaml
Normal file
51
examples/hi_xiaowen/s0/conf/mdtc_small.yaml
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
dataset_conf:
|
||||||
|
filter_conf:
|
||||||
|
max_length: 2048
|
||||||
|
min_length: 0
|
||||||
|
resample_conf:
|
||||||
|
resample_rate: 16000
|
||||||
|
speed_perturb: false
|
||||||
|
feature_extraction_conf:
|
||||||
|
feature_type: 'mfcc'
|
||||||
|
num_ceps: 80
|
||||||
|
num_mel_bins: 80
|
||||||
|
frame_shift: 10
|
||||||
|
frame_length: 25
|
||||||
|
dither: 1.0
|
||||||
|
feature_dither: 0.0
|
||||||
|
spec_aug: true
|
||||||
|
spec_aug_conf:
|
||||||
|
num_t_mask: 1
|
||||||
|
num_f_mask: 1
|
||||||
|
max_t: 20
|
||||||
|
max_f: 40
|
||||||
|
shuffle: true
|
||||||
|
shuffle_conf:
|
||||||
|
shuffle_size: 1500
|
||||||
|
batch_conf:
|
||||||
|
batch_size: 100
|
||||||
|
|
||||||
|
model:
|
||||||
|
hidden_dim: 32
|
||||||
|
preprocessing:
|
||||||
|
type: none
|
||||||
|
backbone:
|
||||||
|
type: mdtc
|
||||||
|
num_stack: 3
|
||||||
|
stack_size: 4
|
||||||
|
kernel_size: 5
|
||||||
|
hidden_dim: 32
|
||||||
|
classifier:
|
||||||
|
type: linear
|
||||||
|
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.001
|
||||||
|
weight_decay: 5e-5
|
||||||
|
warm_up_step: 2500
|
||||||
|
|
||||||
|
training_config:
|
||||||
|
grad_clip: 5
|
||||||
|
max_epoch: 100
|
||||||
|
log_interval: 10
|
||||||
|
criterion: max_pooling
|
||||||
@ -9,13 +9,13 @@ stage=0
|
|||||||
stop_stage=4
|
stop_stage=4
|
||||||
num_keywords=2
|
num_keywords=2
|
||||||
|
|
||||||
config=conf/mdtc.yaml
|
config=conf/mdtc_small.yaml
|
||||||
norm_mean=false
|
norm_mean=false
|
||||||
norm_var=false
|
norm_var=false
|
||||||
gpu_id=0
|
gpu_id=0
|
||||||
|
|
||||||
checkpoint=
|
checkpoint=
|
||||||
dir=exp/mdtc
|
dir=exp/mdtc_small
|
||||||
|
|
||||||
num_average=10
|
num_average=10
|
||||||
score_checkpoint=$dir/avg_${num_average}.pt
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
|||||||
@ -21,7 +21,7 @@ def max_polling_loss(logits: torch.Tensor,
|
|||||||
target: torch.Tensor,
|
target: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
lengths: torch.Tensor,
|
||||||
min_duration: int = 0):
|
min_duration: int = 0):
|
||||||
""" Max-pooling loss
|
''' Max-pooling loss
|
||||||
For keyword, select the frame with the highest posterior.
|
For keyword, select the frame with the highest posterior.
|
||||||
The keyword is triggered when any of the frames is triggered.
|
The keyword is triggered when any of the frames is triggered.
|
||||||
For none keyword, select the hardest frame, namely the frame
|
For none keyword, select the hardest frame, namely the frame
|
||||||
@ -36,7 +36,7 @@ def max_polling_loss(logits: torch.Tensor,
|
|||||||
Returns:
|
Returns:
|
||||||
(float): loss of current batch
|
(float): loss of current batch
|
||||||
(float): accuracy of current batch
|
(float): accuracy of current batch
|
||||||
"""
|
'''
|
||||||
mask = padding_mask(lengths)
|
mask = padding_mask(lengths)
|
||||||
num_utts = logits.size(0)
|
num_utts = logits.size(0)
|
||||||
num_keywords = logits.size(2)
|
num_keywords = logits.size(2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user