From 094fa76ca36f6d4a1fe52ed28f6227d63a1fe426 Mon Sep 17 00:00:00 2001 From: oguzhanercan Date: Mon, 14 Jul 2025 14:16:25 +0300 Subject: [PATCH] ICG Guidance added, %30 faster inference --- .gitignore | 6 +- README.md | 1 + app.py | 55 ++++++++++++++++++- inference.py | 12 ++++ .../pipelines/omnigen2/pipeline_omnigen2.py | 49 ++++++++++++----- requirements.txt | 1 + 6 files changed, 107 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 1ecee53..e64c32c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,9 @@ # Icon must end with two \r Icon - +#pretrained weights +pretrained_models/ +venv # Thumbnails ._* @@ -232,4 +234,4 @@ save_pipeline.py outputs_gradio/* test.py data_configs/test -scripts/test \ No newline at end of file +scripts/test diff --git a/README.md b/README.md index 09cf8b9..86e6635 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ ## 🔥 News +- **2025-07-14**: %30 faster inference with Independent Condition Guidance [ICG](https://arxiv.org/abs/2407.02687). - **2025-07-05**: Training datasets [X2I2](https://huggingface.co/datasets/OmniGen2/X2I2) are available. - **2025-07-03**: OmniGen2 now supports [TeaCache](https://github.com/ali-vilab/TeaCache) and [TaylorSeer](https://github.com/Shenyi-Z/TaylorSeer) for faster inference, see [Usage Tips](#-usage-tips) for details. Thanks @legitnull for great [TeaCache-PR](https://github.com/VectorSpaceLab/OmniGen2/pull/52) and [TaylorSeer-PR](https://github.com/VectorSpaceLab/OmniGen2/pull/76). - **2025-07-01**: OmniGen2 is supported by [ComfyUI official](https://comfyanonymous.github.io/ComfyUI_examples/omnigen), thanks !! diff --git a/app.py b/app.py index 6fbce42..0c30551 100644 --- a/app.py +++ b/app.py @@ -65,6 +65,9 @@ def run( max_input_image_side_length, max_pixels, seed_input, + use_teacache, + teacache_scale, + guidance_method:str, progress=gr.Progress(), ): input_images = [image_input_1, image_input_2, image_input_3] @@ -92,6 +95,23 @@ def progress_callback(cur_step, timesteps): prediction_type="flow_prediction", ) + if use_teacache: + print("Using teacache for inference speedup") + pipeline.transformer.enable_teacache = True + pipeline.transformer.teacache_rel_l1_thresh = teacache_scale + + else: + pipeline.transformer.enable_teacache = False + print(guidance_method) + + if guidance_method == "True": + print("Using ICG guidance method") + guidance_method = "ICG" + else : + print("Using CFG guidance method") + guidance_method = "CFG" + + results = pipeline( prompt=instruction, input_images=input_images, @@ -108,6 +128,7 @@ def progress_callback(cur_step, timesteps): num_images_per_prompt=num_images_per_prompt, generator=generator, output_type="pil", + guidance_method=guidance_method, step_func=progress_callback, ) @@ -881,6 +902,10 @@ def run_for_examples( max_input_image_side_length, max_pixels, seed_input, + use_teacache, + teacache_scale, + guidance_method:str, + progress=gr.Progress(), ): return run( instruction, @@ -900,6 +925,10 @@ def run_for_examples( max_input_image_side_length, max_pixels, seed_input, + use_teacache, + teacache_scale, + guidance_method, + progress=gr.Progress(), ) description = """ @@ -964,6 +993,22 @@ def main(args): width_input = gr.Slider( label="Width", minimum=256, maximum=2048, value=1024, step=128 ) + with gr.Row(equal_height=True): + guidance_method = gr.Dropdown( + label="use_icg", choices=["True", "False"], + value = "True", + info="ICG inference time optimization.", + ) + with gr.Row(equal_height=True): + use_teacache = gr.Dropdown( + label="Teacache usage", choices=[True, False], + value=True, + info="Teacache inference time optimization.", + ) + teacache_scale = gr.Slider( + label="Teacache Scale", minimum=0, maximum=1, value=0.1, step=0.01, + ) + with gr.Row(equal_height=True): text_guidance_scale_input = gr.Slider( label="Text Guidance Scale", @@ -1091,6 +1136,10 @@ def adjust_start_slider(end_val, start_val): max_input_image_side_length, max_pixels, seed_input, + use_teacache, + teacache_scale, + guidance_method, + ], outputs=output_image, ) @@ -1116,6 +1165,10 @@ def adjust_start_slider(end_val, start_val): max_input_image_side_length, max_pixels, seed_input, + use_teacache, + teacache_scale, + guidance_method, + ], outputs=output_image, ) @@ -1133,7 +1186,7 @@ def parse_args(): parser.add_argument( "--model_path", type=str, - default="OmniGen2/OmniGen2", + default="./pretrained_models", help="Path or HuggingFace name of the model to load." ) parser.add_argument( diff --git a/inference.py b/inference.py index bc155be..635a9b5 100644 --- a/inference.py +++ b/inference.py @@ -169,6 +169,17 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable TaylorSeer Caching." ) + + parser.add_argument( + "--guidance_method", + type=str, + default="ICG", + choices=["ICG", "CFG"], + help="Guidance method to use for the model. ICG: Independent Condition Guidance, CFG: Classifier-Free Guidance." + + ) + + return parser.parse_args() def load_pipeline(args: argparse.Namespace, accelerator: Accelerator, weight_dtype: torch.dtype) -> OmniGen2Pipeline: @@ -270,6 +281,7 @@ def run(args: argparse.Namespace, num_images_per_prompt=args.num_images_per_prompt, generator=generator, output_type="pil", + guidance_method=args.guidance_method, ) return results diff --git a/omnigen2/pipelines/omnigen2/pipeline_omnigen2.py b/omnigen2/pipelines/omnigen2/pipeline_omnigen2.py index eca129d..bb0f423 100644 --- a/omnigen2/pipelines/omnigen2/pipeline_omnigen2.py +++ b/omnigen2/pipelines/omnigen2/pipeline_omnigen2.py @@ -18,7 +18,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union - +import time import math from PIL import Image @@ -490,6 +490,7 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", + guidance_method: str = "ICG", return_dict: bool = True, verbose: bool = False, step_func=None, @@ -594,7 +595,9 @@ def __call__( device=device, dtype=dtype, verbose=verbose, + guidance_method=guidance_method, step_func=step_func, + ) image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear') @@ -623,7 +626,9 @@ def processing( device, dtype, verbose, + guidance_method, step_func=None + ): batch_size = latents.shape[0] @@ -648,9 +653,12 @@ def processing( teacache_params = TeaCacheParams() teacache_params_uncond = TeaCacheParams() teacache_params_ref = TeaCacheParams() - + with self.progress_bar(total=num_inference_steps) as progress_bar: + + for i, t in enumerate(timesteps): + if enable_taylorseer: self.transformer.cache_dic = model_pred_cache_dic self.transformer.current = model_pred_current @@ -670,6 +678,7 @@ def processing( image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: + if enable_taylorseer: self.transformer.cache_dic = model_pred_ref_cache_dic self.transformer.current = model_pred_ref_current @@ -692,18 +701,30 @@ def processing( elif self.transformer.enable_teacache: teacache_params_uncond.is_first_or_last_step = i == 0 or i == len(timesteps) - 1 self.transformer.teacache_params = teacache_params_uncond - - model_pred_uncond = self.predict( - t=t, - latents=latents, - prompt_embeds=negative_prompt_embeds, - freqs_cis=freqs_cis, - prompt_attention_mask=negative_prompt_attention_mask, - ref_image_hidden_states=None, - ) - - model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \ - text_guidance_scale * (model_pred - model_pred_ref) + + if guidance_method == "CFG": + + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + + + model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \ + text_guidance_scale * (model_pred - model_pred_ref) + + elif guidance_method == "ICG": + snr = (i + 1e-5) / (self._num_timesteps - i + 1e-5) + icg_text_scale = text_guidance_scale * snr / (1 + snr) + model_pred = model_pred_ref + icg_text_scale * (model_pred - model_pred_ref) + + end_for = time.time() + + elif text_guidance_scale > 1.0: if enable_taylorseer: self.transformer.cache_dic = model_pred_uncond_cache_dic diff --git a/requirements.txt b/requirements.txt index 902ae9b..a1d0ae4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ python-dotenv ninja ipykernel wheel +gradio triton-windows; sys_platform == "win32"