[runtime/onnxruntime] add onnxruntime support (#79)
* [runtime/onnxruntime] add onnxruntime support * add cpplint and clang-format * fix lint
This commit is contained in:
parent
5037d51ed9
commit
53d7b8f807
93
.clang-format
Normal file
93
.clang-format
Normal file
@ -0,0 +1,93 @@
|
||||
---
|
||||
Language: Cpp
|
||||
# BasedOnStyle: Google
|
||||
AccessModifierOffset: -1
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: false
|
||||
AlignEscapedNewlinesLeft: true
|
||||
AlignOperands: true
|
||||
AlignTrailingComments: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: false
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: All
|
||||
AllowShortIfStatementsOnASingleLine: true
|
||||
AllowShortLoopsOnASingleLine: true
|
||||
AlwaysBreakAfterDefinitionReturnType: None
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
AlwaysBreakTemplateDeclarations: true
|
||||
BinPackArguments: true
|
||||
BinPackParameters: true
|
||||
BraceWrapping:
|
||||
AfterClass: false
|
||||
AfterControlStatement: false
|
||||
AfterEnum: false
|
||||
AfterFunction: false
|
||||
AfterNamespace: false
|
||||
AfterObjCDeclaration: false
|
||||
AfterStruct: false
|
||||
AfterUnion: false
|
||||
BeforeCatch: false
|
||||
BeforeElse: false
|
||||
IndentBraces: false
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeBraces: Attach
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializersBeforeComma: false
|
||||
BreakAfterJavaFieldAnnotations: false
|
||||
BreakStringLiterals: true
|
||||
ColumnLimit: 80
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
Cpp11BracedListStyle: true
|
||||
DisableFormat: false
|
||||
ExperimentalAutoDetectBinPacking: false
|
||||
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
|
||||
IncludeCategories:
|
||||
- Regex: '^<.*\.h>'
|
||||
Priority: 1
|
||||
- Regex: '^<.*'
|
||||
Priority: 2
|
||||
- Regex: '.*'
|
||||
Priority: 3
|
||||
IncludeIsMainRegex: '([-_](test|unittest))?$'
|
||||
IndentCaseLabels: true
|
||||
IndentWidth: 2
|
||||
IndentWrappedFunctionNames: false
|
||||
JavaScriptQuotes: Leave
|
||||
JavaScriptWrapImports: true
|
||||
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: false
|
||||
PenaltyBreakBeforeFirstCallParameter: 1
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 200
|
||||
PointerAlignment: Left
|
||||
ReflowComments: true
|
||||
SortIncludes: true
|
||||
SpaceAfterCStyleCast: false
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 2
|
||||
SpacesInAngles: false
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
Standard: Auto
|
||||
TabWidth: 8
|
||||
UseTab: Never
|
||||
...
|
||||
2
CPPLINT.cfg
Normal file
2
CPPLINT.cfg
Normal file
@ -0,0 +1,2 @@
|
||||
root=runtime/core
|
||||
filter=-build/c++11
|
||||
64
runtime/core/bin/kws_main.cc
Normal file
64
runtime/core/bin/kws_main.cc
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "frontend/feature_pipeline.h"
|
||||
#include "frontend/wav.h"
|
||||
#include "kws/keyword_spotting.h"
|
||||
#include "utils/log.h"
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc != 5) {
|
||||
LOG(FATAL) << "Usage: kws_main fbank_dim(int) batch_size(int) "
|
||||
<< "kws_model_path test_wav_path";
|
||||
}
|
||||
|
||||
const int num_bins = std::stoi(argv[1]); // Fbank feature dim
|
||||
const int batch_size = std::stoi(argv[2]);
|
||||
const std::string model_path = argv[3];
|
||||
const std::string wav_path = argv[4];
|
||||
|
||||
wenet::WavReader wav_reader(wav_path);
|
||||
int num_samples = wav_reader.num_samples();
|
||||
wenet::FeaturePipelineConfig feature_config(num_bins, 16000);
|
||||
wenet::FeaturePipeline feature_pipeline(feature_config);
|
||||
std::vector<float> wav(wav_reader.data(), wav_reader.data() + num_samples);
|
||||
feature_pipeline.AcceptWaveform(wav);
|
||||
feature_pipeline.set_input_finished();
|
||||
|
||||
wekws::KeywordSpotting spotter(model_path);
|
||||
|
||||
// Simulate streaming, detect batch by batch
|
||||
int offset = 0;
|
||||
while (true) {
|
||||
std::vector<std::vector<float>> feats;
|
||||
bool ok = feature_pipeline.Read(batch_size, &feats);
|
||||
std::vector<std::vector<float>> prob;
|
||||
spotter.Forward(feats, &prob);
|
||||
for (int i = 0; i < prob.size(); i++) {
|
||||
std::cout << "frame " << offset + i << " prob";
|
||||
for (int j = 0; j < prob[i].size(); j++) {
|
||||
std::cout << " " << prob[i][0];
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
// Reach the end of feature pipeline
|
||||
if (!ok) break;
|
||||
offset += prob.size();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
222
runtime/core/frontend/fbank.h
Normal file
222
runtime/core/frontend/fbank.h
Normal file
@ -0,0 +1,222 @@
|
||||
// Copyright (c) 2017 Personal (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef FRONTEND_FBANK_H_
|
||||
#define FRONTEND_FBANK_H_
|
||||
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/fft.h"
|
||||
#include "utils/log.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
// This code is based on kaldi Fbank implentation, please see
|
||||
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc
|
||||
class Fbank {
|
||||
public:
|
||||
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift)
|
||||
: num_bins_(num_bins),
|
||||
sample_rate_(sample_rate),
|
||||
frame_length_(frame_length),
|
||||
frame_shift_(frame_shift),
|
||||
use_log_(true),
|
||||
remove_dc_offset_(true),
|
||||
generator_(0),
|
||||
distribution_(0, 1.0),
|
||||
dither_(0.0) {
|
||||
fft_points_ = UpperPowerOfTwo(frame_length_);
|
||||
// generate bit reversal table and trigonometric function table
|
||||
const int fft_points_4 = fft_points_ / 4;
|
||||
bitrev_.resize(fft_points_);
|
||||
sintbl_.resize(fft_points_ + fft_points_4);
|
||||
make_sintbl(fft_points_, sintbl_.data());
|
||||
make_bitrev(fft_points_, bitrev_.data());
|
||||
|
||||
int num_fft_bins = fft_points_ / 2;
|
||||
float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;
|
||||
int low_freq = 20, high_freq = sample_rate_ / 2;
|
||||
float mel_low_freq = MelScale(low_freq);
|
||||
float mel_high_freq = MelScale(high_freq);
|
||||
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
|
||||
bins_.resize(num_bins_);
|
||||
center_freqs_.resize(num_bins_);
|
||||
for (int bin = 0; bin < num_bins; ++bin) {
|
||||
float left_mel = mel_low_freq + bin * mel_freq_delta,
|
||||
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
|
||||
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
|
||||
center_freqs_[bin] = InverseMelScale(center_mel);
|
||||
std::vector<float> this_bin(num_fft_bins);
|
||||
int first_index = -1, last_index = -1;
|
||||
for (int i = 0; i < num_fft_bins; ++i) {
|
||||
float freq = (fft_bin_width * i); // Center frequency of this fft
|
||||
// bin.
|
||||
float mel = MelScale(freq);
|
||||
if (mel > left_mel && mel < right_mel) {
|
||||
float weight;
|
||||
if (mel <= center_mel)
|
||||
weight = (mel - left_mel) / (center_mel - left_mel);
|
||||
else
|
||||
weight = (right_mel - mel) / (right_mel - center_mel);
|
||||
this_bin[i] = weight;
|
||||
if (first_index == -1) first_index = i;
|
||||
last_index = i;
|
||||
}
|
||||
}
|
||||
CHECK(first_index != -1 && last_index >= first_index);
|
||||
bins_[bin].first = first_index;
|
||||
int size = last_index + 1 - first_index;
|
||||
bins_[bin].second.resize(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
bins_[bin].second[i] = this_bin[first_index + i];
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE(cdliang): add hamming window
|
||||
hamming_window_.resize(frame_length_);
|
||||
double a = M_2PI / (frame_length - 1);
|
||||
for (int i = 0; i < frame_length; i++) {
|
||||
double i_fl = static_cast<double>(i);
|
||||
hamming_window_[i] = 0.54 - 0.46 * cos(a * i_fl);
|
||||
}
|
||||
}
|
||||
|
||||
void set_use_log(bool use_log) { use_log_ = use_log; }
|
||||
|
||||
void set_remove_dc_offset(bool remove_dc_offset) {
|
||||
remove_dc_offset_ = remove_dc_offset;
|
||||
}
|
||||
|
||||
void set_dither(float dither) { dither_ = dither; }
|
||||
|
||||
int num_bins() const { return num_bins_; }
|
||||
|
||||
static inline float InverseMelScale(float mel_freq) {
|
||||
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
|
||||
}
|
||||
|
||||
static inline float MelScale(float freq) {
|
||||
return 1127.0f * logf(1.0f + freq / 700.0f);
|
||||
}
|
||||
|
||||
static int UpperPowerOfTwo(int n) {
|
||||
return static_cast<int>(pow(2, ceil(log(n) / log(2))));
|
||||
}
|
||||
|
||||
// preemphasis
|
||||
void PreEmphasis(float coeff, std::vector<float>* data) const {
|
||||
if (coeff == 0.0) return;
|
||||
for (int i = data->size() - 1; i > 0; i--)
|
||||
(*data)[i] -= coeff * (*data)[i - 1];
|
||||
(*data)[0] -= coeff * (*data)[0];
|
||||
}
|
||||
|
||||
// add hamming window
|
||||
void Hamming(std::vector<float>* data) const {
|
||||
CHECK(data->size() >= hamming_window_.size());
|
||||
for (size_t i = 0; i < hamming_window_.size(); ++i) {
|
||||
(*data)[i] *= hamming_window_[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Compute fbank feat, return num frames
|
||||
int Compute(const std::vector<float>& wave,
|
||||
std::vector<std::vector<float>>* feat) {
|
||||
int num_samples = wave.size();
|
||||
if (num_samples < frame_length_) return 0;
|
||||
int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
|
||||
feat->resize(num_frames);
|
||||
std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);
|
||||
std::vector<float> power(fft_points_ / 2);
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
std::vector<float> data(wave.data() + i * frame_shift_,
|
||||
wave.data() + i * frame_shift_ + frame_length_);
|
||||
// optional add noise
|
||||
if (dither_ != 0.0) {
|
||||
for (size_t j = 0; j < data.size(); ++j)
|
||||
data[j] += dither_ * distribution_(generator_);
|
||||
}
|
||||
// optinal remove dc offset
|
||||
if (remove_dc_offset_) {
|
||||
float mean = 0.0;
|
||||
for (size_t j = 0; j < data.size(); ++j) mean += data[j];
|
||||
mean /= data.size();
|
||||
for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;
|
||||
}
|
||||
|
||||
PreEmphasis(0.97, &data);
|
||||
// Povey(&data);
|
||||
Hamming(&data);
|
||||
// copy data to fft_real
|
||||
memset(fft_img.data(), 0, sizeof(float) * fft_points_);
|
||||
memset(fft_real.data() + frame_length_, 0,
|
||||
sizeof(float) * (fft_points_ - frame_length_));
|
||||
memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_);
|
||||
fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(),
|
||||
fft_points_);
|
||||
// power
|
||||
for (int j = 0; j < fft_points_ / 2; ++j) {
|
||||
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
|
||||
}
|
||||
|
||||
(*feat)[i].resize(num_bins_);
|
||||
// cepstral coefficients, triangle filter array
|
||||
for (int j = 0; j < num_bins_; ++j) {
|
||||
float mel_energy = 0.0;
|
||||
int s = bins_[j].first;
|
||||
for (size_t k = 0; k < bins_[j].second.size(); ++k) {
|
||||
mel_energy += bins_[j].second[k] * power[s + k];
|
||||
}
|
||||
// optional use log
|
||||
if (use_log_) {
|
||||
if (mel_energy < std::numeric_limits<float>::epsilon())
|
||||
mel_energy = std::numeric_limits<float>::epsilon();
|
||||
mel_energy = logf(mel_energy);
|
||||
}
|
||||
|
||||
(*feat)[i][j] = mel_energy;
|
||||
// printf("%f ", mel_energy);
|
||||
}
|
||||
// printf("\n");
|
||||
}
|
||||
return num_frames;
|
||||
}
|
||||
|
||||
private:
|
||||
int num_bins_;
|
||||
int sample_rate_;
|
||||
int frame_length_, frame_shift_;
|
||||
int fft_points_;
|
||||
bool use_log_;
|
||||
bool remove_dc_offset_;
|
||||
std::vector<float> center_freqs_;
|
||||
std::vector<std::pair<int, std::vector<float>>> bins_;
|
||||
std::vector<float> hamming_window_;
|
||||
std::default_random_engine generator_;
|
||||
std::normal_distribution<float> distribution_;
|
||||
float dither_;
|
||||
|
||||
// bit reversal table
|
||||
std::vector<int> bitrev_;
|
||||
// trigonometric function table
|
||||
std::vector<float> sintbl_;
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // FRONTEND_FBANK_H_
|
||||
113
runtime/core/frontend/feature_pipeline.cc
Normal file
113
runtime/core/frontend/feature_pipeline.cc
Normal file
@ -0,0 +1,113 @@
|
||||
// Copyright (c) 2017 Personal (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "frontend/feature_pipeline.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
namespace wenet {
|
||||
|
||||
FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config)
|
||||
: config_(config),
|
||||
feature_dim_(config.num_bins),
|
||||
fbank_(config.num_bins, config.sample_rate, config.frame_length,
|
||||
config.frame_shift),
|
||||
num_frames_(0),
|
||||
input_finished_(false) {}
|
||||
|
||||
void FeaturePipeline::AcceptWaveform(const std::vector<float>& wav) {
|
||||
std::vector<std::vector<float>> feats;
|
||||
std::vector<float> waves;
|
||||
waves.insert(waves.end(), remained_wav_.begin(), remained_wav_.end());
|
||||
waves.insert(waves.end(), wav.begin(), wav.end());
|
||||
int num_frames = fbank_.Compute(waves, &feats);
|
||||
for (size_t i = 0; i < feats.size(); ++i) {
|
||||
feature_queue_.Push(std::move(feats[i]));
|
||||
}
|
||||
num_frames_ += num_frames;
|
||||
|
||||
int left_samples = waves.size() - config_.frame_shift * num_frames;
|
||||
remained_wav_.resize(left_samples);
|
||||
std::copy(waves.begin() + config_.frame_shift * num_frames, waves.end(),
|
||||
remained_wav_.begin());
|
||||
// We are still adding wave, notify input is not finished
|
||||
finish_condition_.notify_one();
|
||||
}
|
||||
|
||||
void FeaturePipeline::AcceptWaveform(const std::vector<int16_t>& wav) {
|
||||
std::vector<float> float_wav(wav.size());
|
||||
for (size_t i = 0; i < wav.size(); i++) {
|
||||
float_wav[i] = static_cast<float>(wav[i]);
|
||||
}
|
||||
this->AcceptWaveform(float_wav);
|
||||
}
|
||||
|
||||
void FeaturePipeline::set_input_finished() {
|
||||
CHECK(!input_finished_);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
input_finished_ = true;
|
||||
}
|
||||
finish_condition_.notify_one();
|
||||
}
|
||||
|
||||
bool FeaturePipeline::ReadOne(std::vector<float>* feat) {
|
||||
if (!feature_queue_.Empty()) {
|
||||
*feat = std::move(feature_queue_.Pop());
|
||||
return true;
|
||||
} else {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (!input_finished_) {
|
||||
// This will release the lock and wait for notify_one()
|
||||
// from AcceptWaveform() or set_input_finished()
|
||||
finish_condition_.wait(lock);
|
||||
if (!feature_queue_.Empty()) {
|
||||
*feat = std::move(feature_queue_.Pop());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
CHECK(input_finished_);
|
||||
// Double check queue.empty, see issue#893 for detailed discussions.
|
||||
if (!feature_queue_.Empty()) {
|
||||
*feat = std::move(feature_queue_.Pop());
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool FeaturePipeline::Read(int num_frames,
|
||||
std::vector<std::vector<float>>* feats) {
|
||||
feats->clear();
|
||||
std::vector<float> feat;
|
||||
while (feats->size() < num_frames) {
|
||||
if (ReadOne(&feat)) {
|
||||
feats->push_back(std::move(feat));
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void FeaturePipeline::Reset() {
|
||||
input_finished_ = false;
|
||||
num_frames_ = 0;
|
||||
remained_wav_.clear();
|
||||
feature_queue_.Clear();
|
||||
}
|
||||
|
||||
} // namespace wenet
|
||||
118
runtime/core/frontend/feature_pipeline.h
Normal file
118
runtime/core/frontend/feature_pipeline.h
Normal file
@ -0,0 +1,118 @@
|
||||
// Copyright (c) 2017 Personal (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef FRONTEND_FEATURE_PIPELINE_H_
|
||||
#define FRONTEND_FEATURE_PIPELINE_H_
|
||||
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/fbank.h"
|
||||
#include "utils/log.h"
|
||||
#include "utils/blocking_queue.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
struct FeaturePipelineConfig {
|
||||
int num_bins;
|
||||
int sample_rate;
|
||||
int frame_length;
|
||||
int frame_shift;
|
||||
FeaturePipelineConfig(int num_bins, int sample_rate)
|
||||
: num_bins(num_bins), // 80 dim fbank
|
||||
sample_rate(sample_rate) { // 16k sample rate
|
||||
frame_length = sample_rate / 1000 * 25; // frame length 25ms
|
||||
frame_shift = sample_rate / 1000 * 10; // frame shift 10ms
|
||||
}
|
||||
|
||||
void Info() const {
|
||||
LOG(INFO) << "feature pipeline config"
|
||||
<< " num_bins " << num_bins << " frame_length " << frame_length
|
||||
<< " frame_shift " << frame_shift;
|
||||
}
|
||||
};
|
||||
|
||||
// Typically, FeaturePipeline is used in two threads: one thread A calls
|
||||
// AcceptWaveform() to add raw wav data and set_input_finished() to notice
|
||||
// the end of input wav, another thread B (decoder thread) calls Read() to
|
||||
// consume features.So a BlockingQueue is used to make this class thread safe.
|
||||
|
||||
// The Read() is designed as a blocking method when there is no feature
|
||||
// in feature_queue_ and the input is not finished.
|
||||
|
||||
class FeaturePipeline {
|
||||
public:
|
||||
explicit FeaturePipeline(const FeaturePipelineConfig& config);
|
||||
|
||||
// The feature extraction is done in AcceptWaveform().
|
||||
void AcceptWaveform(const std::vector<float>& wav);
|
||||
void AcceptWaveform(const std::vector<int16_t>& wav);
|
||||
|
||||
// Current extracted frames number.
|
||||
int num_frames() const { return num_frames_; }
|
||||
int feature_dim() const { return feature_dim_; }
|
||||
const FeaturePipelineConfig& config() const { return config_; }
|
||||
|
||||
// The caller should call this method when speech input is end.
|
||||
// Never call AcceptWaveform() after calling set_input_finished() !
|
||||
void set_input_finished();
|
||||
bool input_finished() const { return input_finished_; }
|
||||
|
||||
// Return False if input is finished and no feature could be read.
|
||||
// Return True if a feature is read.
|
||||
// This function is a blocking method. It will block the thread when
|
||||
// there is no feature in feature_queue_ and the input is not finished.
|
||||
bool ReadOne(std::vector<float>* feat);
|
||||
|
||||
// Read #num_frames frame features.
|
||||
// Return False if less then #num_frames features are read and the
|
||||
// input is finished.
|
||||
// Return True if #num_frames features are read.
|
||||
// This function is a blocking method when there is no feature
|
||||
// in feature_queue_ and the input is not finished.
|
||||
bool Read(int num_frames, std::vector<std::vector<float>>* feats);
|
||||
|
||||
void Reset();
|
||||
bool IsLastFrame(int frame) const {
|
||||
return input_finished_ && (frame == num_frames_ - 1);
|
||||
}
|
||||
|
||||
int NumQueuedFrames() const { return feature_queue_.Size(); }
|
||||
|
||||
private:
|
||||
const FeaturePipelineConfig& config_;
|
||||
int feature_dim_;
|
||||
Fbank fbank_;
|
||||
|
||||
BlockingQueue<std::vector<float>> feature_queue_;
|
||||
int num_frames_;
|
||||
bool input_finished_;
|
||||
|
||||
// The feature extraction is done in AcceptWaveform().
|
||||
// This wavefrom sample points are consumed by frame size.
|
||||
// The residual wavefrom sample points after framing are
|
||||
// kept to be used in next AcceptWaveform() calling.
|
||||
std::vector<float> remained_wav_;
|
||||
|
||||
// Used to block the Read when there is no feature in feature_queue_
|
||||
// and the input is not finished.
|
||||
mutable std::mutex mutex_;
|
||||
std::condition_variable finish_condition_;
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // FRONTEND_FEATURE_PIPELINE_H_
|
||||
121
runtime/core/frontend/fft.cc
Normal file
121
runtime/core/frontend/fft.cc
Normal file
@ -0,0 +1,121 @@
|
||||
// Copyright (c) 2016 HR
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "frontend/fft.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
void make_sintbl(int n, float* sintbl) {
|
||||
int i, n2, n4, n8;
|
||||
float c, s, dc, ds, t;
|
||||
|
||||
n2 = n / 2;
|
||||
n4 = n / 4;
|
||||
n8 = n / 8;
|
||||
t = sin(M_PI / n);
|
||||
dc = 2 * t * t;
|
||||
ds = sqrt(dc * (2 - dc));
|
||||
t = 2 * dc;
|
||||
c = sintbl[n4] = 1;
|
||||
s = sintbl[0] = 0;
|
||||
for (i = 1; i < n8; ++i) {
|
||||
c -= dc;
|
||||
dc += t * c;
|
||||
s += ds;
|
||||
ds -= t * s;
|
||||
sintbl[i] = s;
|
||||
sintbl[n4 - i] = c;
|
||||
}
|
||||
if (n8 != 0) sintbl[n8] = sqrt(0.5);
|
||||
for (i = 0; i < n4; ++i) sintbl[n2 - i] = sintbl[i];
|
||||
for (i = 0; i < n2 + n4; ++i) sintbl[i + n2] = -sintbl[i];
|
||||
}
|
||||
|
||||
void make_bitrev(int n, int* bitrev) {
|
||||
int i, j, k, n2;
|
||||
|
||||
n2 = n / 2;
|
||||
i = j = 0;
|
||||
for (;;) {
|
||||
bitrev[i] = j;
|
||||
if (++i >= n) break;
|
||||
k = n2;
|
||||
while (k <= j) {
|
||||
j -= k;
|
||||
k /= 2;
|
||||
}
|
||||
j += k;
|
||||
}
|
||||
}
|
||||
|
||||
// bitrev: bit reversal table
|
||||
// sintbl: trigonometric function table
|
||||
// x:real part
|
||||
// y:image part
|
||||
// n: fft length
|
||||
int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n) {
|
||||
int i, j, k, ik, h, d, k2, n4, inverse;
|
||||
float t, s, c, dx, dy;
|
||||
|
||||
/* preparation */
|
||||
if (n < 0) {
|
||||
n = -n;
|
||||
inverse = 1; /* inverse transform */
|
||||
} else {
|
||||
inverse = 0;
|
||||
}
|
||||
n4 = n / 4;
|
||||
if (n == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* bit reversal */
|
||||
for (i = 0; i < n; ++i) {
|
||||
j = bitrev[i];
|
||||
if (i < j) {
|
||||
t = x[i];
|
||||
x[i] = x[j];
|
||||
x[j] = t;
|
||||
t = y[i];
|
||||
y[i] = y[j];
|
||||
y[j] = t;
|
||||
}
|
||||
}
|
||||
|
||||
/* transformation */
|
||||
for (k = 1; k < n; k = k2) {
|
||||
h = 0;
|
||||
k2 = k + k;
|
||||
d = n / k2;
|
||||
for (j = 0; j < k; ++j) {
|
||||
c = sintbl[h + n4];
|
||||
if (inverse)
|
||||
s = -sintbl[h];
|
||||
else
|
||||
s = sintbl[h];
|
||||
for (i = j; i < n; i += k2) {
|
||||
ik = i + k;
|
||||
dx = s * y[ik] + c * x[ik];
|
||||
dy = c * y[ik] - s * x[ik];
|
||||
x[ik] = x[i] - dx;
|
||||
x[i] += dx;
|
||||
y[ik] = y[i] - dy;
|
||||
y[i] += dy;
|
||||
}
|
||||
h += d;
|
||||
}
|
||||
}
|
||||
if (inverse) {
|
||||
/* divide by n in case of the inverse transformation */
|
||||
for (i = 0; i < n; ++i) {
|
||||
x[i] /= n;
|
||||
y[i] /= n;
|
||||
}
|
||||
}
|
||||
return 0; /* finished successfully */
|
||||
}
|
||||
|
||||
} // namespace wenet
|
||||
25
runtime/core/frontend/fft.h
Normal file
25
runtime/core/frontend/fft.h
Normal file
@ -0,0 +1,25 @@
|
||||
// Copyright (c) 2016 HR
|
||||
|
||||
#ifndef FRONTEND_FFT_H_
|
||||
#define FRONTEND_FFT_H_
|
||||
|
||||
#ifndef M_PI
|
||||
#define M_PI 3.1415926535897932384626433832795
|
||||
#endif
|
||||
#ifndef M_2PI
|
||||
#define M_2PI 6.283185307179586476925286766559005
|
||||
#endif
|
||||
|
||||
namespace wenet {
|
||||
|
||||
// Fast Fourier Transform
|
||||
|
||||
void make_sintbl(int n, float* sintbl);
|
||||
|
||||
void make_bitrev(int n, int* bitrev);
|
||||
|
||||
int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n);
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // FRONTEND_FFT_H_
|
||||
203
runtime/core/frontend/wav.h
Normal file
203
runtime/core/frontend/wav.h
Normal file
@ -0,0 +1,203 @@
|
||||
// Copyright (c) 2016 Personal (Binbin Zhang)
|
||||
// Created on 2016-08-15
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef FRONTEND_WAV_H_
|
||||
#define FRONTEND_WAV_H_
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "utils/log.h"
|
||||
|
||||
namespace wenet {
|
||||
|
||||
struct WavHeader {
|
||||
char riff[4]; // "riff"
|
||||
unsigned int size;
|
||||
char wav[4]; // "WAVE"
|
||||
char fmt[4]; // "fmt "
|
||||
unsigned int fmt_size;
|
||||
uint16_t format;
|
||||
uint16_t channels;
|
||||
unsigned int sample_rate;
|
||||
unsigned int bytes_per_second;
|
||||
uint16_t block_size;
|
||||
uint16_t bit;
|
||||
char data[4]; // "data"
|
||||
unsigned int data_size;
|
||||
};
|
||||
|
||||
class WavReader {
|
||||
public:
|
||||
WavReader() : data_(nullptr) {}
|
||||
explicit WavReader(const std::string& filename) { Open(filename); }
|
||||
|
||||
bool Open(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "rb");
|
||||
if (NULL == fp) {
|
||||
LOG(WARNING) << "Error in read " << filename;
|
||||
return false;
|
||||
}
|
||||
|
||||
WavHeader header;
|
||||
fread(&header, 1, sizeof(header), fp);
|
||||
if (header.fmt_size < 16) {
|
||||
fprintf(stderr,
|
||||
"WaveData: expect PCM format data "
|
||||
"to have fmt chunk of at least size 16.\n");
|
||||
return false;
|
||||
} else if (header.fmt_size > 16) {
|
||||
int offset = 44 - 8 + header.fmt_size - 16;
|
||||
fseek(fp, offset, SEEK_SET);
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
// check "riff" "WAVE" "fmt " "data"
|
||||
|
||||
// Skip any subchunks between "fmt" and "data". Usually there will
|
||||
// be a single "fact" subchunk, but on Windows there can also be a
|
||||
// "list" subchunk.
|
||||
while (0 != strncmp(header.data, "data", 4)) {
|
||||
// We will just ignore the data in these chunks.
|
||||
fseek(fp, header.data_size, SEEK_CUR);
|
||||
// read next subchunk
|
||||
fread(header.data, 8, sizeof(char), fp);
|
||||
}
|
||||
|
||||
num_channel_ = header.channels;
|
||||
sample_rate_ = header.sample_rate;
|
||||
bits_per_sample_ = header.bit;
|
||||
int num_data = header.data_size / (bits_per_sample_ / 8);
|
||||
data_ = new float[num_data];
|
||||
num_samples_ = num_data / num_channel_;
|
||||
|
||||
for (int i = 0; i < num_data; ++i) {
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample;
|
||||
fread(&sample, 1, sizeof(char), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample;
|
||||
fread(&sample, 1, sizeof(int16_t), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
int sample;
|
||||
fread(&sample, 1, sizeof(int), fp);
|
||||
data_[i] = static_cast<float>(sample);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
fprintf(stderr, "unsupported quantization bits");
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
return true;
|
||||
}
|
||||
|
||||
int num_channel() const { return num_channel_; }
|
||||
int sample_rate() const { return sample_rate_; }
|
||||
int bits_per_sample() const { return bits_per_sample_; }
|
||||
int num_samples() const { return num_samples_; }
|
||||
|
||||
~WavReader() {
|
||||
if (data_ != NULL) delete[] data_;
|
||||
}
|
||||
|
||||
const float* data() const { return data_; }
|
||||
|
||||
private:
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
int num_samples_; // sample points per channel
|
||||
float* data_;
|
||||
};
|
||||
|
||||
class WavWriter {
|
||||
public:
|
||||
WavWriter(const float* data, int num_samples, int num_channel,
|
||||
int sample_rate, int bits_per_sample)
|
||||
: data_(data),
|
||||
num_samples_(num_samples),
|
||||
num_channel_(num_channel),
|
||||
sample_rate_(sample_rate),
|
||||
bits_per_sample_(bits_per_sample) {}
|
||||
|
||||
void Write(const std::string& filename) {
|
||||
FILE* fp = fopen(filename.c_str(), "w");
|
||||
// init char 'riff' 'WAVE' 'fmt ' 'data'
|
||||
WavHeader header;
|
||||
char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
|
||||
0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
|
||||
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
|
||||
memcpy(&header, wav_header, sizeof(header));
|
||||
header.channels = num_channel_;
|
||||
header.bit = bits_per_sample_;
|
||||
header.sample_rate = sample_rate_;
|
||||
header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.size = sizeof(header) - 8 + header.data_size;
|
||||
header.bytes_per_second =
|
||||
sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
|
||||
header.block_size = num_channel_ * (bits_per_sample_ / 8);
|
||||
|
||||
fwrite(&header, 1, sizeof(header), fp);
|
||||
|
||||
for (int i = 0; i < num_samples_; ++i) {
|
||||
for (int j = 0; j < num_channel_; ++j) {
|
||||
switch (bits_per_sample_) {
|
||||
case 8: {
|
||||
char sample = static_cast<char>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
int sample = static_cast<int>(data_[i * num_channel_ + j]);
|
||||
fwrite(&sample, 1, sizeof(sample), fp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fclose(fp);
|
||||
}
|
||||
|
||||
private:
|
||||
const float* data_;
|
||||
int num_samples_; // total float points in data_
|
||||
int num_channel_;
|
||||
int sample_rate_;
|
||||
int bits_per_sample_;
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // FRONTEND_WAV_H_
|
||||
101
runtime/core/kws/keyword_spotting.cc
Normal file
101
runtime/core/kws/keyword_spotting.cc
Normal file
@ -0,0 +1,101 @@
|
||||
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include "kws/keyword_spotting.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace wekws {
|
||||
|
||||
Ort::Env KeywordSpotting::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "");
|
||||
Ort::SessionOptions KeywordSpotting::session_options_ = Ort::SessionOptions();
|
||||
|
||||
KeywordSpotting::KeywordSpotting(const std::string& model_path) {
|
||||
// 1. Load sessions
|
||||
session_ = std::make_shared<Ort::Session>(env_, model_path.c_str(),
|
||||
session_options_);
|
||||
// 2. Model info
|
||||
in_names_ = {"input", "cache"};
|
||||
out_names_ = {"output", "r_cache"};
|
||||
auto metadata = session_->GetModelMetadata();
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
cache_dim_ = std::stoi(metadata.LookupCustomMetadataMap("cache_dim",
|
||||
allocator));
|
||||
cache_len_ = std::stoi(metadata.LookupCustomMetadataMap("cache_len",
|
||||
allocator));
|
||||
std::cout << "Kws Model Info:" << std::endl
|
||||
<< "\tcache_dim: " << cache_dim_ << std::endl
|
||||
<< "\tcache_len: " << cache_len_ << std::endl;
|
||||
Reset();
|
||||
}
|
||||
|
||||
|
||||
void KeywordSpotting::Reset() {
|
||||
Ort::MemoryInfo memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
cache_.resize(cache_dim_ * cache_len_, 0.0);
|
||||
const int64_t cache_shape[] = {1, cache_dim_, cache_len_};
|
||||
cache_ort_ = Ort::Value::CreateTensor<float>(
|
||||
memory_info, cache_.data(), cache_.size(), cache_shape, 3);
|
||||
}
|
||||
|
||||
|
||||
void KeywordSpotting::Forward(
|
||||
const std::vector<std::vector<float>>& feats,
|
||||
std::vector<std::vector<float>>* prob) {
|
||||
prob->clear();
|
||||
if (feats.size() == 0) return;
|
||||
Ort::MemoryInfo memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
// 1. Prepare input
|
||||
int num_frames = feats.size();
|
||||
int feature_dim = feats[0].size();
|
||||
std::vector<float> slice_feats;
|
||||
for (int i = 0; i < feats.size(); i++) {
|
||||
slice_feats.insert(slice_feats.end(), feats[i].begin(), feats[i].end());
|
||||
}
|
||||
const int64_t feats_shape[3] = {1, num_frames, feature_dim};
|
||||
Ort::Value feats_ort = Ort::Value::CreateTensor<float>(
|
||||
memory_info, slice_feats.data(), slice_feats.size(), feats_shape, 3);
|
||||
// 2. Ort forward
|
||||
std::vector<Ort::Value> inputs;
|
||||
inputs.emplace_back(std::move(feats_ort));
|
||||
inputs.emplace_back(std::move(cache_ort_));
|
||||
// ort_outputs.size() == 2
|
||||
std::vector<Ort::Value> ort_outputs = session_->Run(
|
||||
Ort::RunOptions{nullptr}, in_names_.data(), inputs.data(),
|
||||
inputs.size(), out_names_.data(), out_names_.size());
|
||||
|
||||
// 3. Update cache
|
||||
cache_ort_ = std::move(ort_outputs[1]);
|
||||
|
||||
// 4. Get keyword prob
|
||||
float* data = ort_outputs[0].GetTensorMutableData<float>();
|
||||
auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
|
||||
int num_outputs = type_info.GetShape()[1];
|
||||
int output_dim = type_info.GetShape()[2];
|
||||
prob->resize(num_outputs);
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
(*prob)[i].resize(output_dim);
|
||||
memcpy((*prob)[i].data(), data + i * output_dim,
|
||||
sizeof(float) * output_dim);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace wekws
|
||||
61
runtime/core/kws/keyword_spotting.h
Normal file
61
runtime/core/kws/keyword_spotting.h
Normal file
@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef KWS_KEYWORD_SPOTTING_H_
|
||||
#define KWS_KEYWORD_SPOTTING_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace wekws {
|
||||
|
||||
class KeywordSpotting {
|
||||
public:
|
||||
explicit KeywordSpotting(const std::string& model_path);
|
||||
// Call reset if keyword is detected
|
||||
void Reset();
|
||||
|
||||
static void InitEngineThreads(int num_threads) {
|
||||
session_options_.SetIntraOpNumThreads(num_threads);
|
||||
session_options_.SetInterOpNumThreads(num_threads);
|
||||
}
|
||||
|
||||
void Forward(const std::vector<std::vector<float>>& feats,
|
||||
std::vector<std::vector<float>>* prob);
|
||||
|
||||
private:
|
||||
// session
|
||||
static Ort::Env env_;
|
||||
static Ort::SessionOptions session_options_;
|
||||
std::shared_ptr<Ort::Session> session_ = nullptr;
|
||||
// node names
|
||||
std::vector<const char*> in_names_;
|
||||
std::vector<const char*> out_names_;
|
||||
|
||||
// meta info
|
||||
int cache_dim_ = 0;
|
||||
int cache_len_ = 0;
|
||||
// cache info
|
||||
Ort::Value cache_ort_{nullptr};
|
||||
std::vector<float> cache_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace wekws
|
||||
|
||||
#endif // KWS_KEYWORD_SPOTTING_H_
|
||||
98
runtime/core/utils/blocking_queue.h
Normal file
98
runtime/core/utils/blocking_queue.h
Normal file
@ -0,0 +1,98 @@
|
||||
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef UTILS_BLOCKING_QUEUE_H_
|
||||
#define UTILS_BLOCKING_QUEUE_H_
|
||||
|
||||
#include <condition_variable>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
namespace wenet {
|
||||
|
||||
#define WENET_DISALLOW_COPY_AND_ASSIGN(Type) \
|
||||
Type(const Type&) = delete; \
|
||||
Type& operator=(const Type&) = delete;
|
||||
|
||||
template <typename T>
|
||||
class BlockingQueue {
|
||||
public:
|
||||
explicit BlockingQueue(size_t capacity = std::numeric_limits<int>::max())
|
||||
: capacity_(capacity) {}
|
||||
|
||||
void Push(const T& value) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (queue_.size() >= capacity_) {
|
||||
not_full_condition_.wait(lock);
|
||||
}
|
||||
queue_.push(value);
|
||||
}
|
||||
not_empty_condition_.notify_one();
|
||||
}
|
||||
|
||||
void Push(T&& value) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (queue_.size() >= capacity_) {
|
||||
not_full_condition_.wait(lock);
|
||||
}
|
||||
queue_.push(std::move(value));
|
||||
}
|
||||
not_empty_condition_.notify_one();
|
||||
}
|
||||
|
||||
T Pop() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (queue_.empty()) {
|
||||
not_empty_condition_.wait(lock);
|
||||
}
|
||||
T t(std::move(queue_.front()));
|
||||
queue_.pop();
|
||||
not_full_condition_.notify_one();
|
||||
return t;
|
||||
}
|
||||
|
||||
bool Empty() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return queue_.empty();
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return queue_.size();
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
while (!Empty()) {
|
||||
Pop();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t capacity_;
|
||||
mutable std::mutex mutex_;
|
||||
std::condition_variable not_full_condition_;
|
||||
std::condition_variable not_empty_condition_;
|
||||
std::queue<T> queue_;
|
||||
|
||||
public:
|
||||
WENET_DISALLOW_COPY_AND_ASSIGN(BlockingQueue);
|
||||
};
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // UTILS_BLOCKING_QUEUE_H_
|
||||
83
runtime/core/utils/log.h
Normal file
83
runtime/core/utils/log.h
Normal file
@ -0,0 +1,83 @@
|
||||
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef UTILS_LOG_H_
|
||||
#define UTILS_LOG_H_
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
namespace wenet {
|
||||
|
||||
const int INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3;
|
||||
|
||||
class Logger {
|
||||
public:
|
||||
Logger(int severity, const char* func, const char* file, int line) {
|
||||
severity_ = severity;
|
||||
switch (severity) {
|
||||
case INFO:
|
||||
ss_ << "INFO (";
|
||||
break;
|
||||
case WARNING:
|
||||
ss_ << "WARNING (";
|
||||
break;
|
||||
case ERROR:
|
||||
ss_ << "ERROR (";
|
||||
break;
|
||||
case FATAL:
|
||||
ss_ << "FATAL (";
|
||||
break;
|
||||
default:
|
||||
severity_ = FATAL;
|
||||
ss_ << "FATAL (";
|
||||
}
|
||||
ss_ << func << "():" << file << ':' << line << ") ";
|
||||
}
|
||||
|
||||
~Logger() {
|
||||
std::cerr << ss_.str() << std::endl << std::flush;
|
||||
if (severity_ == FATAL) {
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> Logger& operator<<(const T &val) {
|
||||
ss_ << val;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
int severity_;
|
||||
std::ostringstream ss_;
|
||||
};
|
||||
|
||||
#define LOG(severity) ::wenet::Logger( \
|
||||
::wenet::severity, __func__, __FILE__, __LINE__)
|
||||
|
||||
#define CHECK(test) \
|
||||
do { \
|
||||
if (!(test)) { \
|
||||
std::cerr << "CHECK (" << __func__ << "():" << __FILE__ << ":" \
|
||||
<< __LINE__ << ") " << #test << std::endl; \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
} // namespace wenet
|
||||
|
||||
#endif // UTILS_LOG_H_
|
||||
31
runtime/onnxruntime/CMakeLists.txt
Normal file
31
runtime/onnxruntime/CMakeLists.txt
Normal file
@ -0,0 +1,31 @@
|
||||
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
|
||||
|
||||
project(wekws VERSION 0.1)
|
||||
|
||||
set(CMAKE_VERBOSE_MAKEFILE on)
|
||||
|
||||
include(FetchContent)
|
||||
include(ExternalProject)
|
||||
set(FETCHCONTENT_QUIET OFF)
|
||||
get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
set(FETCHCONTENT_BASE_DIR ${fc_base})
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -g -pthread")
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
FetchContent_Declare(onnxruntime
|
||||
URL https://github.com/microsoft/onnxruntime/releases/download/v1.11.1/onnxruntime-linux-x64-1.11.1.tgz
|
||||
URL_HASH SHA256=ddc03b5ae325c675ff76a6f18786ce7d310be6eb6f320087f7a0e9228115f24d
|
||||
)
|
||||
FetchContent_MakeAvailable(onnxruntime)
|
||||
include_directories(${onnxruntime_SOURCE_DIR}/include)
|
||||
link_directories(${onnxruntime_SOURCE_DIR}/lib)
|
||||
|
||||
|
||||
add_executable(kws_main
|
||||
bin/kws_main.cc
|
||||
kws/keyword_spotting.cc
|
||||
frontend/feature_pipeline.cc
|
||||
frontend/fft.cc
|
||||
)
|
||||
target_link_libraries(kws_main PUBLIC onnxruntime)
|
||||
9
runtime/onnxruntime/README.md
Normal file
9
runtime/onnxruntime/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
## How to Build?
|
||||
|
||||
``` sh
|
||||
mkdir build && cd build && cmake .. && cmake --build .
|
||||
```
|
||||
|
||||
## How to Use?
|
||||
|
||||
Type `./build/kws_main --help` for usage.
|
||||
1
runtime/onnxruntime/bin
Symbolic link
1
runtime/onnxruntime/bin
Symbolic link
@ -0,0 +1 @@
|
||||
../core/bin
|
||||
1
runtime/onnxruntime/frontend
Symbolic link
1
runtime/onnxruntime/frontend
Symbolic link
@ -0,0 +1 @@
|
||||
../core/frontend
|
||||
1
runtime/onnxruntime/kws
Symbolic link
1
runtime/onnxruntime/kws
Symbolic link
@ -0,0 +1 @@
|
||||
../core/kws
|
||||
1
runtime/onnxruntime/utils
Symbolic link
1
runtime/onnxruntime/utils
Symbolic link
@ -0,0 +1 @@
|
||||
../core/utils
|
||||
Loading…
x
Reference in New Issue
Block a user