cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -4,32 +4,50 @@ except ImportError:
4
4
  __version__ = "unknown version"
5
5
  version_tuple = (0, 0, "unknown version")
6
6
 
7
- from cache_dit.utils import summary
8
- from cache_dit.utils import strify
7
+
9
8
  from cache_dit.utils import disable_print
10
9
  from cache_dit.logger import init_logger
11
- from cache_dit.cache_factory import load_options
12
- from cache_dit.cache_factory import enable_cache
13
- from cache_dit.cache_factory import disable_cache
14
- from cache_dit.cache_factory import cache_type
15
- from cache_dit.cache_factory import block_range
16
- from cache_dit.cache_factory import CacheType
17
- from cache_dit.cache_factory import BlockAdapter
18
- from cache_dit.cache_factory import ParamsModifier
19
- from cache_dit.cache_factory import ForwardPattern
20
- from cache_dit.cache_factory import PatchFunctor
21
- from cache_dit.cache_factory import BasicCacheConfig
22
- from cache_dit.cache_factory import CalibratorConfig
23
- from cache_dit.cache_factory import TaylorSeerCalibratorConfig
24
- from cache_dit.cache_factory import FoCaCalibratorConfig
25
- from cache_dit.cache_factory import supported_pipelines
26
- from cache_dit.cache_factory import get_adapter
10
+ from cache_dit.caching import load_options
11
+ from cache_dit.caching import enable_cache
12
+ from cache_dit.caching import disable_cache
13
+ from cache_dit.caching import cache_type
14
+ from cache_dit.caching import block_range
15
+ from cache_dit.caching import CacheType
16
+ from cache_dit.caching import BlockAdapter
17
+ from cache_dit.caching import ParamsModifier
18
+ from cache_dit.caching import ForwardPattern
19
+ from cache_dit.caching import PatchFunctor
20
+ from cache_dit.caching import BasicCacheConfig
21
+ from cache_dit.caching import DBCacheConfig
22
+ from cache_dit.caching import DBPruneConfig
23
+ from cache_dit.caching import CalibratorConfig
24
+ from cache_dit.caching import TaylorSeerCalibratorConfig
25
+ from cache_dit.caching import FoCaCalibratorConfig
26
+ from cache_dit.caching import supported_pipelines
27
+ from cache_dit.caching import get_adapter
27
28
  from cache_dit.compile import set_compile_configs
28
- from cache_dit.quantize import quantize
29
+ from cache_dit.parallelism import ParallelismBackend
30
+ from cache_dit.parallelism import ParallelismConfig
31
+ from cache_dit.summary import supported_matrix
32
+ from cache_dit.summary import summary
33
+ from cache_dit.summary import strify
34
+
35
+ try:
36
+ from cache_dit.quantize import quantize
37
+ except ImportError as e: # noqa: F841
38
+ err_msg = str(e)
39
+
40
+ def quantize(*args, **kwargs):
41
+ raise ImportError(
42
+ "Quantization requires additional dependencies. "
43
+ "Please install cache-dit[quantization] or cache-dit[all] "
44
+ f"to use this feature. Error message: {err_msg}"
45
+ )
29
46
 
30
47
 
31
48
  NONE = CacheType.NONE
32
49
  DBCache = CacheType.DBCache
50
+ DBPrune = CacheType.DBPrune
33
51
 
34
52
  Pattern_0 = ForwardPattern.Pattern_0
