cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,264 @@
1
+ import torch
2
+ import functools
3
+ from typing import Optional
4
+ import torch.nn.functional as F
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+ from diffusers.models.transformers.transformer_ltx import (
7
+ LTXVideoTransformer3DModel,
8
+ LTXAttention,
9
+ AttentionModuleMixin,
10
+ LTXVideoAttnProcessor,
11
+ apply_rotary_emb,
12
+ )
13
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
14
+
15
+ try:
16
+ from diffusers.models._modeling_parallel import (
17
+ ContextParallelInput,
18
+ ContextParallelOutput,
19
+ ContextParallelModelPlan,
20
+ )
21
+ except ImportError:
22
+ raise ImportError(
23
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
24
+ "Please install latest version of diffusers from source: \n"
25
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
26
+ )
27
+ from .cp_plan_registers import (
28
+ ContextParallelismPlanner,
29
+ ContextParallelismPlannerRegister,
30
+ )
31
+
32
+ from cache_dit.logger import init_logger
33
+
34
+ logger = init_logger(__name__)
35
+
36
+
37
+ @ContextParallelismPlannerRegister.register("LTXVideo")
38
+ class LTXVideoContextParallelismPlanner(ContextParallelismPlanner):
39
+ def apply(
40
+ self,
41
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
42
+ **kwargs,
43
+ ) -> ContextParallelModelPlan:
44
+ assert transformer is not None, "Transformer must be provided."
45
+ assert isinstance(
46
+ transformer, LTXVideoTransformer3DModel
47
+ ), "Transformer must be an instance of LTXVideoTransformer3DModel"
48
+
49
+ # NOTE: The atttention_mask preparation in LTXAttention while using
50
+ # context parallelism is buggy in diffusers v0.36.0.dev0, so we
51
+ # disable the preference to use native diffusers implementation here.
52
+ self._cp_planner_preferred_native_diffusers = False
53
+
54
+ if (
55
+ transformer is not None
56
+ and self._cp_planner_preferred_native_diffusers
57
+ ):
58
+ if hasattr(transformer, "_cp_plan"):
59
+ if transformer._cp_plan is not None:
60
+ return transformer._cp_plan
61
+
62
+ # Apply monkey patch to fix attention mask preparation at class level
63
+ assert issubclass(LTXAttention, AttentionModuleMixin)
64
+ LTXAttention.prepare_attention_mask = (
65
+ __patch__LTXAttention_prepare_attention_mask__
66
+ )
67
+ LTXVideoAttnProcessor.__call__ = __patch__LTXVideoAttnProcessor__call__
68
+
69
+ # Otherwise, use the custom CP plan defined here, this maybe
70
+ # a little different from the native diffusers implementation
71
+ # for some models.
72
+
73
+ _cp_plan = {
74
+ # Here is a Transformer level CP plan for Flux, which will
75
+ # only apply the only 1 split hook (pre_forward) on the forward
76
+ # of Transformer, and gather the output after Transformer forward.
77
+ # Pattern of transformer forward, split_output=False:
78
+ # un-split input -> splited input (inside transformer)
79
+ # Pattern of the transformer_blocks, single_transformer_blocks:
80
+ # splited input (previous splited output) -> to_qkv/...
81
+ # -> all2all
82
+ # -> attn (local head, full seqlen)
83
+ # -> all2all
84
+ # -> splited output
85
+ # The `hidden_states` and `encoder_hidden_states` will still keep
86
+ # itself splited after block forward, namely, hidden_states will
87
+ # automatically split by the all2all comm op after attn, and the
88
+ # encoder_hidden_states will be keep splited after the entrypoint
89
+ # of transformer forward, for the all blocks.
90
+ "": {
91
+ "hidden_states": ContextParallelInput(
92
+ split_dim=1, expected_dims=3, split_output=False
93
+ ),
94
+ "encoder_hidden_states": ContextParallelInput(
95
+ split_dim=1, expected_dims=3, split_output=False
96
+ ),
97
+ # NOTE: encoder_attention_mask (namely, attention_mask in cross-attn)
98
+ # should never be split across seqlen while using context parallelism
99
+ # for LTXVideoTransformer3DModel. It don't contribute to any computation
100
+ # in parallel or not. So we comment it out here and handle the head-split
101
+ # correctly while using context parallel in the patched attention processor.
102
+ # "encoder_attention_mask": ContextParallelInput(
103
+ # split_dim=1, expected_dims=2, split_output=False
104
+ # ),
105
+ },
106
+ # Pattern of rope, split_output=True (split output rather than input):
107
+ # un-split input
108
+ # -> keep input un-split
109
+ # -> rope
110
+ # -> splited output
111
+ "rope": {
112
+ 0: ContextParallelInput(
113
+ split_dim=1, expected_dims=3, split_output=True
114
+ ),
115
+ 1: ContextParallelInput(
116
+ split_dim=1, expected_dims=3, split_output=True
117
+ ),
118
+ },
119
+ # Then, the final proj_out will gather the splited output.
120
+ # splited input (previous splited output)
121
+ # -> all gather
122
+ # -> un-split output
123
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
124
+ }
125
+ return _cp_plan
126
+
127
+
128
+ @functools.wraps(LTXAttention.prepare_attention_mask)
129
+ def __patch__LTXAttention_prepare_attention_mask__(
130
+ self: LTXAttention,
131
+ attention_mask: torch.Tensor,
132
+ target_length: int,
133
+ batch_size: int,
134
+ out_dim: int = 3,
135
+ # NOTE(DefTruth): Allow specifying head_size for CP
136
+ head_size: Optional[int] = None,
137
+ ) -> torch.Tensor:
138
+ """
139
+ Prepare the attention mask for the attention computation.
140
+
141
+ Args:
142
+ attention_mask (`torch.Tensor`): The attention mask to prepare.
143
+ target_length (`int`): The target length of the attention mask.
144
+ batch_size (`int`): The batch size for repeating the attention mask.
145
+ out_dim (`int`, *optional*, defaults to `3`): Output dimension.
146
+
147
+ Returns:
148
+ `torch.Tensor`: The prepared attention mask.
149
+ """
150
+ if head_size is None:
151
+ head_size = self.heads
152
+ if attention_mask is None:
153
+ return attention_mask
154
+
155
+ current_length: int = attention_mask.shape[-1]
156
+ if current_length != target_length:
157
+ if attention_mask.device.type == "mps":
158
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
159
+ # Instead, we can manually construct the padding tensor.
160
+ padding_shape = (
161
+ attention_mask.shape[0],
162
+ attention_mask.shape[1],
163
+ target_length,
164
+ )
165
+ padding = torch.zeros(
166
+ padding_shape,
167
+ dtype=attention_mask.dtype,
168
+ device=attention_mask.device,
169
+ )
170
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
171
+ else:
172
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
173
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
174
+ # remaining_length: int = target_length - current_length
175
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
176
+ attention_mask = F.pad(
177
+ attention_mask, (0, target_length), value=0.0
178
+ )
179
+
180
+ if out_dim == 3:
181
+ if attention_mask.shape[0] < batch_size * head_size:
182
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
183
+ elif out_dim == 4:
184
+ attention_mask = attention_mask.unsqueeze(1)
185
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
186
+
187
+ return attention_mask
188
+
189
+
190
+ @functools.wraps(LTXVideoAttnProcessor.__call__)
191
+ def __patch__LTXVideoAttnProcessor__call__(
192
+ self: LTXVideoAttnProcessor,
193
+ attn: "LTXAttention",
194
+ hidden_states: torch.Tensor,
195
+ encoder_hidden_states: Optional[torch.Tensor] = None,
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ image_rotary_emb: Optional[torch.Tensor] = None,
198
+ ) -> torch.Tensor:
199
+ batch_size, sequence_length, _ = (
200
+ hidden_states.shape
201
+ if encoder_hidden_states is None
202
+ else encoder_hidden_states.shape
203
+ )
204
+
205
+ if attention_mask is not None:
206
+ if self._parallel_config is None:
207
+ attention_mask = attn.prepare_attention_mask(
208
+ attention_mask, sequence_length, batch_size
209
+ )
210
+ attention_mask = attention_mask.view(
211
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
212
+ )
213
+ else:
214
+ # NOTE(DefTruth): Fix attention mask preparation for context parallelism
215
+ cp_config = getattr(
216
+ self._parallel_config, "context_parallel_config", None
217
+ )
218
+ if cp_config is not None and cp_config._world_size > 1:
219
+ head_size = attn.heads // cp_config._world_size
220
+ attention_mask = attn.prepare_attention_mask(
221
+ attention_mask,
222
+ sequence_length * cp_config._world_size,
223
+ batch_size,
224
+ 3,
225
+ head_size,
226
+ )
227
+ attention_mask = attention_mask.view(
228
+ batch_size, head_size, -1, attention_mask.shape[-1]
229
+ )
230
+
231
+ if encoder_hidden_states is None:
232
+ encoder_hidden_states = hidden_states
233
+
234
+ query = attn.to_q(hidden_states)
235
+ key = attn.to_k(encoder_hidden_states)
236
+ value = attn.to_v(encoder_hidden_states)
237
+
238
+ query = attn.norm_q(query)
239
+ key = attn.norm_k(key)
240
+
241
+ if image_rotary_emb is not None:
242
+ query = apply_rotary_emb(query, image_rotary_emb)
243
+ key = apply_rotary_emb(key, image_rotary_emb)
244
+
245
+ query = query.unflatten(2, (attn.heads, -1))
246
+ key = key.unflatten(2, (attn.heads, -1))
247
+ value = value.unflatten(2, (attn.heads, -1))
248
+
249
+ hidden_states = dispatch_attention_fn(
250
+ query,
251
+ key,
252
+ value,
253
+ attn_mask=attention_mask,
254
+ dropout_p=0.0,
255
+ is_causal=False,
256
+ backend=self._attention_backend,
257
+ parallel_config=self._parallel_config,
258
+ )
259
+ hidden_states = hidden_states.flatten(2, 3)
260
+ hidden_states = hidden_states.to(query.dtype)
261
+
262
+ hidden_states = attn.to_out[0](hidden_states)
263
+ hidden_states = attn.to_out[1](hidden_states)
264
+ return hidden_states
@@ -0,0 +1,407 @@
1
+ import torch
2
+ import functools
3
+ from typing import Optional, Tuple
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
6
+ from diffusers.models.transformers.transformer_qwenimage import (
7
+ apply_rotary_emb_qwen,
8
+ )
9
+
10
+ try:
11
+ from nunchaku.models.transformers.transformer_flux_v2 import (
12
+ NunchakuFluxAttention,
13
+ NunchakuFluxFA2Processor,
14
+ NunchakuFluxTransformer2DModelV2,
15
+ )
16
+ from nunchaku.ops.fused import fused_qkv_norm_rottary
17
+ from nunchaku.models.transformers.transformer_qwenimage import (
18
+ NunchakuQwenAttention,
19
+ NunchakuQwenImageNaiveFA2Processor,
20
+ NunchakuQwenImageTransformer2DModel,
21
+ )
22
+ except ImportError:
23
+ raise ImportError(
24
+ "NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
25
+ "requires the 'nunchaku' package. Please install nunchaku before using "
26
+ "the context parallelism for nunchaku 4-bits models."
27
+ )
28
+
29
+ try:
30
+ from diffusers.models._modeling_parallel import (
31
+ ContextParallelInput,
32
+ ContextParallelOutput,
33
+ ContextParallelModelPlan,
34
+ )
35
+ except ImportError:
36
+ raise ImportError(
37
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
38
+ "Please install latest version of diffusers from source: \n"
39
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
40
+ )
41
+ from .cp_plan_registers import (
42
+ ContextParallelismPlanner,
43
+ ContextParallelismPlannerRegister,
44
+ )
45
+
46
+ from cache_dit.logger import init_logger
47
+
48
+ logger = init_logger(__name__)
49
+
50
+
51
+ @ContextParallelismPlannerRegister.register("NunchakuFlux")
52
+ class NunchakuFluxContextParallelismPlanner(ContextParallelismPlanner):
53
+ def apply(
54
+ self,
55
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
56
+ **kwargs,
57
+ ) -> ContextParallelModelPlan:
58
+
59
+ self._cp_planner_preferred_native_diffusers = False
60
+
61
+ if (
62
+ transformer is not None
63
+ and self._cp_planner_preferred_native_diffusers
64
+ ):
65
+
66
+ assert isinstance(
67
+ transformer, NunchakuFluxTransformer2DModelV2
68
+ ), "Transformer must be an instance of NunchakuFluxTransformer2DModelV2"
69
+ if hasattr(transformer, "_cp_plan"):
70
+ if transformer._cp_plan is not None:
71
+ return transformer._cp_plan
72
+
73
+ NunchakuFluxFA2Processor.__call__ = (
74
+ __patch_NunchakuFluxFA2Processor__call__
75
+ )
76
+ # Also need to patch the parallel config and attention backend
77
+ if not hasattr(NunchakuFluxFA2Processor, "_parallel_config"):
78
+ NunchakuFluxFA2Processor._parallel_config = None
79
+ if not hasattr(NunchakuFluxFA2Processor, "_attention_backend"):
80
+ NunchakuFluxFA2Processor._attention_backend = None
81
+ if not hasattr(NunchakuFluxAttention, "_parallel_config"):
82
+ NunchakuFluxAttention._parallel_config = None
83
+ if not hasattr(NunchakuFluxAttention, "_attention_backend"):
84
+ NunchakuFluxAttention._attention_backend = None
85
+
86
+ # Otherwise, use the custom CP plan defined here, this maybe
87
+ # a little different from the native diffusers implementation
88
+ # for some models.
89
+ _cp_plan = {
90
+ # Here is a Transformer level CP plan for Flux, which will
91
+ # only apply the only 1 split hook (pre_forward) on the forward
92
+ # of Transformer, and gather the output after Transformer forward.
93
+ # Pattern of transformer forward, split_output=False:
94
+ # un-split input -> splited input (inside transformer)
95
+ # Pattern of the transformer_blocks, single_transformer_blocks:
96
+ # splited input (previous splited output) -> to_qkv/...
97
+ # -> all2all
98
+ # -> attn (local head, full seqlen)
99
+ # -> all2all
100
+ # -> splited output
101
+ # The `hidden_states` and `encoder_hidden_states` will still keep
102
+ # itself splited after block forward (namely, automatic split by
103
+ # the all2all comm op after attn) for the all blocks.
104
+ # img_ids and txt_ids will only be splited once at the very beginning,
105
+ # and keep splited through the whole transformer forward. The all2all
106
+ # comm op only happens on the `out` tensor after local attn not on
107
+ # img_ids and txt_ids.
108
+ "": {
109
+ "hidden_states": ContextParallelInput(
110
+ split_dim=1, expected_dims=3, split_output=False
111
+ ),
112
+ "encoder_hidden_states": ContextParallelInput(
113
+ split_dim=1, expected_dims=3, split_output=False
114
+ ),
115
+ "img_ids": ContextParallelInput(
116
+ split_dim=0, expected_dims=2, split_output=False
117
+ ),
118
+ "txt_ids": ContextParallelInput(
119
+ split_dim=0, expected_dims=2, split_output=False
120
+ ),
121
+ },
122
+ # Then, the final proj_out will gather the splited output.
123
+ # splited input (previous splited output)
124
+ # -> all gather
125
+ # -> un-split output
126
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
127
+ }
128
+ return _cp_plan
129
+
130
+
131
+ @functools.wraps(NunchakuFluxFA2Processor.__call__)
132
+ def __patch_NunchakuFluxFA2Processor__call__(
133
+ self: NunchakuFluxFA2Processor,
134
+ attn: NunchakuFluxAttention,
135
+ hidden_states: torch.Tensor,
136
+ encoder_hidden_states: Optional[torch.Tensor] = None,
137
+ attention_mask: Optional[torch.Tensor] = None,
138
+ image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
139
+ **kwargs,
140
+ ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
141
+ # The original implementation of NunchakuFluxFA2Processor.__call__
142
+ # is not changed here for brevity. In actual implementation, we need to
143
+ # modify the attention computation to support context parallelism.
144
+ if attention_mask is not None:
145
+ raise NotImplementedError("attention_mask is not supported")
146
+
147
+ batch_size, _, channels = hidden_states.shape
148
+ assert channels == attn.heads * attn.head_dim
149
+ qkv = fused_qkv_norm_rottary(
150
+ hidden_states,
151
+ attn.to_qkv,
152
+ attn.norm_q,
153
+ attn.norm_k,
154
+ (
155
+ image_rotary_emb[0]
156
+ if isinstance(image_rotary_emb, tuple)
157
+ else image_rotary_emb
158
+ ),
159
+ )
160
+
161
+ if attn.added_kv_proj_dim is not None:
162
+ assert encoder_hidden_states is not None
163
+ assert isinstance(image_rotary_emb, tuple)
164
+ qkv_context = fused_qkv_norm_rottary(
165
+ encoder_hidden_states,
166
+ attn.add_qkv_proj,
167
+ attn.norm_added_q,
168
+ attn.norm_added_k,
169
+ image_rotary_emb[1],
170
+ )
171
+ qkv = torch.cat([qkv_context, qkv], dim=1)
172
+
173
+ query, key, value = qkv.chunk(3, dim=-1)
174
+ # Original implementation:
175
+ # query = query.view(batch_size, -1, attn.heads, attn.head_dim).transpose(
176
+ # 1, 2
177
+ # )
178
+ # key = key.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
179
+ # value = value.view(batch_size, -1, attn.heads, attn.head_dim).transpose(
180
+ # 1, 2
181
+ # )
182
+ # hidden_states = F.scaled_dot_product_attention(
183
+ # query,
184
+ # key,
185
+ # value,
186
+ # attn_mask=attention_mask,
187
+ # dropout_p=0.0,
188
+ # is_causal=False,
189
+ # )
190
+ # hidden_states = hidden_states.transpose(1, 2).reshape(
191
+ # batch_size, -1, attn.heads * attn.head_dim
192
+ # )
193
+ # hidden_states = hidden_states.to(query.dtype)
194
+
195
+ # NOTE(DefTruth): Monkey patch to support context parallelism
196
+ query = query.view(batch_size, -1, attn.heads, attn.head_dim)
197
+ key = key.view(batch_size, -1, attn.heads, attn.head_dim)
198
+ value = value.view(batch_size, -1, attn.heads, attn.head_dim)
199
+
200
+ hidden_states = dispatch_attention_fn(
201
+ query,
202
+ key,
203
+ value,
204
+ attn_mask=attention_mask,
205
+ backend=getattr(self, "_attention_backend", None),
206
+ parallel_config=getattr(self, "_parallel_config", None),
207
+ )
208
+ hidden_states = hidden_states.flatten(2, 3)
209
+ hidden_states = hidden_states.to(query.dtype)
210
+
211
+ if encoder_hidden_states is not None:
212
+ encoder_hidden_states, hidden_states = (
213
+ hidden_states[:, : encoder_hidden_states.shape[1]],
214
+ hidden_states[:, encoder_hidden_states.shape[1] :],
215
+ )
216
+ # linear proj
217
+ hidden_states = attn.to_out[0](hidden_states)
218
+ # dropout
219
+ hidden_states = attn.to_out[1](hidden_states)
220
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
221
+ return hidden_states, encoder_hidden_states
222
+ else:
223
+ # for single transformer block, we split the proj_out into two linear layers
224
+ hidden_states = attn.to_out(hidden_states)
225
+ return hidden_states
226
+
227
+
228
+ @ContextParallelismPlannerRegister.register("NunchakuQwenImage")
229
+ class NunchakuQwenImageContextParallelismPlanner(ContextParallelismPlanner):
230
+ def apply(
231
+ self,
232
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
233
+ **kwargs,
234
+ ) -> ContextParallelModelPlan:
235
+
236
+ self._cp_planner_preferred_native_diffusers = False
237
+
238
+ if (
239
+ transformer is not None
240
+ and self._cp_planner_preferred_native_diffusers
241
+ ):
242
+
243
+ assert isinstance(
244
+ transformer, NunchakuQwenImageTransformer2DModel
245
+ ), "Transformer must be an instance of NunchakuQwenImageTransformer2DModel"
246
+ if hasattr(transformer, "_cp_plan"):
247
+ if transformer._cp_plan is not None:
248
+ return transformer._cp_plan
249
+
250
+ NunchakuQwenImageNaiveFA2Processor.__call__ = (
251
+ __patch_NunchakuQwenImageNaiveFA2Processor__call__
252
+ )
253
+ # Also need to patch the parallel config and attention backend
254
+ if not hasattr(NunchakuQwenImageNaiveFA2Processor, "_parallel_config"):
255
+ NunchakuQwenImageNaiveFA2Processor._parallel_config = None
256
+ if not hasattr(
257
+ NunchakuQwenImageNaiveFA2Processor, "_attention_backend"
258
+ ):
259
+ NunchakuQwenImageNaiveFA2Processor._attention_backend = None
260
+ if not hasattr(NunchakuQwenAttention, "_parallel_config"):
261
+ NunchakuQwenAttention._parallel_config = None
262
+ if not hasattr(NunchakuQwenAttention, "_attention_backend"):
263
+ NunchakuQwenAttention._attention_backend = None
264
+
265
+ # Otherwise, use the custom CP plan defined here, this maybe
266
+ # a little different from the native diffusers implementation
267
+ # for some models.
268
+ _cp_plan = {
269
+ # Here is a Transformer level CP plan for Flux, which will
270
+ # only apply the only 1 split hook (pre_forward) on the forward
271
+ # of Transformer, and gather the output after Transformer forward.
272
+ # Pattern of transformer forward, split_output=False:
273
+ # un-split input -> splited input (inside transformer)
274
+ # Pattern of the transformer_blocks, single_transformer_blocks:
275
+ # splited input (previous splited output) -> to_qkv/...
276
+ # -> all2all
277
+ # -> attn (local head, full seqlen)
278
+ # -> all2all
279
+ # -> splited output
280
+ # The `hidden_states` and `encoder_hidden_states` will still keep
281
+ # itself splited after block forward (namely, automatic split by
282
+ # the all2all comm op after attn) for the all blocks.
283
+ "": {
284
+ "hidden_states": ContextParallelInput(
285
+ split_dim=1, expected_dims=3, split_output=False
286
+ ),
287
+ # NOTE: Due to the joint attention implementation of
288
+ # QwenImageTransformerBlock, we must split the
289
+ # encoder_hidden_states as well.
290
+ "encoder_hidden_states": ContextParallelInput(
291
+ split_dim=1, expected_dims=3, split_output=False
292
+ ),
293
+ # NOTE: But encoder_hidden_states_mask seems never used in
294
+ # QwenImageTransformerBlock, so we do not split it here.
295
+ # "encoder_hidden_states_mask": ContextParallelInput(
296
+ # split_dim=1, expected_dims=2, split_output=False
297
+ # ),
298
+ },
299
+ # Pattern of pos_embed, split_output=True (split output rather than input):
300
+ # un-split input
301
+ # -> keep input un-split
302
+ # -> rope
303
+ # -> splited output
304
+ "pos_embed": {
305
+ 0: ContextParallelInput(
306
+ split_dim=0, expected_dims=2, split_output=True
307
+ ),
308
+ 1: ContextParallelInput(
309
+ split_dim=0, expected_dims=2, split_output=True
310
+ ),
311
+ },
312
+ # Then, the final proj_out will gather the splited output.
313
+ # splited input (previous splited output)
314
+ # -> all gather
315
+ # -> un-split output
316
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
317
+ }
318
+ return _cp_plan
319
+
320
+
321
+ @functools.wraps(NunchakuQwenImageNaiveFA2Processor.__call__)
322
+ def __patch_NunchakuQwenImageNaiveFA2Processor__call__(
323
+ self,
324
+ attn,
325
+ hidden_states: torch.FloatTensor,
326
+ encoder_hidden_states: torch.FloatTensor = None,
327
+ encoder_hidden_states_mask: torch.FloatTensor = None,
328
+ attention_mask: Optional[torch.FloatTensor] = None,
329
+ image_rotary_emb: Optional[torch.Tensor] = None,
330
+ ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
331
+ if encoder_hidden_states is None:
332
+ raise ValueError(
333
+ "NunchakuQwenImageFA2Processor requires encoder_hidden_states (text stream)"
334
+ )
335
+
336
+ seq_txt = encoder_hidden_states.shape[1]
337
+
338
+ # Compute QKV for image stream (sample projections)
339
+ img_qkv = attn.to_qkv(hidden_states)
340
+ img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)
341
+
342
+ # Compute QKV for text stream (context projections)
343
+ txt_qkv = attn.add_qkv_proj(encoder_hidden_states)
344
+ txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1)
345
+
346
+ # Reshape for multi-head attention
347
+ img_query = img_query.unflatten(-1, (attn.heads, -1)) # [B, L, H, D]
348
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
349
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
350
+
351
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
352
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
353
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
354
+
355
+ # Apply QK normalization
356
+ assert attn.norm_q is not None
357
+ img_query = attn.norm_q(img_query)
358
+ assert attn.norm_k is not None
359
+ img_key = attn.norm_k(img_key)
360
+ assert attn.norm_added_q is not None
361
+ txt_query = attn.norm_added_q(txt_query)
362
+ assert attn.norm_added_k is not None
363
+ txt_key = attn.norm_added_k(txt_key)
364
+
365
+ # Apply rotary embeddings
366
+ if image_rotary_emb is not None:
367
+ img_freqs, txt_freqs = image_rotary_emb
368
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
369
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
370
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
371
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
372
+
373
+ # Concatenate for joint attention: [text, image]
374
+ joint_query = torch.cat([txt_query, img_query], dim=1)
375
+ joint_key = torch.cat([txt_key, img_key], dim=1)
376
+ joint_value = torch.cat([txt_value, img_value], dim=1)
377
+
378
+ # Compute joint attention
379
+ joint_hidden_states = dispatch_attention_fn(
380
+ joint_query,
381
+ joint_key,
382
+ joint_value,
383
+ attn_mask=attention_mask,
384
+ dropout_p=0.0,
385
+ is_causal=False,
386
+ # NOTE(DefTruth): Use the patched attention backend and
387
+ # parallel config to make context parallelism work here.
388
+ backend=getattr(self, "_attention_backend", None),
389
+ parallel_config=getattr(self, "_parallel_config", None),
390
+ )
391
+
392
+ # Reshape back
393
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
394
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
395
+
396
+ # Split attention outputs back
397
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
398
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
399
+
400
+ # Apply output projections
401
+ img_attn_output = attn.to_out[0](img_attn_output)
402
+ if len(attn.to_out) > 1:
403
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
404
+
405
+ txt_attn_output = attn.to_add_out(txt_attn_output)
406
+
407
+ return img_attn_output, txt_attn_output