fix repeat activation, add a interval restrict.
This commit is contained in:
parent
ceb59e87a6
commit
c26ec9e4d0
@ -252,6 +252,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.activated = False
|
self.activated = False
|
||||||
|
|
||||||
self.total_frames = 0 # frame offset, for absolute time
|
self.total_frames = 0 # frame offset, for absolute time
|
||||||
|
self.last_active_pos = -1 # the last frame of being activated
|
||||||
self.result = {}
|
self.result = {}
|
||||||
|
|
||||||
def set_keywords(self, keywords):
|
def set_keywords(self, keywords):
|
||||||
@ -378,16 +379,25 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
|
|
||||||
duration = end - start
|
duration = end - start
|
||||||
if hit_keyword is not None:
|
if hit_keyword is not None:
|
||||||
if self.hit_score >= self.threshold and self.min_frames <= duration <= self.max_frames:
|
if self.hit_score >= self.threshold and \
|
||||||
|
self.min_frames <= duration <= self.max_frames \
|
||||||
|
and (self.last_active_pos==-1 or end-self.last_active_pos >= self.interval_frames):
|
||||||
self.activated = True
|
self.activated = True
|
||||||
|
self.last_active_pos = end
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
||||||
f"duration {duration}, score {self.hit_score} Activated.")
|
f"duration {duration}, score {self.hit_score}, Activated.")
|
||||||
|
|
||||||
|
elif self.last_active_pos>0 and end-self.last_active_pos < self.interval_frames:
|
||||||
|
logging.info(
|
||||||
|
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
||||||
|
f"but interval {end-self.last_active_pos} is lower than {self.interval_frames}, Deactivated. ")
|
||||||
|
|
||||||
elif self.hit_score < self.threshold:
|
elif self.hit_score < self.threshold:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
||||||
f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ")
|
f"but {self.hit_score} is lower than {self.threshold}, Deactivated. ")
|
||||||
|
|
||||||
elif self.min_frames > duration or duration > self.max_frames:
|
elif self.min_frames > duration or duration > self.max_frames:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
f"Frame {absolute_time} detect {hit_keyword} from {start} to {end} frame. "
|
||||||
@ -404,7 +414,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
def forward(self, wave_chunk):
|
def forward(self, wave_chunk):
|
||||||
feature = self.accept_wave(wave_chunk)
|
feature = self.accept_wave(wave_chunk)
|
||||||
if feature is None or feature.size(0) < 1:
|
if feature is None or feature.size(0) < 1:
|
||||||
return self.result
|
return {} # # the feature is not enough to get result.
|
||||||
feature = feature.unsqueeze(0) # add a batch dimension
|
feature = feature.unsqueeze(0) # add a batch dimension
|
||||||
logits, self.in_cache = self.model(feature, self.in_cache)
|
logits, self.in_cache = self.model(feature, self.in_cache)
|
||||||
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
|
||||||
@ -434,6 +444,7 @@ class KeyWordSpotter(torch.nn.Module):
|
|||||||
self.feats_ctx_offset = 0 # after downsample, offset exist.
|
self.feats_ctx_offset = 0 # after downsample, offset exist.
|
||||||
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
|
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
self.total_frames = 0 # frame offset, for absolute time
|
self.total_frames = 0 # frame offset, for absolute time
|
||||||
|
self.last_active_pos = -1 # the last frame of being activated
|
||||||
self.result = {}
|
self.result = {}
|
||||||
|
|
||||||
def demo():
|
def demo():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user