formatting

This commit is contained in:
Jingyong Hou 2021-11-12 11:02:00 +08:00
parent e0f6e1d5ed
commit 8d6aee9307
6 changed files with 61 additions and 43 deletions

View File

@ -2,17 +2,28 @@
import os
import sys
import argparse
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(', ')
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
', ')
CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
if __name__=='__main__':
parser = argparse.ArgumentParser(description='prepare kaldi format file for google speech command dataset ')
parser.add_argument('--wav_list', required=True, help='wave list is a file containts full path of a wav file in google speech command dataset')
parser.add_argument('--data_dir', required=True, help='folder to write kaldi format files')
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=
'prepare kaldi format file for google speech command dataset ')
parser.add_argument(
'--wav_list',
required=True,
help=
'wave list is a file containts full path of a wav file in google speech command dataset'
)
parser.add_argument('--data_dir',
required=True,
help='folder to write kaldi format files')
args = parser.parse_args()
data_dir = args.data_dir
f_wav_scp = open(os.path.join(data_dir,'wav.scp'), 'w')
f_wav_scp = open(os.path.join(data_dir, 'wav.scp'), 'w')
f_text = open(os.path.join(data_dir, 'text'), 'w')
with open(args.wav_list) as f:
for line in f.readlines():
@ -21,9 +32,8 @@ if __name__=='__main__':
wav_id = '_'.join([keyword, file_name_new])
file_dir = line.strip()
f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n')
label = CLASS_TO_IDX[keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"]
label = CLASS_TO_IDX[
keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"]
f_text.writelines(wav_id + ' ' + str(label) + '\n')
f_wav_scp.close()
f_text.close()

View File

@ -4,6 +4,7 @@ import os
import shutil
import argparse
def move_files(src_folder, to_folder, list_file):
with open(list_file) as f:
for line in f.readlines():
@ -12,14 +13,18 @@ def move_files(src_folder, to_folder, list_file):
dest = os.path.join(to_folder, dirname)
if not os.path.exists(dest):
os.mkdir(dest)
shutil.move(os.path.join(src_folder, line),dest)
shutil.move(os.path.join(src_folder, line), dest)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Split google command dataset.')
parser.add_argument('root', type=str, help='the path to the root folder of the google commands dataset')
parser = argparse.ArgumentParser(
description='Split google command dataset.')
parser.add_argument(
'root',
type=str,
help='the path to the root folder of the google commands dataset')
args = parser.parse_args()
audio_folder = os.path.join(args.root, 'audio')
validation_path = os.path.join(audio_folder, 'validation_list.txt')
test_path = os.path.join(audio_folder, 'testing_list.txt')
@ -34,4 +39,3 @@ if __name__ == '__main__':
move_files(audio_folder, test_folder, test_path)
move_files(audio_folder, valid_folder, validation_path)
os.rename(audio_folder, train_folder)

View File

@ -15,15 +15,22 @@
import torch
import torch.nn as nn
def acc_frame(logits: torch.Tensor, target: torch.Tensor, ):
def acc_frame(
logits: torch.Tensor,
target: torch.Tensor,
):
if logits is None:
return 0
pred = logits.max(1, keepdim=True)[1]
correct = pred.eq(target.long().view_as(pred)).sum().item()
return correct*100.0/logits.size(0)
return correct * 100.0 / logits.size(0)
def cross_entropy(logits: torch.Tensor, target: torch.Tensor,
lengths: torch.Tensor, min_duration: int = 0):
def cross_entropy(logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
""" Cross Entropy Loss
Attributes:
logits: (B, D), D is the number of keywords plus 1 (non-keyword)

View File

@ -1,12 +1,10 @@
import torch
import torch.nn as nn
class GlobalClassifier(nn.Module):
"""Add a global average pooling before the classifier"""
def __init__(
self,
classifier: nn.Module
):
def __init__(self, classifier: nn.Module):
super(GlobalClassifier, self).__init__()
self.classifier = classifier
@ -14,12 +12,10 @@ class GlobalClassifier(nn.Module):
x = torch.mean(x, dim=1)
return self.classifier(x)
class LastClassifier(nn.Module):
"""Select last frame to do the classification"""
def __init__(
self,
classifier: nn.Module
):
def __init__(self, classifier: nn.Module):
super(LastClassifier, self).__init__()
self.classifier = classifier

View File

@ -24,6 +24,7 @@ from kws.model.mdtc import MDTC
from kws.model.classifier import GlobalClassifier, LastClassifier
from kws.utils.cmvn import load_cmvn
class KWSModel(torch.nn.Module):
"""Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim)
@ -31,16 +32,10 @@ class KWSModel(torch.nn.Module):
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim)
"""
def __init__(
self,
idim: int,
odim: int,
hdim: int,
global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module,
classifier: torch.nn.Module
):
def __init__(self, idim: int, odim: int, hdim: int,
global_cmvn: Optional[torch.nn.Module],
preprocessing: Optional[torch.nn.Module],
backbone: torch.nn.Module, classifier: torch.nn.Module):
super().__init__()
self.idim = idim
self.odim = odim

View File

@ -39,7 +39,7 @@ def RHE(indice: torch.Tensor, k: int):
reserve.append(indice[i])
rm_s = max(indice[i] - k, 0)
rm_e = min(indice[i] + k, lenght)
available_indice[rm_s : rm_e + 1] = 0
available_indice[rm_s:rm_e + 1] = 0
else:
continue
@ -48,7 +48,9 @@ def RHE(indice: torch.Tensor, k: int):
return torch.tensor(reserve).long()
def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float = 10):
def downsample_training_sample_and_calculate_loss(logits,
targets,
ratio: float = 10):
num_training = 0
loss = 0
for i in range(len(logits)):
@ -69,8 +71,11 @@ def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float
return loss / num_training
def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio=1):
def max_pooling_RHE_binary_CE(logits,
targets,
lengths,
RHE_thr=10000,
max_ratio=1):
"""Max-pooling loss with regional hard example mining
For each keyword utterance, select the frame with the highest posterior.
The keyword is triggered when any of the frames is triggered.
@ -108,7 +113,8 @@ def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio
if logits[i, max_idx, j] >= 0.5:
num_hit += 1
else:
sorted_logits, sorted_index = torch.sort(logits[i, :end_idx], dim=0)
sorted_logits, sorted_index = torch.sort(logits[i, :end_idx],
dim=0)
reversed_index = torch.flip(sorted_index, dims=[0])
selected_indexes = RHE(reversed_index[:, j], RHE_thr)
new_logits[j].append(logits[i, selected_indexes, j])