fix lint
This commit is contained in:
parent
1628f8c221
commit
515222c83e
@ -35,6 +35,7 @@ model:
|
|||||||
stack_size: 4
|
stack_size: 4
|
||||||
kernel_size: 5
|
kernel_size: 5
|
||||||
hidden_dim: 64
|
hidden_dim: 64
|
||||||
|
causal: true
|
||||||
|
|
||||||
optim: adam
|
optim: adam
|
||||||
optim_conf:
|
optim_conf:
|
||||||
|
|||||||
@ -30,6 +30,7 @@ model:
|
|||||||
preprocessing:
|
preprocessing:
|
||||||
type: none
|
type: none
|
||||||
backbone:
|
backbone:
|
||||||
|
causal: true
|
||||||
type: mdtc
|
type: mdtc
|
||||||
num_stack: 3
|
num_stack: 3
|
||||||
stack_size: 4
|
stack_size: 4
|
||||||
|
|||||||
@ -41,6 +41,6 @@ optim_conf:
|
|||||||
|
|
||||||
training_config:
|
training_config:
|
||||||
grad_clip: 5
|
grad_clip: 5
|
||||||
max_epoch: 80
|
max_epoch: 2
|
||||||
log_interval: 10
|
log_interval: 10
|
||||||
|
|
||||||
|
|||||||
@ -70,7 +70,7 @@ class KWSModel(nn.Module):
|
|||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
stride = self.streaming_chunk
|
stride = self.streaming_chunk
|
||||||
num_frames = x.size(1) #(B, T, D)
|
num_frames = x.size(1)
|
||||||
cache: Optional[torch.Tensor] = None
|
cache: Optional[torch.Tensor] = None
|
||||||
out = []
|
out = []
|
||||||
for cur in range(0, num_frames, stride):
|
for cur in range(0, num_frames, stride):
|
||||||
@ -180,5 +180,6 @@ def init_model(configs):
|
|||||||
activation = nn.Sigmoid()
|
activation = nn.Sigmoid()
|
||||||
|
|
||||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||||
preprocessing, backbone, classifier, activation, streaming_chunk)
|
preprocessing, backbone, classifier, activation,
|
||||||
|
streaming_chunk)
|
||||||
return kws_model
|
return kws_model
|
||||||
|
|||||||
@ -162,7 +162,7 @@ class TCN(nn.Module):
|
|||||||
for block in self.network:
|
for block in self.network:
|
||||||
cur_padding += block.padding
|
cur_padding += block.padding
|
||||||
x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding])
|
x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding])
|
||||||
out_caches.append(c) # c (B, D, T)
|
out_caches.append(c)
|
||||||
x = x.transpose(1, 2) # (B, T, D)
|
x = x.transpose(1, 2) # (B, T, D)
|
||||||
new_cache = torch.cat(out_caches, dim=2)
|
new_cache = torch.cat(out_caches, dim=2)
|
||||||
return x, new_cache
|
return x, new_cache
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user