[wekws] add cache support (#78)
This commit is contained in:
parent
8aa68ad750
commit
5037d51ed9
@ -16,6 +16,7 @@ import argparse
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
from wekws.model.kws_model import init_model
|
||||
@ -25,9 +26,6 @@ from wekws.utils.checkpoint import load_checkpoint
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='export to onnx model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--jit_model',
|
||||
required=True,
|
||||
help='pytorch jit script model')
|
||||
parser.add_argument('--onnx_model',
|
||||
required=True,
|
||||
help='output onnx model')
|
||||
@ -48,21 +46,45 @@ def main():
|
||||
model.eval()
|
||||
# dummy_input: (batch, time, feature_dim)
|
||||
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
|
||||
torch.onnx.export(model,
|
||||
dummy_input,
|
||||
cache = torch.zeros(1,
|
||||
model.hdim,
|
||||
model.backbone.padding,
|
||||
dtype=torch.float)
|
||||
torch.onnx.export(model, (dummy_input, cache),
|
||||
args.onnx_model,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
dynamic_axes={'input': {
|
||||
1: 'T'
|
||||
}},
|
||||
opset_version=10)
|
||||
input_names=['input', 'cache'],
|
||||
output_names=['output', 'r_cache'],
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
1: 'T'
|
||||
},
|
||||
'output': {
|
||||
1: 'T'
|
||||
}},
|
||||
opset_version=13,
|
||||
verbose=False,
|
||||
do_constant_folding=True)
|
||||
|
||||
torch_output = model(dummy_input)
|
||||
# Add hidden dim and cache size
|
||||
onnx_model = onnx.load(args.onnx_model)
|
||||
meta = onnx_model.metadata_props.add()
|
||||
meta.key, meta.value = 'cache_dim', str(model.hdim)
|
||||
meta = onnx_model.metadata_props.add()
|
||||
meta.key, meta.value = 'cache_len', str(model.backbone.padding)
|
||||
onnx.save(onnx_model, args.onnx_model)
|
||||
|
||||
# Verify onnx precision
|
||||
torch_output = model(dummy_input, cache)
|
||||
ort_sess = ort.InferenceSession(args.onnx_model)
|
||||
onnx_input = dummy_input.numpy()
|
||||
onnx_output = ort_sess.run(None, {'input': onnx_input})
|
||||
if torch.allclose(torch_output, torch.tensor(onnx_output[0])):
|
||||
onnx_output = ort_sess.run(None, {
|
||||
'input': dummy_input.numpy(),
|
||||
'cache': cache.numpy()
|
||||
})
|
||||
|
||||
if torch.allclose(torch_output[0],
|
||||
torch.tensor(onnx_output[0]), atol=1e-6) and \
|
||||
torch.allclose(torch_output[1],
|
||||
torch.tensor(onnx_output[1]), atol=1e-6):
|
||||
print('Export to onnx succeed!')
|
||||
else:
|
||||
print('''Export to onnx succeed, but pytorch/onnx have different
|
||||
|
||||
@ -109,7 +109,7 @@ def main():
|
||||
keys, feats, target, lengths = batch
|
||||
feats = feats.to(device)
|
||||
lengths = lengths.to(device)
|
||||
logits = model(feats)
|
||||
logits, _ = model(feats)
|
||||
num_keywords = logits.shape[2]
|
||||
logits = logits.cpu()
|
||||
for i in range(len(keys)):
|
||||
|
||||
@ -59,14 +59,18 @@ class KWSModel(nn.Module):
|
||||
self.classifier = classifier
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.global_cmvn is not None:
|
||||
x = self.global_cmvn(x)
|
||||
x = self.preprocessing(x)
|
||||
x, _ = self.backbone(x)
|
||||
x, out_cache = self.backbone(x, in_cache)
|
||||
x = self.classifier(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
return x, out_cache
|
||||
|
||||
def fuse_modules(self):
|
||||
self.preprocessing.fuse_modules()
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -31,15 +31,21 @@ class Block(nn.Module):
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x(torch.Tensor): Input tensor (B, D, T)
|
||||
cache(torch.Tensor): Input cache(B, D, self.padding)
|
||||
Returns:
|
||||
torch.Tensor(B, D, T)
|
||||
torch.Tensor(B, D, T): output
|
||||
torch.Tensor(B, D, self.padding): new cache
|
||||
"""
|
||||
# The CNN used here is causal convolution
|
||||
if cache is None:
|
||||
if cache.size(0) == 0:
|
||||
y = F.pad(x, (self.padding, 0), value=0.0)
|
||||
else:
|
||||
y = torch.cat((cache, x), dim=2)
|
||||
@ -127,10 +133,15 @@ class TCN(nn.Module):
|
||||
self.network.append(
|
||||
block_class(channel, kernel_size, dilation, dropout))
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (B, T, D)
|
||||
in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size
|
||||
|
||||
Returns:
|
||||
torch.Tensor(B, T, D)
|
||||
@ -138,9 +149,15 @@ class TCN(nn.Module):
|
||||
"""
|
||||
x = x.transpose(1, 2) # (B, D, T)
|
||||
out_caches = []
|
||||
offset = 0
|
||||
for block in self.network:
|
||||
x, c = block(x)
|
||||
out_caches.append(c)
|
||||
if in_cache.size(0) > 0:
|
||||
c_in = in_cache[:, :, offset:offset + block.padding]
|
||||
else:
|
||||
c_in = torch.zeros(0, 0, 0)
|
||||
x, c_out = block(x, c_in)
|
||||
out_caches.append(c_out)
|
||||
offset += block.padding
|
||||
x = x.transpose(1, 2) # (B, T, D)
|
||||
new_cache = torch.cat(out_caches, dim=2)
|
||||
return x, new_cache
|
||||
|
||||
@ -43,7 +43,7 @@ class Executor:
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits = model(feats)
|
||||
logits, _ = model(feats)
|
||||
loss_type = args.get('criterion', 'max_pooling')
|
||||
loss, acc = criterion(loss_type, logits, target, feats_lengths,
|
||||
min_duration)
|
||||
@ -76,7 +76,7 @@ class Executor:
|
||||
num_utts = feats_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
logits = model(feats)
|
||||
logits, _ = model(feats)
|
||||
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
||||
logits, target, feats_lengths)
|
||||
if torch.isfinite(loss):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user