Skip to content

Commit caee2e7

Browse files
authored
Merge pull request #251 from TensorSpeech/refactor
Refactor: add support tf2.8 + update imports + update examples + add helpers
2 parents c426e8f + 68ead53 commit caee2e7

40 files changed

+1225
-1491
lines changed

examples/conformer/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ Go to [config.yml](./config.yml)
1010

1111
## Usage
1212

13-
Training, see `python examples/conformer/train_*.py --help`
13+
Training, see `python examples/conformer/train.py --help`
1414

15-
Testing, see `python examples/conformer/test_*.py --help`
15+
Testing, see `python examples/conformer/test.py --help`
1616

17-
TFLite Conversion, see `python examples/conformer/tflite_*.py --help`
17+
TFLite Conversion, see `python examples/conformer/inference/gen_tflite_model.py --help`
1818

1919
## Conformer Subwords - Results on LibriSpeech
2020

examples/conformer/config.yml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ decoder_config:
3131
beam_width: 0
3232
norm_score: True
3333
corpus_files:
34-
- /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
34+
- H:/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
3535

3636
model_config:
3737
name: conformer
@@ -75,8 +75,8 @@ learning_config:
7575
num_masks: 1
7676
mask_factor: 27
7777
data_paths:
78-
- /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
79-
tfrecords_dir: /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
78+
- H:/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
79+
tfrecords_dir: null
8080
shuffle: True
8181
cache: True
8282
buffer_size: 100
@@ -85,8 +85,9 @@ learning_config:
8585

8686
eval_dataset_config:
8787
use_tf: True
88-
data_paths: null
89-
tfrecords_dir: /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
88+
data_paths:
89+
- H:/MLDL/Datasets/ASR/Raw/LibriSpeech/dev-clean/transcripts.tsv
90+
tfrecords_dir: null
9091
shuffle: False
9192
cache: True
9293
buffer_size: 100
@@ -95,7 +96,8 @@ learning_config:
9596

9697
test_dataset_config:
9798
use_tf: True
98-
data_paths: null
99+
data_paths:
100+
- H:/MLDL/Datasets/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv
99101
tfrecords_dir: null
100102
shuffle: False
101103
cache: True
@@ -113,13 +115,13 @@ learning_config:
113115
batch_size: 2
114116
num_epochs: 50
115117
checkpoint:
116-
filepath: /mnt/Miscellanea/Models/local/conformer/checkpoints/{epoch:02d}.h5
118+
filepath: D:/Models/local/conformer/checkpoints/{epoch:02d}.h5
117119
save_best_only: False
118120
save_weights_only: True
119121
save_freq: epoch
120-
states_dir: /mnt/Miscellanea/Models/local/conformer/states
122+
states_dir: D:/Models/local/conformer/states
121123
tensorboard:
122-
log_dir: /mnt/Miscellanea/Models/local/conformer/tensorboard
124+
log_dir: D:/Models/local/conformer/tensorboard
123125
histogram_freq: 1
124126
write_graph: True
125127
write_images: True

examples/conformer/inference/gen_saved_model.py

Lines changed: 51 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import argparse
1615
import os
16+
import fire
1717

1818
from tensorflow_asr.utils import env_util
1919

@@ -22,71 +22,58 @@
2222

2323
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2424

25-
tf.keras.backend.clear_session()
26-
27-
parser = argparse.ArgumentParser(prog="Conformer Testing")
28-
29-
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
30-
31-
parser.add_argument("--h5", type=str, default=None, help="Path to saved h5 weights")
32-
33-
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
34-
35-
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
36-
37-
parser.add_argument("--output_dir", type=str, default=None, help="Output directory for saved model")
38-
39-
args = parser.parse_args()
40-
41-
assert args.h5
42-
assert args.output_dir
4325

4426
from tensorflow_asr.configs.config import Config
45-
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
46-
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SentencePieceFeaturizer, SubwordFeaturizer
27+
from tensorflow_asr.helpers import featurizer_helpers
4728
from tensorflow_asr.models.transducer.conformer import Conformer
4829

49-
config = Config(args.config)
50-
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
51-
52-
if args.sentence_piece:
53-
logger.info("Use SentencePiece ...")
54-
text_featurizer = SentencePieceFeaturizer(config.decoder_config)
55-
elif args.subwords:
56-
logger.info("Use subwords ...")
57-
text_featurizer = SubwordFeaturizer(config.decoder_config)
58-
else:
59-
logger.info("Use characters ...")
60-
text_featurizer = CharFeaturizer(config.decoder_config)
61-
62-
tf.random.set_seed(0)
63-
64-
# build model
65-
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
66-
conformer.make(speech_featurizer.shape)
67-
conformer.load_weights(args.h5, by_name=True)
68-
conformer.summary(line_length=100)
69-
conformer.add_featurizers(speech_featurizer, text_featurizer)
70-
71-
72-
class ConformerModule(tf.Module):
73-
def __init__(self, model: Conformer, name=None):
74-
super().__init__(name=name)
75-
self.model = model
76-
self.num_rnns = config.model_config["prediction_num_rnns"]
77-
self.rnn_units = config.model_config["prediction_rnn_units"]
78-
self.rnn_nstates = 2 if config.model_config["prediction_rnn_type"] == "lstm" else 1
79-
80-
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
81-
def pred(self, signal):
82-
predicted = tf.constant(0, dtype=tf.int32)
83-
states = tf.zeros([self.num_rnns, self.rnn_nstates, 1, self.rnn_units], dtype=tf.float32)
84-
features = self.model.speech_featurizer.tf_extract(signal)
85-
encoded = self.model.encoder_inference(features)
86-
hypothesis = self.model._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=False)
87-
transcript = self.model.text_featurizer.indices2upoints(hypothesis.prediction)
88-
return transcript
89-
9030

