-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Fusing works fine on non-gguf versions of Wan2.2, but yield issues when used with the GGUF transformer which is what most consumers have to use due to memory constraints (which also make fusing quite important).
Reproduction
import torch
from diffusers import WanImageToVideoPipeline, DiffusionPipeline, LCMScheduler, UniPCMultistepScheduler, GGUFQuantizationConfig, AutoencoderKLWan, WanTransformer3DModel
from huggingface_hub import hf_hub_download
import requests
from PIL import Image
from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
from io import BytesIO
from diffusers.utils import export_to_video
import safetensors.torch
# GGUF Model URLs from QuantStack repository
# Using Q2_K variant as requested (smallest one)
REPO_ID = "QuantStack/Wan2.2-I2V-A14B-GGUF"
HIGH_NOISE_GGUF = "HighNoise/Wan2.2-I2V-A14B-HighNoise-Q2_K.gguf"
LOW_NOISE_GGUF = "LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
# Download GGUF models
high_noise_path = hf_hub_download(repo_id=REPO_ID, filename=HIGH_NOISE_GGUF)
low_noise_path = hf_hub_download(repo_id=REPO_ID, filename=LOW_NOISE_GGUF)
print("Loading quantized transformers...")
# Load quantized transformers
transformer_high_noise = WanTransformer3DModel.from_single_file(
high_noise_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
config="Wan-AI/Wan2.2-I2V-A14B-Diffusers",
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
transformer_low_noise = WanTransformer3DModel.from_single_file(
low_noise_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
config="Wan-AI/Wan2.2-I2V-A14B-Diffusers",
subfolder="transformer_2",
torch_dtype=torch.bfloat16,
)
print("Creating pipeline with quantized transformers...")
# Create pipeline with quantized transformers
pipe = WanImageToVideoPipeline.from_pretrained(
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
transformer=transformer_high_noise, # High noise goes to main transformer
transformer_2=transformer_low_noise, # Low noise goes to transformer_2
torch_dtype=torch.bfloat16,
)
print(pipe.scheduler)
# Keep UniPCM scheduler and add shift 8
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
print("Downloading LoRA weights...")
# Download LoRA weights
lora_path = hf_hub_download(
repo_id="Kijai/WanVideo_comfy",
filename="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"
)
print("Moving pipeline to GPU...")
pipe.to("cuda")
# wan 2.2 with both transformers
org_state_dict = safetensors.torch.load_file(lora_path)
converted_state_dict = _convert_non_diffusers_wan_lora_to_diffusers(org_state_dict)
# Load LoRA for main transformer
pipe.load_lora_weights(lora_path, adapter_name='lightx2v_t1')
pipe.set_adapters(["lightx2v_t1"], adapter_weights=[3.0])
# Load LoRA for transformer_2
pipe.transformer_2.load_lora_adapter(converted_state_dict, adapter_name="lightx2v")
pipe.transformer_2.set_adapters(["lightx2v"], weights=[1.5])
# THIS IS THE PROBLEMATIC LINE, WITHOUT IT WORKS
pipe.fuse_lora()
print("Loading input image...")
# Load input image
image_url = "https://cloud.inference.sh/u/4mg21r6ta37mpaz6ktzwtt8krr/01k1g7k73eebnrmzmc6h0bghq6.png"
response = requests.get(image_url)
input_image = Image.open(BytesIO(response.content)).convert("RGB")
print("Generating video...")
# Using LCM scheduler allows for fewer inference steps while maintaining quality
frames = pipe(input_image, "animate", num_inference_steps=4, guidance_scale=1.0).frames[0]
print("Exporting video...")
export_to_video(frames, "test_lora_gguf_q2.mp4")
print("Done! Video saved as test_lora_gguf_q2.mp4")
Logs
Traceback (most recent call last):
File "/home/luca/video/wan2-2-i2v-a14b-bkp/fast-wan-gguf.py", line 75, in <module>
pipe.fuse_lora()
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/diffusers/loaders/lora_pipeline.py", line 5430, in fuse_lora
super().fuse_lora(
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/diffusers/loaders/lora_base.py", line 611, in fuse_lora
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/diffusers/loaders/peft.py", line 660, in fuse_lora
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1044, in apply
module.apply(fn)
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1044, in apply
module.apply(fn)
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1044, in apply
module.apply(fn)
[Previous line repeated 1 more time]
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1045, in apply
fn(self)
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/diffusers/loaders/peft.py", line 682, in _fuse_lora_apply
module.merge(**merge_kwargs)
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/peft/tuners/lora/layer.py", line 677, in merge
base_layer.weight.data += delta_weight
File "/home/luca/video/wan2-2-i2v-a14b-bkp/.venv/lib/python3.12/site-packages/diffusers/quantizers/gguf/utils.py", line 428, in __torch_function__
result = super().__torch_function__(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (1680) must match the size of tensor b (5120) at non-singleton dimension 1
System Info
- 🤗 Diffusers version: 0.35.0.dev0
- Platform: Linux-5.15.0-136-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.7.1+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.34.3
- Transformers version: 4.55.0.dev0
- Accelerate version: 1.8.1
- PEFT version: 0.16.0
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB - Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
sorry to bother you again @sayakpaul @a-r-r-o-w
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working