35
53
  Pattern_1 = ForwardPattern.Pattern_1
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.0.3'
32
- __version_tuple__ = version_tuple = (1, 0, 3)
31
+ __version__ = version = '1.0.14'
32
+ __version_tuple__ = version_tuple = (1, 0, 14)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,36 @@
1
+ from cache_dit.caching.cache_types import CacheType
2
+ from cache_dit.caching.cache_types import cache_type
3
+ from cache_dit.caching.cache_types import block_range
4
+
5
+ from cache_dit.caching.forward_pattern import ForwardPattern
6
+ from cache_dit.caching.params_modifier import ParamsModifier
7
+ from cache_dit.caching.patch_functors import PatchFunctor
8
+
9
+ from cache_dit.caching.block_adapters import BlockAdapter
10
+ from cache_dit.caching.block_adapters import BlockAdapterRegistry
11
+ from cache_dit.caching.block_adapters import FakeDiffusionPipeline
12
+
13
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
14
+ from cache_dit.caching.cache_contexts import DBCacheConfig
15
+ from cache_dit.caching.cache_contexts import CachedContext
16
+ from cache_dit.caching.cache_contexts import CachedContextManager
17
+ from cache_dit.caching.cache_contexts import DBPruneConfig
18
+ from cache_dit.caching.cache_contexts import PrunedContext
19
+ from cache_dit.caching.cache_contexts import PrunedContextManager
20
+ from cache_dit.caching.cache_contexts import ContextManager
21
+ from cache_dit.caching.cache_contexts import CalibratorConfig
22
+ from cache_dit.caching.cache_contexts import TaylorSeerCalibratorConfig
23
+ from cache_dit.caching.cache_contexts import FoCaCalibratorConfig
24
+
25
+ from cache_dit.caching.cache_blocks import CachedBlocks
26
+ from cache_dit.caching.cache_blocks import PrunedBlocks
27
+ from cache_dit.caching.cache_blocks import UnifiedBlocks
28
+
29
+ from cache_dit.caching.cache_adapters import CachedAdapter
30
+
31
+ from cache_dit.caching.cache_interface import enable_cache
32
+ from cache_dit.caching.cache_interface import disable_cache
33
+ from cache_dit.caching.cache_interface import supported_pipelines
34
+ from cache_dit.caching.cache_interface import get_adapter
35
+
36
+ from cache_dit.caching.utils import load_options
@@ -1,7 +1,10 @@
1
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
- from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
3
- from cache_dit.cache_factory.block_adapters.block_adapters import ParamsModifier
4
- from cache_dit.cache_factory.block_adapters.block_registers import (
1
+ from cache_dit.caching.forward_pattern import ForwardPattern
2
+ from cache_dit.caching.block_adapters.block_adapters import BlockAdapter
3
+ from cache_dit.caching.block_adapters.block_adapters import (
4
+ FakeDiffusionPipeline,
5
+ )
6
+ from cache_dit.caching.block_adapters.block_adapters import ParamsModifier
7
+ from cache_dit.caching.block_adapters.block_registers import (
5
8
  BlockAdapterRegistry,
6
9
  )
7
10
 
@@ -12,7 +15,10 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
12
15
  from cache_dit.utils import is_diffusers_at_least_0_3_5
13
16
 
14
17
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
15
- if is_diffusers_at_least_0_3_5():
18
+ transformer_cls_name: str = pipe.transformer.__class__.__name__
19
+ if is_diffusers_at_least_0_3_5() and not transformer_cls_name.startswith(
20
+ "Nunchaku"
21
+ ):
16
22
  return BlockAdapter(
17
23
  pipe=pipe,
18
24
  transformer=pipe.transformer,
@@ -24,6 +30,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
24
30
  ForwardPattern.Pattern_1,
25
31
  ForwardPattern.Pattern_1,
26
32
  ],
33
+ check_forward_pattern=True,
27
34
  **kwargs,
28
35
  )
29
36
  else:
@@ -38,6 +45,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
38
45
  ForwardPattern.Pattern_1,
39
46
  ForwardPattern.Pattern_3,
40
47
  ],
48
+ check_forward_pattern=True,
41
49
  **kwargs,
42
50
  )
43
51
 
@@ -52,6 +60,7 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
52
60
  transformer=pipe.transformer,
53
61
  blocks=pipe.transformer.transformer_blocks,
54
62
  forward_pattern=ForwardPattern.Pattern_0,
63
+ check_forward_pattern=True,
55
64
  **kwargs,
56
65
  )
57
66
 
@@ -66,6 +75,7 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
66
75
  transformer=pipe.transformer,
67
76
  blocks=pipe.transformer.transformer_blocks,
68
77
  forward_pattern=ForwardPattern.Pattern_0,
78
+ check_forward_pattern=True,
69
79
  **kwargs,
70
80
  )
71
81
 
@@ -101,6 +111,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
101
111
  ForwardPattern.Pattern_2,
102
112
  ForwardPattern.Pattern_2,
103
113
  ],
114
+ check_forward_pattern=True,
104
115
  has_separate_cfg=True,
105
116
  **kwargs,
106
117
  )
@@ -111,6 +122,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
111
122
  transformer=pipe.transformer,
