formatting

This commit is contained in:
Jingyong Hou 2021-11-12 11:02:00 +08:00
parent e0f6e1d5ed
commit 8d6aee9307
6 changed files with 61 additions and 43 deletions

View File

@ -2,17 +2,28 @@
import os import os
import sys import sys
import argparse import argparse
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(', ')
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
', ')
CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))} CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
if __name__=='__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='prepare kaldi format file for google speech command dataset ') parser = argparse.ArgumentParser(
parser.add_argument('--wav_list', required=True, help='wave list is a file containts full path of a wav file in google speech command dataset') description=
parser.add_argument('--data_dir', required=True, help='folder to write kaldi format files') 'prepare kaldi format file for google speech command dataset ')
parser.add_argument(
'--wav_list',
required=True,
help=
'wave list is a file containts full path of a wav file in google speech command dataset'
)
parser.add_argument('--data_dir',
required=True,
help='folder to write kaldi format files')
args = parser.parse_args() args = parser.parse_args()
data_dir = args.data_dir data_dir = args.data_dir
f_wav_scp = open(os.path.join(data_dir,'wav.scp'), 'w') f_wav_scp = open(os.path.join(data_dir, 'wav.scp'), 'w')
f_text = open(os.path.join(data_dir, 'text'), 'w') f_text = open(os.path.join(data_dir, 'text'), 'w')
with open(args.wav_list) as f: with open(args.wav_list) as f:
for line in f.readlines(): for line in f.readlines():
@ -21,9 +32,8 @@ if __name__=='__main__':
wav_id = '_'.join([keyword, file_name_new]) wav_id = '_'.join([keyword, file_name_new])
file_dir = line.strip() file_dir = line.strip()
f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n') f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n')
label = CLASS_TO_IDX[keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"] label = CLASS_TO_IDX[
keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"]
f_text.writelines(wav_id + ' ' + str(label) + '\n') f_text.writelines(wav_id + ' ' + str(label) + '\n')
f_wav_scp.close() f_wav_scp.close()
f_text.close() f_text.close()

View File

@ -4,6 +4,7 @@ import os
import shutil import shutil
import argparse import argparse
def move_files(src_folder, to_folder, list_file): def move_files(src_folder, to_folder, list_file):
with open(list_file) as f: with open(list_file) as f:
for line in f.readlines(): for line in f.readlines():
@ -12,14 +13,18 @@ def move_files(src_folder, to_folder, list_file):
dest = os.path.join(to_folder, dirname) dest = os.path.join(to_folder, dirname)
if not os.path.exists(dest): if not os.path.exists(dest):
os.mkdir(dest) os.mkdir(dest)
shutil.move(os.path.join(src_folder, line),dest) shutil.move(os.path.join(src_folder, line), dest)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Split google command dataset.') parser = argparse.ArgumentParser(
parser.add_argument('root', type=str, help='the path to the root folder of the google commands dataset') description='Split google command dataset.')
parser.add_argument(
'root',
type=str,
help='the path to the root folder of the google commands dataset')
args = parser.parse_args() args = parser.parse_args()
audio_folder = os.path.join(args.root, 'audio') audio_folder = os.path.join(args.root, 'audio')
validation_path = os.path.join(audio_folder, 'validation_list.txt') validation_path = os.path.join(audio_folder, 'validation_list.txt')
test_path = os.path.join(audio_folder, 'testing_list.txt') test_path = os.path.join(audio_folder, 'testing_list.txt')
@ -34,4 +39,3 @@ if __name__ == '__main__':
move_files(audio_folder, test_folder, test_path) move_files(audio_folder, test_folder, test_path)
move_files(audio_folder, valid_folder, validation_path) move_files(audio_folder, valid_folder, validation_path)
os.rename(audio_folder, train_folder) os.rename(audio_folder, train_folder)

View File

