cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,358 @@
1
+ import torch
2
+ from typing import Any, Tuple, List, Union, Optional
3
+ from diffusers import DiffusionPipeline, ModelMixin
4
+ from cache_dit.caching.cache_types import CacheType
5
+ from cache_dit.caching.block_adapters import BlockAdapter
6
+ from cache_dit.caching.block_adapters import BlockAdapterRegistry
7
+ from cache_dit.caching.cache_adapters import CachedAdapter
8
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
9
+ from cache_dit.caching.cache_contexts import DBCacheConfig
10
+ from cache_dit.caching.cache_contexts import DBPruneConfig
11
+ from cache_dit.caching.cache_contexts import CalibratorConfig
12
+ from cache_dit.caching.params_modifier import ParamsModifier
13
+ from cache_dit.parallelism import ParallelismConfig
14
+ from cache_dit.parallelism import enable_parallelism
15
+
16
+ from cache_dit.logger import init_logger
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ def enable_cache(
22
+ # DiffusionPipeline or BlockAdapter
23
+ pipe_or_adapter: Union[
24
+ DiffusionPipeline,
25
+ BlockAdapter,
26
+ # Transformer-only
27
+ torch.nn.Module,
28
+ ModelMixin,
29
+ ],
30
+ # BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
31
+ cache_config: Optional[
32
+ Union[
33
+ BasicCacheConfig,
34
+ DBCacheConfig,
35
+ DBPruneConfig,
36
+ ]
37
+ ] = None,
38
+ # Calibrator config: TaylorSeerCalibratorConfig, etc.
39
+ calibrator_config: Optional[CalibratorConfig] = None,
40
+ # Modify cache context params for specific blocks.
41
+ params_modifiers: Optional[
42
+ Union[
43
+ ParamsModifier,
44
+ List[ParamsModifier],
45
+ List[List[ParamsModifier]],
46
+ ]
47
+ ] = None,
48
+ # Config for Parallelism
49
+ parallelism_config: Optional[ParallelismConfig] = None,
50
+ # Other cache context kwargs: Deprecated cache kwargs
51
+ **kwargs,
52
+ ) -> Union[
53
+ DiffusionPipeline,
54
+ # Transformer-only
55
+ torch.nn.Module,
56
+ ModelMixin,
57
+ BlockAdapter,
58
+ ]:
59
+ r"""
60
+ The `enable_cache` function serves as a unified caching interface designed to optimize the performance
61
+ of diffusion transformer models by implementing an intelligent caching mechanism known as `DBCache`.
62
+ This API is engineered to be compatible with nearly `all` diffusion transformer architectures that
63
+ feature transformer blocks adhering to standard input-output patterns, eliminating the need for
64
+ architecture-specific modifications.
65
+
66
+ By strategically caching intermediate outputs of transformer blocks during the diffusion process,
67
+ `DBCache` significantly reduces redundant computations without compromising generation quality.
68
+ The caching mechanism works by tracking residual differences between consecutive steps, allowing
69
+ the model to reuse previously computed features when these differences fall below a configurable
70
+ threshold. This approach maintains a balance between computational efficiency and output precision.
71
+
72
+ The default configuration (`F8B0, 8 warmup steps, unlimited cached steps`) is carefully tuned to
73
+ provide an optimal tradeoff for most common use cases. The "F8B0" configuration indicates that
74
+ the first 8 transformer blocks are used to compute stable feature differences, while no final
75
+ blocks are employed for additional fusion. The warmup phase ensures the model establishes
76
+ sufficient feature representation before caching begins, preventing potential degradation of
77
+ output quality.
78
+
79
+ This function seamlessly integrates with both standard diffusion pipelines and custom block
80
+ adapters, making it versatile for various deployment scenarios—from research prototyping to
81
+ production environments where inference speed is critical. By abstracting the complexity of
82
+ caching logic behind a simple interface, it enables developers to enhance model performance
83
+ with minimal code changes.
84
+
85
+ Args:
86
+ pipe_or_adapter (`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
87
+ The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
88
+ For example: cache_dit.enable_cache(FluxPipeline(...)).
89
+
90
+ cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
91
+ Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
92
+ Fn_compute_blocks: (`int`, *required*, defaults to 8):
93
+ Specifies that `DBCache` uses the**first n**Transformer blocks to fit the information at time step t,
94
+ enabling the calculation of a more stable L1 difference and delivering more accurate information
95
+ to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
96
+ for more details of DBCache.
97
+ Bn_compute_blocks: (`int`, *required*, defaults to 0):
98
+ Further fuses approximate information in the **last n** Transformer blocks to enhance
99
+ prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
100
+ that use residual cache.
101
+ residual_diff_threshold (`float`, *required*, defaults to 0.08):
102
+ the value of residual diff threshold, a higher value leads to faster performance at the
103
+ cost of lower precision.
104
+ max_warmup_steps (`int`, *required*, defaults to 8):
105
+ DBCache does not apply the caching strategy when the number of running steps is less than
106
+ or equal to this value, ensuring the model sufficiently learns basic features during warmup.
107
+ warmup_interval (`int`, *required*, defaults to 1):
108
+ Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
109
+ in warmup steps will be computed, others will use dynamic cache.
110
+ max_cached_steps (`int`, *required*, defaults to -1):
111
+ DBCache disables the caching strategy when the previous cached steps exceed this value to
112
+ prevent precision degradation.
113
+ max_continuous_cached_steps (`int`, *required*, defaults to -1):
114
+ DBCache disables the caching strategy when the previous continous cached steps exceed this value to
115
+ prevent precision degradation.
116
+ enable_separate_cfg (`bool`, *required*, defaults to None):
117
+ Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
118
+ and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
119
+ CogVideoX, HunyuanVideo, Mochi, etc.
120
+ cfg_compute_first (`bool`, *required*, defaults to False):
121
+ Whether to compute cfg forward first, default is False, meaning:
122
+ 0, 2, 4, ..., -> non-CFG step;
123
+ 1, 3, 5, ... -> CFG step.
124
+ cfg_diff_compute_separate (`bool`, *required*, defaults to True):
125
+ Whether to compute separate difference values for CFG and non-CFG steps, default is True.
126
+ If False, we will use the computed difference from the current non-CFG transformer step
127
+ for the current CFG step.
128
+ num_inference_steps (`int`, *optional*, defaults to None):
129
+ num_inference_steps for DiffusionPipeline, used to adjust some internal settings
130
+ for better caching performance. For example, we will refresh the cache once the
131
+ executed steps exceed num_inference_steps if num_inference_steps is provided.
132
+
133
+ calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
134
+ Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
135
+ with a specific calibrator, such as taylorseer, foca, and so on.
136
+
137
+ params_modifiers ('ParamsModifier', *optional*, defaults to None):
138
+ Modify cache context params for specific blocks. The configurable params listed belows:
139
+ cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
140
+ The same as 'cache_config' param in cache_dit.enable_cache() interface.
141
+ calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
142
+ The same as 'calibrator_config' param in cache_dit.enable_cache() interface.
143
+ **kwargs: (`dict`, *optional*, defaults to {}):
144
+ The same as 'kwargs' param in cache_dit.enable_cache() interface.
145
+
146
+ parallelism_config (`ParallelismConfig`, *optional*, defaults to None):
147
+ Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
148
+ parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
149
+ for more details of ParallelismConfig.
150
+ backend: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
151
+ Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
152
+ For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
153
+ only NATIVE_PYTORCH backend is supported.
154
+ ulysses_size: (`int`, *optional*, defaults to None):
155
+ The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
156
+ This setting is only valid when backend is NATIVE_DIFFUSER.
157
+ ring_size: (`int`, *optional*, defaults to None):
158
+ The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
159
+ This setting is only valid when backend is NATIVE_DIFFUSER.
160
+ tp_size: (`int`, *optional*, defaults to None):
161
+ The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
162
+ This setting is only valid when backend is NATIVE_PYTORCH.
163
+ parallel_kwargs: (`dict`, *optional*, defaults to {}):
164
+ Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
165
+ it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
166
+
167
+ kwargs (`dict`, *optional*, defaults to {})
168
+ Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_context.py
169
+ for more details.
170
+
171
+ Examples:
172
+ ```py
173
+ >>> import cache_dit
174
+ >>> from diffusers import DiffusionPipeline
175
+ >>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
176
+ >>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
177
+ >>> output = pipe(...) # Just call the pipe as normal.
178
+ >>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
179
+ >>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
180
+ """
181
+ # Precheck for compatibility of different configurations
182
+ if cache_config is None:
183
+ if parallelism_config is None:
184
+ # Set default cache config only when parallelism is not enabled
185
+ logger.info("cache_config is None, using default DBCacheConfig")
186
+ cache_config = DBCacheConfig()
187
+ else:
188
+ # Allow empty cache_config when parallelism is enabled
189
+ logger.warning(
190
+ "Parallelism is enabled and cache_config is None. Please manually "
191
+ "set cache_config to avoid potential compatibility issues. "
192
+ "Otherwise, cache will not be enabled."
193
+ )
194
+
195
+ # Collect cache context kwargs
196
+ context_kwargs = {}
197
+ if (cache_type := context_kwargs.get("cache_type", None)) is not None:
198
+ if cache_type == CacheType.NONE:
199
+ return pipe_or_adapter
200
+
201
+ # NOTE: Deprecated cache config params. These parameters are now retained
202
+ # for backward compatibility but will be removed in the future.
203
+ deprecated_kwargs = {
204
+ "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
205
+ "Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
206
+ "max_warmup_steps": kwargs.get("max_warmup_steps", None),
207
+ "max_cached_steps": kwargs.get("max_cached_steps", None),
208
+ "max_continuous_cached_steps": kwargs.get(
209
+ "max_continuous_cached_steps", None
210
+ ),
211
+ "residual_diff_threshold": kwargs.get("residual_diff_threshold", None),
212
+ "enable_separate_cfg": kwargs.get("enable_separate_cfg", None),
213
+ "cfg_compute_first": kwargs.get("cfg_compute_first", None),
214
+ "cfg_diff_compute_separate": kwargs.get(
215
+ "cfg_diff_compute_separate", None
216
+ ),
217
+ }
218
+
219
+ deprecated_kwargs = {
220
+ k: v for k, v in deprecated_kwargs.items() if v is not None
221
+ }
222
+
223
+ if deprecated_kwargs:
224
+ logger.warning(
225
+ "Manually settup DBCache context without BasicCacheConfig is "
226
+ "deprecated and will be removed in the future, please use "
227
+ "`cache_config` parameter instead!"
228
+ )
229
+ if cache_config is not None:
230
+ cache_config.update(**deprecated_kwargs)
231
+ else:
232
+ cache_config = BasicCacheConfig(**deprecated_kwargs)
233
+
234
+ if cache_config is not None:
235
+ context_kwargs["cache_config"] = cache_config
236
+
237
+ # NOTE: Deprecated taylorseer params. These parameters are now retained
238
+ # for backward compatibility but will be removed in the future.
239
+ if cache_config is not None and (
240
+ kwargs.get("enable_taylorseer", None) is not None
241
+ or kwargs.get("enable_encoder_taylorseer", None) is not None
242
+ ):
243
+ logger.warning(
244
+ "Manually settup TaylorSeer calibrator without TaylorSeerCalibratorConfig is "
245
+ "deprecated and will be removed in the future, please use "
246
+ "`calibrator_config` parameter instead!"
247
+ )
248
+ from cache_dit.caching.cache_contexts.calibrators import (
249
+ TaylorSeerCalibratorConfig,
250
+ )
251
+
252
+ calibrator_config = TaylorSeerCalibratorConfig(
253
+ enable_calibrator=kwargs.get("enable_taylorseer"),
254
+ enable_encoder_calibrator=kwargs.get("enable_encoder_taylorseer"),
255
+ calibrator_cache_type=kwargs.get(
256
+ "taylorseer_cache_type", "residual"
257
+ ),
258
+ taylorseer_order=kwargs.get("taylorseer_order", 1),
259
+ )
260
+
261
+ if calibrator_config is not None:
262
+ context_kwargs["calibrator_config"] = calibrator_config
263
+
264
+ if params_modifiers is not None:
265
+ context_kwargs["params_modifiers"] = params_modifiers
266
+
267
+ if cache_config is not None:
268
+ if isinstance(
269
+ pipe_or_adapter,
270
+ (DiffusionPipeline, BlockAdapter, torch.nn.Module, ModelMixin),
271
+ ):
272
+ pipe_or_adapter = CachedAdapter.apply(
273
+ pipe_or_adapter,
274
+ **context_kwargs,
275
+ )
276
+ else:
277
+ raise ValueError(
278
+ f"type: {type(pipe_or_adapter)} is not valid, "
279
+ "Please pass DiffusionPipeline or BlockAdapter"
280
+ "for the 1's position param: pipe_or_adapter"
281
+ )
282
+ else:
283
+ logger.warning(
284
+ "cache_config is None, skip enabling cache for "
285
+ f"{pipe_or_adapter.__class__.__name__}."
286
+ )
287
+
288
+ # NOTE: Users should always enable parallelism after applying
289
+ # cache to avoid hooks conflict.
290
+ if parallelism_config is not None:
291
+ assert isinstance(
292
+ parallelism_config, ParallelismConfig
293
+ ), "parallelism_config should be of type ParallelismConfig."
294
+
295
+ transformers = []
296
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
297
+ adapter = BlockAdapterRegistry.get_adapter(pipe_or_adapter)
298
+ if adapter is None:
299
+ assert hasattr(pipe_or_adapter, "transformer"), (
300
+ "The given DiffusionPipeline does not have "
301
+ "a 'transformer' attribute, cannot enable "
302
+ "parallelism."
303
+ )
304
+ transformers = [pipe_or_adapter.transformer]
305
+ else:
306
+ adapter = BlockAdapter.normalize(adapter, unique=False)
307
+ transformers = BlockAdapter.flatten(adapter.transformer)
308
+ else:
309
+ if not BlockAdapter.is_normalized(pipe_or_adapter):
310
+ pipe_or_adapter = BlockAdapter.normalize(
311
+ pipe_or_adapter, unique=False
312
+ )
313
+ transformers = BlockAdapter.flatten(pipe_or_adapter.transformer)
314
+
315
+ if len(transformers) == 0:
316
+ logger.warning(
317
+ "No transformer is detected in the "
318
+ "BlockAdapter, skip enabling parallelism."
319
+ )
320
+ return pipe_or_adapter
321
+
322
+ if len(transformers) > 1:
323
+ logger.warning(
324
+ "Multiple transformers are detected in the "
325
+ "BlockAdapter, all transfomers will be "
326
+ "enabled for parallelism."
327
+ )
328
+ for i, transformer in enumerate(transformers):
329
+ # Enable parallelism for the transformer inplace
330
+ transformers[i] = enable_parallelism(
331
+ transformer, parallelism_config
332
+ )
333
+ return pipe_or_adapter
334
+
335
+
336
+ def disable_cache(
337
+ pipe_or_adapter: Union[
338
+ DiffusionPipeline,
339
+ BlockAdapter,
340
+ ],
341
+ ):
342
+ CachedAdapter.maybe_release_hooks(pipe_or_adapter)
343
+ logger.warning(
344
+ f"Cache Acceleration is disabled for: "
345
+ f"{pipe_or_adapter.__class__.__name__}."
346
+ )
347
+
348
+
349
+ def supported_pipelines(
350
+ **kwargs,
351
+ ) -> Tuple[int, List[str]]:
352
+ return BlockAdapterRegistry.supported_pipelines(**kwargs)
353
+
354
+
355
+ def get_adapter(
356
+ pipe: DiffusionPipeline | str | Any,
357
+ ) -> BlockAdapter:
358
+ return BlockAdapterRegistry.get_adapter(pipe)
@@ -6,7 +6,8 @@ logger = init_logger(__name__)
6
6
 
