From 48209da4721d9c03fcc7dfbee3fb72ad45fdbabc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 1 Aug 2025 15:21:39 +0530 Subject: [PATCH] support hf_quantizer in cache warmup. --- src/diffusers/models/model_loading_utils.py | 17 ++++++++----- src/diffusers/models/modeling_utils.py | 4 ++-- src/diffusers/quantizers/base.py | 11 +++++++++ .../quantizers/torchao/torchao_quantizer.py | 24 +++++++++++++++++++ 4 files changed, 48 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 4e2d24b75011..4e4af0af9bdd 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -16,7 +16,6 @@ import importlib import inspect -import math import os from array import array from collections import OrderedDict, defaultdict @@ -559,27 +558,33 @@ def _expand_device_map(device_map, param_names): # Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859 -def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None: +def _caching_allocator_warmup( + model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer] +) -> None: """ This function warm-ups the caching allocator based on the size of the model tensors that will reside on each device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model, which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a very large margin. """ + factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor() # Remove disk and cpu devices, and cast to proper torch.device accelerator_device_map = { param: torch.device(device) for param, device in expanded_device_map.items() if str(device) not in ["cpu", "disk"] } - parameter_count = defaultdict(lambda: 0) + total_byte_count = defaultdict(lambda: 0) for param_name, device in accelerator_device_map.items(): try: param = model.get_parameter(param_name) except AttributeError: param = model.get_buffer(param_name) - parameter_count[device] += math.prod(param.shape) + # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` + param_byte_count = param.numel() * param.element_size() + # TODO: account for TP when needed. + total_byte_count[device] += param_byte_count # This will kick off the caching allocator to avoid having to Malloc afterwards - for device, param_count in parameter_count.items(): - _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) + for device, byte_count in total_byte_count.items(): + _ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 815f12a70774..8f175bd68dd8 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1551,9 +1551,9 @@ def _load_pretrained_model( # When the actual device allocations happen, the allocator already has a pool of unused device memory # that it can re-use for faster loading of the model. # TODO: add support for warmup with hf_quantizer - if device_map is not None and hf_quantizer is None: + if device_map is not None: expanded_device_map = _expand_device_map(device_map, expected_keys) - _caching_allocator_warmup(model, expanded_device_map, dtype) + _caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer) offload_index = {} if device_map is not None and "disk" in device_map.values() else None state_dict_folder, state_dict_index = None, None diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py index 357d920d29c4..24fc724b4c88 100644 --- a/src/diffusers/quantizers/base.py +++ b/src/diffusers/quantizers/base.py @@ -209,6 +209,17 @@ def dequantize(self, model): return model + def get_cuda_warm_up_factor(self): + """ + The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda. + A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means + we allocate half the memory of the weights residing in the empty model, etc... + """ + # By default we return 4, i.e. half the model size (this corresponds to the case where the model is not + # really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual + # weight loading) + return 4 + def _dequantize(self, model): raise NotImplementedError( f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index c12513f061da..c7648b529d82 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -19,6 +19,7 @@ import importlib import types +from fnmatch import fnmatch from typing import TYPE_CHECKING, Any, Dict, List, Union from packaging import version @@ -278,6 +279,29 @@ def create_quantized_param( module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) quantize_(module, self.quantization_config.get_apply_tensor_subclass()) + def get_cuda_warm_up_factor(self): + """ + This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup. + - A factor of 2 means we pre-allocate the full memory footprint of the model. + - A factor of 4 means we pre-allocate half of that, and so on + + However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give + the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents + quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the + torch_dtype not the actual bit-width of the quantized data. + + To correct for this: + - Use a division factor of 8 for int4 weights + - Use a division factor of 4 for int8 weights + """ + # Original mapping for non-AOBaseConfig types + map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "float8*": 4} + quant_type = self.quantization_config.quant_type + for pattern, target_dtype in map_to_target_dtype.items(): + if fnmatch(quant_type, pattern): + return target_dtype + raise ValueError(f"Unsupported quant_type: {quant_type!r}") + def _process_model_before_weight_loading( self, model: "ModelMixin",