diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..91dcbc0 --- /dev/null +++ b/.clang-format @@ -0,0 +1,93 @@ +--- +Language: Cpp +# BasedOnStyle: Google +AccessModifierOffset: -1 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: true +AllowShortLoopsOnASingleLine: true +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] +IncludeCategories: + - Regex: '^<.*\.h>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Auto +TabWidth: 8 +UseTab: Never +... diff --git a/CPPLINT.cfg b/CPPLINT.cfg new file mode 100644 index 0000000..d3c8984 --- /dev/null +++ b/CPPLINT.cfg @@ -0,0 +1,2 @@ +root=runtime/core +filter=-build/c++11 diff --git a/runtime/core/bin/kws_main.cc b/runtime/core/bin/kws_main.cc new file mode 100644 index 0000000..298a139 --- /dev/null +++ b/runtime/core/bin/kws_main.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include + +#include "frontend/feature_pipeline.h" +#include "frontend/wav.h" +#include "kws/keyword_spotting.h" +#include "utils/log.h" + +int main(int argc, char* argv[]) { + if (argc != 5) { + LOG(FATAL) << "Usage: kws_main fbank_dim(int) batch_size(int) " + << "kws_model_path test_wav_path"; + } + + const int num_bins = std::stoi(argv[1]); // Fbank feature dim + const int batch_size = std::stoi(argv[2]); + const std::string model_path = argv[3]; + const std::string wav_path = argv[4]; + + wenet::WavReader wav_reader(wav_path); + int num_samples = wav_reader.num_samples(); + wenet::FeaturePipelineConfig feature_config(num_bins, 16000); + wenet::FeaturePipeline feature_pipeline(feature_config); + std::vector wav(wav_reader.data(), wav_reader.data() + num_samples); + feature_pipeline.AcceptWaveform(wav); + feature_pipeline.set_input_finished(); + + wekws::KeywordSpotting spotter(model_path); + + // Simulate streaming, detect batch by batch + int offset = 0; + while (true) { + std::vector> feats; + bool ok = feature_pipeline.Read(batch_size, &feats); + std::vector> prob; + spotter.Forward(feats, &prob); + for (int i = 0; i < prob.size(); i++) { + std::cout << "frame " << offset + i << " prob"; + for (int j = 0; j < prob[i].size(); j++) { + std::cout << " " << prob[i][0]; + } + std::cout << std::endl; + } + // Reach the end of feature pipeline + if (!ok) break; + offset += prob.size(); + } + return 0; +} diff --git a/runtime/core/frontend/fbank.h b/runtime/core/frontend/fbank.h new file mode 100644 index 0000000..5aafd77 --- /dev/null +++ b/runtime/core/frontend/fbank.h @@ -0,0 +1,222 @@ +// Copyright (c) 2017 Personal (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FRONTEND_FBANK_H_ +#define FRONTEND_FBANK_H_ + +#include +#include +#include +#include +#include + +#include "frontend/fft.h" +#include "utils/log.h" + +namespace wenet { + +// This code is based on kaldi Fbank implentation, please see +// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc +class Fbank { + public: + Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift) + : num_bins_(num_bins), + sample_rate_(sample_rate), + frame_length_(frame_length), + frame_shift_(frame_shift), + use_log_(true), + remove_dc_offset_(true), + generator_(0), + distribution_(0, 1.0), + dither_(0.0) { + fft_points_ = UpperPowerOfTwo(frame_length_); + // generate bit reversal table and trigonometric function table + const int fft_points_4 = fft_points_ / 4; + bitrev_.resize(fft_points_); + sintbl_.resize(fft_points_ + fft_points_4); + make_sintbl(fft_points_, sintbl_.data()); + make_bitrev(fft_points_, bitrev_.data()); + + int num_fft_bins = fft_points_ / 2; + float fft_bin_width = static_cast(sample_rate_) / fft_points_; + int low_freq = 20, high_freq = sample_rate_ / 2; + float mel_low_freq = MelScale(low_freq); + float mel_high_freq = MelScale(high_freq); + float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1); + bins_.resize(num_bins_); + center_freqs_.resize(num_bins_); + for (int bin = 0; bin < num_bins; ++bin) { + float left_mel = mel_low_freq + bin * mel_freq_delta, + center_mel = mel_low_freq + (bin + 1) * mel_freq_delta, + right_mel = mel_low_freq + (bin + 2) * mel_freq_delta; + center_freqs_[bin] = InverseMelScale(center_mel); + std::vector this_bin(num_fft_bins); + int first_index = -1, last_index = -1; + for (int i = 0; i < num_fft_bins; ++i) { + float freq = (fft_bin_width * i); // Center frequency of this fft + // bin. + float mel = MelScale(freq); + if (mel > left_mel && mel < right_mel) { + float weight; + if (mel <= center_mel) + weight = (mel - left_mel) / (center_mel - left_mel); + else + weight = (right_mel - mel) / (right_mel - center_mel); + this_bin[i] = weight; + if (first_index == -1) first_index = i; + last_index = i; + } + } + CHECK(first_index != -1 && last_index >= first_index); + bins_[bin].first = first_index; + int size = last_index + 1 - first_index; + bins_[bin].second.resize(size); + for (int i = 0; i < size; ++i) { + bins_[bin].second[i] = this_bin[first_index + i]; + } + } + + // NOTE(cdliang): add hamming window + hamming_window_.resize(frame_length_); + double a = M_2PI / (frame_length - 1); + for (int i = 0; i < frame_length; i++) { + double i_fl = static_cast(i); + hamming_window_[i] = 0.54 - 0.46 * cos(a * i_fl); + } + } + + void set_use_log(bool use_log) { use_log_ = use_log; } + + void set_remove_dc_offset(bool remove_dc_offset) { + remove_dc_offset_ = remove_dc_offset; + } + + void set_dither(float dither) { dither_ = dither; } + + int num_bins() const { return num_bins_; } + + static inline float InverseMelScale(float mel_freq) { + return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f); + } + + static inline float MelScale(float freq) { + return 1127.0f * logf(1.0f + freq / 700.0f); + } + + static int UpperPowerOfTwo(int n) { + return static_cast(pow(2, ceil(log(n) / log(2)))); + } + + // preemphasis + void PreEmphasis(float coeff, std::vector* data) const { + if (coeff == 0.0) return; + for (int i = data->size() - 1; i > 0; i--) + (*data)[i] -= coeff * (*data)[i - 1]; + (*data)[0] -= coeff * (*data)[0]; + } + + // add hamming window + void Hamming(std::vector* data) const { + CHECK(data->size() >= hamming_window_.size()); + for (size_t i = 0; i < hamming_window_.size(); ++i) { + (*data)[i] *= hamming_window_[i]; + } + } + + // Compute fbank feat, return num frames + int Compute(const std::vector& wave, + std::vector>* feat) { + int num_samples = wave.size(); + if (num_samples < frame_length_) return 0; + int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_); + feat->resize(num_frames); + std::vector fft_real(fft_points_, 0), fft_img(fft_points_, 0); + std::vector power(fft_points_ / 2); + for (int i = 0; i < num_frames; ++i) { + std::vector data(wave.data() + i * frame_shift_, + wave.data() + i * frame_shift_ + frame_length_); + // optional add noise + if (dither_ != 0.0) { + for (size_t j = 0; j < data.size(); ++j) + data[j] += dither_ * distribution_(generator_); + } + // optinal remove dc offset + if (remove_dc_offset_) { + float mean = 0.0; + for (size_t j = 0; j < data.size(); ++j) mean += data[j]; + mean /= data.size(); + for (size_t j = 0; j < data.size(); ++j) data[j] -= mean; + } + + PreEmphasis(0.97, &data); + // Povey(&data); + Hamming(&data); + // copy data to fft_real + memset(fft_img.data(), 0, sizeof(float) * fft_points_); + memset(fft_real.data() + frame_length_, 0, + sizeof(float) * (fft_points_ - frame_length_)); + memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_); + fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(), + fft_points_); + // power + for (int j = 0; j < fft_points_ / 2; ++j) { + power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j]; + } + + (*feat)[i].resize(num_bins_); + // cepstral coefficients, triangle filter array + for (int j = 0; j < num_bins_; ++j) { + float mel_energy = 0.0; + int s = bins_[j].first; + for (size_t k = 0; k < bins_[j].second.size(); ++k) { + mel_energy += bins_[j].second[k] * power[s + k]; + } + // optional use log + if (use_log_) { + if (mel_energy < std::numeric_limits::epsilon()) + mel_energy = std::numeric_limits::epsilon(); + mel_energy = logf(mel_energy); + } + + (*feat)[i][j] = mel_energy; + // printf("%f ", mel_energy); + } + // printf("\n"); + } + return num_frames; + } + + private: + int num_bins_; + int sample_rate_; + int frame_length_, frame_shift_; + int fft_points_; + bool use_log_; + bool remove_dc_offset_; + std::vector center_freqs_; + std::vector>> bins_; + std::vector hamming_window_; + std::default_random_engine generator_; + std::normal_distribution distribution_; + float dither_; + + // bit reversal table + std::vector bitrev_; + // trigonometric function table + std::vector sintbl_; +}; + +} // namespace wenet + +#endif // FRONTEND_FBANK_H_ diff --git a/runtime/core/frontend/feature_pipeline.cc b/runtime/core/frontend/feature_pipeline.cc new file mode 100644 index 0000000..c59d6f7 --- /dev/null +++ b/runtime/core/frontend/feature_pipeline.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2017 Personal (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/feature_pipeline.h" + +#include +#include + +namespace wenet { + +FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config) + : config_(config), + feature_dim_(config.num_bins), + fbank_(config.num_bins, config.sample_rate, config.frame_length, + config.frame_shift), + num_frames_(0), + input_finished_(false) {} + +void FeaturePipeline::AcceptWaveform(const std::vector& wav) { + std::vector> feats; + std::vector waves; + waves.insert(waves.end(), remained_wav_.begin(), remained_wav_.end()); + waves.insert(waves.end(), wav.begin(), wav.end()); + int num_frames = fbank_.Compute(waves, &feats); + for (size_t i = 0; i < feats.size(); ++i) { + feature_queue_.Push(std::move(feats[i])); + } + num_frames_ += num_frames; + + int left_samples = waves.size() - config_.frame_shift * num_frames; + remained_wav_.resize(left_samples); + std::copy(waves.begin() + config_.frame_shift * num_frames, waves.end(), + remained_wav_.begin()); + // We are still adding wave, notify input is not finished + finish_condition_.notify_one(); +} + +void FeaturePipeline::AcceptWaveform(const std::vector& wav) { + std::vector float_wav(wav.size()); + for (size_t i = 0; i < wav.size(); i++) { + float_wav[i] = static_cast(wav[i]); + } + this->AcceptWaveform(float_wav); +} + +void FeaturePipeline::set_input_finished() { + CHECK(!input_finished_); + { + std::lock_guard lock(mutex_); + input_finished_ = true; + } + finish_condition_.notify_one(); +} + +bool FeaturePipeline::ReadOne(std::vector* feat) { + if (!feature_queue_.Empty()) { + *feat = std::move(feature_queue_.Pop()); + return true; + } else { + std::unique_lock lock(mutex_); + while (!input_finished_) { + // This will release the lock and wait for notify_one() + // from AcceptWaveform() or set_input_finished() + finish_condition_.wait(lock); + if (!feature_queue_.Empty()) { + *feat = std::move(feature_queue_.Pop()); + return true; + } + } + CHECK(input_finished_); + // Double check queue.empty, see issue#893 for detailed discussions. + if (!feature_queue_.Empty()) { + *feat = std::move(feature_queue_.Pop()); + return true; + } else { + return false; + } + } +} + +bool FeaturePipeline::Read(int num_frames, + std::vector>* feats) { + feats->clear(); + std::vector feat; + while (feats->size() < num_frames) { + if (ReadOne(&feat)) { + feats->push_back(std::move(feat)); + } else { + return false; + } + } + return true; +} + +void FeaturePipeline::Reset() { + input_finished_ = false; + num_frames_ = 0; + remained_wav_.clear(); + feature_queue_.Clear(); +} + +} // namespace wenet diff --git a/runtime/core/frontend/feature_pipeline.h b/runtime/core/frontend/feature_pipeline.h new file mode 100644 index 0000000..3fdafa6 --- /dev/null +++ b/runtime/core/frontend/feature_pipeline.h @@ -0,0 +1,118 @@ +// Copyright (c) 2017 Personal (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FRONTEND_FEATURE_PIPELINE_H_ +#define FRONTEND_FEATURE_PIPELINE_H_ + +#include +#include +#include +#include + +#include "frontend/fbank.h" +#include "utils/log.h" +#include "utils/blocking_queue.h" + +namespace wenet { + +struct FeaturePipelineConfig { + int num_bins; + int sample_rate; + int frame_length; + int frame_shift; + FeaturePipelineConfig(int num_bins, int sample_rate) + : num_bins(num_bins), // 80 dim fbank + sample_rate(sample_rate) { // 16k sample rate + frame_length = sample_rate / 1000 * 25; // frame length 25ms + frame_shift = sample_rate / 1000 * 10; // frame shift 10ms + } + + void Info() const { + LOG(INFO) << "feature pipeline config" + << " num_bins " << num_bins << " frame_length " << frame_length + << " frame_shift " << frame_shift; + } +}; + +// Typically, FeaturePipeline is used in two threads: one thread A calls +// AcceptWaveform() to add raw wav data and set_input_finished() to notice +// the end of input wav, another thread B (decoder thread) calls Read() to +// consume features.So a BlockingQueue is used to make this class thread safe. + +// The Read() is designed as a blocking method when there is no feature +// in feature_queue_ and the input is not finished. + +class FeaturePipeline { + public: + explicit FeaturePipeline(const FeaturePipelineConfig& config); + + // The feature extraction is done in AcceptWaveform(). + void AcceptWaveform(const std::vector& wav); + void AcceptWaveform(const std::vector& wav); + + // Current extracted frames number. + int num_frames() const { return num_frames_; } + int feature_dim() const { return feature_dim_; } + const FeaturePipelineConfig& config() const { return config_; } + + // The caller should call this method when speech input is end. + // Never call AcceptWaveform() after calling set_input_finished() ! + void set_input_finished(); + bool input_finished() const { return input_finished_; } + + // Return False if input is finished and no feature could be read. + // Return True if a feature is read. + // This function is a blocking method. It will block the thread when + // there is no feature in feature_queue_ and the input is not finished. + bool ReadOne(std::vector* feat); + + // Read #num_frames frame features. + // Return False if less then #num_frames features are read and the + // input is finished. + // Return True if #num_frames features are read. + // This function is a blocking method when there is no feature + // in feature_queue_ and the input is not finished. + bool Read(int num_frames, std::vector>* feats); + + void Reset(); + bool IsLastFrame(int frame) const { + return input_finished_ && (frame == num_frames_ - 1); + } + + int NumQueuedFrames() const { return feature_queue_.Size(); } + + private: + const FeaturePipelineConfig& config_; + int feature_dim_; + Fbank fbank_; + + BlockingQueue> feature_queue_; + int num_frames_; + bool input_finished_; + + // The feature extraction is done in AcceptWaveform(). + // This wavefrom sample points are consumed by frame size. + // The residual wavefrom sample points after framing are + // kept to be used in next AcceptWaveform() calling. + std::vector remained_wav_; + + // Used to block the Read when there is no feature in feature_queue_ + // and the input is not finished. + mutable std::mutex mutex_; + std::condition_variable finish_condition_; +}; + +} // namespace wenet + +#endif // FRONTEND_FEATURE_PIPELINE_H_ diff --git a/runtime/core/frontend/fft.cc b/runtime/core/frontend/fft.cc new file mode 100644 index 0000000..d293203 --- /dev/null +++ b/runtime/core/frontend/fft.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2016 HR + +#include +#include +#include + +#include "frontend/fft.h" + +namespace wenet { + +void make_sintbl(int n, float* sintbl) { + int i, n2, n4, n8; + float c, s, dc, ds, t; + + n2 = n / 2; + n4 = n / 4; + n8 = n / 8; + t = sin(M_PI / n); + dc = 2 * t * t; + ds = sqrt(dc * (2 - dc)); + t = 2 * dc; + c = sintbl[n4] = 1; + s = sintbl[0] = 0; + for (i = 1; i < n8; ++i) { + c -= dc; + dc += t * c; + s += ds; + ds -= t * s; + sintbl[i] = s; + sintbl[n4 - i] = c; + } + if (n8 != 0) sintbl[n8] = sqrt(0.5); + for (i = 0; i < n4; ++i) sintbl[n2 - i] = sintbl[i]; + for (i = 0; i < n2 + n4; ++i) sintbl[i + n2] = -sintbl[i]; +} + +void make_bitrev(int n, int* bitrev) { + int i, j, k, n2; + + n2 = n / 2; + i = j = 0; + for (;;) { + bitrev[i] = j; + if (++i >= n) break; + k = n2; + while (k <= j) { + j -= k; + k /= 2; + } + j += k; + } +} + +// bitrev: bit reversal table +// sintbl: trigonometric function table +// x:real part +// y:image part +// n: fft length +int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n) { + int i, j, k, ik, h, d, k2, n4, inverse; + float t, s, c, dx, dy; + + /* preparation */ + if (n < 0) { + n = -n; + inverse = 1; /* inverse transform */ + } else { + inverse = 0; + } + n4 = n / 4; + if (n == 0) { + return 0; + } + + /* bit reversal */ + for (i = 0; i < n; ++i) { + j = bitrev[i]; + if (i < j) { + t = x[i]; + x[i] = x[j]; + x[j] = t; + t = y[i]; + y[i] = y[j]; + y[j] = t; + } + } + + /* transformation */ + for (k = 1; k < n; k = k2) { + h = 0; + k2 = k + k; + d = n / k2; + for (j = 0; j < k; ++j) { + c = sintbl[h + n4]; + if (inverse) + s = -sintbl[h]; + else + s = sintbl[h]; + for (i = j; i < n; i += k2) { + ik = i + k; + dx = s * y[ik] + c * x[ik]; + dy = c * y[ik] - s * x[ik]; + x[ik] = x[i] - dx; + x[i] += dx; + y[ik] = y[i] - dy; + y[i] += dy; + } + h += d; + } + } + if (inverse) { + /* divide by n in case of the inverse transformation */ + for (i = 0; i < n; ++i) { + x[i] /= n; + y[i] /= n; + } + } + return 0; /* finished successfully */ +} + +} // namespace wenet diff --git a/runtime/core/frontend/fft.h b/runtime/core/frontend/fft.h new file mode 100644 index 0000000..5015311 --- /dev/null +++ b/runtime/core/frontend/fft.h @@ -0,0 +1,25 @@ +// Copyright (c) 2016 HR + +#ifndef FRONTEND_FFT_H_ +#define FRONTEND_FFT_H_ + +#ifndef M_PI +#define M_PI 3.1415926535897932384626433832795 +#endif +#ifndef M_2PI +#define M_2PI 6.283185307179586476925286766559005 +#endif + +namespace wenet { + +// Fast Fourier Transform + +void make_sintbl(int n, float* sintbl); + +void make_bitrev(int n, int* bitrev); + +int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n); + +} // namespace wenet + +#endif // FRONTEND_FFT_H_ diff --git a/runtime/core/frontend/wav.h b/runtime/core/frontend/wav.h new file mode 100644 index 0000000..791a128 --- /dev/null +++ b/runtime/core/frontend/wav.h @@ -0,0 +1,203 @@ +// Copyright (c) 2016 Personal (Binbin Zhang) +// Created on 2016-08-15 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FRONTEND_WAV_H_ +#define FRONTEND_WAV_H_ + +#include +#include +#include +#include +#include + +#include + +#include "utils/log.h" + +namespace wenet { + +struct WavHeader { + char riff[4]; // "riff" + unsigned int size; + char wav[4]; // "WAVE" + char fmt[4]; // "fmt " + unsigned int fmt_size; + uint16_t format; + uint16_t channels; + unsigned int sample_rate; + unsigned int bytes_per_second; + uint16_t block_size; + uint16_t bit; + char data[4]; // "data" + unsigned int data_size; +}; + +class WavReader { + public: + WavReader() : data_(nullptr) {} + explicit WavReader(const std::string& filename) { Open(filename); } + + bool Open(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "rb"); + if (NULL == fp) { + LOG(WARNING) << "Error in read " << filename; + return false; + } + + WavHeader header; + fread(&header, 1, sizeof(header), fp); + if (header.fmt_size < 16) { + fprintf(stderr, + "WaveData: expect PCM format data " + "to have fmt chunk of at least size 16.\n"); + return false; + } else if (header.fmt_size > 16) { + int offset = 44 - 8 + header.fmt_size - 16; + fseek(fp, offset, SEEK_SET); + fread(header.data, 8, sizeof(char), fp); + } + // check "riff" "WAVE" "fmt " "data" + + // Skip any subchunks between "fmt" and "data". Usually there will + // be a single "fact" subchunk, but on Windows there can also be a + // "list" subchunk. + while (0 != strncmp(header.data, "data", 4)) { + // We will just ignore the data in these chunks. + fseek(fp, header.data_size, SEEK_CUR); + // read next subchunk + fread(header.data, 8, sizeof(char), fp); + } + + num_channel_ = header.channels; + sample_rate_ = header.sample_rate; + bits_per_sample_ = header.bit; + int num_data = header.data_size / (bits_per_sample_ / 8); + data_ = new float[num_data]; + num_samples_ = num_data / num_channel_; + + for (int i = 0; i < num_data; ++i) { + switch (bits_per_sample_) { + case 8: { + char sample; + fread(&sample, 1, sizeof(char), fp); + data_[i] = static_cast(sample); + break; + } + case 16: { + int16_t sample; + fread(&sample, 1, sizeof(int16_t), fp); + data_[i] = static_cast(sample); + break; + } + case 32: { + int sample; + fread(&sample, 1, sizeof(int), fp); + data_[i] = static_cast(sample); + break; + } + default: + fprintf(stderr, "unsupported quantization bits"); + exit(1); + } + } + fclose(fp); + return true; + } + + int num_channel() const { return num_channel_; } + int sample_rate() const { return sample_rate_; } + int bits_per_sample() const { return bits_per_sample_; } + int num_samples() const { return num_samples_; } + + ~WavReader() { + if (data_ != NULL) delete[] data_; + } + + const float* data() const { return data_; } + + private: + int num_channel_; + int sample_rate_; + int bits_per_sample_; + int num_samples_; // sample points per channel + float* data_; +}; + +class WavWriter { + public: + WavWriter(const float* data, int num_samples, int num_channel, + int sample_rate, int bits_per_sample) + : data_(data), + num_samples_(num_samples), + num_channel_(num_channel), + sample_rate_(sample_rate), + bits_per_sample_(bits_per_sample) {} + + void Write(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "w"); + // init char 'riff' 'WAVE' 'fmt ' 'data' + WavHeader header; + char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, + 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, + 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00}; + memcpy(&header, wav_header, sizeof(header)); + header.channels = num_channel_; + header.bit = bits_per_sample_; + header.sample_rate = sample_rate_; + header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8); + header.size = sizeof(header) - 8 + header.data_size; + header.bytes_per_second = + sample_rate_ * num_channel_ * (bits_per_sample_ / 8); + header.block_size = num_channel_ * (bits_per_sample_ / 8); + + fwrite(&header, 1, sizeof(header), fp); + + for (int i = 0; i < num_samples_; ++i) { + for (int j = 0; j < num_channel_; ++j) { + switch (bits_per_sample_) { + case 8: { + char sample = static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 16: { + int16_t sample = static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 32: { + int sample = static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + } + } + } + fclose(fp); + } + + private: + const float* data_; + int num_samples_; // total float points in data_ + int num_channel_; + int sample_rate_; + int bits_per_sample_; +}; + +} // namespace wenet + +#endif // FRONTEND_WAV_H_ diff --git a/runtime/core/kws/keyword_spotting.cc b/runtime/core/kws/keyword_spotting.cc new file mode 100644 index 0000000..cf1df8b --- /dev/null +++ b/runtime/core/kws/keyword_spotting.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "kws/keyword_spotting.h" + +#include +#include +#include +#include +#include + +namespace wekws { + +Ort::Env KeywordSpotting::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, ""); +Ort::SessionOptions KeywordSpotting::session_options_ = Ort::SessionOptions(); + +KeywordSpotting::KeywordSpotting(const std::string& model_path) { + // 1. Load sessions + session_ = std::make_shared(env_, model_path.c_str(), + session_options_); + // 2. Model info + in_names_ = {"input", "cache"}; + out_names_ = {"output", "r_cache"}; + auto metadata = session_->GetModelMetadata(); + Ort::AllocatorWithDefaultOptions allocator; + cache_dim_ = std::stoi(metadata.LookupCustomMetadataMap("cache_dim", + allocator)); + cache_len_ = std::stoi(metadata.LookupCustomMetadataMap("cache_len", + allocator)); + std::cout << "Kws Model Info:" << std::endl + << "\tcache_dim: " << cache_dim_ << std::endl + << "\tcache_len: " << cache_len_ << std::endl; + Reset(); +} + + +void KeywordSpotting::Reset() { + Ort::MemoryInfo memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + cache_.resize(cache_dim_ * cache_len_, 0.0); + const int64_t cache_shape[] = {1, cache_dim_, cache_len_}; + cache_ort_ = Ort::Value::CreateTensor( + memory_info, cache_.data(), cache_.size(), cache_shape, 3); +} + + +void KeywordSpotting::Forward( + const std::vector>& feats, + std::vector>* prob) { + prob->clear(); + if (feats.size() == 0) return; + Ort::MemoryInfo memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + // 1. Prepare input + int num_frames = feats.size(); + int feature_dim = feats[0].size(); + std::vector slice_feats; + for (int i = 0; i < feats.size(); i++) { + slice_feats.insert(slice_feats.end(), feats[i].begin(), feats[i].end()); + } + const int64_t feats_shape[3] = {1, num_frames, feature_dim}; + Ort::Value feats_ort = Ort::Value::CreateTensor( + memory_info, slice_feats.data(), slice_feats.size(), feats_shape, 3); + // 2. Ort forward + std::vector inputs; + inputs.emplace_back(std::move(feats_ort)); + inputs.emplace_back(std::move(cache_ort_)); + // ort_outputs.size() == 2 + std::vector ort_outputs = session_->Run( + Ort::RunOptions{nullptr}, in_names_.data(), inputs.data(), + inputs.size(), out_names_.data(), out_names_.size()); + + // 3. Update cache + cache_ort_ = std::move(ort_outputs[1]); + + // 4. Get keyword prob + float* data = ort_outputs[0].GetTensorMutableData(); + auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo(); + int num_outputs = type_info.GetShape()[1]; + int output_dim = type_info.GetShape()[2]; + prob->resize(num_outputs); + for (int i = 0; i < num_outputs; i++) { + (*prob)[i].resize(output_dim); + memcpy((*prob)[i].data(), data + i * output_dim, + sizeof(float) * output_dim); + } +} + +} // namespace wekws diff --git a/runtime/core/kws/keyword_spotting.h b/runtime/core/kws/keyword_spotting.h new file mode 100644 index 0000000..14bf732 --- /dev/null +++ b/runtime/core/kws/keyword_spotting.h @@ -0,0 +1,61 @@ +// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KWS_KEYWORD_SPOTTING_H_ +#define KWS_KEYWORD_SPOTTING_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace wekws { + +class KeywordSpotting { + public: + explicit KeywordSpotting(const std::string& model_path); + // Call reset if keyword is detected + void Reset(); + + static void InitEngineThreads(int num_threads) { + session_options_.SetIntraOpNumThreads(num_threads); + session_options_.SetInterOpNumThreads(num_threads); + } + + void Forward(const std::vector>& feats, + std::vector>* prob); + + private: + // session + static Ort::Env env_; + static Ort::SessionOptions session_options_; + std::shared_ptr session_ = nullptr; + // node names + std::vector in_names_; + std::vector out_names_; + + // meta info + int cache_dim_ = 0; + int cache_len_ = 0; + // cache info + Ort::Value cache_ort_{nullptr}; + std::vector cache_; +}; + + +} // namespace wekws + +#endif // KWS_KEYWORD_SPOTTING_H_ diff --git a/runtime/core/utils/blocking_queue.h b/runtime/core/utils/blocking_queue.h new file mode 100644 index 0000000..b1748f6 --- /dev/null +++ b/runtime/core/utils/blocking_queue.h @@ -0,0 +1,98 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_BLOCKING_QUEUE_H_ +#define UTILS_BLOCKING_QUEUE_H_ + +#include +#include +#include +#include +#include + +namespace wenet { + +#define WENET_DISALLOW_COPY_AND_ASSIGN(Type) \ + Type(const Type&) = delete; \ + Type& operator=(const Type&) = delete; + +template +class BlockingQueue { + public: + explicit BlockingQueue(size_t capacity = std::numeric_limits::max()) + : capacity_(capacity) {} + + void Push(const T& value) { + { + std::unique_lock lock(mutex_); + while (queue_.size() >= capacity_) { + not_full_condition_.wait(lock); + } + queue_.push(value); + } + not_empty_condition_.notify_one(); + } + + void Push(T&& value) { + { + std::unique_lock lock(mutex_); + while (queue_.size() >= capacity_) { + not_full_condition_.wait(lock); + } + queue_.push(std::move(value)); + } + not_empty_condition_.notify_one(); + } + + T Pop() { + std::unique_lock lock(mutex_); + while (queue_.empty()) { + not_empty_condition_.wait(lock); + } + T t(std::move(queue_.front())); + queue_.pop(); + not_full_condition_.notify_one(); + return t; + } + + bool Empty() const { + std::lock_guard lock(mutex_); + return queue_.empty(); + } + + size_t Size() const { + std::lock_guard lock(mutex_); + return queue_.size(); + } + + void Clear() { + while (!Empty()) { + Pop(); + } + } + + private: + size_t capacity_; + mutable std::mutex mutex_; + std::condition_variable not_full_condition_; + std::condition_variable not_empty_condition_; + std::queue queue_; + + public: + WENET_DISALLOW_COPY_AND_ASSIGN(BlockingQueue); +}; + +} // namespace wenet + +#endif // UTILS_BLOCKING_QUEUE_H_ diff --git a/runtime/core/utils/log.h b/runtime/core/utils/log.h new file mode 100644 index 0000000..9d7601c --- /dev/null +++ b/runtime/core/utils/log.h @@ -0,0 +1,83 @@ +// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef UTILS_LOG_H_ +#define UTILS_LOG_H_ + +#include + +#include +#include + +namespace wenet { + +const int INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3; + +class Logger { + public: + Logger(int severity, const char* func, const char* file, int line) { + severity_ = severity; + switch (severity) { + case INFO: + ss_ << "INFO ("; + break; + case WARNING: + ss_ << "WARNING ("; + break; + case ERROR: + ss_ << "ERROR ("; + break; + case FATAL: + ss_ << "FATAL ("; + break; + default: + severity_ = FATAL; + ss_ << "FATAL ("; + } + ss_ << func << "():" << file << ':' << line << ") "; + } + + ~Logger() { + std::cerr << ss_.str() << std::endl << std::flush; + if (severity_ == FATAL) { + abort(); + } + } + + template Logger& operator<<(const T &val) { + ss_ << val; + return *this; + } + + private: + int severity_; + std::ostringstream ss_; +}; + +#define LOG(severity) ::wenet::Logger( \ + ::wenet::severity, __func__, __FILE__, __LINE__) + +#define CHECK(test) \ +do { \ + if (!(test)) { \ + std::cerr << "CHECK (" << __func__ << "():" << __FILE__ << ":" \ + << __LINE__ << ") " << #test << std::endl; \ + exit(-1); \ + } \ +} while (0) + +} // namespace wenet + +#endif // UTILS_LOG_H_ diff --git a/runtime/onnxruntime/CMakeLists.txt b/runtime/onnxruntime/CMakeLists.txt new file mode 100644 index 0000000..4964dcd --- /dev/null +++ b/runtime/onnxruntime/CMakeLists.txt @@ -0,0 +1,31 @@ +cmake_minimum_required(VERSION 3.13 FATAL_ERROR) + +project(wekws VERSION 0.1) + +set(CMAKE_VERBOSE_MAKEFILE on) + +include(FetchContent) +include(ExternalProject) +set(FETCHCONTENT_QUIET OFF) +get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") +set(FETCHCONTENT_BASE_DIR ${fc_base}) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -g -pthread") +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +FetchContent_Declare(onnxruntime + URL https://github.com/microsoft/onnxruntime/releases/download/v1.11.1/onnxruntime-linux-x64-1.11.1.tgz + URL_HASH SHA256=ddc03b5ae325c675ff76a6f18786ce7d310be6eb6f320087f7a0e9228115f24d +) +FetchContent_MakeAvailable(onnxruntime) +include_directories(${onnxruntime_SOURCE_DIR}/include) +link_directories(${onnxruntime_SOURCE_DIR}/lib) + + +add_executable(kws_main + bin/kws_main.cc + kws/keyword_spotting.cc + frontend/feature_pipeline.cc + frontend/fft.cc +) +target_link_libraries(kws_main PUBLIC onnxruntime) diff --git a/runtime/onnxruntime/README.md b/runtime/onnxruntime/README.md new file mode 100644 index 0000000..9cd2c18 --- /dev/null +++ b/runtime/onnxruntime/README.md @@ -0,0 +1,9 @@ +## How to Build? + +``` sh +mkdir build && cd build && cmake .. && cmake --build . +``` + +## How to Use? + +Type `./build/kws_main --help` for usage. diff --git a/runtime/onnxruntime/bin b/runtime/onnxruntime/bin new file mode 120000 index 0000000..938df72 --- /dev/null +++ b/runtime/onnxruntime/bin @@ -0,0 +1 @@ +../core/bin \ No newline at end of file diff --git a/runtime/onnxruntime/frontend b/runtime/onnxruntime/frontend new file mode 120000 index 0000000..0292335 --- /dev/null +++ b/runtime/onnxruntime/frontend @@ -0,0 +1 @@ +../core/frontend \ No newline at end of file diff --git a/runtime/onnxruntime/kws b/runtime/onnxruntime/kws new file mode 120000 index 0000000..2070d5a --- /dev/null +++ b/runtime/onnxruntime/kws @@ -0,0 +1 @@ +../core/kws \ No newline at end of file diff --git a/runtime/onnxruntime/utils b/runtime/onnxruntime/utils new file mode 120000 index 0000000..9e19e7a --- /dev/null +++ b/runtime/onnxruntime/utils @@ -0,0 +1 @@ +../core/utils \ No newline at end of file