Skip to content

Independent Condition Guidance added, %30 faster inference for diffusion process #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
# Icon must end with two \r
Icon


#pretrained weights
pretrained_models/
venv
# Thumbnails
._*

Expand Down Expand Up @@ -232,4 +234,4 @@ save_pipeline.py
outputs_gradio/*
test.py
data_configs/test
scripts/test
scripts/test
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
</h4>

## 🔥 News
- **2025-07-14**: %30 faster inference with Independent Condition Guidance [ICG](https://arxiv.org/abs/2407.02687).
- **2025-07-05**: Training datasets [X2I2](https://huggingface.co/datasets/OmniGen2/X2I2) are available.
- **2025-07-03**: OmniGen2 now supports [TeaCache](https://github.com/ali-vilab/TeaCache) and [TaylorSeer](https://github.com/Shenyi-Z/TaylorSeer) for faster inference, see [Usage Tips](#-usage-tips) for details. Thanks @legitnull for great [TeaCache-PR](https://github.com/VectorSpaceLab/OmniGen2/pull/52) and [TaylorSeer-PR](https://github.com/VectorSpaceLab/OmniGen2/pull/76).
- **2025-07-01**: OmniGen2 is supported by [ComfyUI official](https://comfyanonymous.github.io/ComfyUI_examples/omnigen), thanks !!
Expand Down
55 changes: 54 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def run(
max_input_image_side_length,
max_pixels,
seed_input,
use_teacache,
teacache_scale,
guidance_method:str,
progress=gr.Progress(),
):
input_images = [image_input_1, image_input_2, image_input_3]
Expand Down Expand Up @@ -92,6 +95,23 @@ def progress_callback(cur_step, timesteps):
prediction_type="flow_prediction",
)

if use_teacache:
print("Using teacache for inference speedup")
pipeline.transformer.enable_teacache = True
pipeline.transformer.teacache_rel_l1_thresh = teacache_scale

else:
pipeline.transformer.enable_teacache = False
print(guidance_method)

if guidance_method == "True":
print("Using ICG guidance method")
guidance_method = "ICG"
else :
print("Using CFG guidance method")
guidance_method = "CFG"


results = pipeline(
prompt=instruction,
input_images=input_images,
Expand All @@ -108,6 +128,7 @@ def progress_callback(cur_step, timesteps):
num_images_per_prompt=num_images_per_prompt,
generator=generator,
output_type="pil",
guidance_method=guidance_method,
step_func=progress_callback,
)

Expand Down Expand Up @@ -881,6 +902,10 @@ def run_for_examples(
max_input_image_side_length,
max_pixels,
seed_input,
use_teacache,
teacache_scale,
guidance_method:str,
progress=gr.Progress(),
):
return run(
instruction,
Expand All @@ -900,6 +925,10 @@ def run_for_examples(
max_input_image_side_length,
max_pixels,
seed_input,
use_teacache,
teacache_scale,
guidance_method,
progress=gr.Progress(),
)

description = """
Expand Down Expand Up @@ -964,6 +993,22 @@ def main(args):
width_input = gr.Slider(
label="Width", minimum=256, maximum=2048, value=1024, step=128
)
with gr.Row(equal_height=True):
guidance_method = gr.Dropdown(
label="use_icg", choices=["True", "False"],
value = "True",
info="ICG inference time optimization.",
)
with gr.Row(equal_height=True):
use_teacache = gr.Dropdown(
label="Teacache usage", choices=[True, False],
value=True,
info="Teacache inference time optimization.",
)
teacache_scale = gr.Slider(
label="Teacache Scale", minimum=0, maximum=1, value=0.1, step=0.01,
)

with gr.Row(equal_height=True):
text_guidance_scale_input = gr.Slider(
label="Text Guidance Scale",
Expand Down Expand Up @@ -1091,6 +1136,10 @@ def adjust_start_slider(end_val, start_val):
max_input_image_side_length,
max_pixels,
seed_input,
use_teacache,
teacache_scale,
guidance_method,

],
outputs=output_image,
)
Expand All @@ -1116,6 +1165,10 @@ def adjust_start_slider(end_val, start_val):
max_input_image_side_length,
max_pixels,
seed_input,
use_teacache,
teacache_scale,
guidance_method,

],
outputs=output_image,
)
Expand All @@ -1133,7 +1186,7 @@ def parse_args():
parser.add_argument(
"--model_path",
type=str,
default="OmniGen2/OmniGen2",
default="./pretrained_models",
help="Path or HuggingFace name of the model to load."
)
parser.add_argument(
Expand Down
12 changes: 12 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable TaylorSeer Caching."
)

parser.add_argument(
"--guidance_method",
type=str,
default="ICG",
choices=["ICG", "CFG"],
help="Guidance method to use for the model. ICG: Independent Condition Guidance, CFG: Classifier-Free Guidance."

)


return parser.parse_args()

def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline:
Expand Down Expand Up @@ -270,6 +281,7 @@ def run(args: argparse.Namespace,
num_images_per_prompt=args.num_images_per_prompt,
generator=generator,
output_type="pil",
guidance_method=args.guidance_method,
)
return results

Expand Down
49 changes: 35 additions & 14 deletions omnigen2/pipelines/omnigen2/pipeline_omnigen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import time
import math

from PIL import Image
Expand Down Expand Up @@ -490,6 +490,7 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
guidance_method: str = "ICG",
return_dict: bool = True,
verbose: bool = False,
step_func=None,
Expand Down Expand Up @@ -594,7 +595,9 @@ def __call__(
device=device,
dtype=dtype,
verbose=verbose,
guidance_method=guidance_method,
step_func=step_func,

)

image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear')
Expand Down Expand Up @@ -623,7 +626,9 @@ def processing(
device,
dtype,
verbose,
guidance_method,
step_func=None

):
batch_size = latents.shape[0]

Expand All @@ -648,9 +653,12 @@ def processing(
teacache_params = TeaCacheParams()
teacache_params_uncond = TeaCacheParams()
teacache_params_ref = TeaCacheParams()

with self.progress_bar(total=num_inference_steps) as progress_bar:


for i, t in enumerate(timesteps):

if enable_taylorseer:
self.transformer.cache_dic = model_pred_cache_dic
self.transformer.current = model_pred_current
Expand All @@ -670,6 +678,7 @@ def processing(
image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0

if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:

if enable_taylorseer:
self.transformer.cache_dic = model_pred_ref_cache_dic
self.transformer.current = model_pred_ref_current
Expand All @@ -692,18 +701,30 @@ def processing(
elif self.transformer.enable_teacache:
teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1
self.transformer.teacache_params = teacache_params_uncond

model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)

model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
text_guidance_scale * (model_pred - model_pred_ref)

if guidance_method == "CFG":

model_pred_uncond = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
freqs_cis=freqs_cis,
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=None,
)


model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \
text_guidance_scale * (model_pred - model_pred_ref)

elif guidance_method == "ICG":
snr = (i + 1e-5) / (self._num_timesteps - i + 1e-5)
icg_text_scale = text_guidance_scale * snr / (1 + snr)
model_pred = model_pred_ref + icg_text_scale * (model_pred - model_pred_ref)

end_for = time.time()


elif text_guidance_scale > 1.0:
if enable_taylorseer:
self.transformer.cache_dic = model_pred_uncond_cache_dic
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ python-dotenv
ninja
ipykernel
wheel
gradio
triton-windows; sys_platform == "win32"