From 9b20c840faf65d1ced4b275af4d1cd9ab375c9d3 Mon Sep 17 00:00:00 2001 From: dujing Date: Mon, 24 Jul 2023 17:08:15 +0800 Subject: [PATCH] fix quickcheck and flake8 --- examples/hi_xiaowen/s0/README.md | 4 ++-- wekws/bin/compute_det_ctc.py | 6 +++--- wekws/bin/stream_kws_ctc.py | 36 ++++++++++++++++---------------- wekws/model/fsmn.py | 17 ++++++++------- 4 files changed, 32 insertions(+), 31 deletions(-) diff --git a/examples/hi_xiaowen/s0/README.md b/examples/hi_xiaowen/s0/README.md index 4f2f7d0..964b71e 100644 --- a/examples/hi_xiaowen/s0/README.md +++ b/examples/hi_xiaowen/s0/README.md @@ -39,7 +39,7 @@ FRRs with FAR fixed at once per 12 hours: Now, the DSTCN model with CTC loss may not get the best performance, because the pretraining phase is not sufficiently converged. We recommend you use pretrained -FSMN model as initial checkpoint to train your own model. +FSMN model as initial checkpoint to train your own model. Comparison Between stream_score_ctc and score_ctc. FRRs with FAR fixed at once per 12 hours: @@ -57,6 +57,6 @@ Actually the probability will increase through the time, so we record a lower va which result in a higher False Rejection Rate in Detection Error Tradeoff result. The actual FRR will be lower than the DET curve gives in a given threshold. -On some small data KWS tasks, we believe the FSMN-CTC model is more robust +On some small data KWS tasks, we believe the FSMN-CTC model is more robust compared with the classification model using CE/Max-pooling loss. For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary). \ No newline at end of file diff --git a/wekws/bin/compute_det_ctc.py b/wekws/bin/compute_det_ctc.py index 2ff459f..9a8b5e7 100644 --- a/wekws/bin/compute_det_ctc.py +++ b/wekws/bin/compute_det_ctc.py @@ -53,7 +53,7 @@ def load_label_and_score(keywords_list, label_file, score_file, true_keywords): key = arr[0] is_detected = arr[1] if is_detected == 'detected': - keyword=true_keywords[arr[2]] + keyword = true_keywords[arr[2]] if key not in score_table: score_table.update({ key: { @@ -247,8 +247,8 @@ if __name__ == '__main__': num_false_alarm = 0 # transverse the all filler_table - for key, confi in keyword_filler_table[keyword][ - 'filler_table'].items(): + for key, confi in keyword_filler_table[ + keyword]['filler_table'].items(): if confi >= threshold: num_false_alarm += 1 # print(f'false alarm: {keyword}, {key}, {confi}') diff --git a/wekws/bin/stream_kws_ctc.py b/wekws/bin/stream_kws_ctc.py index b20a13e..07e01e4 100644 --- a/wekws/bin/stream_kws_ctc.py +++ b/wekws/bin/stream_kws_ctc.py @@ -16,7 +16,7 @@ from __future__ import print_function import argparse import struct -#import wave +# import wave import librosa import logging import os @@ -328,12 +328,12 @@ class KeyWordSpotter(torch.nn.Module): wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension feats = kaldi.fbank(wave_tensor, - num_mel_bins=self.num_mel_bins, - frame_length=self.frame_length, - frame_shift=self.frame_shift, - dither=0, - energy_floor=0.0, - sample_frequency=self.sample_rate) + num_mel_bins=self.num_mel_bins, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + dither=0, + energy_floor=0.0, + sample_frequency=self.sample_rate) # update wave remained feat_len = len(feats) frame_shift = int(self.frame_shift / 1000 * self.sample_rate) @@ -351,8 +351,8 @@ class KeyWordSpotter(torch.nn.Module): else: feats_pad = torch.cat((self.feature_remained, feats)) - ctx_frm = feats_pad.shape[0] - \ - (self.right_context+self.right_context) + ctx_frm = feats_pad.shape[0] - ( + self.right_context + self.right_context) ctx_win = (self.left_context + self.right_context + 1) ctx_dim = feats.shape[1] * ctx_win feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) @@ -362,15 +362,15 @@ class KeyWordSpotter(torch.nn.Module): # update feature remained, and feats self.feature_remained = \ - feats[-(self.left_context+self.right_context):] + feats[-(self.left_context + self.right_context):] feats = feats_ctx.to(self.device) if self.downsampling > 1: - last_remainder = 0 if self.feats_ctx_offset==0 \ - else self.downsampling-self.feats_ctx_offset - remainder = (feats.size(0)+last_remainder) % self.downsampling + last_remainder = 0 if self.feats_ctx_offset == 0 \ + else self.downsampling - self.feats_ctx_offset + remainder = (feats.size(0) + last_remainder) % self.downsampling feats = feats[self.feats_ctx_offset::self.downsampling, :] self.feats_ctx_offset = remainder \ - if remainder == 0 else self.downsampling-remainder + if remainder == 0 else self.downsampling - remainder return feats def decode_keywords(self, t, probs): @@ -419,8 +419,8 @@ class KeyWordSpotter(torch.nn.Module): if hit_keyword is not None: if self.hit_score >= self.threshold and \ self.min_frames <= duration <= self.max_frames \ - and (self.last_active_pos==-1 or - end-self.last_active_pos >= self.interval_frames): + and (self.last_active_pos == -1 or + end - self.last_active_pos >= self.interval_frames): self.activated = True self.last_active_pos = end logging.info( @@ -428,8 +428,8 @@ class KeyWordSpotter(torch.nn.Module): f"from {start} to {end} frame. " f"duration {duration}, score {self.hit_score}, Activated.") - elif self.last_active_pos>0 and \ - end-self.last_active_pos < self.interval_frames: + elif self.last_active_pos > 0 and \ + end - self.last_active_pos < self.interval_frames: logging.info( f"Frame {absolute_time} detect {hit_keyword} " f"from {start} to {end} frame. " diff --git a/wekws/model/fsmn.py b/wekws/model/fsmn.py index 2b9b74d..430ed0f 100644 --- a/wekws/model/fsmn.py +++ b/wekws/model/fsmn.py @@ -216,7 +216,7 @@ class FSMNBlock(nn.Module): self.dequant = torch.quantization.DeQuantStub() def forward(self, - input: Tuple[torch.Tensor, torch.Tensor] ): + input: Tuple[torch.Tensor, torch.Tensor]): if isinstance(input, tuple): input, in_cache = input else : @@ -236,12 +236,13 @@ class FSMNBlock(nn.Module): y_left = self.quant(y_left) y_left = self.conv_left(y_left) y_left = self.dequant(y_left) - out = x_pad[:, :, (self.lorder - 1) * self.lstride: - -self.rorder * self.rstride, :] + y_left + out = x_pad[:, :, (self.lorder - 1) * self.lstride: -self.rorder * + self.rstride, :] + y_left if self.conv_right is not None: # y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) - y_right = x_pad[:, :, -(x_per.size(2)+self.rorder*self.rstride):, :] + y_right = x_pad[:, :, -( + x_per.size(2) + self.rorder * self.rstride):, :] y_right = y_right[:, :, self.rstride:, :] y_right = self.quant(y_right) y_right = self.conv_right(y_right) @@ -257,8 +258,8 @@ class FSMNBlock(nn.Module): re_str = '' re_str += ' %d %d\n' % (self.dim, self.dim) re_str += ' %d %d %d ' \ - ' %d %d 0\n' % ( - 1, self.lorder, self.rorder, self.lstride, self.rstride) + ' %d %d 0\n' % ( + 1, self.lorder, self.rorder, self.lstride, self.rstride) # print(self.conv_left.weight,self.conv_right.weight) lfiters = self.state_dict()['conv_left.weight'] @@ -445,8 +446,8 @@ class FSMN(nn.Module): self.output_affine_dim = output_affine_dim self.output_dim = output_dim - self.padding = (self.lorder-1) * self.lstride \ - + self.rorder * self.rstride + self.padding = (self.lorder - 1) * self.lstride \ + + self.rorder * self.rstride self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)