|
20 | 20 | from tf2onnx import constants, logging, utils, optimizer
|
21 | 21 | from tf2onnx import tf_loader
|
22 | 22 | from tf2onnx.graph import ExternalTensorStorage
|
23 |
| -from tf2onnx.tf_utils import compress_graph_def, get_tf_version |
| 23 | +from tf2onnx.tf_utils import compress_graph_def, get_tf_version, get_keras_version |
24 | 24 |
|
25 | 25 |
|
26 | 26 |
|
@@ -409,6 +409,100 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None,
|
409 | 409 |
|
410 | 410 | return model_proto, external_tensor_storage
|
411 | 411 |
|
| 412 | +def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, |
| 413 | + custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None, |
| 414 | + target=None, large_model=False, output_path=None, optimizers=None): |
| 415 | + """ |
| 416 | + Convert a Keras 3 model to ONNX using tf2onnx. |
| 417 | +
|
| 418 | + Args: |
| 419 | + model: Keras 3 Functional or Sequential model |
| 420 | + name: Name for the converted model |
| 421 | + input_signature: Optional list of tf.TensorSpec |
| 422 | + opset: ONNX opset version |
| 423 | + custom_ops: Dictionary of custom ops |
| 424 | + custom_op_handlers: Dictionary of custom op handlers |
| 425 | + custom_rewriter: List of graph rewriters |
| 426 | + inputs_as_nchw: List of input names to convert to NCHW |
| 427 | + extra_opset: Additional opset imports |
| 428 | + shape_override: Dictionary to override input shapes |
| 429 | + target: Target platforms (for workarounds) |
| 430 | + large_model: Whether to use external tensor storage |
| 431 | + output_path: Optional path to write ONNX model to file |
| 432 | +
|
| 433 | + Returns: |
| 434 | + A tuple (model_proto, external_tensor_storage_dict) |
| 435 | + """ |
| 436 | + |
| 437 | + |
| 438 | + if not input_signature: |
| 439 | + |
| 440 | + input_signature = [ |
| 441 | + tf.TensorSpec(tensor.shape, tensor.dtype, name=tensor.name.split(":")[0]) |
| 442 | + for tensor in model.inputs |
| 443 | + ] |
| 444 | + |
| 445 | + # Trace model |
| 446 | + function = tf.function(model) |
| 447 | + concrete_func = function.get_concrete_function(*input_signature) |
| 448 | + |
| 449 | + # These inputs will be removed during freezing (includes resources, etc.) |
| 450 | + if hasattr(concrete_func.graph, '_captures'): |
| 451 | + graph_captures = concrete_func.graph._captures # pylint: disable=protected-access |
| 452 | + captured_inputs = [t_name.name for _, t_name in graph_captures.values()] |
| 453 | + else: |
| 454 | + graph_captures = concrete_func.graph.function_captures.by_val_internal |
| 455 | + captured_inputs = [t.name for t in graph_captures.values()] |
| 456 | + input_names = [input_tensor.name for input_tensor in concrete_func.inputs |
| 457 | + if input_tensor.name not in captured_inputs] |
| 458 | + output_names = [output_tensor.name for output_tensor in concrete_func.outputs |
| 459 | + if output_tensor.dtype != tf.dtypes.resource] |
| 460 | + |
| 461 | + |
| 462 | + tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names) |
| 463 | + reverse_lookup = {v: k for k, v in tensors_to_rename.items()} |
| 464 | + |
| 465 | + |
| 466 | + |
| 467 | + valid_names = [] |
| 468 | + for out in [t.name for t in model.outputs]: |
| 469 | + if out in reverse_lookup: |
| 470 | + valid_names.append(reverse_lookup[out]) |
| 471 | + else: |
| 472 | + print(f"Warning: Output name '{out}' not found in reverse_lookup.") |
| 473 | + # Fallback: verwende TensorFlow-Ausgangsnamen direkt |
| 474 | + valid_names = [t.name for t in concrete_func.outputs if t.dtype != tf.dtypes.resource] |
| 475 | + break |
| 476 | + output_names = valid_names |
| 477 | + |
| 478 | + |
| 479 | + #if old_out_names is not None: |
| 480 | + #model.output_names = old_out_names |
| 481 | + |
| 482 | + with tf.device("/cpu:0"): |
| 483 | + frozen_graph, initialized_tables = \ |
| 484 | + tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model) |
| 485 | + model_proto, external_tensor_storage = _convert_common( |
| 486 | + frozen_graph, |
| 487 | + name=model.name, |
| 488 | + continue_on_error=True, |
| 489 | + target=target, |
| 490 | + opset=opset, |
| 491 | + custom_ops=custom_ops, |
| 492 | + custom_op_handlers=custom_op_handlers, |
| 493 | + optimizers=optimizers, |
| 494 | + custom_rewriter=custom_rewriter, |
| 495 | + extra_opset=extra_opset, |
| 496 | + shape_override=shape_override, |
| 497 | + input_names=input_names, |
| 498 | + output_names=output_names, |
| 499 | + inputs_as_nchw=inputs_as_nchw, |
| 500 | + outputs_as_nchw=outputs_as_nchw, |
| 501 | + large_model=large_model, |
| 502 | + tensors_to_rename=tensors_to_rename, |
| 503 | + initialized_tables=initialized_tables, |
| 504 | + output_path=output_path) |
| 505 | + return model_proto, external_tensor_storage |
412 | 506 |
|
413 | 507 | def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
|
414 | 508 | custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None,
|
@@ -439,6 +533,10 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
|
439 | 533 | if get_tf_version() < Version("2.0"):
|
440 | 534 | return _from_keras_tf1(model, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw,
|
441 | 535 | outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)
|
| 536 | + if get_keras_version() > Version("3.0"): |
| 537 | + return from_keras3(model, input_signature, opset, custom_ops, custom_op_handlers, |
| 538 | + custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override, |
| 539 | + target, large_model, output_path, optimizers) |
442 | 540 |
|
443 | 541 | old_out_names = _rename_duplicate_keras_model_names(model)
|
444 | 542 | from tensorflow.python.keras.saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel
|
|
0 commit comments