cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,543 @@
1
+ import torch
2
+
3
+ from cache_dit.caching import ForwardPattern
4
+ from cache_dit.caching.cache_contexts.cache_manager import (
5
+ ContextNotExistError,
6
+ )
7
+ from cache_dit.caching.cache_blocks.pattern_base import (
8
+ CachedBlocks_Pattern_Base,
9
+ )
10
+ from cache_dit.caching.cache_contexts.prune_context import PrunedContext
11
+ from cache_dit.caching.cache_contexts.prune_manager import (
12
+ PrunedContextManager,
13
+ )
14
+ from cache_dit.caching.cache_types import CacheType
15
+
16
+ from cache_dit.logger import init_logger
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
22
+ _supported_patterns = [
23
+ ForwardPattern.Pattern_3,
24
+ ForwardPattern.Pattern_4,
25
+ ForwardPattern.Pattern_5,
26
+ ]
27
+
28
+ def call_blocks(
29
+ self,
30
+ hidden_states: torch.Tensor,
31
+ *args,
32
+ **kwargs,
33
+ ):
34
+ # Call all blocks to process the hidden states without cache.
35
+ new_encoder_hidden_states = None
36
+ for block in self.transformer_blocks:
37
+ hidden_states = block(
38
+ hidden_states,
39
+ *args,
40
+ **kwargs,
41
+ )
42
+ hidden_states, new_encoder_hidden_states = (
43
+ self._process_block_outputs(hidden_states)
44
+ )
45
+
46
+ return hidden_states, new_encoder_hidden_states
47
+
48
+ @torch.compiler.disable
49
+ def _process_block_outputs(
50
+ self, hidden_states: torch.Tensor | tuple
51
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
52
+ # Process the outputs for the block.
53
+ new_encoder_hidden_states = None
54
+ if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
55
+ if len(hidden_states) == 2:
56
+ if isinstance(hidden_states[1], torch.Tensor):
57
+ hidden_states, new_encoder_hidden_states = hidden_states
58
+ if not self.forward_pattern.Return_H_First:
59
+ hidden_states, new_encoder_hidden_states = (
60
+ new_encoder_hidden_states,
61
+ hidden_states,
62
+ )
63
+ elif isinstance(hidden_states[0], torch.Tensor):
64
+ hidden_states = hidden_states[0]
65
+ else:
66
+ raise ValueError("Unexpected hidden_states format.")
67
+ else:
68
+ assert (
69
+ len(hidden_states) == 1
70
+ ), f"Unexpected output length: {len(hidden_states)}"
71
+ hidden_states = hidden_states[0]
72
+ return hidden_states, new_encoder_hidden_states
73
+
74
+ @torch.compiler.disable
75
+ def _process_forward_outputs(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ new_encoder_hidden_states: torch.Tensor | None,
79
+ ) -> (
80
+ torch.Tensor
81
+ | tuple[torch.Tensor, torch.Tensor]
82
+ | tuple[torch.Tensor, None]
83
+ ):
84
+ if self.forward_pattern.Return_H_Only:
85
+ return hidden_states
86
+ else:
87
+ if self.forward_pattern.Return_H_First:
88
+ return (hidden_states, new_encoder_hidden_states)
89
+ else:
90
+ return (new_encoder_hidden_states, hidden_states)
91
+
92
+ def forward(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ *args,
96
+ **kwargs,
97
+ ):
98
+ # Use it's own cache context.
99
+ try:
100
+ self.context_manager.set_context(self.cache_context)
101
+ self._check_cache_params()
102
+ except ContextNotExistError as e:
103
+ logger.warning(f"context not exist: {e}, skip cache.")
104
+ hidden_states, new_encoder_hidden_states = self.call_blocks(
105
+ hidden_states,
106
+ *args,
107
+ **kwargs,
108
+ )
109
+ return self._process_forward_outputs(
110
+ hidden_states, new_encoder_hidden_states
111
+ )
112
+
113
+ original_hidden_states = hidden_states
114
+ # Call first `n` blocks to process the hidden states for
115
+ # more stable diff calculation.
116
+ hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
117
+ hidden_states,
118
+ *args,
119
+ **kwargs,
120
+ )
121
+
122
+ Fn_hidden_states_residual = self._get_Fn_residual(
123
+ original_hidden_states, hidden_states
124
+ )
125
+ del original_hidden_states
126
+
127
+ self.context_manager.mark_step_begin()
128
+ # Residual L1 diff or Hidden States L1 diff
129
+ can_use_cache = self.context_manager.can_cache(
130
+ (
131
+ Fn_hidden_states_residual
132
+ if not self.context_manager.is_l1_diff_enabled()
133
+ else hidden_states
134
+ ),
135
+ parallelized=self._is_parallelized(),
136
+ prefix=(
137
+ f"{self.cache_prefix}_Fn_residual"
138
+ if not self.context_manager.is_l1_diff_enabled()
139
+ else f"{self.cache_prefix}_Fn_hidden_states"
140
+ ),
141
+ )
142
+
143
+ torch._dynamo.graph_break()
144
+ if can_use_cache:
145
+ self.context_manager.add_cached_step()
146
+ del Fn_hidden_states_residual
147
+ hidden_states, new_encoder_hidden_states = (
148
+ self.context_manager.apply_cache(
149
+ hidden_states,
150
+ new_encoder_hidden_states, # encoder_hidden_states not use cache
151
+ prefix=(
152
+ f"{self.cache_prefix}_Bn_residual"
153
+ if self.context_manager.is_cache_residual()
154
+ else f"{self.cache_prefix}_Bn_hidden_states"
155
+ ),
156
+ encoder_prefix=(
157
+ f"{self.cache_prefix}_Bn_residual"
158
+ if self.context_manager.is_encoder_cache_residual()
159
+ else f"{self.cache_prefix}_Bn_hidden_states"
160
+ ),
161
+ )
162
+ )
163
+ torch._dynamo.graph_break()
164
+ # Call last `n` blocks to further process the hidden states
165
+ # for higher precision.
166
+ if self.context_manager.Bn_compute_blocks() > 0:
167
+ hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
168
+ hidden_states,
169
+ *args,
170
+ **kwargs,
171
+ )
172
+ else:
173
+ self.context_manager.set_Fn_buffer(
174
+ Fn_hidden_states_residual,
175
+ prefix=f"{self.cache_prefix}_Fn_residual",
176
+ )
177
+ if self.context_manager.is_l1_diff_enabled():
178
+ # for hidden states L1 diff
179
+ self.context_manager.set_Fn_buffer(
180
+ hidden_states,
181
+ f"{self.cache_prefix}_Fn_hidden_states",
182
+ )
183
+ del Fn_hidden_states_residual
184
+ torch._dynamo.graph_break()
185
+ old_encoder_hidden_states = new_encoder_hidden_states
186
+ (
187
+ hidden_states,
188
+ new_encoder_hidden_states,
189
+ hidden_states_residual,
190
+ ) = self.call_Mn_blocks( # middle
191
+ hidden_states,
192
+ *args,
193
+ **kwargs,
194
+ )
195
+
196
+ torch._dynamo.graph_break()
197
+ if self.context_manager.is_cache_residual():
198
+ self.context_manager.set_Bn_buffer(
199
+ hidden_states_residual,
200
+ prefix=f"{self.cache_prefix}_Bn_residual",
201
+ )
202
+ else:
203
+ self.context_manager.set_Bn_buffer(
204
+ hidden_states,
205
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
206
+ )
207
+
208
+ if new_encoder_hidden_states is not None:
209
+ new_encoder_hidden_states_residual = (
210
+ new_encoder_hidden_states - old_encoder_hidden_states
211
+ )
212
+ if self.context_manager.is_encoder_cache_residual():
213
+ if new_encoder_hidden_states is not None:
214
+ self.context_manager.set_Bn_encoder_buffer(
215
+ new_encoder_hidden_states_residual,
216
+ prefix=f"{self.cache_prefix}_Bn_residual",
217
+ )
218
+ else:
219
+ if new_encoder_hidden_states is not None:
220
+ self.context_manager.set_Bn_encoder_buffer(
221
+ new_encoder_hidden_states_residual,
222
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
223
+ )
224
+ torch._dynamo.graph_break()
225
+ # Call last `n` blocks to further process the hidden states
226
+ # for higher precision.
227
+ if self.context_manager.Bn_compute_blocks() > 0:
228
+ hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
229
+ hidden_states,
230
+ *args,
231
+ **kwargs,
232
+ )
233
+
234
+ torch._dynamo.graph_break()
235
+
236
+ return self._process_forward_outputs(
237
+ hidden_states,
238
+ new_encoder_hidden_states,
239
+ )
240
+
241
+ def call_Fn_blocks(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ *args,
245
+ **kwargs,
246
+ ):
247
+ new_encoder_hidden_states = None
248
+ for block in self._Fn_blocks():
249
+ hidden_states = block(
250
+ hidden_states,
251
+ *args,
252
+ **kwargs,
253
+ )
254
+ hidden_states, new_encoder_hidden_states = (
255
+ self._process_block_outputs(hidden_states)
256
+ )
257
+
258
+ return hidden_states, new_encoder_hidden_states
259
+
260
+ def call_Mn_blocks(
261
+ self,
262
+ hidden_states: torch.Tensor,
263
+ *args,
264
+ **kwargs,
265
+ ):
266
+ original_hidden_states = hidden_states
267
+ new_encoder_hidden_states = None
268
+ for block in self._Mn_blocks():
269
+ hidden_states = block(
270
+ hidden_states,
271
+ *args,
272
+ **kwargs,
273
+ )
274
+
275
+ hidden_states, new_encoder_hidden_states = (
276
+ self._process_block_outputs(hidden_states)
277
+ )
278
+
279
+ # compute hidden_states residual
280
+ hidden_states = hidden_states.contiguous()
281
+ hidden_states_residual = hidden_states - original_hidden_states.to(
282
+ hidden_states.device
283
+ )
284
+
285
+ return (
286
+ hidden_states,
287
+ new_encoder_hidden_states,
288
+ hidden_states_residual,
289
+ )
290
+
291
+ def call_Bn_blocks(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ *args,
295
+ **kwargs,
296
+ ):
297
+ new_encoder_hidden_states = None
298
+ if self.context_manager.Bn_compute_blocks() == 0:
299
+ return hidden_states, new_encoder_hidden_states
300
+
301
+ for block in self._Bn_blocks():
302
+ hidden_states = block(
303
+ hidden_states,
304
+ *args,
305
+ **kwargs,
306
+ )
307
+
308
+ hidden_states, new_encoder_hidden_states = (
309
+ self._process_block_outputs(hidden_states)
310
+ )
311
+
312
+ return hidden_states, new_encoder_hidden_states
313
+
314
+
315
+ class PrunedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_3_4_5):
316
+ _supported_patterns = [
317
+ ForwardPattern.Pattern_3,
318
+ ForwardPattern.Pattern_4,
319
+ ForwardPattern.Pattern_5,
320
+ ]
321
+ pruned_blocks_step: int = 0 # number of pruned blocks in current step
322
+
323
+ def __init__(
324
+ self,
325
+ # 0. Transformer blocks configuration
326
+ transformer_blocks: torch.nn.ModuleList,
327
+ transformer: torch.nn.Module = None,
328
+ forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
329
+ check_forward_pattern: bool = True,
330
+ check_num_outputs: bool = True,
331
+ # 1. Prune context configuration
332
+ cache_prefix: str = None, # maybe un-need.
333
+ cache_context: PrunedContext | str = None,
334
+ context_manager: PrunedContextManager = None,
335
+ cache_type: CacheType = CacheType.DBPrune,
336
+ **kwargs,
337
+ ):
338
+ super().__init__(
339
+ # 0. Transformer blocks configuration
340
+ transformer_blocks,
341
+ transformer=transformer,
342
+ forward_pattern=forward_pattern,
343
+ check_forward_pattern=check_forward_pattern,
344
+ check_num_outputs=check_num_outputs,
345
+ # 1. Cache context configuration
346
+ cache_prefix=cache_prefix,
347
+ cache_context=cache_context,
348
+ context_manager=context_manager,
349
+ cache_type=cache_type,
350
+ **kwargs,
351
+ )
352
+ assert isinstance(
353
+ self.context_manager, PrunedContextManager
354
+ ), "context_manager must be PrunedContextManager for PrunedBlocks."
355
+ self.context_manager: PrunedContextManager = (
356
+ self.context_manager
357
+ ) # For type hint
358
+
359
+ @torch.compiler.disable
360
+ def _check_cache_type(self):
361
+ assert (
362
+ self.cache_type == CacheType.DBPrune
363
+ ), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
364
+
365
+ def forward(
366
+ self,
367
+ hidden_states: torch.Tensor,
368
+ *args,
369
+ **kwargs,
370
+ ):
371
+ self.pruned_blocks_step: int = 0 # reset for each step
372
+
373
+ # Use it's own cache context.
374
+ try:
375
+ self.context_manager.set_context(self.cache_context)
376
+ self._check_cache_params()
377
+ except ContextNotExistError as e:
378
+ logger.warning(f"context not exist: {e}, skip prune.")
379
+ hidden_states, new_encoder_hidden_states = self.call_blocks(
380
+ hidden_states,
381
+ *args,
382
+ **kwargs,
383
+ )
384
+ return self._process_forward_outputs(
385
+ hidden_states, new_encoder_hidden_states
386
+ )
387
+
388
+ self.context_manager.mark_step_begin()
389
+
390
+ if self._check_if_context_parallel_enabled(self.transformer_blocks[0]):
391
+ raise RuntimeError(
392
+ "Block level Context parallelism is not supported in PrunedBlocks."
393
+ )
394
+
395
+ # Call all blocks with prune strategy to process the hidden states.
396
+ new_encoder_hidden_states = None
397
+ for i, block in enumerate(self.transformer_blocks):
398
+ hidden_states, new_encoder_hidden_states = self.compute_or_prune(
399
+ i,
400
+ block,
401
+ hidden_states,
402
+ new_encoder_hidden_states,
403
+ *args,
404
+ **kwargs,
405
+ )
406
+
407
+ self.context_manager.add_pruned_block(self.pruned_blocks_step)
408
+ self.context_manager.add_actual_block(self.num_blocks)
409
+
410
+ return self._process_forward_outputs(
411
+ hidden_states,
412
+ new_encoder_hidden_states,
413
+ )
414
+
415
+ @property
416
+ @torch.compiler.disable
417
+ def num_blocks(self):
418
+ return len(self.transformer_blocks)
419
+
420
+ @torch.compiler.disable
421
+ def _skip_prune(self, block_id: int) -> bool:
422
+ # Wrap for non compiled mode.
423
+ return block_id in self.context_manager.get_non_prune_blocks_ids(
424
+ self.num_blocks
425
+ )
426
+
427
+ @torch.compiler.disable
428
+ def _maybe_prune(
429
+ self,
430
+ block_id: int, # Block index in the transformer blocks
431
+ hidden_states: torch.Tensor, # hidden_states or residual
432
+ prefix: str = "Bn_original", # prev step name for single blocks
433
+ ):
434
+ # Wrap for non compiled mode.
435
+ can_use_prune = False
436
+ if not self._skip_prune(block_id):
437
+ can_use_prune = self.context_manager.can_prune(
438
+ hidden_states, # curr step
439
+ parallelized=self._is_parallelized(),
440
+ prefix=prefix, # prev step
441
+ )
442
+ self.pruned_blocks_step += int(can_use_prune)
443
+ return can_use_prune
444
+
445
+ def compute_or_prune(
446
+ self,
447
+ block_id: int, # Block index in the transformer blocks
448
+ # Below are the inputs to the block
449
+ block, # The transformer block to be executed
450
+ hidden_states: torch.Tensor,
451
+ new_encoder_hidden_states: torch.Tensor | None,
452
+ *args,
453
+ **kwargs,
454
+ ):
455
+ original_hidden_states = hidden_states
456
+ original_encoder_hidden_states = new_encoder_hidden_states
457
+
458
+ can_use_prune = self._maybe_prune(
459
+ block_id,
460
+ hidden_states,
461
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
462
+ )
463
+
464
+ # Prune steps: Prune current block and reuse the cached
465
+ # residuals for hidden states approximate.
466
+ torch._dynamo.graph_break()
467
+ if can_use_prune:
468
+ self.context_manager.add_pruned_step()
469
+ hidden_states, new_encoder_hidden_states = (
470
+ self.context_manager.apply_prune(
471
+ hidden_states,
472
+ new_encoder_hidden_states,
473
+ prefix=(
474
+ f"{self.cache_prefix}_{block_id}_Bn_residual"
475
+ if self.context_manager.is_cache_residual()
476
+ else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
477
+ ),
478
+ encoder_prefix=(
479
+ f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
480
+ if self.context_manager.is_encoder_cache_residual()
481
+ else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
482
+ ),
483
+ )
484
+ )
485
+ torch._dynamo.graph_break()
486
+ else:
487
+ # Normal steps: Compute the block and cache the residuals.
488
+ hidden_states = block(
489
+ hidden_states,
490
+ *args,
491
+ **kwargs,
492
+ )
493
+ hidden_states, new_encoder_hidden_states = (
494
+ self._process_block_outputs(
495
+ hidden_states, new_encoder_hidden_states
496
+ )
497
+ )
498
+ if not self._skip_prune(block_id):
499
+ hidden_states = hidden_states.contiguous()
500
+ hidden_states_residual = hidden_states - original_hidden_states
501
+
502
+ if (
503
+ new_encoder_hidden_states is not None
504
+ and original_encoder_hidden_states is not None
505
+ ):
506
+ new_encoder_hidden_states = (
507
+ new_encoder_hidden_states.contiguous()
508
+ )
509
+ new_encoder_hidden_states_residual = (
510
+ new_encoder_hidden_states
511
+ - original_encoder_hidden_states
512
+ )
513
+ else:
514
+ new_encoder_hidden_states_residual = None
515
+
516
+ self.context_manager.set_Fn_buffer(
517
+ original_hidden_states,
518
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
519
+ )
520
+ if self.context_manager.is_cache_residual():
521
+ self.context_manager.set_Bn_buffer(
522
+ hidden_states_residual,
523
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
524
+ )
525
+ else:
526
+ self.context_manager.set_Bn_buffer(
527
+ hidden_states,
528
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
529
+ )
530
+ if new_encoder_hidden_states_residual is not None:
531
+ if self.context_manager.is_encoder_cache_residual():
532
+ self.context_manager.set_Bn_encoder_buffer(
533
+ new_encoder_hidden_states_residual,
534
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
535
+ )
536
+ else:
537
+ self.context_manager.set_Bn_encoder_buffer(
538
+ new_encoder_hidden_states_residual,
539
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
540
+ )
541
+ torch._dynamo.graph_break()
542
+
543
+ return hidden_states, new_encoder_hidden_states