From af72eced377b07d5ce4a0dc126d54d3c29fc35dd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Jul 2025 11:06:11 +0530 Subject: [PATCH 1/6] checking. --- src/diffusers/models/model_loading_utils.py | 125 ++++++++++++++++++++ src/diffusers/models/modeling_utils.py | 124 +++++++------------ 2 files changed, 167 insertions(+), 82 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index ebc7d79aeb28..a01e409db8b4 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -19,6 +19,7 @@ import os from array import array from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Dict, List, Optional, Union from zipfile import is_zipfile @@ -304,6 +305,130 @@ def load_model_dict_into_meta( return offload_index, state_dict_index +def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): + """ + Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first + checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's + parameters. + + """ + if model_to_load.device.type == "meta": + return False + + if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: + return False + + # Some models explicitly do not support param buffer assignment + if not getattr(model_to_load, "_supports_param_buffer_assignment", True): + logger.debug( + f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" + ) + return False + + # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype + first_key = next(iter(model_to_load.state_dict().keys())) + if start_prefix + first_key in state_dict: + return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype + + return False + + +def load_shard_file(args): + ( + model, + model_state_dict, + shard_file, + device_map, + dtype, + hf_quantizer, + keep_in_fp32_modules, + dduf_entries, + loaded_keys, + unexpected_keys, + offload_index, + offload_folder, + state_dict_index, + state_dict_folder, + ignore_mismatched_sizes, + low_cpu_mem_usage, + ) = args + assign_to_params_buffers = None + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = [] + if low_cpu_mem_usage: + offload_index, state_dict_index = load_model_dict_into_meta( + model, + state_dict, + device_map=device_map, + dtype=dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_index=state_dict_index, + state_dict_folder=state_dict_folder, + ) + else: + if assign_to_params_buffers is None: + assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) + + error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) + return offload_index, state_dict_index, mismatched_keys, error_msgs + + +def load_shard_files_with_threadpool(args_list): + num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) + + # Do not spawn anymore workers than you need + num_workers = min(len(args_list), num_workers) + + logger.info(f"Loading model weights in parallel with {num_workers} workers...") + + error_msgs = [] + mismatched_keys = [] + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: + futures = [executor.submit(load_shard_file, arg) for arg in args_list] + for future in as_completed(futures): + result = future.result() + offload_index, state_dict_index, _mismatched_keys, _error_msgs = result + error_msgs += _error_msgs + mismatched_keys += _mismatched_keys + pbar.update(1) + + return offload_index, state_dict_index, mismatched_keys, error_msgs + + +def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, +): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + + if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + def _load_state_dict_into_model( model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False ) -> List[str]: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 8e1ec5f55889..0ddfd5c858a9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -66,8 +66,8 @@ _determine_device_map, _fetch_index_file, _fetch_index_file_legacy, - _load_state_dict_into_model, - load_model_dict_into_meta, + load_shard_file, + load_shard_files_with_threadpool, load_state_dict, ) @@ -200,34 +200,6 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return last_tuple[1].dtype -def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): - """ - Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first - checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's - parameters. - - """ - if model_to_load.device.type == "meta": - return False - - if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: - return False - - # Some models explicitly do not support param buffer assignment - if not getattr(model_to_load, "_supports_param_buffer_assignment", True): - logger.debug( - f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" - ) - return False - - # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype - first_key = next(iter(model_to_load.state_dict().keys())) - if start_prefix + first_key in state_dict: - return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype - - return False - - @contextmanager def no_init_weights(): """ @@ -926,6 +898,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) + # TODO: enable TRUE ENV VARs + is_parallel_loading_enabled = bool(os.environ.get("HF_ENABLE_PARALLEL_LOADING", 1)) + + if is_parallel_loading_enabled and not low_cpu_mem_usage: + raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.") + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 logger.warning( @@ -1261,6 +1239,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, dduf_entries=dduf_entries, + is_parallel_loading_enabled=is_parallel_loading_enabled, ) loading_info = { "missing_keys": missing_keys, @@ -1456,6 +1435,7 @@ def _load_pretrained_model( offload_state_dict: Optional[bool] = None, offload_folder: Optional[Union[str, os.PathLike]] = None, dduf_entries: Optional[Dict[str, DDUFEntry]] = None, + is_parallel_loading_enabled: Optional[bool] = False, ): model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) @@ -1470,8 +1450,6 @@ def _load_pretrained_model( unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] mismatched_keys = [] - - assign_to_params_buffers = None error_msgs = [] # Deal with offload @@ -1499,63 +1477,45 @@ def _load_pretrained_model( # load_state_dict will manage the case where we pass a dict instead of a file # if state dict is not None, it means that we don't need to read the files from resolved_model_file also resolved_model_file = [state_dict] + is_file = not isinstance(state_dict, dict) - if len(resolved_model_file) > 1: - resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") - - for shard_file in resolved_model_file: - state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) - - def _find_mismatched_keys( - state_dict, - model_state_dict, - loaded_keys, - ignore_mismatched_sizes, - ): - mismatched_keys = [] - if ignore_mismatched_sizes: - for checkpoint_key in loaded_keys: - model_key = checkpoint_key - # If the checkpoint is sharded, we may not have the key here. - if checkpoint_key not in state_dict: - continue - - if ( - model_key in model_state_dict - and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape - ): - mismatched_keys.append( - (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) - ) - del state_dict[checkpoint_key] - return mismatched_keys - - mismatched_keys += _find_mismatched_keys( - state_dict, + # prepare the arguments. + args_list = [ + ( + model, model_state_dict, + shard_file, + device_map, + dtype, + hf_quantizer, + keep_in_fp32_modules, + dduf_entries, loaded_keys, + unexpected_keys, + offload_index, + offload_folder, + state_dict_index, + state_dict_folder, ignore_mismatched_sizes, + low_cpu_mem_usage, ) + for shard_file in resolved_model_file + ] - if low_cpu_mem_usage: - offload_index, state_dict_index = load_model_dict_into_meta( - model, - state_dict, - device_map=device_map, - dtype=dtype, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - unexpected_keys=unexpected_keys, - offload_folder=offload_folder, - offload_index=offload_index, - state_dict_index=state_dict_index, - state_dict_folder=state_dict_folder, - ) - else: - if assign_to_params_buffers is None: - assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) + if is_parallel_loading_enabled and is_file: + offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool( + args_list + ) + error_msgs += _error_msgs + mismatched_keys += _mismatched_keys + else: + if len(args_list) > 1: + args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") - error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) + for args in args_list: + offload_index, state_dict_index, _error_msgs = load_shard_file(args) + error_msgs += _error_msgs + mismatched_keys += _mismatched_keys if offload_index is not None and len(offload_index) > 0: save_offload_index(offload_index, offload_folder) From d4e297620763079e45ffb3f29a3d540fa7b60d09 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Jul 2025 12:12:33 +0530 Subject: [PATCH 2/6] checking --- src/diffusers/loaders/single_file_model.py | 2 +- src/diffusers/loaders/single_file_utils.py | 2 +- src/diffusers/loaders/transformer_flux.py | 3 ++- src/diffusers/loaders/transformer_sd3.py | 3 ++- src/diffusers/loaders/unet.py | 3 ++- src/diffusers/models/model_loading_utils.py | 1 + src/diffusers/models/modeling_utils.py | 7 ++++--- 7 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 17ac81ca26f6..5c1bfda37fee 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -61,7 +61,7 @@ if is_accelerate_available(): from accelerate import dispatch_model, init_empty_weights - from ..models.modeling_utils import load_model_dict_into_meta + from ..models.model_loading_utils import load_model_dict_into_meta SINGLE_FILE_LOADABLE_CLASSES = { diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ee0786aa2d6a..a4bbb7d4fbb5 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -54,7 +54,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights - from ..models.modeling_utils import load_model_dict_into_meta + from ..models.model_loading_utils import load_model_dict_into_meta logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index c7d81a8baebd..a8af3c61216f 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -17,7 +17,8 @@ ImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ..models.model_loading_utils import load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..utils import ( is_accelerate_available, is_torch_version, diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py index c58d3280cfe1..e48f251c08d5 100644 --- a/src/diffusers/loaders/transformer_sd3.py +++ b/src/diffusers/loaders/transformer_sd3.py @@ -16,7 +16,8 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 from ..models.embeddings import IPAdapterTimeImageProjection -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ..models.model_loading_utils import load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..utils import is_accelerate_available, is_torch_version, logging diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 3546497c195b..7ff2f0cee95b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -30,7 +30,8 @@ IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict +from ..models.model_loading_utils import load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index a01e409db8b4..ff1ab94d6358 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -388,6 +388,7 @@ def load_shard_files_with_threadpool(args_list): # Do not spawn anymore workers than you need num_workers = min(len(args_list), num_workers) + print(f"{num_workers=}") logger.info(f"Loading model weights in parallel with {num_workers} workers...") diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0ddfd5c858a9..777e59c06b49 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -899,7 +899,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P disable_mmap = kwargs.pop("disable_mmap", False) # TODO: enable TRUE ENV VARs - is_parallel_loading_enabled = bool(os.environ.get("HF_ENABLE_PARALLEL_LOADING", 1)) + is_parallel_loading_enabled = bool(os.environ.get("HF_ENABLE_PARALLEL_LOADING", 0)) if is_parallel_loading_enabled and not low_cpu_mem_usage: raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.") @@ -1477,7 +1477,7 @@ def _load_pretrained_model( # load_state_dict will manage the case where we pass a dict instead of a file # if state dict is not None, it means that we don't need to read the files from resolved_model_file also resolved_model_file = [state_dict] - is_file = not isinstance(state_dict, dict) + is_file = resolved_model_file and state_dict is None # prepare the arguments. args_list = [ @@ -1502,6 +1502,7 @@ def _load_pretrained_model( for shard_file in resolved_model_file ] + print(f"{is_parallel_loading_enabled=}, {is_file=}") if is_parallel_loading_enabled and is_file: offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool( args_list @@ -1513,7 +1514,7 @@ def _load_pretrained_model( args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") for args in args_list: - offload_index, state_dict_index, _error_msgs = load_shard_file(args) + offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_file(args) error_msgs += _error_msgs mismatched_keys += _mismatched_keys From c9b680da2c841b53e0dfcc02806850597678097e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 10 Jul 2025 12:26:09 +0530 Subject: [PATCH 3/6] checking --- src/diffusers/models/model_loading_utils.py | 3 +-- src/diffusers/models/modeling_utils.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index ff1ab94d6358..7393a74d7cba 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -388,8 +388,7 @@ def load_shard_files_with_threadpool(args_list): # Do not spawn anymore workers than you need num_workers = min(len(args_list), num_workers) - print(f"{num_workers=}") - + logger.info(f"Loading model weights in parallel with {num_workers} workers...") error_msgs = [] diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 777e59c06b49..d3f28ce885e0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1502,7 +1502,6 @@ def _load_pretrained_model( for shard_file in resolved_model_file ] - print(f"{is_parallel_loading_enabled=}, {is_file=}") if is_parallel_loading_enabled and is_file: offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool( args_list From 536df5a01d9cd32b4ab55a3c245f7d03f90f40cb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Jul 2025 13:25:09 +0530 Subject: [PATCH 4/6] up --- src/diffusers/models/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 963c8df1e1c4..154bd8479547 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -45,6 +45,7 @@ SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, + ENV_VARS_TRUE_VALUES, WEIGHTS_NAME, _add_variant, _get_checkpoint_shard_files, @@ -959,8 +960,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) - # TODO: enable TRUE ENV VARs - is_parallel_loading_enabled = bool(os.environ.get("HF_ENABLE_PARALLEL_LOADING", 0)) + is_parallel_loading_enabled = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES if is_parallel_loading_enabled and not low_cpu_mem_usage: raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.") From 04cd5cc3ff614e79827515262842e52a11d6f710 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Jul 2025 13:36:38 +0530 Subject: [PATCH 5/6] up --- src/diffusers/models/modeling_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 154bd8479547..3d82e314642b 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -41,11 +41,11 @@ from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( CONFIG_NAME, + ENV_VARS_TRUE_VALUES, FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, - ENV_VARS_TRUE_VALUES, WEIGHTS_NAME, _add_variant, _get_checkpoint_shard_files, @@ -1547,7 +1547,6 @@ def _load_pretrained_model( # load_state_dict will manage the case where we pass a dict instead of a file # if state dict is not None, it means that we don't need to read the files from resolved_model_file also resolved_model_file = [state_dict] - is_file = resolved_model_file and state_dict is None # prepare the arguments. args_list = [ @@ -1572,7 +1571,7 @@ def _load_pretrained_model( for shard_file in resolved_model_file ] - if is_parallel_loading_enabled and is_file: + if is_parallel_loading_enabled: offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool( args_list ) From cb0b3ed3385b45a8974cbdbe2f709211e552274c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Jul 2025 13:50:53 +0530 Subject: [PATCH 6/6] up --- src/diffusers/loaders/transformer_flux.py | 3 ++- src/diffusers/models/modeling_utils.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index ced81960fae5..ef7b921b7ddf 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -17,7 +17,8 @@ ImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta +from ..models.model_loading_utils import load_model_dict_into_meta +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..utils import is_accelerate_available, is_torch_version, logging from ..utils.torch_utils import empty_device_cache diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 3d82e314642b..69d26b7ada1d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -961,7 +961,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P disable_mmap = kwargs.pop("disable_mmap", False) is_parallel_loading_enabled = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES - if is_parallel_loading_enabled and not low_cpu_mem_usage: raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")