This commit is contained in:
Jingyong Hou 2021-11-16 10:10:08 +08:00
parent 8d6aee9307
commit 7936343720
4 changed files with 14 additions and 2 deletions

View File

@ -44,3 +44,4 @@ training_config:
grad_clip: 5
max_epoch: 100
log_interval: 10
criterion: RHE

View File

@ -22,3 +22,12 @@ class LastClassifier(nn.Module):
def forward(self, x: torch.Tensor):
x = x[:, -1, :]
return self.classifier(x)
class ElementClassifier(nn.Module):
"""Classify all the frames in an utterance"""
def __init__(self, classifier: nn.Module):
super(ElementClassifier, self).__init__()
self.classifier = classifier
def forward(self, x: torch.Tensor):
return self.classifier(x)

View File

@ -74,7 +74,7 @@ def downsample_training_sample_and_calculate_loss(logits,
def max_pooling_RHE_binary_CE(logits,
targets,
lengths,
RHE_thr=10000,
RHE_thr=100,
max_ratio=1):
"""Max-pooling loss with regional hard example mining
For each keyword utterance, select the frame with the highest posterior.

View File

@ -18,11 +18,13 @@ import torch
from torch.nn.utils import clip_grad_norm_
from kws.model.max_pooling import max_pooling_loss
from kws.model.max_pooling_RHE import max_pooling_RHE_binary_CE
from kws.model.ce import cross_entropy
criterion_dict = {'CE': cross_entropy,
'max_pooling': max_pooling_loss}
'max_pooling': max_pooling_loss,
'RHE': max_pooling_RHE_binary_CE}
class Executor:
def __init__(self):