427 lines
14 KiB
Python
427 lines
14 KiB
Python
# Copyright (c) 2021 Binbin Zhang
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import io
|
|
import logging
|
|
import json
|
|
import random
|
|
|
|
import numpy as np
|
|
from scipy import signal
|
|
from scipy.io import wavfile
|
|
import torch
|
|
import torchaudio
|
|
import torchaudio.compliance.kaldi as kaldi
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
|
|
def parse_raw(data):
|
|
""" Parse key/wav/txt from json line
|
|
|
|
Args:
|
|
data: Iterable[str], str is a json line has key/wav/txt
|
|
|
|
Returns:
|
|
Iterable[{key, wav, label, sample_rate}]
|
|
"""
|
|
for sample in data:
|
|
assert 'src' in sample
|
|
json_line = sample['src']
|
|
obj = json.loads(json_line)
|
|
assert 'key' in obj
|
|
assert 'wav' in obj
|
|
assert 'txt' in obj
|
|
key = obj['key']
|
|
wav_file = obj['wav']
|
|
txt = obj['txt']
|
|
try:
|
|
waveform, sample_rate = torchaudio.load(wav_file)
|
|
example = dict(key=key,
|
|
label=txt,
|
|
wav=waveform,
|
|
sample_rate=sample_rate)
|
|
yield example
|
|
except Exception as ex:
|
|
logging.warning('Failed to read {}'.format(wav_file))
|
|
|
|
|
|
def filter(data, max_length=10240, min_length=10):
|
|
""" Filter sample according to feature and label length
|
|
Inplace operation.
|
|
|
|
Args::
|
|
data: Iterable[{key, wav, label, sample_rate}]
|
|
max_length: drop utterance which is greater than max_length(10ms)
|
|
min_length: drop utterance which is less than min_length(10ms)
|
|
|
|
Returns:
|
|
Iterable[{key, wav, label, sample_rate}]
|
|
"""
|
|
for sample in data:
|
|
assert 'sample_rate' in sample
|
|
assert 'wav' in sample
|
|
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
|
num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
|
|
if num_frames < min_length:
|
|
continue
|
|
if num_frames > max_length:
|
|
continue
|
|
yield sample
|
|
|
|
|
|
def resample(data, resample_rate=16000):
|
|
""" Resample data.
|
|
Inplace operation.
|
|
|
|
Args:
|
|
data: Iterable[{key, wav, label, sample_rate}]
|
|
resample_rate: target resample rate
|
|
|
|
Returns:
|
|
Iterable[{key, wav, label, sample_rate}]
|
|
"""
|
|
for sample in data:
|
|
assert 'sample_rate' in sample
|
|
assert 'wav' in sample
|
|
sample_rate = sample['sample_rate']
|
|
waveform = sample['wav']
|
|
if sample_rate != resample_rate:
|
|
sample['sample_rate'] = resample_rate
|
|
sample['wav'] = torchaudio.transforms.Resample(
|
|
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
|
yield sample
|
|
|
|
|
|
def speed_perturb(data, speeds=None):
|
|
""" Apply speed perturb to the data.
|
|
Inplace operation.
|
|
|
|
Args:
|
|
data: Iterable[{key, wav, label, sample_rate}]
|
|
speeds(List[float]): optional speed
|
|
|
|
Returns:
|
|
Iterable[{key, wav, label, sample_rate}]
|
|
"""
|
|
if speeds is None:
|
|
speeds = [0.9, 1.0, 1.1]
|
|
for sample in data:
|
|
assert 'sample_rate' in sample
|
|
assert 'wav' in sample
|
|
sample_rate = sample['sample_rate']
|
|
waveform = sample['wav']
|
|
speed = random.choice(speeds)
|
|
if speed != 1.0:
|
|
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
|
|
waveform, sample_rate,
|
|
[['speed', str(speed)], ['rate', str(sample_rate)]])
|
|
sample['wav'] = wav
|
|
|
|
yield sample
|
|
|
|
|
|
def compute_mfcc(
|
|
data,
|
|
feature_type='mfcc',
|
|
num_ceps=80,
|
|
num_mel_bins=80,
|
|
frame_length=25,
|
|
frame_shift=10,
|
|
dither=0.0,
|
|
):
|
|
"""Extract mfcc
|
|
|
|
Args:
|
|
data: Iterable[{key, wav, label, sample_rate}]
|
|
|
|
Returns:
|
|
Iterable[{key, feat, label}]
|
|
"""
|
|
for sample in data:
|
|
assert 'sample_rate' in sample
|
|
assert 'wav' in sample
|
|
assert 'key' in sample
|
|
assert 'label' in sample
|
|
sample_rate = sample['sample_rate']
|
|
waveform = sample['wav']
|
|
waveform = waveform * (1 << 15)
|
|
# Only keep key, feat, label
|
|
mat = kaldi.mfcc(
|
|
waveform,
|
|
num_ceps=num_ceps,
|
|
num_mel_bins=num_mel_bins,
|
|
frame_length=frame_length,
|
|
frame_shift=frame_shift,
|
|
dither=dither,
|
|
energy_floor=0.0,
|
|
sample_frequency=sample_rate,
|
|
)
|
|
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
|
|
|
|
|
def compute_fbank(data,
|
|
feature_type='fbank',
|
|
num_mel_bins=23,
|
|
frame_length=25,
|
|
frame_shift=10,
|
|
dither=0.0):
|
|
""" Extract fbank
|
|
|
|
Args:
|
|
data: Iterable[{key, wav, label, sample_rate}]
|
|
|
|
Returns:
|
|
Iterable[{key, feat, label}]
|
|
"""
|
|
for sample in data:
|
|
assert 'sample_rate' in sample
|
|
assert 'wav' in sample
|
|
assert 'key' in sample
|
|
assert 'label' in sample
|
|
sample_rate = sample['sample_rate']
|
|
waveform = sample['wav']
|
|
waveform = waveform * (1 << 15)
|
|
# Only keep key, feat, label
|
|
mat = kaldi.fbank(waveform,
|
|
num_mel_bins=num_mel_bins,
|
|
frame_length=frame_length,
|
|
frame_shift=frame_shift,
|
|
dither=dither,
|
|
energy_floor=0.0,
|
|
sample_frequency=sample_rate)
|
|
yield dict(key=sample['key'], label=sample['label'], feat=mat)
|
|
|
|
|
|
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10):
|
|
""" Do spec augmentation
|
|
Inplace operation
|
|
|
|
Args:
|
|
data: Iterable[{key, feat, label}]
|
|
num_t_mask: number of time mask to apply
|
|
num_f_mask: number of freq mask to apply
|
|
max_t: max width of time mask
|
|
max_f: max width of freq mask
|
|
|
|
Returns
|
|
Iterable[{key, feat, label}]
|
|
"""
|
|
for sample in data:
|
|
assert 'feat' in sample
|
|
x = sample['feat']
|
|
assert isinstance(x, torch.Tensor)
|
|
y = x.clone().detach()
|
|
max_frames = y.size(0)
|
|
max_freq = y.size(1)
|
|
# time mask
|
|
for i in range(num_t_mask):
|
|
start = random.randint(0, max_frames - 1)
|
|
length = random.randint(1, max_t)
|
|
end = min(max_frames, start + length)
|
|
y[start:end, :] = 0
|
|
# freq mask
|
|
for i in range(num_f_mask):
|
|
start = random.randint(0, max_freq - 1)
|
|
length = random.randint(1, max_f)
|
|
end = min(max_freq, start + length)
|
|
y[:, start:end] = 0
|
|
sample['feat'] = y
|
|
yield sample
|
|
|
|
|
|
def shuffle(data, shuffle_size=1000):
|
|
""" Local shuffle the data
|
|
|
|
Args:
|
|
data: Iterable[{key, feat, label}]
|
|
shuffle_size: buffer size for shuffle
|
|
|
|
Returns:
|
|
Iterable[{key, feat, label}]
|
|
"""
|
|
buf = []
|
|
for sample in data:
|
|
buf.append(sample)
|
|
if len(buf) >= shuffle_size:
|
|
random.shuffle(buf)
|
|
for x in buf:
|
|
yield x
|
|
buf = []
|
|
# The sample left over
|
|
random.shuffle(buf)
|
|
for x in buf:
|
|
yield x
|
|
|
|
def context_expansion(data, left=1, right=1):
|
|
""" expand left and right frames
|
|
Args:
|
|
data: Iterable[{key, feat, label}]
|
|
left (int): feature left context frames
|
|
right (int): feature right context frames
|
|
|
|
Returns:
|
|
data: Iterable[{key, feat, label}]
|
|
"""
|
|
for sample in data:
|
|
index = 0
|
|
feats = sample['feat']
|
|
ctx_dim = feats.shape[0]
|
|
ctx_frm = feats.shape[1] * (left + right + 1)
|
|
feats_ctx = torch.zeros(ctx_dim, ctx_frm, dtype=torch.float32)
|
|
for lag in range(-left, right + 1):
|
|
feats_ctx[:, index:index + feats.shape[1]] = torch.roll(
|
|
feats, -lag, 0)
|
|
index = index + feats.shape[1]
|
|
|
|
# replication pad left margin
|
|
for idx in range(left):
|
|
for cpx in range(left - idx):
|
|
feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1)
|
|
* feats.shape[1]] = feats_ctx[left, :feats.shape[1]]
|
|
|
|
feats_ctx = feats_ctx[:feats_ctx.shape[0] - right]
|
|
sample['feat'] = feats_ctx
|
|
yield sample
|
|
|
|
|
|
def frame_skip(data, skip_rate=1):
|
|
""" skip frame
|
|
Args:
|
|
data: Iterable[{key, feat, label}]
|
|
skip_rate (int): take every N-frames for model input
|
|
|
|
Returns:
|
|
data: Iterable[{key, feat, label}]
|
|
"""
|
|
for sample in data:
|
|
feats_skip = sample['feat'][::skip_rate, :]
|
|
sample['feat'] = feats_skip
|
|
yield sample
|
|
|
|
def batch(data, batch_size=16):
|
|
""" Static batch the data by `batch_size`
|
|
|
|
Args:
|
|
data: Iterable[{key, feat, label}]
|
|
batch_size: batch size
|
|
|
|
Returns:
|
|
Iterable[List[{key, feat, label}]]
|
|
"""
|
|
buf = []
|
|
for sample in data:
|
|
buf.append(sample)
|
|
if len(buf) >= batch_size:
|
|
yield buf
|
|
buf = []
|
|
if len(buf) > 0:
|
|
yield buf
|
|
|
|
|
|
def padding(data):
|
|
""" Padding the data into training data
|
|
|
|
Args:
|
|
data: Iterable[List[{key, feat, label}]]
|
|
|
|
Returns:
|
|
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
|
"""
|
|
for sample in data:
|
|
assert isinstance(sample, list)
|
|
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
|
|
dtype=torch.int32)
|
|
order = torch.argsort(feats_length, descending=True)
|
|
feats_lengths = torch.tensor(
|
|
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
|
|
sorted_feats = [sample[i]['feat'] for i in order]
|
|
sorted_keys = [sample[i]['key'] for i in order]
|
|
padded_feats = pad_sequence(sorted_feats,
|
|
batch_first=True,
|
|
padding_value=0)
|
|
|
|
if isinstance(sample[0]['label'], int):
|
|
padded_labels = torch.tensor([sample[i]['label'] for i in order],
|
|
dtype=torch.int32)
|
|
label_lengths = torch.tensor([1 for i in order],
|
|
dtype=torch.int32)
|
|
else:
|
|
sorted_labels = [
|
|
torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order
|
|
]
|
|
label_lengths = torch.tensor([len(sample[i]['label']) for i in order],
|
|
dtype=torch.int32)
|
|
padded_labels = pad_sequence(
|
|
sorted_labels, batch_first=True, padding_value=-1)
|
|
yield (sorted_keys, padded_feats, padded_labels, feats_lengths, label_lengths)
|
|
|
|
|
|
def add_reverb(data, reverb_source, aug_prob):
|
|
for sample in data:
|
|
assert 'wav' in sample
|
|
if aug_prob > random.random():
|
|
audio = sample['wav'].numpy()[0]
|
|
audio_len = audio.shape[0]
|
|
_, rir_data = reverb_source.random_one()
|
|
rir_io = io.BytesIO(rir_data)
|
|
_, rir_audio = wavfile.read(rir_io)
|
|
rir_audio = rir_audio.astype(np.float32)
|
|
if len(rir_audio.shape) > 1:
|
|
rir_audio = rir_audio[:, 0]
|
|
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
|
|
out_audio = signal.convolve(audio, rir_audio,
|
|
mode='full')[:audio_len]
|
|
out_audio = torch.from_numpy(out_audio)
|
|
out_audio = torch.unsqueeze(out_audio, 0)
|
|
sample['wav'] = out_audio
|
|
yield sample
|
|
|
|
|
|
def add_noise(data, noise_source, aug_prob):
|
|
for sample in data:
|
|
assert 'wav' in sample
|
|
assert 'key' in sample
|
|
if aug_prob > random.random():
|
|
audio = sample['wav'].numpy()[0]
|
|
audio_len = audio.shape[0]
|
|
audio_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
|
|
key, noise_data = noise_source.random_one()
|
|
if key.startswith('noise'):
|
|
snr_range = [0, 15]
|
|
elif key.startswith('speech'):
|
|
snr_range = [5, 30]
|
|
elif key.startswith('music'):
|
|
snr_range = [5, 15]
|
|
else:
|
|
snr_range = [0, 15]
|
|
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
|
|
noise_audio = noise_audio.astype(np.float32)
|
|
if len(noise_audio.shape) > 1:
|
|
noise_audio = noise_audio[:, 0]
|
|
if noise_audio.shape[0] > audio_len:
|
|
start = random.randint(0, noise_audio.shape[0] - audio_len)
|
|
noise_audio = noise_audio[start:start + audio_len]
|
|
else:
|
|
# Resize will repeat copy
|
|
noise_audio = np.resize(noise_audio, (audio_len, ))
|
|
noise_snr = random.uniform(snr_range[0], snr_range[1])
|
|
noise_db = 10 * np.log10(np.mean(noise_audio**2) + 1e-4)
|
|
noise_audio = np.sqrt(10**(
|
|
(audio_db - noise_db - noise_snr) / 10)) * noise_audio
|
|
out_audio = audio + noise_audio
|
|
out_audio = torch.from_numpy(out_audio)
|
|
out_audio = torch.unsqueeze(out_audio, 0)
|
|
sample['wav'] = out_audio
|
|
yield sample
|