112
123
  blocks=pipe.transformer.blocks,
113
124
  forward_pattern=ForwardPattern.Pattern_2,
125
+ check_forward_pattern=True,
114
126
  has_separate_cfg=True,
115
127
  **kwargs,
116
128
  )
@@ -132,6 +144,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
132
144
  ForwardPattern.Pattern_0,
133
145
  ForwardPattern.Pattern_0,
134
146
  ],
147
+ check_forward_pattern=True,
135
148
  # The type hint in diffusers is wrong
136
149
  check_num_outputs=False,
137
150
  **kwargs,
@@ -146,7 +159,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
146
159
 
147
160
  pipe_cls_name: str = pipe.__class__.__name__
148
161
  if pipe_cls_name.startswith("QwenImageControlNet"):
149
- from cache_dit.cache_factory.patch_functors import (
162
+ from cache_dit.caching.patch_functors import (
150
163
  QwenImageControlNetPatchFunctor,
151
164
  )
152
165
 
@@ -156,6 +169,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
156
169
  blocks=pipe.transformer.transformer_blocks,
157
170
  forward_pattern=ForwardPattern.Pattern_1,
158
171
  patch_functor=QwenImageControlNetPatchFunctor(),
172
+ check_forward_pattern=True,
159
173
  has_separate_cfg=True,
160
174
  )
161
175
  else:
@@ -164,6 +178,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
164
178
  transformer=pipe.transformer,
165
179
  blocks=pipe.transformer.transformer_blocks,
166
180
  forward_pattern=ForwardPattern.Pattern_1,
181
+ check_forward_pattern=True,
167
182
  has_separate_cfg=True,
168
183
  **kwargs,
169
184
  )
@@ -179,6 +194,7 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
179
194
  transformer=pipe.transformer,
180
195
  blocks=pipe.transformer.transformer_blocks,
181
196
  forward_pattern=ForwardPattern.Pattern_2,
197
+ check_forward_pattern=True,
182
198
  **kwargs,
183
199
  )
184
200
 
@@ -193,6 +209,7 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
193
209
  transformer=pipe.transformer,
194
210
  blocks=pipe.transformer.transformer_blocks,
195
211
  forward_pattern=ForwardPattern.Pattern_2,
212
+ check_forward_pattern=True,
196
213
  **kwargs,
197
214
  )
198
215
 
@@ -207,6 +224,7 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
207
224
  transformer=pipe.transformer,
208
225
  blocks=pipe.transformer.transformer_blocks,
209
226
  forward_pattern=ForwardPattern.Pattern_0,
227
+ check_forward_pattern=True,
210
228
  **kwargs,
211
229
  )
212
230
 
@@ -221,6 +239,7 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
221
239
  transformer=pipe.transformer,
222
240
  blocks=pipe.transformer.transformer_blocks,
223
241
  forward_pattern=ForwardPattern.Pattern_0,
242
+ check_forward_pattern=True,
224
243
  has_separate_cfg=True,
225
244
  **kwargs,
226
245
  )
@@ -236,6 +255,7 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
236
255
  transformer=pipe.transformer,
237
256
  blocks=pipe.transformer.transformer_blocks,
238
257
  forward_pattern=ForwardPattern.Pattern_2,
258
+ check_forward_pattern=True,
239
259
  has_separate_cfg=True,
240
260
  **kwargs,
241
261
  )
@@ -251,6 +271,7 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
251
271
  transformer=pipe.transformer,
252
272
  blocks=pipe.transformer.transformer_blocks,
253
273
  forward_pattern=ForwardPattern.Pattern_0,
274
+ check_forward_pattern=True,
254
275
  **kwargs,
255
276
  )
256
277
 
@@ -268,6 +289,7 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
268
289
  # encoder_hidden_states will never change in the blocks
269
290
  # forward loop.
270
291
  forward_pattern=ForwardPattern.Pattern_3,
292
+ check_forward_pattern=True,
271
293
  has_separate_cfg=True,
272
294
  **kwargs,
273
295
  )
@@ -283,6 +305,7 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
283
305
  transformer=pipe.transformer,
284
306
  blocks=pipe.transformer.transformer_blocks,
285
307
  forward_pattern=ForwardPattern.Pattern_1,
308
+ check_forward_pattern=True,
286
309
  **kwargs,
287
310
  )
288
311
 
