-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[LoRA] support lightx2v lora in wan #12040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I can confirm it works for 2.1 (while for 2.2 this lora has a lot of key issues); however while it runs I still have to check results and compare to what comfy returns. it's still not possible to apply the lora directly to the transformer which in case of wan2.2 where there are two different transformers that need the same lora with different weights, but probably it's out of scope for this PR |
I will look for those. During your comparison, it might be better to try to ensure identical settings because a small difference in hyperparameter values (schedulers, scales, etc.) can lead to wildly different results for LoRAs. So, we usually rely on visual quality.
Yeah from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
# convert to `diffusers` style state dict.
converted_state_dict = _convert_non_diffusers_wan_lora_to_diffusers(downloaded_lightx2v_state_dict)
pipe.transformer.load_lora_adapter(converted_state_dict, adapter_name="transformer_one")
pipe.transformer_2.lora_lora_adapter(converted_state_dict, adapter_name="transformer_two")
pipe.transformer.set_adapters("transformer_one", weight_value_one)
pipe.transformer_2.set_adapters("transformer_two", weight_value_two) |
Update: I was looking into the unexpected keys issue for the Wan 2.2 model. When trying to load the said LoRA state dict, it complains about the keys that have either I manually confirmed this by printing the modules. For loading the LoRA into the second transformer, we gotta do: import torch
from diffusers import WanImageToVideoPipeline
from huggingface_hub import hf_hub_download
import requests
from PIL import Image
from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
from io import BytesIO
import safetensors.torch
# Load a basic transformer model
pipe = WanImageToVideoPipeline.from_pretrained(
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
torch_dtype=torch.bfloat16
)
lora_path = hf_hub_download(
repo_id="Kijai/WanVideo_comfy",
filename="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"
)
pipe.to("cuda")
pipe.load_lora_weights(lora_path)
# print(pipe.transformer.__class__.__name__)
# print(pipe.transformer.peft_config)
org_state_dict = safetensors.torch.load_file(lora_path)
converted_state_dict = _convert_non_diffusers_wan_lora_to_diffusers(org_state_dict)
pipe.transformer_2.load_lora_adapter(converted_state_dict)
image_url = "https://cloud.inference.sh/u/4mg21r6ta37mpaz6ktzwtt8krr/01k1g7k73eebnrmzmc6h0bghq6.png"
response = requests.get(image_url)
input_image = Image.open(BytesIO(response.content)).convert("RGB")
frames = pipe(input_image, "animate", num_inference_steps=4, guidance_scale=1.0) It works but throws the same unexpected keys warning which is again expected for the second DiT in Wan 2.2. |
the last snippet is gold @sayakpaul it works perfectly, hero of the day! I think the official documentation would benefit from an example on how to load LightX2V given how popular it is on reddit btw. |
just for reference, with the proper weights and option to switch from wan2.1 to 2.2 import torch
from diffusers import WanImageToVideoPipeline, DiffusionPipeline, LCMScheduler, UniPCMultistepScheduler
from huggingface_hub import hf_hub_download
import requests
from PIL import Image
from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers
from io import BytesIO
from diffusers.utils import export_to_video
import safetensors.torch
#Wan 2.1
#pipe = DiffusionPipeline.from_pretrained(
# "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",
# torch_dtype=torch.bfloat16
#)
pipe = WanImageToVideoPipeline.from_pretrained(
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
lora_path = hf_hub_download(
repo_id="Kijai/WanVideo_comfy",
filename="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors"
)
pipe.to("cuda")
pipe.load_lora_weights(lora_path, adapter_name='lightx2v_t1')
pipe.set_adapters(["lightx2v_t1"], adapter_weights=[3.0])
if hasattr(pipe, "transformer_2") and pipe.transformer_2 is not None:
org_state_dict = safetensors.torch.load_file(lora_path)
converted_state_dict = _convert_non_diffusers_wan_lora_to_diffusers(org_state_dict)
pipe.transformer_2.load_lora_adapter(converted_state_dict, adapter_name="lightx2v")
pipe.transformer_2.set_adapters(["lightx2v"], weights=[1.5])
image_url = "https://cloud.inference.sh/u/4mg21r6ta37mpaz6ktzwtt8krr/01k1g7k73eebnrmzmc6h0bghq6.png"
response = requests.get(image_url)
input_image = Image.open(BytesIO(response.content)).convert("RGB")
frames = pipe(input_image, "animate", num_inference_steps=4, guidance_scale=1.0).frames[0]
export_to_video(frames, "output.mp4") |
@luke14free glad that it worked out! Docs added in 7308bc1. |
docs/source/en/api/pipelines/wan.md
Outdated
Follow Wan 2.2 checkpoints are also supported: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just continue growing the above list instead of adding another
docs/source/en/api/pipelines/wan.md
Outdated
@@ -327,6 +333,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip | |||
|
|||
- Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images. | |||
|
|||
## Using LightX2V LoRAs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's put this in the Notes section instead of creating a new subsection
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
@luke14free How much boundary_ratio I should keep for lightx2v lora where inference step is just 4 |
@mayankagrawal10198 most workflows I have seen use 50%-50% |
What does this PR do?
Fixes #12037
@luke14free, could you give this PR a try? I tested it with the following snippet: