Skip to content

Commit 37d15f9

Browse files
committed
input1 tensor failed to print after invoke
1 parent c207e51 commit 37d15f9

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

tests/tflm/tflite_export/conftest.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ def simple_tflm_graph():
1818
input_tensors = [],
1919
output_tensors = []
2020
)
21-
weight_op.op_attr["value"] = np.array([1,2,3,4], dtype=np.int8)
22-
weight_op.op_attr["shape"] = [4,1]
21+
#weight_op.op_attr["value"] = np.array([1,2,3,4], dtype=np.int8)
22+
weight_op.op_attr["value"] = np.array([10,20,30,40], dtype=np.float32)
23+
weight_op.op_attr["shape"] = [1,4]
2324

2425
weight = TensorInfo(
2526
name = "weight",
2627
op_name = "weight_const",
27-
dtype = np.dtype("int8"),
28+
dtype = np.dtype("float32"),
2829
shape = weight_op.op_attr["shape"],
2930
ugraph = ugraph
3031
)
@@ -39,14 +40,14 @@ def simple_tflm_graph():
3940
input_tensors = [],
4041
output_tensors = []
4142
)
42-
mock_input_op.op_attr["value"] = np.array([[1.0],[2.0],[3.0],[4.0]], dtype=np.float32)
43-
mock_input_op.op_attr["shape"] = [1,4]
43+
mock_input_op.op_attr["value"] = np.array([[2],[4],[6],[8]], dtype=np.float32)
44+
mock_input_op.op_attr["shape"] = [4,1]
4445

4546
input1 = TensorInfo(
46-
name = "input",
47+
name = "input1",
4748
op_name = "mock_input_const",
4849
dtype = mock_input_op.op_attr["value"].dtype,
49-
shape = [1, 4],
50+
shape = mock_input_op.op_attr["shape"],
5051
ugraph = ugraph
5152
)
5253

@@ -61,13 +62,14 @@ def simple_tflm_graph():
6162
input_tensors = [],
6263
output_tensors = []
6364
)
64-
bias_op.op_attr["value"] = np.array([1], dtype=np.int8)
65+
#bias_op.op_attr["value"] = np.array([1], dtype=np.int8)
66+
bias_op.op_attr["value"] = np.array([7], dtype=np.float32)
6567
bias_op.op_attr["shape"] = [1]
6668

6769
bias = TensorInfo(
6870
name = "bias",
6971
op_name = "bias_const",
70-
dtype = np.dtype("int8"),
72+
dtype = np.dtype("float32"),
7173
shape = bias_op.op_attr["shape"],
7274
ugraph = ugraph
7375
)
@@ -87,8 +89,8 @@ def simple_tflm_graph():
8789
output = TensorInfo(
8890
name = "output",
8991
op_name = "FC1",
90-
dtype = np.dtype("float"),
91-
shape = [1, 1],
92+
dtype = np.dtype("float32"),
93+
shape = [1],
9294
ugraph = ugraph
9395
)
9496

@@ -103,4 +105,4 @@ def simple_tflm_graph():
103105
#ugraph = prune_graph(ugraph)
104106

105107
#return: ugraph, input tensors, output tensors
106-
return [ugraph, [], ["output"]]
108+
return [ugraph, [], ["input1", "weight", "bias", "output"]]

tests/tflm/tflite_export/test_write.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,21 @@ def test_tflite_fb_write(hybrid_quant_output):
8686
ugraph = exporter.transform(sample_ugraph)
8787
model_content = exporter.output()
8888

89-
print_tflite_graph(model_content)
89+
#print_tflite_graph(model_content)
9090

9191
# referece_model_content = open('/Users/neitan01/Documents/tflm/sinExample/sine_model.tflite', "rb").read()
9292
# print_tflite_graph(referece_model_content)
9393

9494
open("tflm_test_model.tflite", "wb").write(model_content)
9595
test_model = tf.lite.Interpreter('tflm_test_model.tflite')
9696
test_model.allocate_tensors()
97-
test_model_output_index = test_model.tensor(test_model.get_output_details()[0]["index"])
9897
test_model.invoke()
99-
output_content = test_model.get_tensor(test_model_output_index)[0]
10098

101-
print(output_content)
99+
print(test_model.get_tensor_details())
100+
print("1 :", test_model.get_tensor(test_model.get_output_details()[0]["index"]))
101+
print("2 :", test_model.get_tensor(test_model.get_output_details()[1]["index"]))
102+
print("3 :", test_model.get_tensor(test_model.get_output_details()[2]["index"]))
103+
print("out :", test_model.get_tensor(test_model.get_output_details()[3]["index"]))
102104

103105
test_pass = True
104106
assert test_pass, 'error message here'

utensor_cgen/transformer/tflite_exporter.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ def get_fullyconnected_builtin_option(fbuilder, op_info):
3838

3939
return obj, BuiltinOptions.FullyConnectedOptions
4040

41+
def tensor_type_lookup(numpy_dtype):
42+
TensorType = tflite.TensorType.TensorType
43+
lookup_map = dict()
44+
lookup_map[np.dtype('float32')] = TensorType.FLOAT32
45+
lookup_map[np.dtype('int8')] = TensorType.INT8
46+
47+
return lookup_map[numpy_dtype]
48+
49+
4150
class FlatbufferOpManager:
4251
op_list = list()
4352
code_name_lookup = {v: k for k, v in BuiltinOperator.__dict__.items()}
@@ -193,7 +202,7 @@ def __create_static_tensor(self, ugraph):
193202
#tensor object
194203
tflite.Tensor.TensorStart(self.fbuilder)
195204
tflite.Tensor.TensorAddShape(self.fbuilder, shape_vec)
196-
tflite.Tensor.TensorAddType(self.fbuilder, TensorType.INT8) #TODO: a conversion class here, out_dtype
205+
tflite.Tensor.TensorAddType(self.fbuilder, tensor_type_lookup(out_dtype))
197206
if export_tensor_name:
198207
tflite.Tensor.TensorAddName(self.fbuilder, tensor_name)
199208
tflite.Tensor.TensorAddQuantization(self.fbuilder, q_param)
@@ -306,7 +315,7 @@ def __create_variable_tensors(self, ugraph):
306315
tflite.Tensor.TensorStart(self.fbuilder)
307316
tflite.Tensor.TensorAddShape(self.fbuilder, shape_vec)
308317
#tflite.Tensor.TensorAddType(self.fbuilder, TensorType.INT8) #TODO: tensor type conversion here
309-
tflite.Tensor.TensorAddType(self.fbuilder, TensorType.FLOAT32)
318+
tflite.Tensor.TensorAddType(self.fbuilder, tensor_type_lookup(tensor_info.dtype))
310319
tflite.Tensor.TensorAddName(self.fbuilder, tensor_name)
311320
#tflite.Tensor.TensorAddQuantization(self.fbuilder, q_param)
312321
tflite.Tensor.TensorAddIsVariable(self.fbuilder, True)

0 commit comments

Comments
 (0)