Skip to content

Commit 8d431dc

Browse files
committed
tighten compilation tests for quantization
1 parent 2841504 commit 8d431dc

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,7 @@ def quantization_config(self):
886886
components_to_quantize=["transformer", "text_encoder_2"],
887887
)
888888

889+
@require_bitsandbytes_version_greater("0.46.1")
889890
def test_torch_compile(self):
890891
torch._dynamo.config.capture_dynamic_output_shape_ops = True
891892
super().test_torch_compile()

tests/quantization/test_torch_compile_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,18 @@ def _test_torch_compile(self, torch_dtype=torch.bfloat16):
5656
pipe.transformer.compile(fullgraph=True)
5757

5858
# small resolutions to ensure speedy execution.
59-
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
59+
with torch._dynamo.config.patch(error_on_recompile=True):
60+
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
6061

6162
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
6263
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
6364
pipe.enable_model_cpu_offload()
64-
pipe.transformer.compile()
65+
# regional compilation is better for offloading.
66+
# see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
67+
if getattr(pipe.transformer, "_repeated_blocks"):
68+
pipe.transformer.compile_repeated_blocks(fullgraph=True)
69+
else:
70+
pipe.transformer.compile()
6571

6672
# small resolutions to ensure speedy execution.
6773
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)

0 commit comments

Comments
 (0)