-
Notifications
You must be signed in to change notification settings - Fork 563
[WIP] Switch to onnx dynamo export #2219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi ! were you able to test this ? I only tried exporting a simple BERT model and for some reason, simple dynamic shapes specification keep failing with dynamo 🤔 |
All tests failed but I'm not sure I correctly ran them. The dynamic shapes should be fixed for DynamicCache by this PR: huggingface/transformers#36652 which was just merged. That would leave two remaining issues:
Models using DynamicCache should be ok. I need to support other caches as well in transformers. So this simple change may be able to fix a couple of models but not all of them. So for the first PR, I'd like to fix one model. That would mean to be able to keep the old exporter for the existing models and use the new exporter for one model assuming we can manually trigger it. I don't know if that's a scenario you would be ok with. |
After a bit of diffing I found why.
import onnx
import torch
import transformers
from torch.export import Dim
# Load the model
model = transformers.AutoModel.from_pretrained("bert-base-cased")
# Convert the model to ONNX format
input_ids = torch.randint(0, 10, (4, 128), dtype=torch.int64)
attention_mask = torch.ones((4, 128), dtype=torch.int64)
# old, fail with dynamo
dynamic_axes = {
"input_ids": {
0: "batch_size",
1: "sequence_length",
},
"attention_mask": {
0: "batch_size",
1: "sequence_length",
},
}
# new, more strict, works with dynamo
dynamic_shapes = {
"input_ids": {
0: Dim("batch_size", max=512),
1: Dim("sequence_length", max=512 - 1),
},
"attention_mask": {
0: Dim("batch_size", max=512),
1: Dim("sequence_length", max=512 - 1),
},
}
onnx_program = torch.onnx.export(
model,
(input_ids, attention_mask),
"torch_exported_model.onnx",
dynamic_shapes=dynamic_shapes,
export_params=True,
dynamo=True,
)
# Load and save the ONNX model with safetensors
onnx_model = onnx.load("torch_exported_model.onnx") |
We improved the backward compatibility. You don't have to change import onnx
import torch
import transformers
from torch.export import Dim
# Load the model
model = transformers.AutoModel.from_pretrained("bert-base-cased")
# Convert the model to ONNX format
input_ids = torch.randint(0, 10, (4, 128), dtype=torch.int64)
attention_mask = torch.ones((4, 128), dtype=torch.int64)
# old, fail with dynamo
dynamic_axes = {
"input_ids": {
0: "batch_size",
1: "sequence_length",
},
"attention_mask": {
0: "batch_size",
1: "sequence_length",
},
}
onnx_program = torch.onnx.export(
model,
(input_ids, attention_mask),
"torch_exported_model.onnx",
dynamic_axes=dynamic_axes,
dynamo=True,
)
# Load and save the ONNX model with safetensors
onnx_model = onnx.load("torch_exported_model.onnx") However, I noticed we forget to rename the dynamic dimensions (we'll fix that) but if you write |
@xadupre can you please add onnxscript to see if tests are passing |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
This PR investigates the switch to the newest onnx exporter.
Before submitting
Who can review?