[wekws] add online noise and rir argumentation (#115)
* [wekws] add online noise and rir argumentation * format * format * update copyright Co-authored-by: menglong.xu <menglong.xu>
This commit is contained in:
parent
5c6088f947
commit
6da85d4662
@ -5,6 +5,8 @@ dataset_conf:
|
|||||||
resample_conf:
|
resample_conf:
|
||||||
resample_rate: 16000
|
resample_rate: 16000
|
||||||
speed_perturb: false
|
speed_perturb: false
|
||||||
|
reverb_prob: 0.2
|
||||||
|
noise_prob: 0.3
|
||||||
feature_extraction_conf:
|
feature_extraction_conf:
|
||||||
feature_type: 'fbank'
|
feature_type: 'fbank'
|
||||||
num_mel_bins: 40
|
num_mel_bins: 40
|
||||||
|
|||||||
@ -20,6 +20,8 @@ num_average=30
|
|||||||
score_checkpoint=$dir/avg_${num_average}.pt
|
score_checkpoint=$dir/avg_${num_average}.pt
|
||||||
|
|
||||||
download_dir=./data/local # your data dir
|
download_dir=./data/local # your data dir
|
||||||
|
noise_lmdb=
|
||||||
|
reverb_lmdb=
|
||||||
|
|
||||||
. tools/parse_options.sh || exit 1;
|
. tools/parse_options.sh || exit 1;
|
||||||
|
|
||||||
@ -78,6 +80,8 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
|||||||
--min_duration 50 \
|
--min_duration 50 \
|
||||||
--seed 777 \
|
--seed 777 \
|
||||||
$cmvn_opts \
|
$cmvn_opts \
|
||||||
|
${reverb_lmdb:+--reverb_lmdb $reverb_lmdb} \
|
||||||
|
${noise_lmdb:+--noise_lmdb $noise_lmdb} \
|
||||||
${checkpoint:+--checkpoint $checkpoint}
|
${checkpoint:+--checkpoint $checkpoint}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@ -11,3 +11,6 @@ flake8-pyi==20.5.0
|
|||||||
mccabe
|
mccabe
|
||||||
pycodestyle==2.6.0
|
pycodestyle==2.6.0
|
||||||
pyflakes==2.2.0
|
pyflakes==2.2.0
|
||||||
|
lmdb
|
||||||
|
scipy
|
||||||
|
tqdm
|
||||||
59
tools/make_lmdb.py
Normal file
59
tools/make_lmdb.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import lmdb
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description='')
|
||||||
|
parser.add_argument('in_scp_file', help='input scp file')
|
||||||
|
parser.add_argument('out_lmdb', help='output lmdb')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
db = lmdb.open(args.out_lmdb, map_size=int(math.pow(1024, 4))) # 1TB
|
||||||
|
# txn is for Transaciton
|
||||||
|
txn = db.begin(write=True)
|
||||||
|
keys = []
|
||||||
|
with open(args.in_scp_file, 'r', encoding='utf8') as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
for i, line in enumerate(tqdm(lines)):
|
||||||
|
arr = line.strip().split()
|
||||||
|
assert len(arr) == 2
|
||||||
|
key, wav = arr[0], arr[1]
|
||||||
|
keys.append(key)
|
||||||
|
with open(wav, 'rb') as fin:
|
||||||
|
data = fin.read()
|
||||||
|
txn.put(key.encode(), data)
|
||||||
|
# Write flush to disk
|
||||||
|
if i % 100 == 0:
|
||||||
|
txn.commit()
|
||||||
|
txn = db.begin(write=True)
|
||||||
|
txn.commit()
|
||||||
|
with db.begin(write=True) as txn:
|
||||||
|
txn.put(b'__keys__', pickle.dumps(keys))
|
||||||
|
db.sync()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@ -77,6 +77,12 @@ def get_args():
|
|||||||
default=100,
|
default=100,
|
||||||
type=int,
|
type=int,
|
||||||
help='prefetch number')
|
help='prefetch number')
|
||||||
|
parser.add_argument('--reverb_lmdb',
|
||||||
|
default=None,
|
||||||
|
help='reverb lmdb file')
|
||||||
|
parser.add_argument('--noise_lmdb',
|
||||||
|
default=None,
|
||||||
|
help='noise lmdb file')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
@ -106,7 +112,10 @@ def main():
|
|||||||
cv_conf['spec_aug'] = False
|
cv_conf['spec_aug'] = False
|
||||||
cv_conf['shuffle'] = False
|
cv_conf['shuffle'] = False
|
||||||
|
|
||||||
train_dataset = Dataset(args.train_data, train_conf)
|
train_dataset = Dataset(args.train_data,
|
||||||
|
train_conf,
|
||||||
|
reverb_lmdb=args.reverb_lmdb,
|
||||||
|
noise_lmdb=args.noise_lmdb)
|
||||||
cv_dataset = Dataset(args.cv_data, cv_conf)
|
cv_dataset = Dataset(args.cv_data, cv_conf)
|
||||||
|
|
||||||
train_data_loader = DataLoader(train_dataset,
|
train_data_loader = DataLoader(train_dataset,
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from torch.utils.data import IterableDataset
|
|||||||
|
|
||||||
import wekws.dataset.processor as processor
|
import wekws.dataset.processor as processor
|
||||||
from wekws.utils.file_utils import read_lists
|
from wekws.utils.file_utils import read_lists
|
||||||
|
from wekws.dataset.lmdb_data import LmdbData
|
||||||
|
|
||||||
|
|
||||||
class Processor(IterableDataset):
|
class Processor(IterableDataset):
|
||||||
@ -112,7 +113,10 @@ class DataList(IterableDataset):
|
|||||||
yield data
|
yield data
|
||||||
|
|
||||||
|
|
||||||
def Dataset(data_list_file, conf, partition=True):
|
def Dataset(data_list_file, conf,
|
||||||
|
partition=True,
|
||||||
|
reverb_lmdb=None,
|
||||||
|
noise_lmdb=None):
|
||||||
""" Construct dataset from arguments
|
""" Construct dataset from arguments
|
||||||
|
|
||||||
We have two shuffle stage in the Dataset. The first is global
|
We have two shuffle stage in the Dataset. The first is global
|
||||||
@ -122,6 +126,8 @@ def Dataset(data_list_file, conf, partition=True):
|
|||||||
Args:
|
Args:
|
||||||
data_type(str): raw/shard
|
data_type(str): raw/shard
|
||||||
partition(bool): whether to do data partition in terms of rank
|
partition(bool): whether to do data partition in terms of rank
|
||||||
|
reverb_lmdb: reverb data source lmdb file
|
||||||
|
noise_lmdb: noise data source lmdb file
|
||||||
"""
|
"""
|
||||||
lists = read_lists(data_list_file)
|
lists = read_lists(data_list_file)
|
||||||
shuffle = conf.get('shuffle', True)
|
shuffle = conf.get('shuffle', True)
|
||||||
@ -136,6 +142,14 @@ def Dataset(data_list_file, conf, partition=True):
|
|||||||
speed_perturb = conf.get('speed_perturb', False)
|
speed_perturb = conf.get('speed_perturb', False)
|
||||||
if speed_perturb:
|
if speed_perturb:
|
||||||
dataset = Processor(dataset, processor.speed_perturb)
|
dataset = Processor(dataset, processor.speed_perturb)
|
||||||
|
if reverb_lmdb and conf.get('reverb_prob', 0) > 0:
|
||||||
|
reverb_data = LmdbData(reverb_lmdb)
|
||||||
|
dataset = Processor(dataset, processor.add_reverb,
|
||||||
|
reverb_data, conf['reverb_prob'])
|
||||||
|
if noise_lmdb and conf.get('noise_prob', 0) > 0:
|
||||||
|
noise_data = LmdbData(noise_lmdb)
|
||||||
|
dataset = Processor(dataset, processor.add_noise,
|
||||||
|
noise_data, conf['noise_prob'])
|
||||||
feature_extraction_conf = conf.get('feature_extraction_conf', {})
|
feature_extraction_conf = conf.get('feature_extraction_conf', {})
|
||||||
if feature_extraction_conf['feature_type'] == 'mfcc':
|
if feature_extraction_conf['feature_type'] == 'mfcc':
|
||||||
dataset = Processor(dataset, processor.compute_mfcc,
|
dataset = Processor(dataset, processor.compute_mfcc,
|
||||||
|
|||||||
53
wekws/dataset/lmdb_data.py
Normal file
53
wekws/dataset/lmdb_data.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import lmdb
|
||||||
|
|
||||||
|
|
||||||
|
class LmdbData:
|
||||||
|
|
||||||
|
def __init__(self, lmdb_file):
|
||||||
|
self.db = lmdb.open(lmdb_file,
|
||||||
|
readonly=True,
|
||||||
|
lock=False,
|
||||||
|
readahead=False)
|
||||||
|
with self.db.begin(write=False) as txn:
|
||||||
|
obj = txn.get(b'__keys__')
|
||||||
|
assert obj is not None
|
||||||
|
self.keys = pickle.loads(obj)
|
||||||
|
assert isinstance(self.keys, list)
|
||||||
|
|
||||||
|
def random_one(self):
|
||||||
|
assert len(self.keys) > 0
|
||||||
|
index = random.randint(0, len(self.keys) - 1)
|
||||||
|
key = self.keys[index]
|
||||||
|
with self.db.begin(write=False) as txn:
|
||||||
|
value = txn.get(key.encode())
|
||||||
|
assert value is not None
|
||||||
|
return key, value
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.db.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
db = LmdbData(sys.argv[1])
|
||||||
|
key, _ = db.random_one()
|
||||||
|
print(key)
|
||||||
|
key, _ = db.random_one()
|
||||||
|
print(key)
|
||||||
@ -12,10 +12,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy import signal
|
||||||
|
from scipy.io import wavfile
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import torchaudio.compliance.kaldi as kaldi
|
import torchaudio.compliance.kaldi as kaldi
|
||||||
@ -304,3 +308,58 @@ def padding(data):
|
|||||||
batch_first=True,
|
batch_first=True,
|
||||||
padding_value=0)
|
padding_value=0)
|
||||||
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)
|
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)
|
||||||
|
|
||||||
|
|
||||||
|
def add_reverb(data, reverb_source, aug_prob):
|
||||||
|
for sample in data:
|
||||||
|
assert 'wav' in sample
|
||||||
|
if aug_prob > random.random():
|
||||||
|
audio = sample['wav'].numpy()[0]
|
||||||
|
audio_len = audio.shape[0]
|
||||||
|
_, rir_data = reverb_source.random_one()
|
||||||
|
rir_io = io.BytesIO(rir_data)
|
||||||
|
_, rir_audio = wavfile.read(rir_io)
|
||||||
|
rir_audio = rir_audio.astype(np.float32)
|
||||||
|
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
|
||||||
|
out_audio = signal.convolve(audio, rir_audio,
|
||||||
|
mode='full')[:audio_len]
|
||||||
|
out_audio = torch.from_numpy(out_audio)
|
||||||
|
out_audio = torch.unsqueeze(out_audio, 0)
|
||||||
|
sample['wav'] = out_audio
|
||||||
|
yield sample
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise(data, noise_source, aug_prob):
|
||||||
|
for sample in data:
|
||||||
|
assert 'wav' in sample
|
||||||
|
assert 'key' in sample
|
||||||
|
if aug_prob > random.random():
|
||||||
|
audio = sample['wav'].numpy()[0]
|
||||||
|
audio_len = audio.shape[0]
|
||||||
|
audio_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
|
||||||
|
key, noise_data = noise_source.random_one()
|
||||||
|
if key.startswith('noise'):
|
||||||
|
snr_range = [0, 15]
|
||||||
|
elif key.startswith('speech'):
|
||||||
|
snr_range = [5, 30]
|
||||||
|
elif key.startswith('music'):
|
||||||
|
snr_range = [5, 15]
|
||||||
|
else:
|
||||||
|
snr_range = [0, 15]
|
||||||
|
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
|
||||||
|
noise_audio = noise_audio.astype(np.float32)
|
||||||
|
if noise_audio.shape[0] > audio_len:
|
||||||
|
start = random.randint(0, noise_audio.shape[0] - audio_len)
|
||||||
|
noise_audio = noise_audio[start:start + audio_len]
|
||||||
|
else:
|
||||||
|
# Resize will repeat copy
|
||||||
|
noise_audio = np.resize(noise_audio, (audio_len, ))
|
||||||
|
noise_snr = random.uniform(snr_range[0], snr_range[1])
|
||||||
|
noise_db = 10 * np.log10(np.mean(noise_audio**2) + 1e-4)
|
||||||
|
noise_audio = np.sqrt(10**(
|
||||||
|
(audio_db - noise_db - noise_snr) / 10)) * noise_audio
|
||||||
|
out_audio = audio + noise_audio
|
||||||
|
out_audio = torch.from_numpy(out_audio)
|
||||||
|
out_audio = torch.unsqueeze(out_audio, 0)
|
||||||
|
sample['wav'] = out_audio
|
||||||
|
yield sample
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user