cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cache_dit/__init__.py +8 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +17 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +555 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +262 -938
  8. cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
  12. cache_dit/cache_factory/cache_blocks/utils.py +16 -10
  13. cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
  14. cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
  16. cache_dit/cache_factory/cache_interface.py +31 -31
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +26 -26
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
  22. cache_dit-0.2.28.dist-info/RECORD +47 -0
  23. cache_dit/cache_factory/cache_context.py +0 -1155
  24. cache_dit-0.2.26.dist-info/RECORD +0 -42
  25. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -10,9 +10,11 @@ from cache_dit.cache_factory import cache_type
10
10
  from cache_dit.cache_factory import block_range
11
11
  from cache_dit.cache_factory import CacheType
12
12
  from cache_dit.cache_factory import BlockAdapter
13
+ from cache_dit.cache_factory import ParamsModifier
13
14
  from cache_dit.cache_factory import ForwardPattern
14
15
  from cache_dit.cache_factory import PatchFunctor
15
16
  from cache_dit.cache_factory import supported_pipelines
17
+ from cache_dit.cache_factory import get_adapter
16
18
  from cache_dit.compile import set_compile_configs
17
19
  from cache_dit.quantize import quantize
18
20
  from cache_dit.utils import summary
@@ -22,9 +24,9 @@ from cache_dit.logger import init_logger
22
24
  NONE = CacheType.NONE
23
25
  DBCache = CacheType.DBCache
24
26
 
25
- Forward_Pattern_0 = ForwardPattern.Pattern_0
26
- Forward_Pattern_1 = ForwardPattern.Pattern_1
27
- Forward_Pattern_2 = ForwardPattern.Pattern_2
28
- Forward_Pattern_3 = ForwardPattern.Pattern_3
29
- Forward_Pattern_4 = ForwardPattern.Pattern_4
30
- Forward_Pattern_5 = ForwardPattern.Pattern_5
27
+ Pattern_0 = ForwardPattern.Pattern_0
28
+ Pattern_1 = ForwardPattern.Pattern_1
29
+ Pattern_2 = ForwardPattern.Pattern_2
30
+ Pattern_3 = ForwardPattern.Pattern_3
31
+ Pattern_4 = ForwardPattern.Pattern_4
32
+ Pattern_5 = ForwardPattern.Pattern_5
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.26'
32
- __version_tuple__ = version_tuple = (0, 2, 26)
31
+ __version__ = version = '0.2.28'
32
+ __version_tuple__ = version_tuple = (0, 2, 28)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,10 +1,23 @@
1
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
1
  from cache_dit.cache_factory.cache_types import CacheType
3
2
  from cache_dit.cache_factory.cache_types import cache_type
4
3
  from cache_dit.cache_factory.cache_types import block_range
5
- from cache_dit.cache_factory.cache_adapters import BlockAdapter
6
- from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
4
+
5
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
6
+
7
+ from cache_dit.cache_factory.patch_functors import PatchFunctor
8
+
9
+ from cache_dit.cache_factory.block_adapters import BlockAdapter
10
+ from cache_dit.cache_factory.block_adapters import ParamsModifier
11
+ from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
12
+
13
+ from cache_dit.cache_factory.cache_contexts import CachedContext
14
+ from cache_dit.cache_factory.cache_contexts import CachedContextManager
15
+ from cache_dit.cache_factory.cache_blocks import CachedBlocks
16
+
17
+ from cache_dit.cache_factory.cache_adapters import CachedAdapter
18
+
7
19
  from cache_dit.cache_factory.cache_interface import enable_cache
8
20
  from cache_dit.cache_factory.cache_interface import supported_pipelines
9
- from cache_dit.cache_factory.patch_functors import PatchFunctor
21
+ from cache_dit.cache_factory.cache_interface import get_adapter
22
+
10
23
  from cache_dit.cache_factory.utils import load_options
