diff --git a/examples/speechcommand_v1/s0/local/prepare_speech_command.py b/examples/speechcommand_v1/s0/local/prepare_speech_command.py index 0cad173..f7898f3 100644 --- a/examples/speechcommand_v1/s0/local/prepare_speech_command.py +++ b/examples/speechcommand_v1/s0/local/prepare_speech_command.py @@ -2,17 +2,28 @@ import os import sys import argparse -CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split(', ') + +CLASSES = 'unknown, yes, no, up, down, left, right, on, off, stop, go'.split( + ', ') CLASS_TO_IDX = {CLASSES[i]: str(i) for i in range(len(CLASSES))} -if __name__=='__main__': - parser = argparse.ArgumentParser(description='prepare kaldi format file for google speech command dataset ') - parser.add_argument('--wav_list', required=True, help='wave list is a file containts full path of a wav file in google speech command dataset') - parser.add_argument('--data_dir', required=True, help='folder to write kaldi format files') +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description= + 'prepare kaldi format file for google speech command dataset ') + parser.add_argument( + '--wav_list', + required=True, + help= + 'wave list is a file containts full path of a wav file in google speech command dataset' + ) + parser.add_argument('--data_dir', + required=True, + help='folder to write kaldi format files') args = parser.parse_args() data_dir = args.data_dir - f_wav_scp = open(os.path.join(data_dir,'wav.scp'), 'w') + f_wav_scp = open(os.path.join(data_dir, 'wav.scp'), 'w') f_text = open(os.path.join(data_dir, 'text'), 'w') with open(args.wav_list) as f: for line in f.readlines(): @@ -21,9 +32,8 @@ if __name__=='__main__': wav_id = '_'.join([keyword, file_name_new]) file_dir = line.strip() f_wav_scp.writelines(wav_id + ' ' + file_dir + '\n') - label = CLASS_TO_IDX[keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"] + label = CLASS_TO_IDX[ + keyword] if keyword in CLASS_TO_IDX else CLASS_TO_IDX["unknown"] f_text.writelines(wav_id + ' ' + str(label) + '\n') f_wav_scp.close() f_text.close() - - diff --git a/examples/speechcommand_v1/s0/local/split_dataset.py b/examples/speechcommand_v1/s0/local/split_dataset.py index 6fce2d2..8d625db 100755 --- a/examples/speechcommand_v1/s0/local/split_dataset.py +++ b/examples/speechcommand_v1/s0/local/split_dataset.py @@ -4,6 +4,7 @@ import os import shutil import argparse + def move_files(src_folder, to_folder, list_file): with open(list_file) as f: for line in f.readlines(): @@ -12,14 +13,18 @@ def move_files(src_folder, to_folder, list_file): dest = os.path.join(to_folder, dirname) if not os.path.exists(dest): os.mkdir(dest) - shutil.move(os.path.join(src_folder, line),dest) + shutil.move(os.path.join(src_folder, line), dest) + if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Split google command dataset.') - parser.add_argument('root', type=str, help='the path to the root folder of the google commands dataset') + parser = argparse.ArgumentParser( + description='Split google command dataset.') + parser.add_argument( + 'root', + type=str, + help='the path to the root folder of the google commands dataset') args = parser.parse_args() - audio_folder = os.path.join(args.root, 'audio') validation_path = os.path.join(audio_folder, 'validation_list.txt') test_path = os.path.join(audio_folder, 'testing_list.txt') @@ -34,4 +39,3 @@ if __name__ == '__main__': move_files(audio_folder, test_folder, test_path) move_files(audio_folder, valid_folder, validation_path) os.rename(audio_folder, train_folder) - diff --git a/kws/model/ce.py b/kws/model/ce.py index 3ce0252..10668d0 100644 --- a/kws/model/ce.py +++ b/kws/model/ce.py @@ -15,15 +15,22 @@ import torch import torch.nn as nn -def acc_frame(logits: torch.Tensor, target: torch.Tensor, ): + +def acc_frame( + logits: torch.Tensor, + target: torch.Tensor, +): if logits is None: return 0 pred = logits.max(1, keepdim=True)[1] correct = pred.eq(target.long().view_as(pred)).sum().item() - return correct*100.0/logits.size(0) + return correct * 100.0 / logits.size(0) -def cross_entropy(logits: torch.Tensor, target: torch.Tensor, - lengths: torch.Tensor, min_duration: int = 0): + +def cross_entropy(logits: torch.Tensor, + target: torch.Tensor, + lengths: torch.Tensor, + min_duration: int = 0): """ Cross Entropy Loss Attributes: logits: (B, D), D is the number of keywords plus 1 (non-keyword) @@ -37,4 +44,4 @@ def cross_entropy(logits: torch.Tensor, target: torch.Tensor, cross_entropy = nn.CrossEntropyLoss() loss = cross_entropy(logits, target) acc = acc_frame(logits, target) - return loss, acc \ No newline at end of file + return loss, acc diff --git a/kws/model/classifier.py b/kws/model/classifier.py index 732c42d..cf81ab5 100644 --- a/kws/model/classifier.py +++ b/kws/model/classifier.py @@ -1,12 +1,10 @@ import torch import torch.nn as nn + class GlobalClassifier(nn.Module): """Add a global average pooling before the classifier""" - def __init__( - self, - classifier: nn.Module - ): + def __init__(self, classifier: nn.Module): super(GlobalClassifier, self).__init__() self.classifier = classifier @@ -14,15 +12,13 @@ class GlobalClassifier(nn.Module): x = torch.mean(x, dim=1) return self.classifier(x) + class LastClassifier(nn.Module): """Select last frame to do the classification""" - def __init__( - self, - classifier: nn.Module - ): + def __init__(self, classifier: nn.Module): super(LastClassifier, self).__init__() self.classifier = classifier def forward(self, x: torch.Tensor): x = x[:, -1, :] - return self.classifier(x) \ No newline at end of file + return self.classifier(x) diff --git a/kws/model/kws_model.py b/kws/model/kws_model.py index d5a1840..15fb08e 100644 --- a/kws/model/kws_model.py +++ b/kws/model/kws_model.py @@ -24,6 +24,7 @@ from kws.model.mdtc import MDTC from kws.model.classifier import GlobalClassifier, LastClassifier from kws.utils.cmvn import load_cmvn + class KWSModel(torch.nn.Module): """Our model consists of four parts: 1. global_cmvn: Optional, (idim, idim) @@ -31,16 +32,10 @@ class KWSModel(torch.nn.Module): 3. backbone: backbone or feature extractor of the whole network, (hdim, hdim) 4. classifier: output layer or classifier of KWS model, (hdim, odim) """ - def __init__( - self, - idim: int, - odim: int, - hdim: int, - global_cmvn: Optional[torch.nn.Module], - preprocessing: Optional[torch.nn.Module], - backbone: torch.nn.Module, - classifier: torch.nn.Module - ): + def __init__(self, idim: int, odim: int, hdim: int, + global_cmvn: Optional[torch.nn.Module], + preprocessing: Optional[torch.nn.Module], + backbone: torch.nn.Module, classifier: torch.nn.Module): super().__init__() self.idim = idim self.odim = odim diff --git a/kws/model/max_pooling_RHE.py b/kws/model/max_pooling_RHE.py index 17a0b58..7963ec1 100644 --- a/kws/model/max_pooling_RHE.py +++ b/kws/model/max_pooling_RHE.py @@ -39,7 +39,7 @@ def RHE(indice: torch.Tensor, k: int): reserve.append(indice[i]) rm_s = max(indice[i] - k, 0) rm_e = min(indice[i] + k, lenght) - available_indice[rm_s : rm_e + 1] = 0 + available_indice[rm_s:rm_e + 1] = 0 else: continue @@ -48,7 +48,9 @@ def RHE(indice: torch.Tensor, k: int): return torch.tensor(reserve).long() -def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float = 10): +def downsample_training_sample_and_calculate_loss(logits, + targets, + ratio: float = 10): num_training = 0 loss = 0 for i in range(len(logits)): @@ -69,8 +71,11 @@ def downsample_training_sample_and_calculate_loss(logits, targets, ratio: float return loss / num_training -def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio=1): - +def max_pooling_RHE_binary_CE(logits, + targets, + lengths, + RHE_thr=10000, + max_ratio=1): """Max-pooling loss with regional hard example mining For each keyword utterance, select the frame with the highest posterior. The keyword is triggered when any of the frames is triggered. @@ -108,7 +113,8 @@ def max_pooling_RHE_binary_CE(logits, targets, lengths, RHE_thr=10000, max_ratio if logits[i, max_idx, j] >= 0.5: num_hit += 1 else: - sorted_logits, sorted_index = torch.sort(logits[i, :end_idx], dim=0) + sorted_logits, sorted_index = torch.sort(logits[i, :end_idx], + dim=0) reversed_index = torch.flip(sorted_index, dims=[0]) selected_indexes = RHE(reversed_index[:, j], RHE_thr) new_logits[j].append(logits[i, selected_indexes, j])