fix quickcheck and flake8

This commit is contained in:
dujing 2023-07-24 17:08:15 +08:00
parent ea6a0f5cda
commit 9b20c840fa
4 changed files with 32 additions and 31 deletions

View File

@ -39,7 +39,7 @@ FRRs with FAR fixed at once per 12 hours:
Now, the DSTCN model with CTC loss may not get the best performance, because the Now, the DSTCN model with CTC loss may not get the best performance, because the
pretraining phase is not sufficiently converged. We recommend you use pretrained pretraining phase is not sufficiently converged. We recommend you use pretrained
FSMN model as initial checkpoint to train your own model. FSMN model as initial checkpoint to train your own model.
Comparison Between stream_score_ctc and score_ctc. Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours: FRRs with FAR fixed at once per 12 hours:
@ -57,6 +57,6 @@ Actually the probability will increase through the time, so we record a lower va
which result in a higher False Rejection Rate in Detection Error Tradeoff result. which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold. The actual FRR will be lower than the DET curve gives in a given threshold.
On some small data KWS tasks, we believe the FSMN-CTC model is more robust On some small data KWS tasks, we believe the FSMN-CTC model is more robust
compared with the classification model using CE/Max-pooling loss. compared with the classification model using CE/Max-pooling loss.
For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary). For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).

View File

