Skip to content

[WIP] Switch to onnx dynamo export #2219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

xadupre
Copy link

@xadupre xadupre commented Mar 25, 2025

What does this PR do?

This PR investigates the switch to the newest onnx exporter.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@IlyasMoutawwakil
Copy link
Member

Hi ! were you able to test this ? I only tried exporting a simple BERT model and for some reason, simple dynamic shapes specification keep failing with dynamo 🤔

@xadupre
Copy link
Author

xadupre commented Apr 2, 2025

All tests failed but I'm not sure I correctly ran them. The dynamic shapes should be fixed for DynamicCache by this PR: huggingface/transformers#36652 which was just merged. That would leave two remaining issues:

  • convert dynamic axes into dynamic shapes. That should be doable automatically. We have something in place but it does not handle all the scenarios
  • when a dimension is dynamic, the tensor cannot have a size of 1 or 0 for this dimension. This is a limitation coming from torch.export.export. We'll try to address that as well.

Models using DynamicCache should be ok. I need to support other caches as well in transformers. So this simple change may be able to fix a couple of models but not all of them.

So for the first PR, I'd like to fix one model. That would mean to be able to keep the old exporter for the existing models and use the new exporter for one model assuming we can manually trigger it. I don't know if that's a scenario you would be ok with.

@IlyasMoutawwakil
Copy link
Member

IlyasMoutawwakil commented Apr 2, 2025

After a bit of diffing I found why.
My assumption was that dynamo=True is backward compatible with dynamo=False.
I've been using the same dynamic_axes we have in optimum of the form dynamix_axes = {input_name1: {idx1: axis_name1, ...}, ...}, these work great with dynamo=False, but once I enable dynamo=True, the old dynamix_axes are used as the new dynamix_shapes (passed to torch.export), which are much more strict.
In the case of bert for example, we have no min/max specification for batch_size and sequence_length, which results in export errors because guards batch_size != 9223372036854775807 and sequence_length != 512 are created.

  - Not all values of sequence_length = L['args'][0][0].size()[1] in the specified range sequence_length <= 512 satisfy the generated guard L['args'][0][0].size()[1] != 512.
  - Not all values of batch_size = L['args'][0][0].size()[0] in the specified range satisfy the generated guard L['args'][0][0].size()[0] != 9223372036854775807.
import onnx
import torch
import transformers
from torch.export import Dim

# Load the model
model = transformers.AutoModel.from_pretrained("bert-base-cased")

# Convert the model to ONNX format
input_ids = torch.randint(0, 10, (4, 128), dtype=torch.int64)
attention_mask = torch.ones((4, 128), dtype=torch.int64)

# old, fail with dynamo
dynamic_axes = {
    "input_ids": {
        0: "batch_size",
        1: "sequence_length",
    },
    "attention_mask": {
        0: "batch_size",
        1: "sequence_length",
    },
}

# new, more strict, works with dynamo
dynamic_shapes = {
    "input_ids": {
        0: Dim("batch_size", max=512),
        1: Dim("sequence_length", max=512 - 1),
    },
    "attention_mask": {
        0: Dim("batch_size", max=512),
        1: Dim("sequence_length", max=512 - 1),
    },
}

onnx_program = torch.onnx.export(
    model,
    (input_ids, attention_mask),
    "torch_exported_model.onnx",
    dynamic_shapes=dynamic_shapes,
    export_params=True,
    dynamo=True,
)

# Load and save the ONNX model with safetensors
onnx_model = onnx.load("torch_exported_model.onnx")

@xadupre
Copy link
Author

xadupre commented Apr 2, 2025

We improved the backward compatibility. You don't have to change dynamic_axes. You can keep them as it is if there is no cache. Internally, we replace every string by torch.export.Dim.AUTO recently introduced. There is no constraint to add.

import onnx
import torch
import transformers
from torch.export import Dim

# Load the model
model = transformers.AutoModel.from_pretrained("bert-base-cased")

# Convert the model to ONNX format
input_ids = torch.randint(0, 10, (4, 128), dtype=torch.int64)
attention_mask = torch.ones((4, 128), dtype=torch.int64)

# old, fail with dynamo
dynamic_axes = {
    "input_ids": {
        0: "batch_size",
        1: "sequence_length",
    },
    "attention_mask": {
        0: "batch_size",
        1: "sequence_length",
    },
}

onnx_program = torch.onnx.export(
    model,
    (input_ids, attention_mask),
    "torch_exported_model.onnx",
    dynamic_axes=dynamic_axes,
    dynamo=True,
)

# Load and save the ONNX model with safetensors
onnx_model = onnx.load("torch_exported_model.onnx")

However, I noticed we forget to rename the dynamic dimensions (we'll fix that) but if you write
dynamic_shapes=dynamic_axes and remove dynamic_axes, you will get the expected dynamic names. I did that with the nightly build. It should be available soon with 2.7. strings don't work with torch.export.export, this is something we added to make it easier to switch. I'm currenlty making some changes in transformers to handle caches as well.

@IlyasMoutawwakil
Copy link
Member

@xadupre can you please add onnxscript to see if tests are passing

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants