fix bug
This commit is contained in:
parent
8d6aee9307
commit
7936343720
@ -44,3 +44,4 @@ training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 100
|
||||
log_interval: 10
|
||||
criterion: RHE
|
||||
@ -22,3 +22,12 @@ class LastClassifier(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x[:, -1, :]
|
||||
return self.classifier(x)
|
||||
|
||||
class ElementClassifier(nn.Module):
|
||||
"""Classify all the frames in an utterance"""
|
||||
def __init__(self, classifier: nn.Module):
|
||||
super(ElementClassifier, self).__init__()
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.classifier(x)
|
||||
@ -74,7 +74,7 @@ def downsample_training_sample_and_calculate_loss(logits,
|
||||
def max_pooling_RHE_binary_CE(logits,
|
||||
targets,
|
||||
lengths,
|
||||
RHE_thr=10000,
|
||||
RHE_thr=100,
|
||||
max_ratio=1):
|
||||
"""Max-pooling loss with regional hard example mining
|
||||
For each keyword utterance, select the frame with the highest posterior.
|
||||
|
||||
@ -18,11 +18,13 @@ import torch
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from kws.model.max_pooling import max_pooling_loss
|
||||
from kws.model.max_pooling_RHE import max_pooling_RHE_binary_CE
|
||||
from kws.model.ce import cross_entropy
|
||||
|
||||
|
||||
criterion_dict = {'CE': cross_entropy,
|
||||
'max_pooling': max_pooling_loss}
|
||||
'max_pooling': max_pooling_loss,
|
||||
'RHE': max_pooling_RHE_binary_CE}
|
||||
|
||||
class Executor:
|
||||
def __init__(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user