@ -53,7 +53,7 @@ def load_label_and_score(keywords_list, label_file, score_file, true_keywords):
key = arr[0] key = arr[0]
is_detected = arr[1] is_detected = arr[1]
if is_detected == 'detected': if is_detected == 'detected':
keyword=true_keywords[arr[2]] keyword = true_keywords[arr[2]]
if key not in score_table: if key not in score_table:
score_table.update({ score_table.update({
key: { key: {
@ -247,8 +247,8 @@ if __name__ == '__main__':
num_false_alarm = 0 num_false_alarm = 0
# transverse the all filler_table # transverse the all filler_table
for key, confi in keyword_filler_table[keyword][ for key, confi in keyword_filler_table[
'filler_table'].items(): keyword]['filler_table'].items():
if confi >= threshold: if confi >= threshold:
num_false_alarm += 1 num_false_alarm += 1
# print(f'false alarm: {keyword}, {key}, {confi}') # print(f'false alarm: {keyword}, {key}, {confi}')

View File

@ -16,7 +16,7 @@ from __future__ import print_function
import argparse import argparse
import struct import struct
#import wave # import wave
import librosa import librosa
import logging import logging
import os import os
@ -328,12 +328,12 @@ class KeyWordSpotter(torch.nn.Module):
wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = torch.from_numpy(wave).float().to(self.device)
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
feats = kaldi.fbank(wave_tensor, feats = kaldi.fbank(wave_tensor,
num_mel_bins=self.num_mel_bins, num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length, frame_length=self.frame_length,
frame_shift=self.frame_shift, frame_shift=self.frame_shift,
dither=0, dither=0,
energy_floor=0.0, energy_floor=0.0,
sample_frequency=self.sample_rate) sample_frequency=self.sample_rate)
# update wave remained # update wave remained
feat_len = len(feats) feat_len = len(feats)
frame_shift = int(self.frame_shift / 1000 * self.sample_rate) frame_shift = int(self.frame_shift / 1000 * self.sample_rate)
@ -351,8 +351,8 @@ class KeyWordSpotter(torch.nn.Module):
else: else:
feats_pad = torch.cat((self.feature_remained, feats)) feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats_pad.shape[0] - \ ctx_frm = feats_pad.shape[0] - (
(self.right_context+self.right_context) self.right_context + self.right_context)
ctx_win = (self.left_context + self.right_context + 1) ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
@ -362,15 +362,15 @@ class KeyWordSpotter(torch.nn.Module):
# update feature remained, and feats # update feature remained, and feats
self.feature_remained = \ self.feature_remained = \
feats[-(self.left_context+self.right_context):] feats[-(self.left_context + self.right_context):]
feats = feats_ctx.to(self.device) feats = feats_ctx.to(self.device)
if self.downsampling > 1: if self.downsampling > 1:
last_remainder = 0 if self.feats_ctx_offset==0 \ last_remainder = 0 if self.feats_ctx_offset == 0 \
else self.downsampling-self.feats_ctx_offset else self.downsampling - self.feats_ctx_offset
remainder = (feats.size(0)+last_remainder) % self.downsampling remainder = (feats.size(0) + last_remainder) % self.downsampling
feats = feats[self.feats_ctx_offset::self.downsampling, :] feats = feats[self.feats_ctx_offset::self.downsampling, :]
self.feats_ctx_offset = remainder \ self.feats_ctx_offset = remainder \
if remainder == 0 else self.downsampling-remainder if remainder == 0 else self.downsampling - remainder
return feats return feats
def decode_keywords(self, t, probs): def decode_keywords(self, t, probs):
@ -419,8 +419,8 @@ class KeyWordSpotter(torch.nn.Module):
if hit_keyword is not None: if hit_keyword is not None:
if self.hit_score >= self.threshold and \ if self.hit_score >= self.threshold and \
self.min_frames <= duration <= self.max_frames \ self.min_frames <= duration <= self.max_frames \
and (self.last_active_pos==-1 or and (self.last_active_pos == -1 or
end-self.last_active_pos >= self.interval_frames): end - self.last_active_pos >= self.interval_frames):
self.activated = True self.activated = True
self.last_active_pos = end self.last_active_pos = end
logging.info( logging.info(
@ -428,8 +428,8 @@ class KeyWordSpotter(torch.nn.Module):
f"from {start} to {end} frame. " f"from {start} to {end} frame. "
f"duration {duration}, score {self.hit_score}, Activated.") f"duration {duration}, score {self.hit_score}, Activated.")
elif self.last_active_pos>0 and \ elif self.last_active_pos > 0 and \
end-self.last_active_pos < self.interval_frames: end - self.last_active_pos < self.interval_frames:
logging.info( logging.info(
f"Frame {absolute_time} detect {hit_keyword} " f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. " f"from {start} to {end} frame. "

View File

@ -216,7 +216,7 @@ class FSMNBlock(nn.Module):
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
def forward(self, def forward(self,
input: Tuple[torch.Tensor, torch.Tensor] ): input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple): if isinstance(input, tuple):
input, in_cache = input input, in_cache = input
else : else :
@ -236,12 +236,13 @@ class FSMNBlock(nn.Module):
y_left = self.quant(y_left) y_left = self.quant(y_left)
y_left = self.conv_left(y_left) y_left = self.conv_left(y_left)
y_left = self.dequant(y_left) y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride: out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder *
-self.rorder * self.rstride, :] + y_left self.rstride, :] + y_left
if self.conv_right is not None: if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :] y_right = x_pad[:, :, -(
x_per.size(2) + self.rorder * self.rstride):, :]
y_right = y_right[:, :, self.rstride:, :] y_right = y_right[:, :, self.rstride:, :]
y_right = self.quant(y_right) y_right = self.quant(y_right)
y_right = self.conv_right(y_right) y_right = self.conv_right(y_right)
@ -257,8 +258,8 @@ class FSMNBlock(nn.Module):
re_str = '' re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim) re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \ re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % ( '<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride) 1, self.lorder, self.rorder, self.lstride, self.rstride)
# print(self.conv_left.weight,self.conv_right.weight) # print(self.conv_left.weight,self.conv_right.weight)
lfiters = self.state_dict()['conv_left.weight'] lfiters = self.state_dict()['conv_left.weight']
@ -445,8 +446,8 @@ class FSMN(nn.Module):
self.output_affine_dim = output_affine_dim self.output_affine_dim = output_affine_dim
self.output_dim = output_dim self.output_dim = output_dim
self.padding = (self.lorder-1) * self.lstride \ self.padding = (self.lorder - 1) * self.lstride \
+ self.rorder * self.rstride + self.rorder * self.rstride
self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)