91-
module = ConformerModule(model=conformer)
92-
tf.saved_model.save(module, export_dir=args.output_dir, signatures=module.pred.get_concrete_function())
31+
def main(
32+
config: str = DEFAULT_YAML,
33+
h5: str = None,
34+
sentence_piece: bool = False,
35+
subwords: bool = False,
36+
output_dir: str = None,
37+
):
38+
assert h5 and output_dir
39+
config = Config(config)
40+
tf.random.set_seed(0)
41+
tf.keras.backend.clear_session()
42+
43+
speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
44+
config=config,
45+
subwords=subwords,
46+
sentence_piece=sentence_piece,
47+
)
48+
49+
# build model
50+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
51+
conformer.make(speech_featurizer.shape)
52+
conformer.load_weights(h5, by_name=True)
53+
conformer.summary(line_length=100)
54+
conformer.add_featurizers(speech_featurizer, text_featurizer)
55+
56+
class ConformerModule(tf.Module):
57+
def __init__(self, model: Conformer, name=None):
58+
super().__init__(name=name)
59+
self.model = model
60+
self.num_rnns = config.model_config["prediction_num_rnns"]
61+
self.rnn_units = config.model_config["prediction_rnn_units"]
62+
self.rnn_nstates = 2 if config.model_config["prediction_rnn_type"] == "lstm" else 1
63+
64+
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
65+
def pred(self, signal):
66+
predicted = tf.constant(0, dtype=tf.int32)
67+
states = tf.zeros([self.num_rnns, self.rnn_nstates, 1, self.rnn_units], dtype=tf.float32)
68+
features = self.model.speech_featurizer.tf_extract(signal)
69+
encoded = self.model.encoder_inference(features)
70+
hypothesis = self.model._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=False)
71+
transcript = self.model.text_featurizer.indices2upoints(hypothesis.prediction)
72+
return transcript
73+
74+
module = ConformerModule(model=conformer)
75+
tf.saved_model.save(module, export_dir=output_dir, signatures=module.pred.get_concrete_function())
76+
77+
78+
if __name__ == "__main__":
79+
fire.Fire(main)

examples/conformer/inference/gen_tflite_model.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,58 +13,45 @@
1313
# limitations under the License.
1414

1515
import os
16-
import argparse
17-
from tensorflow_asr.utils import env_util, file_util
16+
import fire
17+
from tensorflow_asr.utils import env_util
1818

1919
logger = env_util.setup_environment()
2020
import tensorflow as tf
2121

2222
from tensorflow_asr.configs.config import Config
23-
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
24-
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, CharFeaturizer
23+
from tensorflow_asr.helpers import exec_helpers, featurizer_helpers
2524
from tensorflow_asr.models.transducer.conformer import Conformer
2625

2726
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2827

29-
tf.keras.backend.clear_session()
30-
tf.compat.v1.enable_control_flow_v2()
3128

32-
parser = argparse.ArgumentParser(prog="Conformer TFLite")
29+
def main(
30+
config: str = DEFAULT_YAML,
31+
h5: str = None,
32+
subwords: bool = False,
33+
sentence_piece: bool = False,
34+
output: str = None,
35+
):
36+
assert h5 and output
37+
tf.keras.backend.clear_session()
38+
tf.compat.v1.enable_control_flow_v2()
3339

34-
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
40+
config = Config(config)
41+
speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
42+
config=config,
43+
subwords=subwords,
44+
sentence_piece=sentence_piece,
45+
)
3546

36-
parser.add_argument("--h5", type=str, default=None, help="Path to saved model")
47+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
48+
conformer.make(speech_featurizer.shape)
49+
conformer.load_weights(h5, by_name=True)
50+
conformer.summary(line_length=100)
51+
conformer.add_featurizers(speech_featurizer, text_featurizer)
3752

38-
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
53+
exec_helpers.convert_tflite(model=conformer, output=output)
3954

40-
parser.add_argument("output", type=str, default=None, help="TFLite file path to be exported")
4155

