Skip to content

Commit 11d6afc

Browse files
committed
✍️ update dataset and add notebooks
1 parent 4be4a6e commit 11d6afc

File tree

9 files changed

+966
-24
lines changed

9 files changed

+966
-24
lines changed

notebooks/conformer.ipynb

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
{
2+
"metadata": {
3+
"language_info": {
4+
"codemirror_mode": {
5+
"name": "ipython",
6+
"version": 3
7+
},
8+
"file_extension": ".py",
9+
"mimetype": "text/x-python",
10+
"name": "python",
11+
"nbconvert_exporter": "python",
12+
"pygments_lexer": "ipython3",
13+
"version": "3.8.8-final"
14+
},
15+
"orig_nbformat": 2,
16+
"kernelspec": {
17+
"name": "python388jvsc74a57bd045f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f",
18+
"display_name": "Python 3.8.8 64-bit ('tfo': conda)"
19+
},
20+
"metadata": {
21+
"interpreter": {
22+
"hash": "45f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f"
23+
}
24+
}
25+
},
26+
"nbformat": 4,
27+
"nbformat_minor": 2,
28+
"cells": [
29+
{
30+
"cell_type": "code",
31+
"execution_count": null,
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"config = {\n",
36+
" \"speech_config\": {\n",
37+
" \"sample_rate\": 16000,\n",
38+
" \"frame_ms\": 25,\n",
39+
" \"stride_ms\": 10,\n",
40+
" \"num_feature_bins\": 80,\n",
41+
" \"feature_type\": \"log_mel_spectrogram\",\n",
42+
" \"preemphasis\": 0.97,\n",
43+
" \"normalize_signal\": True,\n",
44+
" \"normalize_feature\": True,\n",
45+
" \"normalize_per_feature\": False,\n",
46+
" },\n",
47+
" \"decoder_config\": {\n",
48+
" \"vocabulary\": None,\n",
49+
" \"target_vocab_size\": 1000,\n",
50+
" \"max_subword_length\": 10,\n",
51+
" \"blank_at_zero\": True,\n",
52+
" \"beam_width\": 0,\n",
53+
" \"norm_score\": True,\n",
54+
" \"corpus_files\": None,\n",
55+
" },\n",
56+
" \"model_config\": {\n",
57+
" \"name\": \"conformer\",\n",
58+
" \"encoder_subsampling\": {\n",
59+
" \"type\": \"conv2d\",\n",
60+
" \"filters\": 144,\n",
61+
" \"kernel_size\": 3,\n",
62+
" \"strides\": 2,\n",
63+
" },\n",
64+
" \"encoder_positional_encoding\": \"sinusoid_concat\",\n",
65+
" \"encoder_dmodel\": 144,\n",
66+
" \"encoder_num_blocks\": 16,\n",
67+
" \"encoder_head_size\": 36,\n",
68+
" \"encoder_num_heads\": 4,\n",
69+
" \"encoder_mha_type\": \"relmha\",\n",
70+
" \"encoder_kernel_size\": 32,\n",
71+
" \"encoder_fc_factor\": 0.5,\n",
72+
" \"encoder_dropout\": 0.1,\n",
73+
" \"prediction_embed_dim\": 320,\n",
74+
" \"prediction_embed_dropout\": 0,\n",
75+
" \"prediction_num_rnns\": 1,\n",
76+
" \"prediction_rnn_units\": 320,\n",
77+
" \"prediction_rnn_type\": \"lstm\",\n",
78+
" \"prediction_rnn_implementation\": 2,\n",
79+
" \"prediction_layer_norm\": True,\n",
80+
" \"prediction_projection_units\": 0,\n",
81+
" \"joint_dim\": 320,\n",
82+
" \"prejoint_linear\": True,\n",
83+
" \"joint_activation\": \"tanh\",\n",
84+
" \"joint_mode\": \"add\",\n",
85+
" },\n",
86+
" \"learning_config\": {\n",
87+
" \"train_dataset_config\": {\n",
88+
" \"use_tf\": True,\n",
89+
" \"augmentation_config\": {\n",
90+
" \"feature_augment\": {\n",
91+
" \"time_masking\": {\n",
92+
" \"num_masks\": 10,\n",
93+
" \"mask_factor\": 100,\n",
94+
" \"p_upperbound\": 0.05,\n",
95+
" },\n",
96+
" \"freq_masking\": {\"num_masks\": 1, \"mask_factor\": 27},\n",
97+
" }\n",
98+
" },\n",
99+
" \"data_paths\": [\n",
100+
" \"/mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv\"\n",
101+
" ],\n",
102+
" \"tfrecords_dir\": None,\n",
103+
" \"shuffle\": True,\n",
104+
" \"cache\": True,\n",
105+
" \"buffer_size\": 100,\n",
106+
" \"drop_remainder\": True,\n",
107+
" \"stage\": \"train\",\n",
108+
" },\n",
109+
" \"eval_dataset_config\": {\n",
110+
" \"use_tf\": True,\n",
111+
" \"data_paths\": None,\n",
112+
" \"tfrecords_dir\": None,\n",
113+
" \"shuffle\": False,\n",
114+
" \"cache\": True,\n",
115+
" \"buffer_size\": 100,\n",
116+
" \"drop_remainder\": True,\n",
117+
" \"stage\": \"eval\",\n",
118+
" },\n",
119+
" \"test_dataset_config\": {\n",
120+
" \"use_tf\": True,\n",
121+
" \"data_paths\": None,\n",
122+
" \"tfrecords_dir\": None,\n",
123+
" \"shuffle\": False,\n",
124+
" \"cache\": True,\n",
125+
" \"buffer_size\": 100,\n",
126+
" \"drop_remainder\": True,\n",
127+
" \"stage\": \"test\",\n",
128+
" },\n",
129+
" \"optimizer_config\": {\n",
130+
" \"warmup_steps\": 40000,\n",
131+
" \"beta_1\": 0.9,\n",
132+
" \"beta_2\": 0.98,\n",
133+
" \"epsilon\": 1e-09,\n",
134+
" },\n",
135+
" \"running_config\": {\n",
136+
" \"batch_size\": 2,\n",
137+
" \"num_epochs\": 50,\n",
138+
" \"checkpoint\": {\n",
139+
" \"filepath\": \"/mnt/e/Models/local/conformer/checkpoints/{epoch:02d}.h5\",\n",
140+
" \"save_best_only\": True,\n",
141+
" \"save_weights_only\": True,\n",
142+
" \"save_freq\": \"epoch\",\n",
143+
" },\n",
144+
" \"states_dir\": \"/mnt/e/Models/local/conformer/states\",\n",
145+
" \"tensorboard\": {\n",
146+
" \"log_dir\": \"/mnt/e/Models/local/conformer/tensorboard\",\n",
147+
" \"histogram_freq\": 1,\n",
148+
" \"write_graph\": True,\n",
149+
" \"write_images\": True,\n",
150+
" \"update_freq\": \"epoch\",\n",
151+
" \"profile_batch\": 2,\n",
152+
" },\n",
153+
" },\n",
154+
" },\n",
155+
"}"
156+
]
157+
},
158+
{
159+
"cell_type": "code",
160+
"execution_count": null,
161+
"metadata": {},
162+
"outputs": [],
163+
"source": [
164+
"metadata = {\n",
165+
" \"train\": {\"max_input_length\": 2974, \"max_label_length\": 194, \"num_entries\": 281241},\n",
166+
" \"eval\": {\"max_input_length\": 3516, \"max_label_length\": 186, \"num_entries\": 5567},\n",
167+
"}"
168+
]
169+
},
170+
{
171+
"cell_type": "code",
172+
"execution_count": null,
173+
"metadata": {},
174+
"outputs": [],
175+
"source": [
176+
"import os\n",
177+
"import math\n",
178+
"import argparse\n",
179+
"from tensorflow_asr.utils import env_util\n",
180+
"\n",
181+
"env_util.setup_environment()\n",
182+
"import tensorflow as tf\n",
183+
"\n",
184+
"tf.keras.backend.clear_session()\n",
185+
"tf.config.optimizer.set_experimental_options({\"auto_mixed_precision\": True})\n",
186+
"strategy = env_util.setup_strategy([0])\n",
187+
"\n",
188+
"from tensorflow_asr.configs.config import Config\n",
189+
"from tensorflow_asr.datasets import asr_dataset\n",
190+
"from tensorflow_asr.featurizers import speech_featurizers, text_featurizers\n",
191+
"from tensorflow_asr.models.transducer.conformer import Conformer\n",
192+
"from tensorflow_asr.optimizers.schedules import TransformerSchedule\n",
193+
"\n",
194+
"config = Config(config)\n",
195+
"speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config)\n",
196+
"\n",
197+
"text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config)\n",
198+
"\n",
199+
"train_dataset = asr_dataset.ASRSliceDataset(\n",
200+
" speech_featurizer=speech_featurizer,\n",
201+
" text_featurizer=text_featurizer,\n",
202+
" **vars(config.learning_config.train_dataset_config),\n",
203+
" indefinite=True\n",
204+
")\n",
205+
"eval_dataset = asr_dataset.ASRSliceDataset(\n",
206+
" speech_featurizer=speech_featurizer,\n",
207+
" text_featurizer=text_featurizer,\n",
208+
" **vars(config.learning_config.eval_dataset_config),\n",
209+
" indefinite=True\n",
210+
")\n",
211+
"\n",
212+
"train_dataset.load_metadata(metadata)\n",
213+
"eval_dataset.load_metadata(metadata)\n",
214+
"speech_featurizer.reset_length()\n",
215+
"text_featurizer.reset_length()\n",
216+
"\n",
217+
"global_batch_size = config.learning_config.running_config.batch_size\n",
218+
"global_batch_size *= strategy.num_replicas_in_sync\n",
219+
"\n",
220+
"train_data_loader = train_dataset.create(global_batch_size)\n",
221+
"eval_data_loader = eval_dataset.create(global_batch_size)\n",
222+
"\n",
223+
"with strategy.scope():\n",
224+
" # build model\n",
225+
" conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)\n",
226+
" conformer._build(speech_featurizer.shape)\n",
227+
" conformer.summary(line_length=100)\n",
228+
"\n",
229+
" optimizer = tf.keras.optimizers.Adam(\n",
230+
" TransformerSchedule(\n",
231+
" d_model=conformer.dmodel,\n",
232+
" warmup_steps=config.learning_config.optimizer_config.pop(\"warmup_steps\", 10000),\n",
233+
" max_lr=(0.05 / math.sqrt(conformer.dmodel))\n",
234+
" ),\n",
235+
" **config.learning_config.optimizer_config\n",
236+
" )\n",
237+
"\n",
238+
" conformer.compile(\n",
239+
" optimizer=optimizer,\n",
240+
" experimental_steps_per_execution=10,\n",
241+
" global_batch_size=global_batch_size,\n",
242+
" blank=text_featurizer.blank\n",
243+
" )\n",
244+
"\n",
245+
"callbacks = [\n",
246+
" tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),\n",
247+
" tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),\n",
248+
" tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)\n",
249+
"]\n",
250+
"\n",
251+
"conformer.fit(\n",
252+
" train_data_loader,\n",
253+
" epochs=config.learning_config.running_config.num_epochs,\n",
254+
" validation_data=eval_data_loader,\n",
255+
" callbacks=callbacks,\n",
256+
" steps_per_epoch=train_dataset.total_steps,\n",
257+
" validation_steps=eval_dataset.total_steps\n",
258+
")"
259+
]
260+
},
261+
{
262+
"cell_type": "code",
263+
"execution_count": null,
264+
"metadata": {},
265+
"outputs": [],
266+
"source": []
267+
}
268+
]
269+
}

0 commit comments

Comments
 (0)