cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
1
+ import torch
2
+ import functools
3
+ from typing import Optional
4
+ import torch.nn.functional as F
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+ from diffusers.models.transformers.pixart_transformer_2d import (
7
+ PixArtTransformer2DModel,
8
+ )
9
+ from diffusers.models.attention_processor import (
10
+ Attention,
11
+ AttnProcessor2_0,
12
+ ) # sdpa
13
+ from diffusers.utils import deprecate
14
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
15
+
16
+ try:
17
+ from diffusers.models._modeling_parallel import (
18
+ ContextParallelInput,
19
+ ContextParallelOutput,
20
+ ContextParallelModelPlan,
21
+ )
22
+ except ImportError:
23
+ raise ImportError(
24
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
25
+ "Please install latest version of diffusers from source: \n"
26
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
27
+ )
28
+ from .cp_plan_registers import (
29
+ ContextParallelismPlanner,
30
+ ContextParallelismPlannerRegister,
31
+ )
32
+
33
+ from cache_dit.logger import init_logger
34
+
35
+ logger = init_logger(__name__)
36
+
37
+
38
+ @ContextParallelismPlannerRegister.register("PixArt")
39
+ class PixArtContextParallelismPlanner(ContextParallelismPlanner):
40
+ def apply(
41
+ self,
42
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
43
+ **kwargs,
44
+ ) -> ContextParallelModelPlan:
45
+ assert transformer is not None, "Transformer must be provided."
46
+ assert isinstance(
47
+ transformer, PixArtTransformer2DModel
48
+ ), "Transformer must be an instance of PixArtTransformer2DModel"
49
+
50
+ self._cp_planner_preferred_native_diffusers = False
51
+
52
+ if (
53
+ transformer is not None
54
+ and self._cp_planner_preferred_native_diffusers
55
+ ):
56
+ if hasattr(transformer, "_cp_plan"):
57
+ if transformer._cp_plan is not None:
58
+ return transformer._cp_plan
59
+
60
+ # Apply monkey patch to fix attention mask preparation at class level
61
+ Attention.prepare_attention_mask = (
62
+ __patch_Attention_prepare_attention_mask__
63
+ )
64
+ AttnProcessor2_0.__call__ = __patch_AttnProcessor2_0__call__
65
+ if not hasattr(AttnProcessor2_0, "_parallel_config"):
66
+ AttnProcessor2_0._parallel_config = None
67
+ if not hasattr(AttnProcessor2_0, "_attention_backend"):
68
+ AttnProcessor2_0._attention_backend = None
69
+
70
+ # Otherwise, use the custom CP plan defined here, this maybe
71
+ # a little different from the native diffusers implementation
72
+ # for some models.
73
+
74
+ _cp_plan = {
75
+ # Pattern of transformer_blocks.0, split_output=False:
76
+ # un-split input -> split -> to_qkv/...
77
+ # -> all2all
78
+ # -> attn (local head, full seqlen)
79
+ # -> all2all
80
+ # -> splited output
81
+ # (only split hidden_states, not encoder_hidden_states)
82
+ "transformer_blocks.0": {
83
+ "hidden_states": ContextParallelInput(
84
+ split_dim=1, expected_dims=3, split_output=False
85
+ ),
86
+ },
87
+ # Pattern of the all blocks, split_output=False:
88
+ # un-split input -> split -> to_qkv/...
89
+ # -> all2all
90
+ # -> attn (local head, full seqlen)
91
+ # -> all2all
92
+ # -> splited output
93
+ # (only split encoder_hidden_states, not hidden_states.
94
+ # hidden_states has been automatically split in previous
95
+ # block by all2all comm op after attn)
96
+ # The `encoder_hidden_states` will [NOT] be changed after each block forward,
97
+ # so we need to split it at [ALL] block by the inserted split hook.
98
+ "transformer_blocks.*": {
99
+ "encoder_hidden_states": ContextParallelInput(
100
+ split_dim=1, expected_dims=3, split_output=False
101
+ ),
102
+ },
103
+ # Then, the final proj_out will gather the splited output.
104
+ # splited input (previous splited output)
105
+ # -> all gather
106
+ # -> un-split output
107
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
108
+ }
109
+ return _cp_plan
110
+
111
+
112
+ @functools.wraps(Attention.prepare_attention_mask)
113
+ def __patch_Attention_prepare_attention_mask__(
114
+ self: Attention,
115
+ attention_mask: torch.Tensor,
116
+ target_length: int,
117
+ batch_size: int,
118
+ out_dim: int = 3,
119
+ # NOTE(DefTruth): Allow specifying head_size for CP
120
+ head_size: Optional[int] = None,
121
+ ) -> torch.Tensor:
122
+ if head_size is None:
123
+ head_size = self.heads
124
+ if attention_mask is None:
125
+ return attention_mask
126
+
127
+ current_length: int = attention_mask.shape[-1]
128
+ if current_length != target_length:
129
+ if attention_mask.device.type == "mps":
130
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
131
+ # Instead, we can manually construct the padding tensor.
132
+ padding_shape = (
133
+ attention_mask.shape[0],
134
+ attention_mask.shape[1],
135
+ target_length,
136
+ )
137
+ padding = torch.zeros(
138
+ padding_shape,
139
+ dtype=attention_mask.dtype,
140
+ device=attention_mask.device,
141
+ )
142
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
143
+ else:
144
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
145
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
146
+ # remaining_length: int = target_length - current_length
147
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
148
+ attention_mask = F.pad(
149
+ attention_mask, (0, target_length), value=0.0
150
+ )
151
+
152
+ if out_dim == 3:
153
+ if attention_mask.shape[0] < batch_size * head_size:
154
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
155
+ elif out_dim == 4:
156
+ attention_mask = attention_mask.unsqueeze(1)
157
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
158
+
159
+ return attention_mask
160
+
161
+
162
+ @functools.wraps(AttnProcessor2_0.__call__)
163
+ def __patch_AttnProcessor2_0__call__(
164
+ self: AttnProcessor2_0,
165
+ attn: Attention,
166
+ hidden_states: torch.Tensor,
167
+ encoder_hidden_states: Optional[torch.Tensor] = None,
168
+ attention_mask: Optional[torch.Tensor] = None,
169
+ temb: Optional[torch.Tensor] = None,
170
+ *args,
171
+ **kwargs,
172
+ ) -> torch.Tensor:
173
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
174
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
175
+ deprecate("scale", "1.0.0", deprecation_message)
176
+
177
+ residual = hidden_states
178
+ if attn.spatial_norm is not None:
179
+ hidden_states = attn.spatial_norm(hidden_states, temb)
180
+
181
+ input_ndim = hidden_states.ndim
182
+
183
+ if input_ndim == 4:
184
+ batch_size, channel, height, width = hidden_states.shape
185
+ hidden_states = hidden_states.view(
186
+ batch_size, channel, height * width
187
+ ).transpose(1, 2)
188
+
189
+ batch_size, sequence_length, _ = (
190
+ hidden_states.shape
191
+ if encoder_hidden_states is None
192
+ else encoder_hidden_states.shape
193
+ )
194
+
195
+ if attention_mask is not None:
196
+ if self._parallel_config is None:
197
+ attention_mask = attn.prepare_attention_mask(
198
+ attention_mask, sequence_length, batch_size
199
+ )
200
+ attention_mask = attention_mask.view(
201
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
202
+ )
203
+ else:
204
+ # NOTE(DefTruth): Fix attention mask preparation for context parallelism.
205
+ # Please note that in context parallelism, the sequence_length is the local
206
+ # sequence length on each rank. So we need to adjust the target_length
207
+ # accordingly. The head_size is also adjusted based on the world size
208
+ # in order to make sdpa work correctly, otherwise, the sdpa op will raise
209
+ # error due to the mismatch between attention_mask shape and expected shape.
210
+ cp_config = getattr(
211
+ self._parallel_config, "context_parallel_config", None
212
+ )
213
+ if cp_config is not None and cp_config._world_size > 1:
214
+ head_size = attn.heads // cp_config._world_size
215
+ attention_mask = attn.prepare_attention_mask(
216
+ attention_mask,
217
+ sequence_length * cp_config._world_size,
218
+ batch_size,
219
+ 3,
220
+ head_size,
221
+ )
222
+ attention_mask = attention_mask.view(
223
+ batch_size, head_size, -1, attention_mask.shape[-1]
224
+ )
225
+
226
+ if attn.group_norm is not None:
227
+ hidden_states = attn.group_norm(
228
+ hidden_states.transpose(1, 2)
229
+ ).transpose(1, 2)
230
+
231
+ query = attn.to_q(hidden_states)
232
+
233
+ if encoder_hidden_states is None:
234
+ encoder_hidden_states = hidden_states
235
+ elif attn.norm_cross:
236
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
237
+ encoder_hidden_states
238
+ )
239
+
240
+ key = attn.to_k(encoder_hidden_states)
241
+ value = attn.to_v(encoder_hidden_states)
242
+
243
+ inner_dim = key.shape[-1]
244
+ head_dim = inner_dim // attn.heads
245
+
246
+ # NOTE(DefTruth): no transpose now
247
+ query = query.view(batch_size, -1, attn.heads, head_dim)
248
+ key = key.view(batch_size, -1, attn.heads, head_dim)
249
+ value = value.view(batch_size, -1, attn.heads, head_dim)
250
+
251
+ if attn.norm_q is not None:
252
+ query = attn.norm_q(query)
253
+ if attn.norm_k is not None:
254
+ key = attn.norm_k(key)
255
+
256
+ # NOTE(DefTruth): Use the dispatch_attention_fn to support different backends
257
+ hidden_states = dispatch_attention_fn(
258
+ query,
259
+ key,
260
+ value,
261
+ attn_mask=attention_mask,
262
+ dropout_p=0.0,
263
+ is_causal=False,
264
+ backend=getattr(self, "_attention_backend", None),
265
+ parallel_config=getattr(self, "_parallel_config", None),
266
+ )
267
+ hidden_states = hidden_states.flatten(2, 3)
268
+ hidden_states = hidden_states.to(query.dtype)
269
+
270
+ # linear proj
271
+ hidden_states = attn.to_out[0](hidden_states)
272
+ # dropout
273
+ hidden_states = attn.to_out[1](hidden_states)
274
+
275
+ if input_ndim == 4:
276
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
277
+ batch_size, channel, height, width
278
+ )
279
+
280
+ if attn.residual_connection:
281
+ hidden_states = hidden_states + residual
282
+
283
+ hidden_states = hidden_states / attn.rescale_output_factor
284
+
285
+ return hidden_states
@@ -0,0 +1,104 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+
5
+ try:
6
+ from diffusers.models._modeling_parallel import (
7
+ ContextParallelInput,
8
+ ContextParallelOutput,
9
+ ContextParallelModelPlan,
10
+ )
11
+ except ImportError:
12
+ raise ImportError(
13
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
14
+ "Please install latest version of diffusers from source: \n"
15
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
16
+ )
17
+ from .cp_plan_registers import (
18
+ ContextParallelismPlanner,
19
+ ContextParallelismPlannerRegister,
20
+ )
21
+
22
+ from cache_dit.logger import init_logger
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ @ContextParallelismPlannerRegister.register("QwenImage")
28
+ class QwenImageContextParallelismPlanner(ContextParallelismPlanner):
29
+ def apply(
30
+ self,
31
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
32
+ **kwargs,
33
+ ) -> ContextParallelModelPlan:
34
+
35
+ # NOTE: Set it as False to use custom CP plan defined here.
36
+ self._cp_planner_preferred_native_diffusers = False
37
+
38
+ if (
39
+ transformer is not None
40
+ and self._cp_planner_preferred_native_diffusers
41
+ ):
42
+ from diffusers import QwenImageTransformer2DModel
43
+
44
+ assert isinstance(
45
+ transformer, QwenImageTransformer2DModel
46
+ ), "Transformer must be an instance of QwenImageTransformer2DModel"
47
+ if hasattr(transformer, "_cp_plan"):
48
+ if transformer._cp_plan is not None:
49
+ return transformer._cp_plan
50
+
51
+ # Otherwise, use the custom CP plan defined here, this maybe
52
+ # a little different from the native diffusers implementation
53
+ # for some models.
54
+ _cp_plan = {
55
+ # Here is a Transformer level CP plan for Flux, which will
56
+ # only apply the only 1 split hook (pre_forward) on the forward
57
+ # of Transformer, and gather the output after Transformer forward.
58
+ # Pattern of transformer forward, split_output=False:
59
+ # un-split input -> splited input (inside transformer)
60
+ # Pattern of the transformer_blocks, single_transformer_blocks:
61
+ # splited input (previous splited output) -> to_qkv/...
62
+ # -> all2all
63
+ # -> attn (local head, full seqlen)
64
+ # -> all2all
65
+ # -> splited output
66
+ # The `hidden_states` and `encoder_hidden_states` will still keep
67
+ # itself splited after block forward (namely, automatic split by
68
+ # the all2all comm op after attn) for the all blocks.
69
+ "": {
70
+ "hidden_states": ContextParallelInput(
71
+ split_dim=1, expected_dims=3, split_output=False
72
+ ),
73
+ # NOTE: Due to the joint attention implementation of
74
+ # QwenImageTransformerBlock, we must split the
75
+ # encoder_hidden_states as well.
76
+ "encoder_hidden_states": ContextParallelInput(
77
+ split_dim=1, expected_dims=3, split_output=False
78
+ ),
79
+ # NOTE: But encoder_hidden_states_mask seems never used in
80
+ # QwenImageTransformerBlock, so we do not split it here.
81
+ # "encoder_hidden_states_mask": ContextParallelInput(
82
+ # split_dim=1, expected_dims=2, split_output=False
83
+ # ),
84
+ },
85
+ # Pattern of pos_embed, split_output=True (split output rather than input):
86
+ # un-split input
87
+ # -> keep input un-split
88
+ # -> rope
89
+ # -> splited output
90
+ "pos_embed": {
91
+ 0: ContextParallelInput(
92
+ split_dim=0, expected_dims=2, split_output=True
93
+ ),
94
+ 1: ContextParallelInput(
95
+ split_dim=0, expected_dims=2, split_output=True
96
+ ),
97
+ },
98
+ # Then, the final proj_out will gather the splited output.
99
+ # splited input (previous splited output)
100
+ # -> all gather
101
+ # -> un-split output
102
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
103
+ }
104
+ return _cp_plan
@@ -0,0 +1,84 @@
1
+ import torch
2
+ import logging
3
+ from abc import abstractmethod
4
+ from typing import Optional
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+
7
+ try:
8
+ from diffusers.models._modeling_parallel import (
9
+ ContextParallelModelPlan,
10
+ )
11
+ except ImportError:
12
+ raise ImportError(
13
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
14
+ "Please install latest version of diffusers from source: \n"
15
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
16
+ )
17
+
18
+ from cache_dit.logger import init_logger
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ __all__ = [
24
+ "ContextParallelismPlanner",
25
+ "ContextParallelismPlannerRegister",
26
+ ]
27
+
28
+
29
+ class ContextParallelismPlanner:
30
+ # Prefer native diffusers implementation if available
31
+ _cp_planner_preferred_native_diffusers: bool = True
32
+
33
+ @abstractmethod
34
+ def apply(
35
+ self,
36
+ # NOTE: Keep this kwarg for future extensions
37
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
38
+ **kwargs,
39
+ ) -> ContextParallelModelPlan:
40
+ # NOTE: This method should only return the CP plan dictionary.
41
+ raise NotImplementedError(
42
+ "apply method must be implemented by subclasses"
43
+ )
44
+
45
+
46
+ class ContextParallelismPlannerRegister:
47
+ _cp_planner_registry: dict[str, ContextParallelismPlanner] = {}
48
+
49
+ @classmethod
50
+ def register(cls, name: str):
51
+ def decorator(planner_cls: type[ContextParallelismPlanner]):
52
+ assert (
53
+ name not in cls._cp_planner_registry
54
+ ), f"ContextParallelismPlanner with name {name} is already registered."
55
+ if logger.isEnabledFor(logging.DEBUG):
56
+ logger.debug(f"Registering ContextParallelismPlanner: {name}")
57
+ cls._cp_planner_registry[name] = planner_cls
58
+ return planner_cls
59
+
60
+ return decorator
61
+
62
+ @classmethod
63
+ def get_planner(
64
+ cls, transformer: str | torch.nn.Module | ModelMixin
65
+ ) -> type[ContextParallelismPlanner]:
66
+ if isinstance(transformer, (torch.nn.Module, ModelMixin)):
67
+ name = transformer.__class__.__name__
68
+ else:
69
+ name = transformer
70
+ planner_cls = None
71
+ for planner_name in cls._cp_planner_registry:
72
+ if name.startswith(planner_name):
73
+ planner_cls = cls._cp_planner_registry.get(planner_name)
74
+ break
75
+ if planner_cls is None:
76
+ raise ValueError(f"No planner registered under name: {name}")
77
+ return planner_cls
78
+
79
+ @classmethod
80
+ def supported_planners(
81
+ cls,
82
+ ) -> tuple[int, list[str]]:
83
+ val_planners = cls._cp_planner_registry.keys()
84
+ return len(val_planners), [p for p in val_planners]
@@ -0,0 +1,101 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+
5
+ try:
6
+ from diffusers.models._modeling_parallel import (
7
+ ContextParallelInput,
8
+ ContextParallelOutput,
9
+ ContextParallelModelPlan,
10
+ )
11
+ except ImportError:
12
+ raise ImportError(
13
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
14
+ "Please install latest version of diffusers from source: \n"
15
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
16
+ )
17
+ from .cp_plan_registers import (
18
+ ContextParallelismPlanner,
19
+ ContextParallelismPlannerRegister,
20
+ )
21
+
22
+ from cache_dit.logger import init_logger
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ # TODO: Add WanVACETransformer3DModel context parallelism planner.
28
+ # NOTE: Maybe use full name to avoid name conflict between
29
+ # WanTransformer3DModel and WanVACETransformer3DModel?
30
+ @ContextParallelismPlannerRegister.register("WanTransformer3D")
31
+ class WanContextParallelismPlanner(ContextParallelismPlanner):
32
+ def apply(
33
+ self,
34
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
35
+ **kwargs,
36
+ ) -> ContextParallelModelPlan:
37
+ if (
38
+ transformer is not None
39
+ and self._cp_planner_preferred_native_diffusers
40
+ ):
41
+ from diffusers import WanTransformer3DModel
42
+
43
+ assert isinstance(
44
+ transformer, WanTransformer3DModel
45
+ ), "Transformer must be an instance of WanTransformer3DModel"
46
+ if hasattr(transformer, "_cp_plan"):
47
+ if transformer._cp_plan is not None:
48
+ return transformer._cp_plan
49
+
50
+ # Otherwise, use the custom CP plan defined here, this maybe
51
+ # a little different from the native diffusers implementation
52
+ # for some models.
53
+ _cp_plan = {
54
+ # Pattern of rope, split_output=True (split output rather than input):
55
+ # un-split input
56
+ # -> keep input un-split
57
+ # -> rope
58
+ # -> splited output
59
+ "rope": {
60
+ 0: ContextParallelInput(
61
+ split_dim=1, expected_dims=4, split_output=True
62
+ ),
63
+ 1: ContextParallelInput(
64
+ split_dim=1, expected_dims=4, split_output=True
65
+ ),
66
+ },
67
+ # Pattern of blocks.0, split_output=False:
68
+ # un-split input -> split -> to_qkv/...
69
+ # -> all2all
70
+ # -> attn (local head, full seqlen)
71
+ # -> all2all
72
+ # -> splited output
73
+ # (only split hidden_states, not encoder_hidden_states)
74
+ "blocks.0": {
75
+ "hidden_states": ContextParallelInput(
76
+ split_dim=1, expected_dims=3, split_output=False
77
+ ),
78
+ },
79
+ # Pattern of the all blocks, split_output=False:
80
+ # un-split input -> split -> to_qkv/...
81
+ # -> all2all
82
+ # -> attn (local head, full seqlen)
83
+ # -> all2all
84
+ # -> splited output
85
+ # (only split encoder_hidden_states, not hidden_states.
86
+ # hidden_states has been automatically split in previous
87
+ # block by all2all comm op after attn)
88
+ # The `encoder_hidden_states` will [NOT] be changed after each block forward,
89
+ # so we need to split it at [ALL] block by the inserted split hook.
90
+ "blocks.*": {
91
+ "encoder_hidden_states": ContextParallelInput(
92
+ split_dim=1, expected_dims=3, split_output=False
93
+ ),
94
+ },
95
+ # Then, the final proj_out will gather the splited output.
96
+ # splited input (previous splited output)
97
+ # -> all gather
98
+ # -> un-split output
99
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
100
+ }
101
+ return _cp_plan
@@ -0,0 +1,117 @@
1
+ # Docstring references: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/_modeling_parallel.py#L185
2
+ # A dictionary where keys denote the input to be split across context parallel region, and the
3
+ # value denotes the sharding configuration.
4
+ # If the key is a string, it denotes the name of the parameter in the forward function.
5
+ # If the key is an integer, split_output must be set to True, and it denotes the index of the output
6
+ # to be split across context parallel region.
7
+ # ContextParallelInputType = Dict[
8
+ # Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
9
+ # ]
10
+
11
+ # A dictionary where keys denote the output to be gathered across context parallel region, and the
12
+ # value denotes the gathering configuration.
13
+ # ContextParallelOutputType = Union[
14
+ # ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
15
+ # ]
16
+
17
+ # A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
18
+ # the module should be split/gathered across context parallel region.
19
+ # ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
20
+
21
+ # Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
22
+ #
23
+ # Each model should define a _cp_plan attribute that contains information on how to shard/gather
24
+ # tensors at different stages of the forward:
25
+ #
26
+ # ```python
27
+ # _cp_plan = {
28
+ # "": {
29
+ # "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
30
+ # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
31
+ # "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
32
+ # },
33
+ # "pos_embed": {
34
+ # 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
35
+ # 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
36
+ # },
37
+ # "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
38
+ # }
39
+ # ```
40
+ #
41
+ # The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
42
+ # split/gathered according to this at the respective module level. Here, the following happens:
43
+ # - "":
44
+ # we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
45
+ # the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
46
+ # - "pos_embed":
47
+ # we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
48
+ # we can individually specify how they should be split
49
+ # - "proj_out":
50
+ # before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
51
+ # layer forward has run).
52
+ #
53
+ # ContextParallelInput:
54
+ # specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
55
+ #
56
+ # ContextParallelOutput:
57
+ # specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
58
+
59
+ from .cp_plan_registers import (
60
+ ContextParallelismPlanner,
61
+ ContextParallelismPlannerRegister,
62
+ )
63
+ from .cp_plan_flux import FluxContextParallelismPlanner
64
+ from .cp_plan_qwen_image import QwenImageContextParallelismPlanner
65
+ from .cp_plan_wan import WanContextParallelismPlanner
66
+ from .cp_plan_ltxvideo import LTXVideoContextParallelismPlanner
67
+ from .cp_plan_hunyuan import HunyuanImageContextParallelismPlanner
68
+ from .cp_plan_hunyuan import HunyuanVideoContextParallelismPlanner
69
+ from .cp_plan_cogvideox import CogVideoXContextParallelismPlanner
70
+ from .cp_plan_cogview import CogView3PlusContextParallelismPlanner
71
+ from .cp_plan_cogview import CogView4ContextParallelismPlanner
72
+ from .cp_plan_cosisid import CosisIDContextParallelismPlanner
73
+ from .cp_plan_chroma import ChromaContextParallelismPlanner
74
+ from .cp_plan_pixart import PixArtContextParallelismPlanner
75
+ from .cp_plan_dit import DiTContextParallelismPlanner
76
+
77
+ try:
78
+ import nunchaku # noqa: F401
79
+
80
+ _nunchaku_available = True
81
+ except ImportError:
82
+ _nunchaku_available = False
83
+
84
+ if _nunchaku_available:
85
+ from .cp_plan_nunchaku import ( # noqa: F401
86
+ NunchakuFluxContextParallelismPlanner,
87
+ )
88
+ from .cp_plan_nunchaku import ( # noqa: F401
89
+ NunchakuQwenImageContextParallelismPlanner,
90
+ )
91
+
92
+
93
+ __all__ = [
94
+ "ContextParallelismPlanner",
95
+ "ContextParallelismPlannerRegister",
96
+ "FluxContextParallelismPlanner",
97
+ "QwenImageContextParallelismPlanner",
98
+ "WanContextParallelismPlanner",
99
+ "LTXVideoContextParallelismPlanner",
100
+ "HunyuanImageContextParallelismPlanner",
101
+ "HunyuanVideoContextParallelismPlanner",
102
+ "CogVideoXContextParallelismPlanner",
103
+ "CogView3PlusContextParallelismPlanner",
104
+ "CogView4ContextParallelismPlanner",
105
+ "CosisIDContextParallelismPlanner",
106
+ "ChromaContextParallelismPlanner",
107
+ "PixArtContextParallelismPlanner",
108
+ "DiTContextParallelismPlanner",
109
+ ]
110
+
111
+ if _nunchaku_available:
112
+ __all__.extend(
113
+ [
114
+ "NunchakuFluxContextParallelismPlanner",
115
+ "NunchakuQwenImageContextParallelismPlanner",
116
+ ]
117
+ )