@ -15,15 +15,22 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
def acc_frame(logits: torch.Tensor, target: torch.Tensor, ):
def acc_frame(
logits: torch.Tensor,
target: torch.Tensor,
):
if logits is None: if logits is None:
return 0 return 0
pred = logits.max(1, keepdim=True)[1] pred = logits.max(1, keepdim=True)[1]
correct = pred.eq(target.long().view_as(pred)).sum().item() correct = pred.eq(target.long().view_as(pred)).sum().item()
return correct*100.0/logits.size(0) return correct * 100.0 / logits.size(0)
def cross_entropy(logits: torch.Tensor, target: torch.Tensor,
lengths: torch.Tensor, min_duration: int = 0): def cross_entropy(logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
""" Cross Entropy Loss """ Cross Entropy Loss
Attributes: Attributes:
logits: (B, D), D is the number of keywords plus 1 (non-keyword) logits: (B, D), D is the number of keywords plus 1 (non-keyword)
@ -37,4 +44,4 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor,
cross_entropy = nn.CrossEntropyLoss() cross_entropy = nn.CrossEntropyLoss()
loss = cross_entropy(logits, target) loss = cross_entropy(logits, target)
acc = acc_frame(logits, target) acc = acc_frame(logits, target)
return loss, acc return loss, acc

View File

@ -1,12 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
class GlobalClassifier(nn.Module): class GlobalClassifier(nn.Module):
"""Add a global average pooling before the classifier""" """Add a global average pooling before the classifier"""
def __init__( def __init__(self, classifier: nn.Module):
self,
classifier: nn.Module
):
super(GlobalClassifier, self).__init__() super(GlobalClassifier, self).__init__()
self.classifier = classifier self.classifier = classifier
@ -14,15 +12,13 @@ class GlobalClassifier(nn.Module):
x = torch.mean(x, dim=1) x = torch.mean(x, dim=1)
return self.classifier(x) return self.classifier(x)
class LastClassifier(nn.Module): class LastClassifier(nn.Module):
"""Select last frame to do the classification""" """Select last frame to do the classification"""
def __init__( def __init__(self, classifier: nn.Module):
self,
classifier: nn.Module
):
super(LastClassifier, self).__init__() super(LastClassifier, self).__init__()
self.classifier = classifier self.classifier = classifier
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = x[:, -1, :] x = x[:, -1, :]
return self.classifier(x) return self.classifier(x)

View File

@ -24,6 +24,7 @@ from kws.model.mdtc import MDTC
from kws.model.classifier import GlobalClassifier, LastClassifier from kws.model.classifier import GlobalClassifier, LastClassifier
from kws.utils.cmvn import load_cmvn from kws.utils.cmvn import load_cmvn
class KWSModel(torch.nn.Module): class KWSModel(torch.nn.Module):
"""Our model consists of four parts: """Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim) 1. global_cmvn: Optional, (idim, idim)
@ -31,16 +32,10 @@ class KWSModel(torch.nn.Module):
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim) 3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim) 4. classifier: output layer or classifier of KWS model, (hdim, odim)
""" """
def __init__( def __init__(self, idim: int, odim: int, hdim: int,
self, global_cmvn: Optional[torch.nn.Module],
idim: int, preprocessing: Optional[torch.nn.Module],
odim: int, backbone: torch.nn.Module, classifier: torch.nn.Module):
hdim: int,
global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module,
classifier: torch.nn.Module
):
super().__init__() super().__init__()
self.idim = idim self.idim = idim
self.odim = odim self.odim = odim

View File

@ -39,7 +39,7 @@ def RHE(indice: torch.Tensor, k: int):
reserve.append(indice[i]) reserve.append(indice[i])
rm_s = max(indice[i] - k, 0) rm_s = max(indice[i] - k, 0)
rm_e = min(indice[i] + k, lenght) rm_e = min(indice[i] + k, lenght)
available_indice[rm_s : rm_e + 1] = 0 available_indice[rm_s:rm_e + 1] = 0
else: else:
continue continue
@ -48,7 +48,9 @@ def RHE(indice: torch.Tensor, k: int):
return torch.tensor(reserve).long() return torch.tensor(reserve).long()
def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float = 10): def downsample_training_sample_and_calculate_loss(logits,
targets,
ratio: float = 10):
num_training = 0 num_training = 0
loss = 0 loss = 0
for i in range(len(logits)): for i in range(len(logits)):
@ -69,8 +71,11 @@ def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float
return loss / num_training return loss / num_training
def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio=1): def max_pooling_RHE_binary_CE(logits,
targets,
lengths,
RHE_thr=10000,
max_ratio=1):
"""Max-pooling loss with regional hard example mining """Max-pooling loss with regional hard example mining
For each keyword utterance, select the frame with the highest posterior. For each keyword utterance, select the frame with the highest posterior.
The keyword is triggered when any of the frames is triggered. The keyword is triggered when any of the frames is triggered.
@ -108,7 +113,8 @@ def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio
if logits[i, max_idx, j] >= 0.5: if logits[i, max_idx, j] >= 0.5:
num_hit += 1 num_hit += 1
else: else:
sorted_logits, sorted_index = torch.sort(logits[i, :end_idx], dim=0) sorted_logits, sorted_index = torch.sort(logits[i, :end_idx],
dim=0)
reversed_index = torch.flip(sorted_index, dims=[0]) reversed_index = torch.flip(sorted_index, dims=[0])
selected_indexes = RHE(reversed_index[:, j], RHE_thr) selected_indexes = RHE(reversed_index[:, j], RHE_thr)
new_logits[j].append(logits[i, selected_indexes, j]) new_logits[j].append(logits[i, selected_indexes, j])