cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ import torch
2
+ from typing import Optional
3
+
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
6
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
7
+ from cache_dit.logger import init_logger
8
+ from ..utils import (
9
+ native_diffusers_parallelism_available,
10
+ ContextParallelConfig,
11
+ )
12
+ from .attention import maybe_resigter_native_attention_backend
13
+ from .cp_planners import *
14
+
15
+ try:
16
+ maybe_resigter_native_attention_backend()
17
+ except ImportError as e:
18
+ raise ImportError(e)
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ def maybe_enable_context_parallelism(
24
+ transformer: torch.nn.Module,
25
+ parallelism_config: Optional[ParallelismConfig],
26
+ ) -> torch.nn.Module:
27
+ assert isinstance(transformer, ModelMixin), (
28
+ "transformer must be an instance of diffusers' ModelMixin, "
29
+ f"but got {type(transformer)}"
30
+ )
31
+ if parallelism_config is None:
32
+ return transformer
33
+
34
+ assert isinstance(parallelism_config, ParallelismConfig), (
35
+ "parallelism_config must be an instance of ParallelismConfig"
36
+ f" but got {type(parallelism_config)}"
37
+ )
38
+
39
+ if (
40
+ parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
41
+ and native_diffusers_parallelism_available()
42
+ ):
43
+ cp_config = None
44
+ if (
45
+ parallelism_config.ulysses_size is not None
46
+ or parallelism_config.ring_size is not None
47
+ ):
48
+ cp_config = ContextParallelConfig(
49
+ ulysses_degree=parallelism_config.ulysses_size,
50
+ ring_degree=parallelism_config.ring_size,
51
+ )
52
+ if cp_config is not None:
53
+ attention_backend = parallelism_config.parallel_kwargs.get(
54
+ "attention_backend", None
55
+ )
56
+ if hasattr(transformer, "enable_parallelism"):
57
+ if hasattr(transformer, "set_attention_backend"):
58
+ # native, _native_cudnn, flash, etc.
59
+ if attention_backend is None:
60
+ # Now only _native_cudnn is supported for parallelism
61
+ # issue: https://github.com/huggingface/diffusers/pull/12443
62
+ transformer.set_attention_backend("_native_cudnn")
63
+ logger.warning(
64
+ "attention_backend is None, set default attention backend "
65
+ "to _native_cudnn for parallelism because of the issue: "
66
+ "https://github.com/huggingface/diffusers/pull/12443"
67
+ )
68
+ else:
69
+ transformer.set_attention_backend(attention_backend)
70
+ logger.info(
71
+ "Found attention_backend from config, set attention "
72
+ f"backend to: {attention_backend}"
73
+ )
74
+ # Prefer custom cp_plan if provided
75
+ cp_plan = parallelism_config.parallel_kwargs.get(
76
+ "cp_plan", None
77
+ )
78
+ if cp_plan is not None:
79
+ logger.info(
80
+ f"Using custom context parallelism plan: {cp_plan}"
81
+ )
82
+ else:
83
+ # Try get context parallelism plan from register if not provided
84
+ extra_parallel_kwargs = {}
85
+ if parallelism_config.parallel_kwargs is not None:
86
+ extra_parallel_kwargs = (
87
+ parallelism_config.parallel_kwargs
88
+ )
89
+ cp_plan = ContextParallelismPlannerRegister.get_planner(
90
+ transformer
91
+ )().apply(transformer=transformer, **extra_parallel_kwargs)
92
+
93
+ transformer.enable_parallelism(
94
+ config=cp_config, cp_plan=cp_plan
95
+ )
96
+ _maybe_patch_native_parallel_config(transformer)
97
+ else:
98
+ raise ValueError(
99
+ f"{transformer.__class__.__name__} does not support context parallelism."
100
+ )
101
+
102
+ return transformer
103
+
104
+
105
+ def _maybe_patch_native_parallel_config(
106
+ transformer: torch.nn.Module,
107
+ ) -> torch.nn.Module:
108
+
109
+ cls_name = transformer.__class__.__name__
110
+ if not cls_name.startswith("Nunchaku"):
111
+ return transformer
112
+
113
+ from diffusers import FluxTransformer2DModel, QwenImageTransformer2DModel
114
+
115
+ try:
116
+ from nunchaku.models.transformers.transformer_flux_v2 import (
117
+ NunchakuFluxTransformer2DModelV2,
118
+ NunchakuFluxAttention,
119
+ NunchakuFluxFA2Processor,
120
+ )
121
+ from nunchaku.models.transformers.transformer_qwenimage import (
122
+ NunchakuQwenAttention,
123
+ NunchakuQwenImageNaiveFA2Processor,
124
+ NunchakuQwenImageTransformer2DModel,
125
+ )
126
+ except ImportError:
127
+ raise ImportError(
128
+ "NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
129
+ "requires the 'nunchaku' package. Please install nunchaku before using "
130
+ "the context parallelism for nunchaku 4-bits models."
131
+ )
132
+ assert isinstance(
133
+ transformer,
134
+ (
135
+ NunchakuFluxTransformer2DModelV2,
136
+ FluxTransformer2DModel,
137
+ ),
138
+ ) or isinstance(
139
+ transformer,
140
+ (
141
+ NunchakuQwenImageTransformer2DModel,
142
+ QwenImageTransformer2DModel,
143
+ ),
144
+ ), (
145
+ "transformer must be an instance of NunchakuFluxTransformer2DModelV2 "
146
+ f"or NunchakuQwenImageTransformer2DModel, but got {type(transformer)}"
147
+ )
148
+ config = transformer._parallel_config
149
+
150
+ attention_classes = (
151
+ NunchakuFluxAttention,
152
+ NunchakuFluxFA2Processor,
153
+ NunchakuQwenAttention,
154
+ NunchakuQwenImageNaiveFA2Processor,
155
+ )
156
+ for module in transformer.modules():
157
+ if not isinstance(module, attention_classes):
158
+ continue
159
+ processor = getattr(module, "processor", None)
160
+ if processor is None or not hasattr(processor, "_parallel_config"):
161
+ continue
162
+ processor._parallel_config = config
163
+
164
+ return transformer
@@ -0,0 +1,4 @@
1
+ def maybe_resigter_native_attention_backend():
2
+ """Maybe re-register native attention backend to enable context parallelism."""
3
+ # Import custom attention backend ensuring registration
4
+ from ._attention_dispatch import _native_attention
@@ -0,0 +1,304 @@
1
+ import os
2
+ import torch
3
+ from typing import Optional
4
+
5
+ try:
6
+ from diffusers.models.attention_dispatch import (
7
+ _AttentionBackendRegistry,
8
+ AttentionBackendName,
9
+ _check_device,
10
+ _check_shape,
11
+ TemplatedRingAttention,
12
+ TemplatedUlyssesAttention,
13
+ )
14
+ from diffusers.models._modeling_parallel import ParallelConfig
15
+ except ImportError:
16
+ raise ImportError(
17
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
18
+ "Please install latest version of diffusers from source: \n"
19
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
20
+ )
21
+ from cache_dit.logger import init_logger
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ __all__ = [
27
+ "_native_attention",
28
+ ]
29
+
30
+ # Enable custom native attention backend with context parallelism
31
+ # by default. Users can set the environment variable to 0 to disable
32
+ # this behavior. Default to enabled for better compatibility.
33
+ _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH = bool(
34
+ int(os.getenv("CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH", "1"))
35
+ )
36
+
37
+
38
+ def _is_native_attn_supported_context_parallel() -> bool:
39
+ try:
40
+ return (
41
+ AttentionBackendName.NATIVE
42
+ in _AttentionBackendRegistry._supports_context_parallel
43
+ and _AttentionBackendRegistry._supports_context_parallel[
44
+ AttentionBackendName.NATIVE
45
+ ]
46
+ )
47
+ except Exception:
48
+ assert isinstance(
49
+ _AttentionBackendRegistry._supports_context_parallel, set
50
+ )
51
+ return (
52
+ AttentionBackendName.NATIVE.value
53
+ in _AttentionBackendRegistry._supports_context_parallel
54
+ )
55
+
56
+
57
+ if _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH:
58
+ logger.warning(
59
+ "Re-registering NATIVE attention backend to enable context parallelism. "
60
+ "This is a temporary workaround and should be removed after the native "
61
+ "attention backend supports context parallelism natively. Please check: "
62
+ "https://github.com/huggingface/diffusers/pull/12563 for more details. "
63
+ "Or, you can disable this behavior by setting the environment variable "
64
+ "`CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0`."
65
+ )
66
+ _AttentionBackendRegistry._backends.pop(AttentionBackendName.NATIVE)
67
+ _AttentionBackendRegistry._constraints.pop(AttentionBackendName.NATIVE)
68
+ _AttentionBackendRegistry._supported_arg_names.pop(
69
+ AttentionBackendName.NATIVE
70
+ )
71
+ if _is_native_attn_supported_context_parallel():
72
+ if isinstance(
73
+ _AttentionBackendRegistry._supports_context_parallel, dict
74
+ ):
75
+ _AttentionBackendRegistry._supports_context_parallel.pop(
76
+ AttentionBackendName.NATIVE
77
+ )
78
+ else:
79
+ _AttentionBackendRegistry._supports_context_parallel.remove(
80
+ AttentionBackendName.NATIVE.value
81
+ )
82
+
83
+ # Re-define templated context parallel attention to support attn mask
84
+ def _templated_context_parallel_attention_v2(
85
+ query: torch.Tensor,
86
+ key: torch.Tensor,
87
+ value: torch.Tensor,
88
+ attn_mask: Optional[torch.Tensor] = None,
89
+ dropout_p: float = 0.0,
90
+ is_causal: bool = False,
91
+ scale: Optional[float] = None,
92
+ enable_gqa: bool = False,
93
+ return_lse: bool = False,
94
+ *,
95
+ forward_op,
96
+ backward_op,
97
+ _parallel_config: Optional["ParallelConfig"] = None,
98
+ ):
99
+ if attn_mask is not None:
100
+ # NOTE(DefTruth): Check if forward_op is native attention forward op
101
+ forward_op_name = forward_op.__name__
102
+ if not forward_op_name == "_native_attention_forward_op":
103
+ raise ValueError(
104
+ "Templated context parallel attention with attn_mask "
105
+ "is only supported for native attention backend, "
106
+ f"but got forward_op: {forward_op_name}."
107
+ )
108
+ if is_causal:
109
+ raise ValueError(
110
+ "Causal attention is not yet supported for templated attention."
111
+ )
112
+ if enable_gqa:
113
+ raise ValueError(
114
+ "GQA is not yet supported for templated attention."
115
+ )
116
+
117
+ # TODO: add support for unified attention with ring/ulysses degree both being > 1
118
+ if _parallel_config.context_parallel_config.ring_degree > 1:
119
+ return TemplatedRingAttention.apply(
120
+ query,
121
+ key,
122
+ value,
123
+ attn_mask,
124
+ dropout_p,
125
+ is_causal,
126
+ scale,
127
+ enable_gqa,
128
+ return_lse,
129
+ forward_op,
130
+ backward_op,
131
+ _parallel_config,
132
+ )
133
+ elif _parallel_config.context_parallel_config.ulysses_degree > 1:
134
+ return TemplatedUlyssesAttention.apply(
135
+ query,
136
+ key,
137
+ value,
138
+ attn_mask,
139
+ dropout_p,
140
+ is_causal,
141
+ scale,
142
+ enable_gqa,
143
+ return_lse,
144
+ forward_op,
145
+ backward_op,
146
+ _parallel_config,
147
+ )
148
+ else:
149
+ raise ValueError(
150
+ "Reaching this branch of code is unexpected. Please report a bug."
151
+ )
152
+
153
+ # NOTE:Remove NATIVE attention backend constraints and re-register it.
154
+ # Here is a temporary workaround to enable context parallelism with
155
+ # native attention backend. We should remove this workaround after
156
+ # the native attention backend supports context parallelism natively.
157
+ # Adapted from: https://github.com/huggingface/diffusers/pull/12563
158
+
159
+ def _native_attention_forward_op(
160
+ ctx: torch.autograd.function.FunctionCtx,
161
+ query: torch.Tensor,
162
+ key: torch.Tensor,
163
+ value: torch.Tensor,
164
+ attn_mask: Optional[torch.Tensor] = None,
165
+ dropout_p: float = 0.0,
166
+ is_causal: bool = False,
167
+ scale: Optional[float] = None,
168
+ enable_gqa: bool = False,
169
+ return_lse: bool = False,
170
+ _save_ctx: bool = True,
171
+ _parallel_config: Optional["ParallelConfig"] = None,
172
+ ):
173
+ # Native attention does not return_lse
174
+ if return_lse:
175
+ raise ValueError(
176
+ "Native attention does not support return_lse=True"
177
+ )
178
+
179
+ # used for backward pass
180
+ if _save_ctx:
181
+ ctx.save_for_backward(query, key, value)
182
+ ctx.attn_mask = attn_mask
183
+ ctx.dropout_p = dropout_p
184
+ ctx.is_causal = is_causal
185
+ ctx.scale = scale
186
+ ctx.enable_gqa = enable_gqa
187
+
188
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
189
+ out = torch.nn.functional.scaled_dot_product_attention(
190
+ query=query,
191
+ key=key,
192
+ value=value,
193
+ attn_mask=attn_mask,
194
+ dropout_p=dropout_p,
195
+ is_causal=is_causal,
196
+ scale=scale,
197
+ enable_gqa=enable_gqa,
198
+ )
199
+ out = out.permute(0, 2, 1, 3)
200
+
201
+ return out
202
+
203
+ def _native_attention_backward_op(
204
+ ctx: torch.autograd.function.FunctionCtx,
205
+ grad_out: torch.Tensor,
206
+ *args,
207
+ **kwargs,
208
+ ):
209
+ query, key, value = ctx.saved_tensors
210
+
211
+ query.requires_grad_(True)
212
+ key.requires_grad_(True)
213
+ value.requires_grad_(True)
214
+
215
+ query_t, key_t, value_t = (
216
+ x.permute(0, 2, 1, 3) for x in (query, key, value)
217
+ )
218
+ out = torch.nn.functional.scaled_dot_product_attention(
219
+ query=query_t,
220
+ key=key_t,
221
+ value=value_t,
222
+ attn_mask=ctx.attn_mask,
223
+ dropout_p=ctx.dropout_p,
224
+ is_causal=ctx.is_causal,
225
+ scale=ctx.scale,
226
+ enable_gqa=ctx.enable_gqa,
227
+ )
228
+ out = out.permute(0, 2, 1, 3)
229
+
230
+ grad_out_t = grad_out.permute(0, 2, 1, 3)
231
+ grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
232
+ outputs=out,
233
+ inputs=[query_t, key_t, value_t],
234
+ grad_outputs=grad_out_t,
235
+ retain_graph=False,
236
+ )
237
+
238
+ grad_query = grad_query_t.permute(0, 2, 1, 3)
239
+ grad_key = grad_key_t.permute(0, 2, 1, 3)
240
+ grad_value = grad_value_t.permute(0, 2, 1, 3)
241
+
242
+ return grad_query, grad_key, grad_value
243
+
244
+ @_AttentionBackendRegistry.register(
245
+ AttentionBackendName.NATIVE,
246
+ constraints=[_check_device, _check_shape],
247
+ supports_context_parallel=True,
248
+ )
249
+ def _native_attention(
250
+ query: torch.Tensor,
251
+ key: torch.Tensor,
252
+ value: torch.Tensor,
253
+ attn_mask: Optional[torch.Tensor] = None,
254
+ dropout_p: float = 0.0,
255
+ is_causal: bool = False,
256
+ scale: Optional[float] = None,
257
+ enable_gqa: bool = False,
258
+ return_lse: bool = False,
259
+ _parallel_config: Optional["ParallelConfig"] = None,
260
+ ) -> torch.Tensor:
261
+ if return_lse:
262
+ raise ValueError(
263
+ "Native attention backend does not support setting `return_lse=True`."
264
+ )
265
+ if _parallel_config is None:
266
+ query, key, value = (
267
+ x.permute(0, 2, 1, 3) for x in (query, key, value)
268
+ )
269
+ out = torch.nn.functional.scaled_dot_product_attention(
270
+ query=query,
271
+ key=key,
272
+ value=value,
273
+ attn_mask=attn_mask,
274
+ dropout_p=dropout_p,
275
+ is_causal=is_causal,
276
+ scale=scale,
277
+ enable_gqa=enable_gqa,
278
+ )
279
+ out = out.permute(0, 2, 1, 3)
280
+ else:
281
+ out = _templated_context_parallel_attention_v2(
282
+ query,
283
+ key,
284
+ value,
285
+ attn_mask,
286
+ dropout_p,
287
+ is_causal,
288
+ scale,
289
+ enable_gqa,
290
+ return_lse,
291
+ forward_op=_native_attention_forward_op,
292
+ backward_op=_native_attention_backward_op,
293
+ _parallel_config=_parallel_config,
294
+ )
295
+ return out
296
+
297
+ else:
298
+ from diffusers.models.attention_dispatch import (
299
+ _native_attention,
300
+ ) # noqa: F401
301
+
302
+ logger.info(
303
+ "Native attention backend already supports context parallelism."
304
+ )
@@ -0,0 +1,95 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ from diffusers.models.transformers.transformer_chroma import (
5
+ ChromaTransformer2DModel,
6
+ )
7
+
8
+ try:
9
+ from diffusers.models._modeling_parallel import (
10
+ ContextParallelInput,
11
+ ContextParallelOutput,
12
+ ContextParallelModelPlan,
13
+ )
14
+ except ImportError:
15
+ raise ImportError(
16
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
17
+ "Please install latest version of diffusers from source: \n"
18
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
19
+ )
20
+ from .cp_plan_registers import (
21
+ ContextParallelismPlanner,
22
+ ContextParallelismPlannerRegister,
23
+ )
24
+
25
+ from cache_dit.logger import init_logger
26
+
27
+ logger = init_logger(__name__)
28
+
29
+
30
+ @ContextParallelismPlannerRegister.register("Chroma")
31
+ class ChromaContextParallelismPlanner(ContextParallelismPlanner):
32
+ def apply(
33
+ self,
34
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
35
+ **kwargs,
36
+ ) -> ContextParallelModelPlan:
37
+
38
+ # NOTE: Diffusers native CP plan still not supported
39
+ # for Chroma now.
40
+ self._cp_planner_preferred_native_diffusers = False
41
+
42
+ if (
43
+ transformer is not None
44
+ and self._cp_planner_preferred_native_diffusers
45
+ ):
46
+ assert isinstance(
47
+ transformer, ChromaTransformer2DModel
48
+ ), "Transformer must be an instance of ChromaTransformer2DModel"
49
+ if hasattr(transformer, "_cp_plan"):
50
+ if transformer._cp_plan is not None:
51
+ return transformer._cp_plan
52
+
53
+ # Otherwise, use the custom CP plan defined here, this maybe
54
+ # a little different from the native diffusers implementation
55
+ # for some models.
56
+ _cp_plan = {
57
+ # Here is a Transformer level CP plan for Chroma, which will
58
+ # only apply the only 1 split hook (pre_forward) on the forward
59
+ # of Transformer, and gather the output after Transformer forward.
60
+ # Pattern of transformer forward, split_output=False:
61
+ # un-split input -> splited input (inside transformer)
62
+ # Pattern of the transformer_blocks, single_transformer_blocks:
63
+ # splited input (previous splited output) -> to_qkv/...
64
+ # -> all2all
65
+ # -> attn (local head, full seqlen)
66
+ # -> all2all
67
+ # -> splited output
68
+ # The `hidden_states` and `encoder_hidden_states` will still keep
69
+ # itself splited after block forward (namely, automatic split by
70
+ # the all2all comm op after attn) for the all blocks.
71
+ # img_ids and txt_ids will only be splited once at the very beginning,
72
+ # and keep splited through the whole transformer forward. The all2all
73
+ # comm op only happens on the `out` tensor after local attn not on
74
+ # img_ids and txt_ids.
75
+ "": {
76
+ "hidden_states": ContextParallelInput(
77
+ split_dim=1, expected_dims=3, split_output=False
78
+ ),
79
+ "encoder_hidden_states": ContextParallelInput(
80
+ split_dim=1, expected_dims=3, split_output=False
81
+ ),
82
+ "img_ids": ContextParallelInput(
83
+ split_dim=0, expected_dims=2, split_output=False
84
+ ),
85
+ "txt_ids": ContextParallelInput(
86
+ split_dim=0, expected_dims=2, split_output=False
87
+ ),
88
+ },
89
+ # Then, the final proj_out will gather the splited output.
90
+ # splited input (previous splited output)
91
+ # -> all gather
92
+ # -> un-split output
93
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
94
+ }
95
+ return _cp_plan