cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -1,306 +0,0 @@
1
- import torch
2
-
3
- from cache_dit.cache_factory import ForwardPattern
4
- from cache_dit.cache_factory.cache_contexts.cache_manager import (
5
- CacheNotExistError,
6
- )
7
- from cache_dit.cache_factory.cache_blocks.pattern_base import (
8
- CachedBlocks_Pattern_Base,
9
- )
10
- from cache_dit.logger import init_logger
11
-
12
- logger = init_logger(__name__)
13
-
14
-
15
- class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
16
- _supported_patterns = [
17
- ForwardPattern.Pattern_3,
18
- ForwardPattern.Pattern_4,
19
- ForwardPattern.Pattern_5,
20
- ]
21
-
22
- def call_blocks(
23
- self,
24
- hidden_states: torch.Tensor,
25
- *args,
26
- **kwargs,
27
- ):
28
- # Call all blocks to process the hidden states without cache.
29
- new_encoder_hidden_states = None
30
- for block in self.transformer_blocks:
31
- hidden_states = block(
32
- hidden_states,
33
- *args,
34
- **kwargs,
35
- )
36
- hidden_states, new_encoder_hidden_states = (
37
- self._process_block_outputs(hidden_states)
38
- )
39
-
40
- return hidden_states, new_encoder_hidden_states
41
-
42
- @torch.compiler.disable
43
- def _process_block_outputs(
44
- self, hidden_states: torch.Tensor | tuple
45
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
46
- # Process the outputs for the block.
47
- new_encoder_hidden_states = None
48
- if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
49
- if len(hidden_states) == 2:
50
- if isinstance(hidden_states[1], torch.Tensor):
51
- hidden_states, new_encoder_hidden_states = hidden_states
52
- if not self.forward_pattern.Return_H_First:
53
- hidden_states, new_encoder_hidden_states = (
54
- new_encoder_hidden_states,
55
- hidden_states,
56
- )
57
- elif isinstance(hidden_states[0], torch.Tensor):
58
- hidden_states = hidden_states[0]
59
- else:
60
- raise ValueError("Unexpected hidden_states format.")
61
- else:
62
- assert (
63
- len(hidden_states) == 1
64
- ), f"Unexpected output length: {len(hidden_states)}"
65
- hidden_states = hidden_states[0]
66
- return hidden_states, new_encoder_hidden_states
67
-
68
- @torch.compiler.disable
69
- def _process_forward_outputs(
70
- self,
71
- hidden_states: torch.Tensor,
72
- new_encoder_hidden_states: torch.Tensor | None,
73
- ) -> (
74
- torch.Tensor
75
- | tuple[torch.Tensor, torch.Tensor]
76
- | tuple[torch.Tensor, None]
77
- ):
78
- if self.forward_pattern.Return_H_Only:
79
- return hidden_states
80
- else:
81
- if self.forward_pattern.Return_H_First:
82
- return (hidden_states, new_encoder_hidden_states)
83
- else:
84
- return (new_encoder_hidden_states, hidden_states)
85
-
86
- def forward(
87
- self,
88
- hidden_states: torch.Tensor,
89
- *args,
90
- **kwargs,
91
- ):
92
- # Use it's own cache context.
93
- try:
94
- self.cache_manager.set_context(self.cache_context)
95
- self._check_cache_params()
96
- except CacheNotExistError as e:
97
- logger.warning(f"Cache context not exist: {e}, skip cache.")
98
- hidden_states, new_encoder_hidden_states = self.call_blocks(
99
- hidden_states,
100
- *args,
101
- **kwargs,
102
- )
103
- return self._process_forward_outputs(
104
- hidden_states, new_encoder_hidden_states
105
- )
106
-
107
- original_hidden_states = hidden_states
108
- # Call first `n` blocks to process the hidden states for
109
- # more stable diff calculation.
110
- hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
111
- hidden_states,
112
- *args,
113
- **kwargs,
114
- )
115
-
116
- Fn_hidden_states_residual = hidden_states - original_hidden_states.to(
117
- hidden_states.device
118
- )
119
- del original_hidden_states
120
-
121
- self.cache_manager.mark_step_begin()
122
- # Residual L1 diff or Hidden States L1 diff
123
- can_use_cache = self.cache_manager.can_cache(
124
- (
125
- Fn_hidden_states_residual
126
- if not self.cache_manager.is_l1_diff_enabled()
127
- else hidden_states
128
- ),
129
- parallelized=self._is_parallelized(),
130
- prefix=(
131
- f"{self.cache_prefix}_Fn_residual"
132
- if not self.cache_manager.is_l1_diff_enabled()
133
- else f"{self.cache_prefix}_Fn_hidden_states"
134
- ),
135
- )
136
-
137
- torch._dynamo.graph_break()
138
- if can_use_cache:
139
- self.cache_manager.add_cached_step()
140
- del Fn_hidden_states_residual
141
- hidden_states, new_encoder_hidden_states = (
142
- self.cache_manager.apply_cache(
143
- hidden_states,
144
- new_encoder_hidden_states, # encoder_hidden_states not use cache
145
- prefix=(
146
- f"{self.cache_prefix}_Bn_residual"
147
- if self.cache_manager.is_cache_residual()
148
- else f"{self.cache_prefix}_Bn_hidden_states"
149
- ),
150
- encoder_prefix=(
151
- f"{self.cache_prefix}_Bn_residual"
152
- if self.cache_manager.is_encoder_cache_residual()
153
- else f"{self.cache_prefix}_Bn_hidden_states"
154
- ),
155
- )
156
- )
157
- torch._dynamo.graph_break()
158
- # Call last `n` blocks to further process the hidden states
159
- # for higher precision.
160
- if self.cache_manager.Bn_compute_blocks() > 0:
161
- hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
162
- hidden_states,
163
- *args,
164
- **kwargs,
165
- )
166
- else:
167
- self.cache_manager.set_Fn_buffer(
168
- Fn_hidden_states_residual,
169
- prefix=f"{self.cache_prefix}_Fn_residual",
170
- )
171
- if self.cache_manager.is_l1_diff_enabled():
172
- # for hidden states L1 diff
173
- self.cache_manager.set_Fn_buffer(
174
- hidden_states,
175
- f"{self.cache_prefix}_Fn_hidden_states",
176
- )
177
- del Fn_hidden_states_residual
178
- torch._dynamo.graph_break()
179
- old_encoder_hidden_states = new_encoder_hidden_states
180
- (
181
- hidden_states,
182
- new_encoder_hidden_states,
183
- hidden_states_residual,
184
- ) = self.call_Mn_blocks( # middle
185
- hidden_states,
186
- *args,
187
- **kwargs,
188
- )
189
-
190
- torch._dynamo.graph_break()
191
- if self.cache_manager.is_cache_residual():
192
- self.cache_manager.set_Bn_buffer(
193
- hidden_states_residual,
194
- prefix=f"{self.cache_prefix}_Bn_residual",
195
- )
196
- else:
197
- self.cache_manager.set_Bn_buffer(
198
- hidden_states,
199
- prefix=f"{self.cache_prefix}_Bn_hidden_states",
200
- )
201
-
202
- if new_encoder_hidden_states is not None:
203
- new_encoder_hidden_states_residual = (
204
- new_encoder_hidden_states - old_encoder_hidden_states
205
- )
206
- if self.cache_manager.is_encoder_cache_residual():
207
- if new_encoder_hidden_states is not None:
208
- self.cache_manager.set_Bn_encoder_buffer(
209
- new_encoder_hidden_states_residual,
210
- prefix=f"{self.cache_prefix}_Bn_residual",
211
- )
212
- else:
213
- if new_encoder_hidden_states is not None:
214
- self.cache_manager.set_Bn_encoder_buffer(
215
- new_encoder_hidden_states_residual,
216
- prefix=f"{self.cache_prefix}_Bn_hidden_states",
217
- )
218
- torch._dynamo.graph_break()
219
- # Call last `n` blocks to further process the hidden states
220
- # for higher precision.
221
- if self.cache_manager.Bn_compute_blocks() > 0:
222
- hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
223
- hidden_states,
224
- *args,
225
- **kwargs,
226
- )
227
-
228
- torch._dynamo.graph_break()
229
-
230
- return self._process_forward_outputs(
231
- hidden_states,
232
- new_encoder_hidden_states,
233
- )
234
-
235
- def call_Fn_blocks(
236
- self,
237
- hidden_states: torch.Tensor,
238
- *args,
239
- **kwargs,
240
- ):
241
- new_encoder_hidden_states = None
242
- for block in self._Fn_blocks():
243
- hidden_states = block(
244
- hidden_states,
245
- *args,
246
- **kwargs,
247
- )
248
- hidden_states, new_encoder_hidden_states = (
249
- self._process_block_outputs(hidden_states)
250
- )
251
-
252
- return hidden_states, new_encoder_hidden_states
253
-
254
- def call_Mn_blocks(
255
- self,
256
- hidden_states: torch.Tensor,
257
- *args,
258
- **kwargs,
259
- ):
260
- original_hidden_states = hidden_states
261
- new_encoder_hidden_states = None
262
- for block in self._Mn_blocks():
263
- hidden_states = block(
264
- hidden_states,
265
- *args,
266
- **kwargs,
267
- )
268
-
269
- hidden_states, new_encoder_hidden_states = (
270
- self._process_block_outputs(hidden_states)
271
- )
272
-
273
- # compute hidden_states residual
274
- hidden_states = hidden_states.contiguous()
275
- hidden_states_residual = hidden_states - original_hidden_states.to(
276
- hidden_states.device
277
- )
278
-
279
- return (
280
- hidden_states,
281
- new_encoder_hidden_states,
282
- hidden_states_residual,
283
- )
284
-
285
- def call_Bn_blocks(
286
- self,
287
- hidden_states: torch.Tensor,
288
- *args,
289
- **kwargs,
290
- ):
291
- new_encoder_hidden_states = None
292
- if self.cache_manager.Bn_compute_blocks() == 0:
293
- return hidden_states, new_encoder_hidden_states
294
-
295
- for block in self._Bn_blocks():
296
- hidden_states = block(
297
- hidden_states,
298
- *args,
299
- **kwargs,
300
- )
301
-
302
- hidden_states, new_encoder_hidden_states = (
303
- self._process_block_outputs(hidden_states)
304
- )
305
-
306
- return hidden_states, new_encoder_hidden_states