|
58 | 58 | compute_density_for_timestep_sampling,
|
59 | 59 | compute_loss_weighting_for_sd3,
|
60 | 60 | free_memory,
|
| 61 | + offload_models, |
61 | 62 | )
|
62 | 63 | from diffusers.utils import (
|
63 | 64 | check_min_version,
|
@@ -1364,43 +1365,34 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
|
1364 | 1365 | # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
1365 | 1366 | # the redundant encoding.
|
1366 | 1367 | if not train_dataset.custom_instance_prompts:
|
1367 |
| - if args.offload: |
1368 |
| - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) |
1369 |
| - ( |
1370 |
| - instance_prompt_hidden_states_t5, |
1371 |
| - instance_prompt_hidden_states_llama3, |
1372 |
| - instance_pooled_prompt_embeds, |
1373 |
| - _, |
1374 |
| - _, |
1375 |
| - _, |
1376 |
| - ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) |
1377 |
| - if args.offload: |
1378 |
| - text_encoding_pipeline = text_encoding_pipeline.to("cpu") |
| 1368 | + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): |
| 1369 | + ( |
| 1370 | + instance_prompt_hidden_states_t5, |
| 1371 | + instance_prompt_hidden_states_llama3, |
| 1372 | + instance_pooled_prompt_embeds, |
| 1373 | + _, |
| 1374 | + _, |
| 1375 | + _, |
| 1376 | + ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline) |
1379 | 1377 |
|
1380 | 1378 | # Handle class prompt for prior-preservation.
|
1381 | 1379 | if args.with_prior_preservation:
|
1382 |
| - if args.offload: |
1383 |
| - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) |
1384 |
| - (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( |
1385 |
| - compute_text_embeddings(args.class_prompt, text_encoding_pipeline) |
1386 |
| - ) |
1387 |
| - if args.offload: |
1388 |
| - text_encoding_pipeline = text_encoding_pipeline.to("cpu") |
| 1380 | + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): |
| 1381 | + (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = ( |
| 1382 | + compute_text_embeddings(args.class_prompt, text_encoding_pipeline) |
| 1383 | + ) |
1389 | 1384 |
|
1390 | 1385 | validation_embeddings = {}
|
1391 | 1386 | if args.validation_prompt is not None:
|
1392 |
| - if args.offload: |
1393 |
| - text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device) |
1394 |
| - ( |
1395 |
| - validation_embeddings["prompt_embeds_t5"], |
1396 |
| - validation_embeddings["prompt_embeds_llama3"], |
1397 |
| - validation_embeddings["pooled_prompt_embeds"], |
1398 |
| - validation_embeddings["negative_prompt_embeds_t5"], |
1399 |
| - validation_embeddings["negative_prompt_embeds_llama3"], |
1400 |
| - validation_embeddings["negative_pooled_prompt_embeds"], |
1401 |
| - ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) |
1402 |
| - if args.offload: |
1403 |
| - text_encoding_pipeline = text_encoding_pipeline.to("cpu") |
| 1387 | + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): |
| 1388 | + ( |
| 1389 | + validation_embeddings["prompt_embeds_t5"], |
| 1390 | + validation_embeddings["prompt_embeds_llama3"], |
| 1391 | + validation_embeddings["pooled_prompt_embeds"], |
| 1392 | + validation_embeddings["negative_prompt_embeds_t5"], |
| 1393 | + validation_embeddings["negative_prompt_embeds_llama3"], |
| 1394 | + validation_embeddings["negative_pooled_prompt_embeds"], |
| 1395 | + ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline) |
1404 | 1396 |
|
1405 | 1397 | # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
1406 | 1398 | # pack the statically computed variables appropriately here. This is so that we don't
|
@@ -1581,12 +1573,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
1581 | 1573 | if args.cache_latents:
|
1582 | 1574 | model_input = latents_cache[step].sample()
|
1583 | 1575 | else:
|
1584 |
| - if args.offload: |
1585 |
| - vae = vae.to(accelerator.device) |
1586 |
| - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) |
| 1576 | + with offload_models(vae, device=accelerator.device, offload=args.offload): |
| 1577 | + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) |
1587 | 1578 | model_input = vae.encode(pixel_values).latent_dist.sample()
|
1588 |
| - if args.offload: |
1589 |
| - vae = vae.to("cpu") |
| 1579 | + |
1590 | 1580 | model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
1591 | 1581 | model_input = model_input.to(dtype=weight_dtype)
|
1592 | 1582 |
|
|
0 commit comments