formatting
This commit is contained in:
parent
e0f6e1d5ed
commit
8d6aee9307
@ -2,17 +2,28 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(', ')
|
||||
|
||||
CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(
|
||||
', ')
|
||||
CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))}
|
||||
|
||||
if __name__=='__main__':
|
||||
parser = argparse.ArgumentParser(description='prepare kaldi format file for google speech command dataset ')
|
||||
parser.add_argument('--wav_list', required=True, help='wave list is a file containts full path of a wav file in google speech command dataset')
|
||||
parser.add_argument('--data_dir', required=True, help='folder to write kaldi format files')
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description=
|
||||
'prepare kaldi format file for google speech command dataset ')
|
||||
parser.add_argument(
|
||||
'--wav_list',
|
||||
required=True,
|
||||
help=
|
||||
'wave list is a file containts full path of a wav file in google speech command dataset'
|
||||
)
|
||||
parser.add_argument('--data_dir',
|
||||
required=True,
|
||||
help='folder to write kaldi format files')
|
||||
args = parser.parse_args()
|
||||
|
||||
data_dir = args.data_dir
|
||||
f_wav_scp = open(os.path.join(data_dir,'wav.scp'), 'w')
|
||||
f_wav_scp = open(os.path.join(data_dir, 'wav.scp'), 'w')
|
||||
f_text = open(os.path.join(data_dir, 'text'), 'w')
|
||||
with open(args.wav_list) as f:
|
||||
for line in f.readlines():
|
||||
@ -21,9 +32,8 @@ if __name__=='__main__':
|
||||
wav_id = '_'.join([keyword, file_name_new])
|
||||
file_dir = line.strip()
|
||||
f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n')
|
||||
label = CLASS_TO_IDX[keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"]
|
||||
label = CLASS_TO_IDX[
|
||||
keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"]
|
||||
f_text.writelines(wav_id + ' ' + str(label) + '\n')
|
||||
f_wav_scp.close()
|
||||
f_text.close()
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import shutil
|
||||
import argparse
|
||||
|
||||
|
||||
def move_files(src_folder, to_folder, list_file):
|
||||
with open(list_file) as f:
|
||||
for line in f.readlines():
|
||||
@ -12,14 +13,18 @@ def move_files(src_folder, to_folder, list_file):
|
||||
dest = os.path.join(to_folder, dirname)
|
||||
if not os.path.exists(dest):
|
||||
os.mkdir(dest)
|
||||
shutil.move(os.path.join(src_folder, line),dest)
|
||||
shutil.move(os.path.join(src_folder, line), dest)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Split google command dataset.')
|
||||
parser.add_argument('root', type=str, help='the path to the root folder of the google commands dataset')
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Split google command dataset.')
|
||||
parser.add_argument(
|
||||
'root',
|
||||
type=str,
|
||||
help='the path to the root folder of the google commands dataset')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
audio_folder = os.path.join(args.root, 'audio')
|
||||
validation_path = os.path.join(audio_folder, 'validation_list.txt')
|
||||
test_path = os.path.join(audio_folder, 'testing_list.txt')
|
||||
@ -34,4 +39,3 @@ if __name__ == '__main__':
|
||||
move_files(audio_folder, test_folder, test_path)
|
||||
move_files(audio_folder, valid_folder, validation_path)
|
||||
os.rename(audio_folder, train_folder)
|
||||
|
||||
|
||||
@ -15,15 +15,22 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def acc_frame(logits: torch.Tensor, target: torch.Tensor, ):
|
||||
|
||||
def acc_frame(
|
||||
logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
):
|
||||
if logits is None:
|
||||
return 0
|
||||
pred = logits.max(1, keepdim=True)[1]
|
||||
correct = pred.eq(target.long().view_as(pred)).sum().item()
|
||||
return correct*100.0/logits.size(0)
|
||||
return correct * 100.0 / logits.size(0)
|
||||
|
||||
def cross_entropy(logits: torch.Tensor, target: torch.Tensor,
|
||||
lengths: torch.Tensor, min_duration: int = 0):
|
||||
|
||||
def cross_entropy(logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
min_duration: int = 0):
|
||||
""" Cross Entropy Loss
|
||||
Attributes:
|
||||
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class GlobalClassifier(nn.Module):
|
||||
"""Add a global average pooling before the classifier"""
|
||||
def __init__(
|
||||
self,
|
||||
classifier: nn.Module
|
||||
):
|
||||
def __init__(self, classifier: nn.Module):
|
||||
super(GlobalClassifier, self).__init__()
|
||||
self.classifier = classifier
|
||||
|
||||
@ -14,12 +12,10 @@ class GlobalClassifier(nn.Module):
|
||||
x = torch.mean(x, dim=1)
|
||||
return self.classifier(x)
|
||||
|
||||
|
||||
class LastClassifier(nn.Module):
|
||||
"""Select last frame to do the classification"""
|
||||
def __init__(
|
||||
self,
|
||||
classifier: nn.Module
|
||||
):
|
||||
def __init__(self, classifier: nn.Module):
|
||||
super(LastClassifier, self).__init__()
|
||||
self.classifier = classifier
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ from kws.model.mdtc import MDTC
|
||||
from kws.model.classifier import GlobalClassifier, LastClassifier
|
||||
from kws.utils.cmvn import load_cmvn
|
||||
|
||||
|
||||
class KWSModel(torch.nn.Module):
|
||||
"""Our model consists of four parts:
|
||||
1. global_cmvn: Optional, (idim, idim)
|
||||
@ -31,16 +32,10 @@ class KWSModel(torch.nn.Module):
|
||||
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
|
||||
4. classifier: output layer or classifier of KWS model, (hdim, odim)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
idim: int,
|
||||
odim: int,
|
||||
hdim: int,
|
||||
global_cmvn: Optional[torch.nn.Module],
|
||||
preprocessing: Optional[torch.nn.Module],
|
||||
backbone: torch.nn.Module,
|
||||
classifier: torch.nn.Module
|
||||
):
|
||||
def __init__(self, idim: int, odim: int, hdim: int,
|
||||
global_cmvn: Optional[torch.nn.Module],
|
||||
preprocessing: Optional[torch.nn.Module],
|
||||
backbone: torch.nn.Module, classifier: torch.nn.Module):
|
||||
super().__init__()
|
||||
self.idim = idim
|
||||
self.odim = odim
|
||||
|
||||
@ -39,7 +39,7 @@ def RHE(indice: torch.Tensor, k: int):
|
||||
reserve.append(indice[i])
|
||||
rm_s = max(indice[i] - k, 0)
|
||||
rm_e = min(indice[i] + k, lenght)
|
||||
available_indice[rm_s : rm_e + 1] = 0
|
||||
available_indice[rm_s:rm_e + 1] = 0
|
||||
else:
|
||||
continue
|
||||
|
||||
@ -48,7 +48,9 @@ def RHE(indice: torch.Tensor, k: int):
|
||||
return torch.tensor(reserve).long()
|
||||
|
||||
|
||||
def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float = 10):
|
||||
def downsample_training_sample_and_calculate_loss(logits,
|
||||
targets,
|
||||
ratio: float = 10):
|
||||
num_training = 0
|
||||
loss = 0
|
||||
for i in range(len(logits)):
|
||||
@ -69,8 +71,11 @@ def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float
|
||||
return loss / num_training
|
||||
|
||||
|
||||
def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio=1):
|
||||
|
||||
def max_pooling_RHE_binary_CE(logits,
|
||||
targets,
|
||||
lengths,
|
||||
RHE_thr=10000,
|
||||
max_ratio=1):
|
||||
"""Max-pooling loss with regional hard example mining
|
||||
For each keyword utterance, select the frame with the highest posterior.
|
||||
The keyword is triggered when any of the frames is triggered.
|
||||
@ -108,7 +113,8 @@ def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio
|
||||
if logits[i, max_idx, j] >= 0.5:
|
||||
num_hit += 1
|
||||
else:
|
||||
sorted_logits, sorted_index = torch.sort(logits[i, :end_idx], dim=0)
|
||||
sorted_logits, sorted_index = torch.sort(logits[i, :end_idx],
|
||||
dim=0)
|
||||
reversed_index = torch.flip(sorted_index, dims=[0])
|
||||
selected_indexes = RHE(reversed_index[:, j], RHE_thr)
|
||||
new_logits[j].append(logits[i, selected_indexes, j])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user