diff --git a/examples/hi_xiaowen/s0/run_ctc.sh b/examples/hi_xiaowen/s0/run_ctc.sh index 8f2d694..7cc1daf 100644 --- a/examples/hi_xiaowen/s0/run_ctc.sh +++ b/examples/hi_xiaowen/s0/run_ctc.sh @@ -203,7 +203,9 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --test_data data/test/data.list \ --window_shift $window_shift \ --step 0.001 \ - --score_file $result_dir/score.txt + --score_file $result_dir/score.txt \ + --token_file data/tokens.txt \ + --lexicon_file data/lexicon.txt fi diff --git a/examples/hi_xiaowen/s0/run_fsmn_ctc.sh b/examples/hi_xiaowen/s0/run_fsmn_ctc.sh index 79774a9..743e349 100644 --- a/examples/hi_xiaowen/s0/run_fsmn_ctc.sh +++ b/examples/hi_xiaowen/s0/run_fsmn_ctc.sh @@ -154,17 +154,21 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then --test_data data/test/data.list \ --window_shift $window_shift \ --step 0.001 \ - --score_file $result_dir/score.txt + --score_file $result_dir/score.txt \ + --token_file data/tokens.txt \ + --lexicon_file data/lexicon.txt fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') +# NOTE: FSMN now is not support export to jit, beacuse of nn.Sequential with tuple input +# This issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450 +# jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') +# python wekws/bin/export_jit.py \ +# --config $dir/config.yaml \ +# --checkpoint $score_checkpoint \ +# --jit_model $dir/$jit_model onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g') - python wekws/bin/export_jit.py \ - --config $dir/config.yaml \ - --checkpoint $score_checkpoint \ - --jit_model $dir/$jit_model python wekws/bin/export_onnx.py \ --config $dir/config.yaml \ --checkpoint $score_checkpoint \ diff --git a/wekws/bin/compute_det_ctc.py b/wekws/bin/compute_det_ctc.py index 14f3d86..2c31c7f 100644 --- a/wekws/bin/compute_det_ctc.py +++ b/wekws/bin/compute_det_ctc.py @@ -23,6 +23,7 @@ import os import numpy as np import matplotlib.pyplot as plt import pypinyin # for Chinese Character +from tools.make_list import query_token_set, read_lexicon, read_token def split_mixed_label(input_str): tokens = [] @@ -43,7 +44,7 @@ def space_mixed_label(input_str): space_str = ''.join(f'{sub} ' for sub in splits) return space_str.strip() -def load_label_and_score(keywords_list, label_file, score_file): +def load_label_and_score(keywords_list, label_file, score_file, true_keywords): score_table = {} with open(score_file, 'r', encoding='utf8') as fin: # read score file and store in table @@ -52,10 +53,11 @@ def load_label_and_score(keywords_list, label_file, score_file): key = arr[0] is_detected = arr[1] if is_detected == 'detected': + keyword=true_keywords[arr[2]] if key not in score_table: score_table.update({ key: { - 'kw': space_mixed_label(arr[2]), + 'kw': space_mixed_label(keyword), 'confi': float(arr[3]) } }) @@ -72,6 +74,7 @@ def load_label_and_score(keywords_list, label_file, score_file): # build empty structure for keyword-filler infos keyword_filler_table = {} for keyword in keywords_list: + keyword = true_keywords[keyword] keyword = space_mixed_label(keyword) keyword_filler_table[keyword] = {} keyword_filler_table[keyword]['keyword_table'] = {} @@ -93,6 +96,7 @@ def load_label_and_score(keywords_list, label_file, score_file): assert key in score_table for keyword in keywords_list: + keyword = true_keywords[keyword] keyword = space_mixed_label(keyword) keyword_regstr_lrblk = ' ' + keyword + ' ' if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1: @@ -154,6 +158,8 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description='compute det curve') parser.add_argument('--test_data', required=True, help='label file') parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)') + parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt') + parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt') parser.add_argument('--score_file', required=True, help='score file') parser.add_argument('--step', type=float, default=0.01, help='threshold step') @@ -183,9 +189,18 @@ if __name__ == '__main__': args = parser.parse_args() window_shift = args.window_shift keywords_list = args.keywords.strip().split(',') - keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file) + + token_table = read_token(args.token_file) + lexicon_table = read_lexicon(args.lexicon_file) + true_keywords = {} + for keyword in keywords_list: + strs, indexes = query_token_set(keyword, token_table, lexicon_table) + true_keywords[keyword] = ''.join(strs) + + keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file, true_keywords) for keyword in keywords_list: + keyword = true_keywords[keyword] keyword = space_mixed_label(keyword) keyword_dur = keyword_filler_table[keyword]['keyword_duration'] keyword_num = len(keyword_filler_table[keyword]['keyword_table']) diff --git a/wekws/bin/stream_kws_ctc.py b/wekws/bin/stream_kws_ctc.py index 869b880..e9fe32a 100644 --- a/wekws/bin/stream_kws_ctc.py +++ b/wekws/bin/stream_kws_ctc.py @@ -35,7 +35,10 @@ from tools.make_list import query_token_set, read_lexicon, read_token def get_args(): parser = argparse.ArgumentParser(description='detect keywords online.') parser.add_argument('--config', required=True, help='config file') - parser.add_argument('--wav_path', required=True, help='test wave path.') + parser.add_argument('--wav_path', required=False, default=None, help='test wave path.') + parser.add_argument('--wav_scp', required=False, default=None, help='test wave scp.') + parser.add_argument('--result_file', required=False, default=None, help='test result.') + parser.add_argument('--gpu', type=int, default=-1, @@ -209,7 +212,7 @@ class KeyWordSpotter(torch.nn.Module): self.left_context = dataset_conf['context_expansion_conf']['left'] self.right_context = dataset_conf['context_expansion_conf']['right'] self.feature_remained = None - self.feature_context_offset = 0 # after downsample, offset exist. + self.feats_ctx_offset = 0 # after downsample, offset exist. # model related @@ -289,6 +292,9 @@ class KeyWordSpotter(torch.nn.Module): wave = np.array(data) wave = np.append(self.wave_remained, wave) + if wave.size < (self.frame_length * self.sample_rate / 1000) * self.right_context : + self.wave_remained = wave + return None wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension feats = kaldi.fbank(wave_tensor, @@ -312,7 +318,7 @@ class KeyWordSpotter(torch.nn.Module): else: feats_pad = torch.cat((self.feature_remained, feats)) - ctx_frm = feats.shape[0] - self.right_context + ctx_frm = feats_pad.shape[0] - (self.right_context+self.right_context) ctx_win = (self.left_context + self.right_context + 1) ctx_dim = feats.shape[1] * ctx_win feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) @@ -320,12 +326,13 @@ class KeyWordSpotter(torch.nn.Module): feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0) # update feature remained, and feats - self.feature_remained = feats[-self.left_context:] + self.feature_remained = feats[-(self.left_context+self.right_context):] feats = feats_ctx.to(self.device) if self.downsampling > 1: - feats = feats[self.feature_context_offset::self.downsampling, :] - complement = feats.size(1) % self.downsampling - self.feature_context_offset = complement if complement == 0 else self.downsampling-complement + last_remainder = 0 if self.feats_ctx_offset==0 else self.downsampling-self.feats_ctx_offset + remainder = (feats.size(0)+last_remainder) % self.downsampling + feats = feats[self.feats_ctx_offset::self.downsampling, :] + self.feats_ctx_offset = remainder if remainder == 0 else self.downsampling-remainder return feats def decode_keywords(self, t, probs): @@ -396,11 +403,14 @@ class KeyWordSpotter(torch.nn.Module): def forward(self, wave_chunk): feature = self.accept_wave(wave_chunk) + if feature is None or feature.size(0) < 1: + return self.result feature = feature.unsqueeze(0) # add a batch dimension logits, self.in_cache = self.model(feature, self.in_cache) probs = logits.softmax(2) # (batch_size, maxlen, vocab_size) probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search for (t, prob) in enumerate(probs): + t *= self.downsampling self.decode_keywords(t, prob) self.execute_detection(t) @@ -409,7 +419,7 @@ class KeyWordSpotter(torch.nn.Module): # since a chunk include about 30 frames, once activated, we can jump the latter frames. # TODO: there should give another method to update result, avoiding self.result being cleared. break - self.total_frames += len(probs) # update frame offset + self.total_frames += len(probs) * self.downsampling # update frame offset return self.result def reset(self): @@ -417,6 +427,15 @@ class KeyWordSpotter(torch.nn.Module): self.activated = False self.hit_score = 1.0 + def reset_all(self): + self.reset() + self.wave_remained = np.array([]) + self.feature_remained = None + self.feats_ctx_offset = 0 # after downsample, offset exist. + self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) + self.total_frames = 0 # frame offset, for absolute time + self.result = {} + def demo(): args = get_args() logging.basicConfig(level=logging.DEBUG, @@ -435,22 +454,60 @@ def demo(): args.gpu, args.jit_model) - # actually this could be done in __init__ method, we pull it outside for changing keywords. + # actually this could be done in __init__ method, we pull it outside for changing keywords more freely. kws.set_keywords(args.keywords) - # Caution: input WAV should be standard 16k, 16 bits, 1 channel - # In demo we read wave in non-streaming fashion. - with wave.open(args.wav_path, 'rb') as fin: - assert fin.getnchannels() == 1 - wav = fin.readframes(fin.getnframes()) + if args.wav_path: + # Caution: input WAV should be standard 16k, 16 bits, 1 channel + # In demo we read wave in non-streaming fashion. + with wave.open(args.wav_path, 'rb') as fin: + assert fin.getnchannels() == 1 + wav = fin.readframes(fin.getnframes()) - # We inference every 0.3 seconds, in streaming fashion. - interval = int(0.3 * 16000) * 2 - for i in range(0, len(wav), interval): - chunk_wav = wav[i: min(i + interval, len(wav))] - result = kws.forward(chunk_wav) - print(result) + # We inference every 0.3 seconds, in streaming fashion. + interval = int(0.3 * 16000) * 2 + for i in range(0, len(wav), interval): + chunk_wav = wav[i: min(i + interval, len(wav))] + result = kws.forward(chunk_wav) + print(result) + fout = None + if args.result_file: + fout = open(args.result_file, 'w', encoding='utf-8') + + if args.wav_scp: + with open(args.wav_scp, 'r') as fscp: + for line in fscp: + line = line.strip().split() + assert len(line) == 2, f"The scp should be in kaldi format: \"utt_name wav_path\", but got {line}" + + utt_name, wav_path = line[0], line[1] + with wave.open(wav_path, 'rb') as fin: + assert fin.getnchannels() == 1 + wav = fin.readframes(fin.getnframes()) + + kws.reset_all() + activated = False + + # We inference every 0.3 seconds, in streaming fashion. + interval = int(0.3 * 16000) * 2 + for i in range(0, len(wav), interval): + chunk_wav = wav[i: min(i + interval, len(wav))] + result = kws.forward(chunk_wav) + if 'state' in result and result['state'] == 1: + activated = True + if fout: + hit_keyword = result['keyword'] + hit_score = result['score'] + fout.write('{} detected {} {:.3f}\n'.format(utt_name, hit_keyword, hit_score)) + + if not activated: + if fout: + fout.write('{} rejected\n'.format(utt_name)) + + + if fout: + fout.close() if __name__ == '__main__': demo() diff --git a/wekws/bin/stream_score_ctc.py b/wekws/bin/stream_score_ctc.py index 954f8eb..fab510f 100644 --- a/wekws/bin/stream_score_ctc.py +++ b/wekws/bin/stream_score_ctc.py @@ -44,11 +44,11 @@ def get_args(): help='gpu id for this rank, -1 for cpu') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--batch_size', - default=16, + default=1, type=int, help='batch size for inference') parser.add_argument('--num_workers', - default=0, + default=1, type=int, help='num of subprocess workers for reading') parser.add_argument('--pin_memory', @@ -131,6 +131,8 @@ def main(): test_conf['feature_extraction_conf']['dither'] = 0.0 test_conf['batch_conf']['batch_size'] = args.batch_size + downsampling_factor = test_conf.get('frame_skip', 1) + test_dataset = Dataset(args.test_data, test_conf) test_data_loader = DataLoader(test_dataset, batch_size=None, @@ -205,6 +207,7 @@ def main(): # 2. CTC beam search step by step for t in range(0, maxlen): probs = ctc_probs[t] # (vocab_size,) + t *= downsampling_factor # the real time # key: prefix, value (pb, pnb), default value(-inf, -inf) next_hyps = defaultdict(lambda: (0.0, 0.0, [])) diff --git a/wekws/bin/train.py b/wekws/bin/train.py index 5c5c933..b81ad9f 100644 --- a/wekws/bin/train.py +++ b/wekws/bin/train.py @@ -157,8 +157,12 @@ def main(): # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements if rank == 0: - script_model = torch.jit.script(model) - script_model.save(os.path.join(args.model_dir, 'init.zip')) + pass + # TODO: for now streaming FSMN do not support export to JITScript, + # TODO: because there is nn.Sequential with Tuple input in current FSMN modules. + # the issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450 + # script_model = torch.jit.script(model) + # script_model.save(os.path.join(args.model_dir, 'init.zip')) executor = Executor() # If specify checkpoint, load some info from checkpoint if args.checkpoint is not None: diff --git a/wekws/dataset/processor.py b/wekws/dataset/processor.py index ca8a9db..bb1dde7 100644 --- a/wekws/dataset/processor.py +++ b/wekws/dataset/processor.py @@ -377,6 +377,8 @@ def add_reverb(data, reverb_source, aug_prob): rir_io = io.BytesIO(rir_data) _, rir_audio = wavfile.read(rir_io) rir_audio = rir_audio.astype(np.float32) + if len(rir_audio.shape) > 1: + rir_audio = rir_audio[:, 0] rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2)) out_audio = signal.convolve(audio, rir_audio, mode='full')[:audio_len] @@ -405,6 +407,8 @@ def add_noise(data, noise_source, aug_prob): snr_range = [0, 15] _, noise_audio = wavfile.read(io.BytesIO(noise_data)) noise_audio = noise_audio.astype(np.float32) + if len(noise_audio.shape) > 1: + noise_audio = noise_audio[:, 0] if noise_audio.shape[0] > audio_len: start = random.randint(0, noise_audio.shape[0] - audio_len) noise_audio = noise_audio[start:start + audio_len] diff --git a/wekws/model/fsmn.py b/wekws/model/fsmn.py index 903d456..86ce41f 100644 --- a/wekws/model/fsmn.py +++ b/wekws/model/fsmn.py @@ -39,11 +39,12 @@ class LinearTransform(nn.Module): self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, input): + def forward(self, + input: Tuple[torch.Tensor, torch.Tensor]): if isinstance(input, tuple): input, in_cache = input else: - in_cache = None + in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) output = self.quant(input) output = self.linear(output) output = self.dequant(output) @@ -102,11 +103,12 @@ class AffineTransform(nn.Module): self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, input): + def forward(self, + input: Tuple[torch.Tensor, torch.Tensor]): if isinstance(input, tuple): input, in_cache = input else: - in_cache = None + in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) output = self.quant(input) output = self.linear(output) output = self.dequant(output) @@ -213,15 +215,16 @@ class FSMNBlock(nn.Module): self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() - def forward(self, input): + def forward(self, + input: Tuple[torch.Tensor, torch.Tensor] ): if isinstance(input, tuple): input, in_cache = input else : - in_cache = None + in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) x = torch.unsqueeze(input, 1) x_per = x.permute(0, 3, 2, 1) - if in_cache is None or len(in_cache) == 0 or in_cache[0] is None: + if in_cache is None or len(in_cache) == 0 : x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0]) else: in_cache = in_cache.to(x_per.device) @@ -339,11 +342,12 @@ class RectifiedLinear(nn.Module): self.relu = nn.ReLU() self.dropout = nn.Dropout(0.1) - def forward(self, input): + def forward(self, + input: Tuple[torch.Tensor, torch.Tensor]): if isinstance(input, tuple): input, in_cache = input else : - in_cache = None + in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float) out = self.relu(input) # out = self.dropout(out) return (out, in_cache) @@ -456,7 +460,7 @@ class FSMN(nn.Module): def forward( self, input: torch.Tensor, - in_cache #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) + in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -464,8 +468,8 @@ class FSMN(nn.Module): in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size """ - if in_cache is None or len(in_cache) == 0 or in_cache[0] == None: - in_cache = [None for _ in range(len(self.fsmn))] + if in_cache is None or len(in_cache) == 0 : + in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) for _ in range(len(self.fsmn))] input = (input, in_cache) x1 = self.in_linear1(input) x2 = self.in_linear2(x1)