@@ -0,0 +1,555 @@
1
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
+ from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
3
+ from cache_dit.cache_factory.block_adapters.block_adapters import ParamsModifier
4
+ from cache_dit.cache_factory.block_adapters.block_registers import (
5
+ BlockAdapterRegistry,
6
+ )
7
+
8
+
9
+ @BlockAdapterRegistry.register("Flux")
10
+ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
11
+ from diffusers import FluxTransformer2DModel
12
+ from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
13
+
14
+ assert isinstance(pipe.transformer, FluxTransformer2DModel)
15
+
16
+ return BlockAdapter(
17
+ pipe=pipe,
18
+ transformer=pipe.transformer,
19
+ blocks=(
20
+ pipe.transformer.transformer_blocks
21
+ + pipe.transformer.single_transformer_blocks
22
+ ),
23
+ blocks_name="transformer_blocks",
24
+ dummy_blocks_names=["single_transformer_blocks"],
25
+ patch_functor=FluxPatchFunctor(),
26
+ forward_pattern=ForwardPattern.Pattern_1,
27
+ disable_patch=kwargs.pop("disable_patch", False),
28
+ )
29
+
30
+
31
+ @BlockAdapterRegistry.register("Mochi")
32
+ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
33
+ from diffusers import MochiTransformer3DModel
34
+
35
+ assert isinstance(pipe.transformer, MochiTransformer3DModel)
36
+ return BlockAdapter(
37
+ pipe=pipe,
38
+ transformer=pipe.transformer,
39
+ blocks=pipe.transformer.transformer_blocks,
40
+ blocks_name="transformer_blocks",
41
+ dummy_blocks_names=[],
42
+ forward_pattern=ForwardPattern.Pattern_0,
43
+ )
44
+
45
+
46
+ @BlockAdapterRegistry.register("CogVideoX")
47
+ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
48
+ from diffusers import CogVideoXTransformer3DModel
49
+
50
+ assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
51
+ return BlockAdapter(
52
+ pipe=pipe,
53
+ transformer=pipe.transformer,
54
+ blocks=pipe.transformer.transformer_blocks,
55
+ blocks_name="transformer_blocks",
56
+ dummy_blocks_names=[],
57
+ forward_pattern=ForwardPattern.Pattern_0,
58
+ )
59
+
60
+
61
+ @BlockAdapterRegistry.register("Wan")
62
+ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
63
+ from diffusers import (
64
+ WanTransformer3DModel,
65
+ WanVACETransformer3DModel,
66
+ )
67
+
68
+ assert isinstance(
69
+ pipe.transformer,
70
+ (WanTransformer3DModel, WanVACETransformer3DModel),
71
+ )
72
+ if getattr(pipe, "transformer_2", None):
73
+ assert isinstance(
74
+ pipe.transformer_2,
75
+ (WanTransformer3DModel, WanVACETransformer3DModel),
76
+ )
77
+ # Wan 2.2 MoE
78
+ return BlockAdapter(
79
+ pipe=pipe,
80
+ transformer=[
81
+ pipe.transformer,
82
+ pipe.transformer_2,
83
+ ],
84
+ blocks=[
85
+ pipe.transformer.blocks,
86
+ pipe.transformer_2.blocks,
87
+ ],
88
+ blocks_name=[
89
+ "blocks",
90
+ "blocks",
91
+ ],
92
+ forward_pattern=[
93
+ ForwardPattern.Pattern_2,
94
+ ForwardPattern.Pattern_2,
95
+ ],
96
+ dummy_blocks_names=[],
97
+ has_separate_cfg=True,
98
+ )
99
+ else:
100
+ # Wan 2.1
101
+ return BlockAdapter(
102
+ pipe=pipe,
103
+ transformer=pipe.transformer,
104
+ blocks=pipe.transformer.blocks,
105
+ blocks_name="blocks",
106
+ dummy_blocks_names=[],
107
+ forward_pattern=ForwardPattern.Pattern_2,
108
+ has_separate_cfg=True,
109
+ )
110
+
111
+
112
+ @BlockAdapterRegistry.register("HunyuanVideo")
113
+ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
114
+ from diffusers import HunyuanVideoTransformer3DModel
115
+
116
+ assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
117
+ return BlockAdapter(
118
+ pipe=pipe,
119
+ blocks=(
120
+ pipe.transformer.transformer_blocks
121
+ + pipe.transformer.single_transformer_blocks
122
+ ),
123
+ blocks_name="transformer_blocks",
124
+ dummy_blocks_names=["single_transformer_blocks"],
125
+ forward_pattern=ForwardPattern.Pattern_0,
126
+ )
127
+
128
+
129
+ @BlockAdapterRegistry.register("QwenImage")
130
+ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
131
+ from diffusers import QwenImageTransformer2DModel
132
+
133
+ assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
134
+ return BlockAdapter(
135
+ pipe=pipe,
136
+ transformer=pipe.transformer,
137
+ blocks=pipe.transformer.transformer_blocks,
138
+ blocks_name="transformer_blocks",
139
+ dummy_blocks_names=[],
140
+ forward_pattern=ForwardPattern.Pattern_1,
141
+ has_separate_cfg=True,
142
+ )
143
+
144
+
145
+ @BlockAdapterRegistry.register("LTXVideo")
146
+ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
147
+ from diffusers import LTXVideoTransformer3DModel
148
+
149
+ assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
150
+ return BlockAdapter(
151
+ pipe=pipe,
152
+ transformer=pipe.transformer,
153
+ blocks=pipe.transformer.transformer_blocks,
154
+ blocks_name="transformer_blocks",
155
+ dummy_blocks_names=[],
156
+ forward_pattern=ForwardPattern.Pattern_2,
157
+ )
158
+
159
+
160
+ @BlockAdapterRegistry.register("Allegro")
161
+ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
162
+ from diffusers import AllegroTransformer3DModel
163
+
164
+ assert isinstance(pipe.transformer, AllegroTransformer3DModel)
165
+ return BlockAdapter(
166
+ pipe=pipe,
167
+ transformer=pipe.transformer,
168
+ blocks=pipe.transformer.transformer_blocks,
169
+ blocks_name="transformer_blocks",
170
+ dummy_blocks_names=[],
171
+ forward_pattern=ForwardPattern.Pattern_2,
172
+ )
173
+
174
+
175
+ @BlockAdapterRegistry.register("CogView3Plus")
176
+ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
177
+ from diffusers import CogView3PlusTransformer2DModel
178
+
179
+ assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
180
+ return BlockAdapter(
181
+ pipe=pipe,
182
+ transformer=pipe.transformer,
183
+ blocks=pipe.transformer.transformer_blocks,
184
+ blocks_name="transformer_blocks",
185
+ dummy_blocks_names=[],
186
+ forward_pattern=ForwardPattern.Pattern_0,
187
+ )
188
+
189
+
190
+ @BlockAdapterRegistry.register("CogView4")
191
+ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
192
+ from diffusers import CogView4Transformer2DModel
193
+
194
+ assert isinstance(pipe.transformer, CogView4Transformer2DModel)
195
+ return BlockAdapter(
196
+ pipe=pipe,
197
+ transformer=pipe.transformer,
198
+ blocks=pipe.transformer.transformer_blocks,
199
+ blocks_name="transformer_blocks",
200
+ dummy_blocks_names=[],
201
+ forward_pattern=ForwardPattern.Pattern_0,
202
+ has_separate_cfg=True,
203
+ )
204
+
205
+
206
+ @BlockAdapterRegistry.register("Cosmos")
207
+ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
208
+ from diffusers import CosmosTransformer3DModel
209
+
210
+ assert isinstance(pipe.transformer, CosmosTransformer3DModel)
211
+ return BlockAdapter(
212
+ pipe=pipe,
213
+ transformer=pipe.transformer,
214
+ blocks=pipe.transformer.transformer_blocks,
215
+ blocks_name="transformer_blocks",
216
+ dummy_blocks_names=[],
217
+ forward_pattern=ForwardPattern.Pattern_2,
218
+ has_separate_cfg=True,
219
+ )
220
+
221
+
222
+ @BlockAdapterRegistry.register("EasyAnimate")
223
+ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
224
+ from diffusers import EasyAnimateTransformer3DModel
225
+
226
+ assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
227
+ return BlockAdapter(
228
+ pipe=pipe,
229
+ transformer=pipe.transformer,
230
+ blocks=pipe.transformer.transformer_blocks,
231
+ blocks_name="transformer_blocks",
232
+ dummy_blocks_names=[],
233
+ forward_pattern=ForwardPattern.Pattern_0,
234
+ )
235
+
236
+
237
+ @BlockAdapterRegistry.register("SkyReelsV2")
238
+ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
239
+ from diffusers import SkyReelsV2Transformer3DModel
240
+
241
+ assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
242
+ return BlockAdapter(
243
+ pipe=pipe,
244
+ transformer=pipe.transformer,
245
+ blocks=pipe.transformer.blocks,
246
+ blocks_name="blocks",
247
+ dummy_blocks_names=[],
248
+ forward_pattern=ForwardPattern.Pattern_2,
249
+ has_separate_cfg=True,
250
+ )
251
+
252
+
253
+ @BlockAdapterRegistry.register("SD3")
254
+ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
255
+ from diffusers import SD3Transformer2DModel
256
+
257
+ assert isinstance(pipe.transformer, SD3Transformer2DModel)
258
+ return BlockAdapter(
259
+ pipe=pipe,
260
+ transformer=pipe.transformer,
261
+ blocks=pipe.transformer.transformer_blocks,
262
+ blocks_name="transformer_blocks",
263
+ dummy_blocks_names=[],
264
+ forward_pattern=ForwardPattern.Pattern_1,
265
+ )
266
+
267
+
268
+ @BlockAdapterRegistry.register("ConsisID")
269
+ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
270
+ from diffusers import ConsisIDTransformer3DModel
271
+
272
+ assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
273
+ return BlockAdapter(
274
+ pipe=pipe,
275
+ transformer=pipe.transformer,
276
+ blocks=pipe.transformer.transformer_blocks,
277
+ blocks_name="transformer_blocks",
278
+ dummy_blocks_names=[],
279
+ forward_pattern=ForwardPattern.Pattern_0,
280
+ )
281
+
282
+
283
+ @BlockAdapterRegistry.register("DiT")
284
+ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
285
+ from diffusers import DiTTransformer2DModel
286
+
287
+ assert isinstance(pipe.transformer, DiTTransformer2DModel)
288
+ return BlockAdapter(
289
+ pipe=pipe,
290
+ transformer=pipe.transformer,
291
+ blocks=pipe.transformer.transformer_blocks,
292
+ blocks_name="transformer_blocks",
293
+ dummy_blocks_names=[],
294
+ forward_pattern=ForwardPattern.Pattern_3,
295
+ )
296
+
297
+
298
+ @BlockAdapterRegistry.register("Amused")
299
+ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
300
+ from diffusers import UVit2DModel
301
+
302
+ assert isinstance(pipe.transformer, UVit2DModel)
303
+ return BlockAdapter(
304
+ pipe=pipe,
305
+ transformer=pipe.transformer,
306
+ blocks=pipe.transformer.transformer_layers,
307
+ blocks_name="transformer_layers",
308
+ dummy_blocks_names=[],
309
+ forward_pattern=ForwardPattern.Pattern_3,
310
+ )
311
+
312
+
313
+ @BlockAdapterRegistry.register("Bria")
314
+ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
315
+ from diffusers import BriaTransformer2DModel
316
+
317
+ assert isinstance(pipe.transformer, BriaTransformer2DModel)
318
+ return BlockAdapter(
319
+ pipe=pipe,
320
+ transformer=pipe.transformer,
321
+ blocks=(
322
+ pipe.transformer.transformer_blocks
323
+ + pipe.transformer.single_transformer_blocks
324
+ ),
325
+ blocks_name="transformer_blocks",
326
+ dummy_blocks_names=["single_transformer_blocks"],
327
+ forward_pattern=ForwardPattern.Pattern_0,
328
+ )
329
+
330
+
331
+ @BlockAdapterRegistry.register("HunyuanDiT")
332
+ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
333
+ from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
334
+
335
+ assert isinstance(
336
+ pipe.transformer,
337
+ (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
338
+ )
339
+ return BlockAdapter(
340
+ pipe=pipe,
341
+ transformer=pipe.transformer,
342
+ blocks=pipe.transformer.blocks,
343
+ blocks_name="blocks",
344
+ dummy_blocks_names=[],
345
+ forward_pattern=ForwardPattern.Pattern_3,
346
+ )
347
+
348
+
349
+ @BlockAdapterRegistry.register("HunyuanDiTPAG")
350
+ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
351
+ from diffusers import HunyuanDiT2DModel
352
+
353
+ assert isinstance(pipe.transformer, HunyuanDiT2DModel)
354
+ return BlockAdapter(
355
+ pipe=pipe,
356
+ transformer=pipe.transformer,
357
+ blocks=pipe.transformer.blocks,
358
+ blocks_name="blocks",
359
+ dummy_blocks_names=[],
360
+ forward_pattern=ForwardPattern.Pattern_3,
361
+ )
362
+
363
+
364
+ @BlockAdapterRegistry.register("Lumina")
365
+ def lumina_adapter(pipe) -> BlockAdapter:
366
+ from diffusers import LuminaNextDiT2DModel
367
+
368
+ assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
369
+ return BlockAdapter(
370
+ pipe=pipe,
371
+ transformer=pipe.transformer,
372
+ blocks=pipe.transformer.layers,
373
+ blocks_name="layers",
374
+ dummy_blocks_names=[],
375
+ forward_pattern=ForwardPattern.Pattern_3,
376
+ )
377
+
378
+
379
+ @BlockAdapterRegistry.register("Lumina2")
380
+ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
381
+ from diffusers import Lumina2Transformer2DModel
382
+
383
+ assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
384
+ return BlockAdapter(
385
+ pipe=pipe,
386
+ transformer=pipe.transformer,
387
+ blocks=pipe.transformer.layers,
388
+ blocks_name="layers",
389
+ dummy_blocks_names=[],
390
+ forward_pattern=ForwardPattern.Pattern_3,
391
+ )
392
+
393
+
394
+ @BlockAdapterRegistry.register("OmniGen")
395
+ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
396
+ from diffusers import OmniGenTransformer2DModel
397
+
398
+ assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
399
+ return BlockAdapter(
400
+ pipe=pipe,
401
+ transformer=pipe.transformer,
402
+ blocks=pipe.transformer.layers,
403
+ blocks_name="layers",
404
+ dummy_blocks_names=[],
405
+ forward_pattern=ForwardPattern.Pattern_3,
406
+ )
407
+
408
+
409
+ @BlockAdapterRegistry.register("PixArt")
410
+ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
411
+ from diffusers import PixArtTransformer2DModel
412
+
413
+ assert isinstance(pipe.transformer, PixArtTransformer2DModel)
414
+ return BlockAdapter(
415
+ pipe=pipe,
416
+ transformer=pipe.transformer,
417
+ blocks=pipe.transformer.transformer_blocks,
418
+ blocks_name="transformer_blocks",
419
+ dummy_blocks_names=[],
420
+ forward_pattern=ForwardPattern.Pattern_3,
421
+ )
422
+
423
+
424
+ @BlockAdapterRegistry.register("Sana")
425
+ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
426
+ from diffusers import SanaTransformer2DModel
427
+
428
+ assert isinstance(pipe.transformer, SanaTransformer2DModel)
429
+ return BlockAdapter(
430
+ pipe=pipe,
431
+ transformer=pipe.transformer,
432
+ blocks=pipe.transformer.transformer_blocks,
433
+ blocks_name="transformer_blocks",
434
+ dummy_blocks_names=[],
435
+ forward_pattern=ForwardPattern.Pattern_3,
436
+ )
437
+
438
+
439
+ @BlockAdapterRegistry.register("ShapE")
440
+ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
441
+ from diffusers import PriorTransformer
442
+
443
+ assert isinstance(pipe.prior, PriorTransformer)
444
+ return BlockAdapter(
445
+ pipe=pipe,
446
+ transformer=pipe.prior,
447
+ blocks=pipe.prior.transformer_blocks,
448
+ blocks_name="transformer_blocks",
449
+ dummy_blocks_names=[],
450
+ forward_pattern=ForwardPattern.Pattern_3,
451
+ )
452
+
453
+
454
+ @BlockAdapterRegistry.register("StableAudio")
455
+ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
456
+ from diffusers import StableAudioDiTModel
457
+
458
+ assert isinstance(pipe.transformer, StableAudioDiTModel)
459
+ return BlockAdapter(
460
+ pipe=pipe,
461
+ transformer=pipe.transformer,
462
+ blocks=pipe.transformer.transformer_blocks,
463
+ blocks_name="transformer_blocks",
464
+ dummy_blocks_names=[],
465
+ forward_pattern=ForwardPattern.Pattern_3,
466
+ )
467
+
468
+
469
+ @BlockAdapterRegistry.register("VisualCloze")
470
+ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
471
+ from diffusers import FluxTransformer2DModel
472
+ from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
473
+
474
+ assert isinstance(pipe.transformer, FluxTransformer2DModel)
475
+ return BlockAdapter(
476
+ pipe=pipe,
477
+ transformer=pipe.transformer,
478
+ blocks=(
479
+ pipe.transformer.transformer_blocks
480
+ + pipe.transformer.single_transformer_blocks
481
+ ),
482
+ blocks_name="transformer_blocks",
483
+ dummy_blocks_names=["single_transformer_blocks"],
484
+ patch_functor=FluxPatchFunctor(),
485
+ forward_pattern=ForwardPattern.Pattern_1,
486
+ )
487
+
488
+
489
+ @BlockAdapterRegistry.register("AuraFlow")
490
+ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
491
+ from diffusers import AuraFlowTransformer2DModel
492
+
493
+ assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
494
+ return BlockAdapter(
495
+ pipe=pipe,
496
+ transformer=pipe.transformer,
497
+ blocks=[
498
+ # Only 4 joint blocks, apply no-cache
499
+ # pipe.transformer.joint_transformer_blocks,
500
+ pipe.transformer.single_transformer_blocks,
501
+ ],
502
+ blocks_name=[
503
+ # "joint_transformer_blocks",
504
+ "single_transformer_blocks",
505
+ ],
506
+ forward_pattern=[
507
+ # ForwardPattern.Pattern_1,
508
+ ForwardPattern.Pattern_3,
509
+ ],
510
+ )
511
+
512
+
513
+ @BlockAdapterRegistry.register("Chroma")
514
+ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
515
+ from diffusers import ChromaTransformer2DModel
516
+
517
+ assert isinstance(pipe.transformer, ChromaTransformer2DModel)
518
+ return BlockAdapter(
519
+ pipe=pipe,
520
+ transformer=pipe.transformer,
521
+ blocks=[
522
+ pipe.transformer.transformer_blocks,
523
+ pipe.transformer.single_transformer_blocks,
524
+ ],
525
+ blocks_name=[
526
+ "transformer_blocks",
527
+ "single_transformer_blocks",
528
+ ],
529
+ forward_pattern=[
530
+ ForwardPattern.Pattern_1,
531
+ ForwardPattern.Pattern_3,
532
+ ],
533
+ has_separate_cfg=True,
534
+ )
535
+
536
+
537
+ @BlockAdapterRegistry.register("HiDream")
538
+ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
539
+ from diffusers import HiDreamImageTransformer2DModel
540
+
541
+ assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
542
+ return BlockAdapter(
543
+ pipe=pipe,
544
+ transformer=pipe.transformer,
545
+ blocks=[
546
+ pipe.transformer.double_stream_blocks,
547
+ pipe.transformer.single_stream_blocks,
548
+ ],
549
+ dummy_blocks_names=[],
550
+ forward_pattern=[
551
+ ForwardPattern.Pattern_4,
552
+ ForwardPattern.Pattern_3,
553
+ ],
554
+ check_num_outputs=False,
555
+ )