cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -4,32 +4,50 @@ except ImportError:
4
4
  __version__ = "unknown version"
5
5
  version_tuple = (0, 0, "unknown version")
6
6
 
7
- from cache_dit.utils import summary
8
- from cache_dit.utils import strify
7
+
9
8
  from cache_dit.utils import disable_print
10
9
  from cache_dit.logger import init_logger
11
- from cache_dit.cache_factory import load_options
12
- from cache_dit.cache_factory import enable_cache
13
- from cache_dit.cache_factory import disable_cache
14
- from cache_dit.cache_factory import cache_type
15
- from cache_dit.cache_factory import block_range
16
- from cache_dit.cache_factory import CacheType
17
- from cache_dit.cache_factory import BlockAdapter
18
- from cache_dit.cache_factory import ParamsModifier
19
- from cache_dit.cache_factory import ForwardPattern
20
- from cache_dit.cache_factory import PatchFunctor
21
- from cache_dit.cache_factory import BasicCacheConfig
22
- from cache_dit.cache_factory import CalibratorConfig
23
- from cache_dit.cache_factory import TaylorSeerCalibratorConfig
24
- from cache_dit.cache_factory import FoCaCalibratorConfig
25
- from cache_dit.cache_factory import supported_pipelines
26
- from cache_dit.cache_factory import get_adapter
10
+ from cache_dit.caching import load_options
11
+ from cache_dit.caching import enable_cache
12
+ from cache_dit.caching import disable_cache
13
+ from cache_dit.caching import cache_type
14
+ from cache_dit.caching import block_range
15
+ from cache_dit.caching import CacheType
16
+ from cache_dit.caching import BlockAdapter
17
+ from cache_dit.caching import ParamsModifier
18
+ from cache_dit.caching import ForwardPattern
19
+ from cache_dit.caching import PatchFunctor
20
+ from cache_dit.caching import BasicCacheConfig
21
+ from cache_dit.caching import DBCacheConfig
22
+ from cache_dit.caching import DBPruneConfig
23
+ from cache_dit.caching import CalibratorConfig
24
+ from cache_dit.caching import TaylorSeerCalibratorConfig
25
+ from cache_dit.caching import FoCaCalibratorConfig
26
+ from cache_dit.caching import supported_pipelines
27
+ from cache_dit.caching import get_adapter
27
28
  from cache_dit.compile import set_compile_configs
28
- from cache_dit.quantize import quantize
29
+ from cache_dit.parallelism import ParallelismBackend
30
+ from cache_dit.parallelism import ParallelismConfig
31
+ from cache_dit.summary import supported_matrix
32
+ from cache_dit.summary import summary
33
+ from cache_dit.summary import strify
34
+
35
+ try:
36
+ from cache_dit.quantize import quantize
37
+ except ImportError as e: # noqa: F841
38
+ err_msg = str(e)
39
+
40
+ def quantize(*args, **kwargs):
41
+ raise ImportError(
42
+ "Quantization requires additional dependencies. "
43
+ "Please install cache-dit[quantization] or cache-dit[all] "
44
+ f"to use this feature. Error message: {err_msg}"
45
+ )
29
46
 
30
47
 
31
48
  NONE = CacheType.NONE
32
49
  DBCache = CacheType.DBCache
50
+ DBPrune = CacheType.DBPrune
33
51
 
34
52
  Pattern_0 = ForwardPattern.Pattern_0
