add streaming inference at tcn and ds_tcn model
This commit is contained in:
parent
7d142b9528
commit
1628f8c221
@ -101,7 +101,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
|||||||
--batch_size 256 \
|
--batch_size 256 \
|
||||||
--checkpoint $score_checkpoint \
|
--checkpoint $score_checkpoint \
|
||||||
--score_file $result_dir/score.txt \
|
--score_file $result_dir/score.txt \
|
||||||
--num_workers 8
|
--num_workers 8 \
|
||||||
|
--streaming_chunk 20
|
||||||
|
|
||||||
for keyword in 0 1; do
|
for keyword in 0 1; do
|
||||||
python kws/bin/compute_det.py \
|
python kws/bin/compute_det.py \
|
||||||
@ -136,7 +137,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
|||||||
--jit_model \
|
--jit_model \
|
||||||
--checkpoint $dir/$quantize_score_checkpoint \
|
--checkpoint $dir/$quantize_score_checkpoint \
|
||||||
--score_file $result_dir/score.txt \
|
--score_file $result_dir/score.txt \
|
||||||
--num_workers 8
|
--num_workers 8 \
|
||||||
|
--streaming_chunk 20
|
||||||
for keyword in 0 1; do
|
for keyword in 0 1; do
|
||||||
python kws/bin/compute_det.py \
|
python kws/bin/compute_det.py \
|
||||||
--keyword $keyword \
|
--keyword $keyword \
|
||||||
@ -158,4 +160,4 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|||||||
--config $dir/config.yaml \
|
--config $dir/config.yaml \
|
||||||
--jit_model $dir/$jit_model \
|
--jit_model $dir/$jit_model \
|
||||||
--onnx_model $dir/$onnx_model
|
--onnx_model $dir/$onnx_model
|
||||||
fi
|
fi
|
||||||
|
|||||||
@ -55,6 +55,10 @@ def get_args():
|
|||||||
default=100,
|
default=100,
|
||||||
type=int,
|
type=int,
|
||||||
help='prefetch number')
|
help='prefetch number')
|
||||||
|
parser.add_argument('--streaming_chunk',
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help='streaming chunk size for inference')
|
||||||
parser.add_argument('--score_file',
|
parser.add_argument('--score_file',
|
||||||
required=True,
|
required=True,
|
||||||
help='output score file')
|
help='output score file')
|
||||||
@ -84,6 +88,8 @@ def main():
|
|||||||
test_conf['feature_extraction_conf']['dither'] = 0.0
|
test_conf['feature_extraction_conf']['dither'] = 0.0
|
||||||
test_conf['batch_conf']['batch_size'] = args.batch_size
|
test_conf['batch_conf']['batch_size'] = args.batch_size
|
||||||
|
|
||||||
|
configs['model']['streaming_chunk'] = args.streaming_chunk
|
||||||
|
|
||||||
test_dataset = Dataset(args.test_data, test_conf)
|
test_dataset = Dataset(args.test_data, test_conf)
|
||||||
test_data_loader = DataLoader(test_dataset,
|
test_data_loader = DataLoader(test_dataset,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
|
|||||||
@ -46,6 +46,7 @@ class KWSModel(nn.Module):
|
|||||||
backbone: nn.Module,
|
backbone: nn.Module,
|
||||||
classifier: nn.Module,
|
classifier: nn.Module,
|
||||||
activation: nn.Module,
|
activation: nn.Module,
|
||||||
|
streaming_chunk: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.idim = idim
|
self.idim = idim
|
||||||
@ -56,15 +57,33 @@ class KWSModel(nn.Module):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
self.streaming_chunk = streaming_chunk
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
x = self.global_cmvn(x)
|
x = self.global_cmvn(x)
|
||||||
x = self.preprocessing(x)
|
if self.streaming_chunk <= 0:
|
||||||
x, _ = self.backbone(x)
|
x = self.preprocessing(x)
|
||||||
x = self.classifier(x)
|
x, _ = self.backbone(x)
|
||||||
x = self.activation(x)
|
x = self.classifier(x)
|
||||||
return x
|
x = self.activation(x)
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
stride = self.streaming_chunk
|
||||||
|
num_frames = x.size(1) #(B, T, D)
|
||||||
|
cache: Optional[torch.Tensor] = None
|
||||||
|
out = []
|
||||||
|
for cur in range(0, num_frames, stride):
|
||||||
|
end = min(cur + stride, num_frames)
|
||||||
|
x_s = x[:, cur:end, :]
|
||||||
|
x_s = self.preprocessing(x_s)
|
||||||
|
x_s, cache = self.backbone.forward_chunk(x_s, cache)
|
||||||
|
x_s = self.classifier(x_s)
|
||||||
|
x_s = self.activation(x_s)
|
||||||
|
out.append(x_s)
|
||||||
|
out_ = torch.cat(out, dim=1)
|
||||||
|
# wait all num_frames out, then return
|
||||||
|
return out_
|
||||||
|
|
||||||
def fuse_modules(self):
|
def fuse_modules(self):
|
||||||
self.preprocessing.fuse_modules()
|
self.preprocessing.fuse_modules()
|
||||||
@ -87,6 +106,10 @@ def init_model(configs):
|
|||||||
output_dim = configs['output_dim']
|
output_dim = configs['output_dim']
|
||||||
hidden_dim = configs['hidden_dim']
|
hidden_dim = configs['hidden_dim']
|
||||||
|
|
||||||
|
streaming_chunk = -1
|
||||||
|
if ('streaming_chunk' in configs and configs['streaming_chunk'] > 0):
|
||||||
|
streaming_chunk = configs['streaming_chunk']
|
||||||
|
|
||||||
prep_type = configs['preprocessing']['type']
|
prep_type = configs['preprocessing']['type']
|
||||||
if prep_type == 'linear':
|
if prep_type == 'linear':
|
||||||
preprocessing = LinearSubsampling1(input_dim, hidden_dim)
|
preprocessing = LinearSubsampling1(input_dim, hidden_dim)
|
||||||
@ -157,5 +180,5 @@ def init_model(configs):
|
|||||||
activation = nn.Sigmoid()
|
activation = nn.Sigmoid()
|
||||||
|
|
||||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||||
preprocessing, backbone, classifier, activation)
|
preprocessing, backbone, classifier, activation, streaming_chunk)
|
||||||
return kws_model
|
return kws_model
|
||||||
|
|||||||
@ -143,6 +143,28 @@ class TCN(nn.Module):
|
|||||||
out_caches.append(c)
|
out_caches.append(c)
|
||||||
x = x.transpose(1, 2) # (B, T, D)
|
x = x.transpose(1, 2) # (B, T, D)
|
||||||
new_cache = torch.cat(out_caches, dim=2)
|
new_cache = torch.cat(out_caches, dim=2)
|
||||||
|
return x, new_cache
|
||||||
|
|
||||||
|
def forward_chunk(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor (B, T, D)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor(B, T, D)
|
||||||
|
torch.Tensor(B, D, C): C is the accumulated cache size
|
||||||
|
"""
|
||||||
|
x = x.transpose(1, 2) # (B, D, T)
|
||||||
|
out_caches = []
|
||||||
|
cur_padding = 0
|
||||||
|
if cache is None:
|
||||||
|
cache = torch.zeros(x.size(0), x.size(1), self.padding)
|
||||||
|
for block in self.network:
|
||||||
|
cur_padding += block.padding
|
||||||
|
x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding])
|
||||||
|
out_caches.append(c) # c (B, D, T)
|
||||||
|
x = x.transpose(1, 2) # (B, T, D)
|
||||||
|
new_cache = torch.cat(out_caches, dim=2)
|
||||||
return x, new_cache
|
return x, new_cache
|
||||||
|
|
||||||
def fuse_modules(self):
|
def fuse_modules(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user