cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,17 @@
1
+ import torch
1
2
  from typing import Any, Tuple, List, Union, Optional
2
- from diffusers import DiffusionPipeline
3
- from cache_dit.cache_factory.cache_types import CacheType
4
- from cache_dit.cache_factory.block_adapters import BlockAdapter
5
- from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
6
- from cache_dit.cache_factory.cache_adapters import CachedAdapter
7
- from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
8
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig
9
- from cache_dit.cache_factory.params_modifier import ParamsModifier
3
+ from diffusers import DiffusionPipeline, ModelMixin
4
+ from cache_dit.caching.cache_types import CacheType
5
+ from cache_dit.caching.block_adapters import BlockAdapter
6
+ from cache_dit.caching.block_adapters import BlockAdapterRegistry
7
+ from cache_dit.caching.cache_adapters import CachedAdapter
8
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
9
+ from cache_dit.caching.cache_contexts import DBCacheConfig
10
+ from cache_dit.caching.cache_contexts import DBPruneConfig
11
+ from cache_dit.caching.cache_contexts import CalibratorConfig
12
+ from cache_dit.caching.params_modifier import ParamsModifier
13
+ from cache_dit.parallelism import ParallelismConfig
14
+ from cache_dit.parallelism import enable_parallelism
10
15
 
11
16
  from cache_dit.logger import init_logger
12
17
 
@@ -18,9 +23,18 @@ def enable_cache(
18
23
  pipe_or_adapter: Union[
19
24
  DiffusionPipeline,
20
25
  BlockAdapter,
26
+ # Transformer-only
27
+ torch.nn.Module,
28
+ ModelMixin,
21
29
  ],
22
- # Basic DBCache config: BasicCacheConfig
23
- cache_config: BasicCacheConfig = BasicCacheConfig(),
30
+ # BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
31
+ cache_config: Optional[
32
+ Union[
33
+ BasicCacheConfig,
34
+ DBCacheConfig,
35
+ DBPruneConfig,
36
+ ]
37
+ ] = None,
24
38
  # Calibrator config: TaylorSeerCalibratorConfig, etc.
25
39
  calibrator_config: Optional[CalibratorConfig] = None,
26
40
  # Modify cache context params for specific blocks.
@@ -31,10 +45,15 @@ def enable_cache(
31
45
  List[List[ParamsModifier]],
32
46
  ]
33
47
  ] = None,
48
+ # Config for Parallelism
49
+ parallelism_config: Optional[ParallelismConfig] = None,
34
50
  # Other cache context kwargs: Deprecated cache kwargs
35
51
  **kwargs,
36
52
  ) -> Union[
37
53
  DiffusionPipeline,
54
+ # Transformer-only
55
+ torch.nn.Module,
56
+ ModelMixin,
38
57
  BlockAdapter,
39
58
  ]:
40
59
  r"""
@@ -64,10 +83,9 @@ def enable_cache(
64
83
  with minimal code changes.
65
84
 
66
85
  Args:
67
- pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
86
+ pipe_or_adapter (`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
68
87
  The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
69
- For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
70
- for the usgae of BlockAdapter.
88
+ For example: cache_dit.enable_cache(FluxPipeline(...)).
71
89
 
72
90
  cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
73
91
  Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
@@ -107,6 +125,10 @@ def enable_cache(
107
125
  Whether to compute separate difference values for CFG and non-CFG steps, default is True.
108
126
  If False, we will use the computed difference from the current non-CFG transformer step
109
127
  for the current CFG step.
128
+ num_inference_steps (`int`, *optional*, defaults to None):
129
+ num_inference_steps for DiffusionPipeline, used to adjust some internal settings
130
+ for better caching performance. For example, we will refresh the cache once the
131
+ executed steps exceed num_inference_steps if num_inference_steps is provided.
110
132
 
111
133
  calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
112
134
  Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
@@ -121,8 +143,29 @@ def enable_cache(
121
143
  **kwargs: (`dict`, *optional*, defaults to {}):
122
144
  The same as 'kwargs' param in cache_dit.enable_cache() interface.
123
145
 
146
+ parallelism_config (`ParallelismConfig`, *optional*, defaults to None):
147
+ Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
148
+ parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
149
+ for more details of ParallelismConfig.
150
+ backend: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
151
+ Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
152
+ For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
153
+ only NATIVE_PYTORCH backend is supported.
154
+ ulysses_size: (`int`, *optional*, defaults to None):
155
+ The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
156
+ This setting is only valid when backend is NATIVE_DIFFUSER.
157
+ ring_size: (`int`, *optional*, defaults to None):
158
+ The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
159
+ This setting is only valid when backend is NATIVE_DIFFUSER.
160
+ tp_size: (`int`, *optional*, defaults to None):
161
+ The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
162
+ This setting is only valid when backend is NATIVE_PYTORCH.
163
+ parallel_kwargs: (`dict`, *optional*, defaults to {}):
164
+ Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
165
+ it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
166
+
124
167
  kwargs (`dict`, *optional*, defaults to {})
125
- Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
168
+ Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_context.py
126
169
  for more details.
127
170
 
128
171
  Examples:
@@ -135,15 +178,29 @@ def enable_cache(
135
178
  >>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
136
179
  >>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
137
180
  """
