diff --git a/examples/hey_snips/s0/run.sh b/examples/hey_snips/s0/run.sh index 7ff7f61..28e4fc5 100755 --- a/examples/hey_snips/s0/run.sh +++ b/examples/hey_snips/s0/run.sh @@ -117,10 +117,15 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then - python wekws/bin/export_jit.py --config $dir/config.yaml \ +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') + onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g') + python wekws/bin/export_jit.py \ + --config $dir/config.yaml \ --checkpoint $score_checkpoint \ - --output_file $dir/final.zip \ - --output_quant_file $dir/final.quant.zip + --jit_model $dir/$jit_model + python wekws/bin/export_onnx.py \ + --config $dir/config.yaml \ + --checkpoint $score_checkpoint \ + --onnx_model $dir/$onnx_model fi - diff --git a/examples/hi_xiaowen/s0/run.sh b/examples/hi_xiaowen/s0/run.sh index a92c62b..9573302 100755 --- a/examples/hi_xiaowen/s0/run.sh +++ b/examples/hi_xiaowen/s0/run.sh @@ -115,39 +115,6 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then - echo "Static quantization, compute FRR/FAR..." - # Apply static quantization - quantize_score_checkpoint=$(basename $score_checkpoint | sed -e 's:.pt$:.quant.zip:g') - cat data/train/data.list | python tools/shuffle_list.py --seed 777 | \ - head -n 10000 > $dir/calibration.list - python wekws/bin/static_quantize.py \ - --config $dir/config.yaml \ - --test_data $dir/calibration.list \ - --checkpoint $score_checkpoint \ - --num_workers 8 \ - --script_model $dir/$quantize_score_checkpoint - - result_dir=$dir/test_$(basename $quantize_score_checkpoint) - mkdir -p $result_dir - python wekws/bin/score.py \ - --config $dir/config.yaml \ - --test_data data/test/data.list \ - --batch_size 256 \ - --jit_model \ - --checkpoint $dir/$quantize_score_checkpoint \ - --score_file $result_dir/score.txt \ - --num_workers 8 - for keyword in 0 1; do - python wekws/bin/compute_det.py \ - --keyword $keyword \ - --test_data data/test/data.list \ - --score_file $result_dir/score.txt \ - --stats_file $result_dir/stats.${keyword}.txt - done -fi - - -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g') python wekws/bin/export_jit.py \ diff --git a/runtime/core/bin/kws_main.cc b/runtime/core/bin/kws_main.cc index 24c06d2..16e5bb5 100644 --- a/runtime/core/bin/kws_main.cc +++ b/runtime/core/bin/kws_main.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. + #include #include @@ -49,10 +50,11 @@ int main(int argc, char* argv[]) { std::vector> prob; spotter.Forward(feats, &prob); for (int i = 0; i < prob.size(); i++) { - if (prob[i][0] > 0.1 || prob[i][1] > 0.1) { - std::cout << "frame " << offset + i << " prob " << prob[i][0] << " " - << prob[i][1] << std::endl; + std::cout << "frame " << offset + i << " prob"; + for (int j = 0; j < prob[i].size(); j++) { + std::cout << " " << prob[i][j]; } + std::cout << std::endl; } // Reach the end of feature pipeline if (!ok) break;