Skip to content

support hf_quantizer in cache warmup. #12043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import importlib
import inspect
import math
import os
from array import array
from collections import OrderedDict, defaultdict
Expand Down Expand Up @@ -559,27 +558,33 @@ def _expand_device_map(device_map, param_names):


# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
def _caching_allocator_warmup(
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
) -> None:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
parameter_count = defaultdict(lambda: 0)
total_byte_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += math.prod(param.shape)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()
# TODO: account for TP when needed.
total_byte_count[device] += param_byte_count

# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
for device, byte_count in total_byte_count.items():
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)
4 changes: 2 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,9 +1551,9 @@ def _load_pretrained_model(
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
# TODO: add support for warmup with hf_quantizer
if device_map is not None and hf_quantizer is None:
if device_map is not None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, dtype)
_caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)

offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ def dequantize(self, model):

return model

def get_cuda_warm_up_factor(self):
"""
The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
we allocate half the memory of the weights residing in the empty model, etc...
"""
# By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
# really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
# weight loading)
return 4

def _dequantize(self, model):
raise NotImplementedError(
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
Expand Down
24 changes: 24 additions & 0 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import importlib
import types
from fnmatch import fnmatch
from typing import TYPE_CHECKING, Any, Dict, List, Union

from packaging import version
Expand Down Expand Up @@ -278,6 +279,29 @@ def create_quantized_param(
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())

def get_cuda_warm_up_factor(self):
"""
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
- A factor of 2 means we pre-allocate the full memory footprint of the model.
- A factor of 4 means we pre-allocate half of that, and so on

However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
torch_dtype not the actual bit-width of the quantized data.

To correct for this:
- Use a division factor of 8 for int4 weights
- Use a division factor of 4 for int8 weights
"""
# Original mapping for non-AOBaseConfig types
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "float8*": 4}
quant_type = self.quantization_config.quant_type
for pattern, target_dtype in map_to_target_dtype.items():
if fnmatch(quant_type, pattern):
return target_dtype
raise ValueError(f"Unsupported quant_type: {quant_type!r}")

def _process_model_before_weight_loading(
self,
model: "ModelMixin",
Expand Down
Loading