[wekws] add cache support (#78)

This commit is contained in:
Binbin Zhang 2022-08-27 16:44:22 +08:00 committed by GitHub
parent 8aa68ad750
commit 5037d51ed9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 28 deletions

View File

@ -16,6 +16,7 @@ import argparse
import torch
import yaml
import onnxruntime as ort
from wekws.model.kws_model import init_model
@ -25,9 +26,6 @@ from wekws.utils.checkpoint import load_checkpoint
def get_args():
parser = argparse.ArgumentParser(description='export to onnx model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--jit_model',
required=True,
help='pytorch jit script model')
parser.add_argument('--onnx_model',
required=True,
help='output onnx model')
@ -48,21 +46,45 @@ def main():
model.eval()
# dummy_input: (batch, time, feature_dim)
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
torch.onnx.export(model,
dummy_input,
cache = torch.zeros(1,
model.hdim,
model.backbone.padding,
dtype=torch.float)
torch.onnx.export(model, (dummy_input, cache),
args.onnx_model,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {
1: 'T'
}},
opset_version=10)
input_names=['input', 'cache'],
output_names=['output', 'r_cache'],
dynamic_axes={
'input': {
1: 'T'
},
'output': {
1: 'T'
}},
opset_version=13,
verbose=False,
do_constant_folding=True)
torch_output = model(dummy_input)
# Add hidden dim and cache size
onnx_model = onnx.load(args.onnx_model)
meta = onnx_model.metadata_props.add()
meta.key, meta.value = 'cache_dim', str(model.hdim)
meta = onnx_model.metadata_props.add()
meta.key, meta.value = 'cache_len', str(model.backbone.padding)
onnx.save(onnx_model, args.onnx_model)
# Verify onnx precision
torch_output = model(dummy_input, cache)
ort_sess = ort.InferenceSession(args.onnx_model)
onnx_input = dummy_input.numpy()
onnx_output = ort_sess.run(None, {'input': onnx_input})
if torch.allclose(torch_output, torch.tensor(onnx_output[0])):
onnx_output = ort_sess.run(None, {
'input': dummy_input.numpy(),
'cache': cache.numpy()
})
if torch.allclose(torch_output[0],
torch.tensor(onnx_output[0]), atol=1e-6) and \
torch.allclose(torch_output[1],
torch.tensor(onnx_output[1]), atol=1e-6):
print('Export to onnx succeed!')
else:
print('''Export to onnx succeed, but pytorch/onnx have different

View File

@ -109,7 +109,7 @@ def main():
keys, feats, target, lengths = batch
feats = feats.to(device)
lengths = lengths.to(device)
logits = model(feats)
logits, _ = model(feats)
num_keywords = logits.shape[2]
logits = logits.cpu()
for i in range(len(keys)):

View File

@ -59,14 +59,18 @@ class KWSModel(nn.Module):
self.classifier = classifier
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None:
x = self.global_cmvn(x)
x = self.preprocessing(x)
x, _ = self.backbone(x)
x, out_cache = self.backbone(x, in_cache)
x = self.classifier(x)
x = self.activation(x)
return x
return x, out_cache
def fuse_modules(self):
self.preprocessing.fuse_modules()

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Tuple
import torch
import torch.nn as nn
@ -31,15 +31,21 @@ class Block(nn.Module):
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
def forward(
self,
x: torch.Tensor,
cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x(torch.Tensor): Input tensor (B, D, T)
cache(torch.Tensor): Input cache(B, D, self.padding)
Returns:
torch.Tensor(B, D, T)
torch.Tensor(B, D, T): output
torch.Tensor(B, D, self.padding): new cache
"""
# The CNN used here is causal convolution
if cache is None:
if cache.size(0) == 0:
y = F.pad(x, (self.padding, 0), value=0.0)
else:
y = torch.cat((cache, x), dim=2)
@ -127,10 +133,15 @@ class TCN(nn.Module):
self.network.append(
block_class(channel, kernel_size, dilation, dropout))
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
def forward(
self,
x: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x (torch.Tensor): Input tensor (B, T, D)
in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size
Returns:
torch.Tensor(B, T, D)
@ -138,9 +149,15 @@ class TCN(nn.Module):
"""
x = x.transpose(1, 2) # (B, D, T)
out_caches = []
offset = 0
for block in self.network:
x, c = block(x)
out_caches.append(c)
if in_cache.size(0) > 0:
c_in = in_cache[:, :, offset:offset + block.padding]
else:
c_in = torch.zeros(0, 0, 0)
x, c_out = block(x, c_in)
out_caches.append(c_out)
offset += block.padding
x = x.transpose(1, 2) # (B, T, D)
new_cache = torch.cat(out_caches, dim=2)
return x, new_cache

View File

@ -43,7 +43,7 @@ class Executor:
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits = model(feats)
logits, _ = model(feats)
loss_type = args.get('criterion', 'max_pooling')
loss, acc = criterion(loss_type, logits, target, feats_lengths,
min_duration)
@ -76,7 +76,7 @@ class Executor:
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits = model(feats)
logits, _ = model(feats)
loss, acc = criterion(args.get('criterion', 'max_pooling'),
logits, target, feats_lengths)
if torch.isfinite(loss):