@@ -297,6 +320,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
297
320
  transformer=pipe.transformer,
298
321
  blocks=pipe.transformer.transformer_blocks,
299
322
  forward_pattern=ForwardPattern.Pattern_0,
323
+ check_forward_pattern=True,
300
324
  **kwargs,
301
325
  )
302
326
 
@@ -304,7 +328,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
304
328
  @BlockAdapterRegistry.register("DiT")
305
329
  def dit_adapter(pipe, **kwargs) -> BlockAdapter:
306
330
  from diffusers import DiTTransformer2DModel
307
- from cache_dit.cache_factory.patch_functors import DiTPatchFunctor
331
+ from cache_dit.caching.patch_functors import DiTPatchFunctor
308
332
 
309
333
  assert isinstance(pipe.transformer, DiTTransformer2DModel)
310
334
  return BlockAdapter(
@@ -313,6 +337,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
313
337
  blocks=pipe.transformer.transformer_blocks,
314
338
  forward_pattern=ForwardPattern.Pattern_3,
315
339
  patch_functor=DiTPatchFunctor(),
340
+ check_forward_pattern=True,
316
341
  **kwargs,
317
342
  )
318
343
 
@@ -327,6 +352,7 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
327
352
  transformer=pipe.transformer,
328
353
  blocks=pipe.transformer.transformer_layers,
329
354
  forward_pattern=ForwardPattern.Pattern_3,
355
+ check_forward_pattern=True,
330
356
  **kwargs,
331
357
  )
332
358
 
@@ -347,6 +373,7 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
347
373
  ForwardPattern.Pattern_0,
348
374
  ForwardPattern.Pattern_0,
349
375
  ],
376
+ check_forward_pattern=True,
350
377
  **kwargs,
351
378
  )
352
379
 
@@ -364,6 +391,7 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
364
391
  transformer=pipe.transformer,
365
392
  blocks=pipe.transformer.layers,
366
393
  forward_pattern=ForwardPattern.Pattern_3,
394
+ check_forward_pattern=True,
367
395
  **kwargs,
368
396
  )
369
397
 
@@ -378,6 +406,7 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
378
406
  transformer=pipe.transformer,
379
407
  blocks=pipe.transformer.layers,
380
408
  forward_pattern=ForwardPattern.Pattern_3,
409
+ check_forward_pattern=True,
381
410
  **kwargs,
382
411
  )
383
412
 
@@ -392,6 +421,7 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
392
421
  transformer=pipe.transformer,
393
422
  blocks=pipe.transformer.transformer_blocks,
394
423
  forward_pattern=ForwardPattern.Pattern_3,
424
+ check_forward_pattern=True,
395
425
  **kwargs,
396
426
  )
397
427
 
@@ -406,6 +436,7 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
406
436
  transformer=pipe.transformer,
407
437
  blocks=pipe.transformer.transformer_blocks,
408
438
  forward_pattern=ForwardPattern.Pattern_3,
439
+ check_forward_pattern=True,
409
440
  **kwargs,
410
441
  )
411
442
 
@@ -420,6 +451,7 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
420
451
  transformer=pipe.transformer,
421
452
  blocks=pipe.transformer.transformer_blocks,
422
453
  forward_pattern=ForwardPattern.Pattern_3,
454
+ check_forward_pattern=True,
423
455
  **kwargs,
424
456
  )
425
457
 
@@ -442,6 +474,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
442
474
  ForwardPattern.Pattern_1,
443
475
  ForwardPattern.Pattern_1,
444
476
  ],
477
+ check_forward_pattern=True,
445
478
  **kwargs,
446
479
  )
447
480
  else:
@@ -456,6 +489,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
456
489
  ForwardPattern.Pattern_1,
457
490
  ForwardPattern.Pattern_3,
458
491
  ],
492
+ check_forward_pattern=True,
459
493
  **kwargs,
460
494
  )
461
495
 
@@ -470,6 +504,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
470
504
  transformer=pipe.transformer,
471
505
  blocks=pipe.transformer.single_transformer_blocks,
472
506
  forward_pattern=ForwardPattern.Pattern_3,
507
+ check_forward_pattern=True,
473
508
  **kwargs,
474
509
  )
475
510
 
@@ -477,7 +512,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
477
512
  @BlockAdapterRegistry.register("Chroma")
478
513
  def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
