[runtime/onnxruntime] add onnxruntime support (#79)

* [runtime/onnxruntime] add onnxruntime support

* add cpplint and clang-format

* fix lint
This commit is contained in:
Binbin Zhang 2022-08-28 13:35:21 +08:00 committed by GitHub
parent 5037d51ed9
commit 53d7b8f807
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1348 additions and 0 deletions

93
.clang-format Normal file
View File

@ -0,0 +1,93 @@
---
Language: Cpp
# BasedOnStyle: Google
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: true
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: true
BinPackParameters: true
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: true
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^<.*\.h>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IncludeIsMainRegex: '([-_](test|unittest))?$'
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Auto
TabWidth: 8
UseTab: Never
...

2
CPPLINT.cfg Normal file
View File

@ -0,0 +1,2 @@
root=runtime/core
filter=-build/c++11

View File

@ -0,0 +1,64 @@
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include "frontend/feature_pipeline.h"
#include "frontend/wav.h"
#include "kws/keyword_spotting.h"
#include "utils/log.h"
int main(int argc, char* argv[]) {
if (argc != 5) {
LOG(FATAL) << "Usage: kws_main fbank_dim(int) batch_size(int) "
<< "kws_model_path test_wav_path";
}
const int num_bins = std::stoi(argv[1]); // Fbank feature dim
const int batch_size = std::stoi(argv[2]);
const std::string model_path = argv[3];
const std::string wav_path = argv[4];
wenet::WavReader wav_reader(wav_path);
int num_samples = wav_reader.num_samples();
wenet::FeaturePipelineConfig feature_config(num_bins, 16000);
wenet::FeaturePipeline feature_pipeline(feature_config);
std::vector<float> wav(wav_reader.data(), wav_reader.data() + num_samples);
feature_pipeline.AcceptWaveform(wav);
feature_pipeline.set_input_finished();
wekws::KeywordSpotting spotter(model_path);
// Simulate streaming, detect batch by batch
int offset = 0;
while (true) {
std::vector<std::vector<float>> feats;
bool ok = feature_pipeline.Read(batch_size, &feats);
std::vector<std::vector<float>> prob;
spotter.Forward(feats, &prob);
for (int i = 0; i < prob.size(); i++) {
std::cout << "frame " << offset + i << " prob";
for (int j = 0; j < prob[i].size(); j++) {
std::cout << " " << prob[i][0];
}
std::cout << std::endl;
}
// Reach the end of feature pipeline
if (!ok) break;
offset += prob.size();
}
return 0;
}

View File

@ -0,0 +1,222 @@
// Copyright (c) 2017 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_FBANK_H_
#define FRONTEND_FBANK_H_
#include <cstring>
#include <limits>
#include <random>
#include <utility>
#include <vector>
#include "frontend/fft.h"
#include "utils/log.h"
namespace wenet {
// This code is based on kaldi Fbank implentation, please see
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc
class Fbank {
public:
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift)
: num_bins_(num_bins),
sample_rate_(sample_rate),
frame_length_(frame_length),
frame_shift_(frame_shift),
use_log_(true),
remove_dc_offset_(true),
generator_(0),
distribution_(0, 1.0),
dither_(0.0) {
fft_points_ = UpperPowerOfTwo(frame_length_);
// generate bit reversal table and trigonometric function table
const int fft_points_4 = fft_points_ / 4;
bitrev_.resize(fft_points_);
sintbl_.resize(fft_points_ + fft_points_4);
make_sintbl(fft_points_, sintbl_.data());
make_bitrev(fft_points_, bitrev_.data());
int num_fft_bins = fft_points_ / 2;
float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;
int low_freq = 20, high_freq = sample_rate_ / 2;
float mel_low_freq = MelScale(low_freq);
float mel_high_freq = MelScale(high_freq);
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
bins_.resize(num_bins_);
center_freqs_.resize(num_bins_);
for (int bin = 0; bin < num_bins; ++bin) {
float left_mel = mel_low_freq + bin * mel_freq_delta,
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
center_freqs_[bin] = InverseMelScale(center_mel);
std::vector<float> this_bin(num_fft_bins);
int first_index = -1, last_index = -1;
for (int i = 0; i < num_fft_bins; ++i) {
float freq = (fft_bin_width * i); // Center frequency of this fft
// bin.
float mel = MelScale(freq);
if (mel > left_mel && mel < right_mel) {
float weight;
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else
weight = (right_mel - mel) / (right_mel - center_mel);
this_bin[i] = weight;
if (first_index == -1) first_index = i;
last_index = i;
}
}
CHECK(first_index != -1 && last_index >= first_index);
bins_[bin].first = first_index;
int size = last_index + 1 - first_index;
bins_[bin].second.resize(size);
for (int i = 0; i < size; ++i) {
bins_[bin].second[i] = this_bin[first_index + i];
}
}
// NOTE(cdliang): add hamming window
hamming_window_.resize(frame_length_);
double a = M_2PI / (frame_length - 1);
for (int i = 0; i < frame_length; i++) {
double i_fl = static_cast<double>(i);
hamming_window_[i] = 0.54 - 0.46 * cos(a * i_fl);
}
}
void set_use_log(bool use_log) { use_log_ = use_log; }
void set_remove_dc_offset(bool remove_dc_offset) {
remove_dc_offset_ = remove_dc_offset;
}
void set_dither(float dither) { dither_ = dither; }
int num_bins() const { return num_bins_; }
static inline float InverseMelScale(float mel_freq) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
}
static inline float MelScale(float freq) {
return 1127.0f * logf(1.0f + freq / 700.0f);
}
static int UpperPowerOfTwo(int n) {
return static_cast<int>(pow(2, ceil(log(n) / log(2))));
}
// preemphasis
void PreEmphasis(float coeff, std::vector<float>* data) const {
if (coeff == 0.0) return;
for (int i = data->size() - 1; i > 0; i--)
(*data)[i] -= coeff * (*data)[i - 1];
(*data)[0] -= coeff * (*data)[0];
}
// add hamming window
void Hamming(std::vector<float>* data) const {
CHECK(data->size() >= hamming_window_.size());
for (size_t i = 0; i < hamming_window_.size(); ++i) {
(*data)[i] *= hamming_window_[i];
}
}
// Compute fbank feat, return num frames
int Compute(const std::vector<float>& wave,
std::vector<std::vector<float>>* feat) {
int num_samples = wave.size();
if (num_samples < frame_length_) return 0;
int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
feat->resize(num_frames);
std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);
std::vector<float> power(fft_points_ / 2);
for (int i = 0; i < num_frames; ++i) {
std::vector<float> data(wave.data() + i * frame_shift_,
wave.data() + i * frame_shift_ + frame_length_);
// optional add noise
if (dither_ != 0.0) {
for (size_t j = 0; j < data.size(); ++j)
data[j] += dither_ * distribution_(generator_);
}
// optinal remove dc offset
if (remove_dc_offset_) {
float mean = 0.0;
for (size_t j = 0; j < data.size(); ++j) mean += data[j];
mean /= data.size();
for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;
}
PreEmphasis(0.97, &data);
// Povey(&data);
Hamming(&data);
// copy data to fft_real
memset(fft_img.data(), 0, sizeof(float) * fft_points_);
memset(fft_real.data() + frame_length_, 0,
sizeof(float) * (fft_points_ - frame_length_));
memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_);
fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(),
fft_points_);
// power
for (int j = 0; j < fft_points_ / 2; ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
}
(*feat)[i].resize(num_bins_);
// cepstral coefficients, triangle filter array
for (int j = 0; j < num_bins_; ++j) {
float mel_energy = 0.0;
int s = bins_[j].first;
for (size_t k = 0; k < bins_[j].second.size(); ++k) {
mel_energy += bins_[j].second[k] * power[s + k];
}
// optional use log
if (use_log_) {
if (mel_energy < std::numeric_limits<float>::epsilon())
mel_energy = std::numeric_limits<float>::epsilon();
mel_energy = logf(mel_energy);
}
(*feat)[i][j] = mel_energy;
// printf("%f ", mel_energy);
}
// printf("\n");
}
return num_frames;
}
private:
int num_bins_;
int sample_rate_;
int frame_length_, frame_shift_;
int fft_points_;
bool use_log_;
bool remove_dc_offset_;
std::vector<float> center_freqs_;
std::vector<std::pair<int, std::vector<float>>> bins_;
std::vector<float> hamming_window_;
std::default_random_engine generator_;
std::normal_distribution<float> distribution_;
float dither_;
// bit reversal table
std::vector<int> bitrev_;
// trigonometric function table
std::vector<float> sintbl_;
};
} // namespace wenet
#endif // FRONTEND_FBANK_H_

View File

@ -0,0 +1,113 @@
// Copyright (c) 2017 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/feature_pipeline.h"
#include <algorithm>
#include <utility>
namespace wenet {
FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config)
: config_(config),
feature_dim_(config.num_bins),
fbank_(config.num_bins, config.sample_rate, config.frame_length,
config.frame_shift),
num_frames_(0),
input_finished_(false) {}
void FeaturePipeline::AcceptWaveform(const std::vector<float>& wav) {
std::vector<std::vector<float>> feats;
std::vector<float> waves;
waves.insert(waves.end(), remained_wav_.begin(), remained_wav_.end());
waves.insert(waves.end(), wav.begin(), wav.end());
int num_frames = fbank_.Compute(waves, &feats);
for (size_t i = 0; i < feats.size(); ++i) {
feature_queue_.Push(std::move(feats[i]));
}
num_frames_ += num_frames;
int left_samples = waves.size() - config_.frame_shift * num_frames;
remained_wav_.resize(left_samples);
std::copy(waves.begin() + config_.frame_shift * num_frames, waves.end(),
remained_wav_.begin());
// We are still adding wave, notify input is not finished
finish_condition_.notify_one();
}
void FeaturePipeline::AcceptWaveform(const std::vector<int16_t>& wav) {
std::vector<float> float_wav(wav.size());
for (size_t i = 0; i < wav.size(); i++) {
float_wav[i] = static_cast<float>(wav[i]);
}
this->AcceptWaveform(float_wav);
}
void FeaturePipeline::set_input_finished() {
CHECK(!input_finished_);
{
std::lock_guard<std::mutex> lock(mutex_);
input_finished_ = true;
}
finish_condition_.notify_one();
}
bool FeaturePipeline::ReadOne(std::vector<float>* feat) {
if (!feature_queue_.Empty()) {
*feat = std::move(feature_queue_.Pop());
return true;
} else {
std::unique_lock<std::mutex> lock(mutex_);
while (!input_finished_) {
// This will release the lock and wait for notify_one()
// from AcceptWaveform() or set_input_finished()
finish_condition_.wait(lock);
if (!feature_queue_.Empty()) {
*feat = std::move(feature_queue_.Pop());
return true;
}
}
CHECK(input_finished_);
// Double check queue.empty, see issue#893 for detailed discussions.
if (!feature_queue_.Empty()) {
*feat = std::move(feature_queue_.Pop());
return true;
} else {
return false;
}
}
}
bool FeaturePipeline::Read(int num_frames,
std::vector<std::vector<float>>* feats) {
feats->clear();
std::vector<float> feat;
while (feats->size() < num_frames) {
if (ReadOne(&feat)) {
feats->push_back(std::move(feat));
} else {
return false;
}
}
return true;
}
void FeaturePipeline::Reset() {
input_finished_ = false;
num_frames_ = 0;
remained_wav_.clear();
feature_queue_.Clear();
}
} // namespace wenet

View File

@ -0,0 +1,118 @@
// Copyright (c) 2017 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_FEATURE_PIPELINE_H_
#define FRONTEND_FEATURE_PIPELINE_H_
#include <mutex>
#include <queue>
#include <string>
#include <vector>
#include "frontend/fbank.h"
#include "utils/log.h"
#include "utils/blocking_queue.h"
namespace wenet {
struct FeaturePipelineConfig {
int num_bins;
int sample_rate;
int frame_length;
int frame_shift;
FeaturePipelineConfig(int num_bins, int sample_rate)
: num_bins(num_bins), // 80 dim fbank
sample_rate(sample_rate) { // 16k sample rate
frame_length = sample_rate / 1000 * 25; // frame length 25ms
frame_shift = sample_rate / 1000 * 10; // frame shift 10ms
}
void Info() const {
LOG(INFO) << "feature pipeline config"
<< " num_bins " << num_bins << " frame_length " << frame_length
<< " frame_shift " << frame_shift;
}
};
// Typically, FeaturePipeline is used in two threads: one thread A calls
// AcceptWaveform() to add raw wav data and set_input_finished() to notice
// the end of input wav, another thread B (decoder thread) calls Read() to
// consume features.So a BlockingQueue is used to make this class thread safe.
// The Read() is designed as a blocking method when there is no feature
// in feature_queue_ and the input is not finished.
class FeaturePipeline {
public:
explicit FeaturePipeline(const FeaturePipelineConfig& config);
// The feature extraction is done in AcceptWaveform().
void AcceptWaveform(const std::vector<float>& wav);
void AcceptWaveform(const std::vector<int16_t>& wav);
// Current extracted frames number.
int num_frames() const { return num_frames_; }
int feature_dim() const { return feature_dim_; }
const FeaturePipelineConfig& config() const { return config_; }
// The caller should call this method when speech input is end.
// Never call AcceptWaveform() after calling set_input_finished() !
void set_input_finished();
bool input_finished() const { return input_finished_; }
// Return False if input is finished and no feature could be read.
// Return True if a feature is read.
// This function is a blocking method. It will block the thread when
// there is no feature in feature_queue_ and the input is not finished.
bool ReadOne(std::vector<float>* feat);
// Read #num_frames frame features.
// Return False if less then #num_frames features are read and the
// input is finished.
// Return True if #num_frames features are read.
// This function is a blocking method when there is no feature
// in feature_queue_ and the input is not finished.
bool Read(int num_frames, std::vector<std::vector<float>>* feats);
void Reset();
bool IsLastFrame(int frame) const {
return input_finished_ && (frame == num_frames_ - 1);
}
int NumQueuedFrames() const { return feature_queue_.Size(); }
private:
const FeaturePipelineConfig& config_;
int feature_dim_;
Fbank fbank_;
BlockingQueue<std::vector<float>> feature_queue_;
int num_frames_;
bool input_finished_;
// The feature extraction is done in AcceptWaveform().
// This wavefrom sample points are consumed by frame size.
// The residual wavefrom sample points after framing are
// kept to be used in next AcceptWaveform() calling.
std::vector<float> remained_wav_;
// Used to block the Read when there is no feature in feature_queue_
// and the input is not finished.
mutable std::mutex mutex_;
std::condition_variable finish_condition_;
};
} // namespace wenet
#endif // FRONTEND_FEATURE_PIPELINE_H_

View File

@ -0,0 +1,121 @@
// Copyright (c) 2016 HR
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "frontend/fft.h"
namespace wenet {
void make_sintbl(int n, float* sintbl) {
int i, n2, n4, n8;
float c, s, dc, ds, t;
n2 = n / 2;
n4 = n / 4;
n8 = n / 8;
t = sin(M_PI / n);
dc = 2 * t * t;
ds = sqrt(dc * (2 - dc));
t = 2 * dc;
c = sintbl[n4] = 1;
s = sintbl[0] = 0;
for (i = 1; i < n8; ++i) {
c -= dc;
dc += t * c;
s += ds;
ds -= t * s;
sintbl[i] = s;
sintbl[n4 - i] = c;
}
if (n8 != 0) sintbl[n8] = sqrt(0.5);
for (i = 0; i < n4; ++i) sintbl[n2 - i] = sintbl[i];
for (i = 0; i < n2 + n4; ++i) sintbl[i + n2] = -sintbl[i];
}
void make_bitrev(int n, int* bitrev) {
int i, j, k, n2;
n2 = n / 2;
i = j = 0;
for (;;) {
bitrev[i] = j;
if (++i >= n) break;
k = n2;
while (k <= j) {
j -= k;
k /= 2;
}
j += k;
}
}
// bitrev: bit reversal table
// sintbl: trigonometric function table
// x:real part
// y:image part
// n: fft length
int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n) {
int i, j, k, ik, h, d, k2, n4, inverse;
float t, s, c, dx, dy;
/* preparation */
if (n < 0) {
n = -n;
inverse = 1; /* inverse transform */
} else {
inverse = 0;
}
n4 = n / 4;
if (n == 0) {
return 0;
}
/* bit reversal */
for (i = 0; i < n; ++i) {
j = bitrev[i];
if (i < j) {
t = x[i];
x[i] = x[j];
x[j] = t;
t = y[i];
y[i] = y[j];
y[j] = t;
}
}
/* transformation */
for (k = 1; k < n; k = k2) {
h = 0;
k2 = k + k;
d = n / k2;
for (j = 0; j < k; ++j) {
c = sintbl[h + n4];
if (inverse)
s = -sintbl[h];
else
s = sintbl[h];
for (i = j; i < n; i += k2) {
ik = i + k;
dx = s * y[ik] + c * x[ik];
dy = c * y[ik] - s * x[ik];
x[ik] = x[i] - dx;
x[i] += dx;
y[ik] = y[i] - dy;
y[i] += dy;
}
h += d;
}
}
if (inverse) {
/* divide by n in case of the inverse transformation */
for (i = 0; i < n; ++i) {
x[i] /= n;
y[i] /= n;
}
}
return 0; /* finished successfully */
}
} // namespace wenet

