From c26ec9e4d00a02ce33d028bfd507527091d097d9 Mon Sep 17 00:00:00 2001 From: dujing Date: Tue, 27 Jun 2023 18:07:37 +0800 Subject: [PATCH] fix repeat activation, add a interval restrict. --- wekws/bin/stream_kws_ctc.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/wekws/bin/stream_kws_ctc.py b/wekws/bin/stream_kws_ctc.py index e9fe32a..f609b72 100644 --- a/wekws/bin/stream_kws_ctc.py +++ b/wekws/bin/stream_kws_ctc.py @@ -252,6 +252,7 @@ class KeyWordSpotter(torch.nn.Module): self.activated = False self.total_frames = 0 # frame offset, for absolute time + self.last_active_pos = -1 # the last frame of being activated self.result = {} def set_keywords(self, keywords): @@ -378,16 +379,25 @@ class KeyWordSpotter(torch.nn.Module): duration = end - start if hit_keyword is not None: - if self.hit_score >= self.threshold and self.min_frames <= duration <= self.max_frames: + if self.hit_score >= self.threshold and \ + self.min_frames <= duration <= self.max_frames \ + and (self.last_active_pos==-1 or end-self.last_active_pos >= self.interval_frames): self.activated = True + self.last_active_pos = end logging.info( f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " - f"duration {duration}, score {self.hit_score} Activated.") + f"duration {duration}, score {self.hit_score}, Activated.") + + elif self.last_active_pos>0 and end-self.last_active_pos < self.interval_frames: + logging.info( + f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " + f"but interval {end-self.last_active_pos} is lower than {self.interval_frames}, Deactivated. ") elif self.hit_score < self.threshold: logging.info( f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ") + elif self.min_frames > duration or duration > self.max_frames: logging.info( f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " @@ -404,7 +414,7 @@ class KeyWordSpotter(torch.nn.Module): def forward(self, wave_chunk): feature = self.accept_wave(wave_chunk) if feature is None or feature.size(0) < 1: - return self.result + return {} # # the feature is not enough to get result. feature = feature.unsqueeze(0) # add a batch dimension logits, self.in_cache = self.model(feature, self.in_cache) probs = logits.softmax(2) # (batch_size, maxlen, vocab_size) @@ -434,6 +444,7 @@ class KeyWordSpotter(torch.nn.Module): self.feats_ctx_offset = 0 # after downsample, offset exist. self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) self.total_frames = 0 # frame offset, for absolute time + self.last_active_pos = -1 # the last frame of being activated self.result = {} def demo():