cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,299 @@
1
+ import torch
2
+ import functools
3
+ from typing import Optional, Tuple
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from diffusers.models.transformers.transformer_cogview3plus import (
6
+ CogView3PlusTransformer2DModel,
7
+ CogVideoXAttnProcessor2_0,
8
+ )
9
+ from diffusers.models.transformers.transformer_cogview4 import (
10
+ CogView4Transformer2DModel,
11
+ CogView4AttnProcessor,
12
+ )
13
+ from diffusers.models.attention_processor import Attention
14
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
15
+
16
+ try:
17
+ from diffusers.models._modeling_parallel import (
18
+ ContextParallelInput,
19
+ ContextParallelOutput,
20
+ ContextParallelModelPlan,
21
+ )
22
+ except ImportError:
23
+ raise ImportError(
24
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
25
+ "Please install latest version of diffusers from source: \n"
26
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
27
+ )
28
+ from .cp_plan_registers import (
29
+ ContextParallelismPlanner,
30
+ ContextParallelismPlannerRegister,
31
+ )
32
+ from .cp_plan_cogvideox import __patch_CogVideoXAttnProcessor2_0__call__
33
+
34
+ from cache_dit.logger import init_logger
35
+
36
+ logger = init_logger(__name__)
37
+
38
+
39
+ @ContextParallelismPlannerRegister.register("CogView3Plus")
40
+ class CogView3PlusContextParallelismPlanner(ContextParallelismPlanner):
41
+ def apply(
42
+ self,
43
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
44
+ **kwargs,
45
+ ) -> ContextParallelModelPlan:
46
+
47
+ # NOTE: Diffusers native CP plan still not supported
48
+ # for CogView3Plus now.
49
+ self._cp_planner_preferred_native_diffusers = False
50
+
51
+ if (
52
+ transformer is not None
53
+ and self._cp_planner_preferred_native_diffusers
54
+ ):
55
+ assert isinstance(
56
+ transformer, CogView3PlusTransformer2DModel
57
+ ), "Transformer must be an instance of CogView3PlusTransformer2DModel"
58
+ if hasattr(transformer, "_cp_plan"):
59
+ if transformer._cp_plan is not None:
60
+ return transformer._cp_plan
61
+
62
+ # CogView3Plus and CogVideoX share the same attention processor
63
+ CogVideoXAttnProcessor2_0.__call__ = (
64
+ __patch_CogVideoXAttnProcessor2_0__call__
65
+ )
66
+ # Also need to patch the parallel config and attention backend
67
+ if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
68
+ CogVideoXAttnProcessor2_0._parallel_config = None
69
+ if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
70
+ CogVideoXAttnProcessor2_0._attention_backend = None
71
+
72
+ # Otherwise, use the custom CP plan defined here, this maybe
73
+ # a little different from the native diffusers implementation
74
+ # for some models.
75
+ _cp_plan = {
76
+ # Pattern of transformer_blocks.0, split_output=False:
77
+ # un-split input -> split -> to_qkv/...
78
+ # -> all2all
79
+ # -> attn (local head, full seqlen)
80
+ # -> all2all
81
+ # -> splited output
82
+ # Pattern of the rest transformer_blocks, split_output=False:
83
+ # splited input (previous splited output) -> to_qkv/...
84
+ # -> all2all
85
+ # -> attn (local head, full seqlen)
86
+ # -> all2all
87
+ # -> splited output
88
+ # The `encoder_hidden_states` will be changed after each block forward,
89
+ # so we need to split it at the first block, and keep it splited (namely,
90
+ # automatically split by the all2all op after attn) for the rest blocks.
91
+ # The `out` tensor of local attn will be splited into `hidden_states` and
92
+ # `encoder_hidden_states` after each block forward, thus both of them
93
+ # will be automatically splited by all2all comm op after local attn.
94
+ "transformer_blocks.0": {
95
+ "hidden_states": ContextParallelInput(
96
+ split_dim=1, expected_dims=3, split_output=False
97
+ ),
98
+ "encoder_hidden_states": ContextParallelInput(
99
+ split_dim=1, expected_dims=3, split_output=False
100
+ ),
101
+ },
102
+ # transformer forward while using CP, since it is not splited here.
103
+ # Then, the final proj_out will gather the splited output.
104
+ # splited input (previous splited output)
105
+ # -> all gather
106
+ # -> un-split output
107
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
108
+ }
109
+ return _cp_plan
110
+
111
+
112
+ @ContextParallelismPlannerRegister.register("CogView4")
113
+ class CogView4ContextParallelismPlanner(ContextParallelismPlanner):
114
+ def apply(
115
+ self,
116
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
117
+ **kwargs,
118
+ ) -> ContextParallelModelPlan:
119
+
120
+ # NOTE: Diffusers native CP plan still not supported
121
+ # for CogView4 now.
122
+ self._cp_planner_preferred_native_diffusers = False
123
+
124
+ if (
125
+ transformer is not None
126
+ and self._cp_planner_preferred_native_diffusers
127
+ ):
128
+ assert isinstance(
129
+ transformer, CogView4Transformer2DModel
130
+ ), "Transformer must be an instance of CogView4Transformer2DModel"
131
+ if hasattr(transformer, "_cp_plan"):
132
+ if transformer._cp_plan is not None:
133
+ return transformer._cp_plan
134
+
135
+ CogView4AttnProcessor.__call__ = __patch_CogView4AttnProcessor__call__
136
+ # Also need to patch the parallel config and attention backend
137
+ if not hasattr(CogView4AttnProcessor, "_parallel_config"):
138
+ CogView4AttnProcessor._parallel_config = None
139
+ if not hasattr(CogView4AttnProcessor, "_attention_backend"):
140
+ CogView4AttnProcessor._attention_backend = None
141
+
142
+ # Otherwise, use the custom CP plan defined here, this maybe
143
+ # a little different from the native diffusers implementation
144
+ # for some models.
145
+ _cp_plan = {
146
+ # Pattern of transformer_blocks.0, split_output=False:
147
+ # un-split input -> split -> to_qkv/...
148
+ # -> all2all
149
+ # -> attn (local head, full seqlen)
150
+ # -> all2all
151
+ # -> splited output
152
+ # Pattern of the rest transformer_blocks, split_output=False:
153
+ # splited input (previous splited output) -> to_qkv/...
154
+ # -> all2all
155
+ # -> attn (local head, full seqlen)
156
+ # -> all2all
157
+ # -> splited output
158
+ # The `encoder_hidden_states` will be changed after each block forward,
159
+ # so we need to split it at the first block, and keep it splited (namely,
160
+ # automatically split by the all2all op after attn) for the rest blocks.
161
+ # The `out` tensor of local attn will be splited into `hidden_states` and
162
+ # `encoder_hidden_states` after each block forward, thus both of them
163
+ # will be automatically splited by all2all comm op after local attn.
164
+ "transformer_blocks.0": {
165
+ "hidden_states": ContextParallelInput(
166
+ split_dim=1, expected_dims=3, split_output=False
167
+ ),
168
+ "encoder_hidden_states": ContextParallelInput(
169
+ split_dim=1, expected_dims=3, split_output=False
170
+ ),
171
+ },
172
+ # Pattern of the image_rotary_emb, split at every block, because the it
173
+ # is not automatically splited by all2all comm op and keep un-splited
174
+ # while the block forward finished:
175
+ # un-split input -> split output
176
+ # -> after block forward
177
+ # -> un-split input
178
+ # un-split input -> split output
179
+ # ...
180
+ "transformer_blocks.*": {
181
+ "image_rotary_emb": [
182
+ ContextParallelInput(
183
+ split_dim=0, expected_dims=2, split_output=False
184
+ ),
185
+ ContextParallelInput(
186
+ split_dim=0, expected_dims=2, split_output=False
187
+ ),
188
+ ],
189
+ },
190
+ # transformer forward while using CP, since it is not splited here.
191
+ # Then, the final proj_out will gather the splited output.
192
+ # splited input (previous splited output)
193
+ # -> all gather
194
+ # -> un-split output
195
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
196
+ }
197
+ return _cp_plan
198
+
199
+
200
+ @functools.wraps(CogView4AttnProcessor.__call__)
201
+ def __patch_CogView4AttnProcessor__call__(
202
+ self: CogView4AttnProcessor,
203
+ attn: Attention,
204
+ hidden_states: torch.Tensor,
205
+ encoder_hidden_states: torch.Tensor,
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
208
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
209
+ dtype = encoder_hidden_states.dtype
210
+
211
+ batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
212
+ batch_size, image_seq_length, embed_dim = hidden_states.shape
213
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
214
+
215
+ # 1. QKV projections
216
+ query = attn.to_q(hidden_states)
217
+ key = attn.to_k(hidden_states)
218
+ value = attn.to_v(hidden_states)
219
+
220
+ # NOTE(DefTruth): no transpose
221
+ query = query.unflatten(2, (attn.heads, -1))
222
+ key = key.unflatten(2, (attn.heads, -1))
223
+ value = value.unflatten(2, (attn.heads, -1))
224
+
225
+ # 2. QK normalization
226
+ if attn.norm_q is not None:
227
+ query = attn.norm_q(query).to(dtype=dtype)
228
+ if attn.norm_k is not None:
229
+ key = attn.norm_k(key).to(dtype=dtype)
230
+
231
+ # 3. Rotational positional embeddings applied to latent stream
232
+ if image_rotary_emb is not None:
233
+ from diffusers.models.embeddings import apply_rotary_emb
234
+
235
+ query[:, text_seq_length:] = apply_rotary_emb(
236
+ query[:, text_seq_length:],
237
+ image_rotary_emb,
238
+ use_real_unbind_dim=-2,
239
+ sequence_dim=1,
240
+ )
241
+ key[:, text_seq_length:] = apply_rotary_emb(
242
+ key[:, text_seq_length:],
243
+ image_rotary_emb,
244
+ use_real_unbind_dim=-2,
245
+ sequence_dim=1,
246
+ )
247
+
248
+ # 4. Attention
249
+ if attention_mask is not None:
250
+ text_attn_mask = attention_mask
251
+ assert (
252
+ text_attn_mask.dim() == 2
253
+ ), "the shape of text_attn_mask should be (batch_size, text_seq_length)"
254
+ text_attn_mask = text_attn_mask.float().to(query.device)
255
+ mix_attn_mask = torch.ones(
256
+ (batch_size, text_seq_length + image_seq_length),
257
+ device=query.device,
258
+ )
259
+ mix_attn_mask[:, :text_seq_length] = text_attn_mask # [B, seq_len]
260
+ # TODO(DefTruth): Permute mix_attn_mask if context parallel is used.
261
+ # For example, if work size = 2: [E, H] -> [E_0, H_0, E_1, H_1]
262
+ mix_attn_mask = mix_attn_mask.unsqueeze(2) # [B, seq_len, 1]
263
+ attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(
264
+ 1, 2
265
+ ) # [B, seq_len, seq_len]
266
+ attention_mask = (
267
+ (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
268
+ ) # [B, 1, seq_len, seq_len]
269
+ if (
270
+ hasattr(self, "_parallel_config")
271
+ and self._parallel_config is not None
272
+ ):
273
+ raise NotImplementedError(
274
+ "Attention mask with context parallelism for CogView4 "
275
+ "is not implemented yet."
276
+ )
277
+
278
+ # NOTE(DefTruth): Apply dispatch_attention_fn instead of sdpa directly
279
+ hidden_states = dispatch_attention_fn(
280
+ query,
281
+ key,
282
+ value,
283
+ attn_mask=attention_mask,
284
+ dropout_p=0.0,
285
+ is_causal=False,
286
+ backend=getattr(self, "_attention_backend", None),
287
+ parallel_config=getattr(self, "_parallel_config", None),
288
+ )
289
+ hidden_states = hidden_states.flatten(2, 3)
290
+ hidden_states = hidden_states.type_as(query)
291
+
292
+ # 5. Output projection
293
+ hidden_states = attn.to_out[0](hidden_states)
294
+ hidden_states = attn.to_out[1](hidden_states)
295
+
296
+ encoder_hidden_states, hidden_states = hidden_states.split(
297
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
298
+ )
299
+ return hidden_states, encoder_hidden_states
@@ -0,0 +1,123 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ from diffusers.models.transformers.consisid_transformer_3d import (
5
+ ConsisIDTransformer3DModel,
6
+ )
7
+ from diffusers.models.transformers.cogvideox_transformer_3d import (
8
+ CogVideoXAttnProcessor2_0,
9
+ )
10
+
11
+ try:
12
+ from diffusers.models._modeling_parallel import (
13
+ ContextParallelInput,
14
+ ContextParallelOutput,
15
+ ContextParallelModelPlan,
16
+ )
17
+ except ImportError:
18
+ raise ImportError(
19
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
20
+ "Please install latest version of diffusers from source: \n"
21
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
22
+ )
23
+ from .cp_plan_registers import (
24
+ ContextParallelismPlanner,
25
+ ContextParallelismPlannerRegister,
26
+ )
27
+ from .cp_plan_cogvideox import __patch_CogVideoXAttnProcessor2_0__call__
28
+
29
+ from cache_dit.logger import init_logger
30
+
31
+ logger = init_logger(__name__)
32
+
33
+
34
+ @ContextParallelismPlannerRegister.register("ConsisID")
35
+ class CosisIDContextParallelismPlanner(ContextParallelismPlanner):
36
+ def apply(
37
+ self,
38
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
39
+ **kwargs,
40
+ ) -> ContextParallelModelPlan:
41
+
42
+ # NOTE: Diffusers native CP plan still not supported
43
+ # for ConsisID now.
44
+ self._cp_planner_preferred_native_diffusers = False
45
+
46
+ if (
47
+ transformer is not None
48
+ and self._cp_planner_preferred_native_diffusers
49
+ ):
50
+ assert isinstance(
51
+ transformer, ConsisIDTransformer3DModel
52
+ ), "Transformer must be an instance of ConsisIDTransformer3DModel"
53
+ if hasattr(transformer, "_cp_plan"):
54
+ if transformer._cp_plan is not None:
55
+ return transformer._cp_plan
56
+
57
+ # ConsisID uses the same attention processor as CogVideoX.
58
+ CogVideoXAttnProcessor2_0.__call__ = (
59
+ __patch_CogVideoXAttnProcessor2_0__call__
60
+ )
61
+ # Also need to patch the parallel config and attention backend
62
+ if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
63
+ CogVideoXAttnProcessor2_0._parallel_config = None
64
+ if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
65
+ CogVideoXAttnProcessor2_0._attention_backend = None
66
+
67
+ # Otherwise, use the custom CP plan defined here, this maybe
68
+ # a little different from the native diffusers implementation
69
+ # for some models.
70
+ _cp_plan = {
71
+ # Pattern of transformer_blocks.0, split_output=False:
72
+ # un-split input -> split -> to_qkv/...
73
+ # -> all2all
74
+ # -> attn (local head, full seqlen)
75
+ # -> all2all
76
+ # -> splited output
77
+ # Pattern of the rest transformer_blocks, split_output=False:
78
+ # splited input (previous splited output) -> to_qkv/...
79
+ # -> all2all
80
+ # -> attn (local head, full seqlen)
81
+ # -> all2all
82
+ # -> splited output
83
+ # The `encoder_hidden_states` will be changed after each block forward,
84
+ # so we need to split it at the first block, and keep it splited (namely,
85
+ # automatically split by the all2all op after attn) for the rest blocks.
86
+ # The `out` tensor of local attn will be splited into `hidden_states` and
87
+ # `encoder_hidden_states` after each block forward, thus both of them
88
+ # will be automatically splited by all2all comm op after local attn.
89
+ "transformer_blocks.0": {
90
+ "hidden_states": ContextParallelInput(
91
+ split_dim=1, expected_dims=3, split_output=False
92
+ ),
93
+ "encoder_hidden_states": ContextParallelInput(
94
+ split_dim=1, expected_dims=3, split_output=False
95
+ ),
96
+ },
97
+ # Pattern of the image_rotary_emb, split at every block, because the it
98
+ # is not automatically splited by all2all comm op and keep un-splited
99
+ # while the block forward finished:
100
+ # un-split input -> split output
101
+ # -> after block forward
102
+ # -> un-split input
103
+ # un-split input -> split output
104
+ # ...
105
+ "transformer_blocks.*": {
106
+ "image_rotary_emb": [
107
+ ContextParallelInput(
108
+ split_dim=0, expected_dims=2, split_output=False
109
+ ),
110
+ ContextParallelInput(
111
+ split_dim=0, expected_dims=2, split_output=False
112
+ ),
113
+ ],
114
+ },
115
+ # NOTE: We should gather both hidden_states and encoder_hidden_states
116
+ # at the end of the last block. Because the subsequent op is:
117
+ # hidden_states = torch.cat([encoder_hidden_states, hidden_states])
118
+ f"transformer_blocks.{len(transformer.transformer_blocks) - 1}": [
119
+ ContextParallelOutput(gather_dim=1, expected_dims=3),
120
+ ContextParallelOutput(gather_dim=1, expected_dims=3),
121
+ ],
122
+ }
123
+ return _cp_plan
@@ -0,0 +1,94 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ from diffusers.models.transformers.dit_transformer_2d import (
5
+ DiTTransformer2DModel,
6
+ )
7
+ from diffusers.models.attention_processor import (
8
+ Attention,
9
+ AttnProcessor2_0,
10
+ ) # sdpa
11
+
12
+ try:
13
+ from diffusers.models._modeling_parallel import (
14
+ ContextParallelInput,
15
+ ContextParallelOutput,
16
+ ContextParallelModelPlan,
17
+ )
18
+ except ImportError:
19
+ raise ImportError(
20
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
21
+ "Please install latest version of diffusers from source: \n"
22
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
23
+ )
24
+ from .cp_plan_registers import (
25
+ ContextParallelismPlanner,
26
+ ContextParallelismPlannerRegister,
27
+ )
28
+ from .cp_plan_pixart import (
29
+ __patch_AttnProcessor2_0__call__,
30
+ __patch_Attention_prepare_attention_mask__,
31
+ )
32
+
33
+
34
+ from cache_dit.logger import init_logger
35
+
36
+ logger = init_logger(__name__)
37
+
38
+
39
+ @ContextParallelismPlannerRegister.register("DiT")
40
+ class DiTContextParallelismPlanner(ContextParallelismPlanner):
41
+ def apply(
42
+ self,
43
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
44
+ **kwargs,
45
+ ) -> ContextParallelModelPlan:
46
+ assert transformer is not None, "Transformer must be provided."
47
+ assert isinstance(
48
+ transformer, DiTTransformer2DModel
49
+ ), "Transformer must be an instance of DiTTransformer2DModel"
50
+
51
+ self._cp_planner_preferred_native_diffusers = False
52
+
53
+ if (
54
+ transformer is not None
55
+ and self._cp_planner_preferred_native_diffusers
56
+ ):
57
+ if hasattr(transformer, "_cp_plan"):
58
+ if transformer._cp_plan is not None:
59
+ return transformer._cp_plan
60
+
61
+ # Apply monkey patch to fix attention mask preparation at class level
62
+ Attention.prepare_attention_mask = (
63
+ __patch_Attention_prepare_attention_mask__
64
+ )
65
+ AttnProcessor2_0.__call__ = __patch_AttnProcessor2_0__call__
66
+ if not hasattr(AttnProcessor2_0, "_parallel_config"):
67
+ AttnProcessor2_0._parallel_config = None
68
+ if not hasattr(AttnProcessor2_0, "_attention_backend"):
69
+ AttnProcessor2_0._attention_backend = None
70
+
71
+ # Otherwise, use the custom CP plan defined here, this maybe
72
+ # a little different from the native diffusers implementation
73
+ # for some models.
74
+
75
+ _cp_plan = {
76
+ # Pattern of transformer_blocks.0, split_output=False:
77
+ # un-split input -> split -> to_qkv/...
78
+ # -> all2all
79
+ # -> attn (local head, full seqlen)
80
+ # -> all2all
81
+ # -> splited output
82
+ # (only split hidden_states, not encoder_hidden_states)
83
+ "transformer_blocks.0": {
84
+ "hidden_states": ContextParallelInput(
85
+ split_dim=1, expected_dims=3, split_output=False
86
+ ),
87
+ },
88
+ # Then, the final proj_out will gather the splited output.
89
+ # splited input (previous splited output)
90
+ # -> all gather
91
+ # -> un-split output
92
+ "proj_out_2": ContextParallelOutput(gather_dim=1, expected_dims=3),
93
+ }
94
+ return _cp_plan
@@ -0,0 +1,88 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ from diffusers import FluxTransformer2DModel
5
+
6
+ try:
7
+ from diffusers.models._modeling_parallel import (
8
+ ContextParallelInput,
9
+ ContextParallelOutput,
10
+ ContextParallelModelPlan,
11
+ )
12
+ except ImportError:
13
+ raise ImportError(
14
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
15
+ "Please install latest version of diffusers from source: \n"
16
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
17
+ )
18
+ from .cp_plan_registers import (
19
+ ContextParallelismPlanner,
20
+ ContextParallelismPlannerRegister,
21
+ )
22
+
23
+ from cache_dit.logger import init_logger
24
+
25
+ logger = init_logger(__name__)
26
+
27
+
28
+ @ContextParallelismPlannerRegister.register("Flux")
29
+ class FluxContextParallelismPlanner(ContextParallelismPlanner):
30
+ def apply(
31
+ self,
32
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
33
+ **kwargs,
34
+ ) -> ContextParallelModelPlan:
35
+ if (
36
+ transformer is not None
37
+ and self._cp_planner_preferred_native_diffusers
38
+ ):
39
+ assert isinstance(
40
+ transformer, FluxTransformer2DModel
41
+ ), "Transformer must be an instance of FluxTransformer2DModel"
42
+ if hasattr(transformer, "_cp_plan"):
43
+ if transformer._cp_plan is not None:
44
+ return transformer._cp_plan
45
+
46
+ # Otherwise, use the custom CP plan defined here, this maybe
47
+ # a little different from the native diffusers implementation
48
+ # for some models.
49
+ _cp_plan = {
50
+ # Here is a Transformer level CP plan for Flux, which will
51
+ # only apply the only 1 split hook (pre_forward) on the forward
52
+ # of Transformer, and gather the output after Transformer forward.
53
+ # Pattern of transformer forward, split_output=False:
54
+ # un-split input -> splited input (inside transformer)
55
+ # Pattern of the transformer_blocks, single_transformer_blocks:
56
+ # splited input (previous splited output) -> to_qkv/...
57
+ # -> all2all
58
+ # -> attn (local head, full seqlen)
59
+ # -> all2all
60
+ # -> splited output
61
+ # The `hidden_states` and `encoder_hidden_states` will still keep
62
+ # itself splited after block forward (namely, automatic split by
63
+ # the all2all comm op after attn) for the all blocks.
64
+ # img_ids and txt_ids will only be splited once at the very beginning,
65
+ # and keep splited through the whole transformer forward. The all2all
66
+ # comm op only happens on the `out` tensor after local attn not on
67
+ # img_ids and txt_ids.
68
+ "": {
69
+ "hidden_states": ContextParallelInput(
70
+ split_dim=1, expected_dims=3, split_output=False
71
+ ),
72
+ "encoder_hidden_states": ContextParallelInput(
73
+ split_dim=1, expected_dims=3, split_output=False
74
+ ),
75
+ "img_ids": ContextParallelInput(
76
+ split_dim=0, expected_dims=2, split_output=False
77
+ ),
78
+ "txt_ids": ContextParallelInput(
79
+ split_dim=0, expected_dims=2, split_output=False
80
+ ),
81
+ },
82
+ # Then, the final proj_out will gather the splited output.
83
+ # splited input (previous splited output)
84
+ # -> all gather
85
+ # -> un-split output
86
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
87
+ }
88
+ return _cp_plan