From 8d431dc967a4118168af74aae9c41f2a68764851 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 28 Jul 2025 13:27:20 +0530 Subject: [PATCH 1/4] tighten compilation tests for quantization --- tests/quantization/bnb/test_4bit.py | 1 + tests/quantization/test_torch_compile_utils.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 8e2a8515c662..08c0fee43b80 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -886,6 +886,7 @@ def quantization_config(self): components_to_quantize=["transformer", "text_encoder_2"], ) + @require_bitsandbytes_version_greater("0.46.1") def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True super().test_torch_compile() diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index c742927646b6..91ed173fc69b 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -56,12 +56,18 @@ def _test_torch_compile(self, torch_dtype=torch.bfloat16): pipe.transformer.compile(fullgraph=True) # small resolutions to ensure speedy execution. - pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) + with torch._dynamo.config.patch(error_on_recompile=True): + pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16): pipe = self._init_pipeline(self.quantization_config, torch_dtype) pipe.enable_model_cpu_offload() - pipe.transformer.compile() + # regional compilation is better for offloading. + # see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/ + if getattr(pipe.transformer, "_repeated_blocks"): + pipe.transformer.compile_repeated_blocks(fullgraph=True) + else: + pipe.transformer.compile() # small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) From 69920eff3e3efd3732fde4d4822bedc618ee4f9f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 28 Jul 2025 15:16:53 +0530 Subject: [PATCH 2/4] feat: model_info but local. --- src/diffusers/utils/hub_utils.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 8aaee5b75d93..2fa3f975bce1 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -404,9 +404,21 @@ def _get_checkpoint_shard_files( ignore_patterns = ["*.json", "*.md"] # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + local = False + try: + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + except HTTPError: + if local_files_only: + temp_dir = snapshot_download( + repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only + ) + model_files_info = _get_filepaths_for_folder(temp_dir) + local = True for shard_file in original_shard_filenames: - shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if local: + shard_file_present = any(shard_file in k for k in model_files_info) + else: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) if not shard_file_present: raise EnvironmentError( f"{shards_path} does not appear to have a file named {shard_file} which is " @@ -442,6 +454,16 @@ def _get_checkpoint_shard_files( return cached_filenames, sharded_metadata +def _get_filepaths_for_folder(folder): + relative_paths = [] + for root, dirs, files in os.walk(folder): + for fname in files: + abs_path = os.path.join(root, fname) + rel_path = os.path.relpath(abs_path, start=folder) + relative_paths.append(rel_path) + return relative_paths + + def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None): if filenames and folder: raise ValueError("Both `filenames` and `folder` cannot be provided.") From d5c1772dc361322f7b97705e7ec2686d12c9f58c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 28 Jul 2025 20:17:24 +0530 Subject: [PATCH 3/4] up --- src/diffusers/utils/hub_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 2fa3f975bce1..cd263876132f 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -45,6 +45,7 @@ ) from packaging import version from requests import HTTPError +from requests.exceptions import ConnectionError from .. import __version__ from .constants import ( @@ -407,13 +408,18 @@ def _get_checkpoint_shard_files( local = False try: model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) - except HTTPError: + except ConnectionError as e: if local_files_only: temp_dir = snapshot_download( repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only ) model_files_info = _get_filepaths_for_folder(temp_dir) local = True + else: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e for shard_file in original_shard_filenames: if local: shard_file_present = any(shard_file in k for k in model_files_info) From f38a64443fc81dd481ff6ee3bb4b7690082d8591 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 28 Jul 2025 20:19:38 +0530 Subject: [PATCH 4/4] Revert "tighten compilation tests for quantization" This reverts commit 8d431dc967a4118168af74aae9c41f2a68764851. --- tests/quantization/bnb/test_4bit.py | 1 - tests/quantization/test_torch_compile_utils.py | 10 ++-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 08c0fee43b80..8e2a8515c662 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -886,7 +886,6 @@ def quantization_config(self): components_to_quantize=["transformer", "text_encoder_2"], ) - @require_bitsandbytes_version_greater("0.46.1") def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True super().test_torch_compile() diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index 91ed173fc69b..c742927646b6 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -56,18 +56,12 @@ def _test_torch_compile(self, torch_dtype=torch.bfloat16): pipe.transformer.compile(fullgraph=True) # small resolutions to ensure speedy execution. - with torch._dynamo.config.patch(error_on_recompile=True): - pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) + pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256) def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16): pipe = self._init_pipeline(self.quantization_config, torch_dtype) pipe.enable_model_cpu_offload() - # regional compilation is better for offloading. - # see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/ - if getattr(pipe.transformer, "_repeated_blocks"): - pipe.transformer.compile_repeated_blocks(fullgraph=True) - else: - pipe.transformer.compile() + pipe.transformer.compile() # small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)