cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,304 @@
1
+ import os
2
+ import torch
3
+ from typing import Optional
4
+
5
+ try:
6
+ from diffusers.models.attention_dispatch import (
7
+ _AttentionBackendRegistry,
8
+ AttentionBackendName,
9
+ _check_device,
10
+ _check_shape,
11
+ TemplatedRingAttention,
12
+ TemplatedUlyssesAttention,
13
+ )
14
+ from diffusers.models._modeling_parallel import ParallelConfig
15
+ except ImportError:
16
+ raise ImportError(
17
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
18
+ "Please install latest version of diffusers from source: \n"
19
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
20
+ )
21
+ from cache_dit.logger import init_logger
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ __all__ = [
27
+ "_native_attention",
28
+ ]
29
+
30
+ # Enable custom native attention backend with context parallelism
31
+ # by default. Users can set the environment variable to 0 to disable
32
+ # this behavior. Default to enabled for better compatibility.
33
+ _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH = bool(
34
+ int(os.getenv("CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH", "1"))
35
+ )
36
+
37
+
38
+ def _is_native_attn_supported_context_parallel() -> bool:
39
+ try:
40
+ return (
41
+ AttentionBackendName.NATIVE
42
+ in _AttentionBackendRegistry._supports_context_parallel
43
+ and _AttentionBackendRegistry._supports_context_parallel[
44
+ AttentionBackendName.NATIVE
45
+ ]
46
+ )
47
+ except Exception:
48
+ assert isinstance(
49
+ _AttentionBackendRegistry._supports_context_parallel, set
50
+ )
51
+ return (
52
+ AttentionBackendName.NATIVE.value
53
+ in _AttentionBackendRegistry._supports_context_parallel
54
+ )
55
+
56
+
57
+ if _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH:
58
+ logger.warning(
59
+ "Re-registering NATIVE attention backend to enable context parallelism. "
60
+ "This is a temporary workaround and should be removed after the native "
61
+ "attention backend supports context parallelism natively. Please check: "
62
+ "https://github.com/huggingface/diffusers/pull/12563 for more details. "
63
+ "Or, you can disable this behavior by setting the environment variable "
64
+ "`CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0`."
65
+ )
66
+ _AttentionBackendRegistry._backends.pop(AttentionBackendName.NATIVE)
67
+ _AttentionBackendRegistry._constraints.pop(AttentionBackendName.NATIVE)
68
+ _AttentionBackendRegistry._supported_arg_names.pop(
69
+ AttentionBackendName.NATIVE
70
+ )
71
+ if _is_native_attn_supported_context_parallel():
72
+ if isinstance(
73
+ _AttentionBackendRegistry._supports_context_parallel, dict
74
+ ):
75
+ _AttentionBackendRegistry._supports_context_parallel.pop(
76
+ AttentionBackendName.NATIVE
77
+ )
78
+ else:
79
+ _AttentionBackendRegistry._supports_context_parallel.remove(
80
+ AttentionBackendName.NATIVE.value
81
+ )
82
+
83
+ # Re-define templated context parallel attention to support attn mask
84
+ def _templated_context_parallel_attention_v2(
85
+ query: torch.Tensor,
86
+ key: torch.Tensor,
87
+ value: torch.Tensor,
88
+ attn_mask: Optional[torch.Tensor] = None,
89
+ dropout_p: float = 0.0,
90
+ is_causal: bool = False,
91
+ scale: Optional[float] = None,
92
+ enable_gqa: bool = False,
93
+ return_lse: bool = False,
94
+ *,
95
+ forward_op,
96
+ backward_op,
97
+ _parallel_config: Optional["ParallelConfig"] = None,
98
+ ):
99
+ if attn_mask is not None:
100
+ # NOTE(DefTruth): Check if forward_op is native attention forward op
101
+ forward_op_name = forward_op.__name__
102
+ if not forward_op_name == "_native_attention_forward_op":
103
+ raise ValueError(
104
+ "Templated context parallel attention with attn_mask "
105
+ "is only supported for native attention backend, "
106
+ f"but got forward_op: {forward_op_name}."
107
+ )
108
+ if is_causal:
109
+ raise ValueError(
110
+ "Causal attention is not yet supported for templated attention."
111
+ )
112
+ if enable_gqa:
113
+ raise ValueError(
114
+ "GQA is not yet supported for templated attention."
115
+ )
116
+
117
+ # TODO: add support for unified attention with ring/ulysses degree both being > 1
118
+ if _parallel_config.context_parallel_config.ring_degree > 1:
119
+ return TemplatedRingAttention.apply(
120
+ query,
121
+ key,
122
+ value,
123
+ attn_mask,
124
+ dropout_p,
125
+ is_causal,
126
+ scale,
127
+ enable_gqa,
128
+ return_lse,
129
+ forward_op,
130
+ backward_op,
131
+ _parallel_config,
132
+ )
133
+ elif _parallel_config.context_parallel_config.ulysses_degree > 1:
134
+ return TemplatedUlyssesAttention.apply(
135
+ query,
136
+ key,
137
+ value,
138
+ attn_mask,
139
+ dropout_p,
140
+ is_causal,
141
+ scale,
142
+ enable_gqa,
143
+ return_lse,
144
+ forward_op,
145
+ backward_op,
146
+ _parallel_config,
147
+ )
148
+ else:
149
+ raise ValueError(
150
+ "Reaching this branch of code is unexpected. Please report a bug."
151
+ )
152
+
153
+ # NOTE:Remove NATIVE attention backend constraints and re-register it.
154
+ # Here is a temporary workaround to enable context parallelism with
155
+ # native attention backend. We should remove this workaround after
156
+ # the native attention backend supports context parallelism natively.
157
+ # Adapted from: https://github.com/huggingface/diffusers/pull/12563
158
+
159
+ def _native_attention_forward_op(
160
+ ctx: torch.autograd.function.FunctionCtx,
161
+ query: torch.Tensor,
162
+ key: torch.Tensor,
163
+ value: torch.Tensor,
164
+ attn_mask: Optional[torch.Tensor] = None,
165
+ dropout_p: float = 0.0,
166
+ is_causal: bool = False,
167
+ scale: Optional[float] = None,
168
+ enable_gqa: bool = False,
169
+ return_lse: bool = False,
170
+ _save_ctx: bool = True,
171
+ _parallel_config: Optional["ParallelConfig"] = None,
172
+ ):
173
+ # Native attention does not return_lse
174
+ if return_lse:
175
+ raise ValueError(
176
+ "Native attention does not support return_lse=True"
177
+ )
178
+
179
+ # used for backward pass
180
+ if _save_ctx:
181
+ ctx.save_for_backward(query, key, value)
182
+ ctx.attn_mask = attn_mask
183
+ ctx.dropout_p = dropout_p
184
+ ctx.is_causal = is_causal
185
+ ctx.scale = scale
186
+ ctx.enable_gqa = enable_gqa
187
+
188
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
189
+ out = torch.nn.functional.scaled_dot_product_attention(
190
+ query=query,
191
+ key=key,
192
+ value=value,
193
+ attn_mask=attn_mask,
194
+ dropout_p=dropout_p,
195
+ is_causal=is_causal,
196
+ scale=scale,
197
+ enable_gqa=enable_gqa,
198
+ )
199
+ out = out.permute(0, 2, 1, 3)
200
+
201
+ return out
202
+
203
+ def _native_attention_backward_op(
204
+ ctx: torch.autograd.function.FunctionCtx,
205
+ grad_out: torch.Tensor,
206
+ *args,
207
+ **kwargs,
208
+ ):
209
+ query, key, value = ctx.saved_tensors
210
+
211
+ query.requires_grad_(True)
212
+ key.requires_grad_(True)
213
+ value.requires_grad_(True)
214
+
215
+ query_t, key_t, value_t = (
216
+ x.permute(0, 2, 1, 3) for x in (query, key, value)
217
+ )
218
+ out = torch.nn.functional.scaled_dot_product_attention(
219
+ query=query_t,
220
+ key=key_t,
221
+ value=value_t,
222
+ attn_mask=ctx.attn_mask,
223
+ dropout_p=ctx.dropout_p,
224
+ is_causal=ctx.is_causal,
225
+ scale=ctx.scale,
226
+ enable_gqa=ctx.enable_gqa,
227
+ )
228
+ out = out.permute(0, 2, 1, 3)
229
+
230
+ grad_out_t = grad_out.permute(0, 2, 1, 3)
231
+ grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
232
+ outputs=out,
233
+ inputs=[query_t, key_t, value_t],
234
+ grad_outputs=grad_out_t,
235
+ retain_graph=False,
236
+ )
237
+
238
+ grad_query = grad_query_t.permute(0, 2, 1, 3)
239
+ grad_key = grad_key_t.permute(0, 2, 1, 3)
240
+ grad_value = grad_value_t.permute(0, 2, 1, 3)
241
+
242
+ return grad_query, grad_key, grad_value
243
+
244
+ @_AttentionBackendRegistry.register(
245
+ AttentionBackendName.NATIVE,
246
+ constraints=[_check_device, _check_shape],
247
+ supports_context_parallel=True,
248
+ )
249
+ def _native_attention(
250
+ query: torch.Tensor,
251
+ key: torch.Tensor,
252
+ value: torch.Tensor,
253
+ attn_mask: Optional[torch.Tensor] = None,
254
+ dropout_p: float = 0.0,
255
+ is_causal: bool = False,
256
+ scale: Optional[float] = None,
257
+ enable_gqa: bool = False,
258
+ return_lse: bool = False,
259
+ _parallel_config: Optional["ParallelConfig"] = None,
260
+ ) -> torch.Tensor:
261
+ if return_lse:
262
+ raise ValueError(
263
+ "Native attention backend does not support setting `return_lse=True`."
264
+ )
265
+ if _parallel_config is None:
266
+ query, key, value = (
267
+ x.permute(0, 2, 1, 3) for x in (query, key, value)
268
+ )
269
+ out = torch.nn.functional.scaled_dot_product_attention(
270
+ query=query,
271
+ key=key,
272
+ value=value,
273
+ attn_mask=attn_mask,
274
+ dropout_p=dropout_p,
275
+ is_causal=is_causal,
276
+ scale=scale,
277
+ enable_gqa=enable_gqa,
278
+ )
279
+ out = out.permute(0, 2, 1, 3)
280
+ else:
281
+ out = _templated_context_parallel_attention_v2(
282
+ query,
283
+ key,
284
+ value,
285
+ attn_mask,
286
+ dropout_p,
287
+ is_causal,
288
+ scale,
289
+ enable_gqa,
290
+ return_lse,
291
+ forward_op=_native_attention_forward_op,
292
+ backward_op=_native_attention_backward_op,
293
+ _parallel_config=_parallel_config,
294
+ )
295
+ return out
296
+
297
+ else:
298
+ from diffusers.models.attention_dispatch import (
299
+ _native_attention,
300
+ ) # noqa: F401
301
+
302
+ logger.info(
303
+ "Native attention backend already supports context parallelism."
304
+ )
@@ -0,0 +1,95 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ from diffusers.models.transformers.transformer_chroma import (
5
+ ChromaTransformer2DModel,
6
+ )
7
+
8
+ try:
9
+ from diffusers.models._modeling_parallel import (
10
+ ContextParallelInput,
11
+ ContextParallelOutput,
12
+ ContextParallelModelPlan,
13
+ )
14
+ except ImportError:
15
+ raise ImportError(
16
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
17
+ "Please install latest version of diffusers from source: \n"
18
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
19
+ )
20
+ from .cp_plan_registers import (
21
+ ContextParallelismPlanner,
22
+ ContextParallelismPlannerRegister,
23
+ )
24
+
25
+ from cache_dit.logger import init_logger
26
+
27
+ logger = init_logger(__name__)
28
+
29
+
30
+ @ContextParallelismPlannerRegister.register("Chroma")
31
+ class ChromaContextParallelismPlanner(ContextParallelismPlanner):
32
+ def apply(
33
+ self,
34
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
35
+ **kwargs,
36
+ ) -> ContextParallelModelPlan:
37
+
38
+ # NOTE: Diffusers native CP plan still not supported
39
+ # for Chroma now.
40
+ self._cp_planner_preferred_native_diffusers = False
41
+
42
+ if (
43
+ transformer is not None
44
+ and self._cp_planner_preferred_native_diffusers
45
+ ):
46
+ assert isinstance(
47
+ transformer, ChromaTransformer2DModel
48
+ ), "Transformer must be an instance of ChromaTransformer2DModel"
49
+ if hasattr(transformer, "_cp_plan"):
50
+ if transformer._cp_plan is not None:
51
+ return transformer._cp_plan
52
+
53
+ # Otherwise, use the custom CP plan defined here, this maybe
54
+ # a little different from the native diffusers implementation
55
+ # for some models.
56
+ _cp_plan = {
57
+ # Here is a Transformer level CP plan for Chroma, which will
58
+ # only apply the only 1 split hook (pre_forward) on the forward
59
+ # of Transformer, and gather the output after Transformer forward.
60
+ # Pattern of transformer forward, split_output=False:
61
+ # un-split input -> splited input (inside transformer)
62
+ # Pattern of the transformer_blocks, single_transformer_blocks:
63
+ # splited input (previous splited output) -> to_qkv/...
64
+ # -> all2all
65
+ # -> attn (local head, full seqlen)
66
+ # -> all2all
67
+ # -> splited output
68
+ # The `hidden_states` and `encoder_hidden_states` will still keep
69
+ # itself splited after block forward (namely, automatic split by
70
+ # the all2all comm op after attn) for the all blocks.
71
+ # img_ids and txt_ids will only be splited once at the very beginning,
72
+ # and keep splited through the whole transformer forward. The all2all
73
+ # comm op only happens on the `out` tensor after local attn not on
74
+ # img_ids and txt_ids.
75
+ "": {
76
+ "hidden_states": ContextParallelInput(
77
+ split_dim=1, expected_dims=3, split_output=False
78
+ ),
79
+ "encoder_hidden_states": ContextParallelInput(
80
+ split_dim=1, expected_dims=3, split_output=False
81
+ ),
82
+ "img_ids": ContextParallelInput(
83
+ split_dim=0, expected_dims=2, split_output=False
84
+ ),
85
+ "txt_ids": ContextParallelInput(
86
+ split_dim=0, expected_dims=2, split_output=False
87
+ ),
88
+ },
89
+ # Then, the final proj_out will gather the splited output.
90
+ # splited input (previous splited output)
91
+ # -> all gather
92
+ # -> un-split output
93
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
94
+ }
95
+ return _cp_plan
@@ -0,0 +1,202 @@
1
+ import torch
2
+ import functools
3
+ from typing import Optional, Tuple
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from diffusers.models.transformers.cogvideox_transformer_3d import (
6
+ CogVideoXAttnProcessor2_0,
7
+ CogVideoXTransformer3DModel,
8
+ )
9
+ from diffusers.models.attention_processor import Attention
10
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
11
+
12
+ try:
13
+ from diffusers.models._modeling_parallel import (
14
+ ContextParallelInput,
15
+ ContextParallelOutput,
16
+ ContextParallelModelPlan,
17
+ )
18
+ except ImportError:
19
+ raise ImportError(
20
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
21
+ "Please install latest version of diffusers from source: \n"
22
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
23
+ )
24
+ from .cp_plan_registers import (
25
+ ContextParallelismPlanner,
26
+ ContextParallelismPlannerRegister,
27
+ )
28
+
29
+ from cache_dit.logger import init_logger
30
+
31
+ logger = init_logger(__name__)
32
+
33
+
34
+ @ContextParallelismPlannerRegister.register("CogVideoX")
35
+ class CogVideoXContextParallelismPlanner(ContextParallelismPlanner):
36
+ def apply(
37
+ self,
38
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
39
+ **kwargs,
40
+ ) -> ContextParallelModelPlan:
41
+
42
+ # NOTE: Diffusers native CP plan still not supported
43
+ # for CogVideoX now.
44
+ self._cp_planner_preferred_native_diffusers = False
45
+
46
+ if (
47
+ transformer is not None
48
+ and self._cp_planner_preferred_native_diffusers
49
+ ):
50
+ assert isinstance(
51
+ transformer, CogVideoXTransformer3DModel
52
+ ), "Transformer must be an instance of CogVideoXTransformer3DModel"
53
+ if hasattr(transformer, "_cp_plan"):
54
+ if transformer._cp_plan is not None:
55
+ return transformer._cp_plan
56
+
57
+ CogVideoXAttnProcessor2_0.__call__ = (
58
+ __patch_CogVideoXAttnProcessor2_0__call__
59
+ )
60
+ # Also need to patch the parallel config and attention backend
61
+ if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
62
+ CogVideoXAttnProcessor2_0._parallel_config = None
63
+ if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
64
+ CogVideoXAttnProcessor2_0._attention_backend = None
65
+
66
+ # Otherwise, use the custom CP plan defined here, this maybe
67
+ # a little different from the native diffusers implementation
68
+ # for some models.
69
+ _cp_plan = {
70
+ # Pattern of transformer_blocks.0, split_output=False:
71
+ # un-split input -> split -> to_qkv/...
72
+ # -> all2all
73
+ # -> attn (local head, full seqlen)
74
+ # -> all2all
75
+ # -> splited output
76
+ # Pattern of the rest transformer_blocks, split_output=False:
77
+ # splited input (previous splited output) -> to_qkv/...
78
+ # -> all2all
79
+ # -> attn (local head, full seqlen)
80
+ # -> all2all
81
+ # -> splited output
82
+ # The `encoder_hidden_states` will be changed after each block forward,
83
+ # so we need to split it at the first block, and keep it splited (namely,
84
+ # automatically split by the all2all op after attn) for the rest blocks.
85
+ # The `out` tensor of local attn will be splited into `hidden_states` and
86
+ # `encoder_hidden_states` after each block forward, thus both of them
87
+ # will be automatically splited by all2all comm op after local attn.
88
+ "transformer_blocks.0": {
89
+ "hidden_states": ContextParallelInput(
90
+ split_dim=1, expected_dims=3, split_output=False
91
+ ),
92
+ "encoder_hidden_states": ContextParallelInput(
93
+ split_dim=1, expected_dims=3, split_output=False
94
+ ),
95
+ },
96
+ # Pattern of the image_rotary_emb, split at every block, because the it
97
+ # is not automatically splited by all2all comm op and keep un-splited
98
+ # while the block forward finished:
99
+ # un-split input -> split output
100
+ # -> after block forward
101
+ # -> un-split input
102
+ # un-split input -> split output
103
+ # ...
104
+ "transformer_blocks.*": {
105
+ "image_rotary_emb": [
106
+ ContextParallelInput(
107
+ split_dim=0, expected_dims=2, split_output=False
108
+ ),
109
+ ContextParallelInput(
110
+ split_dim=0, expected_dims=2, split_output=False
111
+ ),
112
+ ],
113
+ },
114
+ # transformer forward while using CP, since it is not splited here.
115
+ # Then, the final proj_out will gather the splited output.
116
+ # splited input (previous splited output)
117
+ # -> all gather
118
+ # -> un-split output
119
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
120
+ }
121
+ return _cp_plan
122
+
123
+
124
+ @functools.wraps(CogVideoXAttnProcessor2_0.__call__)
125
+ def __patch_CogVideoXAttnProcessor2_0__call__(
126
+ self: CogVideoXAttnProcessor2_0,
127
+ attn: Attention,
128
+ hidden_states: torch.Tensor,
129
+ encoder_hidden_states: torch.Tensor,
130
+ attention_mask: Optional[torch.Tensor] = None,
131
+ image_rotary_emb: Optional[torch.Tensor] = None,
132
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ text_seq_length = encoder_hidden_states.size(1)
134
+
135
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
136
+
137
+ batch_size, sequence_length, _ = hidden_states.shape
138
+
139
+ # NOTE(DefTruth): attention mask is always None in CogVideoX
140
+ if attention_mask is not None:
141
+ attention_mask = attn.prepare_attention_mask(
142
+ attention_mask, sequence_length, batch_size
143
+ )
144
+ attention_mask = attention_mask.view(
145
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
146
+ )
147
+
148
+ query = attn.to_q(hidden_states)
149
+ key = attn.to_k(hidden_states)
150
+ value = attn.to_v(hidden_states)
151
+
152
+ inner_dim = key.shape[-1]
153
+ head_dim = inner_dim // attn.heads
154
+
155
+ # NOTE(DefTruth): no transpose
156
+ query = query.view(batch_size, -1, attn.heads, head_dim)
157
+ key = key.view(batch_size, -1, attn.heads, head_dim)
158
+ value = value.view(batch_size, -1, attn.heads, head_dim)
159
+
160
+ if attn.norm_q is not None:
161
+ query = attn.norm_q(query)
162
+ if attn.norm_k is not None:
163
+ key = attn.norm_k(key)
164
+
165
+ # Apply RoPE if needed
166
+ if image_rotary_emb is not None:
167
+ from diffusers.models.embeddings import apply_rotary_emb
168
+
169
+ query[:, text_seq_length:] = apply_rotary_emb(
170
+ query[:, text_seq_length:],
171
+ image_rotary_emb,
172
+ sequence_dim=1,
173
+ )
174
+ if not attn.is_cross_attention:
175
+ key[:, text_seq_length:] = apply_rotary_emb(
176
+ key[:, text_seq_length:],
177
+ image_rotary_emb,
178
+ sequence_dim=1,
179
+ )
180
+
181
+ # NOTE(DefTruth): Apply dispatch_attention_fn instead of sdpa directly
182
+ hidden_states = dispatch_attention_fn(
183
+ query,
184
+ key,
185
+ value,
186
+ attn_mask=attention_mask,
187
+ dropout_p=0.0,
188
+ is_causal=False,
189
+ backend=getattr(self, "_attention_backend", None),
190
+ parallel_config=getattr(self, "_parallel_config", None),
191
+ )
192
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
193
+
194
+ # linear proj
195
+ hidden_states = attn.to_out[0](hidden_states)
196
+ # dropout
197
+ hidden_states = attn.to_out[1](hidden_states)
198
+
199
+ encoder_hidden_states, hidden_states = hidden_states.split(
200
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
201
+ )
202
+ return hidden_states, encoder_hidden_states