7
7
  class CacheType(Enum):
8
8
  NONE = "NONE"
9
- DBCache = "Dual_Block_Cache"
9
+ DBCache = "DBCache" # "Dual_Block_Cache"
10
+ DBPrune = "DBPrune" # "Dynamic_Block_Prune"
10
11
 
11
12
  @staticmethod
12
13
  def type(type_hint: "CacheType | str") -> "CacheType":
@@ -14,6 +15,9 @@ class CacheType(Enum):
14
15
  return type_hint
15
16
  return cache_type(type_hint)
16
17
 
18
+ def __str__(self) -> str:
19
+ return self.value
20
+
17
21
 
18
22
  def cache_type(type_hint: "CacheType | str") -> "CacheType":
19
23
  if type_hint is None:
@@ -21,7 +25,6 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
21
25
 
22
26
  if isinstance(type_hint, CacheType):
23
27
  return type_hint
24
-
25
28
  elif type_hint.upper() in (
26
29
  "DUAL_BLOCK_CACHE",
27
30
  "DB_CACHE",
@@ -29,6 +32,20 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
29
32
  "DB",
30
33
  ):
31
34
  return CacheType.DBCache
35
+ elif type_hint.upper() in (
36
+ "DYNAMIC_BLOCK_PRUNE",
37
+ "DB_PRUNE",
38
+ "DBPRUNE",
39
+ "DBP",
40
+ ):
41
+ return CacheType.DBPrune
42
+ elif type_hint.upper() in (
43
+ "NONE",
44
+ "NO_CACHE",
45
+ "NOCACHE",
46
+ "NC",
47
+ ):
48
+ return CacheType.NONE
32
49
  return CacheType.NONE