35
53
  Pattern_1 = ForwardPattern.Pattern_1
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.2'
32
- __version_tuple__ = version_tuple = (0, 3, 2)
31
+ __version__ = version = '1.0.14'
32
+ __version_tuple__ = version_tuple = (1, 0, 14)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,36 @@
1
+ from cache_dit.caching.cache_types import CacheType
2
+ from cache_dit.caching.cache_types import cache_type
3
+ from cache_dit.caching.cache_types import block_range
4
+
5
+ from cache_dit.caching.forward_pattern import ForwardPattern
6
+ from cache_dit.caching.params_modifier import ParamsModifier
7
+ from cache_dit.caching.patch_functors import PatchFunctor
8
+
9
+ from cache_dit.caching.block_adapters import BlockAdapter
10
+ from cache_dit.caching.block_adapters import BlockAdapterRegistry
11
+ from cache_dit.caching.block_adapters import FakeDiffusionPipeline
12
+
13
+ from cache_dit.caching.cache_contexts import BasicCacheConfig
14
+ from cache_dit.caching.cache_contexts import DBCacheConfig
15
+ from cache_dit.caching.cache_contexts import CachedContext
16
+ from cache_dit.caching.cache_contexts import CachedContextManager
17
+ from cache_dit.caching.cache_contexts import DBPruneConfig
18
+ from cache_dit.caching.cache_contexts import PrunedContext
19
+ from cache_dit.caching.cache_contexts import PrunedContextManager
20
+ from cache_dit.caching.cache_contexts import ContextManager
21
+ from cache_dit.caching.cache_contexts import CalibratorConfig
22
+ from cache_dit.caching.cache_contexts import TaylorSeerCalibratorConfig
23
+ from cache_dit.caching.cache_contexts import FoCaCalibratorConfig
24
+
25
+ from cache_dit.caching.cache_blocks import CachedBlocks
26
+ from cache_dit.caching.cache_blocks import PrunedBlocks
27
+ from cache_dit.caching.cache_blocks import UnifiedBlocks
28
+
29
+ from cache_dit.caching.cache_adapters import CachedAdapter
30
+
31
+ from cache_dit.caching.cache_interface import enable_cache
32
+ from cache_dit.caching.cache_interface import disable_cache
33
+ from cache_dit.caching.cache_interface import supported_pipelines
34
+ from cache_dit.caching.cache_interface import get_adapter
35
+
36
+ from cache_dit.caching.utils import load_options
@@ -1,7 +1,10 @@
1
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
- from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
3
- from cache_dit.cache_factory.block_adapters.block_adapters import ParamsModifier
4
- from cache_dit.cache_factory.block_adapters.block_registers import (
1
+ from cache_dit.caching.forward_pattern import ForwardPattern
2
+ from cache_dit.caching.block_adapters.block_adapters import BlockAdapter
3
+ from cache_dit.caching.block_adapters.block_adapters import (
4
+ FakeDiffusionPipeline,
5
+ )
6
+ from cache_dit.caching.block_adapters.block_adapters import ParamsModifier
7
+ from cache_dit.caching.block_adapters.block_registers import (
5
8
  BlockAdapterRegistry,
6
9
  )
7
10
 
@@ -12,7 +15,10 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
12
15
  from cache_dit.utils import is_diffusers_at_least_0_3_5
13
16
 
14
17
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
15
- if is_diffusers_at_least_0_3_5():
18
+ transformer_cls_name: str = pipe.transformer.__class__.__name__
19
+ if is_diffusers_at_least_0_3_5() and not transformer_cls_name.startswith(
20
+ "Nunchaku"
21
+ ):
16
22
  return BlockAdapter(
17
23
  pipe=pipe,
18
24
  transformer=pipe.transformer,
@@ -24,6 +30,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
24
30
  ForwardPattern.Pattern_1,
25
31
  ForwardPattern.Pattern_1,
26
32
  ],
33
+ check_forward_pattern=True,
27
34
  **kwargs,
28
35
  )
29
36
  else:
@@ -38,6 +45,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
38
45
  ForwardPattern.Pattern_1,
39
46
  ForwardPattern.Pattern_3,
40
47
  ],
48
+ check_forward_pattern=True,
41
49
  **kwargs,
42
50
  )
43
51
 
@@ -52,6 +60,7 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
52
60
  transformer=pipe.transformer,
53
61
  blocks=pipe.transformer.transformer_blocks,
54
62
  forward_pattern=ForwardPattern.Pattern_0,
63
+ check_forward_pattern=True,
55
64
  **kwargs,
