[wekws] add cache support (#78)
This commit is contained in:
parent
8aa68ad750
commit
5037d51ed9
@ -16,6 +16,7 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
from wekws.model.kws_model import init_model
|
from wekws.model.kws_model import init_model
|
||||||
@ -25,9 +26,6 @@ from wekws.utils.checkpoint import load_checkpoint
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(description='export to onnx model')
|
parser = argparse.ArgumentParser(description='export to onnx model')
|
||||||
parser.add_argument('--config', required=True, help='config file')
|
parser.add_argument('--config', required=True, help='config file')
|
||||||
parser.add_argument('--jit_model',
|
|
||||||
required=True,
|
|
||||||
help='pytorch jit script model')
|
|
||||||
parser.add_argument('--onnx_model',
|
parser.add_argument('--onnx_model',
|
||||||
required=True,
|
required=True,
|
||||||
help='output onnx model')
|
help='output onnx model')
|
||||||
@ -48,21 +46,45 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
# dummy_input: (batch, time, feature_dim)
|
# dummy_input: (batch, time, feature_dim)
|
||||||
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
|
dummy_input = torch.randn(1, 100, feature_dim, dtype=torch.float)
|
||||||
torch.onnx.export(model,
|
cache = torch.zeros(1,
|
||||||
dummy_input,
|
model.hdim,
|
||||||
|
model.backbone.padding,
|
||||||
|
dtype=torch.float)
|
||||||
|
torch.onnx.export(model, (dummy_input, cache),
|
||||||
args.onnx_model,
|
args.onnx_model,
|
||||||
input_names=['input'],
|
input_names=['input', 'cache'],
|
||||||
output_names=['output'],
|
output_names=['output', 'r_cache'],
|
||||||
dynamic_axes={'input': {
|
dynamic_axes={
|
||||||
1: 'T'
|
'input': {
|
||||||
}},
|
1: 'T'
|
||||||
opset_version=10)
|
},
|
||||||
|
'output': {
|
||||||
|
1: 'T'
|
||||||
|
}},
|
||||||
|
opset_version=13,
|
||||||
|
verbose=False,
|
||||||
|
do_constant_folding=True)
|
||||||
|
|
||||||
torch_output = model(dummy_input)
|
# Add hidden dim and cache size
|
||||||
|
onnx_model = onnx.load(args.onnx_model)
|
||||||
|
meta = onnx_model.metadata_props.add()
|
||||||
|
meta.key, meta.value = 'cache_dim', str(model.hdim)
|
||||||
|
meta = onnx_model.metadata_props.add()
|
||||||
|
meta.key, meta.value = 'cache_len', str(model.backbone.padding)
|
||||||
|
onnx.save(onnx_model, args.onnx_model)
|
||||||
|
|
||||||
|
# Verify onnx precision
|
||||||
|
torch_output = model(dummy_input, cache)
|
||||||
ort_sess = ort.InferenceSession(args.onnx_model)
|
ort_sess = ort.InferenceSession(args.onnx_model)
|
||||||
onnx_input = dummy_input.numpy()
|
onnx_output = ort_sess.run(None, {
|
||||||
onnx_output = ort_sess.run(None, {'input': onnx_input})
|
'input': dummy_input.numpy(),
|
||||||
if torch.allclose(torch_output, torch.tensor(onnx_output[0])):
|
'cache': cache.numpy()
|
||||||
|
})
|
||||||
|
|
||||||
|
if torch.allclose(torch_output[0],
|
||||||
|
torch.tensor(onnx_output[0]), atol=1e-6) and \
|
||||||
|
torch.allclose(torch_output[1],
|
||||||
|
torch.tensor(onnx_output[1]), atol=1e-6):
|
||||||
print('Export to onnx succeed!')
|
print('Export to onnx succeed!')
|
||||||
else:
|
else:
|
||||||
print('''Export to onnx succeed, but pytorch/onnx have different
|
print('''Export to onnx succeed, but pytorch/onnx have different
|
||||||
|
|||||||
@ -109,7 +109,7 @@ def main():
|
|||||||
keys, feats, target, lengths = batch
|
keys, feats, target, lengths = batch
|
||||||
feats = feats.to(device)
|
feats = feats.to(device)
|
||||||
lengths = lengths.to(device)
|
lengths = lengths.to(device)
|
||||||
logits = model(feats)
|
logits, _ = model(feats)
|
||||||
num_keywords = logits.shape[2]
|
num_keywords = logits.shape[2]
|
||||||
logits = logits.cpu()
|
logits = logits.cpu()
|
||||||
for i in range(len(keys)):
|
for i in range(len(keys)):
|
||||||
|
|||||||
@ -59,14 +59,18 @@ class KWSModel(nn.Module):
|
|||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
x = self.global_cmvn(x)
|
x = self.global_cmvn(x)
|
||||||
x = self.preprocessing(x)
|
x = self.preprocessing(x)
|
||||||
x, _ = self.backbone(x)
|
x, out_cache = self.backbone(x, in_cache)
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
return x
|
return x, out_cache
|
||||||
|
|
||||||
def fuse_modules(self):
|
def fuse_modules(self):
|
||||||
self.preprocessing.fuse_modules()
|
self.preprocessing.fuse_modules()
|
||||||
|
|||||||
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -31,15 +31,21 @@ class Block(nn.Module):
|
|||||||
self.quant = torch.quantization.QuantStub()
|
self.quant = torch.quantization.QuantStub()
|
||||||
self.dequant = torch.quantization.DeQuantStub()
|
self.dequant = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x(torch.Tensor): Input tensor (B, D, T)
|
x(torch.Tensor): Input tensor (B, D, T)
|
||||||
|
cache(torch.Tensor): Input cache(B, D, self.padding)
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor(B, D, T)
|
torch.Tensor(B, D, T): output
|
||||||
|
torch.Tensor(B, D, self.padding): new cache
|
||||||
"""
|
"""
|
||||||
# The CNN used here is causal convolution
|
# The CNN used here is causal convolution
|
||||||
if cache is None:
|
if cache.size(0) == 0:
|
||||||
y = F.pad(x, (self.padding, 0), value=0.0)
|
y = F.pad(x, (self.padding, 0), value=0.0)
|
||||||
else:
|
else:
|
||||||
y = torch.cat((cache, x), dim=2)
|
y = torch.cat((cache, x), dim=2)
|
||||||
@ -127,10 +133,15 @@ class TCN(nn.Module):
|
|||||||
self.network.append(
|
self.network.append(
|
||||||
block_class(channel, kernel_size, dilation, dropout))
|
block_class(channel, kernel_size, dilation, dropout))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (B, T, D)
|
x (torch.Tensor): Input tensor (B, T, D)
|
||||||
|
in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor(B, T, D)
|
torch.Tensor(B, T, D)
|
||||||
@ -138,9 +149,15 @@ class TCN(nn.Module):
|
|||||||
"""
|
"""
|
||||||
x = x.transpose(1, 2) # (B, D, T)
|
x = x.transpose(1, 2) # (B, D, T)
|
||||||
out_caches = []
|
out_caches = []
|
||||||
|
offset = 0
|
||||||
for block in self.network:
|
for block in self.network:
|
||||||
x, c = block(x)
|
if in_cache.size(0) > 0:
|
||||||
out_caches.append(c)
|
c_in = in_cache[:, :, offset:offset + block.padding]
|
||||||
|
else:
|
||||||
|
c_in = torch.zeros(0, 0, 0)
|
||||||
|
x, c_out = block(x, c_in)
|
||||||
|
out_caches.append(c_out)
|
||||||
|
offset += block.padding
|
||||||
x = x.transpose(1, 2) # (B, T, D)
|
x = x.transpose(1, 2) # (B, T, D)
|
||||||
new_cache = torch.cat(out_caches, dim=2)
|
new_cache = torch.cat(out_caches, dim=2)
|
||||||
return x, new_cache
|
return x, new_cache
|
||||||
|
|||||||
@ -43,7 +43,7 @@ class Executor:
|
|||||||
num_utts = feats_lengths.size(0)
|
num_utts = feats_lengths.size(0)
|
||||||
if num_utts == 0:
|
if num_utts == 0:
|
||||||
continue
|
continue
|
||||||
logits = model(feats)
|
logits, _ = model(feats)
|
||||||
loss_type = args.get('criterion', 'max_pooling')
|
loss_type = args.get('criterion', 'max_pooling')
|
||||||
loss, acc = criterion(loss_type, logits, target, feats_lengths,
|
loss, acc = criterion(loss_type, logits, target, feats_lengths,
|
||||||
min_duration)
|
min_duration)
|
||||||
@ -76,7 +76,7 @@ class Executor:
|
|||||||
num_utts = feats_lengths.size(0)
|
num_utts = feats_lengths.size(0)
|
||||||
if num_utts == 0:
|
if num_utts == 0:
|
||||||
continue
|
continue
|
||||||
logits = model(feats)
|
logits, _ = model(feats)
|
||||||
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
loss, acc = criterion(args.get('criterion', 'max_pooling'),
|
||||||
logits, target, feats_lengths)
|
logits, target, feats_lengths)
|
||||||
if torch.isfinite(loss):
|
if torch.isfinite(loss):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user