Skip to content

Commit 67347f6

Browse files
committed
Merge branch 'wan22-followup' of github.com:huggingface/diffusers into wan22-followup
2 parents 5e17dde + dd1328d commit 67347f6

File tree

7 files changed

+109
-53
lines changed

7 files changed

+109
-53
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515

16+
# /// script
17+
# dependencies = [
18+
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
19+
# "torch>=2.0.0",
20+
# "accelerate>=0.31.0",
21+
# "transformers>=4.41.2",
22+
# "ftfy",
23+
# "tensorboard",
24+
# "Jinja2",
25+
# "peft>=0.11.1",
26+
# "sentencepiece",
27+
# ]
28+
# ///
29+
1630
import argparse
1731
import copy
1832
import itertools

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515

16+
# /// script
17+
# dependencies = [
18+
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
19+
# "torch>=2.0.0",
20+
# "accelerate>=0.31.0",
21+
# "transformers>=4.41.2",
22+
# "ftfy",
23+
# "tensorboard",
24+
# "Jinja2",
25+
# "peft>=0.11.1",
26+
# "sentencepiece",
27+
# ]
28+
# ///
29+
1630
import argparse
1731
import gc
1832
import hashlib

examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515

16+
# /// script
17+
# dependencies = [
18+
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
19+
# "torch>=2.0.0",
20+
# "accelerate>=0.31.0",
21+
# "transformers>=4.41.2",
22+
# "ftfy",
23+
# "tensorboard",
24+
# "Jinja2",
25+
# "peft>=0.11.1",
26+
# "sentencepiece",
27+
# ]
28+
# ///
29+
1630
import argparse
1731
import gc
1832
import itertools

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515

16+
# /// script
17+
# dependencies = [
18+
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
19+
# "torch>=2.0.0",
20+
# "accelerate>=0.31.0",
21+
# "transformers>=4.41.2",
22+
# "ftfy",
23+
# "tensorboard",
24+
# "Jinja2",
25+
# "peft>=0.11.1",
26+
# "sentencepiece",
27+
# ]
28+
# ///
29+
1630
import argparse
1731
import copy
1832
import gc

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515

16+
# /// script
17+
# dependencies = [
18+
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
19+
# "torch>=2.0.0",
20+
# "accelerate>=0.31.0",
21+
# "transformers>=4.41.2",
22+
# "ftfy",
23+
# "tensorboard",
24+
# "Jinja2",
25+
# "peft>=0.11.1",
26+
# "sentencepiece",
27+
# ]
28+
# ///
29+
1630
import argparse
1731
import copy
1832
import itertools

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515

16+
# /// script
17+
# dependencies = [
18+
# "diffusers @ git+https://github.com/huggingface/diffusers.git",
19+
# "torch>=2.0.0",
20+
# "accelerate>=1.0.0",
21+
# "transformers>=4.47.0",
22+
# "ftfy",
23+
# "tensorboard",
24+
# "Jinja2",
25+
# "peft>=0.14.0",
26+
# "sentencepiece",
27+
# ]
28+
# ///
29+
1630
import argparse
1731
import copy
1832
import itertools

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 25 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -913,38 +913,21 @@ def patchify(x, patch_size):
913913
if patch_size == 1:
914914
return x
915915

916-
if x.dim() == 4:
917-
# x shape: [batch_size, channels, height, width]
918-
batch_size, channels, height, width = x.shape
919-
920-
# Ensure height and width are divisible by patch_size
921-
if height % patch_size != 0 or width % patch_size != 0:
922-
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
923-
924-
# Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
925-
x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
926-
927-
# Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
928-
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
929-
x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size)
930-
931-
elif x.dim() == 5:
932-
# x shape: [batch_size, channels, frames, height, width]
933-
batch_size, channels, frames, height, width = x.shape
934-
935-
# Ensure height and width are divisible by patch_size
936-
if height % patch_size != 0 or width % patch_size != 0:
937-
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
916+
if x.dim() != 5:
917+
raise ValueError(f"Invalid input shape: {x.shape}")
918+
# x shape: [batch_size, channels, frames, height, width]
919+
batch_size, channels, frames, height, width = x.shape
938920

939-
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
940-
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
921+
# Ensure height and width are divisible by patch_size
922+
if height % patch_size != 0 or width % patch_size != 0:
923+
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
941924

942-
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
943-
x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
944-
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
925+
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
926+
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
945927

946-
else:
947-
raise ValueError(f"Invalid input shape: {x.shape}")
928+
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
929+
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
930+
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
948931

949932
return x
950933

@@ -953,29 +936,18 @@ def unpatchify(x, patch_size):
953936
if patch_size == 1:
954937
return x
955938

956-
if x.dim() == 4:
957-
# x shape: [b, (c * patch_size * patch_size), h, w]
958-
batch_size, c_patches, height, width = x.shape
959-
channels = c_patches // (patch_size * patch_size)
960-
961-
# Reshape to [b, c, patch_size, patch_size, h, w]
962-
x = x.view(batch_size, channels, patch_size, patch_size, height, width)
963-
964-
# Rearrange to [b, c, h * patch_size, w * patch_size]
965-
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
966-
x = x.view(batch_size, channels, height * patch_size, width * patch_size)
967-
968-
elif x.dim() == 5:
969-
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
970-
batch_size, c_patches, frames, height, width = x.shape
971-
channels = c_patches // (patch_size * patch_size)
939+
if x.dim() != 5:
940+
raise ValueError(f"Invalid input shape: {x.shape}")
941+
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
942+
batch_size, c_patches, frames, height, width = x.shape
943+
channels = c_patches // (patch_size * patch_size)
972944

973-
# Reshape to [b, c, patch_size, patch_size, f, h, w]
974-
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
945+
# Reshape to [b, c, patch_size, patch_size, f, h, w]
946+
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
975947

976-
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
977-
x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous()
978-
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
948+
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
949+
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
950+
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
979951

980952
return x
981953

@@ -1044,7 +1016,6 @@ def __init__(
10441016
patch_size: Optional[int] = None,
10451017
scale_factor_temporal: Optional[int] = 4,
10461018
scale_factor_spatial: Optional[int] = 8,
1047-
clip_output: bool = True,
10481019
) -> None:
10491020
super().__init__()
10501021

@@ -1244,10 +1215,11 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
12441215
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
12451216
out = torch.cat([out, out_], 2)
12461217

1247-
if self.config.clip_output:
1248-
out = torch.clamp(out, min=-1.0, max=1.0)
12491218
if self.config.patch_size is not None:
12501219
out = unpatchify(out, patch_size=self.config.patch_size)
1220+
1221+
out = torch.clamp(out, min=-1.0, max=1.0)
1222+
12511223
self.clear_cache()
12521224
if not return_dict:
12531225
return (out,)

0 commit comments

Comments
 (0)