Skip to content

Commit 9c13f86

Browse files
[training] add an offload utility that can be used as a context manager. (#11775)
* add an offload utility that can be used as a context manager. * update --------- Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
1 parent 5c52097 commit 9c13f86

File tree

3 files changed

+62
-36
lines changed

3 files changed

+62
-36
lines changed

.github/workflows/pr_tests_gpu.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ on:
1313
- "src/diffusers/loaders/peft.py"
1414
- "tests/pipelines/test_pipelines_common.py"
1515
- "tests/models/test_modeling_common.py"
16+
- "examples/**/*.py"
1617
workflow_dispatch:
1718

1819
concurrency:

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
compute_density_for_timestep_sampling,
5959
compute_loss_weighting_for_sd3,
6060
free_memory,
61+
offload_models,
6162
)
6263
from diffusers.utils import (
6364
check_min_version,
@@ -1364,43 +1365,34 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13641365
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
13651366
# the redundant encoding.
13661367
if not train_dataset.custom_instance_prompts:
1367-
if args.offload:
1368-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1369-
(
1370-
instance_prompt_hidden_states_t5,
1371-
instance_prompt_hidden_states_llama3,
1372-
instance_pooled_prompt_embeds,
1373-
_,
1374-
_,
1375-
_,
1376-
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
1377-
if args.offload:
1378-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1368+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1369+
(
1370+
instance_prompt_hidden_states_t5,
1371+
instance_prompt_hidden_states_llama3,
1372+
instance_pooled_prompt_embeds,
1373+
_,
1374+
_,
1375+
_,
1376+
) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
13791377

13801378
# Handle class prompt for prior-preservation.
13811379
if args.with_prior_preservation:
1382-
if args.offload:
1383-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1384-
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
1385-
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
1386-
)
1387-
if args.offload:
1388-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1380+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1381+
(class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
1382+
compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
1383+
)
13891384

13901385
validation_embeddings = {}
13911386
if args.validation_prompt is not None:
1392-
if args.offload:
1393-
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
1394-
(
1395-
validation_embeddings["prompt_embeds_t5"],
1396-
validation_embeddings["prompt_embeds_llama3"],
1397-
validation_embeddings["pooled_prompt_embeds"],
1398-
validation_embeddings["negative_prompt_embeds_t5"],
1399-
validation_embeddings["negative_prompt_embeds_llama3"],
1400-
validation_embeddings["negative_pooled_prompt_embeds"],
1401-
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
1402-
if args.offload:
1403-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1387+
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1388+
(
1389+
validation_embeddings["prompt_embeds_t5"],
1390+
validation_embeddings["prompt_embeds_llama3"],
1391+
validation_embeddings["pooled_prompt_embeds"],
1392+
validation_embeddings["negative_prompt_embeds_t5"],
1393+
validation_embeddings["negative_prompt_embeds_llama3"],
1394+
validation_embeddings["negative_pooled_prompt_embeds"],
1395+
) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
14041396

14051397
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14061398
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1581,12 +1573,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15811573
if args.cache_latents:
15821574
model_input = latents_cache[step].sample()
15831575
else:
1584-
if args.offload:
1585-
vae = vae.to(accelerator.device)
1586-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1576+
with offload_models(vae, device=accelerator.device, offload=args.offload):
1577+
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
15871578
model_input = vae.encode(pixel_values).latent_dist.sample()
1588-
if args.offload:
1589-
vae = vae.to("cpu")
1579+
15901580
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
15911581
model_input = model_input.to(dtype=weight_dtype)
15921582

src/diffusers/training_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import random
66
import re
77
import warnings
8+
from contextlib import contextmanager
89
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
910

1011
import numpy as np
1112
import torch
1213

1314
from .models import UNet2DConditionModel
15+
from .pipelines import DiffusionPipeline
1416
from .schedulers import SchedulerMixin
1517
from .utils import (
1618
convert_state_dict_to_diffusers,
@@ -318,6 +320,39 @@ def free_memory():
318320
torch.xpu.empty_cache()
319321

320322

323+
@contextmanager
324+
def offload_models(
325+
*modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True
326+
):
327+
"""
328+
Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
329+
device on exit.
330+
331+
Args:
332+
device (`str` or `torch.Device`): Device to move the `modules` to.
333+
offload (`bool`): Flag to enable offloading.
334+
"""
335+
if offload:
336+
is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
337+
# record where each module was
338+
if is_model:
339+
original_devices = [next(m.parameters()).device for m in modules]
340+
else:
341+
assert len(modules) == 1
342+
original_devices = modules[0].device
343+
# move to target device
344+
for m in modules:
345+
m.to(device)
346+
347+
try:
348+
yield
349+
finally:
350+
if offload:
351+
# move back to original devices
352+
for m, orig_dev in zip(modules, original_devices):
353+
m.to(orig_dev)
354+
355+
321356
def parse_buckets_string(buckets_str):
322357
"""Parses a string defining buckets into a list of (height, width) tuples."""
323358
if not buckets_str:

0 commit comments

Comments
 (0)