Align Stream FSMN and Non-Stream FSMN, both in feature extraction and model forward.

This commit is contained in:
dujing 2023-06-27 16:28:39 +08:00
parent 6f8207267e
commit ceb59e87a6
8 changed files with 139 additions and 46 deletions

View File

@ -203,7 +203,9 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--test_data data/test/data.list \ --test_data data/test/data.list \
--window_shift $window_shift \ --window_shift $window_shift \
--step 0.001 \ --step 0.001 \
--score_file $result_dir/score.txt --score_file $result_dir/score.txt \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
fi fi

View File

@ -154,17 +154,21 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--test_data data/test/data.list \ --test_data data/test/data.list \
--window_shift $window_shift \ --window_shift $window_shift \
--step 0.001 \ --step 0.001 \
--score_file $result_dir/score.txt --score_file $result_dir/score.txt \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g') # NOTE: FSMN now is not support export to jit, beacuse of nn.Sequential with tuple input
# This issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
# jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
# python wekws/bin/export_jit.py \
# --config $dir/config.yaml \
# --checkpoint $score_checkpoint \
# --jit_model $dir/$jit_model
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g') onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
python wekws/bin/export_jit.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--jit_model $dir/$jit_model
python wekws/bin/export_onnx.py \ python wekws/bin/export_onnx.py \
--config $dir/config.yaml \ --config $dir/config.yaml \
--checkpoint $score_checkpoint \ --checkpoint $score_checkpoint \

View File

@ -23,6 +23,7 @@ import os
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pypinyin # for Chinese Character import pypinyin # for Chinese Character
from tools.make_list import query_token_set, read_lexicon, read_token
def split_mixed_label(input_str): def split_mixed_label(input_str):
tokens = [] tokens = []
@ -43,7 +44,7 @@ def space_mixed_label(input_str):
space_str = ''.join(f'{sub} ' for sub in splits) space_str = ''.join(f'{sub} ' for sub in splits)
return space_str.strip() return space_str.strip()
def load_label_and_score(keywords_list, label_file, score_file): def load_label_and_score(keywords_list, label_file, score_file, true_keywords):
score_table = {} score_table = {}
with open(score_file, 'r', encoding='utf8') as fin: with open(score_file, 'r', encoding='utf8') as fin:
# read score file and store in table # read score file and store in table
@ -52,10 +53,11 @@ def load_label_and_score(keywords_list, label_file, score_file):
key = arr[0] key = arr[0]
is_detected = arr[1] is_detected = arr[1]
if is_detected == 'detected': if is_detected == 'detected':
keyword=true_keywords[arr[2]]
if key not in score_table: if key not in score_table:
score_table.update({ score_table.update({
key: { key: {
'kw': space_mixed_label(arr[2]), 'kw': space_mixed_label(keyword),
'confi': float(arr[3]) 'confi': float(arr[3])
} }
}) })
@ -72,6 +74,7 @@ def load_label_and_score(keywords_list, label_file, score_file):
# build empty structure for keyword-filler infos # build empty structure for keyword-filler infos
keyword_filler_table = {} keyword_filler_table = {}
for keyword in keywords_list: for keyword in keywords_list:
keyword = true_keywords[keyword]
keyword = space_mixed_label(keyword) keyword = space_mixed_label(keyword)
keyword_filler_table[keyword] = {} keyword_filler_table[keyword] = {}
keyword_filler_table[keyword]['keyword_table'] = {} keyword_filler_table[keyword]['keyword_table'] = {}
@ -93,6 +96,7 @@ def load_label_and_score(keywords_list, label_file, score_file):
assert key in score_table assert key in score_table
for keyword in keywords_list: for keyword in keywords_list:
keyword = true_keywords[keyword]
keyword = space_mixed_label(keyword) keyword = space_mixed_label(keyword)
keyword_regstr_lrblk = ' ' + keyword + ' ' keyword_regstr_lrblk = ' ' + keyword + ' '
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1: if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
@ -154,6 +158,8 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description='compute det curve') parser = argparse.ArgumentParser(description='compute det curve')
parser.add_argument('--test_data', required=True, help='label file') parser.add_argument('--test_data', required=True, help='label file')
parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)') parser.add_argument('--keywords', type=str, default=None, help='keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None, help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None, help='the path of lexicon.txt')
parser.add_argument('--score_file', required=True, help='score file') parser.add_argument('--score_file', required=True, help='score file')
parser.add_argument('--step', type=float, default=0.01, parser.add_argument('--step', type=float, default=0.01,
help='threshold step') help='threshold step')
@ -183,9 +189,18 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
window_shift = args.window_shift window_shift = args.window_shift
keywords_list = args.keywords.strip().split(',') keywords_list = args.keywords.strip().split(',')
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file)
token_table = read_token(args.token_file)
lexicon_table = read_lexicon(args.lexicon_file)
true_keywords = {}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
true_keywords[keyword] = ''.join(strs)
keyword_filler_table = load_label_and_score(keywords_list, args.test_data, args.score_file, true_keywords)
for keyword in keywords_list: for keyword in keywords_list:
keyword = true_keywords[keyword]
keyword = space_mixed_label(keyword) keyword = space_mixed_label(keyword)
keyword_dur = keyword_filler_table[keyword]['keyword_duration'] keyword_dur = keyword_filler_table[keyword]['keyword_duration']
keyword_num = len(keyword_filler_table[keyword]['keyword_table']) keyword_num = len(keyword_filler_table[keyword]['keyword_table'])

