fix quickcheck and flake8

This commit is contained in:
dujing 2023-07-24 17:08:15 +08:00
parent ea6a0f5cda
commit 9b20c840fa
4 changed files with 32 additions and 31 deletions

View File

@ -53,7 +53,7 @@ def load_label_and_score(keywords_list, label_file, score_file, true_keywords):
key = arr[0] key = arr[0]
is_detected = arr[1] is_detected = arr[1]
if is_detected == 'detected': if is_detected == 'detected':
keyword=true_keywords[arr[2]] keyword = true_keywords[arr[2]]
if key not in score_table: if key not in score_table:
score_table.update({ score_table.update({
key: { key: {
@ -247,8 +247,8 @@ if __name__ == '__main__':
num_false_alarm = 0 num_false_alarm = 0
# transverse the all filler_table # transverse the all filler_table
for key, confi in keyword_filler_table[keyword][ for key, confi in keyword_filler_table[
'filler_table'].items(): keyword]['filler_table'].items():
if confi >= threshold: if confi >= threshold:
num_false_alarm += 1 num_false_alarm += 1
# print(f'false alarm: {keyword}, {key}, {confi}') # print(f'false alarm: {keyword}, {key}, {confi}')

View File

@ -16,7 +16,7 @@ from __future__ import print_function
import argparse import argparse
import struct import struct
#import wave # import wave
import librosa import librosa
import logging import logging
import os import os
@ -328,12 +328,12 @@ class KeyWordSpotter(torch.nn.Module):
wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = torch.from_numpy(wave).float().to(self.device)
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
feats = kaldi.fbank(wave_tensor, feats = kaldi.fbank(wave_tensor,
num_mel_bins=self.num_mel_bins, num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length, frame_length=self.frame_length,
frame_shift=self.frame_shift, frame_shift=self.frame_shift,
dither=0, dither=0,
energy_floor=0.0, energy_floor=0.0,
sample_frequency=self.sample_rate) sample_frequency=self.sample_rate)
# update wave remained # update wave remained
feat_len = len(feats) feat_len = len(feats)
frame_shift = int(self.frame_shift / 1000 * self.sample_rate) frame_shift = int(self.frame_shift / 1000 * self.sample_rate)
@ -351,8 +351,8 @@ class KeyWordSpotter(torch.nn.Module):
else: else:
feats_pad = torch.cat((self.feature_remained, feats)) feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats_pad.shape[0] - \ ctx_frm = feats_pad.shape[0] - (
(self.right_context+self.right_context) self.right_context + self.right_context)
ctx_win = (self.left_context + self.right_context + 1) ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
@ -362,15 +362,15 @@ class KeyWordSpotter(torch.nn.Module):
# update feature remained, and feats # update feature remained, and feats
self.feature_remained = \ self.feature_remained = \
feats[-(self.left_context+self.right_context):] feats[-(self.left_context + self.right_context):]
feats = feats_ctx.to(self.device) feats = feats_ctx.to(self.device)
if self.downsampling > 1: if self.downsampling > 1:
last_remainder = 0 if self.feats_ctx_offset==0 \ last_remainder = 0 if self.feats_ctx_offset == 0 \
else self.downsampling-self.feats_ctx_offset else self.downsampling - self.feats_ctx_offset
remainder = (feats.size(0)+last_remainder) % self.downsampling remainder = (feats.size(0) + last_remainder) % self.downsampling
feats = feats[self.feats_ctx_offset::self.downsampling, :] feats = feats[self.feats_ctx_offset::self.downsampling, :]
self.feats_ctx_offset = remainder \ self.feats_ctx_offset = remainder \
if remainder == 0 else self.downsampling-remainder if remainder == 0 else self.downsampling - remainder
return feats return feats
def decode_keywords(self, t, probs): def decode_keywords(self, t, probs):
@ -419,8 +419,8 @@ class KeyWordSpotter(torch.nn.Module):
if hit_keyword is not None: if hit_keyword is not None:
if self.hit_score >= self.threshold and \ if self.hit_score >= self.threshold and \
self.min_frames <= duration <= self.max_frames \ self.min_frames <= duration <= self.max_frames \
and (self.last_active_pos==-1 or and (self.last_active_pos == -1 or
end-self.last_active_pos >= self.interval_frames): end - self.last_active_pos >= self.interval_frames):
self.activated = True self.activated = True
self.last_active_pos = end self.last_active_pos = end
logging.info( logging.info(
@ -428,8 +428,8 @@ class KeyWordSpotter(torch.nn.Module):
f"from {start} to {end} frame. " f"from {start} to {end} frame. "
f"duration {duration}, score {self.hit_score}, Activated.") f"duration {duration}, score {self.hit_score}, Activated.")
elif self.last_active_pos>0 and \ elif self.last_active_pos > 0 and \
end-self.last_active_pos < self.interval_frames: end - self.last_active_pos < self.interval_frames:
logging.info( logging.info(
f"Frame {absolute_time} detect {hit_keyword} " f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. " f"from {start} to {end} frame. "

View File

@ -216,7 +216,7 @@ class FSMNBlock(nn.Module):
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
def forward(self, def forward(self,
input: Tuple[torch.Tensor, torch.Tensor] ): input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple): if isinstance(input, tuple):
input, in_cache = input input, in_cache = input
else : else :
@ -236,12 +236,13 @@ class FSMNBlock(nn.Module):
y_left = self.quant(y_left) y_left = self.quant(y_left)
y_left = self.conv_left(y_left) y_left = self.conv_left(y_left)
y_left = self.dequant(y_left) y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride: out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder *
-self.rorder * self.rstride, :] + y_left self.rstride, :] + y_left
if self.conv_right is not None: if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :] y_right = x_pad[:, :, -(
x_per.size(2) + self.rorder * self.rstride):, :]
y_right = y_right[:, :, self.rstride:, :] y_right = y_right[:, :, self.rstride:, :]
y_right = self.quant(y_right) y_right = self.quant(y_right)
y_right = self.conv_right(y_right) y_right = self.conv_right(y_right)
@ -257,8 +258,8 @@ class FSMNBlock(nn.Module):
re_str = '' re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim) re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \ re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % ( '<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride) 1, self.lorder, self.rorder, self.lstride, self.rstride)
# print(self.conv_left.weight,self.conv_right.weight) # print(self.conv_left.weight,self.conv_right.weight)
lfiters = self.state_dict()['conv_left.weight'] lfiters = self.state_dict()['conv_left.weight']
@ -445,8 +446,8 @@ class FSMN(nn.Module):
self.output_affine_dim = output_affine_dim self.output_affine_dim = output_affine_dim
self.output_dim = output_dim self.output_dim = output_dim
self.padding = (self.lorder-1) * self.lstride \ self.padding = (self.lorder - 1) * self.lstride \
+ self.rorder * self.rstride + self.rorder * self.rstride
self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)