cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -1,404 +0,0 @@
1
- import inspect
2
- import torch
3
- import torch.distributed as dist
4
-
5
- from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
6
- from cache_dit.cache_factory.cache_contexts.cache_manager import (
7
- CachedContextManager,
8
- )
9
- from cache_dit.cache_factory import ForwardPattern
10
- from cache_dit.logger import init_logger
11
-
12
- logger = init_logger(__name__)
13
-
14
-
15
- class CachedBlocks_Pattern_Base(torch.nn.Module):
16
- _supported_patterns = [
17
- ForwardPattern.Pattern_0,
18
- ForwardPattern.Pattern_1,
19
- ForwardPattern.Pattern_2,
20
- ]
21
-
22
- def __init__(
23
- self,
24
- # 0. Transformer blocks configuration
25
- transformer_blocks: torch.nn.ModuleList,
26
- transformer: torch.nn.Module = None,
27
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
28
- check_forward_pattern: bool = True,
29
- check_num_outputs: bool = True,
30
- # 1. Cache context configuration
31
- cache_prefix: str = None, # maybe un-need.
32
- cache_context: CachedContext | str = None,
33
- cache_manager: CachedContextManager = None,
34
- **kwargs,
35
- ):
36
- super().__init__()
37
-
38
- # 0. Transformer blocks configuration
39
- self.transformer = transformer
40
- self.transformer_blocks = transformer_blocks
41
- self.forward_pattern = forward_pattern
42
- self.check_forward_pattern = check_forward_pattern
43
- self.check_num_outputs = check_num_outputs
44
- # 1. Cache context configuration
45
- self.cache_prefix = cache_prefix
46
- self.cache_context = cache_context
47
- self.cache_manager = cache_manager
48
-
49
- self._check_forward_pattern()
50
- logger.info(
51
- f"Match Cached Blocks: {self.__class__.__name__}, for "
52
- f"{self.cache_prefix}, cache_context: {self.cache_context}, "
53
- f"cache_manager: {self.cache_manager.name}."
54
- )
55
-
56
- def _check_forward_pattern(self):
57
- if not self.check_forward_pattern:
58
- logger.warning(
59
- f"Skipped Forward Pattern Check: {self.forward_pattern}"
60
- )
61
- return
62
-
63
- assert (
64
- self.forward_pattern.Supported
65
- and self.forward_pattern in self._supported_patterns
66
- ), f"Pattern {self.forward_pattern} is not supported now!"
67
-
68
- if self.transformer_blocks is not None:
69
- for block in self.transformer_blocks:
70
- # Special case for HiDreamBlock
71
- if hasattr(block, "block"):
72
- if isinstance(block.block, torch.nn.Module):
73
- block = block.block
74
-
75
- forward_parameters = set(
76
- inspect.signature(block.forward).parameters.keys()
77
- )
78
-
79
- if self.check_num_outputs:
80
- num_outputs = str(
81
- inspect.signature(block.forward).return_annotation
82
- ).count("torch.Tensor")
83
-
84
- if num_outputs > 0:
85
- assert len(self.forward_pattern.Out) == num_outputs, (
86
- f"The number of block's outputs is {num_outputs} don't not "
87
- f"match the number of the pattern: {self.forward_pattern}, "
88
- f"Out: {len(self.forward_pattern.Out)}."
89
- )
90
-
91
- for required_param in self.forward_pattern.In:
92
- assert (
93
- required_param in forward_parameters
94
- ), f"The input parameters must contains: {required_param}."
95
-
96
- @torch.compiler.disable
97
- def _check_cache_params(self):
98
- assert self.cache_manager.Fn_compute_blocks() <= len(
99
- self.transformer_blocks
100
- ), (
101
- f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
102
- f"the number of transformer blocks {len(self.transformer_blocks)}"
103
- )
104
- assert self.cache_manager.Bn_compute_blocks() <= len(
105
- self.transformer_blocks
106
- ), (
107
- f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
108
- f"the number of transformer blocks {len(self.transformer_blocks)}"
109
- )
110
-
111
- def forward(
112
- self,
113
- hidden_states: torch.Tensor,
114
- encoder_hidden_states: torch.Tensor,
115
- *args,
116
- **kwargs,
117
- ):
118
- # Use it's own cache context.
119
- self.cache_manager.set_context(self.cache_context)
120
- self._check_cache_params()
121
-
122
- original_hidden_states = hidden_states
123
- # Call first `n` blocks to process the hidden states for
124
- # more stable diff calculation.
125
- hidden_states, encoder_hidden_states = self.call_Fn_blocks(
126
- hidden_states,
127
- encoder_hidden_states,
128
- *args,
129
- **kwargs,
130
- )
131
-
132
- Fn_hidden_states_residual = hidden_states - original_hidden_states
133
- del original_hidden_states
134
-
135
- self.cache_manager.mark_step_begin()
136
- # Residual L1 diff or Hidden States L1 diff
137
- can_use_cache = self.cache_manager.can_cache(
138
- (
139
- Fn_hidden_states_residual
140
- if not self.cache_manager.is_l1_diff_enabled()
141
- else hidden_states
142
- ),
143
- parallelized=self._is_parallelized(),
144
- prefix=(
145
- f"{self.cache_prefix}_Fn_residual"
146
- if not self.cache_manager.is_l1_diff_enabled()
147
- else f"{self.cache_prefix}_Fn_hidden_states"
148
- ),
149
- )
150
-
151
- torch._dynamo.graph_break()
152
- if can_use_cache:
153
- self.cache_manager.add_cached_step()
154
- del Fn_hidden_states_residual
155
- hidden_states, encoder_hidden_states = (
156
- self.cache_manager.apply_cache(
157
- hidden_states,
158
- encoder_hidden_states,
159
- prefix=(
160
- f"{self.cache_prefix}_Bn_residual"
161
- if self.cache_manager.is_cache_residual()
162
- else f"{self.cache_prefix}_Bn_hidden_states"
163
- ),
164
- encoder_prefix=(
165
- f"{self.cache_prefix}_Bn_residual"
166
- if self.cache_manager.is_encoder_cache_residual()
167
- else f"{self.cache_prefix}_Bn_hidden_states"
168
- ),
169
- )
170
- )
171
- torch._dynamo.graph_break()
172
- # Call last `n` blocks to further process the hidden states
173
- # for higher precision.
174
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
175
- hidden_states,
176
- encoder_hidden_states,
177
- *args,
178
- **kwargs,
179
- )
180
- else:
181
- self.cache_manager.set_Fn_buffer(
182
- Fn_hidden_states_residual,
183
- prefix=f"{self.cache_prefix}_Fn_residual",
184
- )
185
- if self.cache_manager.is_l1_diff_enabled():
186
- # for hidden states L1 diff
187
- self.cache_manager.set_Fn_buffer(
188
- hidden_states,
189
- f"{self.cache_prefix}_Fn_hidden_states",
190
- )
191
- del Fn_hidden_states_residual
192
- torch._dynamo.graph_break()
193
- (
194
- hidden_states,
195
- encoder_hidden_states,
196
- hidden_states_residual,
197
- encoder_hidden_states_residual,
198
- ) = self.call_Mn_blocks( # middle
199
- hidden_states,
200
- encoder_hidden_states,
201
- *args,
202
- **kwargs,
203
- )
204
- torch._dynamo.graph_break()
205
- if self.cache_manager.is_cache_residual():
206
- self.cache_manager.set_Bn_buffer(
207
- hidden_states_residual,
208
- prefix=f"{self.cache_prefix}_Bn_residual",
209
- )
210
- else:
211
- self.cache_manager.set_Bn_buffer(
212
- hidden_states,
213
- prefix=f"{self.cache_prefix}_Bn_hidden_states",
214
- )
215
-
216
- if self.cache_manager.is_encoder_cache_residual():
217
- self.cache_manager.set_Bn_encoder_buffer(
218
- encoder_hidden_states_residual,
219
- prefix=f"{self.cache_prefix}_Bn_residual",
220
- )
221
- else:
222
- self.cache_manager.set_Bn_encoder_buffer(
223
- encoder_hidden_states,
224
- prefix=f"{self.cache_prefix}_Bn_hidden_states",
225
- )
226
- torch._dynamo.graph_break()
227
- # Call last `n` blocks to further process the hidden states
228
- # for higher precision.
229
- hidden_states, encoder_hidden_states = self.call_Bn_blocks(
230
- hidden_states,
231
- encoder_hidden_states,
232
- *args,
233
- **kwargs,
234
- )
235
-
236
- # patch cached stats for blocks or remove it.
237
- torch._dynamo.graph_break()
238
-
239
- return (
240
- hidden_states
241
- if self.forward_pattern.Return_H_Only
242
- else (
243
- (hidden_states, encoder_hidden_states)
244
- if self.forward_pattern.Return_H_First
245
- else (encoder_hidden_states, hidden_states)
246
- )
247
- )
248
-
249
- @torch.compiler.disable
250
- def _is_parallelized(self):
251
- # Compatible with distributed inference.
252
- return any(
253
- (
254
- all(
255
- (
256
- self.transformer is not None,
257
- getattr(self.transformer, "_is_parallelized", False),
258
- )
259
- ),
260
- (dist.is_initialized() and dist.get_world_size() > 1),
261
- )
262
- )
263
-
264
- @torch.compiler.disable
265
- def _is_in_cache_step(self):
266
- # Check if the current step is in cache steps.
267
- # If so, we can skip some Bn blocks and directly
268
- # use the cached values.
269
- return (
270
- self.cache_manager.get_current_step()
271
- in self.cache_manager.get_cached_steps()
272
- ) or (
273
- self.cache_manager.get_current_step()
274
- in self.cache_manager.get_cfg_cached_steps()
275
- )
276
-
277
- @torch.compiler.disable
278
- def _Fn_blocks(self):
279
- # Select first `n` blocks to process the hidden states for
280
- # more stable diff calculation.
281
- # Fn: [0,...,n-1]
282
- selected_Fn_blocks = self.transformer_blocks[
283
- : self.cache_manager.Fn_compute_blocks()
284
- ]
285
- return selected_Fn_blocks
286
-
287
- @torch.compiler.disable
288
- def _Mn_blocks(self): # middle blocks
289
- # M(N-2n): only transformer_blocks [n,...,N-n], middle
290
- if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
291
- selected_Mn_blocks = self.transformer_blocks[
292
- self.cache_manager.Fn_compute_blocks() :
293
- ]
294
- else:
295
- selected_Mn_blocks = self.transformer_blocks[
296
- self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
297
- ]
298
- return selected_Mn_blocks
299
-
300
- @torch.compiler.disable
301
- def _Bn_blocks(self):
302
- # Bn: transformer_blocks [N-n+1,...,N-1]
303
- selected_Bn_blocks = self.transformer_blocks[
304
- -self.cache_manager.Bn_compute_blocks() :
305
- ]
306
- return selected_Bn_blocks
307
-
308
- def call_Fn_blocks(
309
- self,
310
- hidden_states: torch.Tensor,
311
- encoder_hidden_states: torch.Tensor,
312
- *args,
313
- **kwargs,
314
- ):
315
- for block in self._Fn_blocks():
316
- hidden_states = block(
317
- hidden_states,
318
- encoder_hidden_states,
319
- *args,
320
- **kwargs,
321
- )
322
- if not isinstance(hidden_states, torch.Tensor):
323
- hidden_states, encoder_hidden_states = hidden_states
324
- if not self.forward_pattern.Return_H_First:
325
- hidden_states, encoder_hidden_states = (
326
- encoder_hidden_states,
327
- hidden_states,
328
- )
329
-
330
- return hidden_states, encoder_hidden_states
331
-
332
- def call_Mn_blocks(
333
- self,
334
- hidden_states: torch.Tensor,
335
- encoder_hidden_states: torch.Tensor,
336
- *args,
337
- **kwargs,
338
- ):
339
- original_hidden_states = hidden_states
340
- original_encoder_hidden_states = encoder_hidden_states
341
- for block in self._Mn_blocks():
342
- hidden_states = block(
343
- hidden_states,
344
- encoder_hidden_states,
345
- *args,
346
- **kwargs,
347
- )
348
- if not isinstance(hidden_states, torch.Tensor):
349
- hidden_states, encoder_hidden_states = hidden_states
350
- if not self.forward_pattern.Return_H_First:
351
- hidden_states, encoder_hidden_states = (
352
- encoder_hidden_states,
353
- hidden_states,
354
- )
355
-
356
- # compute hidden_states residual
357
- hidden_states = hidden_states.contiguous()
358
-
359
- hidden_states_residual = hidden_states - original_hidden_states
360
-
361
- if (
362
- encoder_hidden_states is not None
363
- and original_encoder_hidden_states is not None
364
- ):
365
- encoder_hidden_states = encoder_hidden_states.contiguous()
366
- encoder_hidden_states_residual = (
367
- encoder_hidden_states - original_encoder_hidden_states
368
- )
369
- else:
370
- encoder_hidden_states_residual = None
371
-
372
- return (
373
- hidden_states,
374
- encoder_hidden_states,
375
- hidden_states_residual,
376
- encoder_hidden_states_residual,
377
- )
378
-
379
- def call_Bn_blocks(
380
- self,
381
- hidden_states: torch.Tensor,
382
- encoder_hidden_states: torch.Tensor,
383
- *args,
384
- **kwargs,
385
- ):
386
- if self.cache_manager.Bn_compute_blocks() == 0:
387
- return hidden_states, encoder_hidden_states
388
-
389
- for block in self._Bn_blocks():
390
- hidden_states = block(
391
- hidden_states,
392
- encoder_hidden_states,
393
- *args,
394
- **kwargs,
395
- )
396
- if not isinstance(hidden_states, torch.Tensor):
397
- hidden_states, encoder_hidden_states = hidden_states
398
- if not self.forward_pattern.Return_H_First:
399
- hidden_states, encoder_hidden_states = (
400
- encoder_hidden_states,
401
- hidden_states,
402
- )
403
-
404
- return hidden_states, encoder_hidden_states
@@ -1,41 +0,0 @@
1
- import torch
2
-
3
- from typing import Any
4
- from cache_dit.cache_factory import CachedContext
5
- from cache_dit.cache_factory import CachedContextManager
6
-
7
-
8
- def patch_cached_stats(
9
- module: torch.nn.Module | Any,
10
- cache_context: CachedContext | str = None,
11
- cache_manager: CachedContextManager = None,
12
- ):
13
- # Patch the cached stats to the module, the cached stats
14
- # will be reset for each calling of pipe.__call__(**kwargs).
15
- if module is None or cache_manager is None:
16
- return
17
-
18
- if cache_context is not None:
19
- cache_manager.set_context(cache_context)
20
-
21
- # TODO: Patch more cached stats to the module
22
- module._cached_steps = cache_manager.get_cached_steps()
23
- module._residual_diffs = cache_manager.get_residual_diffs()
24
- module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
25
- module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
26
-
27
-
28
- def remove_cached_stats(
29
- module: torch.nn.Module | Any,
30
- ):
31
- if module is None:
32
- return
33
-
34
- if hasattr(module, "_cached_steps"):
35
- del module._cached_steps
36
- if hasattr(module, "_residual_diffs"):
37
- del module._residual_diffs
38
- if hasattr(module, "_cfg_cached_steps"):
39
- del module._cfg_cached_steps
40
- if hasattr(module, "_cfg_residual_diffs"):
41
- del module._cfg_residual_diffs
@@ -1,14 +0,0 @@
1
- from cache_dit.cache_factory.cache_contexts.calibrators import (
2
- Calibrator,
3
- CalibratorBase,
4
- CalibratorConfig,
5
- TaylorSeerCalibratorConfig,
6
- FoCaCalibratorConfig,
7
- )
8
- from cache_dit.cache_factory.cache_contexts.cache_context import (
9
- CachedContext,
10
- BasicCacheConfig,
11
- )
12
- from cache_dit.cache_factory.cache_contexts.cache_manager import (
13
- CachedContextManager,
14
- )
@@ -1,217 +0,0 @@
1
- from typing import Any, Tuple, List, Union, Optional
2
- from diffusers import DiffusionPipeline
3
- from cache_dit.cache_factory.cache_types import CacheType
4
- from cache_dit.cache_factory.block_adapters import BlockAdapter
5
- from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
6
- from cache_dit.cache_factory.cache_adapters import CachedAdapter
7
- from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
8
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig
9
- from cache_dit.cache_factory.params_modifier import ParamsModifier
10
-
11
- from cache_dit.logger import init_logger
12
-
13
- logger = init_logger(__name__)
14
-
15
-
16
- def enable_cache(
17
- # DiffusionPipeline or BlockAdapter
18
- pipe_or_adapter: Union[
19
- DiffusionPipeline,
20
- BlockAdapter,
21
- ],
22
- # Basic DBCache config: BasicCacheConfig
23
- cache_config: BasicCacheConfig = BasicCacheConfig(),
24
- # Calibrator config: TaylorSeerCalibratorConfig, etc.
25
- calibrator_config: Optional[CalibratorConfig] = None,
26
- # Modify cache context params for specific blocks.
27
- params_modifiers: Optional[
28
- Union[
29
- ParamsModifier,
30
- List[ParamsModifier],
31
- List[List[ParamsModifier]],
32
- ]
33
- ] = None,
34
- # Other cache context kwargs: Deprecated cache kwargs
35
- **kwargs,
36
- ) -> Union[
37
- DiffusionPipeline,
38
- BlockAdapter,
39
- ]:
40
- r"""
41
- Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
42
- that match the specific Input and Output patterns).
43
-
44
- For a good balance between performance and precision, DBCache is configured by default
45
- with F8B0, 8 warmup steps, and unlimited cached steps.
46
-
47
- Args:
48
- pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
49
- The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
50
- For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
51
- for the usgae of BlockAdapter.
52
- cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
53
- Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
54
- Fn_compute_blocks: (`int`, *required*, defaults to 8):
55
- Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
56
- at time step t, enabling the calculation of a more stable L1 diff and delivering more
57
- accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
58
- for more details of DBCache.
59
- Bn_compute_blocks: (`int`, *required*, defaults to 0):
60
- Further fuses approximate information in the **last n** Transformer blocks to enhance
61
- prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
62
- that use residual cache.
63
- residual_diff_threshold (`float`, *required*, defaults to 0.08):
64
- the value of residual diff threshold, a higher value leads to faster performance at the
65
- cost of lower precision.
66
- max_warmup_steps (`int`, *required*, defaults to 8):
67
- DBCache does not apply the caching strategy when the number of running steps is less than
68
- or equal to this value, ensuring the model sufficiently learns basic features during warmup.
69
- max_cached_steps (`int`, *required*, defaults to -1):
70
- DBCache disables the caching strategy when the previous cached steps exceed this value to
71
- prevent precision degradation.
72
- max_continuous_cached_steps (`int`, *required*, defaults to -1):
73
- DBCache disables the caching strategy when the previous continous cached steps exceed this value to
74
- prevent precision degradation.
75
- enable_separate_cfg (`bool`, *required*, defaults to None):
76
- Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
77
- and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
78
- CogVideoX, HunyuanVideo, Mochi, etc.
79
- cfg_compute_first (`bool`, *required*, defaults to False):
80
- Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
81
- 1, 3, 5, ... -> CFG step.
82
- cfg_diff_compute_separate (`bool`, *required*, defaults to True):
83
- Compute separate diff values for CFG and non-CFG step, default True. If False, we will
84
- use the computed diff from current non-CFG transformer step for current CFG step.
85
- calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
86
- Config for calibrator, if calibrator_config is not None, means that user want to use DBCache
87
- with specific calibrator, such as taylorseer, foca, and so on.
88
- params_modifiers ('ParamsModifier', *optional*, defaults to None):
89
- Modify cache context params for specific blocks. The configurable params listed belows:
90
- cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
91
- The same as 'cache_config' param in cache_dit.enable_cache() interface.
92
- calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
93
- The same as 'calibrator_config' param in cache_dit.enable_cache() interface.
94
- **kwargs: (`dict`, *optional*, defaults to {}):
95
- The same as 'kwargs' param in cache_dit.enable_cache() interface.
96
- kwargs (`dict`, *optional*, defaults to {})
97
- Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
98
- for more details.
99
-
100
- Examples:
101
- ```py
102
- >>> import cache_dit
103
- >>> from diffusers import DiffusionPipeline
104
- >>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
105
- >>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
106
- >>> output = pipe(...) # Just call the pipe as normal.
107
- >>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
108
- >>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
109
- """
110
- # Collect cache context kwargs
111
- cache_context_kwargs = {}
112
- if (cache_type := cache_context_kwargs.pop("cache_type", None)) is not None:
113
- if cache_type == CacheType.NONE:
114
- return pipe_or_adapter
115
-
116
- # WARNING: Deprecated cache config params. These parameters are now retained
117
- # for backward compatibility but will be removed in the future.
118
- deprecated_cache_kwargs = {
119
- "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
120
- "Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
121
- "max_warmup_steps": kwargs.get("max_warmup_steps", None),
122
- "max_cached_steps": kwargs.get("max_cached_steps", None),
123
- "max_continuous_cached_steps": kwargs.get(
124
- "max_continuous_cached_steps", None
125
- ),
126
- "residual_diff_threshold": kwargs.get("residual_diff_threshold", None),
127
- "enable_separate_cfg": kwargs.get("enable_separate_cfg", None),
128
- "cfg_compute_first": kwargs.get("cfg_compute_first", None),
129
- "cfg_diff_compute_separate": kwargs.get(
130
- "cfg_diff_compute_separate", None
131
- ),
132
- }
133
-
134
- deprecated_cache_kwargs = {
135
- k: v for k, v in deprecated_cache_kwargs.items() if v is not None
136
- }
137
-
138
- if deprecated_cache_kwargs:
139
- logger.warning(
140
- "Manually settup DBCache context without BasicCacheConfig is "
141
- "deprecated and will be removed in the future, please use "
142
- "`cache_config` parameter instead!"
143
- )
144
- if cache_config is not None:
145
- cache_config.update(**deprecated_cache_kwargs)
146
- else:
147
- cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
148
-
149
- if cache_config is not None:
150
- cache_context_kwargs["cache_config"] = cache_config
151
-
152
- # WARNING: Deprecated taylorseer params. These parameters are now retained
153
- # for backward compatibility but will be removed in the future.
154
- if (
155
- kwargs.get("enable_taylorseer", None) is not None
156
- or kwargs.get("enable_encoder_taylorseer", None) is not None
157
- ):
158
- logger.warning(
159
- "Manually settup TaylorSeer calibrator without TaylorSeerCalibratorConfig is "
160
- "deprecated and will be removed in the future, please use "
161
- "`calibrator_config` parameter instead!"
162
- )
163
- from cache_dit.cache_factory.cache_contexts.calibrators import (
164
- TaylorSeerCalibratorConfig,
165
- )
166
-
167
- calibrator_config = TaylorSeerCalibratorConfig(
168
- enable_calibrator=kwargs.get("enable_taylorseer"),
169
- enable_encoder_calibrator=kwargs.get("enable_encoder_taylorseer"),
170
- calibrator_cache_type=kwargs.get(
171
- "taylorseer_cache_type", "residual"
172
- ),
173
- taylorseer_order=kwargs.get("taylorseer_order", 1),
174
- )
175
-
176
- if calibrator_config is not None:
177
- cache_context_kwargs["calibrator_config"] = calibrator_config
178
-
179
- if params_modifiers is not None:
180
- cache_context_kwargs["params_modifiers"] = params_modifiers
181
-
182
- if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
183
- return CachedAdapter.apply(
184
- pipe_or_adapter,
185
- **cache_context_kwargs,
186
- )
187
- else:
188
- raise ValueError(
189
- f"type: {type(pipe_or_adapter)} is not valid, "
190
- "Please pass DiffusionPipeline or BlockAdapter"
191
- "for the 1's position param: pipe_or_adapter"
192
- )
193
-
194
-
195
- def disable_cache(
196
- pipe_or_adapter: Union[
197
- DiffusionPipeline,
198
- BlockAdapter,
199
- ],
200
- ):
201
- CachedAdapter.maybe_release_hooks(pipe_or_adapter)
202
- logger.warning(
203
- f"Cache Acceleration is disabled for: "
204
- f"{pipe_or_adapter.__class__.__name__}."
205
- )
206
-
207
-
208
- def supported_pipelines(
209
- **kwargs,
210
- ) -> Tuple[int, List[str]]:
211
- return BlockAdapterRegistry.supported_pipelines(**kwargs)
212
-
213
-
214
- def get_adapter(
215
- pipe: DiffusionPipeline | str | Any,
216
- ) -> BlockAdapter:
217
- return BlockAdapterRegistry.get_adapter(pipe)
@@ -1,12 +0,0 @@
1
- from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
2
- from cache_dit.cache_factory.patch_functors.functor_dit import DiTPatchFunctor
3
- from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
4
- from cache_dit.cache_factory.patch_functors.functor_chroma import (
5
- ChromaPatchFunctor,
6
- )
7
- from cache_dit.cache_factory.patch_functors.functor_hidream import (
8
- HiDreamPatchFunctor,
9
- )
10
- from cache_dit.cache_factory.patch_functors.functor_hunyuan_dit import (
11
- HunyuanDiTPatchFunctor,
12
- )