Skip to content

Remove logger warnings for attention backends and hard error during runtime instead #11967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 81 additions & 18 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,37 @@
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS


logger = get_logger(__name__) # pylint: disable=invalid-name


if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"

_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)


if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
else:
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
flash_attn_func = None
flash_attn_varlen_func = None


if is_flash_attn_3_available():
if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None


if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
Expand All @@ -67,9 +78,6 @@
sageattn_varlen,
)
else:
logger.warning(
"`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
)
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
Expand All @@ -78,39 +86,39 @@
sageattn_varlen = None


if is_torch_version(">=", "2.5.0"):
if _CAN_USE_FLEX_ATTN:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention


if is_torch_npu_available():
if _CAN_USE_NPU_ATTN:
from torch_npu import npu_fusion_attention
else:
npu_fusion_attention = None


if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
if _CAN_USE_XLA_ATTN:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
else:
xla_flash_attention = None


if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
if _CAN_USE_XFORMERS_ATTN:
import xformers.ops as xops
else:
logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
xops = None


logger = get_logger(__name__) # pylint: disable=invalid-name

# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
# - CP with sage attention, flex, xformers, other missing backends
# - Add support for normal and CP training with backends that don't support it yet


_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
Expand Down Expand Up @@ -179,13 +187,16 @@ def list_backends(cls):


@contextlib.contextmanager
def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
Context manager to set the active attention backend.
"""
if backend not in _AttentionBackendRegistry._backends:
raise ValueError(f"Backend {backend} is not registered.")

backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)

old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry._active_backend = backend

Expand Down Expand Up @@ -226,9 +237,10 @@ def dispatch_attention_fn(
"dropout_p": dropout_p,
"is_causal": is_causal,
"scale": scale,
"enable_gqa": enable_gqa,
**attention_kwargs,
}
if is_torch_version(">=", "2.5.0"):
kwargs["enable_gqa"] = enable_gqa

if _AttentionBackendRegistry._checks_enabled:
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
Expand Down Expand Up @@ -305,6 +317,57 @@ def _check_shape(
# ===== Helper functions =====


def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
if not _CAN_USE_FLASH_ATTN:
raise RuntimeError(
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
)

elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
if not _CAN_USE_FLASH_ATTN_3:
raise RuntimeError(
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)

elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
]:
if not _CAN_USE_SAGE_ATTN:
raise RuntimeError(
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
)

elif backend == AttentionBackendName.FLEX:
if not _CAN_USE_FLEX_ATTN:
raise RuntimeError(
f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
)

elif backend == AttentionBackendName._NATIVE_NPU:
if not _CAN_USE_NPU_ATTN:
raise RuntimeError(
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
)

elif backend == AttentionBackendName._NATIVE_XLA:
if not _CAN_USE_XLA_ATTN:
raise RuntimeError(
f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
)

elif backend == AttentionBackendName.XFORMERS:
if not _CAN_USE_XFORMERS_ATTN:
raise RuntimeError(
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
)


@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
Expand Down
10 changes: 7 additions & 3 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,19 +622,21 @@ def set_attention_backend(self, backend: str) -> None:
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements

# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention

logger.warning("Attention backends are an experimental feature and the API may be subject to change.")

backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))

backend = AttentionBackendName(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
_check_attention_backend_requirements(backend)

attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
Expand All @@ -651,6 +653,8 @@ def reset_attention_backend(self) -> None:
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention

logger.warning("Attention backends are an experimental feature and the API may be subject to change.")

attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
Expand Down
Loading