This commit is contained in:
lvanchao1 2022-04-01 19:28:02 +08:00
parent 1628f8c221
commit 515222c83e
5 changed files with 10 additions and 7 deletions

View File

@ -35,6 +35,7 @@ model:
stack_size: 4
kernel_size: 5
hidden_dim: 64
causal: true
optim: adam
optim_conf:

View File

@ -30,6 +30,7 @@ model:
preprocessing:
type: none
backbone:
causal: true
type: mdtc
num_stack: 3
stack_size: 4

View File

@ -41,6 +41,6 @@ optim_conf:
training_config:
grad_clip: 5
max_epoch: 80
max_epoch: 2
log_interval: 10

View File

@ -70,7 +70,7 @@ class KWSModel(nn.Module):
return x
else:
stride = self.streaming_chunk
num_frames = x.size(1) #(B, T, D)
num_frames = x.size(1)
cache: Optional[torch.Tensor] = None
out = []
for cur in range(0, num_frames, stride):
@ -180,5 +180,6 @@ def init_model(configs):
activation = nn.Sigmoid()
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier, activation, streaming_chunk)
preprocessing, backbone, classifier, activation,
streaming_chunk)
return kws_model

View File

@ -143,7 +143,7 @@ class TCN(nn.Module):
out_caches.append(c)
x = x.transpose(1, 2) # (B, T, D)
new_cache = torch.cat(out_caches, dim=2)
return x, new_cache
return x, new_cache
def forward_chunk(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
"""
@ -160,11 +160,11 @@ class TCN(nn.Module):
if cache is None:
cache = torch.zeros(x.size(0), x.size(1), self.padding)
for block in self.network:
cur_padding += block.padding
cur_padding += block.padding
x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding])
out_caches.append(c) # c (B, D, T)
out_caches.append(c)
x = x.transpose(1, 2) # (B, T, D)
new_cache = torch.cat(out_caches, dim=2)
new_cache = torch.cat(out_caches, dim=2)
return x, new_cache
def fuse_modules(self):