56
65
  )
57
66
 
@@ -66,6 +75,7 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
66
75
  transformer=pipe.transformer,
67
76
  blocks=pipe.transformer.transformer_blocks,
68
77
  forward_pattern=ForwardPattern.Pattern_0,
78
+ check_forward_pattern=True,
69
79
  **kwargs,
70
80
  )
71
81
 
@@ -101,6 +111,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
101
111
  ForwardPattern.Pattern_2,
102
112
  ForwardPattern.Pattern_2,
103
113
  ],
114
+ check_forward_pattern=True,
104
115
  has_separate_cfg=True,
105
116
  **kwargs,
106
117
  )
@@ -111,6 +122,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
111
122
  transformer=pipe.transformer,
112
123
  blocks=pipe.transformer.blocks,
113
124
  forward_pattern=ForwardPattern.Pattern_2,
125
+ check_forward_pattern=True,
114
126
  has_separate_cfg=True,
115
127
  **kwargs,
116
128
  )
@@ -132,6 +144,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
132
144
  ForwardPattern.Pattern_0,
133
145
  ForwardPattern.Pattern_0,
134
146
  ],
147
+ check_forward_pattern=True,
135
148
  # The type hint in diffusers is wrong
136
149
  check_num_outputs=False,
137
150
  **kwargs,
@@ -143,14 +156,32 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
143
156
  from diffusers import QwenImageTransformer2DModel
144
157
 
145
158
  assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
146
- return BlockAdapter(
147
- pipe=pipe,
148
- transformer=pipe.transformer,
149
- blocks=pipe.transformer.transformer_blocks,
150
- forward_pattern=ForwardPattern.Pattern_1,
151
- has_separate_cfg=True,
152
- **kwargs,
153
- )
159
+
160
+ pipe_cls_name: str = pipe.__class__.__name__
161
+ if pipe_cls_name.startswith("QwenImageControlNet"):
162
+ from cache_dit.caching.patch_functors import (
163
+ QwenImageControlNetPatchFunctor,
164
+ )
165
+
166
+ return BlockAdapter(
167
+ pipe=pipe,
168
+ transformer=pipe.transformer,
169
+ blocks=pipe.transformer.transformer_blocks,
170
+ forward_pattern=ForwardPattern.Pattern_1,
171
+ patch_functor=QwenImageControlNetPatchFunctor(),
172
+ check_forward_pattern=True,
173
+ has_separate_cfg=True,
174
+ )
175
+ else:
176
+ return BlockAdapter(
177
+ pipe=pipe,
178
+ transformer=pipe.transformer,
179
+ blocks=pipe.transformer.transformer_blocks,
180
+ forward_pattern=ForwardPattern.Pattern_1,
181
+ check_forward_pattern=True,
182
+ has_separate_cfg=True,
183
+ **kwargs,
184
+ )
154
185
 
155
186
 
156
187
  @BlockAdapterRegistry.register("LTX")
@@ -163,6 +194,7 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
163
194
  transformer=pipe.transformer,
164
195
  blocks=pipe.transformer.transformer_blocks,
165
196
  forward_pattern=ForwardPattern.Pattern_2,
197
+ check_forward_pattern=True,
166
198
  **kwargs,
167
199
  )
168
200
 
@@ -177,6 +209,7 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
177
209
  transformer=pipe.transformer,
178
210
  blocks=pipe.transformer.transformer_blocks,
179
211
  forward_pattern=ForwardPattern.Pattern_2,
212
+ check_forward_pattern=True,
180
213
  **kwargs,
181
214
  )
182
215
 
@@ -191,6 +224,7 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
191
224
  transformer=pipe.transformer,
192
225
  blocks=pipe.transformer.transformer_blocks,
193
226
  forward_pattern=ForwardPattern.Pattern_0,
227
+ check_forward_pattern=True,
194
228
  **kwargs,
195
229
  )
196
230
 
@@ -205,6 +239,7 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
205
239
  transformer=pipe.transformer,