181
+ # Precheck for compatibility of different configurations
182
+ if cache_config is None:
183
+ if parallelism_config is None:
184
+ # Set default cache config only when parallelism is not enabled
185
+ logger.info("cache_config is None, using default DBCacheConfig")
186
+ cache_config = DBCacheConfig()
187
+ else:
188
+ # Allow empty cache_config when parallelism is enabled
189
+ logger.warning(
190
+ "Parallelism is enabled and cache_config is None. Please manually "
191
+ "set cache_config to avoid potential compatibility issues. "
192
+ "Otherwise, cache will not be enabled."
193
+ )
194
+
138
195
  # Collect cache context kwargs
139
- cache_context_kwargs = {}
140
- if (cache_type := cache_context_kwargs.pop("cache_type", None)) is not None:
196
+ context_kwargs = {}
197
+ if (cache_type := context_kwargs.get("cache_type", None)) is not None:
141
198
  if cache_type == CacheType.NONE:
142
199
  return pipe_or_adapter
143
200
 
144
- # WARNING: Deprecated cache config params. These parameters are now retained
201
+ # NOTE: Deprecated cache config params. These parameters are now retained
145
202
  # for backward compatibility but will be removed in the future.
146
- deprecated_cache_kwargs = {
203
+ deprecated_kwargs = {
147
204
  "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
148
205
  "Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
149
206
  "max_warmup_steps": kwargs.get("max_warmup_steps", None),
@@ -159,27 +216,27 @@ def enable_cache(
159
216
  ),
160
217
  }
161
218
 
162
- deprecated_cache_kwargs = {
163
- k: v for k, v in deprecated_cache_kwargs.items() if v is not None
219
+ deprecated_kwargs = {
220
+ k: v for k, v in deprecated_kwargs.items() if v is not None
164
221
  }
165
222
 
166
- if deprecated_cache_kwargs:
223
+ if deprecated_kwargs:
167
224
  logger.warning(
168
225
  "Manually settup DBCache context without BasicCacheConfig is "
169
226
  "deprecated and will be removed in the future, please use "
170
227
  "`cache_config` parameter instead!"
171
228
  )
172
229
  if cache_config is not None:
173
- cache_config.update(**deprecated_cache_kwargs)
230
+ cache_config.update(**deprecated_kwargs)
174
231
  else:
175
- cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
232
+ cache_config = BasicCacheConfig(**deprecated_kwargs)
176
233
 
177
234
  if cache_config is not None:
178
- cache_context_kwargs["cache_config"] = cache_config
235
+ context_kwargs["cache_config"] = cache_config
179
236
 
180
- # WARNING: Deprecated taylorseer params. These parameters are now retained
237
+ # NOTE: Deprecated taylorseer params. These parameters are now retained
181
238
  # for backward compatibility but will be removed in the future.
182
- if (
239
+ if cache_config is not None and (
183
240
  kwargs.get("enable_taylorseer", None) is not None
184
241
  or kwargs.get("enable_encoder_taylorseer", None) is not None
185
242
  ):
@@ -188,7 +245,7 @@ def enable_cache(
188
245
  "deprecated and will be removed in the future, please use "
189
246
  "`calibrator_config` parameter instead!"
190
247
  )
191
- from cache_dit.cache_factory.cache_contexts.calibrators import (
248
+ from cache_dit.caching.cache_contexts.calibrators import (
192
249
  TaylorSeerCalibratorConfig,
193
250
  )
194
251
 
@@ -202,23 +259,79 @@ def enable_cache(
202
259
  )
203
260
 
204
261
  if calibrator_config is not None:
205
- cache_context_kwargs["calibrator_config"] = calibrator_config
262
+ context_kwargs["calibrator_config"] = calibrator_config
206
263
 
207
264
  if params_modifiers is not None:
208
- cache_context_kwargs["params_modifiers"] = params_modifiers
265
+ context_kwargs["params_modifiers"] = params_modifiers
209
266
 
210
- if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
211
- return CachedAdapter.apply(
267
+ if cache_config is not None:
268
+ if isinstance(
212
269
  pipe_or_adapter,
213
- **cache_context_kwargs,
214
- )
270
+ (DiffusionPipeline, BlockAdapter, torch.nn.Module, ModelMixin),
271
+ ):
272
+ pipe_or_adapter = CachedAdapter.apply(
273
+ pipe_or_adapter,
274
+ **context_kwargs,
275
+ )
276
+ else:
277
+ raise ValueError(
278
+ f"type: {type(pipe_or_adapter)} is not valid, "
279
+ "Please pass DiffusionPipeline or BlockAdapter"
280
+ "for the 1's position param: pipe_or_adapter"
281
+ )
215
282
  else:
216
- raise ValueError(
217
- f"type: {type(pipe_or_adapter)} is not valid, "
218
- "Please pass DiffusionPipeline or BlockAdapter"
219
- "for the 1's position param: pipe_or_adapter"
283
+ logger.warning(
284
+ "cache_config is None, skip enabling cache for "
285
+ f"{pipe_or_adapter.__class__.__name__}."
220
286
  )
221
287
 
288
+ # NOTE: Users should always enable parallelism after applying
289
+ # cache to avoid hooks conflict.
290
+ if parallelism_config is not None:
291
+ assert isinstance(
292
+ parallelism_config, ParallelismConfig
293
+ ), "parallelism_config should be of type ParallelismConfig."
294
+
295
+ transformers = []
296
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
297
+ adapter = BlockAdapterRegistry.get_adapter(pipe_or_adapter)
298
+ if adapter is None:
299
+ assert hasattr(pipe_or_adapter, "transformer"), (
300
+ "The given DiffusionPipeline does not have "
301
+ "a 'transformer' attribute, cannot enable "
302
+ "parallelism."
303
+ )
304
+ transformers = [pipe_or_adapter.transformer]
305
+ else:
306
+ adapter = BlockAdapter.normalize(adapter, unique=False)
307
+ transformers = BlockAdapter.flatten(adapter.transformer)
308
+ else:
309
+ if not BlockAdapter.is_normalized(pipe_or_adapter):
310
+ pipe_or_adapter = BlockAdapter.normalize(
311
+ pipe_or_adapter, unique=False
312
+ )
313
+ transformers = BlockAdapter.flatten(pipe_or_adapter.transformer)
314
+
315
+ if len(transformers) == 0:
316
+ logger.warning(
317
+ "No transformer is detected in the "
318
+ "BlockAdapter, skip enabling parallelism."
319
+ )
320
+ return pipe_or_adapter
321
+
322
+ if len(transformers) > 1:
323
+ logger.warning(
324
+ "Multiple transformers are detected in the "
325
+ "BlockAdapter, all transfomers will be "
326
+ "enabled for parallelism."
327
+ )
328
+ for i, transformer in enumerate(transformers):
329
+ # Enable parallelism for the transformer inplace
330
+ transformers[i] = enable_parallelism(
331
+ transformer, parallelism_config
332
+ )
333
+ return pipe_or_adapter
334
+
222
335
 
223
336
  def disable_cache(
224
337
  pipe_or_adapter: Union[
@@ -6,7 +6,8 @@ logger = init_logger(__name__)
6
6
 
7
7
  class CacheType(Enum):
8
8
  NONE = "NONE"
9
- DBCache = "Dual_Block_Cache"
9
+ DBCache = "DBCache" # "Dual_Block_Cache"
10
+ DBPrune = "DBPrune" # "Dynamic_Block_Prune"
10
11
 
11
12
  @staticmethod
12
13
  def type(type_hint: "CacheType | str") -> "CacheType":
@@ -14,6 +15,9 @@ class CacheType(Enum):
14
15
  return type_hint
15
16
  return cache_type(type_hint)
16
17
 
18
+ def __str__(self) -> str:
19
+ return self.value
20
+
17
21
 
18
22
  def cache_type(type_hint: "CacheType | str") -> "CacheType":
19
23
  if type_hint is None:
@@ -21,7 +25,6 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
21
25
 
22
26
  if isinstance(type_hint, CacheType):
23
27
  return type_hint
24
-
25
28
  elif type_hint.upper() in (
26
29
  "DUAL_BLOCK_CACHE",
27
30
  "DB_CACHE",
@@ -29,6 +32,20 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
29
32
  "DB",
30
33
  ):
31
34
  return CacheType.DBCache
35
+ elif type_hint.upper() in (
36
+ "DYNAMIC_BLOCK_PRUNE",
37
+ "DB_PRUNE",
38
+ "DBPRUNE",
39
+ "DBP",
40
+ ):
41
+ return CacheType.DBPrune
42
+ elif type_hint.upper() in (
43
+ "NONE",
44
+ "NO_CACHE",
45
+ "NOCACHE",
46
+ "NC",
47
+ ):
48
+ return CacheType.NONE
32
49
  return CacheType.NONE
33
50
 
34
51
 
@@ -20,33 +20,33 @@ class ForwardPattern(Enum):
20
20
 
21
21
  Pattern_0 = (
22
22
  True, # Return_H_First
23
- False, # Return_H_Only
24
- False, # Forward_H_only
23
+ False, # Return_H_Only
24
+ False, # Forward_H_only
25
25
  ("hidden_states", "encoder_hidden_states"), # In
26
26
  ("hidden_states", "encoder_hidden_states"), # Out
27
27
  True, # Supported
28
28
  )
29
29
 
30
30
  Pattern_1 = (
31
- False, # Return_H_First
32
- False, # Return_H_Only
33
- False, # Forward_H_only
31
+ False, # Return_H_First
32
+ False, # Return_H_Only
33
+ False, # Forward_H_only
34
34
  ("hidden_states", "encoder_hidden_states"), # In
35
35
  ("encoder_hidden_states", "hidden_states"), # Out
36
36
  True, # Supported
37
37
  )
38
38
 
39
39
  Pattern_2 = (
40
- False, # Return_H_First
40
+ False, # Return_H_First
41
41
  True, # Return_H_Only
42
- False, # Forward_H_only
42
+ False, # Forward_H_only
43
43
  ("hidden_states", "encoder_hidden_states"), # In
44
- ("hidden_states",), # Out
44
+ ("hidden_states",), # Out
45
45
  True, # Supported
46
46
  )
47
47
 
48
48
  Pattern_3 = (
49
- False, # Return_H_First
49
+ False, # Return_H_First
50
50
  True, # Return_H_Only
51
51
  True, # Forward_H_only
52
52
  ("hidden_states",), # In
@@ -56,18 +56,18 @@ class ForwardPattern(Enum):
56
56
 
57
57
  Pattern_4 = (
58
58
  True, # Return_H_First
59
- False, # Return_H_Only
59
+ False, # Return_H_Only
60
60
  True, # Forward_H_only
61
- ("hidden_states",), # In
61
+ ("hidden_states",), # In
62
62
  ("hidden_states", "encoder_hidden_states"), # Out
63
63
  True, # Supported
64
64
  )
65
65
 
66
66
  Pattern_5 = (
67
- False, # Return_H_First
68
- False, # Return_H_Only
67
+ False, # Return_H_First
68
+ False, # Return_H_Only
69
69
  True, # Forward_H_only
70
- ("hidden_states",), # In
70
+ ("hidden_states",), # In
71
71
  ("encoder_hidden_states", "hidden_states"), # Out
72
72
  True, # Supported
73
73
  )
@@ -1,7 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
- from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
4
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig
3
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
4
+ from cache_dit.caching.cache_contexts import CalibratorConfig
5
5
 
6
6
  from cache_dit.logger import init_logger
7
7
 
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
11
11
  class ParamsModifier:
12
12
  def __init__(
13
13
  self,
14
- # Basic DBCache config: BasicCacheConfig
14
+ # BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
15
15
  cache_config: BasicCacheConfig = None,
16
16
  # Calibrator config: TaylorSeerCalibratorConfig, etc.
17
17
  calibrator_config: Optional[CalibratorConfig] = None,
@@ -22,7 +22,7 @@ class ParamsModifier:
22
22
 
23
23
  # WARNING: Deprecated cache config params. These parameters are now retained
24
24
  # for backward compatibility but will be removed in the future.
25
- deprecated_cache_kwargs = {
25
+ deprecated_kwargs = {
26
26
  "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
27
27
  "Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
28
28
  "max_warmup_steps": kwargs.get("max_warmup_steps", None),
@@ -40,20 +40,20 @@ class ParamsModifier:
40
40
  ),
41
41
  }
42
42
 
43
- deprecated_cache_kwargs = {
44
- k: v for k, v in deprecated_cache_kwargs.items() if v is not None
43
+ deprecated_kwargs = {
44
+ k: v for k, v in deprecated_kwargs.items() if v is not None
45
45
  }
46
46
 
47
- if deprecated_cache_kwargs:
47
+ if deprecated_kwargs:
48
48
  logger.warning(
49
49
  "Manually settup DBCache context without BasicCacheConfig is "
50
50
  "deprecated and will be removed in the future, please use "
51
51
  "`cache_config` parameter instead!"
52
52
  )
53
53
  if cache_config is not None:
54
- cache_config.update(**deprecated_cache_kwargs)
54
+ cache_config.update(**deprecated_kwargs)
55
55
  else:
56
- cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
56
+ cache_config = BasicCacheConfig(**deprecated_kwargs)
57
57
 
58
58
  if cache_config is not None:
59
59
  self._context_kwargs["cache_config"] = cache_config
@@ -68,7 +68,7 @@ class ParamsModifier:
68
68
  "deprecated and will be removed in the future, please use "
69
69
  "`calibrator_config` parameter instead!"
70
70
  )
71
- from cache_dit.cache_factory.cache_contexts.calibrators import (
71
+ from cache_dit.caching.cache_contexts.calibrators import (
72
72
  TaylorSeerCalibratorConfig,
73
73
  )
74
74
 
@@ -0,0 +1,15 @@
1
+ from cache_dit.caching.patch_functors.functor_base import PatchFunctor
2
+ from cache_dit.caching.patch_functors.functor_dit import DiTPatchFunctor
3
+ from cache_dit.caching.patch_functors.functor_flux import FluxPatchFunctor
4
+ from cache_dit.caching.patch_functors.functor_chroma import (
5
+ ChromaPatchFunctor,
6
+ )
7
+ from cache_dit.caching.patch_functors.functor_hidream import (
8
+ HiDreamPatchFunctor,
9
+ )
10
+ from cache_dit.caching.patch_functors.functor_hunyuan_dit import (
11
+ HunyuanDiTPatchFunctor,
12
+ )
13
+ from cache_dit.caching.patch_functors.functor_qwen_image_controlnet import (
14
+ QwenImageControlNetPatchFunctor,
15
+ )
@@ -13,7 +13,7 @@ from diffusers.utils import (
13
13
  unscale_lora_layers,
14
14
  )
15
15
 
16
- from cache_dit.cache_factory.patch_functors.functor_base import (
16
+ from cache_dit.caching.patch_functors.functor_base import (
17
17
  PatchFunctor,
18
18
  )
19
19
  from cache_dit.logger import init_logger
@@ -6,7 +6,7 @@ from diffusers.models.transformers.dit_transformer_2d import (
6
6
  DiTTransformer2DModel,
7
7
  Transformer2DModelOutput,
8
8
  )
9
- from cache_dit.cache_factory.patch_functors.functor_base import (
9
+ from cache_dit.caching.patch_functors.functor_base import (
10
10
  PatchFunctor,
11
11
  )
12
12
  from cache_dit.logger import init_logger
@@ -14,7 +14,7 @@ from diffusers.utils import (
14
14
  unscale_lora_layers,
15
15
  )
16
16
 
17
- from cache_dit.cache_factory.patch_functors.functor_base import (
17
+ from cache_dit.caching.patch_functors.functor_base import (
18
18
  PatchFunctor,
19
19
  )
20
20
  from cache_dit.logger import init_logger
@@ -13,7 +13,7 @@ from diffusers.utils import (
13
13
  scale_lora_layers,
14
14
  unscale_lora_layers,
15
15
  )
16
- from cache_dit.cache_factory.patch_functors.functor_base import (
16
+ from cache_dit.caching.patch_functors.functor_base import (
17
17
  PatchFunctor,
18
18
  )
19
19
  from cache_dit.logger import init_logger
@@ -5,7 +5,7 @@ from diffusers.models.transformers.hunyuan_transformer_2d import (
5
5
  HunyuanDiTBlock,
6
6
  Transformer2DModelOutput,
7
7
  )
8
- from cache_dit.cache_factory.patch_functors.functor_base import (
8
+ from cache_dit.caching.patch_functors.functor_base import (
9
9
  PatchFunctor,
10
10
  )
11
11
  from cache_dit.logger import init_logger
@@ -11,7 +11,7 @@ from diffusers.utils import (
11
11
  scale_lora_layers,
12
12
  unscale_lora_layers,
13
13
  )
14
- from cache_dit.cache_factory.patch_functors.functor_base import (
14
+ from cache_dit.caching.patch_functors.functor_base import (
15
15
  PatchFunctor,
16
16
  )
17
17
  from cache_dit.logger import init_logger
@@ -7,10 +7,6 @@ def load_cache_options_from_yaml(yaml_file_path):
7
7
  kwargs: dict = yaml.safe_load(f)
8
8
 
9
9
  required_keys = [
10
- "max_warmup_steps",
11
- "max_cached_steps",
12
- "Fn_compute_blocks",
13
- "Bn_compute_blocks",
14
10
  "residual_diff_threshold",
15
11
  ]
16
12
  for key in required_keys:
@@ -21,7 +17,7 @@ def load_cache_options_from_yaml(yaml_file_path):
21
17
 
22
18
  cache_context_kwargs = {}
23
19
  if kwargs.get("enable_taylorseer", False):
24
- from cache_dit.cache_factory.cache_contexts.calibrators import (
20
+ from cache_dit.caching.cache_contexts.calibrators import (
25
21
  TaylorSeerCalibratorConfig,
26
22
  )
27
23
 
@@ -38,10 +34,25 @@ def load_cache_options_from_yaml(yaml_file_path):
38
34
  )
39
35
  )
40
36
 
41
- from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
37
+ if "cache_type" not in kwargs:
38
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
42
39
 
43
- cache_context_kwargs["cache_config"] = BasicCacheConfig()
44
- cache_context_kwargs["cache_config"].update(**kwargs)
40
+ cache_context_kwargs["cache_config"] = BasicCacheConfig()
41
+ cache_context_kwargs["cache_config"].update(**kwargs)
42
+ else:
43
+ cache_type = kwargs.pop("cache_type")
44
+ if cache_type == "DBCache":
45
+ from cache_dit.caching.cache_contexts import DBCacheConfig
46
+
47
+ cache_context_kwargs["cache_config"] = DBCacheConfig()
48
+ cache_context_kwargs["cache_config"].update(**kwargs)
49
+ elif cache_type == "DBPrune":
50
+ from cache_dit.caching.cache_contexts import DBPruneConfig
51
+
52
+ cache_context_kwargs["cache_config"] = DBPruneConfig()
53
+ cache_context_kwargs["cache_config"].update(**kwargs)
54
+ else:
55
+ raise ValueError(f"Unsupported cache_type: {cache_type}.")
45
56
 
46
57
  return cache_context_kwargs
47
58
 
@@ -1,3 +1,14 @@
1
+ try:
2
+ import ImageReward
3
+ import lpips
4
+ import skimage
5
+ import scipy
6
+ except ImportError:
7
+ raise ImportError(
8
+ "Metrics functionality requires the 'metrics' extra dependencies. "
9
+ "Install with:\npip install cache-dit[metrics]"
10
+ )
11
+
1
12
  from cache_dit.metrics.metrics import compute_psnr
2
13
  from cache_dit.metrics.metrics import compute_ssim
3
14
  from cache_dit.metrics.metrics import compute_mse
@@ -0,0 +1,3 @@
1
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
2
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
3
+ from cache_dit.parallelism.parallel_interface import enable_parallelism
@@ -0,0 +1,6 @@
1
+ from cache_dit.parallelism.backends.native_diffusers.context_parallelism import (
2
+ ContextParallelismPlannerRegister,
3
+ )
4
+ from cache_dit.parallelism.backends.native_diffusers.parallel_difffusers import (
5
+ maybe_enable_parallelism,
6
+ )