[wekws] add online noise and rir argumentation (#115)
* [wekws] add online noise and rir argumentation * format * format * update copyright Co-authored-by: menglong.xu <menglong.xu>
This commit is contained in:
parent
5c6088f947
commit
6da85d4662
@ -5,6 +5,8 @@ dataset_conf:
|
||||
resample_conf:
|
||||
resample_rate: 16000
|
||||
speed_perturb: false
|
||||
reverb_prob: 0.2
|
||||
noise_prob: 0.3
|
||||
feature_extraction_conf:
|
||||
feature_type: 'fbank'
|
||||
num_mel_bins: 40
|
||||
|
||||
@ -20,6 +20,8 @@ num_average=30
|
||||
score_checkpoint=$dir/avg_${num_average}.pt
|
||||
|
||||
download_dir=./data/local # your data dir
|
||||
noise_lmdb=
|
||||
reverb_lmdb=
|
||||
|
||||
. tools/parse_options.sh || exit 1;
|
||||
|
||||
@ -78,6 +80,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
--min_duration 50 \
|
||||
--seed 777 \
|
||||
$cmvn_opts \
|
||||
${reverb_lmdb:+--reverb_lmdb $reverb_lmdb} \
|
||||
${noise_lmdb:+--noise_lmdb $noise_lmdb} \
|
||||
${checkpoint:+--checkpoint $checkpoint}
|
||||
fi
|
||||
|
||||
|
||||
@ -11,3 +11,6 @@ flake8-pyi==20.5.0
|
||||
mccabe
|
||||
pycodestyle==2.6.0
|
||||
pyflakes==2.2.0
|
||||
lmdb
|
||||
scipy
|
||||
tqdm
|
||||
59
tools/make_lmdb.py
Normal file
59
tools/make_lmdb.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import pickle
|
||||
|
||||
import lmdb
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('in_scp_file', help='input scp file')
|
||||
parser.add_argument('out_lmdb', help='output lmdb')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
db = lmdb.open(args.out_lmdb, map_size=int(math.pow(1024, 4))) # 1TB
|
||||
# txn is for Transaciton
|
||||
txn = db.begin(write=True)
|
||||
keys = []
|
||||
with open(args.in_scp_file, 'r', encoding='utf8') as fin:
|
||||
lines = fin.readlines()
|
||||
for i, line in enumerate(tqdm(lines)):
|
||||
arr = line.strip().split()
|
||||
assert len(arr) == 2
|
||||
key, wav = arr[0], arr[1]
|
||||
keys.append(key)
|
||||
with open(wav, 'rb') as fin:
|
||||
data = fin.read()
|
||||
txn.put(key.encode(), data)
|
||||
# Write flush to disk
|
||||
if i % 100 == 0:
|
||||
txn.commit()
|
||||
txn = db.begin(write=True)
|
||||
txn.commit()
|
||||
with db.begin(write=True) as txn:
|
||||
txn.put(b'__keys__', pickle.dumps(keys))
|
||||
db.sync()
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -77,6 +77,12 @@ def get_args():
|
||||
default=100,
|
||||
type=int,
|
||||
help='prefetch number')
|
||||
parser.add_argument('--reverb_lmdb',
|
||||
default=None,
|
||||
help='reverb lmdb file')
|
||||
parser.add_argument('--noise_lmdb',
|
||||
default=None,
|
||||
help='noise lmdb file')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@ -106,7 +112,10 @@ def main():
|
||||
cv_conf['spec_aug'] = False
|
||||
cv_conf['shuffle'] = False
|
||||
|
||||
train_dataset = Dataset(args.train_data, train_conf)
|
||||
train_dataset = Dataset(args.train_data,
|
||||
train_conf,
|
||||
reverb_lmdb=args.reverb_lmdb,
|
||||
noise_lmdb=args.noise_lmdb)
|
||||
cv_dataset = Dataset(args.cv_data, cv_conf)
|
||||
|
||||
train_data_loader = DataLoader(train_dataset,
|
||||
|
||||
@ -20,6 +20,7 @@ from torch.utils.data import IterableDataset
|
||||
|
||||
import wekws.dataset.processor as processor
|
||||
from wekws.utils.file_utils import read_lists
|
||||
from wekws.dataset.lmdb_data import LmdbData
|
||||
|
||||
|
||||
class Processor(IterableDataset):
|
||||
@ -112,7 +113,10 @@ class DataList(IterableDataset):
|
||||
yield data
|
||||
|
||||
|
||||
def Dataset(data_list_file, conf, partition=True):
|
||||
def Dataset(data_list_file, conf,
|
||||
partition=True,
|
||||
reverb_lmdb=None,
|
||||
noise_lmdb=None):
|
||||
""" Construct dataset from arguments
|
||||
|
||||
We have two shuffle stage in the Dataset. The first is global
|
||||
@ -122,6 +126,8 @@ def Dataset(data_list_file, conf, partition=True):
|
||||
Args:
|
||||
data_type(str): raw/shard
|
||||
partition(bool): whether to do data partition in terms of rank
|
||||
reverb_lmdb: reverb data source lmdb file
|
||||
noise_lmdb: noise data source lmdb file
|
||||
"""
|
||||
lists = read_lists(data_list_file)
|
||||
shuffle = conf.get('shuffle', True)
|
||||
@ -136,6 +142,14 @@ def Dataset(data_list_file, conf, partition=True):
|
||||
speed_perturb = conf.get('speed_perturb', False)
|
||||
if speed_perturb:
|
||||
dataset = Processor(dataset, processor.speed_perturb)
|
||||
if reverb_lmdb and conf.get('reverb_prob', 0) > 0:
|
||||
reverb_data = LmdbData(reverb_lmdb)
|
||||
dataset = Processor(dataset, processor.add_reverb,
|
||||
reverb_data, conf['reverb_prob'])
|
||||
if noise_lmdb and conf.get('noise_prob', 0) > 0:
|
||||
noise_data = LmdbData(noise_lmdb)
|
||||
dataset = Processor(dataset, processor.add_noise,
|
||||
noise_data, conf['noise_prob'])
|
||||
feature_extraction_conf = conf.get('feature_extraction_conf', {})
|
||||
if feature_extraction_conf['feature_type'] == 'mfcc':
|
||||
dataset = Processor(dataset, processor.compute_mfcc,
|
||||
|
||||
53
wekws/dataset/lmdb_data.py
Normal file
53
wekws/dataset/lmdb_data.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import pickle
|
||||
|
||||
import lmdb
|
||||
|
||||
|
||||
class LmdbData:
|
||||
|
||||
def __init__(self, lmdb_file):
|
||||
self.db = lmdb.open(lmdb_file,
|
||||
readonly=True,
|
||||
lock=False,
|
||||
readahead=False)
|
||||
with self.db.begin(write=False) as txn:
|
||||
obj = txn.get(b'__keys__')
|
||||
assert obj is not None
|
||||
self.keys = pickle.loads(obj)
|
||||
assert isinstance(self.keys, list)
|
||||
|
||||
def random_one(self):
|
||||
assert len(self.keys) > 0
|
||||
index = random.randint(0, len(self.keys) - 1)
|
||||
key = self.keys[index]
|
||||
with self.db.begin(write=False) as txn:
|
||||
value = txn.get(key.encode())
|
||||
assert value is not None
|
||||
return key, value
|
||||
|
||||
def __del__(self):
|
||||
self.db.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
db = LmdbData(sys.argv[1])
|
||||
key, _ = db.random_one()
|
||||
print(key)
|
||||
key, _ = db.random_one()
|
||||
print(key)
|
||||
@ -12,10 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import logging
|
||||
import json
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.io import wavfile
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
@ -304,3 +308,58 @@ def padding(data):
|
||||
batch_first=True,
|
||||
padding_value=0)
|
||||
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)
|
||||
|
||||
|
||||
def add_reverb(data, reverb_source, aug_prob):
|
||||
for sample in data:
|
||||
assert 'wav' in sample
|
||||
if aug_prob > random.random():
|
||||
audio = sample['wav'].numpy()[0]
|
||||
audio_len = audio.shape[0]
|
||||
_, rir_data = reverb_source.random_one()
|
||||
rir_io = io.BytesIO(rir_data)
|
||||
_, rir_audio = wavfile.read(rir_io)
|
||||
rir_audio = rir_audio.astype(np.float32)
|
||||
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
|
||||
out_audio = signal.convolve(audio, rir_audio,
|
||||
mode='full')[:audio_len]
|
||||
out_audio = torch.from_numpy(out_audio)
|
||||
out_audio = torch.unsqueeze(out_audio, 0)
|
||||
sample['wav'] = out_audio
|
||||
yield sample
|
||||
|
||||
|
||||
def add_noise(data, noise_source, aug_prob):
|
||||
for sample in data:
|
||||
assert 'wav' in sample
|
||||
assert 'key' in sample
|
||||
if aug_prob > random.random():
|
||||
audio = sample['wav'].numpy()[0]
|
||||
audio_len = audio.shape[0]
|
||||
audio_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
|
||||
key, noise_data = noise_source.random_one()
|
||||
if key.startswith('noise'):
|
||||
snr_range = [0, 15]
|
||||
elif key.startswith('speech'):
|
||||
snr_range = [5, 30]
|
||||
elif key.startswith('music'):
|
||||
snr_range = [5, 15]
|
||||
else:
|
||||
snr_range = [0, 15]
|
||||
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
|
||||
noise_audio = noise_audio.astype(np.float32)
|
||||
if noise_audio.shape[0] > audio_len:
|
||||
start = random.randint(0, noise_audio.shape[0] - audio_len)
|
||||
noise_audio = noise_audio[start:start + audio_len]
|
||||
else:
|
||||
# Resize will repeat copy
|
||||
noise_audio = np.resize(noise_audio, (audio_len, ))
|
||||
noise_snr = random.uniform(snr_range[0], snr_range[1])
|
||||
noise_db = 10 * np.log10(np.mean(noise_audio**2) + 1e-4)
|
||||
noise_audio = np.sqrt(10**(
|
||||
(audio_db - noise_db - noise_snr) / 10)) * noise_audio
|
||||
out_audio = audio + noise_audio
|
||||
out_audio = torch.from_numpy(out_audio)
|
||||
out_audio = torch.unsqueeze(out_audio, 0)
|
||||
sample['wav'] = out_audio
|
||||
yield sample
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user