|
1 | 1 | #-*- coding:utf8 -*-
|
| 2 | +import importlib |
2 | 3 | import os
|
3 | 4 | import re
|
4 | 5 | import sys
|
5 |
| -import importlib |
6 | 6 | from pathlib import Path
|
7 | 7 |
|
8 | 8 | import click
|
@@ -131,24 +131,26 @@ def list_trans_methods(verbose):
|
131 | 131 | help="list of output nodes")
|
132 | 132 | @click.argument('model_file', required=True, metavar='MODEL.{pb,pkl}')
|
133 | 133 | def show_graph(model_file, **kwargs):
|
| 134 | + import pickle |
| 135 | + from utensor_cgen.frontend import FrontendSelector |
134 | 136 | _, ext = os.path.splitext(model_file)
|
135 | 137 | output_nodes = kwargs.pop('output_nodes')
|
136 |
| - if ext == '.pb' or ext == '.pbtxt': |
137 |
| - _show_pb_file(model_file, output_nodes=output_nodes, **kwargs) |
138 |
| - elif ext == '.pkl': |
139 |
| - import pickle |
| 138 | + |
| 139 | + if ext == '.pkl': |
140 | 140 | with open(model_file, 'rb') as fid:
|
141 | 141 | ugraph = pickle.load(fid)
|
142 | 142 | _show_ugraph(ugraph, **kwargs)
|
143 |
| - else: |
144 |
| - msg = click.style('unknown file extension: {}'.format(ext), fg='red', bold=True) |
145 |
| - click.echo(msg, file=sys.stderr) |
146 |
| - |
147 |
| -def _show_pb_file(pb_file, output_nodes, **kwargs): |
148 |
| - import tensorflow as tf |
149 |
| - from utensor_cgen.frontend.tensorflow import GraphDefParser |
150 |
| - ugraph = GraphDefParser.parse(pb_file, output_nodes=output_nodes) |
151 |
| - _show_ugraph(ugraph, **kwargs) |
| 143 | + return 0 |
| 144 | + |
| 145 | + try: |
| 146 | + parser = FrontendSelector.select_parser(ext) |
| 147 | + ugraph = parser.parse(model_file, output_nodes) |
| 148 | + _show_ugraph(ugraph, **kwargs) |
| 149 | + return 0 |
| 150 | + except RuntimeError as err: |
| 151 | + msg = err.args[0] |
| 152 | + click.secho(msg, fg='red', bold=True) |
| 153 | + return 1 |
152 | 154 |
|
153 | 155 | def _show_ugraph(ugraph, oneline=False, ignore_unknown_op=False):
|
154 | 156 | import textwrap
|
|
0 commit comments