fix lint
This commit is contained in:
parent
1628f8c221
commit
515222c83e
@ -35,6 +35,7 @@ model:
|
||||
stack_size: 4
|
||||
kernel_size: 5
|
||||
hidden_dim: 64
|
||||
causal: true
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
|
||||
@ -30,6 +30,7 @@ model:
|
||||
preprocessing:
|
||||
type: none
|
||||
backbone:
|
||||
causal: true
|
||||
type: mdtc
|
||||
num_stack: 3
|
||||
stack_size: 4
|
||||
|
||||
@ -41,6 +41,6 @@ optim_conf:
|
||||
|
||||
training_config:
|
||||
grad_clip: 5
|
||||
max_epoch: 80
|
||||
max_epoch: 2
|
||||
log_interval: 10
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class KWSModel(nn.Module):
|
||||
return x
|
||||
else:
|
||||
stride = self.streaming_chunk
|
||||
num_frames = x.size(1) #(B, T, D)
|
||||
num_frames = x.size(1)
|
||||
cache: Optional[torch.Tensor] = None
|
||||
out = []
|
||||
for cur in range(0, num_frames, stride):
|
||||
@ -180,5 +180,6 @@ def init_model(configs):
|
||||
activation = nn.Sigmoid()
|
||||
|
||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
preprocessing, backbone, classifier, activation, streaming_chunk)
|
||||
preprocessing, backbone, classifier, activation,
|
||||
streaming_chunk)
|
||||
return kws_model
|
||||
|
||||
@ -143,7 +143,7 @@ class TCN(nn.Module):
|
||||
out_caches.append(c)
|
||||
x = x.transpose(1, 2) # (B, T, D)
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
return x, new_cache
|
||||
return x, new_cache
|
||||
|
||||
def forward_chunk(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
@ -160,11 +160,11 @@ class TCN(nn.Module):
|
||||
if cache is None:
|
||||
cache = torch.zeros(x.size(0), x.size(1), self.padding)
|
||||
for block in self.network:
|
||||
cur_padding += block.padding
|
||||
cur_padding += block.padding
|
||||
x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding])
|
||||
out_caches.append(c) # c (B, D, T)
|
||||
out_caches.append(c)
|
||||
x = x.transpose(1, 2) # (B, T, D)
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
return x, new_cache
|
||||
|
||||
def fuse_modules(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user