206
240
  blocks=pipe.transformer.transformer_blocks,
207
241
  forward_pattern=ForwardPattern.Pattern_0,
242
+ check_forward_pattern=True,
208
243
  has_separate_cfg=True,
209
244
  **kwargs,
210
245
  )
@@ -220,6 +255,7 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
220
255
  transformer=pipe.transformer,
221
256
  blocks=pipe.transformer.transformer_blocks,
222
257
  forward_pattern=ForwardPattern.Pattern_2,
258
+ check_forward_pattern=True,
223
259
  has_separate_cfg=True,
224
260
  **kwargs,
225
261
  )
@@ -235,6 +271,7 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
235
271
  transformer=pipe.transformer,
236
272
  blocks=pipe.transformer.transformer_blocks,
237
273
  forward_pattern=ForwardPattern.Pattern_0,
274
+ check_forward_pattern=True,
238
275
  **kwargs,
239
276
  )
240
277
 
@@ -252,6 +289,7 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
252
289
  # encoder_hidden_states will never change in the blocks
253
290
  # forward loop.
254
291
  forward_pattern=ForwardPattern.Pattern_3,
292
+ check_forward_pattern=True,
255
293
  has_separate_cfg=True,
256
294
  **kwargs,
257
295
  )
@@ -267,6 +305,7 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
267
305
  transformer=pipe.transformer,
268
306
  blocks=pipe.transformer.transformer_blocks,
269
307
  forward_pattern=ForwardPattern.Pattern_1,
308
+ check_forward_pattern=True,
270
309
  **kwargs,
271
310
  )
272
311
 
@@ -281,6 +320,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
281
320
  transformer=pipe.transformer,
282
321
  blocks=pipe.transformer.transformer_blocks,
283
322
  forward_pattern=ForwardPattern.Pattern_0,
323
+ check_forward_pattern=True,
284
324
  **kwargs,
285
325
  )
286
326
 
@@ -288,7 +328,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
288
328
  @BlockAdapterRegistry.register("DiT")
289
329
  def dit_adapter(pipe, **kwargs) -> BlockAdapter:
290
330
  from diffusers import DiTTransformer2DModel
291
- from cache_dit.cache_factory.patch_functors import DiTPatchFunctor
331
+ from cache_dit.caching.patch_functors import DiTPatchFunctor
292
332
 
293
333
  assert isinstance(pipe.transformer, DiTTransformer2DModel)
294
334
  return BlockAdapter(
@@ -297,6 +337,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
297
337
  blocks=pipe.transformer.transformer_blocks,
298
338
  forward_pattern=ForwardPattern.Pattern_3,
299
339
  patch_functor=DiTPatchFunctor(),
340
+ check_forward_pattern=True,
300
341
  **kwargs,
301
342
  )
302
343
 
@@ -311,6 +352,7 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
311
352
  transformer=pipe.transformer,
312
353
  blocks=pipe.transformer.transformer_layers,
313
354
  forward_pattern=ForwardPattern.Pattern_3,
355
+ check_forward_pattern=True,
314
356
  **kwargs,
315
357
  )
316
358
 
@@ -331,6 +373,7 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
331
373
  ForwardPattern.Pattern_0,
332
374
  ForwardPattern.Pattern_0,
333
375
  ],
376
+ check_forward_pattern=True,
334
377
  **kwargs,
335
378
  )
336
379
 
@@ -348,6 +391,7 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
348
391
  transformer=pipe.transformer,
349
392
  blocks=pipe.transformer.layers,
350
393
  forward_pattern=ForwardPattern.Pattern_3,
394
+ check_forward_pattern=True,
351
395
  **kwargs,
352
396
  )
353
397
 
@@ -362,6 +406,7 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
362
406
  transformer=pipe.transformer,
363
407
  blocks=pipe.transformer.layers,
364
408
  forward_pattern=ForwardPattern.Pattern_3,
409
+ check_forward_pattern=True,
365
410
  **kwargs,
