Skip to content

[core] respect local_files_only=True when using sharded checkpoints #12005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from packaging import version
from requests import HTTPError
from requests.exceptions import ConnectionError

from .. import __version__
from .constants import (
Expand Down Expand Up @@ -403,9 +404,26 @@ def _get_checkpoint_shard_files(

ignore_patterns = ["*.json", "*.md"]
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
local = False
try:
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
except ConnectionError as e:
if local_files_only:
temp_dir = snapshot_download(
repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only
)
model_files_info = _get_filepaths_for_folder(temp_dir)
local = True
else:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if local:
shard_file_present = any(shard_file in k for k in model_files_info)
else:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
Expand Down Expand Up @@ -441,6 +459,16 @@ def _get_checkpoint_shard_files(
return cached_filenames, sharded_metadata


def _get_filepaths_for_folder(folder):
relative_paths = []
for root, dirs, files in os.walk(folder):
for fname in files:
abs_path = os.path.join(root, fname)
rel_path = os.path.relpath(abs_path, start=folder)
relative_paths.append(rel_path)
return relative_paths


def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
if filenames and folder:
raise ValueError("Both `filenames` and `folder` cannot be provided.")
Expand Down
Loading