479
514
  from diffusers import ChromaTransformer2DModel
480
- from cache_dit.cache_factory.patch_functors import ChromaPatchFunctor
515
+ from cache_dit.caching.patch_functors import ChromaPatchFunctor
481
516
 
482
517
  assert isinstance(pipe.transformer, ChromaTransformer2DModel)
483
518
  return BlockAdapter(
@@ -492,6 +527,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
492
527
  ForwardPattern.Pattern_3,
493
528
  ],
494
529
  patch_functor=ChromaPatchFunctor(),
530
+ check_forward_pattern=True,
495
531
  has_separate_cfg=True,
496
532
  **kwargs,
497
533
  )
@@ -507,6 +543,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
507
543
  transformer=pipe.prior,
508
544
  blocks=pipe.prior.transformer_blocks,
509
545
  forward_pattern=ForwardPattern.Pattern_3,
546
+ check_forward_pattern=True,
510
547
  **kwargs,
511
548
  )
512
549
 
@@ -519,7 +556,7 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
519
556
  # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
520
557
  # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
521
558
  from diffusers import HiDreamImageTransformer2DModel
522
- from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
559
+ from cache_dit.caching.patch_functors import HiDreamPatchFunctor
523
560
 
524
561
  assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
525
562
  return BlockAdapter(
@@ -544,7 +581,7 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
544
581
  @BlockAdapterRegistry.register("HunyuanDiT")
545
582
  def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
546
583
  from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
547
- from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
584
+ from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
548
585
 
549
586
  assert isinstance(
550
587
  pipe.transformer,
@@ -556,6 +593,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
556
593
  blocks=pipe.transformer.blocks,
557
594
  forward_pattern=ForwardPattern.Pattern_3,
558
595
  patch_functor=HunyuanDiTPatchFunctor(),
596
+ check_forward_pattern=True,
559
597
  **kwargs,
560
598
  )
561
599
 
@@ -563,7 +601,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
563
601
  @BlockAdapterRegistry.register("HunyuanDiTPAG")
564
602
  def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
565
603
  from diffusers import HunyuanDiT2DModel
566
- from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
604
+ from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
567
605
 
568
606
  assert isinstance(pipe.transformer, HunyuanDiT2DModel)
569
607
  return BlockAdapter(
@@ -572,5 +610,82 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
572
610
  blocks=pipe.transformer.blocks,
573
611
  forward_pattern=ForwardPattern.Pattern_3,
574
612
  patch_functor=HunyuanDiTPatchFunctor(),
613
+ check_forward_pattern=True,
575
614
  **kwargs,
576
615
  )
616
+
617
+
618
+ @BlockAdapterRegistry.register("Kandinsky5")
619
+ def kandinsky5_adapter(pipe, **kwargs) -> BlockAdapter:
620
+ try:
621
+ from diffusers import Kandinsky5Transformer3DModel
622
+
623
+ assert isinstance(pipe.transformer, Kandinsky5Transformer3DModel)
624
+ return BlockAdapter(
625
+ pipe=pipe,
626
+ transformer=pipe.transformer,
627
+ blocks=pipe.transformer.visual_transformer_blocks,
628
+ forward_pattern=ForwardPattern.Pattern_3, # or Pattern_2
629
+ has_separate_cfg=True,
630
+ check_forward_pattern=False,
631
+ check_num_outputs=False,
632
+ **kwargs,
633
+ )
634
+ except ImportError:
635
+ raise ImportError(
636
+ "Kandinsky5Transformer3DModel is not available in the current diffusers version. "
637
+ "Please upgrade diffusers>=0.36.dev0 to use this adapter."
638
+ )
639
+
640
+
641
+ @BlockAdapterRegistry.register("PRX")
642
+ def prx_adapter(pipe, **kwargs) -> BlockAdapter:
643
+ try:
644
+ from diffusers import PRXTransformer2DModel
645
+
646
+ assert isinstance(pipe.transformer, PRXTransformer2DModel)
647
+ return BlockAdapter(
648
+ pipe=pipe,
649
+ transformer=pipe.transformer,
650
+ blocks=pipe.transformer.blocks,
651
+ forward_pattern=ForwardPattern.Pattern_3,
652
+ check_forward_pattern=True,
653
+ check_num_outputs=False,
654
+ **kwargs,
655
+ )
656
+ except ImportError:
657
+ raise ImportError(
658
+ "PRXTransformer2DModel is not available in the current diffusers version. "
659
+ "Please upgrade diffusers>=0.36.dev0 to use this adapter."
660
+ )
661
+
662
+
663
+ @BlockAdapterRegistry.register("HunyuanImage")
664
+ def hunyuan_image_adapter(pipe, **kwargs) -> BlockAdapter:
665
+ try:
666
+ from diffusers import HunyuanImageTransformer2DModel
667
+
668
+ assert isinstance(pipe.transformer, HunyuanImageTransformer2DModel)
669
+ return BlockAdapter(
670
+ pipe=pipe,
671
+ transformer=pipe.transformer,
672
+ blocks=[
673
+ pipe.transformer.transformer_blocks,
674
+ pipe.transformer.single_transformer_blocks,
675
+ ],
676
+ forward_pattern=[
677
+ ForwardPattern.Pattern_0,
678
+ ForwardPattern.Pattern_0,
679
+ ],
680
+ # set `has_separate_cfg` as True to enable separate cfg caching
681
+ # since in hyimage-2.1 the `guider_state` contains 2 input batches.
682
+ # The cfg is `enabled` by default in AdaptiveProjectedMixGuidance.
683
+ has_separate_cfg=True,
684
+ check_forward_pattern=True,
685
+ **kwargs,
686
+ )
687
+ except ImportError:
688
+ raise ImportError(
689
+ "HunyuanImageTransformer2DModel is not available in the current diffusers version. "
690
+ "Please upgrade diffusers>=0.36.dev0 to use this adapter."
691
+ )
@@ -6,22 +6,32 @@ from collections.abc import Iterable
6
6
 
7
7
  from typing import Any, Tuple, List, Optional, Union
8
8
 
9
- from diffusers import DiffusionPipeline
10
- from cache_dit.cache_factory.patch_functors import PatchFunctor
11
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
12
- from cache_dit.cache_factory.params_modifier import ParamsModifier
9
+ from diffusers import DiffusionPipeline, ModelMixin
10
+ from cache_dit.caching.patch_functors import PatchFunctor
11
+ from cache_dit.caching.forward_pattern import ForwardPattern
12
+ from cache_dit.caching.params_modifier import ParamsModifier
13
13
 
14
14
  from cache_dit.logger import init_logger
15
15
 
16
16
  logger = init_logger(__name__)
17
17
 
18
18
 
19
+ class FakeDiffusionPipeline:
20
+ # A placeholder for pipelines when pipe is None.
21
+ def __init__(
22
+ self,
23
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
24
+ ):
25
+ self.transformer = transformer # Reference only
26
+
27
+
19
28
  @dataclasses.dataclass
20
29
  class BlockAdapter:
21
30
 
22
31
  # Transformer configurations.
23
32
  pipe: Union[
24
33
  DiffusionPipeline,
34
+ FakeDiffusionPipeline,
25
35
  Any,
26
36
  ] = None
27
37
 
@@ -73,7 +83,7 @@ class BlockAdapter:
73
83
  ]