366
411
  )
367
412
 
@@ -376,6 +421,7 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
376
421
  transformer=pipe.transformer,
377
422
  blocks=pipe.transformer.transformer_blocks,
378
423
  forward_pattern=ForwardPattern.Pattern_3,
424
+ check_forward_pattern=True,
379
425
  **kwargs,
380
426
  )
381
427
 
@@ -390,6 +436,7 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
390
436
  transformer=pipe.transformer,
391
437
  blocks=pipe.transformer.transformer_blocks,
392
438
  forward_pattern=ForwardPattern.Pattern_3,
439
+ check_forward_pattern=True,
393
440
  **kwargs,
394
441
  )
395
442
 
@@ -404,6 +451,7 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
404
451
  transformer=pipe.transformer,
405
452
  blocks=pipe.transformer.transformer_blocks,
406
453
  forward_pattern=ForwardPattern.Pattern_3,
454
+ check_forward_pattern=True,
407
455
  **kwargs,
408
456
  )
409
457
 
@@ -426,6 +474,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
426
474
  ForwardPattern.Pattern_1,
427
475
  ForwardPattern.Pattern_1,
428
476
  ],
477
+ check_forward_pattern=True,
429
478
  **kwargs,
430
479
  )
431
480
  else:
@@ -440,6 +489,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
440
489
  ForwardPattern.Pattern_1,
441
490
  ForwardPattern.Pattern_3,
442
491
  ],
492
+ check_forward_pattern=True,
443
493
  **kwargs,
444
494
  )
445
495
 
@@ -454,6 +504,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
454
504
  transformer=pipe.transformer,
455
505
  blocks=pipe.transformer.single_transformer_blocks,
456
506
  forward_pattern=ForwardPattern.Pattern_3,
507
+ check_forward_pattern=True,
457
508
  **kwargs,
458
509
  )
459
510
 
@@ -461,7 +512,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
461
512
  @BlockAdapterRegistry.register("Chroma")
462
513
  def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
463
514
  from diffusers import ChromaTransformer2DModel
464
- from cache_dit.cache_factory.patch_functors import ChromaPatchFunctor
515
+ from cache_dit.caching.patch_functors import ChromaPatchFunctor
465
516
 
466
517
  assert isinstance(pipe.transformer, ChromaTransformer2DModel)
467
518
  return BlockAdapter(
@@ -476,6 +527,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
476
527
  ForwardPattern.Pattern_3,
477
528
  ],
478
529
  patch_functor=ChromaPatchFunctor(),
530
+ check_forward_pattern=True,
479
531
  has_separate_cfg=True,
480
532
  **kwargs,
481
533
  )
@@ -491,6 +543,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
491
543
  transformer=pipe.prior,
492
544
  blocks=pipe.prior.transformer_blocks,
493
545
  forward_pattern=ForwardPattern.Pattern_3,
546
+ check_forward_pattern=True,
494
547
  **kwargs,
495
548
  )
496
549
 
@@ -503,7 +556,7 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
503
556
  # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
504
557
  # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
505
558
  from diffusers import HiDreamImageTransformer2DModel
506
- from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
559
+ from cache_dit.caching.patch_functors import HiDreamPatchFunctor
507
560
 
508
561
  assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
