diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index cf85488b7aa0..bc43ba83cf69 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -45,6 +45,7 @@ ) from packaging import version from requests import HTTPError +from requests.exceptions import ConnectionError from .. import __version__ from .constants import ( @@ -403,9 +404,26 @@ def _get_checkpoint_shard_files( ignore_patterns = ["*.json", "*.md"] # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + local = False + try: + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + except ConnectionError as e: + if local_files_only: + temp_dir = snapshot_download( + repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only + ) + model_files_info = _get_filepaths_for_folder(temp_dir) + local = True + else: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e for shard_file in original_shard_filenames: - shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if local: + shard_file_present = any(shard_file in k for k in model_files_info) + else: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) if not shard_file_present: raise EnvironmentError( f"{shards_path} does not appear to have a file named {shard_file} which is " @@ -441,6 +459,16 @@ def _get_checkpoint_shard_files( return cached_filenames, sharded_metadata +def _get_filepaths_for_folder(folder): + relative_paths = [] + for root, dirs, files in os.walk(folder): + for fname in files: + abs_path = os.path.join(root, fname) + rel_path = os.path.relpath(abs_path, start=folder) + relative_paths.append(rel_path) + return relative_paths + + def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None): if filenames and folder: raise ValueError("Both `filenames` and `folder` cannot be provided.")