143 lines
5.1 KiB
Python
143 lines
5.1 KiB
Python
# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
def RHE(indice: torch.Tensor, k: int):
|
|
"""Regional hard example mining from
|
|
'Mining effective negative training samples for keyword spotting'
|
|
|
|
Attributes:
|
|
index: indice of
|
|
k:
|
|
lengths: (B)
|
|
min_duration: min duration of the keyword
|
|
Returns:
|
|
(torch.Tensor): indice of selected regional hard example
|
|
"""
|
|
if k <= 0:
|
|
return indice
|
|
lenght = len(indice)
|
|
available_indice = torch.tensor([1] * (lenght))
|
|
reserve = []
|
|
for i in range(lenght):
|
|
if 1 == available_indice[indice[i]]:
|
|
reserve.append(indice[i])
|
|
rm_s = max(indice[i] - k, 0)
|
|
rm_e = min(indice[i] + k, lenght)
|
|
available_indice[rm_s:rm_e + 1] = 0
|
|
else:
|
|
continue
|
|
|
|
if torch.sum(available_indice) <= 0:
|
|
break
|
|
return torch.tensor(reserve).long()
|
|
|
|
|
|
def downsample_training_sample_and_calculate_loss(logits,
|
|
targets,
|
|
ratio: float = 10):
|
|
num_training = 0
|
|
loss = 0
|
|
for i in range(len(logits)):
|
|
output = torch.cat(logits[i])
|
|
target = torch.LongTensor(np.concatenate(targets[i]))
|
|
# how many positive targets
|
|
positive_index = target >= 1 # the label of positive label is 1
|
|
negative_index = target < 1 # the label of negative label is 0
|
|
num_p = torch.sum(positive_index)
|
|
selected_p_output = output[positive_index]
|
|
loss += torch.sum(torch.log(selected_p_output))
|
|
|
|
all_n_output = output[negative_index]
|
|
num_n = min(int(ratio * num_p), len(all_n_output))
|
|
_, sorted_index = torch.sort(all_n_output, descending=True)
|
|
selected_n_output = all_n_output[sorted_index[:num_n]]
|
|
num_training += len(selected_p_output) + len(selected_n_output)
|
|
return loss / num_training
|
|
|
|
|
|
def max_pooling_RHE_binary_CE(logits,
|
|
targets,
|
|
lengths,
|
|
RHE_thr=100,
|
|
max_ratio=1):
|
|
"""Max-pooling loss with regional hard example mining
|
|
For each keyword utterance, select the frame with the highest posterior.
|
|
The keyword is triggered when any of the frames is triggered.
|
|
For each non-keyword utterance, select several hard examples using the
|
|
RHE algorithm.
|
|
|
|
Attributes:
|
|
logits: (B, T, D), D is the number of keywords
|
|
target: (B)
|
|
lengths: (B)
|
|
RHE_thr: how many neighbor logits we remove each time we find a hard examle
|
|
Returns:
|
|
(float): loss of current batch
|
|
(float): accuracy of current batch
|
|
"""
|
|
num_hit = 0
|
|
# Here we clamp the sigmoid output to prevent NaN problem
|
|
# When we calculate loss
|
|
logits = torch.clamp(torch.sigmoid(logits), 1e-8, 1.0 - 1e-8)
|
|
num_utts = logits.size(0)
|
|
num_keyword = logits.size(2)
|
|
|
|
new_logits = []
|
|
new_targets = []
|
|
for j in range(num_keyword):
|
|
new_logits.append([])
|
|
new_targets.append([])
|
|
|
|
for i in range(num_utts):
|
|
end_idx = lengths[i]
|
|
for j in range(num_keyword):
|
|
if targets[i] == j:
|
|
max_idx = logits[i, :end_idx].argmax()
|
|
new_logits[j].append(logits[i, max_idx, j])
|
|
new_targets[j].append([1])
|
|
if logits[i, max_idx, j] >= 0.5:
|
|
num_hit += 1
|
|
else:
|
|
sorted_logits, sorted_index = torch.sort(logits[i, :end_idx],
|
|
dim=0)
|
|
reversed_index = torch.flip(sorted_index, dims=[0])
|
|
selected_indexes = RHE(reversed_index[:, j], RHE_thr)
|
|
new_logits[j].append(logits[i, selected_indexes, j])
|
|
new_targets[j].append([0] * len(selected_indexes))
|
|
if torch.sum(sorted_logits[-1, :] >= 0.5) <= 0:
|
|
# all the binary probilities are smaller than 0.5
|
|
num_hit += 1
|
|
|
|
# Here we select training samples acorrding to max_ratio
|
|
loss = downsample_training_sample_and_calculate_loss(
|
|
new_logits,
|
|
new_targets,
|
|
ratio=max_ratio,
|
|
)
|
|
acc = num_hit / num_utts
|
|
return loss, acc
|
|
|
|
|
|
if __name__ == '__main__':
|
|
index = torch.tensor([3, 2, 0, 7, 5, 8, 1, 4, 6])
|
|
print(RHE(index, 0)) # [3, 2, 0, 7, 5, 8, 1, 4, 6]
|
|
print(RHE(index, 1)) # [3, 0, 7, 5 ]
|
|
print(RHE(index, 2)) # [3, 0, 7]
|
|
print(RHE(index, 3)) # [3, 7]
|
|
print(RHE(index, 100)) # [3]
|