View File

@ -0,0 +1,25 @@
// Copyright (c) 2016 HR
#ifndef FRONTEND_FFT_H_
#define FRONTEND_FFT_H_
#ifndef M_PI
#define M_PI 3.1415926535897932384626433832795
#endif
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
namespace wenet {
// Fast Fourier Transform
void make_sintbl(int n, float* sintbl);
void make_bitrev(int n, int* bitrev);
int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n);
} // namespace wenet
#endif // FRONTEND_FFT_H_

203
runtime/core/frontend/wav.h Normal file
View File

@ -0,0 +1,203 @@
// Copyright (c) 2016 Personal (Binbin Zhang)
// Created on 2016-08-15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include "utils/log.h"
namespace wenet {
struct WavHeader {
char riff[4]; // "riff"
unsigned int size;
char wav[4]; // "WAVE"
char fmt[4]; // "fmt "
unsigned int fmt_size;
uint16_t format;
uint16_t channels;
unsigned int sample_rate;
unsigned int bytes_per_second;
uint16_t block_size;
uint16_t bit;
char data[4]; // "data"
unsigned int data_size;
};
class WavReader {
public:
WavReader() : data_(nullptr) {}
explicit WavReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb");
if (NULL == fp) {
LOG(WARNING) << "Error in read " << filename;
return false;
}
WavHeader header;
fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) {
fprintf(stderr,
"WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n");
return false;
} else if (header.fmt_size > 16) {
int offset = 44 - 8 + header.fmt_size - 16;
fseek(fp, offset, SEEK_SET);
fread(header.data, 8, sizeof(char), fp);
}
// check "riff" "WAVE" "fmt " "data"
// Skip any subchunks between "fmt" and "data". Usually there will
// be a single "fact" subchunk, but on Windows there can also be a
// "list" subchunk.
while (0 != strncmp(header.data, "data", 4)) {
// We will just ignore the data in these chunks.
fseek(fp, header.data_size, SEEK_CUR);
// read next subchunk
fread(header.data, 8, sizeof(char), fp);
}
num_channel_ = header.channels;
sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit;
int num_data = header.data_size / (bits_per_sample_ / 8);
data_ = new float[num_data];
num_samples_ = num_data / num_channel_;
for (int i = 0; i < num_data; ++i) {
switch (bits_per_sample_) {
case 8: {
char sample;
fread(&sample, 1, sizeof(char), fp);
data_[i] = static_cast<float>(sample);
break;
}
case 16: {
int16_t sample;
fread(&sample, 1, sizeof(int16_t), fp);
data_[i] = static_cast<float>(sample);
break;
}
case 32: {
int sample;
fread(&sample, 1, sizeof(int), fp);
data_[i] = static_cast<float>(sample);
break;
}
default:
fprintf(stderr, "unsupported quantization bits");
exit(1);
}
}
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_samples() const { return num_samples_; }
~WavReader() {
if (data_ != NULL) delete[] data_;
}
const float* data() const { return data_; }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_samples_; // sample points per channel
float* data_;
};
class WavWriter {
public:
WavWriter(const float* data, int num_samples, int num_channel,
int sample_rate, int bits_per_sample)
: data_(data),
num_samples_(num_samples),
num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample) {}
void Write(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "w");
// init char 'riff' 'WAVE' 'fmt ' 'data'
WavHeader header;
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
memcpy(&header, wav_header, sizeof(header));
header.channels = num_channel_;
header.bit = bits_per_sample_;
header.sample_rate = sample_rate_;
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
header.size = sizeof(header) - 8 + header.data_size;
header.bytes_per_second =
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
header.block_size = num_channel_ * (bits_per_sample_ / 8);
fwrite(&header, 1, sizeof(header), fp);
for (int i = 0; i < num_samples_; ++i) {
for (int j = 0; j < num_channel_; ++j) {
switch (bits_per_sample_) {
case 8: {
char sample = static_cast<char>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 16: {
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 32: {
int sample = static_cast<int>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
}
}
}
fclose(fp);
}
private:
const float* data_;
int num_samples_; // total float points in data_
int num_channel_;
int sample_rate_;
int bits_per_sample_;
};
} // namespace wenet
#endif // FRONTEND_WAV_H_