74
84
  ] = None
75
85
 
76
- check_forward_pattern: bool = True
86
+ check_forward_pattern: Optional[bool] = None
77
87
  check_num_outputs: bool = False
78
88
 
79
89
  # Pipeline Level Flags
@@ -110,12 +120,43 @@ class BlockAdapter:
110
120
  def __post_init__(self):
111
121
  if self.skip_post_init:
112
122
  return
123
+
124
+ self.maybe_fake_pipe()
113
125
  if any((self.pipe is not None, self.transformer is not None)):
114
126
  self.maybe_fill_attrs()
115
127
  self.maybe_patchify()
116
128
  self.maybe_skip_checks()
117
129
 
130
+ def maybe_fake_pipe(self):
131
+ if self.pipe is None:
132
+ self.pipe = FakeDiffusionPipeline()
133
+ logger.warning("pipe is None, use FakeDiffusionPipeline instead.")
134
+
118
135
  def maybe_skip_checks(self):
136
+ if self.check_forward_pattern is None:
137
+ if self.transformer is not None:
138
+ if self.nested_depth(self.transformer) == 0:
139
+ transformer = self.transformer
140
+ elif self.nested_depth(self.transformer) == 1:
141
+ transformer = self.transformer[0]
142
+ else:
143
+ raise ValueError(
144
+ "transformer nested depth can't more than 1, "
145
+ f"current is: {self.nested_depth(self.transformer)}"
146
+ )
147
+ if transformer.__module__.startswith("diffusers"):
148
+ self.check_forward_pattern = True
149
+ logger.info(
150
+ f"Found transformer from diffusers: {transformer.__module__} "
151
+ "enable check_forward_pattern by default."
152
+ )
153
+ else:
154
+ self.check_forward_pattern = False
155
+ logger.info(
156
+ f"Found transformer NOT from diffusers: {transformer.__module__} "
157
+ "disable check_forward_pattern by default."
158
+ )
159
+
119
160
  if getattr(self.transformer, "_hf_hook", None) is not None:
