cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,202 @@
1
+ import torch
2
+ import functools
3
+ from typing import Optional, Tuple
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from diffusers.models.transformers.cogvideox_transformer_3d import (
6
+ CogVideoXAttnProcessor2_0,
7
+ CogVideoXTransformer3DModel,
8
+ )
9
+ from diffusers.models.attention_processor import Attention
10
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
11
+
12
+ try:
13
+ from diffusers.models._modeling_parallel import (
14
+ ContextParallelInput,
15
+ ContextParallelOutput,
16
+ ContextParallelModelPlan,
17
+ )
18
+ except ImportError:
19
+ raise ImportError(
20
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
21
+ "Please install latest version of diffusers from source: \n"
22
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
23
+ )
24
+ from .cp_plan_registers import (
25
+ ContextParallelismPlanner,
26
+ ContextParallelismPlannerRegister,
27
+ )
28
+
29
+ from cache_dit.logger import init_logger
30
+
31
+ logger = init_logger(__name__)
32
+
33
+
34
+ @ContextParallelismPlannerRegister.register("CogVideoX")
35
+ class CogVideoXContextParallelismPlanner(ContextParallelismPlanner):
36
+ def apply(
37
+ self,
38
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
39
+ **kwargs,
40
+ ) -> ContextParallelModelPlan:
41
+
42
+ # NOTE: Diffusers native CP plan still not supported
43
+ # for CogVideoX now.
44
+ self._cp_planner_preferred_native_diffusers = False
45
+
46
+ if (
47
+ transformer is not None
48
+ and self._cp_planner_preferred_native_diffusers
49
+ ):
50
+ assert isinstance(
51
+ transformer, CogVideoXTransformer3DModel
52
+ ), "Transformer must be an instance of CogVideoXTransformer3DModel"
53
+ if hasattr(transformer, "_cp_plan"):
54
+ if transformer._cp_plan is not None:
55
+ return transformer._cp_plan
56
+
57
+ CogVideoXAttnProcessor2_0.__call__ = (
58
+ __patch_CogVideoXAttnProcessor2_0__call__
59
+ )
60
+ # Also need to patch the parallel config and attention backend
61
+ if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
62
+ CogVideoXAttnProcessor2_0._parallel_config = None
63
+ if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
64
+ CogVideoXAttnProcessor2_0._attention_backend = None
65
+
66
+ # Otherwise, use the custom CP plan defined here, this maybe
67
+ # a little different from the native diffusers implementation
68
+ # for some models.
69
+ _cp_plan = {
70
+ # Pattern of transformer_blocks.0, split_output=False:
71
+ # un-split input -> split -> to_qkv/...
72
+ # -> all2all
73
+ # -> attn (local head, full seqlen)
74
+ # -> all2all
75
+ # -> splited output
76
+ # Pattern of the rest transformer_blocks, split_output=False:
77
+ # splited input (previous splited output) -> to_qkv/...
78
+ # -> all2all
79
+ # -> attn (local head, full seqlen)
80
+ # -> all2all
81
+ # -> splited output
82
+ # The `encoder_hidden_states` will be changed after each block forward,
83
+ # so we need to split it at the first block, and keep it splited (namely,
84
+ # automatically split by the all2all op after attn) for the rest blocks.
85
+ # The `out` tensor of local attn will be splited into `hidden_states` and
86
+ # `encoder_hidden_states` after each block forward, thus both of them
87
+ # will be automatically splited by all2all comm op after local attn.
88
+ "transformer_blocks.0": {
89
+ "hidden_states": ContextParallelInput(
90
+ split_dim=1, expected_dims=3, split_output=False
91
+ ),
92
+ "encoder_hidden_states": ContextParallelInput(
93
+ split_dim=1, expected_dims=3, split_output=False
94
+ ),
95
+ },
96
+ # Pattern of the image_rotary_emb, split at every block, because the it
97
+ # is not automatically splited by all2all comm op and keep un-splited
98
+ # while the block forward finished:
99
+ # un-split input -> split output
100
+ # -> after block forward
101
+ # -> un-split input
102
+ # un-split input -> split output
103
+ # ...
104
+ "transformer_blocks.*": {
105
+ "image_rotary_emb": [
106
+ ContextParallelInput(
107
+ split_dim=0, expected_dims=2, split_output=False
108
+ ),
109
+ ContextParallelInput(
110
+ split_dim=0, expected_dims=2, split_output=False
111
+ ),
112
+ ],
113
+ },
114
+ # transformer forward while using CP, since it is not splited here.
115
+ # Then, the final proj_out will gather the splited output.
116
+ # splited input (previous splited output)
117
+ # -> all gather
118
+ # -> un-split output
119
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
120
+ }
121
+ return _cp_plan
122
+
123
+
124
+ @functools.wraps(CogVideoXAttnProcessor2_0.__call__)
125
+ def __patch_CogVideoXAttnProcessor2_0__call__(
126
+ self: CogVideoXAttnProcessor2_0,
127
+ attn: Attention,
128
+ hidden_states: torch.Tensor,
129
+ encoder_hidden_states: torch.Tensor,
130
+ attention_mask: Optional[torch.Tensor] = None,
131
+ image_rotary_emb: Optional[torch.Tensor] = None,
132
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ text_seq_length = encoder_hidden_states.size(1)
134
+
135
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
136
+
137
+ batch_size, sequence_length, _ = hidden_states.shape
138
+
139
+ # NOTE(DefTruth): attention mask is always None in CogVideoX
140
+ if attention_mask is not None:
141
+ attention_mask = attn.prepare_attention_mask(
142
+ attention_mask, sequence_length, batch_size
143
+ )
144
+ attention_mask = attention_mask.view(
145
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
146
+ )
147
+
148
+ query = attn.to_q(hidden_states)
149
+ key = attn.to_k(hidden_states)
150
+ value = attn.to_v(hidden_states)
151
+
152
+ inner_dim = key.shape[-1]
153
+ head_dim = inner_dim // attn.heads
154
+
155
+ # NOTE(DefTruth): no transpose
156
+ query = query.view(batch_size, -1, attn.heads, head_dim)
157
+ key = key.view(batch_size, -1, attn.heads, head_dim)
158
+ value = value.view(batch_size, -1, attn.heads, head_dim)
159
+
160
+ if attn.norm_q is not None:
161
+ query = attn.norm_q(query)
162
+ if attn.norm_k is not None:
163
+ key = attn.norm_k(key)
164
+
165
+ # Apply RoPE if needed
166
+ if image_rotary_emb is not None:
167
+ from diffusers.models.embeddings import apply_rotary_emb
168
+
169
+ query[:, text_seq_length:] = apply_rotary_emb(
170
+ query[:, text_seq_length:],
171
+ image_rotary_emb,
172
+ sequence_dim=1,
173
+ )
174
+ if not attn.is_cross_attention:
175
+ key[:, text_seq_length:] = apply_rotary_emb(
176
+ key[:, text_seq_length:],
177
+ image_rotary_emb,
178
+ sequence_dim=1,
179
+ )
180
+
181
+ # NOTE(DefTruth): Apply dispatch_attention_fn instead of sdpa directly
182
+ hidden_states = dispatch_attention_fn(
183
+ query,
184
+ key,
185
+ value,
186
+ attn_mask=attention_mask,
187
+ dropout_p=0.0,
188
+ is_causal=False,
189
+ backend=getattr(self, "_attention_backend", None),
190
+ parallel_config=getattr(self, "_parallel_config", None),
191
+ )
192
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
193
+
194
+ # linear proj
195
+ hidden_states = attn.to_out[0](hidden_states)
196
+ # dropout
197
+ hidden_states = attn.to_out[1](hidden_states)
198
+
199
+ encoder_hidden_states, hidden_states = hidden_states.split(
200
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
201
+ )
202
+ return hidden_states, encoder_hidden_states
@@ -0,0 +1,299 @@
1
+ import torch
2
+ import functools
3
+ from typing import Optional, Tuple
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from diffusers.models.transformers.transformer_cogview3plus import (
6
+ CogView3PlusTransformer2DModel,
7
+ CogVideoXAttnProcessor2_0,
8
+ )
9
+ from diffusers.models.transformers.transformer_cogview4 import (
10
+ CogView4Transformer2DModel,
11
+ CogView4AttnProcessor,
12
+ )
13
+ from diffusers.models.attention_processor import Attention
14
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
15
+
16
+ try:
17
+ from diffusers.models._modeling_parallel import (
18
+ ContextParallelInput,
19
+ ContextParallelOutput,
20
+ ContextParallelModelPlan,
21
+ )
22
+ except ImportError:
23
+ raise ImportError(
24
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
25
+ "Please install latest version of diffusers from source: \n"
26
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
27
+ )
28
+ from .cp_plan_registers import (
29
+ ContextParallelismPlanner,
30
+ ContextParallelismPlannerRegister,
31
+ )
32
+ from .cp_plan_cogvideox import __patch_CogVideoXAttnProcessor2_0__call__
33
+
34
+ from cache_dit.logger import init_logger
35
+
36
+ logger = init_logger(__name__)
37
+
38
+
39
+ @ContextParallelismPlannerRegister.register("CogView3Plus")
40
+ class CogView3PlusContextParallelismPlanner(ContextParallelismPlanner):
41
+ def apply(
42
+ self,
43
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
44
+ **kwargs,
45
+ ) -> ContextParallelModelPlan:
46
+
47
+ # NOTE: Diffusers native CP plan still not supported
48
+ # for CogView3Plus now.
49
+ self._cp_planner_preferred_native_diffusers = False
50
+
51
+ if (
52
+ transformer is not None
53
+ and self._cp_planner_preferred_native_diffusers
54
+ ):
55
+ assert isinstance(
56
+ transformer, CogView3PlusTransformer2DModel
57
+ ), "Transformer must be an instance of CogView3PlusTransformer2DModel"
58
+ if hasattr(transformer, "_cp_plan"):
59
+ if transformer._cp_plan is not None:
60
+ return transformer._cp_plan
61
+
62
+ # CogView3Plus and CogVideoX share the same attention processor
63
+ CogVideoXAttnProcessor2_0.__call__ = (
64
+ __patch_CogVideoXAttnProcessor2_0__call__
65
+ )
66
+ # Also need to patch the parallel config and attention backend
67
+ if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
68
+ CogVideoXAttnProcessor2_0._parallel_config = None
69
+ if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
70
+ CogVideoXAttnProcessor2_0._attention_backend = None
71
+
72
+ # Otherwise, use the custom CP plan defined here, this maybe
73
+ # a little different from the native diffusers implementation
74
+ # for some models.
75
+ _cp_plan = {
76
+ # Pattern of transformer_blocks.0, split_output=False:
77
+ # un-split input -> split -> to_qkv/...
78
+ # -> all2all
79
+ # -> attn (local head, full seqlen)
80
+ # -> all2all
81
+ # -> splited output
82
+ # Pattern of the rest transformer_blocks, split_output=False:
83
+ # splited input (previous splited output) -> to_qkv/...
84
+ # -> all2all
85
+ # -> attn (local head, full seqlen)
86
+ # -> all2all
87
+ # -> splited output
88
+ # The `encoder_hidden_states` will be changed after each block forward,
89
+ # so we need to split it at the first block, and keep it splited (namely,
90
+ # automatically split by the all2all op after attn) for the rest blocks.
91
+ # The `out` tensor of local attn will be splited into `hidden_states` and
92
+ # `encoder_hidden_states` after each block forward, thus both of them
93
+ # will be automatically splited by all2all comm op after local attn.
94
+ "transformer_blocks.0": {
95
+ "hidden_states": ContextParallelInput(
96
+ split_dim=1, expected_dims=3, split_output=False
97
+ ),
98
+ "encoder_hidden_states": ContextParallelInput(
99
+ split_dim=1, expected_dims=3, split_output=False
100
+ ),
101
+ },
102
+ # transformer forward while using CP, since it is not splited here.
103
+ # Then, the final proj_out will gather the splited output.
104
+ # splited input (previous splited output)
105
+ # -> all gather
106
+ # -> un-split output
107
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
108
+ }
109
+ return _cp_plan
110
+
111
+
112
+ @ContextParallelismPlannerRegister.register("CogView4")
113
+ class CogView4ContextParallelismPlanner(ContextParallelismPlanner):
114
+ def apply(
115
+ self,
116
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
117
+ **kwargs,
118
+ ) -> ContextParallelModelPlan:
119
+
120
+ # NOTE: Diffusers native CP plan still not supported
121
+ # for CogView4 now.
122
+ self._cp_planner_preferred_native_diffusers = False
123
+
124
+ if (
125
+ transformer is not None
126
+ and self._cp_planner_preferred_native_diffusers
127
+ ):
128
+ assert isinstance(
129
+ transformer, CogView4Transformer2DModel
130
+ ), "Transformer must be an instance of CogView4Transformer2DModel"
131
+ if hasattr(transformer, "_cp_plan"):
132
+ if transformer._cp_plan is not None:
133
+ return transformer._cp_plan
134
+
135
+ CogView4AttnProcessor.__call__ = __patch_CogView4AttnProcessor__call__
136
+ # Also need to patch the parallel config and attention backend
137
+ if not hasattr(CogView4AttnProcessor, "_parallel_config"):
138
+ CogView4AttnProcessor._parallel_config = None
139
+ if not hasattr(CogView4AttnProcessor, "_attention_backend"):
140
+ CogView4AttnProcessor._attention_backend = None
141
+
142
+ # Otherwise, use the custom CP plan defined here, this maybe
143
+ # a little different from the native diffusers implementation
144
+ # for some models.
145
+ _cp_plan = {
146
+ # Pattern of transformer_blocks.0, split_output=False:
147
+ # un-split input -> split -> to_qkv/...
148
+ # -> all2all
149
+ # -> attn (local head, full seqlen)
150
+ # -> all2all
151
+ # -> splited output
152
+ # Pattern of the rest transformer_blocks, split_output=False:
153
+ # splited input (previous splited output) -> to_qkv/...
154
+ # -> all2all
155
+ # -> attn (local head, full seqlen)
156
+ # -> all2all
157
+ # -> splited output
158
+ # The `encoder_hidden_states` will be changed after each block forward,
159
+ # so we need to split it at the first block, and keep it splited (namely,
160
+ # automatically split by the all2all op after attn) for the rest blocks.
161
+ # The `out` tensor of local attn will be splited into `hidden_states` and
162
+ # `encoder_hidden_states` after each block forward, thus both of them
163
+ # will be automatically splited by all2all comm op after local attn.
164
+ "transformer_blocks.0": {
165
+ "hidden_states": ContextParallelInput(
166
+ split_dim=1, expected_dims=3, split_output=False
167
+ ),
168
+ "encoder_hidden_states": ContextParallelInput(
169
+ split_dim=1, expected_dims=3, split_output=False
170
+ ),
171
+ },
172
+ # Pattern of the image_rotary_emb, split at every block, because the it
173
+ # is not automatically splited by all2all comm op and keep un-splited
174
+ # while the block forward finished:
175
+ # un-split input -> split output
176
+ # -> after block forward
177
+ # -> un-split input
178
+ # un-split input -> split output
179
+ # ...
180
+ "transformer_blocks.*": {
181
+ "image_rotary_emb": [
182
+ ContextParallelInput(
183
+ split_dim=0, expected_dims=2, split_output=False
184
+ ),
185
+ ContextParallelInput(
186
+ split_dim=0, expected_dims=2, split_output=False
187
+ ),
188
+ ],
189
+ },
190
+ # transformer forward while using CP, since it is not splited here.
191
+ # Then, the final proj_out will gather the splited output.
192
+ # splited input (previous splited output)
193
+ # -> all gather
194
+ # -> un-split output
195
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
196
+ }
197
+ return _cp_plan
198
+
199
+
200
+ @functools.wraps(CogView4AttnProcessor.__call__)
201
+ def __patch_CogView4AttnProcessor__call__(
202
+ self: CogView4AttnProcessor,
203
+ attn: Attention,
204
+ hidden_states: torch.Tensor,
205
+ encoder_hidden_states: torch.Tensor,
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
208
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
209
+ dtype = encoder_hidden_states.dtype
210
+
211
+ batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
212
+ batch_size, image_seq_length, embed_dim = hidden_states.shape
213
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
214
+
215
+ # 1. QKV projections
216
+ query = attn.to_q(hidden_states)
217
+ key = attn.to_k(hidden_states)
218
+ value = attn.to_v(hidden_states)
219
+
220
+ # NOTE(DefTruth): no transpose
221
+ query = query.unflatten(2, (attn.heads, -1))
222
+ key = key.unflatten(2, (attn.heads, -1))
223
+ value = value.unflatten(2, (attn.heads, -1))
224
+
225
+ # 2. QK normalization
226
+ if attn.norm_q is not None:
227
+ query = attn.norm_q(query).to(dtype=dtype)
228
+ if attn.norm_k is not None:
229
+ key = attn.norm_k(key).to(dtype=dtype)
230
+
231
+ # 3. Rotational positional embeddings applied to latent stream
232
+ if image_rotary_emb is not None:
233
+ from diffusers.models.embeddings import apply_rotary_emb
234
+
235
+ query[:, text_seq_length:] = apply_rotary_emb(
236
+ query[:, text_seq_length:],
237
+ image_rotary_emb,
238
+ use_real_unbind_dim=-2,
239
+ sequence_dim=1,
240
+ )
241
+ key[:, text_seq_length:] = apply_rotary_emb(
242
+ key[:, text_seq_length:],
243
+ image_rotary_emb,
244
+ use_real_unbind_dim=-2,
245
+ sequence_dim=1,
246
+ )
247
+
248
+ # 4. Attention
249
+ if attention_mask is not None:
250
+ text_attn_mask = attention_mask
251
+ assert (
252
+ text_attn_mask.dim() == 2
253
+ ), "the shape of text_attn_mask should be (batch_size, text_seq_length)"
254
+ text_attn_mask = text_attn_mask.float().to(query.device)
255
+ mix_attn_mask = torch.ones(
256
+ (batch_size, text_seq_length + image_seq_length),
257
+ device=query.device,
258
+ )
259
+ mix_attn_mask[:, :text_seq_length] = text_attn_mask # [B, seq_len]
260
+ # TODO(DefTruth): Permute mix_attn_mask if context parallel is used.
261
+ # For example, if work size = 2: [E, H] -> [E_0, H_0, E_1, H_1]
262
+ mix_attn_mask = mix_attn_mask.unsqueeze(2) # [B, seq_len, 1]
263
+ attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(
264
+ 1, 2
265
+ ) # [B, seq_len, seq_len]
266
+ attention_mask = (
267
+ (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
268
+ ) # [B, 1, seq_len, seq_len]
269
+ if (
270
+ hasattr(self, "_parallel_config")
271
+ and self._parallel_config is not None
272
+ ):
273
+ raise NotImplementedError(
274
+ "Attention mask with context parallelism for CogView4 "
275
+ "is not implemented yet."
276
+ )
277
+
278
+ # NOTE(DefTruth): Apply dispatch_attention_fn instead of sdpa directly
279
+ hidden_states = dispatch_attention_fn(
280
+ query,
281
+ key,
282
+ value,
283
+ attn_mask=attention_mask,
284
+ dropout_p=0.0,
285
+ is_causal=False,
286
+ backend=getattr(self, "_attention_backend", None),
287
+ parallel_config=getattr(self, "_parallel_config", None),
288
+ )
289
+ hidden_states = hidden_states.flatten(2, 3)
290
+ hidden_states = hidden_states.type_as(query)
291
+
292
+ # 5. Output projection
293
+ hidden_states = attn.to_out[0](hidden_states)
294
+ hidden_states = attn.to_out[1](hidden_states)
295
+
296
+ encoder_hidden_states, hidden_states = hidden_states.split(
297
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
298
+ )
299
+ return hidden_states, encoder_hidden_states
@@ -0,0 +1,123 @@
1
+ import torch
2
+ from typing import Optional
3
+ from diffusers.models.modeling_utils import ModelMixin
4
+ from diffusers.models.transformers.consisid_transformer_3d import (
5
+ ConsisIDTransformer3DModel,
6
+ )
7
+ from diffusers.models.transformers.cogvideox_transformer_3d import (
8
+ CogVideoXAttnProcessor2_0,
9
+ )
10
+
11
+ try:
12
+ from diffusers.models._modeling_parallel import (
13
+ ContextParallelInput,
14
+ ContextParallelOutput,
15
+ ContextParallelModelPlan,
16
+ )
17
+ except ImportError:
18
+ raise ImportError(
19
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
20
+ "Please install latest version of diffusers from source: \n"
21
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
22
+ )
23
+ from .cp_plan_registers import (
24
+ ContextParallelismPlanner,
25
+ ContextParallelismPlannerRegister,
26
+ )
27
+ from .cp_plan_cogvideox import __patch_CogVideoXAttnProcessor2_0__call__
28
+
29
+ from cache_dit.logger import init_logger
30
+
31
+ logger = init_logger(__name__)
32
+
33
+
34
+ @ContextParallelismPlannerRegister.register("ConsisID")
35
+ class CosisIDContextParallelismPlanner(ContextParallelismPlanner):
36
+ def apply(
37
+ self,
38
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
39
+ **kwargs,
40
+ ) -> ContextParallelModelPlan:
41
+
42
+ # NOTE: Diffusers native CP plan still not supported
43
+ # for ConsisID now.
44
+ self._cp_planner_preferred_native_diffusers = False
45
+
46
+ if (
47
+ transformer is not None
48
+ and self._cp_planner_preferred_native_diffusers
49
+ ):
50
+ assert isinstance(
51
+ transformer, ConsisIDTransformer3DModel
52
+ ), "Transformer must be an instance of ConsisIDTransformer3DModel"
53
+ if hasattr(transformer, "_cp_plan"):
54
+ if transformer._cp_plan is not None:
55
+ return transformer._cp_plan
56
+
57
+ # ConsisID uses the same attention processor as CogVideoX.
58
+ CogVideoXAttnProcessor2_0.__call__ = (
59
+ __patch_CogVideoXAttnProcessor2_0__call__
60
+ )
61
+ # Also need to patch the parallel config and attention backend
62
+ if not hasattr(CogVideoXAttnProcessor2_0, "_parallel_config"):
63
+ CogVideoXAttnProcessor2_0._parallel_config = None
64
+ if not hasattr(CogVideoXAttnProcessor2_0, "_attention_backend"):
65
+ CogVideoXAttnProcessor2_0._attention_backend = None
66
+
67
+ # Otherwise, use the custom CP plan defined here, this maybe
68
+ # a little different from the native diffusers implementation
69
+ # for some models.
70
+ _cp_plan = {
71
+ # Pattern of transformer_blocks.0, split_output=False:
72
+ # un-split input -> split -> to_qkv/...
73
+ # -> all2all
74
+ # -> attn (local head, full seqlen)
75
+ # -> all2all
76
+ # -> splited output
77
+ # Pattern of the rest transformer_blocks, split_output=False:
78
+ # splited input (previous splited output) -> to_qkv/...
79
+ # -> all2all
80
+ # -> attn (local head, full seqlen)
81
+ # -> all2all
82
+ # -> splited output
83
+ # The `encoder_hidden_states` will be changed after each block forward,
84
+ # so we need to split it at the first block, and keep it splited (namely,
85
+ # automatically split by the all2all op after attn) for the rest blocks.
86
+ # The `out` tensor of local attn will be splited into `hidden_states` and
87
+ # `encoder_hidden_states` after each block forward, thus both of them
88
+ # will be automatically splited by all2all comm op after local attn.
89
+ "transformer_blocks.0": {
90
+ "hidden_states": ContextParallelInput(
91
+ split_dim=1, expected_dims=3, split_output=False
92
+ ),
93
+ "encoder_hidden_states": ContextParallelInput(
94
+ split_dim=1, expected_dims=3, split_output=False
95
+ ),
96
+ },
97
+ # Pattern of the image_rotary_emb, split at every block, because the it
98
+ # is not automatically splited by all2all comm op and keep un-splited
99
+ # while the block forward finished:
100
+ # un-split input -> split output
101
+ # -> after block forward
102
+ # -> un-split input
103
+ # un-split input -> split output
104
+ # ...
105
+ "transformer_blocks.*": {
106
+ "image_rotary_emb": [
107
+ ContextParallelInput(
108
+ split_dim=0, expected_dims=2, split_output=False
109
+ ),
110
+ ContextParallelInput(
111
+ split_dim=0, expected_dims=2, split_output=False
112
+ ),
113
+ ],
114
+ },
115
+ # NOTE: We should gather both hidden_states and encoder_hidden_states
116
+ # at the end of the last block. Because the subsequent op is:
117
+ # hidden_states = torch.cat([encoder_hidden_states, hidden_states])
118
+ f"transformer_blocks.{len(transformer.transformer_blocks) - 1}": [
119
+ ContextParallelOutput(gather_dim=1, expected_dims=3),
120
+ ContextParallelOutput(gather_dim=1, expected_dims=3),
121
+ ],
122
+ }
123
+ return _cp_plan