fix quickcheck and flake8

This commit is contained in:
dujing 2023-07-24 17:08:15 +08:00
parent ea6a0f5cda
commit 9b20c840fa
4 changed files with 32 additions and 31 deletions

View File

@ -247,8 +247,8 @@ if __name__ == '__main__':
num_false_alarm = 0 num_false_alarm = 0
# transverse the all filler_table # transverse the all filler_table
for key, confi in keyword_filler_table[keyword][ for key, confi in keyword_filler_table[
'filler_table'].items(): keyword]['filler_table'].items():
if confi >= threshold: if confi >= threshold:
num_false_alarm += 1 num_false_alarm += 1
# print(f'false alarm: {keyword}, {key}, {confi}') # print(f'false alarm: {keyword}, {key}, {confi}')

View File

@ -351,8 +351,8 @@ class KeyWordSpotter(torch.nn.Module):
else: else:
feats_pad = torch.cat((self.feature_remained, feats)) feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats_pad.shape[0] - \ ctx_frm = feats_pad.shape[0] - (
(self.right_context+self.right_context) self.right_context + self.right_context)
ctx_win = (self.left_context + self.right_context + 1) ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)

View File

@ -236,12 +236,13 @@ class FSMNBlock(nn.Module):
y_left = self.quant(y_left) y_left = self.quant(y_left)
y_left = self.conv_left(y_left) y_left = self.conv_left(y_left)
y_left = self.dequant(y_left) y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride: out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder *
-self.rorder * self.rstride, :] + y_left self.rstride, :] + y_left
if self.conv_right is not None: if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :] y_right = x_pad[:, :, -(
x_per.size(2) + self.rorder * self.rstride):, :]
y_right = y_right[:, :, self.rstride:, :] y_right = y_right[:, :, self.rstride:, :]
y_right = self.quant(y_right) y_right = self.quant(y_right)
y_right = self.conv_right(y_right) y_right = self.conv_right(y_right)