This commit is contained in:
Jingyong Hou 2021-11-16 10:10:08 +08:00
parent 8d6aee9307
commit 7936343720
4 changed files with 14 additions and 2 deletions

View File

@ -44,3 +44,4 @@ training_config:
grad_clip: 5 grad_clip: 5
max_epoch: 100 max_epoch: 100
log_interval: 10 log_interval: 10
criterion: RHE

View File

@ -22,3 +22,12 @@ class LastClassifier(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = x[:, -1, :] x = x[:, -1, :]
return self.classifier(x) return self.classifier(x)
class ElementClassifier(nn.Module):
"""Classify all the frames in an utterance"""
def __init__(self, classifier: nn.Module):
super(ElementClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
return self.classifier(x)

View File

@ -74,7 +74,7 @@ def downsample_training_sample_and_calculate_loss(logits,
def max_pooling_RHE_binary_CE(logits, def max_pooling_RHE_binary_CE(logits,
targets, targets,
lengths, lengths,
RHE_thr=10000, RHE_thr=100,
max_ratio=1): max_ratio=1):
"""Max-pooling loss with regional hard example mining """Max-pooling loss with regional hard example mining
For each keyword utterance, select the frame with the highest posterior. For each keyword utterance, select the frame with the highest posterior.

View File

@ -18,11 +18,13 @@ import torch
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from kws.model.max_pooling import max_pooling_loss from kws.model.max_pooling import max_pooling_loss
from kws.model.max_pooling_RHE import max_pooling_RHE_binary_CE
from kws.model.ce import cross_entropy from kws.model.ce import cross_entropy
criterion_dict = {'CE': cross_entropy, criterion_dict = {'CE': cross_entropy,
'max_pooling': max_pooling_loss} 'max_pooling': max_pooling_loss,
'RHE': max_pooling_RHE_binary_CE}
class Executor: class Executor:
def __init__(self): def __init__(self):