Compare commits
2 Commits
diwu-wekws
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
059fd87a8f | ||
|
|
58859d580d |
@ -481,6 +481,15 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
|
|
||||||
# update frame offset
|
# update frame offset
|
||||||
self.total_frames += len(probs) * self.downsampling
|
self.total_frames += len(probs) * self.downsampling
|
||||||
|
|
||||||
|
# For streaming kws, the cur_hyps should be reset if the time of
|
||||||
|
# a possible keyword last over the max_frames value you set.
|
||||||
|
# see this issue:https://github.com/duj12/kws_demo/issues/2
|
||||||
|
if len(self.cur_hyps) > 0 and len(self.cur_hyps[0][0]) > 0:
|
||||||
|
keyword_may_start = int(self.cur_hyps[0][1][2][0]['frame'])
|
||||||
|
if (self.total_frames - keyword_may_start) > self.max_frames:
|
||||||
|
self.reset()
|
||||||
|
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
|||||||
@ -141,14 +141,13 @@ def main():
|
|||||||
configs['model']['cmvn'] = {}
|
configs['model']['cmvn'] = {}
|
||||||
configs['model']['cmvn']['norm_var'] = args.norm_var
|
configs['model']['cmvn']['norm_var'] = args.norm_var
|
||||||
configs['model']['cmvn']['cmvn_file'] = args.cmvn_file
|
configs['model']['cmvn']['cmvn_file'] = args.cmvn_file
|
||||||
|
# Init asr model from configs
|
||||||
|
model = init_model(configs['model'])
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
saved_config_path = os.path.join(args.model_dir, 'config.yaml')
|
saved_config_path = os.path.join(args.model_dir, 'config.yaml')
|
||||||
with open(saved_config_path, 'w') as fout:
|
with open(saved_config_path, 'w') as fout:
|
||||||
data = yaml.dump(configs)
|
data = yaml.dump(configs)
|
||||||
fout.write(data)
|
fout.write(data)
|
||||||
|
|
||||||
# Init asr model from configs
|
|
||||||
model = init_model(configs['model'])
|
|
||||||
print(model)
|
print(model)
|
||||||
num_params = count_parameters(model)
|
num_params = count_parameters(model)
|
||||||
print('the number of model params: {}'.format(num_params))
|
print('the number of model params: {}'.format(num_params))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user