1
+ {
2
+ "metadata" : {
3
+ "language_info" : {
4
+ "codemirror_mode" : {
5
+ "name" : " ipython" ,
6
+ "version" : 3
7
+ },
8
+ "file_extension" : " .py" ,
9
+ "mimetype" : " text/x-python" ,
10
+ "name" : " python" ,
11
+ "nbconvert_exporter" : " python" ,
12
+ "pygments_lexer" : " ipython3" ,
13
+ "version" : " 3.8.8-final"
14
+ },
15
+ "orig_nbformat" : 2 ,
16
+ "kernelspec" : {
17
+ "name" : " python388jvsc74a57bd045f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f" ,
18
+ "display_name" : " Python 3.8.8 64-bit ('tfo': conda)"
19
+ },
20
+ "metadata" : {
21
+ "interpreter" : {
22
+ "hash" : " 45f983f364f7a4cc7101e6d6987a2125bf0c2b5c5c9855ff35103689f542d13f"
23
+ }
24
+ }
25
+ },
26
+ "nbformat" : 4 ,
27
+ "nbformat_minor" : 2 ,
28
+ "cells" : [
29
+ {
30
+ "cell_type" : " code" ,
31
+ "execution_count" : null ,
32
+ "metadata" : {},
33
+ "outputs" : [],
34
+ "source" : [
35
+ " config = {\n " ,
36
+ " \" speech_config\" : {\n " ,
37
+ " \" sample_rate\" : 16000,\n " ,
38
+ " \" frame_ms\" : 25,\n " ,
39
+ " \" stride_ms\" : 10,\n " ,
40
+ " \" num_feature_bins\" : 80,\n " ,
41
+ " \" feature_type\" : \" log_mel_spectrogram\" ,\n " ,
42
+ " \" preemphasis\" : 0.97,\n " ,
43
+ " \" normalize_signal\" : True,\n " ,
44
+ " \" normalize_feature\" : True,\n " ,
45
+ " \" normalize_per_feature\" : False,\n " ,
46
+ " },\n " ,
47
+ " \" decoder_config\" : {\n " ,
48
+ " \" vocabulary\" : None,\n " ,
49
+ " \" target_vocab_size\" : 1000,\n " ,
50
+ " \" max_subword_length\" : 10,\n " ,
51
+ " \" blank_at_zero\" : True,\n " ,
52
+ " \" beam_width\" : 0,\n " ,
53
+ " \" norm_score\" : True,\n " ,
54
+ " \" corpus_files\" : None,\n " ,
55
+ " },\n " ,
56
+ " \" model_config\" : {\n " ,
57
+ " \" name\" : \" conformer\" ,\n " ,
58
+ " \" encoder_subsampling\" : {\n " ,
59
+ " \" type\" : \" conv2d\" ,\n " ,
60
+ " \" filters\" : 144,\n " ,
61
+ " \" kernel_size\" : 3,\n " ,
62
+ " \" strides\" : 2,\n " ,
63
+ " },\n " ,
64
+ " \" encoder_positional_encoding\" : \" sinusoid_concat\" ,\n " ,
65
+ " \" encoder_dmodel\" : 144,\n " ,
66
+ " \" encoder_num_blocks\" : 16,\n " ,
67
+ " \" encoder_head_size\" : 36,\n " ,
68
+ " \" encoder_num_heads\" : 4,\n " ,
69
+ " \" encoder_mha_type\" : \" relmha\" ,\n " ,
70
+ " \" encoder_kernel_size\" : 32,\n " ,
71
+ " \" encoder_fc_factor\" : 0.5,\n " ,
72
+ " \" encoder_dropout\" : 0.1,\n " ,
73
+ " \" prediction_embed_dim\" : 320,\n " ,
74
+ " \" prediction_embed_dropout\" : 0,\n " ,
75
+ " \" prediction_num_rnns\" : 1,\n " ,
76
+ " \" prediction_rnn_units\" : 320,\n " ,
77
+ " \" prediction_rnn_type\" : \" lstm\" ,\n " ,
78
+ " \" prediction_rnn_implementation\" : 2,\n " ,
79
+ " \" prediction_layer_norm\" : True,\n " ,
80
+ " \" prediction_projection_units\" : 0,\n " ,
81
+ " \" joint_dim\" : 320,\n " ,
82
+ " \" prejoint_linear\" : True,\n " ,
83
+ " \" joint_activation\" : \" tanh\" ,\n " ,
84
+ " \" joint_mode\" : \" add\" ,\n " ,
85
+ " },\n " ,
86
+ " \" learning_config\" : {\n " ,
87
+ " \" train_dataset_config\" : {\n " ,
88
+ " \" use_tf\" : True,\n " ,
89
+ " \" augmentation_config\" : {\n " ,
90
+ " \" feature_augment\" : {\n " ,
91
+ " \" time_masking\" : {\n " ,
92
+ " \" num_masks\" : 10,\n " ,
93
+ " \" mask_factor\" : 100,\n " ,
94
+ " \" p_upperbound\" : 0.05,\n " ,
95
+ " },\n " ,
96
+ " \" freq_masking\" : {\" num_masks\" : 1, \" mask_factor\" : 27},\n " ,
97
+ " }\n " ,
98
+ " },\n " ,
99
+ " \" data_paths\" : [\n " ,
100
+ " \" /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv\"\n " ,
101
+ " ],\n " ,
102
+ " \" tfrecords_dir\" : None,\n " ,
103
+ " \" shuffle\" : True,\n " ,
104
+ " \" cache\" : True,\n " ,
105
+ " \" buffer_size\" : 100,\n " ,
106
+ " \" drop_remainder\" : True,\n " ,
107
+ " \" stage\" : \" train\" ,\n " ,
108
+ " },\n " ,
109
+ " \" eval_dataset_config\" : {\n " ,
110
+ " \" use_tf\" : True,\n " ,
111
+ " \" data_paths\" : None,\n " ,
112
+ " \" tfrecords_dir\" : None,\n " ,
113
+ " \" shuffle\" : False,\n " ,
114
+ " \" cache\" : True,\n " ,
115
+ " \" buffer_size\" : 100,\n " ,
116
+ " \" drop_remainder\" : True,\n " ,
117
+ " \" stage\" : \" eval\" ,\n " ,
118
+ " },\n " ,
119
+ " \" test_dataset_config\" : {\n " ,
120
+ " \" use_tf\" : True,\n " ,
121
+ " \" data_paths\" : None,\n " ,
122
+ " \" tfrecords_dir\" : None,\n " ,
123
+ " \" shuffle\" : False,\n " ,
124
+ " \" cache\" : True,\n " ,
125
+ " \" buffer_size\" : 100,\n " ,
126
+ " \" drop_remainder\" : True,\n " ,
127
+ " \" stage\" : \" test\" ,\n " ,
128
+ " },\n " ,
129
+ " \" optimizer_config\" : {\n " ,
130
+ " \" warmup_steps\" : 40000,\n " ,
131
+ " \" beta_1\" : 0.9,\n " ,
132
+ " \" beta_2\" : 0.98,\n " ,
133
+ " \" epsilon\" : 1e-09,\n " ,
134
+ " },\n " ,
135
+ " \" running_config\" : {\n " ,
136
+ " \" batch_size\" : 2,\n " ,
137
+ " \" num_epochs\" : 50,\n " ,
138
+ " \" checkpoint\" : {\n " ,
139
+ " \" filepath\" : \" /mnt/e/Models/local/conformer/checkpoints/{epoch:02d}.h5\" ,\n " ,
140
+ " \" save_best_only\" : True,\n " ,
141
+ " \" save_weights_only\" : True,\n " ,
142
+ " \" save_freq\" : \" epoch\" ,\n " ,
143
+ " },\n " ,
144
+ " \" states_dir\" : \" /mnt/e/Models/local/conformer/states\" ,\n " ,
145
+ " \" tensorboard\" : {\n " ,
146
+ " \" log_dir\" : \" /mnt/e/Models/local/conformer/tensorboard\" ,\n " ,
147
+ " \" histogram_freq\" : 1,\n " ,
148
+ " \" write_graph\" : True,\n " ,
149
+ " \" write_images\" : True,\n " ,
150
+ " \" update_freq\" : \" epoch\" ,\n " ,
151
+ " \" profile_batch\" : 2,\n " ,
152
+ " },\n " ,
153
+ " },\n " ,
154
+ " },\n " ,
155
+ " }"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type" : " code" ,
160
+ "execution_count" : null ,
161
+ "metadata" : {},
162
+ "outputs" : [],
163
+ "source" : [
164
+ " metadata = {\n " ,
165
+ " \" train\" : {\" max_input_length\" : 2974, \" max_label_length\" : 194, \" num_entries\" : 281241},\n " ,
166
+ " \" eval\" : {\" max_input_length\" : 3516, \" max_label_length\" : 186, \" num_entries\" : 5567},\n " ,
167
+ " }"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type" : " code" ,
172
+ "execution_count" : null ,
173
+ "metadata" : {},
174
+ "outputs" : [],
175
+ "source" : [
176
+ " import os\n " ,
177
+ " import math\n " ,
178
+ " import argparse\n " ,
179
+ " from tensorflow_asr.utils import env_util\n " ,
180
+ " \n " ,
181
+ " env_util.setup_environment()\n " ,
182
+ " import tensorflow as tf\n " ,
183
+ " \n " ,
184
+ " tf.keras.backend.clear_session()\n " ,
185
+ " tf.config.optimizer.set_experimental_options({\" auto_mixed_precision\" : True})\n " ,
186
+ " strategy = env_util.setup_strategy([0])\n " ,
187
+ " \n " ,
188
+ " from tensorflow_asr.configs.config import Config\n " ,
189
+ " from tensorflow_asr.datasets import asr_dataset\n " ,
190
+ " from tensorflow_asr.featurizers import speech_featurizers, text_featurizers\n " ,
191
+ " from tensorflow_asr.models.transducer.conformer import Conformer\n " ,
192
+ " from tensorflow_asr.optimizers.schedules import TransformerSchedule\n " ,
193
+ " \n " ,
194
+ " config = Config(config)\n " ,
195
+ " speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config)\n " ,
196
+ " \n " ,
197
+ " text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config)\n " ,
198
+ " \n " ,
199
+ " train_dataset = asr_dataset.ASRSliceDataset(\n " ,
200
+ " speech_featurizer=speech_featurizer,\n " ,
201
+ " text_featurizer=text_featurizer,\n " ,
202
+ " **vars(config.learning_config.train_dataset_config),\n " ,
203
+ " indefinite=True\n " ,
204
+ " )\n " ,
205
+ " eval_dataset = asr_dataset.ASRSliceDataset(\n " ,
206
+ " speech_featurizer=speech_featurizer,\n " ,
207
+ " text_featurizer=text_featurizer,\n " ,
208
+ " **vars(config.learning_config.eval_dataset_config),\n " ,
209
+ " indefinite=True\n " ,
210
+ " )\n " ,
211
+ " \n " ,
212
+ " train_dataset.load_metadata(metadata)\n " ,
213
+ " eval_dataset.load_metadata(metadata)\n " ,
214
+ " speech_featurizer.reset_length()\n " ,
215
+ " text_featurizer.reset_length()\n " ,
216
+ " \n " ,
217
+ " global_batch_size = config.learning_config.running_config.batch_size\n " ,
218
+ " global_batch_size *= strategy.num_replicas_in_sync\n " ,
219
+ " \n " ,
220
+ " train_data_loader = train_dataset.create(global_batch_size)\n " ,
221
+ " eval_data_loader = eval_dataset.create(global_batch_size)\n " ,
222
+ " \n " ,
223
+ " with strategy.scope():\n " ,
224
+ " # build model\n " ,
225
+ " conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)\n " ,
226
+ " conformer._build(speech_featurizer.shape)\n " ,
227
+ " conformer.summary(line_length=100)\n " ,
228
+ " \n " ,
229
+ " optimizer = tf.keras.optimizers.Adam(\n " ,
230
+ " TransformerSchedule(\n " ,
231
+ " d_model=conformer.dmodel,\n " ,
232
+ " warmup_steps=config.learning_config.optimizer_config.pop(\" warmup_steps\" , 10000),\n " ,
233
+ " max_lr=(0.05 / math.sqrt(conformer.dmodel))\n " ,
234
+ " ),\n " ,
235
+ " **config.learning_config.optimizer_config\n " ,
236
+ " )\n " ,
237
+ " \n " ,
238
+ " conformer.compile(\n " ,
239
+ " optimizer=optimizer,\n " ,
240
+ " experimental_steps_per_execution=10,\n " ,
241
+ " global_batch_size=global_batch_size,\n " ,
242
+ " blank=text_featurizer.blank\n " ,
243
+ " )\n " ,
244
+ " \n " ,
245
+ " callbacks = [\n " ,
246
+ " tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),\n " ,
247
+ " tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),\n " ,
248
+ " tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)\n " ,
249
+ " ]\n " ,
250
+ " \n " ,
251
+ " conformer.fit(\n " ,
252
+ " train_data_loader,\n " ,
253
+ " epochs=config.learning_config.running_config.num_epochs,\n " ,
254
+ " validation_data=eval_data_loader,\n " ,
255
+ " callbacks=callbacks,\n " ,
256
+ " steps_per_epoch=train_dataset.total_steps,\n " ,
257
+ " validation_steps=eval_dataset.total_steps\n " ,
258
+ " )"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type" : " code" ,
263
+ "execution_count" : null ,
264
+ "metadata" : {},
265
+ "outputs" : [],
266
+ "source" : []
267
+ }
268
+ ]
269
+ }
0 commit comments