fix repeat activation, add a interval restrict.
This commit is contained in:
parent
ceb59e87a6
commit
c26ec9e4d0
@ -252,6 +252,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
self.activated = False
|
||||
|
||||
self.total_frames = 0 # frame offset, for absolute time
|
||||
self.last_active_pos = -1 # the last frame of being activated
|
||||
self.result = {}
|
||||
|
||||
def set_keywords(self, keywords):
|
||||
@ -378,16 +379,25 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
|
||||
duration = end - start
|
||||
if hit_keyword is not None:
|
||||
if self.hit_score >= self.threshold and self.min_frames <= duration <= self.max_frames:
|
||||
if self.hit_score >= self.threshold and \
|
||||
self.min_frames <= duration <= self.max_frames \
|
||||
and (self.last_active_pos==-1 or end-self.last_active_pos >= self.interval_frames):
|
||||
self.activated = True
|
||||
self.last_active_pos = end
|
||||
logging.info(
|
||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
||||
f"duration {duration}, score {self.hit_score} Activated.")
|
||||
f"duration {duration}, score {self.hit_score}, Activated.")
|
||||
|
||||
elif self.last_active_pos>0 and end-self.last_active_pos < self.interval_frames:
|
||||
logging.info(
|
||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
||||
f"but interval {end-self.last_active_pos} is lower than {self.interval_frames}, Deactivated. ")
|
||||
|
||||
elif self.hit_score < self.threshold:
|
||||
logging.info(
|
||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
||||
f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ")
|
||||
|
||||
elif self.min_frames > duration or duration > self.max_frames:
|
||||
logging.info(
|
||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
||||
@ -404,7 +414,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
def forward(self, wave_chunk):
|
||||
feature = self.accept_wave(wave_chunk)
|
||||
if feature is None or feature.size(0) < 1:
|
||||
return self.result
|
||||
return {} # # the feature is not enough to get result.
|
||||
feature = feature.unsqueeze(0) # add a batch dimension
|
||||
logits, self.in_cache = self.model(feature, self.in_cache)
|
||||
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||
@ -434,6 +444,7 @@ class KeyWordSpotter(torch.nn.Module):
|
||||
self.feats_ctx_offset = 0 # after downsample, offset exist.
|
||||
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
self.total_frames = 0 # frame offset, for absolute time
|
||||
self.last_active_pos = -1 # the last frame of being activated
|
||||
self.result = {}
|
||||
|
||||
def demo():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user