View File

@ -0,0 +1,101 @@
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "kws/keyword_spotting.h"
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace wekws {
Ort::Env KeywordSpotting::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "");
Ort::SessionOptions KeywordSpotting::session_options_ = Ort::SessionOptions();
KeywordSpotting::KeywordSpotting(const std::string& model_path) {
// 1. Load sessions
session_ = std::make_shared<Ort::Session>(env_, model_path.c_str(),
session_options_);
// 2. Model info
in_names_ = {"input", "cache"};
out_names_ = {"output", "r_cache"};
auto metadata = session_->GetModelMetadata();
Ort::AllocatorWithDefaultOptions allocator;
cache_dim_ = std::stoi(metadata.LookupCustomMetadataMap("cache_dim",
allocator));
cache_len_ = std::stoi(metadata.LookupCustomMetadataMap("cache_len",
allocator));
std::cout << "Kws Model Info:" << std::endl
<< "\tcache_dim: " << cache_dim_ << std::endl
<< "\tcache_len: " << cache_len_ << std::endl;
Reset();
}
void KeywordSpotting::Reset() {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
cache_.resize(cache_dim_ * cache_len_, 0.0);
const int64_t cache_shape[] = {1, cache_dim_, cache_len_};
cache_ort_ = Ort::Value::CreateTensor<float>(
memory_info, cache_.data(), cache_.size(), cache_shape, 3);
}
void KeywordSpotting::Forward(
const std::vector<std::vector<float>>& feats,
std::vector<std::vector<float>>* prob) {
prob->clear();
if (feats.size() == 0) return;
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
// 1. Prepare input
int num_frames = feats.size();
int feature_dim = feats[0].size();
std::vector<float> slice_feats;
for (int i = 0; i < feats.size(); i++) {
slice_feats.insert(slice_feats.end(), feats[i].begin(), feats[i].end());
}
const int64_t feats_shape[3] = {1, num_frames, feature_dim};
Ort::Value feats_ort = Ort::Value::CreateTensor<float>(
memory_info, slice_feats.data(), slice_feats.size(), feats_shape, 3);
// 2. Ort forward
std::vector<Ort::Value> inputs;
inputs.emplace_back(std::move(feats_ort));
inputs.emplace_back(std::move(cache_ort_));
// ort_outputs.size() == 2
std::vector<Ort::Value> ort_outputs = session_->Run(
Ort::RunOptions{nullptr}, in_names_.data(), inputs.data(),
inputs.size(), out_names_.data(), out_names_.size());
// 3. Update cache
cache_ort_ = std::move(ort_outputs[1]);
// 4. Get keyword prob
float* data = ort_outputs[0].GetTensorMutableData<float>();
auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
int num_outputs = type_info.GetShape()[1];
int output_dim = type_info.GetShape()[2];
prob->resize(num_outputs);
for (int i = 0; i < num_outputs; i++) {
(*prob)[i].resize(output_dim);
memcpy((*prob)[i].data(), data + i * output_dim,
sizeof(float) * output_dim);
}
}
} // namespace wekws