509
562
  return BlockAdapter(
@@ -528,7 +581,7 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
528
581
  @BlockAdapterRegistry.register("HunyuanDiT")
529
582
  def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
530
583
  from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
531
- from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
584
+ from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
532
585
 
533
586
  assert isinstance(
534
587
  pipe.transformer,
@@ -540,6 +593,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
540
593
  blocks=pipe.transformer.blocks,
541
594
  forward_pattern=ForwardPattern.Pattern_3,
542
595
  patch_functor=HunyuanDiTPatchFunctor(),
596
+ check_forward_pattern=True,
543
597
  **kwargs,
544
598
  )
545
599
 
@@ -547,7 +601,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
547
601
  @BlockAdapterRegistry.register("HunyuanDiTPAG")
548
602
  def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
549
603
  from diffusers import HunyuanDiT2DModel
550
- from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
604
+ from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
551
605
 
552
606
  assert isinstance(pipe.transformer, HunyuanDiT2DModel)
553
607
  return BlockAdapter(
@@ -556,5 +610,82 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
556
610
  blocks=pipe.transformer.blocks,
557
611
  forward_pattern=ForwardPattern.Pattern_3,
558
612
  patch_functor=HunyuanDiTPatchFunctor(),
613
+ check_forward_pattern=True,
559
614
  **kwargs,
560
615
  )
616
+
617
+
618
+ @BlockAdapterRegistry.register("Kandinsky5")
619
+ def kandinsky5_adapter(pipe, **kwargs) -> BlockAdapter:
620
+ try:
621
+ from diffusers import Kandinsky5Transformer3DModel
622
+
623
+ assert isinstance(pipe.transformer, Kandinsky5Transformer3DModel)
624
+ return BlockAdapter(
625
+ pipe=pipe,
626
+ transformer=pipe.transformer,
627
+ blocks=pipe.transformer.visual_transformer_blocks,
628
+ forward_pattern=ForwardPattern.Pattern_3, # or Pattern_2
629
+ has_separate_cfg=True,
630
+ check_forward_pattern=False,
631
+ check_num_outputs=False,
632
+ **kwargs,
633
+ )
634
+ except ImportError:
635
+ raise ImportError(
636
+ "Kandinsky5Transformer3DModel is not available in the current diffusers version. "
637
+ "Please upgrade diffusers>=0.36.dev0 to use this adapter."
638
+ )
639
+
640
+
641
+ @BlockAdapterRegistry.register("PRX")
642
+ def prx_adapter(pipe, **kwargs) -> BlockAdapter:
643
+ try:
644
+ from diffusers import PRXTransformer2DModel
645
+
646
+ assert isinstance(pipe.transformer, PRXTransformer2DModel)
647
+ return BlockAdapter(
648
+ pipe=pipe,
649
+ transformer=pipe.transformer,
650
+ blocks=pipe.transformer.blocks,
651
+ forward_pattern=ForwardPattern.Pattern_3,
652
+ check_forward_pattern=True,
653
+ check_num_outputs=False,
654
+ **kwargs,
655
+ )
656
+ except ImportError:
657
+ raise ImportError(
658
+ "PRXTransformer2DModel is not available in the current diffusers version. "
659
+ "Please upgrade diffusers>=0.36.dev0 to use this adapter."
660
+ )
661
+
662
+
663
+ @BlockAdapterRegistry.register("HunyuanImage")
664
+ def hunyuan_image_adapter(pipe, **kwargs) -> BlockAdapter:
665
+ try:
666
+ from diffusers import HunyuanImageTransformer2DModel
667
+
668
+ assert isinstance(pipe.transformer, HunyuanImageTransformer2DModel)
669
+ return BlockAdapter(
670
+ pipe=pipe,
671
+ transformer=pipe.transformer,
672
+ blocks=[
673
+ pipe.transformer.transformer_blocks,
674
+ pipe.transformer.single_transformer_blocks,
675
+ ],
676
+ forward_pattern=[
677
+ ForwardPattern.Pattern_0,
678
+ ForwardPattern.Pattern_0,
679
+ ],
680
+ # set `has_separate_cfg` as True to enable separate cfg caching
681
+ # since in hyimage-2.1 the `guider_state` contains 2 input batches.
682
+ # The cfg is `enabled` by default in AdaptiveProjectedMixGuidance.
683
+ has_separate_cfg=True,
684
+ check_forward_pattern=True,
685
+ **kwargs,
686
+ )
687
+ except ImportError:
688
+ raise ImportError(
689
+ "HunyuanImageTransformer2DModel is not available in the current diffusers version. "
690
+ "Please upgrade diffusers>=0.36.dev0 to use this adapter."
691
+ )