From d91cc95edb904b591efa073cb6a7fb932ec4d35c Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Sun, 4 Sep 2022 18:13:36 +0800 Subject: [PATCH] [android] finished --- .../android/app/src/main/cpp/CMakeLists.txt | 19 ++ runtime/android/app/src/main/cpp/wekws.cc | 119 +++++++++++ .../java/cn/org/wenet/wekws/MainActivity.java | 198 ++++++++++++++++++ .../main/java/cn/org/wenet/wekws/Spot.java | 15 ++ .../app/src/main/res/layout/activity_main.xml | 38 +++- runtime/core/bin/kws_main.cc | 8 +- 6 files changed, 389 insertions(+), 8 deletions(-) create mode 100644 runtime/android/app/src/main/cpp/wekws.cc create mode 100644 runtime/android/app/src/main/java/cn/org/wenet/wekws/Spot.java diff --git a/runtime/android/app/src/main/cpp/CMakeLists.txt b/runtime/android/app/src/main/cpp/CMakeLists.txt index e69de29..589d530 100644 --- a/runtime/android/app/src/main/cpp/CMakeLists.txt +++ b/runtime/android/app/src/main/cpp/CMakeLists.txt @@ -0,0 +1,19 @@ +cmake_minimum_required(VERSION 3.4.1) +project(wekws CXX) +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_VERBOSE_MAKEFILE on) + +set(build_DIR ${CMAKE_SOURCE_DIR}/../../../build) +file(GLOB ONNXRUNTIME_INCLUDE_DIRS "${build_DIR}/onnxruntime*.aar/headers") +file(GLOB ONNXRUNTIME_LINK_DIRS "${build_DIR}/onnxruntime*.aar/jni/${ANDROID_ABI}") +link_directories(${ONNXRUNTIME_LINK_DIRS}) +include_directories(${ONNXRUNTIME_INCLUDE_DIRS}) + +include_directories(${CMAKE_SOURCE_DIR}) +add_library(wekws SHARED + frontend/feature_pipeline.cc + frontend/fft.cc + kws/keyword_spotting.cc + wekws.cc +) +target_link_libraries(wekws PUBLIC onnxruntime) diff --git a/runtime/android/app/src/main/cpp/wekws.cc b/runtime/android/app/src/main/cpp/wekws.cc new file mode 100644 index 0000000..5fcbee8 --- /dev/null +++ b/runtime/android/app/src/main/cpp/wekws.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include +#include + +#include "frontend/feature_pipeline.h" +#include "kws/keyword_spotting.h" +#include "utils/log.h" + +namespace wekws { +std::shared_ptr spotter; +std::shared_ptr feature_config; +std::shared_ptr feature_pipeline; +std::string result; // NOLINT +int offset; + +void init(JNIEnv* env, jobject, jstring jModelDir) { + const char* pModelDir = env->GetStringUTFChars(jModelDir, nullptr); + + std::string modelPath = std::string(pModelDir) + "/wenwen.ort"; + spotter = std::make_shared(modelPath); + + feature_config = std::make_shared(40, 16000); + feature_pipeline = std::make_shared(*feature_config); +} + +void reset(JNIEnv *env, jobject) { + offset = 0; + result = ""; + spotter->Reset(); +} + +void accept_waveform(JNIEnv *env, jobject, jshortArray jWaveform) { + jsize size = env->GetArrayLength(jWaveform); + int16_t* waveform = env->GetShortArrayElements(jWaveform, 0); + std::vector v(waveform, waveform + size); + feature_pipeline->AcceptWaveform(v); + LOG(INFO) << "wekws accept waveform in ms: " << int(size / 16); +} + +void set_input_finished() { + LOG(INFO) << "wekws input finished"; + feature_pipeline->set_input_finished(); +} + +void spot_thread_func() { + while (true) { + std::vector> feats; + feature_pipeline->Read(80, &feats); + std::vector> prob; + spotter->Forward(feats, &prob); + + float max_hi_xiaowen = 0; + float max_nihao_wenwen = 0; + for (int t = 0; t < prob.size(); t++) { + max_hi_xiaowen = std::max(prob[t][0], max_hi_xiaowen); + max_nihao_wenwen = std::max(prob[t][1], max_nihao_wenwen); + } + float detect_prob = max_hi_xiaowen + max_nihao_wenwen; + result = std::to_string(offset) + "prob: " + std::to_string(detect_prob); + offset += prob.size(); + } +} + +void start_spot() { + std::thread decode_thread(spot_thread_func); + decode_thread.detach(); +} + +jstring get_result(JNIEnv *env, jobject) { + LOG(INFO) << "wekws ui result: " << result; + return env->NewStringUTF(result.c_str()); +} +} // namespace wekws + +JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *) { + JNIEnv *env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + return JNI_ERR; + } + + jclass c = env->FindClass("cn/org/wenet/wekws/Spot"); + if (c == nullptr) { + return JNI_ERR; + } + + static const JNINativeMethod methods[] = { + {"init", "(Ljava/lang/String;)V", reinterpret_cast(wekws::init)}, + {"reset", "()V", reinterpret_cast(wekws::reset)}, + {"acceptWaveform", "([S)V", + reinterpret_cast(wekws::accept_waveform)}, + {"setInputFinished", "()V", + reinterpret_cast(wekws::set_input_finished)}, + {"startSpot", "()V", reinterpret_cast(wekws::start_spot)}, + {"getResult", "()Ljava/lang/String;", + reinterpret_cast(wekws::get_result)}, + }; + int rc = env->RegisterNatives(c, methods, + sizeof(methods) / sizeof(JNINativeMethod)); + + if (rc != JNI_OK) { + return rc; + } + + return JNI_VERSION_1_6; +} diff --git a/runtime/android/app/src/main/java/cn/org/wenet/wekws/MainActivity.java b/runtime/android/app/src/main/java/cn/org/wenet/wekws/MainActivity.java index 7457091..e1b0d08 100644 --- a/runtime/android/app/src/main/java/cn/org/wenet/wekws/MainActivity.java +++ b/runtime/android/app/src/main/java/cn/org/wenet/wekws/MainActivity.java @@ -1,14 +1,212 @@ package cn.org.wenet.wekws; import androidx.appcompat.app.AppCompatActivity; +import androidx.core.app.ActivityCompat; +import androidx.core.content.ContextCompat; +import android.Manifest; +import android.content.Context; +import android.content.pm.PackageManager; +import android.content.res.AssetManager; +import android.media.AudioFormat; +import android.media.AudioRecord; +import android.media.MediaRecorder; import android.os.Bundle; +import android.os.Process; +import android.util.Log; +import android.widget.Button; +import android.widget.TextView; +import android.widget.Toast; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; public class MainActivity extends AppCompatActivity { + private final int MY_PERMISSIONS_RECORD_AUDIO = 1; + private static final String LOG_TAG = "WEKWS"; + private static final int SAMPLE_RATE = 16000; // The sampling rate + private static final int MAX_QUEUE_SIZE = 2500; // 100 seconds audio, 1 / 0.04 * 100 + private static final List resource = Arrays.asList("wenwen.ort"); + + private boolean startRecord = false; + private AudioRecord record = null; + private int miniBufferSize = 0; // 1280 bytes 648 byte 40ms, 0.04s + private final BlockingQueue bufferQueue = new ArrayBlockingQueue<>(MAX_QUEUE_SIZE); + + public static void assetsInit(Context context) throws IOException { + AssetManager assetMgr = context.getAssets(); + // Unzip all files in resource from assets to context. + // Note: Uninstall the APP will remove the resource files in the context. + for (String file : assetMgr.list("")) { + if (resource.contains(file)) { + File dst = new File(context.getFilesDir(), file); + if (!dst.exists() || dst.length() == 0) { + Log.i(LOG_TAG, "Unzipping " + file + " to " + dst.getAbsolutePath()); + InputStream is = assetMgr.open(file); + OutputStream os = new FileOutputStream(dst); + byte[] buffer = new byte[4 * 1024]; + int read; + while ((read = is.read(buffer)) != -1) { + os.write(buffer, 0, read); + } + os.flush(); + } + } + } + } + + @Override + public void onRequestPermissionsResult(int requestCode, + String[] permissions, int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + if (requestCode == MY_PERMISSIONS_RECORD_AUDIO) { + if (grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED) { + Log.i(LOG_TAG, "record permission is granted"); + initRecorder(); + } else { + Toast.makeText(this, "Permissions denied to record audio", Toast.LENGTH_LONG).show(); + Button button = findViewById(R.id.button); + button.setEnabled(false); + } + } + } + @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); + requestAudioPermissions(); + try { + assetsInit(this); + } catch (IOException e) { + Log.e(LOG_TAG, "Error process asset files to file path"); + } + + TextView textView = findViewById(R.id.textView); + textView.setText(""); + Spot.init(getFilesDir().getPath()); + + Button button = findViewById(R.id.button); + button.setText("Start Record"); + button.setOnClickListener(view -> { + if (!startRecord) { + startRecord = true; + startRecordThread(); + startSpotThread(); + Spot.reset(); + Spot.startSpot(); + button.setText("Stop Record"); + } else { + startRecord = false; + button.setText("Start Record"); + } + }); + } + + private void requestAudioPermissions() { + if (ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) + != PackageManager.PERMISSION_GRANTED) { + ActivityCompat.requestPermissions(this, + new String[]{Manifest.permission.RECORD_AUDIO}, + MY_PERMISSIONS_RECORD_AUDIO); + } else { + initRecorder(); + } + } + + private void initRecorder() { + // buffer size in bytes 1280 + miniBufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT); + if (miniBufferSize == AudioRecord.ERROR || miniBufferSize == AudioRecord.ERROR_BAD_VALUE) { + Log.e(LOG_TAG, "Audio buffer can't initialize!"); + return; + } + if (ActivityCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) != PackageManager.PERMISSION_GRANTED) { + // TODO: Consider calling + // ActivityCompat#requestPermissions + // here to request the missing permissions, and then overriding + // public void onRequestPermissionsResult(int requestCode, String[] permissions, + // int[] grantResults) + // to handle the case where the user grants the permission. See the documentation + // for ActivityCompat#requestPermissions for more details. + return; + } + record = new AudioRecord(MediaRecorder.AudioSource.DEFAULT, + SAMPLE_RATE, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + miniBufferSize); + if (record.getState() != AudioRecord.STATE_INITIALIZED) { + Log.e(LOG_TAG, "Audio Record can't initialize!"); + return; + } + Log.i(LOG_TAG, "Record init okay"); + } + + private void startRecordThread() { + new Thread(() -> { + VoiceRectView voiceView = findViewById(R.id.voiceRectView); + record.startRecording(); + Process.setThreadPriority(Process.THREAD_PRIORITY_AUDIO); + while (startRecord) { + short[] buffer = new short[miniBufferSize / 2]; + int read = record.read(buffer, 0, buffer.length); + voiceView.add(calculateDb(buffer)); + try { + if (AudioRecord.ERROR_INVALID_OPERATION != read) { + bufferQueue.put(buffer); + } + } catch (InterruptedException e) { + Log.e(LOG_TAG, e.getMessage()); + } + Button button = findViewById(R.id.button); + if (!button.isEnabled() && startRecord) { + runOnUiThread(() -> button.setEnabled(true)); + } + } + record.stop(); + voiceView.zero(); + }).start(); + } + + private double calculateDb(short[] buffer) { + double energy = 0.0; + for (short value : buffer) { + energy += value * value; + } + energy /= buffer.length; + energy = (10 * Math.log10(1 + energy)) / 100; + energy = Math.min(energy, 1.0); + return energy; + } + + private void startSpotThread() { + new Thread(() -> { + // Send all data + while (startRecord || bufferQueue.size() > 0) { + try { + short[] data = bufferQueue.take(); + // 1. add data to C++ interface + Spot.acceptWaveform(data); + // 2. get partial result + runOnUiThread(() -> { + TextView textView = findViewById(R.id.textView); + textView.setText(Spot.getResult()); + }); + } catch (InterruptedException e) { + Log.e(LOG_TAG, e.getMessage()); + } + } + }).start(); } } \ No newline at end of file diff --git a/runtime/android/app/src/main/java/cn/org/wenet/wekws/Spot.java b/runtime/android/app/src/main/java/cn/org/wenet/wekws/Spot.java new file mode 100644 index 0000000..fc38ec8 --- /dev/null +++ b/runtime/android/app/src/main/java/cn/org/wenet/wekws/Spot.java @@ -0,0 +1,15 @@ +package cn.org.wenet.wekws; + +public class Spot { + + static { + System.loadLibrary("wekws"); + } + + public static native void init(String modelDir); + public static native void reset(); + public static native void acceptWaveform(short[] waveform); + public static native void setInputFinished(); + public static native void startSpot(); + public static native String getResult(); +} diff --git a/runtime/android/app/src/main/res/layout/activity_main.xml b/runtime/android/app/src/main/res/layout/activity_main.xml index 17eab17..c61b7d0 100644 --- a/runtime/android/app/src/main/res/layout/activity_main.xml +++ b/runtime/android/app/src/main/res/layout/activity_main.xml @@ -2,17 +2,49 @@ + app:layout_constraintTop_toTopOf="parent" + app:layout_constraintVertical_bias="0.08" /> + + + +