add streaming inference at tcn and ds_tcn model

This commit is contained in:
lvanchao1 2022-04-01 18:12:29 +08:00
parent 7d142b9528
commit 1628f8c221
4 changed files with 62 additions and 9 deletions

View File

@ -101,7 +101,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--batch_size 256 \ --batch_size 256 \
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \ --score_file $result_dir/score.txt \
--num_workers 8 --num_workers 8 \
--streaming_chunk 20
for keyword in 0 1; do for keyword in 0 1; do
python kws/bin/compute_det.py \ python kws/bin/compute_det.py \
@ -136,7 +137,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--jit_model \ --jit_model \
--checkpoint $dir/$quantize_score_checkpoint \ --checkpoint $dir/$quantize_score_checkpoint \
--score_file $result_dir/score.txt \ --score_file $result_dir/score.txt \
--num_workers 8 --num_workers 8 \
--streaming_chunk 20
for keyword in 0 1; do for keyword in 0 1; do
python kws/bin/compute_det.py \ python kws/bin/compute_det.py \
--keyword $keyword \ --keyword $keyword \

View File

@ -55,6 +55,10 @@ def get_args():
default=100, default=100,
type=int, type=int,
help='prefetch number') help='prefetch number')
parser.add_argument('--streaming_chunk',
default=-1,
type=int,
help='streaming chunk size for inference')
parser.add_argument('--score_file', parser.add_argument('--score_file',
required=True, required=True,
help='output score file') help='output score file')
@ -84,6 +88,8 @@ def main():
test_conf['feature_extraction_conf']['dither'] = 0.0 test_conf['feature_extraction_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_size'] = args.batch_size test_conf['batch_conf']['batch_size'] = args.batch_size
configs['model']['streaming_chunk'] = args.streaming_chunk
test_dataset = Dataset(args.test_data, test_conf) test_dataset = Dataset(args.test_data, test_conf)
test_data_loader = DataLoader(test_dataset, test_data_loader = DataLoader(test_dataset,
batch_size=None, batch_size=None,

View File

@ -46,6 +46,7 @@ class KWSModel(nn.Module):
backbone: nn.Module, backbone: nn.Module,
classifier: nn.Module, classifier: nn.Module,
activation: nn.Module, activation: nn.Module,
streaming_chunk: int,
): ):
super().__init__() super().__init__()
self.idim = idim self.idim = idim
@ -56,15 +57,33 @@ class KWSModel(nn.Module):
self.backbone = backbone self.backbone = backbone
self.classifier = classifier self.classifier = classifier
self.activation = activation self.activation = activation
self.streaming_chunk = streaming_chunk
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.global_cmvn is not None: if self.global_cmvn is not None:
x = self.global_cmvn(x) x = self.global_cmvn(x)
x = self.preprocessing(x) if self.streaming_chunk <= 0:
x, _ = self.backbone(x) x = self.preprocessing(x)
x = self.classifier(x) x, _ = self.backbone(x)
x = self.activation(x) x = self.classifier(x)
return x x = self.activation(x)
return x
else:
stride = self.streaming_chunk
num_frames = x.size(1) #(B, T, D)
cache: Optional[torch.Tensor] = None
out = []
for cur in range(0, num_frames, stride):
end = min(cur + stride, num_frames)
x_s = x[:, cur:end, :]
x_s = self.preprocessing(x_s)
x_s, cache = self.backbone.forward_chunk(x_s, cache)
x_s = self.classifier(x_s)
x_s = self.activation(x_s)
out.append(x_s)
out_ = torch.cat(out, dim=1)
# wait all num_frames out, then return
return out_
def fuse_modules(self): def fuse_modules(self):
self.preprocessing.fuse_modules() self.preprocessing.fuse_modules()
@ -87,6 +106,10 @@ def init_model(configs):
output_dim = configs['output_dim'] output_dim = configs['output_dim']
hidden_dim = configs['hidden_dim'] hidden_dim = configs['hidden_dim']
streaming_chunk = -1
if ('streaming_chunk' in configs and configs['streaming_chunk'] > 0):
streaming_chunk = configs['streaming_chunk']
prep_type = configs['preprocessing']['type'] prep_type = configs['preprocessing']['type']
if prep_type == 'linear': if prep_type == 'linear':
preprocessing = LinearSubsampling1(input_dim, hidden_dim) preprocessing = LinearSubsampling1(input_dim, hidden_dim)
@ -157,5 +180,5 @@ def init_model(configs):
activation = nn.Sigmoid() activation = nn.Sigmoid()
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn, kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier, activation) preprocessing, backbone, classifier, activation, streaming_chunk)
return kws_model return kws_model

View File

@ -145,6 +145,28 @@ class TCN(nn.Module):
new_cache = torch.cat(out_caches, dim=2) new_cache = torch.cat(out_caches, dim=2)
return x, new_cache return x, new_cache
def forward_chunk(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
"""
Args:
x (torch.Tensor): Input tensor (B, T, D)
Returns:
torch.Tensor(B, T, D)
torch.Tensor(B, D, C): C is the accumulated cache size
"""
x = x.transpose(1, 2) # (B, D, T)
out_caches = []
cur_padding = 0
if cache is None:
cache = torch.zeros(x.size(0), x.size(1), self.padding)
for block in self.network:
cur_padding += block.padding
x, c = block(x, cache[:, :, cur_padding - block.padding : cur_padding])
out_caches.append(c) # c (B, D, T)
x = x.transpose(1, 2) # (B, T, D)
new_cache = torch.cat(out_caches, dim=2)
return x, new_cache
def fuse_modules(self): def fuse_modules(self):
for m in self.network: for m in self.network:
m.fuse_modules() m.fuse_modules()