diff --git a/examples/hey_snips/s0/conf/ds_tcn.yaml b/examples/hey_snips/s0/conf/ds_tcn.yaml index 6ef4407..73487ee 100644 --- a/examples/hey_snips/s0/conf/ds_tcn.yaml +++ b/examples/hey_snips/s0/conf/ds_tcn.yaml @@ -5,6 +5,8 @@ dataset_conf: resample_conf: resample_rate: 16000 speed_perturb: false + reverb_prob: 0.2 + noise_prob: 0.3 feature_extraction_conf: feature_type: 'fbank' num_mel_bins: 40 diff --git a/examples/hey_snips/s0/run.sh b/examples/hey_snips/s0/run.sh index 28e4fc5..1194267 100755 --- a/examples/hey_snips/s0/run.sh +++ b/examples/hey_snips/s0/run.sh @@ -20,6 +20,8 @@ num_average=30 score_checkpoint=$dir/avg_${num_average}.pt download_dir=./data/local # your data dir +noise_lmdb= +reverb_lmdb= . tools/parse_options.sh || exit 1; @@ -78,6 +80,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then --min_duration 50 \ --seed 777 \ $cmvn_opts \ + ${reverb_lmdb:+--reverb_lmdb $reverb_lmdb} \ + ${noise_lmdb:+--noise_lmdb $noise_lmdb} \ ${checkpoint:+--checkpoint $checkpoint} fi diff --git a/requirements.txt b/requirements.txt index 2585d10..fd9745e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,6 @@ flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 +lmdb +scipy +tqdm \ No newline at end of file diff --git a/tools/make_lmdb.py b/tools/make_lmdb.py new file mode 100644 index 0000000..2b1ce38 --- /dev/null +++ b/tools/make_lmdb.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +import pickle + +import lmdb +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser(description='') + parser.add_argument('in_scp_file', help='input scp file') + parser.add_argument('out_lmdb', help='output lmdb') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + db = lmdb.open(args.out_lmdb, map_size=int(math.pow(1024, 4))) # 1TB + # txn is for Transaciton + txn = db.begin(write=True) + keys = [] + with open(args.in_scp_file, 'r', encoding='utf8') as fin: + lines = fin.readlines() + for i, line in enumerate(tqdm(lines)): + arr = line.strip().split() + assert len(arr) == 2 + key, wav = arr[0], arr[1] + keys.append(key) + with open(wav, 'rb') as fin: + data = fin.read() + txn.put(key.encode(), data) + # Write flush to disk + if i % 100 == 0: + txn.commit() + txn = db.begin(write=True) + txn.commit() + with db.begin(write=True) as txn: + txn.put(b'__keys__', pickle.dumps(keys)) + db.sync() + db.close() + + +if __name__ == '__main__': + main() diff --git a/wekws/bin/train.py b/wekws/bin/train.py index d3f57fc..632a240 100644 --- a/wekws/bin/train.py +++ b/wekws/bin/train.py @@ -77,6 +77,12 @@ def get_args(): default=100, type=int, help='prefetch number') + parser.add_argument('--reverb_lmdb', + default=None, + help='reverb lmdb file') + parser.add_argument('--noise_lmdb', + default=None, + help='noise lmdb file') args = parser.parse_args() return args @@ -106,7 +112,10 @@ def main(): cv_conf['spec_aug'] = False cv_conf['shuffle'] = False - train_dataset = Dataset(args.train_data, train_conf) + train_dataset = Dataset(args.train_data, + train_conf, + reverb_lmdb=args.reverb_lmdb, + noise_lmdb=args.noise_lmdb) cv_dataset = Dataset(args.cv_data, cv_conf) train_data_loader = DataLoader(train_dataset, diff --git a/wekws/dataset/dataset.py b/wekws/dataset/dataset.py index f7d12c6..74aaf4d 100644 --- a/wekws/dataset/dataset.py +++ b/wekws/dataset/dataset.py @@ -20,6 +20,7 @@ from torch.utils.data import IterableDataset import wekws.dataset.processor as processor from wekws.utils.file_utils import read_lists +from wekws.dataset.lmdb_data import LmdbData class Processor(IterableDataset): @@ -112,7 +113,10 @@ class DataList(IterableDataset): yield data -def Dataset(data_list_file, conf, partition=True): +def Dataset(data_list_file, conf, + partition=True, + reverb_lmdb=None, + noise_lmdb=None): """ Construct dataset from arguments We have two shuffle stage in the Dataset. The first is global @@ -122,6 +126,8 @@ def Dataset(data_list_file, conf, partition=True): Args: data_type(str): raw/shard partition(bool): whether to do data partition in terms of rank + reverb_lmdb: reverb data source lmdb file + noise_lmdb: noise data source lmdb file """ lists = read_lists(data_list_file) shuffle = conf.get('shuffle', True) @@ -136,6 +142,14 @@ def Dataset(data_list_file, conf, partition=True): speed_perturb = conf.get('speed_perturb', False) if speed_perturb: dataset = Processor(dataset, processor.speed_perturb) + if reverb_lmdb and conf.get('reverb_prob', 0) > 0: + reverb_data = LmdbData(reverb_lmdb) + dataset = Processor(dataset, processor.add_reverb, + reverb_data, conf['reverb_prob']) + if noise_lmdb and conf.get('noise_prob', 0) > 0: + noise_data = LmdbData(noise_lmdb) + dataset = Processor(dataset, processor.add_noise, + noise_data, conf['noise_prob']) feature_extraction_conf = conf.get('feature_extraction_conf', {}) if feature_extraction_conf['feature_type'] == 'mfcc': dataset = Processor(dataset, processor.compute_mfcc, diff --git a/wekws/dataset/lmdb_data.py b/wekws/dataset/lmdb_data.py new file mode 100644 index 0000000..7c5757c --- /dev/null +++ b/wekws/dataset/lmdb_data.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import pickle + +import lmdb + + +class LmdbData: + + def __init__(self, lmdb_file): + self.db = lmdb.open(lmdb_file, + readonly=True, + lock=False, + readahead=False) + with self.db.begin(write=False) as txn: + obj = txn.get(b'__keys__') + assert obj is not None + self.keys = pickle.loads(obj) + assert isinstance(self.keys, list) + + def random_one(self): + assert len(self.keys) > 0 + index = random.randint(0, len(self.keys) - 1) + key = self.keys[index] + with self.db.begin(write=False) as txn: + value = txn.get(key.encode()) + assert value is not None + return key, value + + def __del__(self): + self.db.close() + + +if __name__ == '__main__': + import sys + db = LmdbData(sys.argv[1]) + key, _ = db.random_one() + print(key) + key, _ = db.random_one() + print(key) diff --git a/wekws/dataset/processor.py b/wekws/dataset/processor.py index 0fd1d84..988b3a2 100644 --- a/wekws/dataset/processor.py +++ b/wekws/dataset/processor.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import logging import json import random +import numpy as np +from scipy import signal +from scipy.io import wavfile import torch import torchaudio import torchaudio.compliance.kaldi as kaldi @@ -304,3 +308,58 @@ def padding(data): batch_first=True, padding_value=0) yield (sorted_keys, padded_feats, sorted_labels, feats_lengths) + + +def add_reverb(data, reverb_source, aug_prob): + for sample in data: + assert 'wav' in sample + if aug_prob > random.random(): + audio = sample['wav'].numpy()[0] + audio_len = audio.shape[0] + _, rir_data = reverb_source.random_one() + rir_io = io.BytesIO(rir_data) + _, rir_audio = wavfile.read(rir_io) + rir_audio = rir_audio.astype(np.float32) + rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2)) + out_audio = signal.convolve(audio, rir_audio, + mode='full')[:audio_len] + out_audio = torch.from_numpy(out_audio) + out_audio = torch.unsqueeze(out_audio, 0) + sample['wav'] = out_audio + yield sample + + +def add_noise(data, noise_source, aug_prob): + for sample in data: + assert 'wav' in sample + assert 'key' in sample + if aug_prob > random.random(): + audio = sample['wav'].numpy()[0] + audio_len = audio.shape[0] + audio_db = 10 * np.log10(np.mean(audio**2) + 1e-4) + key, noise_data = noise_source.random_one() + if key.startswith('noise'): + snr_range = [0, 15] + elif key.startswith('speech'): + snr_range = [5, 30] + elif key.startswith('music'): + snr_range = [5, 15] + else: + snr_range = [0, 15] + _, noise_audio = wavfile.read(io.BytesIO(noise_data)) + noise_audio = noise_audio.astype(np.float32) + if noise_audio.shape[0] > audio_len: + start = random.randint(0, noise_audio.shape[0] - audio_len) + noise_audio = noise_audio[start:start + audio_len] + else: + # Resize will repeat copy + noise_audio = np.resize(noise_audio, (audio_len, )) + noise_snr = random.uniform(snr_range[0], snr_range[1]) + noise_db = 10 * np.log10(np.mean(noise_audio**2) + 1e-4) + noise_audio = np.sqrt(10**( + (audio_db - noise_db - noise_snr) / 10)) * noise_audio + out_audio = audio + noise_audio + out_audio = torch.from_numpy(out_audio) + out_audio = torch.unsqueeze(out_audio, 0) + sample['wav'] = out_audio + yield sample