add streaming inference at tcn and ds_tcn model
This commit is contained in:
parent
7d142b9528
commit
1628f8c221
@ -101,7 +101,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
--batch_size 256 \
|
||||
--checkpoint $score_checkpoint \
|
||||
--score_file $result_dir/score.txt \
|
||||
--num_workers 8
|
||||
--num_workers 8 \
|
||||
--streaming_chunk 20
|
||||
|
||||
for keyword in 0 1; do
|
||||
python kws/bin/compute_det.py \
|
||||
@ -136,7 +137,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
--jit_model \
|
||||
--checkpoint $dir/$quantize_score_checkpoint \
|
||||
--score_file $result_dir/score.txt \
|
||||
--num_workers 8
|
||||
--num_workers 8 \
|
||||
--streaming_chunk 20
|
||||
for keyword in 0 1; do
|
||||
python kws/bin/compute_det.py \
|
||||
--keyword $keyword \
|
||||
@ -158,4 +160,4 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||
--config $dir/config.yaml \
|
||||
--jit_model $dir/$jit_model \
|
||||
--onnx_model $dir/$onnx_model
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -55,6 +55,10 @@ def get_args():
|
||||
default=100,
|
||||
type=int,
|
||||
help='prefetch number')
|
||||
parser.add_argument('--streaming_chunk',
|
||||
default=-1,
|
||||
type=int,
|
||||
help='streaming chunk size for inference')
|
||||
parser.add_argument('--score_file',
|
||||
required=True,
|
||||
help='output score file')
|
||||
@ -84,6 +88,8 @@ def main():
|
||||
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||
test_conf['batch_conf']['batch_size'] = args.batch_size
|
||||
|
||||
configs['model']['streaming_chunk'] = args.streaming_chunk
|
||||
|
||||
test_dataset = Dataset(args.test_data, test_conf)
|
||||
test_data_loader = DataLoader(test_dataset,
|
||||
batch_size=None,
|
||||
|
||||
@ -46,6 +46,7 @@ class KWSModel(nn.Module):
|
||||
backbone: nn.Module,
|
||||
classifier: nn.Module,
|
||||
activation: nn.Module,
|
||||
streaming_chunk: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.idim = idim
|
||||
@ -56,15 +57,33 @@ class KWSModel(nn.Module):
|
||||
self.backbone = backbone
|
||||
self.classifier = classifier
|
||||
self.activation = activation
|
||||
self.streaming_chunk = streaming_chunk
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.global_cmvn is not None:
|
||||
x = self.global_cmvn(x)
|
||||
x = self.preprocessing(x)
|
||||
x, _ = self.backbone(x)
|
||||
x = self.classifier(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
if self.streaming_chunk <= 0:
|
||||
x = self.preprocessing(x)
|
||||
x, _ = self.backbone(x)
|
||||
x = self.classifier(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
else:
|
||||
stride = self.streaming_chunk
|
||||
num_frames = x.size(1) #(B, T, D)
|
||||
cache: Optional[torch.Tensor] = None
|
||||
out = []
|
||||
for cur in range(0, num_frames, stride):
|
||||
end = min(cur + stride, num_frames)
|
||||
x_s = x[:, cur:end, :]
|
||||
x_s = self.preprocessing(x_s)
|
||||
x_s, cache = self.backbone.forward_chunk(x_s, cache)
|
||||
x_s = self.classifier(x_s)
|
||||
x_s = self.activation(x_s)
|
||||
out.append(x_s)
|
||||
out_ = torch.cat(out, dim=1)
|
||||
# wait all num_frames out, then return
|
||||
return out_
|
||||
|
||||
def fuse_modules(self):
|
||||
self.preprocessing.fuse_modules()
|
||||
@ -87,6 +106,10 @@ def init_model(configs):
|
||||
output_dim = configs['output_dim']
|
||||
hidden_dim = configs['hidden_dim']
|
||||
|
||||
streaming_chunk = -1
|
||||
if ('streaming_chunk' in configs and configs['streaming_chunk'] > 0):
|
||||
streaming_chunk = configs['streaming_chunk']
|
||||
|
||||
prep_type = configs['preprocessing']['type']
|
||||
if prep_type == 'linear':
|
||||
preprocessing = LinearSubsampling1(input_dim, hidden_dim)
|
||||
@ -157,5 +180,5 @@ def init_model(configs):
|
||||
activation = nn.Sigmoid()
|
||||
|
||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||
preprocessing, backbone, classifier, activation)
|
||||
preprocessing, backbone, classifier, activation, streaming_chunk)
|
||||
return kws_model
|
||||
|
||||
@ -143,6 +143,28 @@ class TCN(nn.Module):
|
||||
out_caches.append(c)
|
||||
x = x.transpose(1, 2) # (B, T, D)
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
return x, new_cache
|
||||
|
||||
def forward_chunk(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (B, T, D)
|
||||
|
||||
Returns:
|
||||
torch.Tensor(B, T, D)
|
||||
torch.Tensor(B, D, C): C is the accumulated cache size
|
||||
"""
|
||||
x = x.transpose(1, 2) # (B, D, T)
|
||||
out_caches = []
|
||||
cur_padding = 0
|
||||
if cache is None:
|
||||
cache = torch.zeros(x.size(0), x.size(1), self.padding)
|
||||
for block in self.network:
|
||||
cur_padding += block.padding
|
||||
x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding])
|
||||
out_caches.append(c) # c (B, D, T)
|
||||
x = x.transpose(1, 2) # (B, T, D)
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
return x, new_cache
|
||||
|
||||
def fuse_modules(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user