Compare commits

...

2 Commits

Author SHA1 Message Date
Jean Du
059fd87a8f
[ctc]fix false rejection result from long time silence. (#149)
* [ctc]fix false rejection result from long time silence.

* fix list index out of range.

---------

Co-authored-by: dujing <dujing@xmov.ai>
2023-10-09 17:34:36 +08:00
Di Wu
58859d580d
[wekws] fix log (only one process print model) (#150)
Co-authored-by: di.wu <di.wu@diwudeMacBook-Pro.local>
2023-10-09 17:33:52 +08:00
2 changed files with 12 additions and 4 deletions

View File

@ -481,6 +481,15 @@ class KeyWordSpotter(torch.nn.Module):
# update frame offset
self.total_frames += len(probs) * self.downsampling
# For streaming kws, the cur_hyps should be reset if the time of
# a possible keyword last over the max_frames value you set.
# see this issue:https://github.com/duj12/kws_demo/issues/2
if len(self.cur_hyps) > 0 and len(self.cur_hyps[0][0]) > 0:
keyword_may_start = int(self.cur_hyps[0][1][2][0]['frame'])
if (self.total_frames - keyword_may_start) > self.max_frames:
self.reset()
return self.result
def reset(self):

View File

@ -141,15 +141,14 @@ def main():
configs['model']['cmvn'] = {}
configs['model']['cmvn']['norm_var'] = args.norm_var
configs['model']['cmvn']['cmvn_file'] = args.cmvn_file
# Init asr model from configs
model = init_model(configs['model'])
if rank == 0:
saved_config_path = os.path.join(args.model_dir, 'config.yaml')
with open(saved_config_path, 'w') as fout:
data = yaml.dump(configs)
fout.write(data)
# Init asr model from configs
model = init_model(configs['model'])
print(model)
print(model)
num_params = count_parameters(model)
print('the number of model params: {}'.format(num_params))