cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,729 @@
1
+ import torch
2
+ import functools
3
+ from typing import Optional, Dict, Any, Union, Tuple
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from diffusers.utils import (
6
+ USE_PEFT_BACKEND,
7
+ scale_lora_layers,
8
+ unscale_lora_layers,
9
+ )
10
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
11
+ from diffusers.models.transformers.transformer_hunyuan_video import (
12
+ HunyuanVideoTransformer3DModel,
13
+ HunyuanVideoAttnProcessor2_0,
14
+ )
15
+ from diffusers.models.attention_processor import Attention
16
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
17
+
18
+ try:
19
+ from diffusers import HunyuanImageTransformer2DModel
20
+ from diffusers.models._modeling_parallel import (
21
+ ContextParallelInput,
22
+ ContextParallelOutput,
23
+ ContextParallelModelPlan,
24
+ )
25
+ except ImportError:
26
+ raise ImportError(
27
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
28
+ "Please install latest version of diffusers from source: \n"
29
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
30
+ )
31
+ from .cp_plan_registers import (
32
+ ContextParallelismPlanner,
33
+ ContextParallelismPlannerRegister,
34
+ )
35
+
36
+ from cache_dit.logger import init_logger
37
+
38
+ logger = init_logger(__name__)
39
+
40
+
41
+ @ContextParallelismPlannerRegister.register("HunyuanImage")
42
+ class HunyuanImageContextParallelismPlanner(ContextParallelismPlanner):
43
+ def apply(
44
+ self,
45
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
46
+ **kwargs,
47
+ ) -> ContextParallelModelPlan:
48
+
49
+ # NOTE: Diffusers native CP plan still not supported
50
+ # for HunyuanImage now.
51
+ self._cp_planner_preferred_native_diffusers = False
52
+
53
+ if (
54
+ transformer is not None
55
+ and self._cp_planner_preferred_native_diffusers
56
+ ):
57
+ assert isinstance(
58
+ transformer, HunyuanImageTransformer2DModel
59
+ ), "Transformer must be an instance of HunyuanImageTransformer2DModel"
60
+ if hasattr(transformer, "_cp_plan"):
61
+ if transformer._cp_plan is not None:
62
+ return transformer._cp_plan
63
+
64
+ # Apply monkey patch to fix attention mask preparation while using CP
65
+ assert isinstance(transformer, HunyuanImageTransformer2DModel)
66
+ HunyuanImageTransformer2DModel.forward = (
67
+ __patch__HunyuanImageTransformer2DModel_forward__
68
+ )
69
+
70
+ # Otherwise, use the custom CP plan defined here, this maybe
71
+ # a little different from the native diffusers implementation
72
+ # for some models.
73
+ _cp_plan = {
74
+ # Pattern of rope, split_output=True (split output rather than input):
75
+ # un-split input
76
+ # -> keep input un-split
77
+ # -> rope
78
+ # -> splited output
79
+ "rope": {
80
+ 0: ContextParallelInput(
81
+ split_dim=0, expected_dims=2, split_output=True
82
+ ),
83
+ 1: ContextParallelInput(
84
+ split_dim=0, expected_dims=2, split_output=True
85
+ ),
86
+ },
87
+ # Pattern of transformer_blocks.0, split_output=False:
88
+ # un-split input -> split -> to_qkv/...
89
+ # -> all2all
90
+ # -> attn (local head, full seqlen)
91
+ # -> all2all
92
+ # -> splited output
93
+ # Pattern of the rest transformer_blocks, single_transformer_blocks:
94
+ # splited input (previous splited output) -> to_qkv/...
95
+ # -> all2all
96
+ # -> attn (local head, full seqlen)
97
+ # -> all2all
98
+ # -> splited output
99
+ # The `encoder_hidden_states` will be changed after each block forward,
100
+ # so we need to split it at the first block, and keep it splited (namely,
101
+ # automatically split by the all2all op after attn) for the rest blocks.
102
+ # The `out` tensor of local attn will be splited into `hidden_states` and
103
+ # `encoder_hidden_states` after each block forward, thus both of them
104
+ # will be automatically splited by all2all comm op after local attn.
105
+ "transformer_blocks.0": {
106
+ "hidden_states": ContextParallelInput(
107
+ split_dim=1, expected_dims=3, split_output=False
108
+ ),
109
+ "encoder_hidden_states": ContextParallelInput(
110
+ split_dim=1, expected_dims=3, split_output=False
111
+ ),
112
+ },
113
+ # NOTE: We have to handle the `attention_mask` carefully in monkey-patched
114
+ # transformer forward while using CP, since it is not splited here.
115
+ # Then, the final proj_out will gather the splited output.
116
+ # splited input (previous splited output)
117
+ # -> all gather
118
+ # -> un-split output
119
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
120
+ }
121
+ return _cp_plan
122
+
123
+
124
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hunyuanimage.py#L806
125
+ @functools.wraps(HunyuanImageTransformer2DModel.forward)
126
+ def __patch__HunyuanImageTransformer2DModel_forward__(
127
+ self: HunyuanImageTransformer2DModel,
128
+ hidden_states: torch.Tensor,
129
+ timestep: torch.LongTensor,
130
+ encoder_hidden_states: torch.Tensor,
131
+ encoder_attention_mask: torch.Tensor,
132
+ timestep_r: Optional[torch.LongTensor] = None,
133
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
134
+ encoder_attention_mask_2: Optional[torch.Tensor] = None,
135
+ guidance: Optional[torch.Tensor] = None,
136
+ attention_kwargs: Optional[Dict[str, Any]] = None,
137
+ return_dict: bool = True,
138
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
139
+ if attention_kwargs is not None:
140
+ attention_kwargs = attention_kwargs.copy()
141
+ lora_scale = attention_kwargs.pop("scale", 1.0)
142
+ else:
143
+ lora_scale = 1.0
144
+
145
+ if USE_PEFT_BACKEND:
146
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
147
+ scale_lora_layers(self, lora_scale)
148
+ else:
149
+ if (
150
+ attention_kwargs is not None
151
+ and attention_kwargs.get("scale", None) is not None
152
+ ):
153
+ logger.warning(
154
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
155
+ )
156
+
157
+ if hidden_states.ndim == 4:
158
+ batch_size, channels, height, width = hidden_states.shape
159
+ sizes = (height, width)
160
+ elif hidden_states.ndim == 5:
161
+ batch_size, channels, frame, height, width = hidden_states.shape
162
+ sizes = (frame, height, width)
163
+ else:
164
+ raise ValueError(
165
+ f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}"
166
+ )
167
+
168
+ post_patch_sizes = tuple(
169
+ d // p for d, p in zip(sizes, self.config.patch_size)
170
+ )
171
+
172
+ # 1. RoPE
173
+ image_rotary_emb = self.rope(hidden_states)
174
+
175
+ # 2. Conditional embeddings
176
+ encoder_attention_mask = encoder_attention_mask.bool()
177
+ temb = self.time_guidance_embed(
178
+ timestep, guidance=guidance, timestep_r=timestep_r
179
+ )
180
+ hidden_states = self.x_embedder(hidden_states)
181
+ encoder_hidden_states = self.context_embedder(
182
+ encoder_hidden_states, timestep, encoder_attention_mask
183
+ )
184
+
185
+ if (
186
+ self.context_embedder_2 is not None
187
+ and encoder_hidden_states_2 is not None
188
+ ):
189
+ encoder_hidden_states_2 = self.context_embedder_2(
190
+ encoder_hidden_states_2
191
+ )
192
+
193
+ encoder_attention_mask_2 = encoder_attention_mask_2.bool()
194
+
195
+ # reorder and combine text tokens: combine valid tokens first, then padding
196
+ new_encoder_hidden_states = []
197
+ new_encoder_attention_mask = []
198
+
199
+ for text, text_mask, text_2, text_mask_2 in zip(
200
+ encoder_hidden_states,
201
+ encoder_attention_mask,
202
+ encoder_hidden_states_2,
203
+ encoder_attention_mask_2,
204
+ ):
205
+ # Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
206
+ new_encoder_hidden_states.append(
207
+ torch.cat(
208
+ [
209
+ text_2[text_mask_2], # valid byt5
210
+ text[text_mask], # valid mllm
211
+ text_2[~text_mask_2], # invalid byt5
212
+ text[~text_mask], # invalid mllm
213
+ ],
214
+ dim=0,
215
+ )
216
+ )
217
+
218
+ # Apply same reordering to attention masks
219
+ new_encoder_attention_mask.append(
220
+ torch.cat(
221
+ [
222
+ text_mask_2[text_mask_2],
223
+ text_mask[text_mask],
224
+ text_mask_2[~text_mask_2],
225
+ text_mask[~text_mask],
226
+ ],
227
+ dim=0,
228
+ )
229
+ )
230
+
231
+ encoder_hidden_states = torch.stack(new_encoder_hidden_states)
232
+ encoder_attention_mask = torch.stack(new_encoder_attention_mask)
233
+
234
+ attention_mask = torch.nn.functional.pad(
235
+ encoder_attention_mask, (hidden_states.shape[1], 0), value=True
236
+ )
237
+ # NOTE(DefTruth): Permute attention_mask if context parallel is used.
238
+ # For example, if work size = 2: [H, E] -> [H_0, E_0, H_1, E_1]
239
+ if self._parallel_config is not None:
240
+ cp_config = getattr(
241
+ self._parallel_config, "context_parallel_config", None
242
+ )
243
+ if cp_config is not None and cp_config._world_size > 1:
244
+ hidden_mask = attention_mask[:, : hidden_states.shape[1]]
245
+ encoder_mask = attention_mask[:, hidden_states.shape[1] :]
246
+ hidden_mask_splits = torch.chunk(
247
+ hidden_mask, cp_config._world_size, dim=1
248
+ )
249
+ encoder_mask_splits = torch.chunk(
250
+ encoder_mask, cp_config._world_size, dim=1
251
+ )
252
+ new_attention_mask_splits = []
253
+ for i in range(cp_config._world_size):
254
+ new_attention_mask_splits.append(hidden_mask_splits[i])
255
+ new_attention_mask_splits.append(encoder_mask_splits[i])
256
+ attention_mask = torch.cat(new_attention_mask_splits, dim=1)
257
+
258
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(
259
+ 2
260
+ ) # [1,N] -> [1,1,1,N]
261
+
262
+ # 3. Transformer blocks
263
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
264
+ for block in self.transformer_blocks:
265
+ hidden_states, encoder_hidden_states = (
266
+ self._gradient_checkpointing_func(
267
+ block,
268
+ hidden_states,
269
+ encoder_hidden_states,
270
+ temb,
271
+ attention_mask=attention_mask,
272
+ image_rotary_emb=image_rotary_emb,
273
+ )
274
+ )
275
+
276
+ for block in self.single_transformer_blocks:
277
+ hidden_states, encoder_hidden_states = (
278
+ self._gradient_checkpointing_func(
279
+ block,
280
+ hidden_states,
281
+ encoder_hidden_states,
282
+ temb,
283
+ attention_mask=attention_mask,
284
+ image_rotary_emb=image_rotary_emb,
285
+ )
286
+ )
287
+
288
+ else:
289
+ for block in self.transformer_blocks:
290
+ hidden_states, encoder_hidden_states = block(
291
+ hidden_states,
292
+ encoder_hidden_states,
293
+ temb,
294
+ attention_mask=attention_mask,
295
+ image_rotary_emb=image_rotary_emb,
296
+ )
297
+
298
+ for block in self.single_transformer_blocks:
299
+ hidden_states, encoder_hidden_states = block(
300
+ hidden_states,
301
+ encoder_hidden_states,
302
+ temb,
303
+ attention_mask=attention_mask,
304
+ image_rotary_emb=image_rotary_emb,
305
+ )
306
+
307
+ # 4. Output projection
308
+ hidden_states = self.norm_out(hidden_states, temb)
309
+ hidden_states = self.proj_out(hidden_states)
310
+
311
+ # 5. unpatchify
312
+ # reshape: [batch_size, *post_patch_dims, channels, *patch_size]
313
+ out_channels = self.config.out_channels
314
+ reshape_dims = (
315
+ [batch_size]
316
+ + list(post_patch_sizes)
317
+ + [out_channels]
318
+ + list(self.config.patch_size)
319
+ )
320
+ hidden_states = hidden_states.reshape(*reshape_dims)
321
+
322
+ # create permutation pattern: batch, channels, then interleave post_patch and patch dims
323
+ # For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
324
+ # For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
325
+ ndim = len(post_patch_sizes)
326
+ permute_pattern = [0, ndim + 1] # batch, channels
327
+ for i in range(ndim):
328
+ permute_pattern.extend(
329
+ [i + 1, ndim + 2 + i]
330
+ ) # post_patch_sizes[i], patch_sizes[i]
331
+ hidden_states = hidden_states.permute(*permute_pattern)
332
+
333
+ # flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
334
+ # batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
335
+ final_dims = [batch_size, out_channels] + [
336
+ post_patch * patch
337
+ for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
338
+ ]
339
+ hidden_states = hidden_states.reshape(*final_dims)
340
+
341
+ if USE_PEFT_BACKEND:
342
+ # remove `lora_scale` from each PEFT layer
343
+ unscale_lora_layers(self, lora_scale)
344
+
345
+ if not return_dict:
346
+ return (hidden_states,)
347
+
348
+ return Transformer2DModelOutput(sample=hidden_states)
349
+
350
+
351
+ @ContextParallelismPlannerRegister.register("HunyuanVideo")
352
+ class HunyuanVideoContextParallelismPlanner(ContextParallelismPlanner):
353
+ def apply(
354
+ self,
355
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
356
+ **kwargs,
357
+ ) -> ContextParallelModelPlan:
358
+
359
+ # NOTE: Diffusers native CP plan still not supported
360
+ # for HunyuanImage now.
361
+ self._cp_planner_preferred_native_diffusers = False
362
+
363
+ if (
364
+ transformer is not None
365
+ and self._cp_planner_preferred_native_diffusers
366
+ ):
367
+ assert isinstance(
368
+ transformer, HunyuanVideoTransformer3DModel
369
+ ), "Transformer must be an instance of HunyuanVideoTransformer3DModel"
370
+ if hasattr(transformer, "_cp_plan"):
371
+ if transformer._cp_plan is not None:
372
+ return transformer._cp_plan
373
+
374
+ # Apply monkey patch to fix attention mask preparation while using CP
375
+ assert isinstance(transformer, HunyuanVideoTransformer3DModel)
376
+ HunyuanVideoTransformer3DModel.forward = (
377
+ __patch__HunyuanVideoTransformer3DModel_forward__
378
+ )
379
+ HunyuanVideoAttnProcessor2_0.__call__ = (
380
+ __patch_HunyuanVideoAttnProcessor2_0__call__
381
+ )
382
+ # Also need to patch the parallel config and attention backend
383
+ if not hasattr(HunyuanVideoAttnProcessor2_0, "_parallel_config"):
384
+ HunyuanVideoAttnProcessor2_0._parallel_config = None
385
+ if not hasattr(HunyuanVideoAttnProcessor2_0, "_attention_backend"):
386
+ HunyuanVideoAttnProcessor2_0._attention_backend = None
387
+
388
+ # Otherwise, use the custom CP plan defined here, this maybe
389
+ # a little different from the native diffusers implementation
390
+ # for some models.
391
+ _cp_plan = {
392
+ # Pattern of rope, split_output=True (split output rather than input):
393
+ # un-split input
394
+ # -> keep input un-split
395
+ # -> rope
396
+ # -> splited output
397
+ "rope": {
398
+ 0: ContextParallelInput(
399
+ split_dim=0, expected_dims=2, split_output=True
400
+ ),
401
+ 1: ContextParallelInput(
402
+ split_dim=0, expected_dims=2, split_output=True
403
+ ),
404
+ },
405
+ # Pattern of transformer_blocks.0, split_output=False:
406
+ # un-split input -> split -> to_qkv/...
407
+ # -> all2all
408
+ # -> attn (local head, full seqlen)
409
+ # -> all2all
410
+ # -> splited output
411
+ # Pattern of the rest transformer_blocks, single_transformer_blocks:
412
+ # splited input (previous splited output) -> to_qkv/...
413
+ # -> all2all
414
+ # -> attn (local head, full seqlen)
415
+ # -> all2all
416
+ # -> splited output
417
+ # The `encoder_hidden_states` will be changed after each block forward,
418
+ # so we need to split it at the first block, and keep it splited (namely,
419
+ # automatically split by the all2all op after attn) for the rest blocks.
420
+ # The `out` tensor of local attn will be splited into `hidden_states` and
421
+ # `encoder_hidden_states` after each block forward, thus both of them
422
+ # will be automatically splited by all2all comm op after local attn.
423
+ "transformer_blocks.0": {
424
+ "hidden_states": ContextParallelInput(
425
+ split_dim=1, expected_dims=3, split_output=False
426
+ ),
427
+ "encoder_hidden_states": ContextParallelInput(
428
+ split_dim=1, expected_dims=3, split_output=False
429
+ ),
430
+ },
431
+ # NOTE: We have to handle the `attention_mask` carefully in monkey-patched
432
+ # transformer forward while using CP, since it is not splited here.
433
+ # Then, the final proj_out will gather the splited output.
434
+ # splited input (previous splited output)
435
+ # -> all gather
436
+ # -> un-split output
437
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
438
+ }
439
+ return _cp_plan
440
+
441
+
442
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hunyuan_video.py#L1032
443
+ @functools.wraps(HunyuanVideoTransformer3DModel.forward)
444
+ def __patch__HunyuanVideoTransformer3DModel_forward__(
445
+ self: HunyuanVideoTransformer3DModel,
446
+ hidden_states: torch.Tensor,
447
+ timestep: torch.LongTensor,
448
+ encoder_hidden_states: torch.Tensor,
449
+ encoder_attention_mask: torch.Tensor,
450
+ pooled_projections: torch.Tensor,
451
+ guidance: torch.Tensor = None,
452
+ attention_kwargs: Optional[Dict[str, Any]] = None,
453
+ return_dict: bool = True,
454
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
455
+ if attention_kwargs is not None:
456
+ attention_kwargs = attention_kwargs.copy()
457
+ lora_scale = attention_kwargs.pop("scale", 1.0)
458
+ else:
459
+ lora_scale = 1.0
460
+
461
+ if USE_PEFT_BACKEND:
462
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
463
+ scale_lora_layers(self, lora_scale)
464
+ else:
465
+ if (
466
+ attention_kwargs is not None
467
+ and attention_kwargs.get("scale", None) is not None
468
+ ):
469
+ logger.warning(
470
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
471
+ )
472
+
473
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
474
+ p, p_t = self.config.patch_size, self.config.patch_size_t
475
+ post_patch_num_frames = num_frames // p_t
476
+ post_patch_height = height // p
477
+ post_patch_width = width // p
478
+ first_frame_num_tokens = 1 * post_patch_height * post_patch_width
479
+
480
+ # 1. RoPE
481
+ image_rotary_emb = self.rope(hidden_states)
482
+
483
+ # 2. Conditional embeddings
484
+ temb, token_replace_emb = self.time_text_embed(
485
+ timestep, pooled_projections, guidance
486
+ )
487
+
488
+ hidden_states = self.x_embedder(hidden_states)
489
+ encoder_hidden_states = self.context_embedder(
490
+ encoder_hidden_states, timestep, encoder_attention_mask
491
+ )
492
+
493
+ # 3. Attention mask preparation
494
+ latent_sequence_length = hidden_states.shape[1]
495
+ condition_sequence_length = encoder_hidden_states.shape[1]
496
+ sequence_length = latent_sequence_length + condition_sequence_length
497
+ attention_mask = torch.ones(
498
+ batch_size,
499
+ sequence_length,
500
+ device=hidden_states.device,
501
+ dtype=torch.bool,
502
+ ) # [B, N]
503
+ effective_condition_sequence_length = encoder_attention_mask.sum(
504
+ dim=1, dtype=torch.int
505
+ ) # [B,]
506
+ effective_sequence_length = (
507
+ latent_sequence_length + effective_condition_sequence_length
508
+ )
509
+ indices = torch.arange(
510
+ sequence_length, device=hidden_states.device
511
+ ).unsqueeze(
512
+ 0
513
+ ) # [1, N]
514
+ mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
515
+ attention_mask = attention_mask.masked_fill(mask_indices, False)
516
+ # NOTE(DefTruth): Permute attention_mask if context parallel is used.
517
+ # For example, if work size = 2: [H, E] -> [H_0, E_0, H_1, E_1]
518
+ if self._parallel_config is not None:
519
+ cp_config = getattr(
520
+ self._parallel_config, "context_parallel_config", None
521
+ )
522
+ if cp_config is not None and cp_config._world_size > 1:
523
+ hidden_mask = attention_mask[:, :latent_sequence_length]
524
+ encoder_mask = attention_mask[:, latent_sequence_length:]
525
+ hidden_mask_splits = torch.chunk(
526
+ hidden_mask, cp_config._world_size, dim=1
527
+ )
528
+ encoder_mask_splits = torch.chunk(
529
+ encoder_mask, cp_config._world_size, dim=1
530
+ )
531
+ new_attention_mask_splits = []
532
+ for i in range(cp_config._world_size):
533
+ new_attention_mask_splits.append(hidden_mask_splits[i])
534
+ new_attention_mask_splits.append(encoder_mask_splits[i])
535
+ attention_mask = torch.cat(new_attention_mask_splits, dim=1)
536
+
537
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
538
+
539
+ # 4. Transformer blocks
540
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
541
+ for block in self.transformer_blocks:
542
+ hidden_states, encoder_hidden_states = (
543
+ self._gradient_checkpointing_func(
544
+ block,
545
+ hidden_states,
546
+ encoder_hidden_states,
547
+ temb,
548
+ attention_mask,
549
+ image_rotary_emb,
550
+ token_replace_emb,
551
+ first_frame_num_tokens,
552
+ )
553
+ )
554
+
555
+ for block in self.single_transformer_blocks:
556
+ hidden_states, encoder_hidden_states = (
557
+ self._gradient_checkpointing_func(
558
+ block,
559
+ hidden_states,
560
+ encoder_hidden_states,
561
+ temb,
562
+ attention_mask,
563
+ image_rotary_emb,
564
+ token_replace_emb,
565
+ first_frame_num_tokens,
566
+ )
567
+ )
568
+
569
+ else:
570
+ for block in self.transformer_blocks:
571
+ hidden_states, encoder_hidden_states = block(
572
+ hidden_states,
573
+ encoder_hidden_states,
574
+ temb,
575
+ attention_mask,
576
+ image_rotary_emb,
577
+ token_replace_emb,
578
+ first_frame_num_tokens,
579
+ )
580
+
581
+ for block in self.single_transformer_blocks:
582
+ hidden_states, encoder_hidden_states = block(
583
+ hidden_states,
584
+ encoder_hidden_states,
585
+ temb,
586
+ attention_mask,
587
+ image_rotary_emb,
588
+ token_replace_emb,
589
+ first_frame_num_tokens,
590
+ )
591
+
592
+ # 5. Output projection
593
+ hidden_states = self.norm_out(hidden_states, temb)
594
+ hidden_states = self.proj_out(hidden_states)
595
+
596
+ hidden_states = hidden_states.reshape(
597
+ batch_size,
598
+ post_patch_num_frames,
599
+ post_patch_height,
600
+ post_patch_width,
601
+ -1,
602
+ p_t,
603
+ p,
604
+ p,
605
+ )
606
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
607
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
608
+
609
+ if USE_PEFT_BACKEND:
610
+ # remove `lora_scale` from each PEFT layer
611
+ unscale_lora_layers(self, lora_scale)
612
+
613
+ if not return_dict:
614
+ return (hidden_states,)
615
+
616
+ return Transformer2DModelOutput(sample=hidden_states)
617
+
618
+
619
+ @functools.wraps(HunyuanVideoAttnProcessor2_0.__call__)
620
+ def __patch_HunyuanVideoAttnProcessor2_0__call__(
621
+ self: HunyuanVideoAttnProcessor2_0,
622
+ attn: Attention,
623
+ hidden_states: torch.Tensor,
624
+ encoder_hidden_states: Optional[torch.Tensor] = None,
625
+ attention_mask: Optional[torch.Tensor] = None,
626
+ image_rotary_emb: Optional[torch.Tensor] = None,
627
+ ) -> torch.Tensor:
628
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
629
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
630
+
631
+ # 1. QKV projections
632
+ query = attn.to_q(hidden_states)
633
+ key = attn.to_k(hidden_states)
634
+ value = attn.to_v(hidden_states)
635
+
636
+ # NOTE(DefTruth): no transpose
637
+ query = query.unflatten(2, (attn.heads, -1))
638
+ key = key.unflatten(2, (attn.heads, -1))
639
+ value = value.unflatten(2, (attn.heads, -1))
640
+
641
+ # 2. QK normalization
642
+ if attn.norm_q is not None:
643
+ query = attn.norm_q(query)
644
+ if attn.norm_k is not None:
645
+ key = attn.norm_k(key)
646
+
647
+ # 3. Rotational positional embeddings applied to latent stream
648
+ if image_rotary_emb is not None:
649
+ from diffusers.models.embeddings import apply_rotary_emb
650
+
651
+ # NOTE(DefTruth): Monkey patch for encoder conditional RoPE
652
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
653
+ query = torch.cat(
654
+ [
655
+ apply_rotary_emb(
656
+ query[:, : -encoder_hidden_states.shape[1]],
657
+ image_rotary_emb,
658
+ sequence_dim=1,
659
+ ),
660
+ query[:, -encoder_hidden_states.shape[1] :],
661
+ ],
662
+ dim=1,
663
+ )
664
+ key = torch.cat(
665
+ [
666
+ apply_rotary_emb(
667
+ key[:, : -encoder_hidden_states.shape[1]],
668
+ image_rotary_emb,
669
+ sequence_dim=1,
670
+ ),
671
+ key[:, -encoder_hidden_states.shape[1] :],
672
+ ],
673
+ dim=1,
674
+ )
675
+ else:
676
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
677
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
678
+
679
+ # 4. Encoder condition QKV projection and normalization
680
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
681
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
682
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
683
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
684
+
685
+ # NOTE(DefTruth): no transpose
686
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
687
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
688
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
689
+
690
+ if attn.norm_added_q is not None:
691
+ encoder_query = attn.norm_added_q(encoder_query)
692
+ if attn.norm_added_k is not None:
693
+ encoder_key = attn.norm_added_k(encoder_key)
694
+
695
+ query = torch.cat([query, encoder_query], dim=1)
696
+ key = torch.cat([key, encoder_key], dim=1)
697
+ value = torch.cat([value, encoder_value], dim=1)
698
+
699
+ # 5. Attention
700
+ # NOTE(DefTruth): use dispatch_attention_fn
701
+ hidden_states = dispatch_attention_fn(
702
+ query,
703
+ key,
704
+ value,
705
+ attn_mask=attention_mask,
706
+ dropout_p=0.0,
707
+ is_causal=False,
708
+ backend=getattr(self, "_attention_backend", None),
709
+ parallel_config=getattr(self, "_parallel_config", None),
710
+ )
711
+ # NOTE(DefTruth): no transpose
712
+ hidden_states = hidden_states.flatten(2, 3)
713
+ hidden_states = hidden_states.to(query.dtype)
714
+
715
+ # 6. Output projection
716
+ if encoder_hidden_states is not None:
717
+ hidden_states, encoder_hidden_states = (
718
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
719
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
720
+ )
721
+
722
+ if getattr(attn, "to_out", None) is not None:
723
+ hidden_states = attn.to_out[0](hidden_states)
724
+ hidden_states = attn.to_out[1](hidden_states)
725
+
726
+ if getattr(attn, "to_add_out", None) is not None:
727
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
728
+
729
+ return hidden_states, encoder_hidden_states