add examples for speech command dataset
This commit is contained in:
parent
1d5f42f0a3
commit
e0f6e1d5ed
49
examples/speechcommand_v1/s0/conf/mdtc.yaml
Normal file
49
examples/speechcommand_v1/s0/conf/mdtc.yaml
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
dataset_conf:
|
||||||
|
filter_conf:
|
||||||
|
max_length: 2048
|
||||||
|
min_length: 0
|
||||||
|
resample_conf:
|
||||||
|
resample_rate: 16000
|
||||||
|
speed_perturb: false
|
||||||
|
feature_extraction_conf:
|
||||||
|
feature_type: 'mfcc'
|
||||||
|
num_ceps: 80
|
||||||
|
num_mel_bins: 80
|
||||||
|
frame_shift: 10
|
||||||
|
frame_length: 25
|
||||||
|
dither: 1.0
|
||||||
|
feature_dither: 0.0
|
||||||
|
spec_aug: true
|
||||||
|
spec_aug_conf:
|
||||||
|
num_t_mask: 1
|
||||||
|
num_f_mask: 1
|
||||||
|
max_t: 10
|
||||||
|
max_f: 40
|
||||||
|
shuffle: true
|
||||||
|
shuffle_conf:
|
||||||
|
shuffle_size: 1500
|
||||||
|
batch_conf:
|
||||||
|
batch_size: 100
|
||||||
|
|
||||||
|
model:
|
||||||
|
hidden_dim: 64
|
||||||
|
preprocessing:
|
||||||
|
type: none
|
||||||
|
backbone:
|
||||||
|
type: mdtc
|
||||||
|
num_stack: 4
|
||||||
|
stack_size: 4
|
||||||
|
kernel_size: 5
|
||||||
|
hidden_dim: 64
|
||||||
|
classifier:
|
||||||
|
type: last
|
||||||
|
|
||||||
|
optim: adam
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.001
|
||||||
|
|
||||||
|
training_config:
|
||||||
|
grad_clip: 5
|
||||||
|
max_epoch: 100
|
||||||
|
log_interval: 10
|
||||||
|
criterion: CE
|
||||||
1
examples/speechcommand_v1/s0/kws
Symbolic link
1
examples/speechcommand_v1/s0/kws
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../kws
|
||||||
30
examples/speechcommand_v1/s0/local/data_download.sh
Executable file
30
examples/speechcommand_v1/s0/local/data_download.sh
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Copyright 2021 Jingyong Hou (houjingyong@gmail.com)
|
||||||
|
[ -f ./path.sh ] && . ./path.sh
|
||||||
|
|
||||||
|
dl_dir=./data/local
|
||||||
|
|
||||||
|
. tools/parse_options.sh || exit 1;
|
||||||
|
data_dir=$dl_dir
|
||||||
|
file_name=speech_commands_v0.01.tar.gz
|
||||||
|
speech_command_dir=$data_dir/speech_commands_v1
|
||||||
|
audio_dir=$data_dir/speech_commands_v1/audio
|
||||||
|
url=http://download.tensorflow.org/data/$file_name
|
||||||
|
mkdir -p $data_dir
|
||||||
|
if [ ! -f $data_dir/$file_name ]; then
|
||||||
|
echo "downloading $url..."
|
||||||
|
wget -O $data_dir/$file_name $url
|
||||||
|
else
|
||||||
|
echo "$file_name exist in $data_dir, skip download it"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $speech_command_dir/.extracted ]; then
|
||||||
|
mkdir -p $audio_dir
|
||||||
|
tar -xzvf $data_dir/$file_name -C $audio_dir
|
||||||
|
touch $speech_command_dir/.extracted
|
||||||
|
else
|
||||||
|
echo "$speech_command_dir/.exatracted exist in $speech_command_dir, skip exatraction"
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
||||||
29
examples/speechcommand_v1/s0/local/prepare_speech_command.py
Normal file
29
examples/speechcommand_v1/s0/local/prepare_speech_command.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(', ')
|
||||||
|
CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
|
||||||
|
|
||||||
|
if __name__=='__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='prepare kaldi format file for google speech command dataset ')
|
||||||
|
parser.add_argument('--wav_list', required=True, help='wave list is a file containts full path of a wav file in google speech command dataset')
|
||||||
|
parser.add_argument('--data_dir', required=True, help='folder to write kaldi format files')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
data_dir = args.data_dir
|
||||||
|
f_wav_scp = open(os.path.join(data_dir,'wav.scp'), 'w')
|
||||||
|
f_text = open(os.path.join(data_dir, 'text'), 'w')
|
||||||
|
with open(args.wav_list) as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
keyword, file_name = line.strip().split('/')[-2:]
|
||||||
|
file_name_new = file_name.split('.')[0]
|
||||||
|
wav_id = '_'.join([keyword, file_name_new])
|
||||||
|
file_dir = line.strip()
|
||||||
|
f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n')
|
||||||
|
label = CLASS_TO_IDX[keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"]
|
||||||
|
f_text.writelines(wav_id + ' ' + str(label) + '\n')
|
||||||
|
f_wav_scp.close()
|
||||||
|
f_text.close()
|
||||||
|
|
||||||
|
|
||||||
37
examples/speechcommand_v1/s0/local/split_dataset.py
Executable file
37
examples/speechcommand_v1/s0/local/split_dataset.py
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
"""Splits the google speech commands into train, validation and test set """
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def move_files(src_folder, to_folder, list_file):
|
||||||
|
with open(list_file) as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
line = line.rstrip()
|
||||||
|
dirname = os.path.dirname(line)
|
||||||
|
dest = os.path.join(to_folder, dirname)
|
||||||
|
if not os.path.exists(dest):
|
||||||
|
os.mkdir(dest)
|
||||||
|
shutil.move(os.path.join(src_folder, line),dest)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Split google command dataset.')
|
||||||
|
parser.add_argument('root', type=str, help='the path to the root folder of the google commands dataset')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
audio_folder = os.path.join(args.root, 'audio')
|
||||||
|
validation_path = os.path.join(audio_folder, 'validation_list.txt')
|
||||||
|
test_path = os.path.join(audio_folder, 'testing_list.txt')
|
||||||
|
|
||||||
|
valid_folder = os.path.join(args.root, 'valid')
|
||||||
|
test_folder = os.path.join(args.root, 'test')
|
||||||
|
train_folder = os.path.join(args.root, 'train')
|
||||||
|
|
||||||
|
os.mkdir(valid_folder)
|
||||||
|
os.mkdir(test_folder)
|
||||||
|
|
||||||
|
move_files(audio_folder, test_folder, test_path)
|
||||||
|
move_files(audio_folder, valid_folder, validation_path)
|
||||||
|
os.rename(audio_folder, train_folder)
|
||||||
|
|
||||||
1
examples/speechcommand_v1/s0/path.sh
Symbolic link
1
examples/speechcommand_v1/s0/path.sh
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../hi_xiaowen/s0/path.sh
|
||||||
107
examples/speechcommand_v1/s0/run.sh
Normal file
107
examples/speechcommand_v1/s0/run.sh
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Binbin Zhang
|
||||||
|
# Jingyong Hou
|
||||||
|
|
||||||
|
. ./path.sh
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
|
||||||
|
stage=2
|
||||||
|
stop_stage=2
|
||||||
|
num_keywords=11
|
||||||
|
|
||||||
|
config=conf/mdtc.yaml
|
||||||
|
norm_mean=false
|
||||||
|
norm_var=false
|
||||||
|
gpu_id=0
|
||||||
|
|
||||||
|
checkpoint=
|
||||||
|
dir=exp/mdtc
|
||||||
|
|
||||||
|
num_average=10
|
||||||
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
|
||||||
|
download_dir=./data/local # your data dir
|
||||||
|
speech_command_dir=$download_dir/speech_commands_v1
|
||||||
|
. tools/parse_options.sh || exit 1;
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||||
|
echo "Download and extract all datasets"
|
||||||
|
local/data_download.sh --dl_dir $download_dir
|
||||||
|
python local/split_dataset.py $download_dir/speech_commands_v1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
echo "Start preparing Kaldi format files"
|
||||||
|
for x in train test valid;
|
||||||
|
do
|
||||||
|
data=data/$x
|
||||||
|
mkdir -p $data
|
||||||
|
# make wav.scp utt2spk text file
|
||||||
|
find $speech_command_dir/$x -name *.wav | grep -v "_background_noise_" > $data/wav.list
|
||||||
|
python local/prepare_speech_command.py --wav_list=$data/wav.list --data_dir=$data
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
echo "Compute CMVN and Format datasets"
|
||||||
|
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
|
||||||
|
--in_scp data/train/wav.scp \
|
||||||
|
--out_cmvn data/train/global_cmvn
|
||||||
|
|
||||||
|
for x in train valid test; do
|
||||||
|
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur
|
||||||
|
tools/make_list.py data/$x/wav.scp data/$x/text \
|
||||||
|
data/$x/wav.dur data/$x/data.list
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
echo "Start training ..."
|
||||||
|
mkdir -p $dir
|
||||||
|
cmvn_opts=
|
||||||
|
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
|
||||||
|
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
|
||||||
|
python kws/bin/train.py --gpu $gpu_id \
|
||||||
|
--config $config \
|
||||||
|
--train_data data/train/data.list \
|
||||||
|
--cv_data data/valid/data.list \
|
||||||
|
--model_dir $dir \
|
||||||
|
--num_workers 8 \
|
||||||
|
--num_keywords $num_keywords \
|
||||||
|
--min_duration 50 \
|
||||||
|
$cmvn_opts \
|
||||||
|
${checkpoint:+--checkpoint $checkpoint}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# Do model average
|
||||||
|
python kws/bin/average_model.py \
|
||||||
|
--dst_model $score_checkpoint \
|
||||||
|
--src_path $dir \
|
||||||
|
--num ${num_average} \
|
||||||
|
--val_best
|
||||||
|
|
||||||
|
# Compute posterior score
|
||||||
|
result_dir=$dir/test_$(basename $score_checkpoint)
|
||||||
|
mkdir -p $result_dir
|
||||||
|
python kws/bin/score.py --gpu 1 \
|
||||||
|
--config $dir/config.yaml \
|
||||||
|
--test_data data/test/data.list \
|
||||||
|
--batch_size 256 \
|
||||||
|
--checkpoint $score_checkpoint \
|
||||||
|
--score_file $result_dir/score.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
||||||
|
python kws/bin/export_jit.py --config $dir/config.yaml \
|
||||||
|
--checkpoint $score_checkpoint \
|
||||||
|
--output_file $dir/final.zip \
|
||||||
|
--output_quant_file $dir/final.quant.zip
|
||||||
|
fi
|
||||||
1
examples/speechcommand_v1/s0/tools
Symbolic link
1
examples/speechcommand_v1/s0/tools
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../tools
|
||||||
40
kws/model/ce.py
Normal file
40
kws/model/ce.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (c) 2021 Jingyong Hou
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def acc_frame(logits: torch.Tensor, target: torch.Tensor, ):
|
||||||
|
if logits is None:
|
||||||
|
return 0
|
||||||
|
pred = logits.max(1, keepdim=True)[1]
|
||||||
|
correct = pred.eq(target.long().view_as(pred)).sum().item()
|
||||||
|
return correct*100.0/logits.size(0)
|
||||||
|
|
||||||
|
def cross_entropy(logits: torch.Tensor, target: torch.Tensor,
|
||||||
|
lengths: torch.Tensor, min_duration: int = 0):
|
||||||
|
""" Cross Entropy Loss
|
||||||
|
Attributes:
|
||||||
|
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
||||||
|
target: (B)
|
||||||
|
lengths: (B)
|
||||||
|
min_duration: min duration of the keyword
|
||||||
|
Returns:
|
||||||
|
(float): loss of current batch
|
||||||
|
(float): accuracy of current batch
|
||||||
|
"""
|
||||||
|
cross_entropy = nn.CrossEntropyLoss()
|
||||||
|
loss = cross_entropy(logits, target)
|
||||||
|
acc = acc_frame(logits, target)
|
||||||
|
return loss, acc
|
||||||
28
kws/model/classifier.py
Normal file
28
kws/model/classifier.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class GlobalClassifier(nn.Module):
|
||||||
|
"""Add a global average pooling before the classifier"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: nn.Module
|
||||||
|
):
|
||||||
|
super(GlobalClassifier, self).__init__()
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = torch.mean(x, dim=1)
|
||||||
|
return self.classifier(x)
|
||||||
|
|
||||||
|
class LastClassifier(nn.Module):
|
||||||
|
"""Select last frame to do the classification"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classifier: nn.Module
|
||||||
|
):
|
||||||
|
super(LastClassifier, self).__init__()
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
x = x[:, -1, :]
|
||||||
|
return self.classifier(x)
|
||||||
@ -21,9 +21,9 @@ from kws.model.cmvn import GlobalCMVN
|
|||||||
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
|
from kws.model.subsampling import LinearSubsampling1, Conv1dSubsampling1
|
||||||
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
from kws.model.tcn import TCN, CnnBlock, DsCnnBlock
|
||||||
from kws.model.mdtc import MDTC
|
from kws.model.mdtc import MDTC
|
||||||
|
from kws.model.classifier import GlobalClassifier, LastClassifier
|
||||||
from kws.utils.cmvn import load_cmvn
|
from kws.utils.cmvn import load_cmvn
|
||||||
|
|
||||||
|
|
||||||
class KWSModel(torch.nn.Module):
|
class KWSModel(torch.nn.Module):
|
||||||
"""Our model consists of four parts:
|
"""Our model consists of four parts:
|
||||||
1. global_cmvn: Optional, (idim, idim)
|
1. global_cmvn: Optional, (idim, idim)
|
||||||
@ -39,6 +39,7 @@ class KWSModel(torch.nn.Module):
|
|||||||
global_cmvn: Optional[torch.nn.Module],
|
global_cmvn: Optional[torch.nn.Module],
|
||||||
preprocessing: Optional[torch.nn.Module],
|
preprocessing: Optional[torch.nn.Module],
|
||||||
backbone: torch.nn.Module,
|
backbone: torch.nn.Module,
|
||||||
|
classifier: torch.nn.Module
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.idim = idim
|
self.idim = idim
|
||||||
@ -47,7 +48,7 @@ class KWSModel(torch.nn.Module):
|
|||||||
self.global_cmvn = global_cmvn
|
self.global_cmvn = global_cmvn
|
||||||
self.preprocessing = preprocessing
|
self.preprocessing = preprocessing
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.classifier = torch.nn.Linear(hdim, odim)
|
self.classifier = classifier
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if self.global_cmvn is not None:
|
if self.global_cmvn is not None:
|
||||||
@ -56,7 +57,6 @@ class KWSModel(torch.nn.Module):
|
|||||||
x = self.preprocessing(x)
|
x = self.preprocessing(x)
|
||||||
x, _ = self.backbone(x)
|
x, _ = self.backbone(x)
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
x = torch.sigmoid(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -119,9 +119,18 @@ def init_model(configs):
|
|||||||
kernel_size,
|
kernel_size,
|
||||||
causal=True)
|
causal=True)
|
||||||
else:
|
else:
|
||||||
print('Unknown body type {}'.format(backbone_type))
|
print('Unknown backbone type {}'.format(backbone_type))
|
||||||
|
sys.exit(1)
|
||||||
|
classifier_type = configs['classifier']['type']
|
||||||
|
if classifier_type == 'linear':
|
||||||
|
classifier = torch.nn.Linear(hidden_dim, output_dim)
|
||||||
|
elif classifier_type == 'global':
|
||||||
|
classifier = GlobalClassifier(torch.nn.Linear(hidden_dim, output_dim))
|
||||||
|
elif classifier_type == 'last':
|
||||||
|
classifier = LastClassifier(torch.nn.Linear(hidden_dim, output_dim))
|
||||||
|
else:
|
||||||
|
print('Unknown classifier type {}'.format(classifier_type))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
|
||||||
preprocessing, backbone)
|
preprocessing, backbone, classifier)
|
||||||
return kws_model
|
return kws_model
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import torch
|
|||||||
from kws.utils.mask import padding_mask
|
from kws.utils.mask import padding_mask
|
||||||
|
|
||||||
|
|
||||||
def max_polling_loss(logits: torch.Tensor,
|
def max_pooling_loss(logits: torch.Tensor,
|
||||||
target: torch.Tensor,
|
target: torch.Tensor,
|
||||||
lengths: torch.Tensor,
|
lengths: torch.Tensor,
|
||||||
min_duration: int = 0):
|
min_duration: int = 0):
|
||||||
@ -37,6 +37,7 @@ def max_polling_loss(logits: torch.Tensor,
|
|||||||
(float): loss of current batch
|
(float): loss of current batch
|
||||||
(float): accuracy of current batch
|
(float): accuracy of current batch
|
||||||
"""
|
"""
|
||||||
|
logits = torch.sigmoid(logits)
|
||||||
mask = padding_mask(lengths)
|
mask = padding_mask(lengths)
|
||||||
num_utts = logits.size(0)
|
num_utts = logits.size(0)
|
||||||
num_keywords = logits.size(2)
|
num_keywords = logits.size(2)
|
||||||
136
kws/model/max_pooling_RHE.py
Normal file
136
kws/model/max_pooling_RHE.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
# Copyright (c) 2021 Jingyong Hou (houjingyong@gmail.com)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def RHE(indice: torch.Tensor, k: int):
|
||||||
|
"""Regional hard example mining from 'Mining effective negative training samples for keyword spotting'
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
index: indice of
|
||||||
|
k:
|
||||||
|
lengths: (B)
|
||||||
|
min_duration: min duration of the keyword
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): indice of selected regional hard example
|
||||||
|
"""
|
||||||
|
if k <= 0:
|
||||||
|
return indice
|
||||||
|
lenght = len(indice)
|
||||||
|
available_indice = torch.tensor([1] * (lenght))
|
||||||
|
reserve = []
|
||||||
|
for i in range(lenght):
|
||||||
|
if 1 == available_indice[indice[i]]:
|
||||||
|
reserve.append(indice[i])
|
||||||
|
rm_s = max(indice[i] - k, 0)
|
||||||
|
rm_e = min(indice[i] + k, lenght)
|
||||||
|
available_indice[rm_s : rm_e + 1] = 0
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if torch.sum(available_indice) <= 0:
|
||||||
|
break
|
||||||
|
return torch.tensor(reserve).long()
|
||||||
|
|
||||||
|
|
||||||
|
def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float = 10):
|
||||||
|
num_training = 0
|
||||||
|
loss = 0
|
||||||
|
for i in range(len(logits)):
|
||||||
|
output = torch.cat(logits[i])
|
||||||
|
target = torch.LongTensor(np.concatenate(targets[i]))
|
||||||
|
# how many positive targets
|
||||||
|
positive_index = target >= 1 # the label of positive label is 1
|
||||||
|
negative_index = target < 1 # the label of negative label is 0
|
||||||
|
num_p = torch.sum(positive_index)
|
||||||
|
selected_p_output = output[positive_index]
|
||||||
|
loss += torch.sum(torch.log(selected_p_output))
|
||||||
|
|
||||||
|
all_n_output = output[negative_index]
|
||||||
|
num_n = min(int(ratio * num_p), len(all_n_output))
|
||||||
|
_, sorted_index = torch.sort(all_n_output, descending=True)
|
||||||
|
selected_n_output = all_n_output[sorted_index[:num_n]]
|
||||||
|
num_training += len(selected_p_output) + len(selected_n_output)
|
||||||
|
return loss / num_training
|
||||||
|
|
||||||
|
|
||||||
|
def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio=1):
|
||||||
|
|
||||||
|
"""Max-pooling loss with regional hard example mining
|
||||||
|
For each keyword utterance, select the frame with the highest posterior.
|
||||||
|
The keyword is triggered when any of the frames is triggered.
|
||||||
|
For each non-keyword utterance, select several hard examples using the RHE algorithm.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
logits: (B, T, D), D is the number of keywords
|
||||||
|
target: (B)
|
||||||
|
lengths: (B)
|
||||||
|
RHE_thr: how many neighbor logits we remove each time we find a hard examle
|
||||||
|
Returns:
|
||||||
|
(float): loss of current batch
|
||||||
|
(float): accuracy of current batch
|
||||||
|
"""
|
||||||
|
num_hit = 0
|
||||||
|
# Here we clamp the sigmoid output to prevent NaN problem
|
||||||
|
# When we calculate loss
|
||||||
|
logits = torch.clamp(torch.sigmoid(logits), 1e-8, 1.0 - 1e-8)
|
||||||
|
num_utts = logits.size(0)
|
||||||
|
num_keyword = logits.size(2)
|
||||||
|
|
||||||
|
new_logits = []
|
||||||
|
new_targets = []
|
||||||
|
for j in range(num_keyword):
|
||||||
|
new_logits.append([])
|
||||||
|
new_targets.append([])
|
||||||
|
|
||||||
|
for i in range(num_utts):
|
||||||
|
end_idx = lengths[i]
|
||||||
|
for j in range(num_keyword):
|
||||||
|
if targets[i] == j:
|
||||||
|
max_idx = logits[i, :end_idx].argmax()
|
||||||
|
new_logits[j].append(logits[i, max_idx, j])
|
||||||
|
new_targets[j].append([1])
|
||||||
|
if logits[i, max_idx, j] >= 0.5:
|
||||||
|
num_hit += 1
|
||||||
|
else:
|
||||||
|
sorted_logits, sorted_index = torch.sort(logits[i, :end_idx], dim=0)
|
||||||
|
reversed_index = torch.flip(sorted_index, dims=[0])
|
||||||
|
selected_indexes = RHE(reversed_index[:, j], RHE_thr)
|
||||||
|
new_logits[j].append(logits[i, selected_indexes, j])
|
||||||
|
new_targets[j].append([0] * len(selected_indexes))
|
||||||
|
if torch.sum(sorted_logits[-1, :] >= 0.5) <= 0:
|
||||||
|
# all the binary probilities are smaller than 0.5
|
||||||
|
num_hit += 1
|
||||||
|
|
||||||
|
# Here we select training samples acorrding to max_ratio
|
||||||
|
loss = downsample_training_sample_and_calculate_loss(
|
||||||
|
new_logits,
|
||||||
|
new_targets,
|
||||||
|
ratio=max_ratio,
|
||||||
|
)
|
||||||
|
acc = num_hit / num_utts
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
index = torch.tensor([3, 2, 0, 7, 5, 8, 1, 4, 6])
|
||||||
|
print(RHE(index, 0)) # [3, 2, 0, 7, 5, 8, 1, 4, 6]
|
||||||
|
print(RHE(index, 1)) # [3, 0, 7, 5 ]
|
||||||
|
print(RHE(index, 2)) # [3, 0, 7]
|
||||||
|
print(RHE(index, 3)) # [3, 7]
|
||||||
|
print(RHE(index, 100)) # [3]
|
||||||
@ -17,9 +17,13 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
|
||||||
from kws.model.loss import max_polling_loss
|
from kws.model.max_pooling import max_pooling_loss
|
||||||
|
from kws.model.ce import cross_entropy
|
||||||
|
|
||||||
|
|
||||||
|
criterion_dict = {'CE': cross_entropy,
|
||||||
|
'max_pooling': max_pooling_loss}
|
||||||
|
|
||||||
class Executor:
|
class Executor:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.step = 0
|
self.step = 0
|
||||||
@ -32,6 +36,7 @@ class Executor:
|
|||||||
log_interval = args.get('log_interval', 10)
|
log_interval = args.get('log_interval', 10)
|
||||||
epoch = args.get('epoch', 0)
|
epoch = args.get('epoch', 0)
|
||||||
min_duration = args.get('min_duration', 0)
|
min_duration = args.get('min_duration', 0)
|
||||||
|
criterion = criterion_dict[args.get('criterion', max_pooling_loss)]
|
||||||
|
|
||||||
num_total_batch = 0
|
num_total_batch = 0
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
@ -44,7 +49,7 @@ class Executor:
|
|||||||
if num_utts == 0:
|
if num_utts == 0:
|
||||||
continue
|
continue
|
||||||
logits = model(feats)
|
logits = model(feats)
|
||||||
loss, acc = max_polling_loss(logits, target, feats_lengths,
|
loss, acc = criterion(logits, target, feats_lengths,
|
||||||
min_duration)
|
min_duration)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||||
@ -61,6 +66,7 @@ class Executor:
|
|||||||
model.eval()
|
model.eval()
|
||||||
log_interval = args.get('log_interval', 10)
|
log_interval = args.get('log_interval', 10)
|
||||||
epoch = args.get('epoch', 0)
|
epoch = args.get('epoch', 0)
|
||||||
|
criterion = criterion_dict[args.get('criterion', max_pooling_loss)]
|
||||||
# in order to avoid division by 0
|
# in order to avoid division by 0
|
||||||
num_seen_utts = 1
|
num_seen_utts = 1
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
@ -75,7 +81,7 @@ class Executor:
|
|||||||
continue
|
continue
|
||||||
num_seen_utts += num_utts
|
num_seen_utts += num_utts
|
||||||
logits = model(feats)
|
logits = model(feats)
|
||||||
loss, acc = max_polling_loss(logits, target, feats_lengths)
|
loss, acc = criterion(logits, target, feats_lengths)
|
||||||
if torch.isfinite(loss):
|
if torch.isfinite(loss):
|
||||||
num_seen_utts += num_utts
|
num_seen_utts += num_utts
|
||||||
total_loss += loss.item() * num_utts
|
total_loss += loss.item() * num_utts
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user