Skip to content

Commit 48209da

Browse files
committed
support hf_quantizer in cache warmup.
1 parent 20e0740 commit 48209da

File tree

4 files changed

+48
-8
lines changed

4 files changed

+48
-8
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import importlib
1818
import inspect
19-
import math
2019
import os
2120
from array import array
2221
from collections import OrderedDict, defaultdict
@@ -559,27 +558,33 @@ def _expand_device_map(device_map, param_names):
559558

560559

561560
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
562-
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
561+
def _caching_allocator_warmup(
562+
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
563+
) -> None:
563564
"""
564565
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
565566
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
566567
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
567568
very large margin.
568569
"""
570+
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
569571
# Remove disk and cpu devices, and cast to proper torch.device
570572
accelerator_device_map = {
571573
param: torch.device(device)
572574
for param, device in expanded_device_map.items()
573575
if str(device) not in ["cpu", "disk"]
574576
}
575-
parameter_count = defaultdict(lambda: 0)
577+
total_byte_count = defaultdict(lambda: 0)
576578
for param_name, device in accelerator_device_map.items():
577579
try:
578580
param = model.get_parameter(param_name)
579581
except AttributeError:
580582
param = model.get_buffer(param_name)
581-
parameter_count[device] += math.prod(param.shape)
583+
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
584+
param_byte_count = param.numel() * param.element_size()
585+
# TODO: account for TP when needed.
586+
total_byte_count[device] += param_byte_count
582587

583588
# This will kick off the caching allocator to avoid having to Malloc afterwards
584-
for device, param_count in parameter_count.items():
585-
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
589+
for device, byte_count in total_byte_count.items():
590+
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,9 +1551,9 @@ def _load_pretrained_model(
15511551
# When the actual device allocations happen, the allocator already has a pool of unused device memory
15521552
# that it can re-use for faster loading of the model.
15531553
# TODO: add support for warmup with hf_quantizer
1554-
if device_map is not None and hf_quantizer is None:
1554+
if device_map is not None:
15551555
expanded_device_map = _expand_device_map(device_map, expected_keys)
1556-
_caching_allocator_warmup(model, expanded_device_map, dtype)
1556+
_caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)
15571557

15581558
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
15591559
state_dict_folder, state_dict_index = None, None

src/diffusers/quantizers/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,17 @@ def dequantize(self, model):
209209

210210
return model
211211

212+
def get_cuda_warm_up_factor(self):
213+
"""
214+
The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
215+
A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
216+
we allocate half the memory of the weights residing in the empty model, etc...
217+
"""
218+
# By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
219+
# really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
220+
# weight loading)
221+
return 4
222+
212223
def _dequantize(self, model):
213224
raise NotImplementedError(
214225
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import importlib
2121
import types
22+
from fnmatch import fnmatch
2223
from typing import TYPE_CHECKING, Any, Dict, List, Union
2324

2425
from packaging import version
@@ -278,6 +279,29 @@ def create_quantized_param(
278279
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
279280
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
280281

282+
def get_cuda_warm_up_factor(self):
283+
"""
284+
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
285+
- A factor of 2 means we pre-allocate the full memory footprint of the model.
286+
- A factor of 4 means we pre-allocate half of that, and so on
287+
288+
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
289+
the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
290+
quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
291+
torch_dtype not the actual bit-width of the quantized data.
292+
293+
To correct for this:
294+
- Use a division factor of 8 for int4 weights
295+
- Use a division factor of 4 for int8 weights
296+
"""
297+
# Original mapping for non-AOBaseConfig types
298+
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "float8*": 4}
299+
quant_type = self.quantization_config.quant_type
300+
for pattern, target_dtype in map_to_target_dtype.items():
301+
if fnmatch(quant_type, pattern):
302+
return target_dtype
303+
raise ValueError(f"Unsupported quant_type: {quant_type!r}")
304+
281305
def _process_model_before_weight_loading(
282306
self,
283307
model: "ModelMixin",

0 commit comments

Comments
 (0)