diff --git a/runtime/android/app/src/main/cpp/wekws.cc b/runtime/android/app/src/main/cpp/wekws.cc index d7fe24f..f50e622 100644 --- a/runtime/android/app/src/main/cpp/wekws.cc +++ b/runtime/android/app/src/main/cpp/wekws.cc @@ -1,119 +1,119 @@ -// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include - -#include -#include - -#include "frontend/feature_pipeline.h" -#include "kws/keyword_spotting.h" -#include "utils/log.h" - -namespace wekws { -std::shared_ptr spotter; -std::shared_ptr feature_config; -std::shared_ptr feature_pipeline; -std::string result; // NOLINT -int offset; - -void init(JNIEnv* env, jobject, jstring jModelDir) { - const char* pModelDir = env->GetStringUTFChars(jModelDir, nullptr); - - std::string modelPath = std::string(pModelDir) + "/wenwen.ort"; - spotter = std::make_shared(modelPath); - - feature_config = std::make_shared(40, 16000); - feature_pipeline = std::make_shared(*feature_config); -} - -void reset(JNIEnv *env, jobject) { - offset = 0; - result = ""; - spotter->Reset(); -} - -void accept_waveform(JNIEnv *env, jobject, jshortArray jWaveform) { - jsize size = env->GetArrayLength(jWaveform); - int16_t* waveform = env->GetShortArrayElements(jWaveform, 0); - std::vector v(waveform, waveform + size); - feature_pipeline->AcceptWaveform(v); - LOG(INFO) << "wekws accept waveform in ms: " << int(size / 16); -} - -void set_input_finished() { - LOG(INFO) << "wekws input finished"; - feature_pipeline->set_input_finished(); -} - -void spot_thread_func() { - while (true) { - std::vector> feats; - feature_pipeline->Read(80, &feats); - std::vector> prob; - spotter->Forward(feats, &prob); - - float max_hi_xiaowen = 0; - float max_nihao_wenwen = 0; - for (int t = 0; t < prob.size(); t++) { - max_hi_xiaowen = std::max(prob[t][0], max_hi_xiaowen); - max_nihao_wenwen = std::max(prob[t][1], max_nihao_wenwen); - } - float max_prob = max_hi_xiaowen + max_nihao_wenwen; - result = std::to_string(offset) + " prob: " + std::to_string(max_prob); - offset += prob.size(); - } -} - -void start_spot() { - std::thread decode_thread(spot_thread_func); - decode_thread.detach(); -} - -jstring get_result(JNIEnv *env, jobject) { - LOG(INFO) << "wekws ui result: " << result; - return env->NewStringUTF(result.c_str()); -} -} // namespace wekws - -JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) { - JNIEnv *env; - if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { - return JNI_ERR; - } - - jclass c = env->FindClass("cn/org/wenet/wekws/Spot"); - if (c == nullptr) { - return JNI_ERR; - } - - static const JNINativeMethod methods[] = { - {"init", "(Ljava/lang/String;)V", reinterpret_cast(wekws::init)}, - {"reset", "()V", reinterpret_cast(wekws::reset)}, - {"acceptWaveform", "([S)V", - reinterpret_cast(wekws::accept_waveform)}, - {"setInputFinished", "()V", - reinterpret_cast(wekws::set_input_finished)}, - {"startSpot", "()V", reinterpret_cast(wekws::start_spot)}, - {"getResult", "()Ljava/lang/String;", - reinterpret_cast(wekws::get_result)}, - }; - int rc = env->RegisterNatives(c, methods, - sizeof(methods) / sizeof(JNINativeMethod)); - - if (rc != JNI_OK) { - return rc; - } - - return JNI_VERSION_1_6; -} +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include +#include + +#include "frontend/feature_pipeline.h" +#include "kws/keyword_spotting.h" +#include "utils/log.h" + +namespace wekws { +std::shared_ptr spotter; +std::shared_ptr feature_config; +std::shared_ptr feature_pipeline; +std::string result; // NOLINT +int offset; + +void init(JNIEnv* env, jobject, jstring jModelDir) { + const char* pModelDir = env->GetStringUTFChars(jModelDir, nullptr); + + std::string modelPath = std::string(pModelDir) + "/kws.ort"; + spotter = std::make_shared(modelPath); + + feature_config = std::make_shared(40, 16000); + feature_pipeline = std::make_shared(*feature_config); +} + +void reset(JNIEnv* env, jobject) { + offset = 0; + result = ""; + spotter->Reset(); +} + +void accept_waveform(JNIEnv* env, jobject, jshortArray jWaveform) { + jsize size = env->GetArrayLength(jWaveform); + int16_t* waveform = env->GetShortArrayElements(jWaveform, 0); + std::vector v(waveform, waveform + size); + feature_pipeline->AcceptWaveform(v); + LOG(INFO) << "wekws accept waveform in ms: " << int(size / 16); +} + +void set_input_finished() { + LOG(INFO) << "wekws input finished"; + feature_pipeline->set_input_finished(); +} + +void spot_thread_func() { + while (true) { + std::vector> feats; + feature_pipeline->Read(80, &feats); + std::vector> prob; + spotter->Forward(feats, &prob); + + float max_hi_xiaowen = 0; + float max_nihao_wenwen = 0; + for (int t = 0; t < prob.size(); t++) { + max_hi_xiaowen = std::max(prob[t][0], max_hi_xiaowen); + max_nihao_wenwen = std::max(prob[t][1], max_nihao_wenwen); + } + float max_prob = max_hi_xiaowen + max_nihao_wenwen; + result = std::to_string(offset) + " prob: " + std::to_string(max_prob); + offset += prob.size(); + } +} + +void start_spot() { + std::thread decode_thread(spot_thread_func); + decode_thread.detach(); +} + +jstring get_result(JNIEnv* env, jobject) { + LOG(INFO) << "wekws ui result: " << result; + return env->NewStringUTF(result.c_str()); +} +} // namespace wekws + +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + return JNI_ERR; + } + + jclass c = env->FindClass("cn/org/wenet/wekws/Spot"); + if (c == nullptr) { + return JNI_ERR; + } + + static const JNINativeMethod methods[] = { + {"init", "(Ljava/lang/String;)V", reinterpret_cast(wekws::init)}, + {"reset", "()V", reinterpret_cast(wekws::reset)}, + {"acceptWaveform", "([S)V", + reinterpret_cast(wekws::accept_waveform)}, + {"setInputFinished", "()V", + reinterpret_cast(wekws::set_input_finished)}, + {"startSpot", "()V", reinterpret_cast(wekws::start_spot)}, + {"getResult", "()Ljava/lang/String;", + reinterpret_cast(wekws::get_result)}, + }; + int rc = env->RegisterNatives(c, methods, + sizeof(methods) / sizeof(JNINativeMethod)); + + if (rc != JNI_OK) { + return rc; + } + + return JNI_VERSION_1_6; +} diff --git a/runtime/android/app/src/main/java/cn/org/wenet/wekws/MainActivity.java b/runtime/android/app/src/main/java/cn/org/wenet/wekws/MainActivity.java index e1b0d08..1653cf9 100644 --- a/runtime/android/app/src/main/java/cn/org/wenet/wekws/MainActivity.java +++ b/runtime/android/app/src/main/java/cn/org/wenet/wekws/MainActivity.java @@ -34,7 +34,7 @@ public class MainActivity extends AppCompatActivity { private static final String LOG_TAG = "WEKWS"; private static final int SAMPLE_RATE = 16000; // The sampling rate private static final int MAX_QUEUE_SIZE = 2500; // 100 seconds audio, 1 / 0.04 * 100 - private static final List resource = Arrays.asList("wenwen.ort"); + private static final List resource = Arrays.asList("kws.ort"); private boolean startRecord = false; private AudioRecord record = null; diff --git a/runtime/android/build.gradle b/runtime/android/build.gradle new file mode 100644 index 0000000..5ae9a7b --- /dev/null +++ b/runtime/android/build.gradle @@ -0,0 +1,9 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. +plugins { + id 'com.android.application' version '7.2.2' apply false + id 'com.android.library' version '7.2.2' apply false +} + +task clean(type: Delete) { + delete rootProject.buildDir +} \ No newline at end of file