diff --git a/examples/hi_xiaowen/s0/conf/mdtc.yaml b/examples/hi_xiaowen/s0/conf/mdtc.yaml index 4ec5f7e..905a7e9 100644 --- a/examples/hi_xiaowen/s0/conf/mdtc.yaml +++ b/examples/hi_xiaowen/s0/conf/mdtc.yaml @@ -44,3 +44,4 @@ training_config: grad_clip: 5 max_epoch: 100 log_interval: 10 + criterion: RHE \ No newline at end of file diff --git a/kws/model/classifier.py b/kws/model/classifier.py index cf81ab5..b4a0983 100644 --- a/kws/model/classifier.py +++ b/kws/model/classifier.py @@ -22,3 +22,12 @@ class LastClassifier(nn.Module): def forward(self, x: torch.Tensor): x = x[:, -1, :] return self.classifier(x) + +class ElementClassifier(nn.Module): + """Classify all the frames in an utterance""" + def __init__(self, classifier: nn.Module): + super(ElementClassifier, self).__init__() + self.classifier = classifier + + def forward(self, x: torch.Tensor): + return self.classifier(x) \ No newline at end of file diff --git a/kws/model/max_pooling_RHE.py b/kws/model/max_pooling_RHE.py index 7963ec1..8168af5 100644 --- a/kws/model/max_pooling_RHE.py +++ b/kws/model/max_pooling_RHE.py @@ -74,7 +74,7 @@ def downsample_training_sample_and_calculate_loss(logits, def max_pooling_RHE_binary_CE(logits, targets, lengths, - RHE_thr=10000, + RHE_thr=100, max_ratio=1): """Max-pooling loss with regional hard example mining For each keyword utterance, select the frame with the highest posterior. diff --git a/kws/utils/executor.py b/kws/utils/executor.py index 5f394bf..cebca58 100644 --- a/kws/utils/executor.py +++ b/kws/utils/executor.py @@ -18,11 +18,13 @@ import torch from torch.nn.utils import clip_grad_norm_ from kws.model.max_pooling import max_pooling_loss +from kws.model.max_pooling_RHE import max_pooling_RHE_binary_CE from kws.model.ce import cross_entropy criterion_dict = {'CE': cross_entropy, - 'max_pooling': max_pooling_loss} + 'max_pooling': max_pooling_loss, + 'RHE': max_pooling_RHE_binary_CE} class Executor: def __init__(self):