View File

@ -0,0 +1,61 @@
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef KWS_KEYWORD_SPOTTING_H_
#define KWS_KEYWORD_SPOTTING_H_
#include <memory>
#include <string>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace wekws {
class KeywordSpotting {
public:
explicit KeywordSpotting(const std::string& model_path);
// Call reset if keyword is detected
void Reset();
static void InitEngineThreads(int num_threads) {
session_options_.SetIntraOpNumThreads(num_threads);
session_options_.SetInterOpNumThreads(num_threads);
}
void Forward(const std::vector<std::vector<float>>& feats,
std::vector<std::vector<float>>* prob);
private:
// session
static Ort::Env env_;
static Ort::SessionOptions session_options_;
std::shared_ptr<Ort::Session> session_ = nullptr;
// node names
std::vector<const char*> in_names_;
std::vector<const char*> out_names_;
// meta info
int cache_dim_ = 0;
int cache_len_ = 0;
// cache info
Ort::Value cache_ort_{nullptr};
std::vector<float> cache_;
};
} // namespace wekws
#endif // KWS_KEYWORD_SPOTTING_H_

View File

@ -0,0 +1,98 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef UTILS_BLOCKING_QUEUE_H_
#define UTILS_BLOCKING_QUEUE_H_
#include <condition_variable>
#include <limits>
#include <mutex>
#include <queue>
#include <utility>
namespace wenet {
#define WENET_DISALLOW_COPY_AND_ASSIGN(Type) \
Type(const Type&) = delete; \
Type& operator=(const Type&) = delete;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity = std::numeric_limits<int>::max())
: capacity_(capacity) {}
void Push(const T& value) {
{
std::unique_lock<std::mutex> lock(mutex_);
while (queue_.size() >= capacity_) {
not_full_condition_.wait(lock);
}
queue_.push(value);
}
not_empty_condition_.notify_one();
}
void Push(T&& value) {
{
std::unique_lock<std::mutex> lock(mutex_);
while (queue_.size() >= capacity_) {
not_full_condition_.wait(lock);
}
queue_.push(std::move(value));
}
not_empty_condition_.notify_one();
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
while (queue_.empty()) {
not_empty_condition_.wait(lock);
}
T t(std::move(queue_.front()));
queue_.pop();
not_full_condition_.notify_one();
return t;
}
bool Empty() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.empty();
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
void Clear() {
while (!Empty()) {
Pop();
}
}
private:
size_t capacity_;
mutable std::mutex mutex_;
std::condition_variable not_full_condition_;
std::condition_variable not_empty_condition_;
std::queue<T> queue_;
public:
WENET_DISALLOW_COPY_AND_ASSIGN(BlockingQueue);
};
} // namespace wenet
#endif // UTILS_BLOCKING_QUEUE_H_

