Skip to content

Commit 0b0f311

Browse files
authored
Merge branch 'main' into support-lighttx2v-lora
2 parents cd730e7 + 0c71189 commit 0b0f311

File tree

2 files changed

+27
-54
lines changed

2 files changed

+27
-54
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 25 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -913,38 +913,21 @@ def patchify(x, patch_size):
913913
if patch_size == 1:
914914
return x
915915

916-
if x.dim() == 4:
917-
# x shape: [batch_size, channels, height, width]
918-
batch_size, channels, height, width = x.shape
919-
920-
# Ensure height and width are divisible by patch_size
921-
if height % patch_size != 0 or width % patch_size != 0:
922-
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
923-
924-
# Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
925-
x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
926-
927-
# Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
928-
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
929-
x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size)
930-
931-
elif x.dim() == 5:
932-
# x shape: [batch_size, channels, frames, height, width]
933-
batch_size, channels, frames, height, width = x.shape
934-
935-
# Ensure height and width are divisible by patch_size
936-
if height % patch_size != 0 or width % patch_size != 0:
937-
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
916+
if x.dim() != 5:
917+
raise ValueError(f"Invalid input shape: {x.shape}")
918+
# x shape: [batch_size, channels, frames, height, width]
919+
batch_size, channels, frames, height, width = x.shape
938920

939-
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
940-
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
921+
# Ensure height and width are divisible by patch_size
922+
if height % patch_size != 0 or width % patch_size != 0:
923+
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
941924

942-
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
943-
x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
944-
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
925+
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
926+
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
945927

946-
else:
947-
raise ValueError(f"Invalid input shape: {x.shape}")
928+
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
929+
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
930+
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
948931

949932
return x
950933

@@ -953,29 +936,18 @@ def unpatchify(x, patch_size):
953936
if patch_size == 1:
954937
return x
955938

956-
if x.dim() == 4:
957-
# x shape: [b, (c * patch_size * patch_size), h, w]
958-
batch_size, c_patches, height, width = x.shape
959-
channels = c_patches // (patch_size * patch_size)
960-
961-
# Reshape to [b, c, patch_size, patch_size, h, w]
962-
x = x.view(batch_size, channels, patch_size, patch_size, height, width)
963-
964-
# Rearrange to [b, c, h * patch_size, w * patch_size]
965-
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
966-
x = x.view(batch_size, channels, height * patch_size, width * patch_size)
967-
968-
elif x.dim() == 5:
969-
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
970-
batch_size, c_patches, frames, height, width = x.shape
971-
channels = c_patches // (patch_size * patch_size)
939+
if x.dim() != 5:
940+
raise ValueError(f"Invalid input shape: {x.shape}")
941+
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
942+
batch_size, c_patches, frames, height, width = x.shape
943+
channels = c_patches // (patch_size * patch_size)
972944

973-
# Reshape to [b, c, patch_size, patch_size, f, h, w]
974-
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
945+
# Reshape to [b, c, patch_size, patch_size, f, h, w]
946+
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
975947

976-
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
977-
x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous()
978-
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
948+
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
949+
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
950+
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
979951

980952
return x
981953

@@ -1044,7 +1016,6 @@ def __init__(
10441016
patch_size: Optional[int] = None,
10451017
scale_factor_temporal: Optional[int] = 4,
10461018
scale_factor_spatial: Optional[int] = 8,
1047-
clip_output: bool = True,
10481019
) -> None:
10491020
super().__init__()
10501021

@@ -1244,10 +1215,11 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
12441215
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
12451216
out = torch.cat([out, out_], 2)
12461217

1247-
if self.config.clip_output:
1248-
out = torch.clamp(out, min=-1.0, max=1.0)
12491218
if self.config.patch_size is not None:
12501219
out = unpatchify(out, patch_size=self.config.patch_size)
1220+
1221+
out = torch.clamp(out, min=-1.0, max=1.0)
1222+
12511223
self.clear_cache()
12521224
if not return_dict:
12531225
return (out,)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,8 @@ def __call__(
10341034

10351035
# expand the latents if we are doing classifier free guidance
10361036
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1037-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1037+
if hasattr(self.scheduler, "scale_model_input"):
1038+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
10381039

10391040
# predict the noise residual
10401041
noise_pred = self.unet(

0 commit comments

Comments
 (0)