fix quickcheck and flake8

This commit is contained in:
dujing 2023-07-24 17:08:15 +08:00
parent ea6a0f5cda
commit 9b20c840fa
4 changed files with 32 additions and 31 deletions

View File

@ -39,7 +39,7 @@ FRRs with FAR fixed at once per 12 hours:
Now, the DSTCN model with CTC loss may not get the best performance, because the
pretraining phase is not sufficiently converged. We recommend you use pretrained
FSMN model as initial checkpoint to train your own model.
FSMN model as initial checkpoint to train your own model.
Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours:
@ -57,6 +57,6 @@ Actually the probability will increase through the time, so we record a lower va
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold.
On some small data KWS tasks, we believe the FSMN-CTC model is more robust
On some small data KWS tasks, we believe the FSMN-CTC model is more robust
compared with the classification model using CE/Max-pooling loss.
For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).

View File

@ -53,7 +53,7 @@ def load_label_and_score(keywords_list, label_file, score_file, true_keywords):
key = arr[0]
is_detected = arr[1]
if is_detected == 'detected':
keyword=true_keywords[arr[2]]
keyword = true_keywords[arr[2]]
if key not in score_table:
score_table.update({
key: {
@ -247,8 +247,8 @@ if __name__ == '__main__':
num_false_alarm = 0
# transverse the all filler_table
for key, confi in keyword_filler_table[keyword][
'filler_table'].items():
for key, confi in keyword_filler_table[
keyword]['filler_table'].items():
if confi >= threshold:
num_false_alarm += 1
# print(f'false alarm: {keyword}, {key}, {confi}')

View File

@ -16,7 +16,7 @@ from __future__ import print_function
import argparse
import struct
#import wave
# import wave
import librosa
import logging
import os
@ -328,12 +328,12 @@ class KeyWordSpotter(torch.nn.Module):
wave_tensor = torch.from_numpy(wave).float().to(self.device)
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
feats = kaldi.fbank(wave_tensor,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=0,
energy_floor=0.0,
sample_frequency=self.sample_rate)
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=0,
energy_floor=0.0,
sample_frequency=self.sample_rate)
# update wave remained
feat_len = len(feats)
frame_shift = int(self.frame_shift / 1000 * self.sample_rate)
@ -351,8 +351,8 @@ class KeyWordSpotter(torch.nn.Module):
else:
feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats_pad.shape[0] - \
(self.right_context+self.right_context)
ctx_frm = feats_pad.shape[0] - (
self.right_context + self.right_context)
ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
@ -362,15 +362,15 @@ class KeyWordSpotter(torch.nn.Module):
# update feature remained, and feats
self.feature_remained = \
feats[-(self.left_context+self.right_context):]
feats[-(self.left_context + self.right_context):]
feats = feats_ctx.to(self.device)
if self.downsampling > 1:
last_remainder = 0 if self.feats_ctx_offset==0 \
else self.downsampling-self.feats_ctx_offset
remainder = (feats.size(0)+last_remainder) % self.downsampling
last_remainder = 0 if self.feats_ctx_offset == 0 \
else self.downsampling - self.feats_ctx_offset
remainder = (feats.size(0) + last_remainder) % self.downsampling
feats = feats[self.feats_ctx_offset::self.downsampling, :]
self.feats_ctx_offset = remainder \
if remainder == 0 else self.downsampling-remainder
if remainder == 0 else self.downsampling - remainder
return feats
def decode_keywords(self, t, probs):
@ -419,8 +419,8 @@ class KeyWordSpotter(torch.nn.Module):
if hit_keyword is not None:
if self.hit_score >= self.threshold and \
self.min_frames <= duration <= self.max_frames \
and (self.last_active_pos==-1 or
end-self.last_active_pos >= self.interval_frames):
and (self.last_active_pos == -1 or
end - self.last_active_pos >= self.interval_frames):
self.activated = True
self.last_active_pos = end
logging.info(
@ -428,8 +428,8 @@ class KeyWordSpotter(torch.nn.Module):
f"from {start} to {end} frame. "
f"duration {duration}, score {self.hit_score}, Activated.")
elif self.last_active_pos>0 and \
end-self.last_active_pos < self.interval_frames:
elif self.last_active_pos > 0 and \
end - self.last_active_pos < self.interval_frames:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} "
f"from {start} to {end} frame. "

View File

@ -216,7 +216,7 @@ class FSMNBlock(nn.Module):
self.dequant = torch.quantization.DeQuantStub()
def forward(self,
input: Tuple[torch.Tensor, torch.Tensor] ):
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple):
input, in_cache = input
else :
@ -236,12 +236,13 @@ class FSMNBlock(nn.Module):
y_left = self.quant(y_left)
y_left = self.conv_left(y_left)
y_left = self.dequant(y_left)
out = x_pad[:, :, (self.lorder - 1) * self.lstride:
-self.rorder * self.rstride, :] + y_left
out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder *
self.rstride, :] + y_left
if self.conv_right is not None:
# y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :]
y_right = x_pad[:, :, -(
x_per.size(2) + self.rorder * self.rstride):, :]
y_right = y_right[:, :, self.rstride:, :]
y_right = self.quant(y_right)
y_right = self.conv_right(y_right)
@ -257,8 +258,8 @@ class FSMNBlock(nn.Module):
re_str = ''
re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d ' \
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride)
'<LStride> %d <RStride> %d <MaxNorm> 0\n' % (
1, self.lorder, self.rorder, self.lstride, self.rstride)
# print(self.conv_left.weight,self.conv_right.weight)
lfiters = self.state_dict()['conv_left.weight']
@ -445,8 +446,8 @@ class FSMN(nn.Module):
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
self.padding = (self.lorder-1) * self.lstride \
+ self.rorder * self.rstride
self.padding = (self.lorder - 1) * self.lstride \
+ self.rorder * self.rstride
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)