83
runtime/core/utils/log.h Normal file
View File

@ -0,0 +1,83 @@
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef UTILS_LOG_H_
#define UTILS_LOG_H_
#include <stdlib.h>
#include <iostream>
#include <sstream>
namespace wenet {
const int INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3;
class Logger {
public:
Logger(int severity, const char* func, const char* file, int line) {
severity_ = severity;
switch (severity) {
case INFO:
ss_ << "INFO (";
break;
case WARNING:
ss_ << "WARNING (";
break;
case ERROR:
ss_ << "ERROR (";
break;
case FATAL:
ss_ << "FATAL (";
break;
default:
severity_ = FATAL;
ss_ << "FATAL (";
}
ss_ << func << "():" << file << ':' << line << ") ";
}
~Logger() {
std::cerr << ss_.str() << std::endl << std::flush;
if (severity_ == FATAL) {
abort();
}
}
template <typename T> Logger& operator<<(const T &val) {
ss_ << val;
return *this;
}
private:
int severity_;
std::ostringstream ss_;
};
#define LOG(severity) ::wenet::Logger( \
::wenet::severity, __func__, __FILE__, __LINE__)
#define CHECK(test) \
do { \
if (!(test)) { \
std::cerr << "CHECK (" << __func__ << "():" << __FILE__ << ":" \
<< __LINE__ << ") " << #test << std::endl; \
exit(-1); \
} \
} while (0)
} // namespace wenet
#endif // UTILS_LOG_H_

View File

@ -0,0 +1,31 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(wekws VERSION 0.1)
set(CMAKE_VERBOSE_MAKEFILE on)
include(FetchContent)
include(ExternalProject)
set(FETCHCONTENT_QUIET OFF)
get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_base})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -g -pthread")
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
FetchContent_Declare(onnxruntime
URL https://github.com/microsoft/onnxruntime/releases/download/v1.11.1/onnxruntime-linux-x64-1.11.1.tgz
URL_HASH SHA256=ddc03b5ae325c675ff76a6f18786ce7d310be6eb6f320087f7a0e9228115f24d
)
FetchContent_MakeAvailable(onnxruntime)
include_directories(${onnxruntime_SOURCE_DIR}/include)
link_directories(${onnxruntime_SOURCE_DIR}/lib)
add_executable(kws_main
bin/kws_main.cc
kws/keyword_spotting.cc
frontend/feature_pipeline.cc
frontend/fft.cc
)
target_link_libraries(kws_main PUBLIC onnxruntime)

View File

@ -0,0 +1,9 @@
## How to Build?
``` sh
mkdir build && cd build && cmake .. && cmake --build .
```
## How to Use?
Type `./build/kws_main --help` for usage.

1
runtime/onnxruntime/bin Symbolic link
View File

@ -0,0 +1 @@
../core/bin

View File

@ -0,0 +1 @@
../core/frontend

1
runtime/onnxruntime/kws Symbolic link
View File

@ -0,0 +1 @@
../core/kws

1
runtime/onnxruntime/utils Symbolic link
View File

@ -0,0 +1 @@
../core/utils