This commit is contained in:
lvanchao1 2022-04-01 19:28:02 +08:00
parent 1628f8c221
commit 515222c83e
5 changed files with 10 additions and 7 deletions

View File

@ -35,6 +35,7 @@ model:
stack_size: 4 stack_size: 4
kernel_size: 5 kernel_size: 5
hidden_dim: 64 hidden_dim: 64
causal: true
optim: adam optim: adam
optim_conf: optim_conf:

View File

@ -30,6 +30,7 @@ model:
preprocessing: preprocessing:
type: none type: none
backbone: backbone:
causal: true
type: mdtc type: mdtc
num_stack: 3 num_stack: 3
stack_size: 4 stack_size: 4

View File

@ -41,6 +41,6 @@ optim_conf:
training_config: training_config:
grad_clip: 5 grad_clip: 5
max_epoch: 80 max_epoch: 2
log_interval: 10 log_interval: 10

View File

@ -70,7 +70,7 @@ class KWSModel(nn.Module):
return x return x
else: else:
stride = self.streaming_chunk stride = self.streaming_chunk
num_frames = x.size(1) #(B, T, D) num_frames = x.size(1)
cache: Optional[torch.Tensor] = None cache: Optional[torch.Tensor] = None
out = [] out = []
for cur in range(0, num_frames, stride): for cur in range(0, num_frames, stride):
@ -180,5 +180,6 @@ def init_model(configs):
activation = nn.Sigmoid() activation = nn.Sigmoid()
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier, activation, streaming_chunk) preprocessing, backbone, classifier, activation,
streaming_chunk)
return kws_model return kws_model

View File

@ -162,7 +162,7 @@ class TCN(nn.Module):
for block in self.network: for block in self.network:
cur_padding += block.padding cur_padding += block.padding
x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding]) x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding])
out_caches.append(c) # c (B, D, T) out_caches.append(c)
x = x.transpose(1, 2) # (B, T, D) x = x.transpose(1, 2) # (B, T, D)
new_cache = torch.cat(out_caches, dim=2) new_cache = torch.cat(out_caches, dim=2)
return x, new_cache return x, new_cache