wekws/wekws/model/loss.py
2023-05-19 11:38:16 +08:00

453 lines
16 KiB
Python

# Copyright (c) 2021 Binbin Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch, math, sys
import torch.nn.functional as F
from collections import defaultdict
from typing import List, Optional, Tuple
from wekws.utils.mask import padding_mask
def max_pooling_loss(logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
''' Max-pooling loss
For keyword, select the frame with the highest posterior.
The keyword is triggered when any of the frames is triggered.
For none keyword, select the hardest frame, namely the frame
with lowest filler posterior(highest keyword posterior).
the keyword is not triggered when all frames are not triggered.
Attributes:
logits: (B, T, D), D is the number of keywords
target: (B)
lengths: (B)
min_duration: min duration of the keyword
Returns:
(float): loss of current batch
(float): accuracy of current batch
'''
mask = padding_mask(lengths)
num_utts = logits.size(0)
num_keywords = logits.size(2)
target = target.cpu()
loss = 0.0
for i in range(num_utts):
for j in range(num_keywords):
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
if target[i] == j:
# For the keyword, do max-polling
prob = logits[i, :, j]
m = mask[i].clone().detach()
m[:min_duration] = True
prob = prob.masked_fill(m, 0.0)
prob = torch.clamp(prob, 1e-8, 1.0)
max_prob = prob.max()
loss += -torch.log(max_prob)
else:
# For other keywords or filler, do min-polling
prob = 1 - logits[i, :, j]
prob = prob.masked_fill(mask[i], 1.0)
prob = torch.clamp(prob, 1e-8, 1.0)
min_prob = prob.min()
loss += -torch.log(min_prob)
loss = loss / num_utts
# Compute accuracy of current batch
mask = mask.unsqueeze(-1)
logits = logits.masked_fill(mask, 0.0)
max_logits, index = logits.max(1)
num_correct = 0
for i in range(num_utts):
max_p, idx = max_logits[i].max(0)
# Predict correct as the i'th keyword
if max_p > 0.5 and idx == target[i]:
num_correct += 1
# Predict correct as the filler, filler id < 0
if max_p < 0.5 and target[i] < 0:
num_correct += 1
acc = num_correct / num_utts
# acc = 0.0
return loss, acc
def acc_frame(
logits: torch.Tensor,
target: torch.Tensor,
):
if logits is None:
return 0
pred = logits.max(1, keepdim=True)[1]
correct = pred.eq(target.long().view_as(pred)).sum().item()
return correct * 100.0 / logits.size(0)
def acc_utterance(logits: torch.Tensor, target: torch.Tensor,
logits_length: torch.Tensor, target_length: torch.Tensor):
if logits is None:
return 0
logits = logits.softmax(2) # (1, maxlen, vocab_size)
logits = logits.cpu()
target = target.cpu()
total_word = 0
total_ins = 0
total_sub = 0
total_del = 0
calculator = Calculator()
for i in range(logits.size(0)):
score = logits[i][:logits_length[i]]
hyps = ctc_prefix_beam_search(score, logits_length[i], None, 3, 5)
lab = [str(item) for item in target[i][:target_length[i]].tolist()]
rec = []
if len(hyps) > 0:
rec = [str(item) for item in hyps[0][0]]
result = calculator.calculate(lab, rec)
# print(f'result:{result}')
if result['all'] != 0:
total_word += result['all']
total_ins += result['ins']
total_sub += result['sub']
total_del += result['del']
return float(total_word - total_ins - total_sub
- total_del) * 100.0 / total_word
def ctc_loss(logits: torch.Tensor,
target: torch.Tensor,
logits_lengths: torch.Tensor,
target_lengths: torch.Tensor,
need_acc: bool = False):
""" CTC Loss
Args:
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
target: (B)
logits_lengths: (B)
target_lengths: (B)
Returns:
(float): loss of current batch
"""
acc = 0.0
if need_acc:
acc = acc_utterance(logits, target, logits_lengths, target_lengths)
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose(0, 1)
logits = logits.log_softmax(2)
loss = F.ctc_loss(
logits, target, logits_lengths, target_lengths, reduction='sum')
loss = loss / logits.size(1) # batch mean
return loss, acc
def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
""" Cross Entropy Loss
Attributes:
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
target: (B)
lengths: (B)
min_duration: min duration of the keyword
Returns:
(float): loss of current batch
(float): accuracy of current batch
"""
loss = F.cross_entropy(logits, target)
acc = acc_frame(logits, target)
return loss, acc
def criterion(type: str,
logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
target_lengths: torch.Tensor = None,
min_duration: int = 0,
validation: bool = False, ):
if type == 'ce':
loss, acc = cross_entropy(logits, target)
return loss, acc
elif type == 'max_pooling':
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
return loss, acc
elif type == 'ctc':
loss, acc = ctc_loss(logits, target, lengths, target_lengths, validation)
return loss, acc
else:
exit(1)
def ctc_prefix_beam_search(
logits: torch.Tensor,
logits_lengths: torch.Tensor,
keywords_tokenset: set = None,
score_beam_size: int = 3,
path_beam_size: int = 20,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
logits (torch.Tensor): (1, max_len, vocab_size)
logits_lengths (torch.Tensor): (1, )
keywords_tokenset (set): token set for filtering score
score_beam_size (int): beam size for score
path_beam_size (int): beam size for path
Returns:
List[List[int]]: nbest results
"""
maxlen = logits.size(0)
# ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size)
ctc_probs = logits
cur_hyps = [(tuple(), (1.0, 0.0, []))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
probs = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, []))
# 2.1 First beam prune: select topk best
top_k_probs, top_k_index = probs.topk(
score_beam_size) # (score_beam_size,)
# filter prob score that is too small
filter_probs = []
filter_index = []
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
if keywords_tokenset is not None:
if prob > 0.05 and idx in keywords_tokenset:
filter_probs.append(prob)
filter_index.append(idx)
else:
if prob > 0.05:
filter_probs.append(prob)
filter_index.append(idx)
if len(filter_index) == 0:
continue
for s in filter_index:
ps = probs[s].item()
for prefix, (pb, pnb, cur_nodes) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pb = n_pb + pb * ps + pnb * ps
nodes = cur_nodes.copy()
next_hyps[prefix] = (n_pb, n_pnb, nodes)
elif s == last:
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
# Update *ss -> *s;
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pnb = n_pnb + pnb * ps
nodes = cur_nodes.copy()
if ps > nodes[-1]['prob']: # update frame and prob
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
next_hyps[prefix] = (n_pb, n_pnb, nodes)
if not math.isclose(pb, 0.0, abs_tol=0.000001):
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
n_pnb = n_pnb + pb * ps
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
if nodes:
if ps > nodes[-1]['prob']: # update frame and prob
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
else:
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
n_pnb = n_pnb + pb * ps + pnb * ps
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)
cur_hyps = next_hyps[:path_beam_size]
hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
return hyps
class Calculator:
def __init__(self):
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec):
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab):
self.space.append([])
for row in self.space:
for element in row:
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec):
row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)):
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)):
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i - 1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist:
min_dist = dist
min_error = error
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist:
min_dist = dist
min_error = error
if lab_token == rec_token:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
error = 'cor'
else:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist:
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
i = len(lab) - 1
j = len(rec) - 1
while True:
if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, '')
i = i - 1
elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0:
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, '')
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non': # starting point
break
else: # shouldn't reach here
print(
'this should not happen , i = {i} , j = {j} , error = {error}'
.format(i=i, j=j, error=self.space[i][j]['error']))
return result
def overall(self):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in data:
if token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self):
return list(self.data.keys())