Skip to content

Wan 2.2 5b i2v results poor quality compared to official Wan HF Space #12034

@okaris

Description

@okaris

Describe the bug

diffusers result:

Image

video links
https://github.com/user-attachments/assets/4dd6d342-9dfb-4946-a714-641f5a4cc98d

(this video might not play in the browser due to a diffusers.utils.export_video encoding issue, adding 2 alternatives encoded with lix264)

extra:

https://github.com/user-attachments/assets/e990c9a8-0a55-4c89-9233-056aa6743103
https://github.com/user-attachments/assets/e7fd1a57-e87e-486c-8ecb-f1147649d600

wan space result:

Image
c60bc1aeb5b0f50b.mp4

Reproduction

import os
# Enable HF Hub fast transfer for faster model downloads
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import torch
import numpy as np
import tempfile
from typing import Optional
from pydantic import Field
from PIL import Image
from diffusers import WanImageToVideoPipeline
from diffusers.utils import export_to_video
from accelerate import Accelerator

from inferencesh import BaseApp, BaseAppInput, BaseAppOutput, File


class AppInput(BaseAppInput):
    image: File = Field(description="Input image for video generation")
    prompt: str = Field(description="Text prompt for video generation")
    negative_prompt: str = Field(
        default="oversaturated, overexposed, static, blurry details, subtitles, stylized, artwork, painting, still image, overall gray, worst quality, low quality, JPEG artifacts, ugly, deformed, extra fingers, poorly drawn hands, poorly drawn face, malformed, disfigured, deformed limbs, fused fingers, static motionless frame, cluttered background, three legs, crowded background, walking backwards",
        description="Negative prompt to guide what to avoid in generation"
    )
    resolution: str = Field(default="720p", description="Resolution preset", enum=["480p", "720p"])
    max_area: Optional[int] = Field(default=None, description="Maximum area for image resizing (auto-set based on resolution if not specified)")
    num_frames: int = Field(default=121, description="Number of frames to generate")
    guidance_scale: float = Field(default=5.0, description="Classifier-free guidance scale")
    num_inference_steps: int = Field(default=50, description="Number of denoising steps")
    fps: int = Field(default=24, description="Frames per second for the output video")
    seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
    video_output_quality: int = Field(default=5, ge=1, le=10, description="Video output quality (1-10)")

class AppOutput(BaseAppOutput):
    file: File = Field(description="Generated video file")

class App(BaseApp):
    async def setup(self, metadata):
        """Initialize the Wan2.2-TI2V-5B Image-to-Video pipeline and resources here."""
        print("Setting up Wan2.2-TI2V-5B Image-to-Video pipeline...")
        
        # Store resolution defaults (using TI2V resolution standards)
        self.resolution_presets = {
            "480p": {"max_area": 480 * 832},
            "720p": {"max_area": 704 * 1280}  # Updated to TI2V's 704 height
        }
        
        # Initialize accelerator
        self.accelerator = Accelerator()
        
        # Set up device and dtype using accelerator
        self.device = self.accelerator.device
        self.dtype = torch.bfloat16 if self.device.type == "cuda" else torch.float32
        
        print(f"Using device: {self.device}")
        print(f"Using dtype: {self.dtype}")
                
        # Model ID for the TI2V 5B variant (using I2V pipeline with TI2V model)
        self.model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
        
        self.pipe = WanImageToVideoPipeline.from_pretrained(
            self.model_id, 
            torch_dtype=self.dtype
        )
        self.pipe.enable_model_cpu_offload()
       
    def resize_image_for_pipeline(self, image: Image.Image, max_area: int) -> tuple[Image.Image, int, int]:
        """Resize image according to pipeline requirements."""
        aspect_ratio = image.height / image.width
        mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1]
        print(f"Mod value: {mod_value}")
        height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
        width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
        
        resized_image = image.resize((width, height))
        print(f"Resized image from {image.size} to {resized_image.size} (target area: {max_area})")
        
        return resized_image, width, height

    async def run(self, input_data: AppInput, metadata) -> AppOutput:
        """Generate video from image and text prompt."""
        print(f"Generating video with prompt: {input_data.prompt}")
        
        # Use resolution preset if max_area not specified
        preset = self.resolution_presets.get(input_data.resolution, self.resolution_presets["720p"])
        max_area = input_data.max_area if input_data.max_area is not None else preset["max_area"]
        print(f"Using resolution: {input_data.resolution}, max area: {max_area}")
        
        # Load and process input image
        image = Image.open(input_data.image.path).convert("RGB")
        print(f"Loaded image: {image.size}")
        
        # Resize image according to pipeline requirements
        resized_image, width, height = self.resize_image_for_pipeline(image, max_area)
        
        # Set seed if provided
        generator = None
        if input_data.seed is not None:
            generator = torch.Generator(device=self.device).manual_seed(input_data.seed)
            print(f"Using seed: {input_data.seed}")
        
        # Generate video
        print("Starting video generation...")
        output = self.pipe(
            image=resized_image,
            prompt=input_data.prompt,
            negative_prompt=input_data.negative_prompt,
            height=height,
            width=width,
            num_frames=input_data.num_frames,
            guidance_scale=input_data.guidance_scale,
            num_inference_steps=input_data.num_inference_steps,
            generator=generator,
        ).frames[0]
        
        print("Video generation complete, exporting...")
        
        # Create temporary file for output
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
            output_path = temp_file.name
        
        # Export video
        export_to_video(output, output_path, fps=input_data.fps, quality=input_data.video_output_quality)
        
        print(f"Video exported to: {output_path}")
        
        return AppOutput(file=File(path=output_path))

    async def unload(self):
        """Clean up resources here."""
        print("Cleaning up...")
        if hasattr(self, 'pipe'):
            del self.pipe
        
        # Clear GPU cache if using CUDA
        if hasattr(self, 'device') and self.device.type == "cuda":
            torch.cuda.empty_cache()
        
        print("Cleanup complete!") 

sample script adapted from provided default script.

inputs:

{
  "image": "https://cloud.inference.sh/u/4mg21r6ta37mpaz6ktzwtt8krr/01k1g7k73eebnrmzmc6h0bghq6.png",
  "prompt": "morpheus from the matrix offering the choice, include morpheus, on one hand it says \"local\" on the other it says \"cloud\"",
  "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
  "resolution": "720p",
  "max_area": null,
  "num_frames": 49,
  "guidance_scale": 5.0,
  "num_inference_steps": 38,
  "fps": 24,
  "seed": 42,
  "cache_threshold": 0.0,
  "video_output_quality": 10
}

compared against same inputs on https://huggingface.co/spaces/Wan-AI/Wan-2.2-5B

Logs

System Info

git+https://github.com/huggingface/diffusers.git@9d313fc718c8ace9a35f07dad9d5ce8018f8d216

  • 🤗 Diffusers version: 0.35.0.dev0
  • Platform: Linux-5.15.0-136-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.34.2
  • Transformers version: 4.51.3
  • Accelerate version: 1.6.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
    NVIDIA A100-SXM4-80GB, 81920 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@yiyixuxu @DN6 @a-r-r-o-w

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions