cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,748 @@
1
+ import inspect
2
+ import logging
3
+ import torch
4
+ import torch.distributed as dist
5
+ from diffusers.hooks import HookRegistry
6
+
7
+ try:
8
+ from diffusers.hooks.context_parallel import ContextParallelSplitHook
9
+ except ImportError:
10
+ ContextParallelSplitHook = None
11
+ raise UserWarning(
12
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
13
+ "Please install latest version of diffusers from source: \n"
14
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
15
+ )
16
+ from cache_dit.caching.cache_contexts.cache_context import CachedContext
17
+ from cache_dit.caching.cache_contexts.prune_context import PrunedContext
18
+ from cache_dit.caching.cache_contexts.cache_manager import (
19
+ CachedContextManager,
20
+ ContextNotExistError,
21
+ )
22
+ from cache_dit.caching.cache_contexts.prune_manager import (
23
+ PrunedContextManager,
24
+ )
25
+ from cache_dit.caching import ForwardPattern
26
+ from cache_dit.caching.cache_types import CacheType
27
+ from cache_dit.logger import init_logger
28
+
29
+ logger = init_logger(__name__)
30
+
31
+
32
+ class CachedBlocks_Pattern_Base(torch.nn.Module):
33
+ _supported_patterns = [
34
+ ForwardPattern.Pattern_0,
35
+ ForwardPattern.Pattern_1,
36
+ ForwardPattern.Pattern_2,
37
+ ]
38
+
39
+ def __init__(
40
+ self,
41
+ # 0. Transformer blocks configuration
42
+ transformer_blocks: torch.nn.ModuleList,
43
+ transformer: torch.nn.Module = None,
44
+ forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
45
+ check_forward_pattern: bool = True,
46
+ check_num_outputs: bool = True,
47
+ # 1. Cache context configuration
48
+ cache_prefix: str = None, # maybe un-need.
49
+ cache_context: CachedContext | str = None,
50
+ context_manager: CachedContextManager = None,
51
+ cache_type: CacheType = CacheType.DBCache,
52
+ **kwargs,
53
+ ):
54
+ super().__init__()
55
+
56
+ # 0. Transformer blocks configuration
57
+ self.transformer = transformer
58
+ self.transformer_blocks = transformer_blocks
59
+ self.forward_pattern = forward_pattern
60
+ self.check_forward_pattern = check_forward_pattern
61
+ self.check_num_outputs = check_num_outputs
62
+ # 1. Cache context configuration
63
+ self.cache_prefix = cache_prefix
64
+ self.cache_context = cache_context
65
+ self.context_manager = context_manager
66
+ self.cache_type = cache_type
67
+
68
+ self._check_forward_pattern()
69
+ self._check_cache_type()
70
+ logger.info(
71
+ f"Match Blocks: {self.__class__.__name__}, for "
72
+ f"{self.cache_prefix}, cache_context: {self.cache_context}, "
73
+ f"context_manager: {self.context_manager.name}."
74
+ )
75
+
76
+ def _check_forward_pattern(self):
77
+ if not self.check_forward_pattern:
78
+ logger.warning(
79
+ f"Skipped Forward Pattern Check: {self.forward_pattern}"
80
+ )
81
+ return
82
+
83
+ assert (
84
+ self.forward_pattern.Supported
85
+ and self.forward_pattern in self._supported_patterns
86
+ ), f"Pattern {self.forward_pattern} is not supported now!"
87
+
88
+ if self.transformer_blocks is not None:
89
+ for block in self.transformer_blocks:
90
+ # Special case for HiDreamBlock
91
+ if hasattr(block, "block"):
92
+ if isinstance(block.block, torch.nn.Module):
93
+ block = block.block
94
+
95
+ forward_parameters = set(
96
+ inspect.signature(block.forward).parameters.keys()
97
+ )
98
+
99
+ if self.check_num_outputs:
100
+ num_outputs = str(
101
+ inspect.signature(block.forward).return_annotation
102
+ ).count("torch.Tensor")
103
+
104
+ if num_outputs > 0:
105
+ assert len(self.forward_pattern.Out) == num_outputs, (
106
+ f"The number of block's outputs is {num_outputs} don't not "
107
+ f"match the number of the pattern: {self.forward_pattern}, "
108
+ f"Out: {len(self.forward_pattern.Out)}."
109
+ )
110
+
111
+ for required_param in self.forward_pattern.In:
112
+ assert (
113
+ required_param in forward_parameters
114
+ ), f"The input parameters must contains: {required_param}."
115
+
116
+ @torch.compiler.disable
117
+ def _check_cache_type(self):
118
+ assert (
119
+ self.cache_type == CacheType.DBCache
120
+ ), f"Cache type {self.cache_type} is not supported for CachedBlocks."
121
+
122
+ @torch.compiler.disable
123
+ def _check_cache_params(self):
124
+ self._check_cache_type()
125
+ assert self.context_manager.Fn_compute_blocks() <= len(
126
+ self.transformer_blocks
127
+ ), (
128
+ f"Fn_compute_blocks {self.context_manager.Fn_compute_blocks()} must be less than "
129
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
130
+ )
131
+ assert self.context_manager.Bn_compute_blocks() <= len(
132
+ self.transformer_blocks
133
+ ), (
134
+ f"Bn_compute_blocks {self.context_manager.Bn_compute_blocks()} must be less than "
135
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
136
+ )
137
+
138
+ def call_blocks(
139
+ self,
140
+ hidden_states: torch.Tensor,
141
+ encoder_hidden_states: torch.Tensor,
142
+ *args,
143
+ **kwargs,
144
+ ):
145
+ # Call all blocks to process the hidden states without cache.
146
+ for block in self.transformer_blocks:
147
+ hidden_states = block(
148
+ hidden_states,
149
+ encoder_hidden_states,
150
+ *args,
151
+ **kwargs,
152
+ )
153
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
154
+ hidden_states, encoder_hidden_states
155
+ )
156
+ return hidden_states, encoder_hidden_states
157
+
158
+ @torch.compiler.disable
159
+ def _process_block_outputs(
160
+ self,
161
+ hidden_states: torch.Tensor | tuple,
162
+ encoder_hidden_states: torch.Tensor | None,
163
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
164
+ if not isinstance(hidden_states, torch.Tensor):
165
+ hidden_states, encoder_hidden_states = hidden_states
166
+ if not self.forward_pattern.Return_H_First:
167
+ hidden_states, encoder_hidden_states = (
168
+ encoder_hidden_states,
169
+ hidden_states,
170
+ )
171
+ return hidden_states, encoder_hidden_states
172
+
173
+ @torch.compiler.disable
174
+ def _process_forward_outputs(
175
+ self,
176
+ hidden_states: torch.Tensor,
177
+ encoder_hidden_states: torch.Tensor | None,
178
+ ) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor:
179
+ return (
180
+ hidden_states
181
+ if self.forward_pattern.Return_H_Only
182
+ else (
183
+ (hidden_states, encoder_hidden_states)
184
+ if self.forward_pattern.Return_H_First
185
+ else (encoder_hidden_states, hidden_states)
186
+ )
187
+ )
188
+
189
+ @torch.compiler.disable
190
+ def _check_if_context_parallel_enabled(
191
+ self,
192
+ module: torch.nn.Module,
193
+ ) -> bool:
194
+ if ContextParallelSplitHook is None:
195
+ return False
196
+ if hasattr(module, "_diffusers_hook"):
197
+ _diffusers_hook: HookRegistry = module._diffusers_hook
198
+ for hook in _diffusers_hook.hooks.values():
199
+ if isinstance(hook, ContextParallelSplitHook):
200
+ return True
201
+ return False
202
+
203
+ def _get_Fn_residual(
204
+ self,
205
+ original_hidden_states: torch.Tensor,
206
+ hidden_states: torch.Tensor,
207
+ ) -> torch.Tensor:
208
+ # NOTE: Make cases compatible with context parallelism while using
209
+ # block level cp plan, e.g., WanTransformer3DModel. The shape of
210
+ # `original_hidden_states` and `hidden_states` after Fn maybe
211
+ # different due to seqlen split in context parallelism.
212
+ if self._check_if_context_parallel_enabled(
213
+ self.transformer_blocks[0]
214
+ ) and (original_hidden_states.shape != hidden_states.shape):
215
+ # Force use `hidden_states` as the Fn states residual for subsequent
216
+ # dynamic cache processing if the shape is different.
217
+ Fn_hidden_states_residual = hidden_states
218
+ if logger.isEnabledFor(logging.DEBUG):
219
+ logger.debug(
220
+ f"Context parallelism is enabled in Fn blocks, and the shape of "
221
+ f"original_hidden_states {original_hidden_states.shape} and "
222
+ f"hidden_states {hidden_states.shape} are different after Fn blocks. "
223
+ f"Use hidden_states as Fn_hidden_states_residual directly."
224
+ )
225
+ else:
226
+ Fn_hidden_states_residual = (
227
+ hidden_states - original_hidden_states.to(hidden_states.device)
228
+ )
229
+ return Fn_hidden_states_residual
230
+
231
+ def forward(
232
+ self,
233
+ hidden_states: torch.Tensor,
234
+ encoder_hidden_states: torch.Tensor,
235
+ *args,
236
+ **kwargs,
237
+ ):
238
+ # Use it's own cache context.
239
+ try:
240
+ self.context_manager.set_context(self.cache_context)
241
+ self._check_cache_params()
242
+ except ContextNotExistError as e:
243
+ logger.warning(f"Cache context not exist: {e}, skip cache.")
244
+ # Call all blocks to process the hidden states.
245
+ hidden_states, encoder_hidden_states = self.call_blocks(
246
+ hidden_states,
247
+ encoder_hidden_states,
248
+ *args,
249
+ **kwargs,
250
+ )
251
+ return self._process_forward_outputs(
252
+ hidden_states,
253
+ encoder_hidden_states,
254
+ )
255
+
256
+ original_hidden_states = hidden_states
257
+ # Call first `n` blocks to process the hidden states for
258
+ # more stable diff calculation.
259
+ hidden_states, encoder_hidden_states = self.call_Fn_blocks(
260
+ hidden_states,
261
+ encoder_hidden_states,
262
+ *args,
263
+ **kwargs,
264
+ )
265
+
266
+ Fn_hidden_states_residual = self._get_Fn_residual(
267
+ original_hidden_states, hidden_states
268
+ )
269
+ del original_hidden_states
270
+
271
+ self.context_manager.mark_step_begin()
272
+ # Residual L1 diff or Hidden States L1 diff
273
+ can_use_cache = self.context_manager.can_cache(
274
+ (
275
+ Fn_hidden_states_residual
276
+ if not self.context_manager.is_l1_diff_enabled()
277
+ else hidden_states
278
+ ),
279
+ parallelized=self._is_parallelized(),
280
+ prefix=(
281
+ f"{self.cache_prefix}_Fn_residual"
282
+ if not self.context_manager.is_l1_diff_enabled()
283
+ else f"{self.cache_prefix}_Fn_hidden_states"
284
+ ),
285
+ )
286
+
287
+ torch._dynamo.graph_break()
288
+ if can_use_cache:
289
+ self.context_manager.add_cached_step()
290
+ del Fn_hidden_states_residual
291
+ hidden_states, encoder_hidden_states = (
292
+ self.context_manager.apply_cache(
293
+ hidden_states,
294
+ encoder_hidden_states,
295
+ prefix=(
296
+ f"{self.cache_prefix}_Bn_residual"
297
+ if self.context_manager.is_cache_residual()
298
+ else f"{self.cache_prefix}_Bn_hidden_states"
299
+ ),
300
+ encoder_prefix=(
301
+ f"{self.cache_prefix}_Bn_residual"
302
+ if self.context_manager.is_encoder_cache_residual()
303
+ else f"{self.cache_prefix}_Bn_hidden_states"
304
+ ),
305
+ )
306
+ )
307
+ torch._dynamo.graph_break()
308
+ # Call last `n` blocks to further process the hidden states
309
+ # for higher precision.
310
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
311
+ hidden_states,
312
+ encoder_hidden_states,
313
+ *args,
314
+ **kwargs,
315
+ )
316
+ else:
317
+ self.context_manager.set_Fn_buffer(
318
+ Fn_hidden_states_residual,
319
+ prefix=f"{self.cache_prefix}_Fn_residual",
320
+ )
321
+ if self.context_manager.is_l1_diff_enabled():
322
+ # for hidden states L1 diff
323
+ self.context_manager.set_Fn_buffer(
324
+ hidden_states,
325
+ f"{self.cache_prefix}_Fn_hidden_states",
326
+ )
327
+ del Fn_hidden_states_residual
328
+ torch._dynamo.graph_break()
329
+ (
330
+ hidden_states,
331
+ encoder_hidden_states,
332
+ hidden_states_residual,
333
+ encoder_hidden_states_residual,
334
+ ) = self.call_Mn_blocks( # middle
335
+ hidden_states,
336
+ encoder_hidden_states,
337
+ *args,
338
+ **kwargs,
339
+ )
340
+ torch._dynamo.graph_break()
341
+ if self.context_manager.is_cache_residual():
342
+ self.context_manager.set_Bn_buffer(
343
+ hidden_states_residual,
344
+ prefix=f"{self.cache_prefix}_Bn_residual",
345
+ )
346
+ else:
347
+ self.context_manager.set_Bn_buffer(
348
+ hidden_states,
349
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
350
+ )
351
+
352
+ if self.context_manager.is_encoder_cache_residual():
353
+ self.context_manager.set_Bn_encoder_buffer(
354
+ encoder_hidden_states_residual,
355
+ prefix=f"{self.cache_prefix}_Bn_residual",
356
+ )
357
+ else:
358
+ self.context_manager.set_Bn_encoder_buffer(
359
+ encoder_hidden_states,
360
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
361
+ )
362
+ torch._dynamo.graph_break()
363
+ # Call last `n` blocks to further process the hidden states
364
+ # for higher precision.
365
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
366
+ hidden_states,
367
+ encoder_hidden_states,
368
+ *args,
369
+ **kwargs,
370
+ )
371
+
372
+ # patch cached stats for blocks or remove it.
373
+ torch._dynamo.graph_break()
374
+
375
+ return self._process_forward_outputs(
376
+ hidden_states,
377
+ encoder_hidden_states,
378
+ )
379
+
380
+ @torch.compiler.disable
381
+ def _is_parallelized(self):
382
+ # Compatible with distributed inference.
383
+ return any(
384
+ (
385
+ all(
386
+ (
387
+ self.transformer is not None,
388
+ getattr(self.transformer, "_is_parallelized", False),
389
+ )
390
+ ),
391
+ (dist.is_initialized() and dist.get_world_size() > 1),
392
+ )
393
+ )
394
+
395
+ @torch.compiler.disable
396
+ def _is_in_cache_step(self):
397
+ # Check if the current step is in cache steps.
398
+ # If so, we can skip some Bn blocks and directly
399
+ # use the cached values.
400
+ return (
401
+ self.context_manager.get_current_step()
402
+ in self.context_manager.get_cached_steps()
403
+ ) or (
404
+ self.context_manager.get_current_step()
405
+ in self.context_manager.get_cfg_cached_steps()
406
+ )
407
+
408
+ @torch.compiler.disable
409
+ def _Fn_blocks(self):
410
+ # Select first `n` blocks to process the hidden states for
411
+ # more stable diff calculation.
412
+ # Fn: [0,...,n-1]
413
+ selected_Fn_blocks = self.transformer_blocks[
414
+ : self.context_manager.Fn_compute_blocks()
415
+ ]
416
+ return selected_Fn_blocks
417
+
418
+ @torch.compiler.disable
419
+ def _Mn_blocks(self): # middle blocks
420
+ # M(N-2n): only transformer_blocks [n,...,N-n], middle
421
+ if self.context_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
422
+ selected_Mn_blocks = self.transformer_blocks[
423
+ self.context_manager.Fn_compute_blocks() :
424
+ ]
425
+ else:
426
+ selected_Mn_blocks = self.transformer_blocks[
427
+ self.context_manager.Fn_compute_blocks() : -self.context_manager.Bn_compute_blocks()
428
+ ]
429
+ return selected_Mn_blocks
430
+
431
+ @torch.compiler.disable
432
+ def _Bn_blocks(self):
433
+ # Bn: transformer_blocks [N-n+1,...,N-1]
434
+ selected_Bn_blocks = self.transformer_blocks[
435
+ -self.context_manager.Bn_compute_blocks() :
436
+ ]
437
+ return selected_Bn_blocks
438
+
439
+ def call_Fn_blocks(
440
+ self,
441
+ hidden_states: torch.Tensor,
442
+ encoder_hidden_states: torch.Tensor,
443
+ *args,
444
+ **kwargs,
445
+ ):
446
+ for block in self._Fn_blocks():
447
+ hidden_states = block(
448
+ hidden_states,
449
+ encoder_hidden_states,
450
+ *args,
451
+ **kwargs,
452
+ )
453
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
454
+ hidden_states, encoder_hidden_states
455
+ )
456
+
457
+ return hidden_states, encoder_hidden_states
458
+
459
+ def call_Mn_blocks(
460
+ self,
461
+ hidden_states: torch.Tensor,
462
+ encoder_hidden_states: torch.Tensor,
463
+ *args,
464
+ **kwargs,
465
+ ):
466
+ original_hidden_states = hidden_states
467
+ original_encoder_hidden_states = encoder_hidden_states
468
+ for block in self._Mn_blocks():
469
+ hidden_states = block(
470
+ hidden_states,
471
+ encoder_hidden_states,
472
+ *args,
473
+ **kwargs,
474
+ )
475
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
476
+ hidden_states, encoder_hidden_states
477
+ )
478
+
479
+ # compute hidden_states residual
480
+ hidden_states = hidden_states.contiguous()
481
+
482
+ hidden_states_residual = hidden_states - original_hidden_states
483
+
484
+ if (
485
+ encoder_hidden_states is not None
486
+ and original_encoder_hidden_states is not None
487
+ ):
488
+ encoder_hidden_states = encoder_hidden_states.contiguous()
489
+ encoder_hidden_states_residual = (
490
+ encoder_hidden_states - original_encoder_hidden_states
491
+ )
492
+ else:
493
+ encoder_hidden_states_residual = None
494
+
495
+ return (
496
+ hidden_states,
497
+ encoder_hidden_states,
498
+ hidden_states_residual,
499
+ encoder_hidden_states_residual,
500
+ )
501
+
502
+ def call_Bn_blocks(
503
+ self,
504
+ hidden_states: torch.Tensor,
505
+ encoder_hidden_states: torch.Tensor,
506
+ *args,
507
+ **kwargs,
508
+ ):
509
+ if self.context_manager.Bn_compute_blocks() == 0:
510
+ return hidden_states, encoder_hidden_states
511
+
512
+ for block in self._Bn_blocks():
513
+ hidden_states = block(
514
+ hidden_states,
515
+ encoder_hidden_states,
516
+ *args,
517
+ **kwargs,
518
+ )
519
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
520
+ hidden_states, encoder_hidden_states
521
+ )
522
+
523
+ return hidden_states, encoder_hidden_states
524
+
525
+
526
+ class PrunedBlocks_Pattern_Base(CachedBlocks_Pattern_Base):
527
+ pruned_blocks_step: int = 0 # number of pruned blocks in current step
528
+
529
+ def __init__(
530
+ self,
531
+ # 0. Transformer blocks configuration
532
+ transformer_blocks: torch.nn.ModuleList,
533
+ transformer: torch.nn.Module = None,
534
+ forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
535
+ check_forward_pattern: bool = True,
536
+ check_num_outputs: bool = True,
537
+ # 1. Prune context configuration
538
+ cache_prefix: str = None, # maybe un-need.
539
+ cache_context: PrunedContext | str = None,
540
+ context_manager: PrunedContextManager = None,
541
+ cache_type: CacheType = CacheType.DBPrune,
542
+ **kwargs,
543
+ ):
544
+ super().__init__(
545
+ # 0. Transformer blocks configuration
546
+ transformer_blocks,
547
+ transformer=transformer,
548
+ forward_pattern=forward_pattern,
549
+ check_forward_pattern=check_forward_pattern,
550
+ check_num_outputs=check_num_outputs,
551
+ # 1. Cache context configuration
552
+ cache_prefix=cache_prefix,
553
+ cache_context=cache_context,
554
+ context_manager=context_manager,
555
+ cache_type=cache_type,
556
+ **kwargs,
557
+ )
558
+ assert isinstance(
559
+ self.context_manager, PrunedContextManager
560
+ ), "context_manager must be PrunedContextManager for PrunedBlocks."
561
+ self.context_manager: PrunedContextManager = (
562
+ self.context_manager
563
+ ) # For type hint
564
+
565
+ @torch.compiler.disable
566
+ def _check_cache_type(self):
567
+ assert (
568
+ self.cache_type == CacheType.DBPrune
569
+ ), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
570
+
571
+ def forward(
572
+ self,
573
+ hidden_states: torch.Tensor,
574
+ encoder_hidden_states: torch.Tensor,
575
+ *args,
576
+ **kwargs,
577
+ ):
578
+ self.pruned_blocks_step: int = 0 # reset for each step
579
+
580
+ # Use it's own cache context.
581
+ try:
582
+ self.context_manager.set_context(self.cache_context)
583
+ self._check_cache_params()
584
+ except ContextNotExistError as e:
585
+ logger.warning(f"Cache context not exist: {e}, skip prune.")
586
+ # Fallback to call all blocks to process the hidden states w/o prune.
587
+ hidden_states, encoder_hidden_states = self.call_blocks(
588
+ hidden_states,
589
+ encoder_hidden_states,
590
+ *args,
591
+ **kwargs,
592
+ )
593
+ return self._process_forward_outputs(
594
+ hidden_states,
595
+ encoder_hidden_states,
596
+ )
597
+
598
+ self.context_manager.mark_step_begin()
599
+
600
+ if self._check_if_context_parallel_enabled(self.transformer_blocks[0]):
601
+ raise RuntimeError(
602
+ "Block level Context parallelism is not supported in PrunedBlocks."
603
+ )
604
+
605
+ # Call all blocks with prune strategy to process the hidden states.
606
+ for i, block in enumerate(self.transformer_blocks):
607
+ hidden_states, encoder_hidden_states = self.compute_or_prune(
608
+ i,
609
+ block,
610
+ hidden_states,
611
+ encoder_hidden_states,
612
+ *args,
613
+ **kwargs,
614
+ )
615
+
616
+ self.context_manager.add_pruned_block(self.pruned_blocks_step)
617
+ self.context_manager.add_actual_block(self.num_blocks)
618
+
619
+ return self._process_forward_outputs(
620
+ hidden_states,
621
+ encoder_hidden_states,
622
+ )
623
+
624
+ @property
625
+ @torch.compiler.disable
626
+ def num_blocks(self):
627
+ return len(self.transformer_blocks)
628
+
629
+ @torch.compiler.disable
630
+ def _skip_prune(self, block_id: int) -> bool:
631
+ # Wrap for non compiled mode.
632
+ return block_id in self.context_manager.get_non_prune_blocks_ids(
633
+ self.num_blocks
634
+ )
635
+
636
+ @torch.compiler.disable
637
+ def _maybe_prune(
638
+ self,
639
+ block_id: int, # Block index in the transformer blocks
640
+ hidden_states: torch.Tensor, # hidden_states or residual
641
+ prefix: str = "Bn_original", # prev step name for single blocks
642
+ ):
643
+ # Wrap for non compiled mode.
644
+ can_use_prune = False
645
+ if not self._skip_prune(block_id):
646
+ can_use_prune = self.context_manager.can_prune(
647
+ hidden_states, # curr step
648
+ parallelized=self._is_parallelized(),
649
+ prefix=prefix, # prev step
650
+ )
651
+ self.pruned_blocks_step += int(can_use_prune)
652
+ return can_use_prune
653
+
654
+ def compute_or_prune(
655
+ self,
656
+ block_id: int, # Block index in the transformer blocks
657
+ # Below are the inputs to the block
658
+ block, # The transformer block to be executed
659
+ hidden_states: torch.Tensor,
660
+ encoder_hidden_states: torch.Tensor,
661
+ *args,
662
+ **kwargs,
663
+ ):
664
+ original_hidden_states = hidden_states
665
+ original_encoder_hidden_states = encoder_hidden_states
666
+
667
+ can_use_prune = self._maybe_prune(
668
+ block_id,
669
+ hidden_states,
670
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
671
+ )
672
+
673
+ # Prune steps: Prune current block and reuse the cached
674
+ # residuals for hidden states approximate.
675
+ torch._dynamo.graph_break()
676
+ if can_use_prune:
677
+ self.context_manager.add_pruned_step()
678
+ hidden_states, encoder_hidden_states = (
679
+ self.context_manager.apply_prune(
680
+ hidden_states,
681
+ encoder_hidden_states,
682
+ prefix=(
683
+ f"{self.cache_prefix}_{block_id}_Bn_residual"
684
+ if self.context_manager.is_cache_residual()
685
+ else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
686
+ ),
687
+ encoder_prefix=(
688
+ f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
689
+ if self.context_manager.is_encoder_cache_residual()
690
+ else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
691
+ ),
692
+ )
693
+ )
694
+ torch._dynamo.graph_break()
695
+ else:
696
+ # Normal steps: Compute the block and cache the residuals.
697
+ hidden_states = block(
698
+ hidden_states,
699
+ encoder_hidden_states,
700
+ *args,
701
+ **kwargs,
702
+ )
703
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
704
+ hidden_states, encoder_hidden_states
705
+ )
706
+ if not self._skip_prune(block_id):
707
+ hidden_states = hidden_states.contiguous()
708
+ hidden_states_residual = hidden_states - original_hidden_states
709
+
710
+ if (
711
+ encoder_hidden_states is not None
712
+ and original_encoder_hidden_states is not None
713
+ ):
714
+ encoder_hidden_states = encoder_hidden_states.contiguous()
715
+ encoder_hidden_states_residual = (
716
+ encoder_hidden_states - original_encoder_hidden_states
717
+ )
718
+ else:
719
+ encoder_hidden_states_residual = None
720
+
721
+ self.context_manager.set_Fn_buffer(
722
+ original_hidden_states,
723
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
724
+ )
725
+ if self.context_manager.is_cache_residual():
726
+ self.context_manager.set_Bn_buffer(
727
+ hidden_states_residual,
728
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
729
+ )
730
+ else:
731
+ self.context_manager.set_Bn_buffer(
732
+ hidden_states,
733
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
734
+ )
735
+ if encoder_hidden_states_residual is not None:
736
+ if self.context_manager.is_encoder_cache_residual():
737
+ self.context_manager.set_Bn_encoder_buffer(
738
+ encoder_hidden_states_residual,
739
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
740
+ )
741
+ else:
742
+ self.context_manager.set_Bn_encoder_buffer(
743
+ encoder_hidden_states_residual,
744
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
745
+ )
746
+ torch._dynamo.graph_break()
747
+
748
+ return hidden_states, encoder_hidden_states