fix bug
This commit is contained in:
parent
8d6aee9307
commit
7936343720
@ -44,3 +44,4 @@ training_config:
|
|||||||
grad_clip: 5
|
grad_clip: 5
|
||||||
max_epoch: 100
|
max_epoch: 100
|
||||||
log_interval: 10
|
log_interval: 10
|
||||||
|
criterion: RHE
|
||||||
@ -22,3 +22,12 @@ class LastClassifier(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
x = x[:, -1, :]
|
x = x[:, -1, :]
|
||||||
return self.classifier(x)
|
return self.classifier(x)
|
||||||
|
|
||||||
|
class ElementClassifier(nn.Module):
|
||||||
|
"""Classify all the frames in an utterance"""
|
||||||
|
def __init__(self, classifier: nn.Module):
|
||||||
|
super(ElementClassifier, self).__init__()
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return self.classifier(x)
|
||||||
@ -74,7 +74,7 @@ def downsample_training_sample_and_calculate_loss(logits,
|
|||||||
def max_pooling_RHE_binary_CE(logits,
|
def max_pooling_RHE_binary_CE(logits,
|
||||||
targets,
|
targets,
|
||||||
lengths,
|
lengths,
|
||||||
RHE_thr=10000,
|
RHE_thr=100,
|
||||||
max_ratio=1):
|
max_ratio=1):
|
||||||
"""Max-pooling loss with regional hard example mining
|
"""Max-pooling loss with regional hard example mining
|
||||||
For each keyword utterance, select the frame with the highest posterior.
|
For each keyword utterance, select the frame with the highest posterior.
|
||||||
|
|||||||
@ -18,11 +18,13 @@ import torch
|
|||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
|
||||||
from kws.model.max_pooling import max_pooling_loss
|
from kws.model.max_pooling import max_pooling_loss
|
||||||
|
from kws.model.max_pooling_RHE import max_pooling_RHE_binary_CE
|
||||||
from kws.model.ce import cross_entropy
|
from kws.model.ce import cross_entropy
|
||||||
|
|
||||||
|
|
||||||
criterion_dict = {'CE': cross_entropy,
|
criterion_dict = {'CE': cross_entropy,
|
||||||
'max_pooling': max_pooling_loss}
|
'max_pooling': max_pooling_loss,
|
||||||
|
'RHE': max_pooling_RHE_binary_CE}
|
||||||
|
|
||||||
class Executor:
|
class Executor:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user