[wekws] add online noise and rir argumentation (#115)

* [wekws] add online noise and rir  argumentation

* format

* format

* update copyright

Co-authored-by: menglong.xu <menglong.xu>
This commit is contained in:
Menglong Xu 2022-11-28 21:12:26 +08:00 committed by GitHub
parent 5c6088f947
commit 6da85d4662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 205 additions and 2 deletions

View File

@ -5,6 +5,8 @@ dataset_conf:
resample_conf:
resample_rate: 16000
speed_perturb: false
reverb_prob: 0.2
noise_prob: 0.3
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40

View File

@ -20,6 +20,8 @@ num_average=30
score_checkpoint=$dir/avg_${num_average}.pt
download_dir=./data/local # your data dir
noise_lmdb=
reverb_lmdb=
. tools/parse_options.sh || exit 1;
@ -78,6 +80,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
--min_duration 50 \
--seed 777 \
$cmvn_opts \
${reverb_lmdb:+--reverb_lmdb $reverb_lmdb} \
${noise_lmdb:+--noise_lmdb $noise_lmdb} \
${checkpoint:+--checkpoint $checkpoint}
fi

View File

@ -11,3 +11,6 @@ flake8-pyi==20.5.0
mccabe
pycodestyle==2.6.0
pyflakes==2.2.0
lmdb
scipy
tqdm

59
tools/make_lmdb.py Normal file
View File

@ -0,0 +1,59 @@
# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import pickle
import lmdb
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument('in_scp_file', help='input scp file')
parser.add_argument('out_lmdb', help='output lmdb')
args = parser.parse_args()
return args
def main():
args = get_args()
db = lmdb.open(args.out_lmdb, map_size=int(math.pow(1024, 4))) # 1TB
# txn is for Transaciton
txn = db.begin(write=True)
keys = []
with open(args.in_scp_file, 'r', encoding='utf8') as fin:
lines = fin.readlines()
for i, line in enumerate(tqdm(lines)):
arr = line.strip().split()
assert len(arr) == 2
key, wav = arr[0], arr[1]
keys.append(key)
with open(wav, 'rb') as fin:
data = fin.read()
txn.put(key.encode(), data)
# Write flush to disk
if i % 100 == 0:
txn.commit()
txn = db.begin(write=True)
txn.commit()
with db.begin(write=True) as txn:
txn.put(b'__keys__', pickle.dumps(keys))
db.sync()
db.close()
if __name__ == '__main__':
main()

View File

@ -77,6 +77,12 @@ def get_args():
default=100,
type=int,
help='prefetch number')
parser.add_argument('--reverb_lmdb',
default=None,
help='reverb lmdb file')
parser.add_argument('--noise_lmdb',
default=None,
help='noise lmdb file')
args = parser.parse_args()
return args
@ -106,7 +112,10 @@ def main():
cv_conf['spec_aug'] = False
cv_conf['shuffle'] = False
train_dataset = Dataset(args.train_data, train_conf)
train_dataset = Dataset(args.train_data,
train_conf,
reverb_lmdb=args.reverb_lmdb,
noise_lmdb=args.noise_lmdb)
cv_dataset = Dataset(args.cv_data, cv_conf)
train_data_loader = DataLoader(train_dataset,

View File

@ -20,6 +20,7 @@ from torch.utils.data import IterableDataset
import wekws.dataset.processor as processor
from wekws.utils.file_utils import read_lists
from wekws.dataset.lmdb_data import LmdbData
class Processor(IterableDataset):
@ -112,7 +113,10 @@ class DataList(IterableDataset):
yield data
def Dataset(data_list_file, conf, partition=True):
def Dataset(data_list_file, conf,
partition=True,
reverb_lmdb=None,
noise_lmdb=None):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
@ -122,6 +126,8 @@ def Dataset(data_list_file, conf, partition=True):
Args:
data_type(str): raw/shard
partition(bool): whether to do data partition in terms of rank
reverb_lmdb: reverb data source lmdb file
noise_lmdb: noise data source lmdb file
"""
lists = read_lists(data_list_file)
shuffle = conf.get('shuffle', True)
@ -136,6 +142,14 @@ def Dataset(data_list_file, conf, partition=True):
speed_perturb = conf.get('speed_perturb', False)
if speed_perturb:
dataset = Processor(dataset, processor.speed_perturb)
if reverb_lmdb and conf.get('reverb_prob', 0) > 0:
reverb_data = LmdbData(reverb_lmdb)
dataset = Processor(dataset, processor.add_reverb,
reverb_data, conf['reverb_prob'])
if noise_lmdb and conf.get('noise_prob', 0) > 0:
noise_data = LmdbData(noise_lmdb)
dataset = Processor(dataset, processor.add_noise,
noise_data, conf['noise_prob'])
feature_extraction_conf = conf.get('feature_extraction_conf', {})
if feature_extraction_conf['feature_type'] == 'mfcc':
dataset = Processor(dataset, processor.compute_mfcc,

View File

@ -0,0 +1,53 @@
# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import pickle
import lmdb
class LmdbData:
def __init__(self, lmdb_file):
self.db = lmdb.open(lmdb_file,
readonly=True,
lock=False,
readahead=False)
with self.db.begin(write=False) as txn:
obj = txn.get(b'__keys__')
assert obj is not None
self.keys = pickle.loads(obj)
assert isinstance(self.keys, list)
def random_one(self):
assert len(self.keys) > 0
index = random.randint(0, len(self.keys) - 1)
key = self.keys[index]
with self.db.begin(write=False) as txn:
value = txn.get(key.encode())
assert value is not None
return key, value
def __del__(self):
self.db.close()
if __name__ == '__main__':
import sys
db = LmdbData(sys.argv[1])
key, _ = db.random_one()
print(key)
key, _ = db.random_one()
print(key)

View File

@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import json
import random
import numpy as np
from scipy import signal
from scipy.io import wavfile
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
@ -304,3 +308,58 @@ def padding(data):
batch_first=True,
padding_value=0)
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)
def add_reverb(data, reverb_source, aug_prob):
for sample in data:
assert 'wav' in sample
if aug_prob > random.random():
audio = sample['wav'].numpy()[0]
audio_len = audio.shape[0]
_, rir_data = reverb_source.random_one()
rir_io = io.BytesIO(rir_data)
_, rir_audio = wavfile.read(rir_io)
rir_audio = rir_audio.astype(np.float32)
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
out_audio = signal.convolve(audio, rir_audio,
mode='full')[:audio_len]
out_audio = torch.from_numpy(out_audio)
out_audio = torch.unsqueeze(out_audio, 0)
sample['wav'] = out_audio
yield sample
def add_noise(data, noise_source, aug_prob):
for sample in data:
assert 'wav' in sample
assert 'key' in sample
if aug_prob > random.random():
audio = sample['wav'].numpy()[0]
audio_len = audio.shape[0]
audio_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
key, noise_data = noise_source.random_one()
if key.startswith('noise'):
snr_range = [0, 15]
elif key.startswith('speech'):
snr_range = [5, 30]
elif key.startswith('music'):
snr_range = [5, 15]
else:
snr_range = [0, 15]
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
noise_audio = noise_audio.astype(np.float32)
if noise_audio.shape[0] > audio_len:
start = random.randint(0, noise_audio.shape[0] - audio_len)
noise_audio = noise_audio[start:start + audio_len]
else:
# Resize will repeat copy
noise_audio = np.resize(noise_audio, (audio_len, ))
noise_snr = random.uniform(snr_range[0], snr_range[1])
noise_db = 10 * np.log10(np.mean(noise_audio**2) + 1e-4)
noise_audio = np.sqrt(10**(
(audio_db - noise_db - noise_snr) / 10)) * noise_audio
out_audio = audio + noise_audio
out_audio = torch.from_numpy(out_audio)
out_audio = torch.unsqueeze(out_audio, 0)
sample['wav'] = out_audio
yield sample