View File

@ -35,7 +35,10 @@ from tools.make_list import query_token_set, read_lexicon, read_token
def get_args(): def get_args():
parser = argparse.ArgumentParser(description='detect keywords online.') parser = argparse.ArgumentParser(description='detect keywords online.')
parser.add_argument('--config', required=True, help='config file') parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--wav_path', required=True, help='test wave path.') parser.add_argument('--wav_path', required=False, default=None, help='test wave path.')
parser.add_argument('--wav_scp', required=False, default=None, help='test wave scp.')
parser.add_argument('--result_file', required=False, default=None, help='test result.')
parser.add_argument('--gpu', parser.add_argument('--gpu',
type=int, type=int,
default=-1, default=-1,
@ -209,7 +212,7 @@ class KeyWordSpotter(torch.nn.Module):
self.left_context = dataset_conf['context_expansion_conf']['left'] self.left_context = dataset_conf['context_expansion_conf']['left']
self.right_context = dataset_conf['context_expansion_conf']['right'] self.right_context = dataset_conf['context_expansion_conf']['right']
self.feature_remained = None self.feature_remained = None
self.feature_context_offset = 0 # after downsample, offset exist. self.feats_ctx_offset = 0 # after downsample, offset exist.
# model related # model related
@ -289,6 +292,9 @@ class KeyWordSpotter(torch.nn.Module):
wave = np.array(data) wave = np.array(data)
wave = np.append(self.wave_remained, wave) wave = np.append(self.wave_remained, wave)
if wave.size < (self.frame_length * self.sample_rate / 1000) * self.right_context :
self.wave_remained = wave
return None
wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = torch.from_numpy(wave).float().to(self.device)
wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension
feats = kaldi.fbank(wave_tensor, feats = kaldi.fbank(wave_tensor,
@ -312,7 +318,7 @@ class KeyWordSpotter(torch.nn.Module):
else: else:
feats_pad = torch.cat((self.feature_remained, feats)) feats_pad = torch.cat((self.feature_remained, feats))
ctx_frm = feats.shape[0] - self.right_context ctx_frm = feats_pad.shape[0] - (self.right_context+self.right_context)
ctx_win = (self.left_context + self.right_context + 1) ctx_win = (self.left_context + self.right_context + 1)
ctx_dim = feats.shape[1] * ctx_win ctx_dim = feats.shape[1] * ctx_win
feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32)
@ -320,12 +326,13 @@ class KeyWordSpotter(torch.nn.Module):
feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0) feats_ctx[i] = torch.cat(tuple(feats_pad[i: i + ctx_win])).unsqueeze(0)
# update feature remained, and feats # update feature remained, and feats
self.feature_remained = feats[-self.left_context:] self.feature_remained = feats[-(self.left_context+self.right_context):]
feats = feats_ctx.to(self.device) feats = feats_ctx.to(self.device)
if self.downsampling > 1: if self.downsampling > 1:
feats = feats[self.feature_context_offset::self.downsampling, :] last_remainder = 0 if self.feats_ctx_offset==0 else self.downsampling-self.feats_ctx_offset
complement = feats.size(1) % self.downsampling remainder = (feats.size(0)+last_remainder) % self.downsampling
self.feature_context_offset = complement if complement == 0 else self.downsampling-complement feats = feats[self.feats_ctx_offset::self.downsampling, :]
self.feats_ctx_offset = remainder if remainder == 0 else self.downsampling-remainder
return feats return feats
def decode_keywords(self, t, probs): def decode_keywords(self, t, probs):
@ -396,11 +403,14 @@ class KeyWordSpotter(torch.nn.Module):
def forward(self, wave_chunk): def forward(self, wave_chunk):
feature = self.accept_wave(wave_chunk) feature = self.accept_wave(wave_chunk)
if feature is None or feature.size(0) < 1:
return self.result
feature = feature.unsqueeze(0) # add a batch dimension feature = feature.unsqueeze(0) # add a batch dimension
logits, self.in_cache = self.model(feature, self.in_cache) logits, self.in_cache = self.model(feature, self.in_cache)
probs = logits.softmax(2) # (batch_size, maxlen, vocab_size) probs = logits.softmax(2) # (batch_size, maxlen, vocab_size)
probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search probs = probs[0].cpu() # remove batch dimension, move to cpu for ctc_prefix_beam_search
for (t, prob) in enumerate(probs): for (t, prob) in enumerate(probs):
t *= self.downsampling
self.decode_keywords(t, prob) self.decode_keywords(t, prob)
self.execute_detection(t) self.execute_detection(t)
@ -409,7 +419,7 @@ class KeyWordSpotter(torch.nn.Module):
# since a chunk include about 30 frames, once activated, we can jump the latter frames. # since a chunk include about 30 frames, once activated, we can jump the latter frames.
# TODO: there should give another method to update result, avoiding self.result being cleared. # TODO: there should give another method to update result, avoiding self.result being cleared.
break break
self.total_frames += len(probs) # update frame offset self.total_frames += len(probs) * self.downsampling # update frame offset
return self.result return self.result
def reset(self): def reset(self):
@ -417,6 +427,15 @@ class KeyWordSpotter(torch.nn.Module):
self.activated = False self.activated = False
self.hit_score = 1.0 self.hit_score = 1.0
def reset_all(self):
self.reset()
self.wave_remained = np.array([])
self.feature_remained = None
self.feats_ctx_offset = 0 # after downsample, offset exist.
self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float)
self.total_frames = 0 # frame offset, for absolute time
self.result = {}
def demo(): def demo():
args = get_args() args = get_args()
logging.basicConfig(level=logging.DEBUG, logging.basicConfig(level=logging.DEBUG,
@ -435,22 +454,60 @@ def demo():
args.gpu, args.gpu,
args.jit_model) args.jit_model)
# actually this could be done in __init__ method, we pull it outside for changing keywords. # actually this could be done in __init__ method, we pull it outside for changing keywords more freely.
kws.set_keywords(args.keywords) kws.set_keywords(args.keywords)
# Caution: input WAV should be standard 16k, 16 bits, 1 channel if args.wav_path:
# In demo we read wave in non-streaming fashion. # Caution: input WAV should be standard 16k, 16 bits, 1 channel
with wave.open(args.wav_path, 'rb') as fin: # In demo we read wave in non-streaming fashion.
assert fin.getnchannels() == 1 with wave.open(args.wav_path, 'rb') as fin:
wav = fin.readframes(fin.getnframes()) assert fin.getnchannels() == 1
wav = fin.readframes(fin.getnframes())
# We inference every 0.3 seconds, in streaming fashion. # We inference every 0.3 seconds, in streaming fashion.
interval = int(0.3 * 16000) * 2 interval = int(0.3 * 16000) * 2
for i in range(0, len(wav), interval): for i in range(0, len(wav), interval):
chunk_wav = wav[i: min(i + interval, len(wav))] chunk_wav = wav[i: min(i + interval, len(wav))]
result = kws.forward(chunk_wav) result = kws.forward(chunk_wav)
print(result) print(result)
fout = None
if args.result_file:
fout = open(args.result_file, 'w', encoding='utf-8')
if args.wav_scp:
with open(args.wav_scp, 'r') as fscp:
for line in fscp:
line = line.strip().split()
assert len(line) == 2, f"The scp should be in kaldi format: \"utt_name wav_path\", but got {line}"
utt_name, wav_path = line[0], line[1]
with wave.open(wav_path, 'rb') as fin:
assert fin.getnchannels() == 1
wav = fin.readframes(fin.getnframes())
kws.reset_all()
activated = False
# We inference every 0.3 seconds, in streaming fashion.
interval = int(0.3 * 16000) * 2
for i in range(0, len(wav), interval):
chunk_wav = wav[i: min(i + interval, len(wav))]
result = kws.forward(chunk_wav)
if 'state' in result and result['state'] == 1:
activated = True
if fout:
hit_keyword = result['keyword']
hit_score = result['score']
fout.write('{} detected {} {:.3f}\n'.format(utt_name, hit_keyword, hit_score))
if not activated:
if fout:
fout.write('{} rejected\n'.format(utt_name))
if fout:
fout.close()
if __name__ == '__main__': if __name__ == '__main__':
demo() demo()

View File

@ -44,11 +44,11 @@ def get_args():
help='gpu id for this rank, -1 for cpu') help='gpu id for this rank, -1 for cpu')
parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--batch_size', parser.add_argument('--batch_size',
default=16, default=1,
type=int, type=int,
help='batch size for inference') help='batch size for inference')
parser.add_argument('--num_workers', parser.add_argument('--num_workers',
default=0, default=1,
type=int, type=int,
help='num of subprocess workers for reading') help='num of subprocess workers for reading')
parser.add_argument('--pin_memory', parser.add_argument('--pin_memory',
@ -131,6 +131,8 @@ def main():
test_conf['feature_extraction_conf']['dither'] = 0.0 test_conf['feature_extraction_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_size'] = args.batch_size test_conf['batch_conf']['batch_size'] = args.batch_size
downsampling_factor = test_conf.get('frame_skip', 1)
test_dataset = Dataset(args.test_data, test_conf) test_dataset = Dataset(args.test_data, test_conf)
test_data_loader = DataLoader(test_dataset, test_data_loader = DataLoader(test_dataset,
batch_size=None, batch_size=None,
@ -205,6 +207,7 @@ def main():
# 2. CTC beam search step by step # 2. CTC beam search step by step
for t in range(0, maxlen): for t in range(0, maxlen):
probs = ctc_probs[t] # (vocab_size,) probs = ctc_probs[t] # (vocab_size,)
t *= downsampling_factor # the real time
# key: prefix, value (pb, pnb), default value(-inf, -inf) # key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, [])) next_hyps = defaultdict(lambda: (0.0, 0.0, []))

View File

@ -157,8 +157,12 @@ def main():
# Try to export the model by script, if fails, we should refine # Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements # the code to satisfy the script export requirements
if rank == 0: if rank == 0:
script_model = torch.jit.script(model) pass
script_model.save(os.path.join(args.model_dir, 'init.zip')) # TODO: for now streaming FSMN do not support export to JITScript,
# TODO: because there is nn.Sequential with Tuple input in current FSMN modules.
# the issue is in https://stackoverflow.com/questions/75714299/pytorch-jit-script-error-when-sequential-container-takes-a-tuple-input/76553450#76553450
# script_model = torch.jit.script(model)
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor() executor = Executor()
# If specify checkpoint, load some info from checkpoint # If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None: if args.checkpoint is not None:

View File

@ -377,6 +377,8 @@ def add_reverb(data, reverb_source, aug_prob):
rir_io = io.BytesIO(rir_data) rir_io = io.BytesIO(rir_data)
_, rir_audio = wavfile.read(rir_io) _, rir_audio = wavfile.read(rir_io)
rir_audio = rir_audio.astype(np.float32) rir_audio = rir_audio.astype(np.float32)
if len(rir_audio.shape) > 1:
rir_audio = rir_audio[:, 0]
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2)) rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
out_audio = signal.convolve(audio, rir_audio, out_audio = signal.convolve(audio, rir_audio,
mode='full')[:audio_len] mode='full')[:audio_len]
@ -405,6 +407,8 @@ def add_noise(data, noise_source, aug_prob):
snr_range = [0, 15] snr_range = [0, 15]
_, noise_audio = wavfile.read(io.BytesIO(noise_data)) _, noise_audio = wavfile.read(io.BytesIO(noise_data))
noise_audio = noise_audio.astype(np.float32) noise_audio = noise_audio.astype(np.float32)
if len(noise_audio.shape) > 1:
noise_audio = noise_audio[:, 0]
if noise_audio.shape[0] > audio_len: if noise_audio.shape[0] > audio_len:
start = random.randint(0, noise_audio.shape[0] - audio_len) start = random.randint(0, noise_audio.shape[0] - audio_len)
noise_audio = noise_audio[start:start + audio_len] noise_audio = noise_audio[start:start + audio_len]

View File

@ -39,11 +39,12 @@ class LinearTransform(nn.Module):
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
def forward(self, input): def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple): if isinstance(input, tuple):
input, in_cache = input input, in_cache = input
else: else:
in_cache = None in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
output = self.quant(input) output = self.quant(input)
output = self.linear(output) output = self.linear(output)
output = self.dequant(output) output = self.dequant(output)
@ -102,11 +103,12 @@ class AffineTransform(nn.Module):
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
def forward(self, input): def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple): if isinstance(input, tuple):
input, in_cache = input input, in_cache = input
else: else:
in_cache = None in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
output = self.quant(input) output = self.quant(input)
output = self.linear(output) output = self.linear(output)
output = self.dequant(output) output = self.dequant(output)
@ -213,15 +215,16 @@ class FSMNBlock(nn.Module):
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
def forward(self, input): def forward(self,
input: Tuple[torch.Tensor, torch.Tensor] ):
if isinstance(input, tuple): if isinstance(input, tuple):
input, in_cache = input input, in_cache = input
else : else :
in_cache = None in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
x = torch.unsqueeze(input, 1) x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) x_per = x.permute(0, 3, 2, 1)
if in_cache is None or len(in_cache) == 0 or in_cache[0] is None: if in_cache is None or len(in_cache) == 0 :
x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0]) x_pad = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride + self.rorder * self.rstride, 0])
else: else:
in_cache = in_cache.to(x_per.device) in_cache = in_cache.to(x_per.device)
@ -339,11 +342,12 @@ class RectifiedLinear(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1) self.dropout = nn.Dropout(0.1)
def forward(self, input): def forward(self,
input: Tuple[torch.Tensor, torch.Tensor]):
if isinstance(input, tuple): if isinstance(input, tuple):
input, in_cache = input input, in_cache = input
else : else :
in_cache = None in_cache = torch.zeros(0, 0, 0, 0, dtype=torch.float)
out = self.relu(input) out = self.relu(input)
# out = self.dropout(out) # out = self.dropout(out)
return (out, in_cache) return (out, in_cache)
@ -456,7 +460,7 @@ class FSMN(nn.Module):
def forward( def forward(
self, self,
input: torch.Tensor, input: torch.Tensor,
in_cache #: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -464,8 +468,8 @@ class FSMN(nn.Module):
in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size in_cache(torch.Tensor): (B, D, C), C is the accumulated cache size
""" """
if in_cache is None or len(in_cache) == 0 or in_cache[0] == None: if in_cache is None or len(in_cache) == 0 :
in_cache = [None for _ in range(len(self.fsmn))] in_cache = [torch.zeros(0, 0, 0, 0, dtype=torch.float) for _ in range(len(self.fsmn))]
input = (input, in_cache) input = (input, in_cache)
x1 = self.in_linear1(input) x1 = self.in_linear1(input)
x2 = self.in_linear2(x1) x2 = self.in_linear2(x1)