Skip to content

Commit 6f4c5b3

Browse files
committed
Merge branch 'keras3-support' of https://github.com/nfirle/tensorflow-onnx into ci
2 parents 61b8495 + 671f22e commit 6f4c5b3

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

tf2onnx/convert.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tf2onnx import constants, logging, utils, optimizer
2121
from tf2onnx import tf_loader
2222
from tf2onnx.graph import ExternalTensorStorage
23-
from tf2onnx.tf_utils import compress_graph_def, get_tf_version
23+
from tf2onnx.tf_utils import compress_graph_def, get_tf_version, get_keras_version
2424

2525

2626

@@ -409,6 +409,100 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None,
409409

410410
return model_proto, external_tensor_storage
411411

412+
def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
413+
custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None,
414+
target=None, large_model=False, output_path=None, optimizers=None):
415+
"""
416+
Convert a Keras 3 model to ONNX using tf2onnx.
417+
418+
Args:
419+
model: Keras 3 Functional or Sequential model
420+
name: Name for the converted model
421+
input_signature: Optional list of tf.TensorSpec
422+
opset: ONNX opset version
423+
custom_ops: Dictionary of custom ops
424+
custom_op_handlers: Dictionary of custom op handlers
425+
custom_rewriter: List of graph rewriters
426+
inputs_as_nchw: List of input names to convert to NCHW
427+
extra_opset: Additional opset imports
428+
shape_override: Dictionary to override input shapes
429+
target: Target platforms (for workarounds)
430+
large_model: Whether to use external tensor storage
431+
output_path: Optional path to write ONNX model to file
432+
433+
Returns:
434+
A tuple (model_proto, external_tensor_storage_dict)
435+
"""
436+
437+
438+
if not input_signature:
439+
440+
input_signature = [
441+
tf.TensorSpec(tensor.shape, tensor.dtype, name=tensor.name.split(":")[0])
442+
for tensor in model.inputs
443+
]
444+
445+
# Trace model
446+
function = tf.function(model)
447+
concrete_func = function.get_concrete_function(*input_signature)
448+
449+
# These inputs will be removed during freezing (includes resources, etc.)
450+
if hasattr(concrete_func.graph, '_captures'):
451+
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
452+
captured_inputs = [t_name.name for _, t_name in graph_captures.values()]
453+
else:
454+
graph_captures = concrete_func.graph.function_captures.by_val_internal
455+
captured_inputs = [t.name for t in graph_captures.values()]
456+
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
457+
if input_tensor.name not in captured_inputs]
458+
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
459+
if output_tensor.dtype != tf.dtypes.resource]
460+
461+
462+
tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names)
463+
reverse_lookup = {v: k for k, v in tensors_to_rename.items()}
464+
465+
466+
467+
valid_names = []
468+
for out in [t.name for t in model.outputs]:
469+
if out in reverse_lookup:
470+
valid_names.append(reverse_lookup[out])
471+
else:
472+
print(f"Warning: Output name '{out}' not found in reverse_lookup.")
473+
# Fallback: verwende TensorFlow-Ausgangsnamen direkt
474+
valid_names = [t.name for t in concrete_func.outputs if t.dtype != tf.dtypes.resource]
475+
break
476+
output_names = valid_names
477+
478+
479+
#if old_out_names is not None:
480+
#model.output_names = old_out_names
481+
482+
with tf.device("/cpu:0"):
483+
frozen_graph, initialized_tables = \
484+
tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model)
485+
model_proto, external_tensor_storage = _convert_common(
486+
frozen_graph,
487+
name=model.name,
488+
continue_on_error=True,
489+
target=target,
490+
opset=opset,
491+
custom_ops=custom_ops,
492+
custom_op_handlers=custom_op_handlers,
493+
optimizers=optimizers,
494+
custom_rewriter=custom_rewriter,
495+
extra_opset=extra_opset,
496+
shape_override=shape_override,
497+
input_names=input_names,
498+
output_names=output_names,
499+
inputs_as_nchw=inputs_as_nchw,
500+
outputs_as_nchw=outputs_as_nchw,
501+
large_model=large_model,
502+
tensors_to_rename=tensors_to_rename,
503+
initialized_tables=initialized_tables,
504+
output_path=output_path)
505+
return model_proto, external_tensor_storage
412506

413507
def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
414508
custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None,
@@ -439,6 +533,10 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
439533
if get_tf_version() < Version("2.0"):
440534
return _from_keras_tf1(model, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw,
441535
outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)
536+
if get_keras_version() > Version("3.0"):
537+
return from_keras3(model, input_signature, opset, custom_ops, custom_op_handlers,
538+
custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override,
539+
target, large_model, output_path, optimizers)
442540

443541
old_out_names = _rename_duplicate_keras_model_names(model)
444542
from tensorflow.python.keras.saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel

tf2onnx/tf_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
import tensorflow as tf
13+
import keras
1314

1415
from tensorflow.core.framework import types_pb2, tensor_pb2, graph_pb2
1516
from tensorflow.python.framework import tensor_util
@@ -124,6 +125,9 @@ def get_tf_node_attr(node, name):
124125
def get_tf_version():
125126
return Version(tf.__version__)
126127

128+
def get_keras_version():
129+
return Version(keras.__version__)
130+
127131
def compress_graph_def(graph_def):
128132
"""
129133
Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.

0 commit comments

Comments
 (0)