33
50
 
34
51
 
@@ -20,33 +20,33 @@ class ForwardPattern(Enum):
20
20
 
21
21
  Pattern_0 = (
22
22
  True, # Return_H_First
23
- False, # Return_H_Only
24
- False, # Forward_H_only
23
+ False, # Return_H_Only
24
+ False, # Forward_H_only
25
25
  ("hidden_states", "encoder_hidden_states"), # In
26
26
  ("hidden_states", "encoder_hidden_states"), # Out
27
27
  True, # Supported
28
28
  )
29
29
 
30
30
  Pattern_1 = (
31
- False, # Return_H_First
32
- False, # Return_H_Only
33
- False, # Forward_H_only
31
+ False, # Return_H_First
32
+ False, # Return_H_Only
33
+ False, # Forward_H_only
34
34
  ("hidden_states", "encoder_hidden_states"), # In
35
35
  ("encoder_hidden_states", "hidden_states"), # Out
36
36
  True, # Supported
37
37
  )
38
38
 
39
39
  Pattern_2 = (
40
- False, # Return_H_First
40
+ False, # Return_H_First
41
41
  True, # Return_H_Only
42
- False, # Forward_H_only
42
+ False, # Forward_H_only
43
43
  ("hidden_states", "encoder_hidden_states"), # In
44
- ("hidden_states",), # Out
44
+ ("hidden_states",), # Out
45
45
  True, # Supported
46
46
  )
