formatting
This commit is contained in:
parent
e0f6e1d5ed
commit
8d6aee9307
@ -2,17 +2,28 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(', ')
|
|
||||||
|
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
|
||||||
|
', ')
|
||||||
CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
|
CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
|
||||||
|
|
||||||
if __name__=='__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='prepare kaldi format file for google speech command dataset ')
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--wav_list', required=True, help='wave list is a file containts full path of a wav file in google speech command dataset')
|
description=
|
||||||
parser.add_argument('--data_dir', required=True, help='folder to write kaldi format files')
|
'prepare kaldi format file for google speech command dataset ')
|
||||||
|
parser.add_argument(
|
||||||
|
'--wav_list',
|
||||||
|
required=True,
|
||||||
|
help=
|
||||||
|
'wave list is a file containts full path of a wav file in google speech command dataset'
|
||||||
|
)
|
||||||
|
parser.add_argument('--data_dir',
|
||||||
|
required=True,
|
||||||
|
help='folder to write kaldi format files')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
data_dir = args.data_dir
|
data_dir = args.data_dir
|
||||||
f_wav_scp = open(os.path.join(data_dir,'wav.scp'), 'w')
|
f_wav_scp = open(os.path.join(data_dir, 'wav.scp'), 'w')
|
||||||
f_text = open(os.path.join(data_dir, 'text'), 'w')
|
f_text = open(os.path.join(data_dir, 'text'), 'w')
|
||||||
with open(args.wav_list) as f:
|
with open(args.wav_list) as f:
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
@ -21,9 +32,8 @@ if __name__=='__main__':
|
|||||||
wav_id = '_'.join([keyword, file_name_new])
|
wav_id = '_'.join([keyword, file_name_new])
|
||||||
file_dir = line.strip()
|
file_dir = line.strip()
|
||||||
f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n')
|
f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n')
|
||||||
label = CLASS_TO_IDX[keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"]
|
label = CLASS_TO_IDX[
|
||||||
|
keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"]
|
||||||
f_text.writelines(wav_id + ' ' + str(label) + '\n')
|
f_text.writelines(wav_id + ' ' + str(label) + '\n')
|
||||||
f_wav_scp.close()
|
f_wav_scp.close()
|
||||||
f_text.close()
|
f_text.close()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
def move_files(src_folder, to_folder, list_file):
|
def move_files(src_folder, to_folder, list_file):
|
||||||
with open(list_file) as f:
|
with open(list_file) as f:
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
@ -12,14 +13,18 @@ def move_files(src_folder, to_folder, list_file):
|
|||||||
dest = os.path.join(to_folder, dirname)
|
dest = os.path.join(to_folder, dirname)
|
||||||
if not os.path.exists(dest):
|
if not os.path.exists(dest):
|
||||||
os.mkdir(dest)
|
os.mkdir(dest)
|
||||||
shutil.move(os.path.join(src_folder, line),dest)
|
shutil.move(os.path.join(src_folder, line), dest)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Split google command dataset.')
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('root', type=str, help='the path to the root folder of the google commands dataset')
|
description='Split google command dataset.')
|
||||||
|
parser.add_argument(
|
||||||
|
'root',
|
||||||
|
type=str,
|
||||||
|
help='the path to the root folder of the google commands dataset')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
audio_folder = os.path.join(args.root, 'audio')
|
audio_folder = os.path.join(args.root, 'audio')
|
||||||
validation_path = os.path.join(audio_folder, 'validation_list.txt')
|
validation_path = os.path.join(audio_folder, 'validation_list.txt')
|
||||||
test_path = os.path.join(audio_folder, 'testing_list.txt')
|
test_path = os.path.join(audio_folder, 'testing_list.txt')
|
||||||
@ -34,4 +39,3 @@ if __name__ == '__main__':
|
|||||||
move_files(audio_folder, test_folder, test_path)
|
move_files(audio_folder, test_folder, test_path)
|
||||||
move_files(audio_folder, valid_folder, validation_path)
|
move_files(audio_folder, valid_folder, validation_path)
|
||||||
os.rename(audio_folder, train_folder)
|
os.rename(audio_folder, train_folder)
|
||||||
|
|
||||||
|
|||||||
@ -15,15 +15,22 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
def acc_frame(logits: torch.Tensor, target: torch.Tensor, ):
|
|
||||||
|
def acc_frame(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
):
|
||||||
if logits is None:
|
if logits is None:
|
||||||
return 0
|
return 0
|
||||||
pred = logits.max(1, keepdim=True)[1]
|
pred = logits.max(1, keepdim=True)[1]
|
||||||
correct = pred.eq(target.long().view_as(pred)).sum().item()
|
correct = pred.eq(target.long().view_as(pred)).sum().item()
|
||||||
return correct*100.0/logits.size(0)
|
return correct * 100.0 / logits.size(0)
|
||||||
|
|
||||||
def cross_entropy(logits: torch.Tensor, target: torch.Tensor,
|
|
||||||
lengths: torch.Tensor, min_duration: int = 0):
|
def cross_entropy(logits: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
min_duration: int = 0):
|
||||||
""" Cross Entropy Loss
|
""" Cross Entropy Loss
|
||||||
Attributes:
|
Attributes:
|
||||||
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
||||||
@ -37,4 +44,4 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor,
|
|||||||
cross_entropy = nn.CrossEntropyLoss()
|
cross_entropy = nn.CrossEntropyLoss()
|
||||||
loss = cross_entropy(logits, target)
|
loss = cross_entropy(logits, target)
|
||||||
acc = acc_frame(logits, target)
|
acc = acc_frame(logits, target)
|
||||||
return loss, acc
|
return loss, acc
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class GlobalClassifier(nn.Module):
|
class GlobalClassifier(nn.Module):
|
||||||
"""Add a global average pooling before the classifier"""
|
"""Add a global average pooling before the classifier"""
|
||||||
def __init__(
|
def __init__(self, classifier: nn.Module):
|
||||||
self,
|
|
||||||
classifier: nn.Module
|
|
||||||
):
|
|
||||||
super(GlobalClassifier, self).__init__()
|
super(GlobalClassifier, self).__init__()
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
|
|
||||||
@ -14,15 +12,13 @@ class GlobalClassifier(nn.Module):
|
|||||||
x = torch.mean(x, dim=1)
|
x = torch.mean(x, dim=1)
|
||||||
return self.classifier(x)
|
return self.classifier(x)
|
||||||
|
|
||||||
|
|
||||||
class LastClassifier(nn.Module):
|
class LastClassifier(nn.Module):
|
||||||
"""Select last frame to do the classification"""
|
"""Select last frame to do the classification"""
|
||||||
def __init__(
|
def __init__(self, classifier: nn.Module):
|
||||||
self,
|
|
||||||
classifier: nn.Module
|
|
||||||
):
|
|
||||||
super(LastClassifier, self).__init__()
|
super(LastClassifier, self).__init__()
|
||||||
self.classifier = classifier
|
self.classifier = classifier
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
x = x[:, -1, :]
|
x = x[:, -1, :]
|
||||||
return self.classifier(x)
|
return self.classifier(x)
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from kws.model.mdtc import MDTC
|
|||||||
from kws.model.classifier import GlobalClassifier, LastClassifier
|
from kws.model.classifier import GlobalClassifier, LastClassifier
|
||||||
from kws.utils.cmvn import load_cmvn
|
from kws.utils.cmvn import load_cmvn
|
||||||
|
|
||||||
|
|
||||||
class KWSModel(torch.nn.Module):
|
class KWSModel(torch.nn.Module):
|
||||||
"""Our model consists of four parts:
|
"""Our model consists of four parts:
|
||||||
1. global_cmvn: Optional, (idim, idim)
|
1. global_cmvn: Optional, (idim, idim)
|
||||||
@ -31,16 +32,10 @@ class KWSModel(torch.nn.Module):
|
|||||||
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
||||||
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(self, idim: int, odim: int, hdim: int,
|
||||||
self,
|
global_cmvn: Optional[torch.nn.Module],
|
||||||
idim: int,
|
preprocessing: Optional[torch.nn.Module],
|
||||||
odim: int,
|
backbone: torch.nn.Module, classifier: torch.nn.Module):
|
||||||
hdim: int,
|
|
||||||
global_cmvn: Optional[torch.nn.Module],
|
|
||||||
preprocessing: Optional[torch.nn.Module],
|
|
||||||
backbone: torch.nn.Module,
|
|
||||||
classifier: torch.nn.Module
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.idim = idim
|
self.idim = idim
|
||||||
self.odim = odim
|
self.odim = odim
|
||||||
|
|||||||
@ -39,7 +39,7 @@ def RHE(indice: torch.Tensor, k: int):
|
|||||||
reserve.append(indice[i])
|
reserve.append(indice[i])
|
||||||
rm_s = max(indice[i] - k, 0)
|
rm_s = max(indice[i] - k, 0)
|
||||||
rm_e = min(indice[i] + k, lenght)
|
rm_e = min(indice[i] + k, lenght)
|
||||||
available_indice[rm_s : rm_e + 1] = 0
|
available_indice[rm_s:rm_e + 1] = 0
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -48,7 +48,9 @@ def RHE(indice: torch.Tensor, k: int):
|
|||||||
return torch.tensor(reserve).long()
|
return torch.tensor(reserve).long()
|
||||||
|
|
||||||
|
|
||||||
def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float = 10):
|
def downsample_training_sample_and_calculate_loss(logits,
|
||||||
|
targets,
|
||||||
|
ratio: float = 10):
|
||||||
num_training = 0
|
num_training = 0
|
||||||
loss = 0
|
loss = 0
|
||||||
for i in range(len(logits)):
|
for i in range(len(logits)):
|
||||||
@ -69,8 +71,11 @@ def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float
|
|||||||
return loss / num_training
|
return loss / num_training
|
||||||
|
|
||||||
|
|
||||||
def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio=1):
|
def max_pooling_RHE_binary_CE(logits,
|
||||||
|
targets,
|
||||||
|
lengths,
|
||||||
|
RHE_thr=10000,
|
||||||
|
max_ratio=1):
|
||||||
"""Max-pooling loss with regional hard example mining
|
"""Max-pooling loss with regional hard example mining
|
||||||
For each keyword utterance, select the frame with the highest posterior.
|
For each keyword utterance, select the frame with the highest posterior.
|
||||||
The keyword is triggered when any of the frames is triggered.
|
The keyword is triggered when any of the frames is triggered.
|
||||||
@ -108,7 +113,8 @@ def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio
|
|||||||
if logits[i, max_idx, j] >= 0.5:
|
if logits[i, max_idx, j] >= 0.5:
|
||||||
num_hit += 1
|
num_hit += 1
|
||||||
else:
|
else:
|
||||||
sorted_logits, sorted_index = torch.sort(logits[i, :end_idx], dim=0)
|
sorted_logits, sorted_index = torch.sort(logits[i, :end_idx],
|
||||||
|
dim=0)
|
||||||
reversed_index = torch.flip(sorted_index, dims=[0])
|
reversed_index = torch.flip(sorted_index, dims=[0])
|
||||||
selected_indexes = RHE(reversed_index[:, j], RHE_thr)
|
selected_indexes = RHE(reversed_index[:, j], RHE_thr)
|
||||||
new_logits[j].append(logits[i, selected_indexes, j])
|
new_logits[j].append(logits[i, selected_indexes, j])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user