42-
args = parser.parse_args()
43-
44-
assert args.h5 and args.output
45-
46-
config = Config(args.config)
47-
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
48-
49-
if args.subwords:
50-
text_featurizer = SubwordFeaturizer(config.decoder_config)
51-
else:
52-
text_featurizer = CharFeaturizer(config.decoder_config)
53-
54-
# build model
55-
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
56-
conformer.make(speech_featurizer.shape)
57-
conformer.load_weights(args.h5, by_name=True)
58-
conformer.summary(line_length=100)
59-
conformer.add_featurizers(speech_featurizer, text_featurizer)
60-
61-
concrete_func = conformer.make_tflite_function().get_concrete_function()
62-
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
63-
converter.experimental_new_converter = True
64-
converter.optimizations = [tf.lite.Optimize.DEFAULT]
65-
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
66-
tflite_model = converter.convert()
67-
68-
args.output = file_util.preprocess_paths(args.output)
69-
with open(args.output, "wb") as tflite_out:
70-
tflite_out.write(tflite_model)
56+
if __name__ == "__main__":
57+
fire.Fire(main)

examples/conformer/inference/run_saved_model.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import argparse
1615
import os
16+
import fire
1717

1818
from tensorflow_asr.utils import env_util
1919

@@ -22,21 +22,23 @@
2222

2323
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2424

25-
tf.keras.backend.clear_session()
2625

27-
parser = argparse.ArgumentParser()
26+
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
2827

29-
parser.add_argument("--saved_model", type=str, default=None, help="The file path of saved model")
3028

31-
parser.add_argument("filename", type=str, default=None, help="Audio file path")
29+
def main(
30+
saved_model: str = None,
31+
filename: str = None,
32+
):
33+
tf.keras.backend.clear_session()
3234

33-
args = parser.parse_args()
35+
module = tf.saved_model.load(export_dir=saved_model)
3436

35-
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
37+
signal = read_raw_audio(filename)
38+
transcript = module.pred(signal)
3639

37-
module = tf.saved_model.load(export_dir=args.saved_model)
40+
print("Transcript: ", "".join([chr(u) for u in transcript]))
3841

39-
signal = read_raw_audio(args.filename)
40-
transcript = module.pred(signal)
4142

42-
print("Transcript: ", "".join([chr(u) for u in transcript]))
43+
if __name__ == "__main__":
44+
fire.Fire(main)

examples/conformer/inference/run_tflite_model.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,36 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import argparse
15+
import fire
1616
import tensorflow as tf
1717

1818
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
1919

20-
parser = argparse.ArgumentParser()
2120

22-
parser.add_argument("filename", metavar="FILENAME", help="Audio file to be played back")
21+
def main(
22+
filename: str,
23+
tflite: str = None,
24+
blank: int = 0,
25+
num_rnns: int = 1,
26+
nstates: int = 2,
27+
statesize: int = 320,
28+
):
29+
tflitemodel = tf.lite.Interpreter(model_path=tflite)
2330

24-
parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite")
31+
signal = read_raw_audio(filename)
2532

26-
parser.add_argument("--blank", type=int, default=0, help="Blank index")
33+
input_details = tflitemodel.get_input_details()
34+
output_details = tflitemodel.get_output_details()
35+
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
36+
tflitemodel.allocate_tensors()
37+
tflitemodel.set_tensor(input_details[0]["index"], signal)
38+
tflitemodel.set_tensor(input_details[1]["index"], tf.constant(blank, dtype=tf.int32))
39+
tflitemodel.set_tensor(input_details[2]["index"], tf.zeros([num_rnns, nstates, 1, statesize], dtype=tf.float32))
40+
tflitemodel.invoke()
41+
hyp = tflitemodel.get_tensor(output_details[0]["index"])
2742

28-
parser.add_argument("--num_rnns", type=int, default=1, help="Number of RNN layers in prediction network")
43+
print("".join([chr(u) for u in hyp]))
2944

30-
parser.add_argument("--nstates", type=int, default=2, help="Number of RNN states in prediction network")
3145

32-
parser.add_argument("--statesize", type=int, default=320, help="Size of RNN state in prediction network")
33-
34-
args = parser.parse_args()
35-
36-
tflitemodel = tf.lite.Interpreter(model_path=args.tflite)
37-
38-
signal = read_raw_audio(args.filename)
39-
40-
input_details = tflitemodel.get_input_details()
41-
output_details = tflitemodel.get_output_details()
42-
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
43-
tflitemodel.allocate_tensors()
44-
tflitemodel.set_tensor(input_details[0]["index"], signal)
45-
tflitemodel.set_tensor(input_details[1]["index"], tf.constant(args.blank, dtype=tf.int32))
46-
tflitemodel.set_tensor(input_details[2]["index"], tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32))
47-
tflitemodel.invoke()
48-
hyp = tflitemodel.get_tensor(output_details[0]["index"])
49-
50-
print("".join([chr(u) for u in hyp]))
46+
if __name__ == "__main__":
47+
fire.Fire(main)

0 commit comments

Comments
 (0)