47
47
 
48
48
  Pattern_3 = (
49
- False, # Return_H_First
49
+ False, # Return_H_First
50
50
  True, # Return_H_Only
51
51
  True, # Forward_H_only
52
52
  ("hidden_states",), # In
@@ -56,18 +56,18 @@ class ForwardPattern(Enum):
56
56
 
57
57
  Pattern_4 = (
58
58
  True, # Return_H_First
59
- False, # Return_H_Only
59
+ False, # Return_H_Only
60
60
  True, # Forward_H_only
61
- ("hidden_states",), # In
61
+ ("hidden_states",), # In
62
62
  ("hidden_states", "encoder_hidden_states"), # Out
63
63
  True, # Supported
64
64
  )
65
65
 
66
66
  Pattern_5 = (
67
- False, # Return_H_First
68
- False, # Return_H_Only
67
+ False, # Return_H_First
68
+ False, # Return_H_Only
69
69
  True, # Forward_H_only
70
- ("hidden_states",), # In
70
+ ("hidden_states",), # In
71
71
  ("encoder_hidden_states", "hidden_states"), # Out
72
72
  True, # Supported
73
73
  )
@@ -1,7 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
- from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
4
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig
3
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
4
+ from cache_dit.caching.cache_contexts import CalibratorConfig
5
5
 
6
6
  from cache_dit.logger import init_logger
7
7
 
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
11
11
  class ParamsModifier:
12
12
  def __init__(
13
13
  self,
14
- # Basic DBCache config: BasicCacheConfig
14
+ # BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
15
15
  cache_config: BasicCacheConfig = None,
16
16
  # Calibrator config: TaylorSeerCalibratorConfig, etc.
17
17
  calibrator_config: Optional[CalibratorConfig] = None,
@@ -22,7 +22,7 @@ class ParamsModifier:
22
22
 
23
23
  # WARNING: Deprecated cache config params. These parameters are now retained
24
24
  # for backward compatibility but will be removed in the future.
25
- deprecated_cache_kwargs = {
25
+ deprecated_kwargs = {
26
26
  "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
27
27
  "Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
28
28
  "max_warmup_steps": kwargs.get("max_warmup_steps", None),
@@ -40,20 +40,20 @@ class ParamsModifier:
40
40
  ),
41
41
  }
42
42
 
43
- deprecated_cache_kwargs = {
44
- k: v for k, v in deprecated_cache_kwargs.items() if v is not None
43
+ deprecated_kwargs = {
44
+ k: v for k, v in deprecated_kwargs.items() if v is not None
45
45
  }
46
46
 
47
- if deprecated_cache_kwargs:
47
+ if deprecated_kwargs:
48
48
  logger.warning(
49
49
  "Manually settup DBCache context without BasicCacheConfig is "
50
50
  "deprecated and will be removed in the future, please use "
51
51
  "`cache_config` parameter instead!"
52
52
  )
53
53
  if cache_config is not None:
54
- cache_config.update(**deprecated_cache_kwargs)
54
+ cache_config.update(**deprecated_kwargs)
55
55
  else:
56
- cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
56
+ cache_config = BasicCacheConfig(**deprecated_kwargs)
57
57
 
58
58
  if cache_config is not None:
59
59
  self._context_kwargs["cache_config"] = cache_config
@@ -68,7 +68,7 @@ class ParamsModifier:
68
68
  "deprecated and will be removed in the future, please use "
69
69
  "`calibrator_config` parameter instead!"
70
70
  )
71
- from cache_dit.cache_factory.cache_contexts.calibrators import (
71
+ from cache_dit.caching.cache_contexts.calibrators import (
72
72
  TaylorSeerCalibratorConfig,
73
73
  )
74
74
 
@@ -0,0 +1,15 @@
1
+ from cache_dit.caching.patch_functors.functor_base import PatchFunctor
2
+ from cache_dit.caching.patch_functors.functor_dit import DiTPatchFunctor
3
+ from cache_dit.caching.patch_functors.functor_flux import FluxPatchFunctor
4
+ from cache_dit.caching.patch_functors.functor_chroma import (
5
+ ChromaPatchFunctor,
6
+ )
7
+ from cache_dit.caching.patch_functors.functor_hidream import (
8
+ HiDreamPatchFunctor,
9
+ )
10
+ from cache_dit.caching.patch_functors.functor_hunyuan_dit import (
11
+ HunyuanDiTPatchFunctor,
12
+ )
13
+ from cache_dit.caching.patch_functors.functor_qwen_image_controlnet import (
14
+ QwenImageControlNetPatchFunctor,
15
+ )
@@ -13,7 +13,7 @@ from diffusers.utils import (
13
13
  unscale_lora_layers,
14
14
  )
15
15
 
16
- from cache_dit.cache_factory.patch_functors.functor_base import (
16
+ from cache_dit.caching.patch_functors.functor_base import (
17
17
  PatchFunctor,
18
18
  )
19
19
  from cache_dit.logger import init_logger
@@ -6,7 +6,7 @@ from diffusers.models.transformers.dit_transformer_2d import (
6
6
  DiTTransformer2DModel,
7
7
  Transformer2DModelOutput,
8
8
  )
9
- from cache_dit.cache_factory.patch_functors.functor_base import (
9
+ from cache_dit.caching.patch_functors.functor_base import (
10
10
  PatchFunctor,
11
11
  )
12
12
  from cache_dit.logger import init_logger
@@ -14,7 +14,7 @@ from diffusers.utils import (
14
14
  unscale_lora_layers,
15
15
  )
16
16
 
17
- from cache_dit.cache_factory.patch_functors.functor_base import (
17
+ from cache_dit.caching.patch_functors.functor_base import (
18
18
  PatchFunctor,
19
19
  )
20
20
  from cache_dit.logger import init_logger
@@ -13,7 +13,7 @@ from diffusers.utils import (
13
13
  scale_lora_layers,
14
14
  unscale_lora_layers,
15
15
  )
16
- from cache_dit.cache_factory.patch_functors.functor_base import (
16
+ from cache_dit.caching.patch_functors.functor_base import (
17
17
  PatchFunctor,
18
18
  )
19
19
  from cache_dit.logger import init_logger
@@ -362,9 +362,7 @@ def __patch_transformer_forward__(
362
362
  )
363
363
  if hidden_states_masks is not None:
364
364
  # NOTE: Patched
365
- cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[
366
- self.double_stream_blocks[-1].block._block_id
367
- ]
365
+ cur_llama31_encoder_hidden_states = llama31_encoder_hidden_states[0]
368
366
  encoder_attention_mask_ones = torch.ones(
369
367
  (
370
368
  batch_size,
@@ -5,7 +5,7 @@ from diffusers.models.transformers.hunyuan_transformer_2d import (
5
5
  HunyuanDiTBlock,
6
6
  Transformer2DModelOutput,
7
7
  )
8
- from cache_dit.cache_factory.patch_functors.functor_base import (
8
+ from cache_dit.caching.patch_functors.functor_base import (
9
9
  PatchFunctor,
10
10
  )
11
11
  from cache_dit.logger import init_logger