120
161
  logger.warning("_hf_hook is not None, force skip pattern check!")
121
162
  self.check_forward_pattern = False
@@ -208,7 +249,10 @@ class BlockAdapter:
208
249
  if self.transformer is not None:
209
250
  self.patch_functor.apply(self.transformer, *args, **kwargs)
210
251
  else:
211
- assert hasattr(self.pipe, "transformer")
252
+ assert hasattr(self.pipe, "transformer"), (
253
+ "pipe.transformer can not be None when patch_functor "
254
+ "is provided and transformer is None."
255
+ )
212
256
  self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
213
257
 
214
258
  @staticmethod
@@ -224,6 +268,10 @@ class BlockAdapter:
224
268
  adapter.forward_pattern is not None
225
269
  ), "adapter.forward_pattern can not be None."
226
270
  pipe = adapter.pipe
271
+ if isinstance(pipe, FakeDiffusionPipeline):
272
+ raise ValueError(
273
+ "Can not auto block adapter for FakeDiffusionPipeline."
274
+ )
227
275
 
228
276
  assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
229
277
 
@@ -489,6 +537,7 @@ class BlockAdapter:
489
537
  @staticmethod
490
538
  def normalize(
491
539
  adapter: "BlockAdapter",
540
+ unique: bool = True,
492
541
  ) -> "BlockAdapter":
493
542
 
494
543
  if getattr(adapter, "_is_normalized", False):
@@ -523,7 +572,10 @@ class BlockAdapter:
523
572
  adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
524
573
  adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
525
574
  adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
526
- BlockAdapter.unique(adapter)
575
+ # Some times, the cache_config will be None.
576
+ # So we do not perform unique check here.
577
+ if unique:
578
+ BlockAdapter.unique(adapter)
527
579
 
528
580
  adapter._is_normalized = True
529
581
 
@@ -571,6 +623,10 @@ class BlockAdapter:
571
623
  if not getattr(adapter, "_is_normalized", False):
572
624
  raise RuntimeError("block_adapter must be normailzed.")
573
625
 
626
+ @classmethod
627
+ def is_normalized(cls, adapter: "BlockAdapter") -> bool:
628
+ return getattr(adapter, "_is_normalized", False)
629
+
574
630
  @classmethod
575
631
  def is_cached(cls, adapter: Any) -> bool:
576
632
  if isinstance(adapter, cls):
@@ -592,6 +648,21 @@ class BlockAdapter:
592
648
  else:
593
649
  return getattr(adapter, "_is_cached", False)
594
650
 
651
+ @classmethod
652
+ def is_parallelized(cls, adapter: Any) -> bool:
653
+ if isinstance(adapter, cls):
654
+ cls.assert_normalized(adapter)
655
+ return getattr(adapter.transformer[0], "_is_parallelized", False)
656
+ elif isinstance(adapter, DiffusionPipeline):
657
+ return getattr(adapter.transformer, "_is_parallelized", False)
658
+ elif isinstance(adapter, torch.nn.Module):
659
+ return getattr(adapter, "_is_parallelized", False)
660
+ elif isinstance(adapter, list): # [TRN_0,...]
661
+ assert isinstance(adapter[0], torch.nn.Module)
662
+ return getattr(adapter[0], "_is_parallelized", False)
663
+ else:
664
+ return getattr(adapter, "_is_parallelized", False)
665
+
595
666
  @classmethod
596
667
  def nested_depth(cls, obj: Any):
597
668
  # str: 0; List[str]: 1; List[List[str]]: 2