Skip to content

Commit d80f35a

Browse files
committed
minor fix
1 parent 36c3013 commit d80f35a

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

utensor_cgen/cli.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#-*- coding:utf8 -*-
2+
import importlib
23
import os
34
import re
45
import sys
5-
import importlib
66
from pathlib import Path
77

88
import click
@@ -131,24 +131,26 @@ def list_trans_methods(verbose):
131131
help="list of output nodes")
132132
@click.argument('model_file', required=True, metavar='MODEL.{pb,pkl}')
133133
def show_graph(model_file, **kwargs):
134+
import pickle
135+
from utensor_cgen.frontend import FrontendSelector
134136
_, ext = os.path.splitext(model_file)
135137
output_nodes = kwargs.pop('output_nodes')
136-
if ext == '.pb' or ext == '.pbtxt':
137-
_show_pb_file(model_file, output_nodes=output_nodes, **kwargs)
138-
elif ext == '.pkl':
139-
import pickle
138+
139+
if ext == '.pkl':
140140
with open(model_file, 'rb') as fid:
141141
ugraph = pickle.load(fid)
142142
_show_ugraph(ugraph, **kwargs)
143-
else:
144-
msg = click.style('unknown file extension: {}'.format(ext), fg='red', bold=True)
145-
click.echo(msg, file=sys.stderr)
146-
147-
def _show_pb_file(pb_file, output_nodes, **kwargs):
148-
import tensorflow as tf
149-
from utensor_cgen.frontend.tensorflow import GraphDefParser
150-
ugraph = GraphDefParser.parse(pb_file, output_nodes=output_nodes)
151-
_show_ugraph(ugraph, **kwargs)
143+
return 0
144+
145+
try:
146+
parser = FrontendSelector.select_parser(ext)
147+
ugraph = parser.parse(model_file, output_nodes)
148+
_show_ugraph(ugraph, **kwargs)
149+
return 0
150+
except RuntimeError as err:
151+
msg = err.args[0]
152+
click.secho(msg, fg='red', bold=True)
153+
return 1
152154

153155
def _show_ugraph(ugraph, oneline=False, ignore_unknown_op=False):
154156
import textwrap

utensor_cgen/frontend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def select_parser(cls, file_ext):
3434
cls._setup()
3535
parser_cls = cls._parser_map.get(file_ext, None)
3636
if parser_cls is None:
37-
raise RuntimeError("unknown model file ext found: %s" % file_ext)
37+
raise RuntimeError("unknown model file extension: %s" % file_ext)
3838
return parser_cls
3939

4040
@classmethod

0 commit comments

Comments
 (0)