Skip to content

Commit 55c2873

Browse files
committed
Support Keras 3 models in from_keras
1 parent 3dd7729 commit 55c2873

File tree

2 files changed

+109
-1
lines changed

2 files changed

+109
-1
lines changed

tf2onnx/convert.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tf2onnx import constants, logging, utils, optimizer
2121
from tf2onnx import tf_loader
2222
from tf2onnx.graph import ExternalTensorStorage
23-
from tf2onnx.tf_utils import compress_graph_def, get_tf_version
23+
from tf2onnx.tf_utils import compress_graph_def, get_tf_version, get_keras_version
2424

2525

2626

@@ -408,6 +408,106 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None,
408408

409409
return model_proto, external_tensor_storage
410410

411+
def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
412+
custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None,
413+
target=None, large_model=False, output_path=None, optimizers=None):
414+
"""
415+
Convert a Keras 3 model to ONNX using tf2onnx.
416+
417+
Args:
418+
model: Keras 3 Functional or Sequential model
419+
name: Name for the converted model
420+
input_signature: Optional list of tf.TensorSpec
421+
opset: ONNX opset version
422+
custom_ops: Dictionary of custom ops
423+
custom_op_handlers: Dictionary of custom op handlers
424+
custom_rewriter: List of graph rewriters
425+
inputs_as_nchw: List of input names to convert to NCHW
426+
extra_opset: Additional opset imports
427+
shape_override: Dictionary to override input shapes
428+
target: Target platforms (for workarounds)
429+
large_model: Whether to use external tensor storage
430+
output_path: Optional path to write ONNX model to file
431+
432+
Returns:
433+
A tuple (model_proto, external_tensor_storage_dict)
434+
"""
435+
436+
437+
if not input_signature:
438+
439+
input_signature = [
440+
tf.TensorSpec(tensor.shape, tensor.dtype, name=tensor.name.split(":")[0])
441+
for tensor in model.inputs
442+
]
443+
444+
# Trace model
445+
function = tf.function(model)
446+
concrete_func = function.get_concrete_function(*input_signature)
447+
448+
# These inputs will be removed during freezing (includes resources, etc.)
449+
if hasattr(concrete_func.graph, '_captures'):
450+
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
451+
captured_inputs = [t_name.name for _, t_name in graph_captures.values()]
452+
else:
453+
graph_captures = concrete_func.graph.function_captures.by_val_internal
454+
captured_inputs = [t.name for t in graph_captures.values()]
455+
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
456+
if input_tensor.name not in captured_inputs]
457+
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
458+
if output_tensor.dtype != tf.dtypes.resource]
459+
460+
461+
tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names)
462+
reverse_lookup = {v: k for k, v in tensors_to_rename.items()}
463+
464+
465+
466+
valid_names = []
467+
for out in model.output_names:
468+
if out in reverse_lookup:
469+
valid_names.append(reverse_lookup[out])
470+
else:
471+
print(f"Warning: Output name '{out}' not found in reverse_lookup.")
472+
# Fallback: verwende TensorFlow-Ausgangsnamen direkt
473+
valid_names = [t.name for t in concrete_func.outputs if t.dtype != tf.dtypes.resource]
474+
break
475+
output_names = valid_names
476+
477+
478+
#if old_out_names is not None:
479+
#model.output_names = old_out_names
480+
481+
with tf.device("/cpu:0"):
482+
frozen_graph, initialized_tables = \
483+
tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model)
484+
485+
for node in frozen_graph.node:
486+
print(node.name, node.op)
487+
model_proto, external_tensor_storage = _convert_common(
488+
frozen_graph,
489+
name=model.name,
490+
continue_on_error=True,
491+
target=target,
492+
opset=opset,
493+
custom_ops=custom_ops,
494+
custom_op_handlers=custom_op_handlers,
495+
optimizers=optimizers,
496+
custom_rewriter=custom_rewriter,
497+
extra_opset=extra_opset,
498+
shape_override=shape_override,
499+
input_names=input_names,
500+
output_names=output_names,
501+
inputs_as_nchw=inputs_as_nchw,
502+
outputs_as_nchw=outputs_as_nchw,
503+
large_model=large_model,
504+
tensors_to_rename=tensors_to_rename,
505+
initialized_tables=initialized_tables,
506+
output_path=output_path)
507+
508+
#print(model_proto)
509+
510+
return model_proto, external_tensor_storage
411511

412512
def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
413513
custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None,
@@ -438,6 +538,10 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
438538
if get_tf_version() < Version("2.0"):
439539
return _from_keras_tf1(model, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw,
440540
outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)
541+
if get_keras_version() > Version("3.0"):
542+
return from_keras3(model, input_signature, opset, custom_ops, custom_op_handlers,
543+
custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override,
544+
target, large_model, output_path, optimizers)
441545

442546
old_out_names = _rename_duplicate_keras_model_names(model)
443547
from tensorflow.python.keras.saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel

tf2onnx/tf_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
import tensorflow as tf
13+
import keras
1314

1415
from tensorflow.core.framework import types_pb2, tensor_pb2, graph_pb2
1516
from tensorflow.python.framework import tensor_util
@@ -124,6 +125,9 @@ def get_tf_node_attr(node, name):
124125
def get_tf_version():
125126
return Version(tf.__version__)
126127

128+
def get_keras_version():
129+
return Version(keras.__version__)
130+
127131
def compress_graph_def(graph_def):
128132
"""
129133
Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.

0 commit comments

Comments
 (0)