fix repeat activation, add a interval restrict.

This commit is contained in:
dujing 2023-06-27 18:07:37 +08:00
parent ceb59e87a6
commit c26ec9e4d0

View File

@ -252,6 +252,7 @@ class KeyWordSpotter(torch.nn.Module):
self.activated = False self.activated = False
self.total_frames = 0 # frame offset, for absolute time self.total_frames = 0 # frame offset, for absolute time
self.last_active_pos = -1 # the last frame of being activated
self.result = {} self.result = {}
def set_keywords(self, keywords): def set_keywords(self, keywords):
@ -378,16 +379,25 @@ class KeyWordSpotter(torch.nn.Module):
duration = end - start duration = end - start
if hit_keyword is not None: if hit_keyword is not None:
if self.hit_score >= self.threshold and self.min_frames <= duration <= self.max_frames: if self.hit_score >= self.threshold and \
self.min_frames <= duration <= self.max_frames \
and (self.last_active_pos==-1 or end-self.last_active_pos >= self.interval_frames):
self.activated = True self.activated = True
self.last_active_pos = end
logging.info( logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
f"duration {duration}, score {self.hit_score} Activated.") f"duration {duration}, score {self.hit_score}, Activated.")
elif self.last_active_pos>0 and end-self.last_active_pos < self.interval_frames:
logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
f"but interval {end-self.last_active_pos} is lower than {self.interval_frames}, Deactivated. ")
elif self.hit_score < self.threshold: elif self.hit_score < self.threshold:
logging.info( logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ") f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ")
elif self.min_frames > duration or duration > self.max_frames: elif self.min_frames > duration or duration > self.max_frames:
logging.info( logging.info(
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. " f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
@ -404,7 +414,7 @@ class KeyWordSpotter(torch.nn.Module):
def forward(self, wave_chunk): def forward(self, wave_chunk):
feature = self.accept_wave(wave_chunk) feature = self.accept_wave(wave_chunk)
if feature is None or feature.size(0) < 1: if feature is None or feature.size(0) < 1:
return self.result return {} # # the feature is not enough to get result.
feature = feature.unsqueeze(0) # add a batch dimension feature = feature.unsqueeze(0) # add a batch dimension
logits, self.in_cache = self.model(feature, self.in_cache) logits, self.in_cache = self.model(feature, self.in_cache)
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size) probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
@ -434,6 +444,7 @@ class KeyWordSpotter(torch.nn.Module):
self.feats_ctx_offset = 0 # after downsample, offset exist. self.feats_ctx_offset = 0 # after downsample, offset exist.
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
self.total_frames = 0 # frame offset, for absolute time self.total_frames = 0 # frame offset, for absolute time
self.last_active_pos = -1 # the last frame of being activated
self.result = {} self.result = {}
def demo(): def demo():