cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
5
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
6
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
7
|
+
from cache_dit.logger import init_logger
|
|
8
|
+
from ..utils import (
|
|
9
|
+
native_diffusers_parallelism_available,
|
|
10
|
+
ContextParallelConfig,
|
|
11
|
+
)
|
|
12
|
+
from .attention import maybe_resigter_native_attention_backend
|
|
13
|
+
from .cp_planners import *
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
maybe_resigter_native_attention_backend()
|
|
17
|
+
except ImportError as e:
|
|
18
|
+
raise ImportError(e)
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def maybe_enable_context_parallelism(
|
|
24
|
+
transformer: torch.nn.Module,
|
|
25
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
26
|
+
) -> torch.nn.Module:
|
|
27
|
+
assert isinstance(transformer, ModelMixin), (
|
|
28
|
+
"transformer must be an instance of diffusers' ModelMixin, "
|
|
29
|
+
f"but got {type(transformer)}"
|
|
30
|
+
)
|
|
31
|
+
if parallelism_config is None:
|
|
32
|
+
return transformer
|
|
33
|
+
|
|
34
|
+
assert isinstance(parallelism_config, ParallelismConfig), (
|
|
35
|
+
"parallelism_config must be an instance of ParallelismConfig"
|
|
36
|
+
f" but got {type(parallelism_config)}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if (
|
|
40
|
+
parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
|
|
41
|
+
and native_diffusers_parallelism_available()
|
|
42
|
+
):
|
|
43
|
+
cp_config = None
|
|
44
|
+
if (
|
|
45
|
+
parallelism_config.ulysses_size is not None
|
|
46
|
+
or parallelism_config.ring_size is not None
|
|
47
|
+
):
|
|
48
|
+
cp_config = ContextParallelConfig(
|
|
49
|
+
ulysses_degree=parallelism_config.ulysses_size,
|
|
50
|
+
ring_degree=parallelism_config.ring_size,
|
|
51
|
+
)
|
|
52
|
+
if cp_config is not None:
|
|
53
|
+
attention_backend = parallelism_config.parallel_kwargs.get(
|
|
54
|
+
"attention_backend", None
|
|
55
|
+
)
|
|
56
|
+
if hasattr(transformer, "enable_parallelism"):
|
|
57
|
+
if hasattr(transformer, "set_attention_backend"):
|
|
58
|
+
# native, _native_cudnn, flash, etc.
|
|
59
|
+
if attention_backend is None:
|
|
60
|
+
# Now only _native_cudnn is supported for parallelism
|
|
61
|
+
# issue: https://github.com/huggingface/diffusers/pull/12443
|
|
62
|
+
transformer.set_attention_backend("_native_cudnn")
|
|
63
|
+
logger.warning(
|
|
64
|
+
"attention_backend is None, set default attention backend "
|
|
65
|
+
"to _native_cudnn for parallelism because of the issue: "
|
|
66
|
+
"https://github.com/huggingface/diffusers/pull/12443"
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
transformer.set_attention_backend(attention_backend)
|
|
70
|
+
logger.info(
|
|
71
|
+
"Found attention_backend from config, set attention "
|
|
72
|
+
f"backend to: {attention_backend}"
|
|
73
|
+
)
|
|
74
|
+
# Prefer custom cp_plan if provided
|
|
75
|
+
cp_plan = parallelism_config.parallel_kwargs.get(
|
|
76
|
+
"cp_plan", None
|
|
77
|
+
)
|
|
78
|
+
if cp_plan is not None:
|
|
79
|
+
logger.info(
|
|
80
|
+
f"Using custom context parallelism plan: {cp_plan}"
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
# Try get context parallelism plan from register if not provided
|
|
84
|
+
extra_parallel_kwargs = {}
|
|
85
|
+
if parallelism_config.parallel_kwargs is not None:
|
|
86
|
+
extra_parallel_kwargs = (
|
|
87
|
+
parallelism_config.parallel_kwargs
|
|
88
|
+
)
|
|
89
|
+
cp_plan = ContextParallelismPlannerRegister.get_planner(
|
|
90
|
+
transformer
|
|
91
|
+
)().apply(transformer=transformer, **extra_parallel_kwargs)
|
|
92
|
+
|
|
93
|
+
transformer.enable_parallelism(
|
|
94
|
+
config=cp_config, cp_plan=cp_plan
|
|
95
|
+
)
|
|
96
|
+
_maybe_patch_native_parallel_config(transformer)
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"{transformer.__class__.__name__} does not support context parallelism."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return transformer
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _maybe_patch_native_parallel_config(
|
|
106
|
+
transformer: torch.nn.Module,
|
|
107
|
+
) -> torch.nn.Module:
|
|
108
|
+
|
|
109
|
+
cls_name = transformer.__class__.__name__
|
|
110
|
+
if not cls_name.startswith("Nunchaku"):
|
|
111
|
+
return transformer
|
|
112
|
+
|
|
113
|
+
from diffusers import FluxTransformer2DModel, QwenImageTransformer2DModel
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
from nunchaku.models.transformers.transformer_flux_v2 import (
|
|
117
|
+
NunchakuFluxTransformer2DModelV2,
|
|
118
|
+
NunchakuFluxAttention,
|
|
119
|
+
NunchakuFluxFA2Processor,
|
|
120
|
+
)
|
|
121
|
+
from nunchaku.models.transformers.transformer_qwenimage import (
|
|
122
|
+
NunchakuQwenAttention,
|
|
123
|
+
NunchakuQwenImageNaiveFA2Processor,
|
|
124
|
+
NunchakuQwenImageTransformer2DModel,
|
|
125
|
+
)
|
|
126
|
+
except ImportError:
|
|
127
|
+
raise ImportError(
|
|
128
|
+
"NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
|
|
129
|
+
"requires the 'nunchaku' package. Please install nunchaku before using "
|
|
130
|
+
"the context parallelism for nunchaku 4-bits models."
|
|
131
|
+
)
|
|
132
|
+
assert isinstance(
|
|
133
|
+
transformer,
|
|
134
|
+
(
|
|
135
|
+
NunchakuFluxTransformer2DModelV2,
|
|
136
|
+
FluxTransformer2DModel,
|
|
137
|
+
),
|
|
138
|
+
) or isinstance(
|
|
139
|
+
transformer,
|
|
140
|
+
(
|
|
141
|
+
NunchakuQwenImageTransformer2DModel,
|
|
142
|
+
QwenImageTransformer2DModel,
|
|
143
|
+
),
|
|
144
|
+
), (
|
|
145
|
+
"transformer must be an instance of NunchakuFluxTransformer2DModelV2 "
|
|
146
|
+
f"or NunchakuQwenImageTransformer2DModel, but got {type(transformer)}"
|
|
147
|
+
)
|
|
148
|
+
config = transformer._parallel_config
|
|
149
|
+
|
|
150
|
+
attention_classes = (
|
|
151
|
+
NunchakuFluxAttention,
|
|
152
|
+
NunchakuFluxFA2Processor,
|
|
153
|
+
NunchakuQwenAttention,
|
|
154
|
+
NunchakuQwenImageNaiveFA2Processor,
|
|
155
|
+
)
|
|
156
|
+
for module in transformer.modules():
|
|
157
|
+
if not isinstance(module, attention_classes):
|
|
158
|
+
continue
|
|
159
|
+
processor = getattr(module, "processor", None)
|
|
160
|
+
if processor is None or not hasattr(processor, "_parallel_config"):
|
|
161
|
+
continue
|
|
162
|
+
processor._parallel_config = config
|
|
163
|
+
|
|
164
|
+
return transformer
|
cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from diffusers.models.attention_dispatch import (
|
|
7
|
+
_AttentionBackendRegistry,
|
|
8
|
+
AttentionBackendName,
|
|
9
|
+
_check_device,
|
|
10
|
+
_check_shape,
|
|
11
|
+
TemplatedRingAttention,
|
|
12
|
+
TemplatedUlyssesAttention,
|
|
13
|
+
)
|
|
14
|
+
from diffusers.models._modeling_parallel import ParallelConfig
|
|
15
|
+
except ImportError:
|
|
16
|
+
raise ImportError(
|
|
17
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
18
|
+
"Please install latest version of diffusers from source: \n"
|
|
19
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
20
|
+
)
|
|
21
|
+
from cache_dit.logger import init_logger
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"_native_attention",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
# Enable custom native attention backend with context parallelism
|
|
31
|
+
# by default. Users can set the environment variable to 0 to disable
|
|
32
|
+
# this behavior. Default to enabled for better compatibility.
|
|
33
|
+
_CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH = bool(
|
|
34
|
+
int(os.getenv("CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH", "1"))
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _is_native_attn_supported_context_parallel() -> bool:
|
|
39
|
+
try:
|
|
40
|
+
return (
|
|
41
|
+
AttentionBackendName.NATIVE
|
|
42
|
+
in _AttentionBackendRegistry._supports_context_parallel
|
|
43
|
+
and _AttentionBackendRegistry._supports_context_parallel[
|
|
44
|
+
AttentionBackendName.NATIVE
|
|
45
|
+
]
|
|
46
|
+
)
|
|
47
|
+
except Exception:
|
|
48
|
+
assert isinstance(
|
|
49
|
+
_AttentionBackendRegistry._supports_context_parallel, set
|
|
50
|
+
)
|
|
51
|
+
return (
|
|
52
|
+
AttentionBackendName.NATIVE.value
|
|
53
|
+
in _AttentionBackendRegistry._supports_context_parallel
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
if _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH:
|
|
58
|
+
logger.warning(
|
|
59
|
+
"Re-registering NATIVE attention backend to enable context parallelism. "
|
|
60
|
+
"This is a temporary workaround and should be removed after the native "
|
|
61
|
+
"attention backend supports context parallelism natively. Please check: "
|
|
62
|
+
"https://github.com/huggingface/diffusers/pull/12563 for more details. "
|
|
63
|
+
"Or, you can disable this behavior by setting the environment variable "
|
|
64
|
+
"`CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0`."
|
|
65
|
+
)
|
|
66
|
+
_AttentionBackendRegistry._backends.pop(AttentionBackendName.NATIVE)
|
|
67
|
+
_AttentionBackendRegistry._constraints.pop(AttentionBackendName.NATIVE)
|
|
68
|
+
_AttentionBackendRegistry._supported_arg_names.pop(
|
|
69
|
+
AttentionBackendName.NATIVE
|
|
70
|
+
)
|
|
71
|
+
if _is_native_attn_supported_context_parallel():
|
|
72
|
+
if isinstance(
|
|
73
|
+
_AttentionBackendRegistry._supports_context_parallel, dict
|
|
74
|
+
):
|
|
75
|
+
_AttentionBackendRegistry._supports_context_parallel.pop(
|
|
76
|
+
AttentionBackendName.NATIVE
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
_AttentionBackendRegistry._supports_context_parallel.remove(
|
|
80
|
+
AttentionBackendName.NATIVE.value
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Re-define templated context parallel attention to support attn mask
|
|
84
|
+
def _templated_context_parallel_attention_v2(
|
|
85
|
+
query: torch.Tensor,
|
|
86
|
+
key: torch.Tensor,
|
|
87
|
+
value: torch.Tensor,
|
|
88
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
89
|
+
dropout_p: float = 0.0,
|
|
90
|
+
is_causal: bool = False,
|
|
91
|
+
scale: Optional[float] = None,
|
|
92
|
+
enable_gqa: bool = False,
|
|
93
|
+
return_lse: bool = False,
|
|
94
|
+
*,
|
|
95
|
+
forward_op,
|
|
96
|
+
backward_op,
|
|
97
|
+
_parallel_config: Optional["ParallelConfig"] = None,
|
|
98
|
+
):
|
|
99
|
+
if attn_mask is not None:
|
|
100
|
+
# NOTE(DefTruth): Check if forward_op is native attention forward op
|
|
101
|
+
forward_op_name = forward_op.__name__
|
|
102
|
+
if not forward_op_name == "_native_attention_forward_op":
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"Templated context parallel attention with attn_mask "
|
|
105
|
+
"is only supported for native attention backend, "
|
|
106
|
+
f"but got forward_op: {forward_op_name}."
|
|
107
|
+
)
|
|
108
|
+
if is_causal:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"Causal attention is not yet supported for templated attention."
|
|
111
|
+
)
|
|
112
|
+
if enable_gqa:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"GQA is not yet supported for templated attention."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# TODO: add support for unified attention with ring/ulysses degree both being > 1
|
|
118
|
+
if _parallel_config.context_parallel_config.ring_degree > 1:
|
|
119
|
+
return TemplatedRingAttention.apply(
|
|
120
|
+
query,
|
|
121
|
+
key,
|
|
122
|
+
value,
|
|
123
|
+
attn_mask,
|
|
124
|
+
dropout_p,
|
|
125
|
+
is_causal,
|
|
126
|
+
scale,
|
|
127
|
+
enable_gqa,
|
|
128
|
+
return_lse,
|
|
129
|
+
forward_op,
|
|
130
|
+
backward_op,
|
|
131
|
+
_parallel_config,
|
|
132
|
+
)
|
|
133
|
+
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
|
|
134
|
+
return TemplatedUlyssesAttention.apply(
|
|
135
|
+
query,
|
|
136
|
+
key,
|
|
137
|
+
value,
|
|
138
|
+
attn_mask,
|
|
139
|
+
dropout_p,
|
|
140
|
+
is_causal,
|
|
141
|
+
scale,
|
|
142
|
+
enable_gqa,
|
|
143
|
+
return_lse,
|
|
144
|
+
forward_op,
|
|
145
|
+
backward_op,
|
|
146
|
+
_parallel_config,
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
"Reaching this branch of code is unexpected. Please report a bug."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# NOTE:Remove NATIVE attention backend constraints and re-register it.
|
|
154
|
+
# Here is a temporary workaround to enable context parallelism with
|
|
155
|
+
# native attention backend. We should remove this workaround after
|
|
156
|
+
# the native attention backend supports context parallelism natively.
|
|
157
|
+
# Adapted from: https://github.com/huggingface/diffusers/pull/12563
|
|
158
|
+
|
|
159
|
+
def _native_attention_forward_op(
|
|
160
|
+
ctx: torch.autograd.function.FunctionCtx,
|
|
161
|
+
query: torch.Tensor,
|
|
162
|
+
key: torch.Tensor,
|
|
163
|
+
value: torch.Tensor,
|
|
164
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
165
|
+
dropout_p: float = 0.0,
|
|
166
|
+
is_causal: bool = False,
|
|
167
|
+
scale: Optional[float] = None,
|
|
168
|
+
enable_gqa: bool = False,
|
|
169
|
+
return_lse: bool = False,
|
|
170
|
+
_save_ctx: bool = True,
|
|
171
|
+
_parallel_config: Optional["ParallelConfig"] = None,
|
|
172
|
+
):
|
|
173
|
+
# Native attention does not return_lse
|
|
174
|
+
if return_lse:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
"Native attention does not support return_lse=True"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# used for backward pass
|
|
180
|
+
if _save_ctx:
|
|
181
|
+
ctx.save_for_backward(query, key, value)
|
|
182
|
+
ctx.attn_mask = attn_mask
|
|
183
|
+
ctx.dropout_p = dropout_p
|
|
184
|
+
ctx.is_causal = is_causal
|
|
185
|
+
ctx.scale = scale
|
|
186
|
+
ctx.enable_gqa = enable_gqa
|
|
187
|
+
|
|
188
|
+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
|
189
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
|
190
|
+
query=query,
|
|
191
|
+
key=key,
|
|
192
|
+
value=value,
|
|
193
|
+
attn_mask=attn_mask,
|
|
194
|
+
dropout_p=dropout_p,
|
|
195
|
+
is_causal=is_causal,
|
|
196
|
+
scale=scale,
|
|
197
|
+
enable_gqa=enable_gqa,
|
|
198
|
+
)
|
|
199
|
+
out = out.permute(0, 2, 1, 3)
|
|
200
|
+
|
|
201
|
+
return out
|
|
202
|
+
|
|
203
|
+
def _native_attention_backward_op(
|
|
204
|
+
ctx: torch.autograd.function.FunctionCtx,
|
|
205
|
+
grad_out: torch.Tensor,
|
|
206
|
+
*args,
|
|
207
|
+
**kwargs,
|
|
208
|
+
):
|
|
209
|
+
query, key, value = ctx.saved_tensors
|
|
210
|
+
|
|
211
|
+
query.requires_grad_(True)
|
|
212
|
+
key.requires_grad_(True)
|
|
213
|
+
value.requires_grad_(True)
|
|
214
|
+
|
|
215
|
+
query_t, key_t, value_t = (
|
|
216
|
+
x.permute(0, 2, 1, 3) for x in (query, key, value)
|
|
217
|
+
)
|
|
218
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
|
219
|
+
query=query_t,
|
|
220
|
+
key=key_t,
|
|
221
|
+
value=value_t,
|
|
222
|
+
attn_mask=ctx.attn_mask,
|
|
223
|
+
dropout_p=ctx.dropout_p,
|
|
224
|
+
is_causal=ctx.is_causal,
|
|
225
|
+
scale=ctx.scale,
|
|
226
|
+
enable_gqa=ctx.enable_gqa,
|
|
227
|
+
)
|
|
228
|
+
out = out.permute(0, 2, 1, 3)
|
|
229
|
+
|
|
230
|
+
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
|
231
|
+
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
|
232
|
+
outputs=out,
|
|
233
|
+
inputs=[query_t, key_t, value_t],
|
|
234
|
+
grad_outputs=grad_out_t,
|
|
235
|
+
retain_graph=False,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
|
239
|
+
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
|
240
|
+
grad_value = grad_value_t.permute(0, 2, 1, 3)
|
|
241
|
+
|
|
242
|
+
return grad_query, grad_key, grad_value
|
|
243
|
+
|
|
244
|
+
@_AttentionBackendRegistry.register(
|
|
245
|
+
AttentionBackendName.NATIVE,
|
|
246
|
+
constraints=[_check_device, _check_shape],
|
|
247
|
+
supports_context_parallel=True,
|
|
248
|
+
)
|
|
249
|
+
def _native_attention(
|
|
250
|
+
query: torch.Tensor,
|
|
251
|
+
key: torch.Tensor,
|
|
252
|
+
value: torch.Tensor,
|
|
253
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
254
|
+
dropout_p: float = 0.0,
|
|
255
|
+
is_causal: bool = False,
|
|
256
|
+
scale: Optional[float] = None,
|
|
257
|
+
enable_gqa: bool = False,
|
|
258
|
+
return_lse: bool = False,
|
|
259
|
+
_parallel_config: Optional["ParallelConfig"] = None,
|
|
260
|
+
) -> torch.Tensor:
|
|
261
|
+
if return_lse:
|
|
262
|
+
raise ValueError(
|
|
263
|
+
"Native attention backend does not support setting `return_lse=True`."
|
|
264
|
+
)
|
|
265
|
+
if _parallel_config is None:
|
|
266
|
+
query, key, value = (
|
|
267
|
+
x.permute(0, 2, 1, 3) for x in (query, key, value)
|
|
268
|
+
)
|
|
269
|
+
out = torch.nn.functional.scaled_dot_product_attention(
|
|
270
|
+
query=query,
|
|
271
|
+
key=key,
|
|
272
|
+
value=value,
|
|
273
|
+
attn_mask=attn_mask,
|
|
274
|
+
dropout_p=dropout_p,
|
|
275
|
+
is_causal=is_causal,
|
|
276
|
+
scale=scale,
|
|
277
|
+
enable_gqa=enable_gqa,
|
|
278
|
+
)
|
|
279
|
+
out = out.permute(0, 2, 1, 3)
|
|
280
|
+
else:
|
|
281
|
+
out = _templated_context_parallel_attention_v2(
|
|
282
|
+
query,
|
|
283
|
+
key,
|
|
284
|
+
value,
|
|
285
|
+
attn_mask,
|
|
286
|
+
dropout_p,
|
|
287
|
+
is_causal,
|
|
288
|
+
scale,
|
|
289
|
+
enable_gqa,
|
|
290
|
+
return_lse,
|
|
291
|
+
forward_op=_native_attention_forward_op,
|
|
292
|
+
backward_op=_native_attention_backward_op,
|
|
293
|
+
_parallel_config=_parallel_config,
|
|
294
|
+
)
|
|
295
|
+
return out
|
|
296
|
+
|
|
297
|
+
else:
|
|
298
|
+
from diffusers.models.attention_dispatch import (
|
|
299
|
+
_native_attention,
|
|
300
|
+
) # noqa: F401
|
|
301
|
+
|
|
302
|
+
logger.info(
|
|
303
|
+
"Native attention backend already supports context parallelism."
|
|
304
|
+
)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
4
|
+
from diffusers.models.transformers.transformer_chroma import (
|
|
5
|
+
ChromaTransformer2DModel,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from diffusers.models._modeling_parallel import (
|
|
10
|
+
ContextParallelInput,
|
|
11
|
+
ContextParallelOutput,
|
|
12
|
+
ContextParallelModelPlan,
|
|
13
|
+
)
|
|
14
|
+
except ImportError:
|
|
15
|
+
raise ImportError(
|
|
16
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
17
|
+
"Please install latest version of diffusers from source: \n"
|
|
18
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
19
|
+
)
|
|
20
|
+
from .cp_plan_registers import (
|
|
21
|
+
ContextParallelismPlanner,
|
|
22
|
+
ContextParallelismPlannerRegister,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from cache_dit.logger import init_logger
|
|
26
|
+
|
|
27
|
+
logger = init_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@ContextParallelismPlannerRegister.register("Chroma")
|
|
31
|
+
class ChromaContextParallelismPlanner(ContextParallelismPlanner):
|
|
32
|
+
def apply(
|
|
33
|
+
self,
|
|
34
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> ContextParallelModelPlan:
|
|
37
|
+
|
|
38
|
+
# NOTE: Diffusers native CP plan still not supported
|
|
39
|
+
# for Chroma now.
|
|
40
|
+
self._cp_planner_preferred_native_diffusers = False
|
|
41
|
+
|
|
42
|
+
if (
|
|
43
|
+
transformer is not None
|
|
44
|
+
and self._cp_planner_preferred_native_diffusers
|
|
45
|
+
):
|
|
46
|
+
assert isinstance(
|
|
47
|
+
transformer, ChromaTransformer2DModel
|
|
48
|
+
), "Transformer must be an instance of ChromaTransformer2DModel"
|
|
49
|
+
if hasattr(transformer, "_cp_plan"):
|
|
50
|
+
if transformer._cp_plan is not None:
|
|
51
|
+
return transformer._cp_plan
|
|
52
|
+
|
|
53
|
+
# Otherwise, use the custom CP plan defined here, this maybe
|
|
54
|
+
# a little different from the native diffusers implementation
|
|
55
|
+
# for some models.
|
|
56
|
+
_cp_plan = {
|
|
57
|
+
# Here is a Transformer level CP plan for Chroma, which will
|
|
58
|
+
# only apply the only 1 split hook (pre_forward) on the forward
|
|
59
|
+
# of Transformer, and gather the output after Transformer forward.
|
|
60
|
+
# Pattern of transformer forward, split_output=False:
|
|
61
|
+
# un-split input -> splited input (inside transformer)
|
|
62
|
+
# Pattern of the transformer_blocks, single_transformer_blocks:
|
|
63
|
+
# splited input (previous splited output) -> to_qkv/...
|
|
64
|
+
# -> all2all
|
|
65
|
+
# -> attn (local head, full seqlen)
|
|
66
|
+
# -> all2all
|
|
67
|
+
# -> splited output
|
|
68
|
+
# The `hidden_states` and `encoder_hidden_states` will still keep
|
|
69
|
+
# itself splited after block forward (namely, automatic split by
|
|
70
|
+
# the all2all comm op after attn) for the all blocks.
|
|
71
|
+
# img_ids and txt_ids will only be splited once at the very beginning,
|
|
72
|
+
# and keep splited through the whole transformer forward. The all2all
|
|
73
|
+
# comm op only happens on the `out` tensor after local attn not on
|
|
74
|
+
# img_ids and txt_ids.
|
|
75
|
+
"": {
|
|
76
|
+
"hidden_states": ContextParallelInput(
|
|
77
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
78
|
+
),
|
|
79
|
+
"encoder_hidden_states": ContextParallelInput(
|
|
80
|
+
split_dim=1, expected_dims=3, split_output=False
|
|
81
|
+
),
|
|
82
|
+
"img_ids": ContextParallelInput(
|
|
83
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
84
|
+
),
|
|
85
|
+
"txt_ids": ContextParallelInput(
|
|
86
|
+
split_dim=0, expected_dims=2, split_output=False
|
|
87
|
+
),
|
|
88
|
+
},
|
|
89
|
+
# Then, the final proj_out will gather the splited output.
|
|
90
|
+
# splited input (previous splited output)
|
|
91
|
+
# -> all gather
|
|
92
|
+
# -> un-split output
|
|
93
|
+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
|
94
|
+
}
|
|
95
|
+
return _cp_plan
|