cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (32) hide show
  1. cache_dit/__init__.py +9 -4
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +16 -3
  4. cache_dit/cache_factory/block_adapters/__init__.py +538 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +121 -563
  8. cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
  11. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
  12. cache_dit/cache_factory/cache_blocks/utils.py +23 -0
  13. cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
  14. cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
  15. cache_dit/cache_factory/cache_interface.py +24 -16
  16. cache_dit/cache_factory/forward_pattern.py +45 -24
  17. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  18. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  19. cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
  20. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
  21. cache_dit/quantize/quantize_ao.py +19 -4
  22. cache_dit/quantize/quantize_interface.py +2 -2
  23. cache_dit/utils.py +19 -15
  24. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
  25. cache_dit-0.2.27.dist-info/RECORD +47 -0
  26. cache_dit-0.2.25.dist-info/RECORD +0 -36
  27. /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
  28. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  29. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
  30. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
  31. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
  32. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -11,6 +11,9 @@ from cache_dit.cache_factory import block_range
11
11
  from cache_dit.cache_factory import CacheType
12
12
  from cache_dit.cache_factory import BlockAdapter
13
13
  from cache_dit.cache_factory import ForwardPattern
14
+ from cache_dit.cache_factory import PatchFunctor
15
+ from cache_dit.cache_factory import supported_pipelines
16
+ from cache_dit.cache_factory import get_adapter
14
17
  from cache_dit.compile import set_compile_configs
15
18
  from cache_dit.quantize import quantize
16
19
  from cache_dit.utils import summary
@@ -20,7 +23,9 @@ from cache_dit.logger import init_logger
20
23
  NONE = CacheType.NONE
21
24
  DBCache = CacheType.DBCache
22
25
 
23
- Forward_Pattern_0 = ForwardPattern.Pattern_0
24
- Forward_Pattern_1 = ForwardPattern.Pattern_1
25
- Forward_Pattern_2 = ForwardPattern.Pattern_2
26
- Forward_Pattern_3 = ForwardPattern.Pattern_3
26
+ Pattern_0 = ForwardPattern.Pattern_0
27
+ Pattern_1 = ForwardPattern.Pattern_1
28
+ Pattern_2 = ForwardPattern.Pattern_2
29
+ Pattern_3 = ForwardPattern.Pattern_3
30
+ Pattern_4 = ForwardPattern.Pattern_4
31
+ Pattern_5 = ForwardPattern.Pattern_5
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.25'
32
- __version_tuple__ = version_tuple = (0, 2, 25)
31
+ __version__ = version = '0.2.27'
32
+ __version_tuple__ = version_tuple = (0, 2, 27)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,8 +1,21 @@
1
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
1
  from cache_dit.cache_factory.cache_types import CacheType
3
2
  from cache_dit.cache_factory.cache_types import cache_type
4
3
  from cache_dit.cache_factory.cache_types import block_range
5
- from cache_dit.cache_factory.cache_adapters import BlockAdapter
6
- from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
4
+
5
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
6
+
7
+ from cache_dit.cache_factory.patch_functors import PatchFunctor
8
+
9
+ from cache_dit.cache_factory.block_adapters import BlockAdapter
10
+ from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
11
+
12
+ from cache_dit.cache_factory.cache_contexts import CachedContext
13
+ from cache_dit.cache_factory.cache_blocks import CachedBlocks
14
+
15
+ from cache_dit.cache_factory.cache_adapters import CachedAdapter
16
+
7
17
  from cache_dit.cache_factory.cache_interface import enable_cache
18
+ from cache_dit.cache_factory.cache_interface import supported_pipelines
19
+ from cache_dit.cache_factory.cache_interface import get_adapter
20
+
8
21
  from cache_dit.cache_factory.utils import load_options
@@ -0,0 +1,538 @@
1
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
+ from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
3
+ from cache_dit.cache_factory.block_adapters.block_registers import (
4
+ BlockAdapterRegistry,
5
+ )
6
+
7
+
8
+ @BlockAdapterRegistry.register("Flux")
9
+ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
10
+ from diffusers import FluxTransformer2DModel
11
+ from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
12
+
13
+ assert isinstance(pipe.transformer, FluxTransformer2DModel)
14
+
15
+ return BlockAdapter(
16
+ pipe=pipe,
17
+ transformer=pipe.transformer,
18
+ blocks=(
19
+ pipe.transformer.transformer_blocks
20
+ + pipe.transformer.single_transformer_blocks
21
+ ),
22
+ blocks_name="transformer_blocks",
23
+ dummy_blocks_names=["single_transformer_blocks"],
24
+ patch_functor=FluxPatchFunctor(),
25
+ forward_pattern=ForwardPattern.Pattern_1,
26
+ disable_patch=kwargs.pop("disable_patch", False),
27
+ )
28
+
29
+
30
+ @BlockAdapterRegistry.register("Mochi")
31
+ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
32
+ from diffusers import MochiTransformer3DModel
33
+
34
+ assert isinstance(pipe.transformer, MochiTransformer3DModel)
35
+ return BlockAdapter(
36
+ pipe=pipe,
37
+ transformer=pipe.transformer,
38
+ blocks=pipe.transformer.transformer_blocks,
39
+ blocks_name="transformer_blocks",
40
+ dummy_blocks_names=[],
41
+ forward_pattern=ForwardPattern.Pattern_0,
42
+ )
43
+
44
+
45
+ @BlockAdapterRegistry.register("CogVideoX")
46
+ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
47
+ from diffusers import CogVideoXTransformer3DModel
48
+
49
+ assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
50
+ return BlockAdapter(
51
+ pipe=pipe,
52
+ transformer=pipe.transformer,
53
+ blocks=pipe.transformer.transformer_blocks,
54
+ blocks_name="transformer_blocks",
55
+ dummy_blocks_names=[],
56
+ forward_pattern=ForwardPattern.Pattern_0,
57
+ )
58
+
59
+
60
+ @BlockAdapterRegistry.register("Wan")
61
+ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
62
+ from diffusers import (
63
+ WanTransformer3DModel,
64
+ WanVACETransformer3DModel,
65
+ )
66
+
67
+ assert isinstance(
68
+ pipe.transformer,
69
+ (WanTransformer3DModel, WanVACETransformer3DModel),
70
+ )
71
+ if getattr(pipe, "transformer_2", None):
72
+ # Wan 2.2, cache for low-noise transformer
73
+ return BlockAdapter(
74
+ pipe=pipe,
75
+ transformer=pipe.transformer_2,
76
+ blocks=pipe.transformer_2.blocks,
77
+ blocks_name="blocks",
78
+ dummy_blocks_names=[],
79
+ forward_pattern=ForwardPattern.Pattern_2,
80
+ has_separate_cfg=True,
81
+ )
82
+ else:
83
+ # Wan 2.1
84
+ return BlockAdapter(
85
+ pipe=pipe,
86
+ transformer=pipe.transformer,
87
+ blocks=pipe.transformer.blocks,
88
+ blocks_name="blocks",
89
+ dummy_blocks_names=[],
90
+ forward_pattern=ForwardPattern.Pattern_2,
91
+ has_separate_cfg=True,
92
+ )
93
+
94
+
95
+ @BlockAdapterRegistry.register("HunyuanVideo")
96
+ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
97
+ from diffusers import HunyuanVideoTransformer3DModel
98
+
99
+ assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
100
+ return BlockAdapter(
101
+ pipe=pipe,
102
+ blocks=(
103
+ pipe.transformer.transformer_blocks
104
+ + pipe.transformer.single_transformer_blocks
105
+ ),
106
+ blocks_name="transformer_blocks",
107
+ dummy_blocks_names=["single_transformer_blocks"],
108
+ forward_pattern=ForwardPattern.Pattern_0,
109
+ )
110
+
111
+
112
+ @BlockAdapterRegistry.register("QwenImage")
113
+ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
114
+ from diffusers import QwenImageTransformer2DModel
115
+
116
+ assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
117
+ return BlockAdapter(
118
+ pipe=pipe,
119
+ transformer=pipe.transformer,
120
+ blocks=pipe.transformer.transformer_blocks,
121
+ blocks_name="transformer_blocks",
122
+ dummy_blocks_names=[],
123
+ forward_pattern=ForwardPattern.Pattern_1,
124
+ has_separate_cfg=True,
125
+ )
126
+
127
+
128
+ @BlockAdapterRegistry.register("LTXVideo")
129
+ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
130
+ from diffusers import LTXVideoTransformer3DModel
131
+
132
+ assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
133
+ return BlockAdapter(
134
+ pipe=pipe,
135
+ transformer=pipe.transformer,
136
+ blocks=pipe.transformer.transformer_blocks,
137
+ blocks_name="transformer_blocks",
138
+ dummy_blocks_names=[],
139
+ forward_pattern=ForwardPattern.Pattern_2,
140
+ )
141
+
142
+
143
+ @BlockAdapterRegistry.register("Allegro")
144
+ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
145
+ from diffusers import AllegroTransformer3DModel
146
+
147
+ assert isinstance(pipe.transformer, AllegroTransformer3DModel)
148
+ return BlockAdapter(
149
+ pipe=pipe,
150
+ transformer=pipe.transformer,
151
+ blocks=pipe.transformer.transformer_blocks,
152
+ blocks_name="transformer_blocks",
153
+ dummy_blocks_names=[],
154
+ forward_pattern=ForwardPattern.Pattern_2,
155
+ )
156
+
157
+
158
+ @BlockAdapterRegistry.register("CogView3Plus")
159
+ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
160
+ from diffusers import CogView3PlusTransformer2DModel
161
+
162
+ assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
163
+ return BlockAdapter(
164
+ pipe=pipe,
165
+ transformer=pipe.transformer,
166
+ blocks=pipe.transformer.transformer_blocks,
167
+ blocks_name="transformer_blocks",
168
+ dummy_blocks_names=[],
169
+ forward_pattern=ForwardPattern.Pattern_0,
170
+ )
171
+
172
+
173
+ @BlockAdapterRegistry.register("CogView4")
174
+ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
175
+ from diffusers import CogView4Transformer2DModel
176
+
177
+ assert isinstance(pipe.transformer, CogView4Transformer2DModel)
178
+ return BlockAdapter(
179
+ pipe=pipe,
180
+ transformer=pipe.transformer,
181
+ blocks=pipe.transformer.transformer_blocks,
182
+ blocks_name="transformer_blocks",
183
+ dummy_blocks_names=[],
184
+ forward_pattern=ForwardPattern.Pattern_0,
185
+ has_separate_cfg=True,
186
+ )
187
+
188
+
189
+ @BlockAdapterRegistry.register("Cosmos")
190
+ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
191
+ from diffusers import CosmosTransformer3DModel
192
+
193
+ assert isinstance(pipe.transformer, CosmosTransformer3DModel)
194
+ return BlockAdapter(
195
+ pipe=pipe,
196
+ transformer=pipe.transformer,
197
+ blocks=pipe.transformer.transformer_blocks,
198
+ blocks_name="transformer_blocks",
199
+ dummy_blocks_names=[],
200
+ forward_pattern=ForwardPattern.Pattern_2,
201
+ has_separate_cfg=True,
202
+ )
203
+
204
+
205
+ @BlockAdapterRegistry.register("EasyAnimate")
206
+ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
207
+ from diffusers import EasyAnimateTransformer3DModel
208
+
209
+ assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
210
+ return BlockAdapter(
211
+ pipe=pipe,
212
+ transformer=pipe.transformer,
213
+ blocks=pipe.transformer.transformer_blocks,
214
+ blocks_name="transformer_blocks",
215
+ dummy_blocks_names=[],
216
+ forward_pattern=ForwardPattern.Pattern_0,
217
+ )
218
+
219
+
220
+ @BlockAdapterRegistry.register("SkyReelsV2")
221
+ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
222
+ from diffusers import SkyReelsV2Transformer3DModel
223
+
224
+ assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
225
+ return BlockAdapter(
226
+ pipe=pipe,
227
+ transformer=pipe.transformer,
228
+ blocks=pipe.transformer.blocks,
229
+ blocks_name="blocks",
230
+ dummy_blocks_names=[],
231
+ forward_pattern=ForwardPattern.Pattern_2,
232
+ has_separate_cfg=True,
233
+ )
234
+
235
+
236
+ @BlockAdapterRegistry.register("SD3")
237
+ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
238
+ from diffusers import SD3Transformer2DModel
239
+
240
+ assert isinstance(pipe.transformer, SD3Transformer2DModel)
241
+ return BlockAdapter(
242
+ pipe=pipe,
243
+ transformer=pipe.transformer,
244
+ blocks=pipe.transformer.transformer_blocks,
245
+ blocks_name="transformer_blocks",
246
+ dummy_blocks_names=[],
247
+ forward_pattern=ForwardPattern.Pattern_1,
248
+ )
249
+
250
+
251
+ @BlockAdapterRegistry.register("ConsisID")
252
+ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
253
+ from diffusers import ConsisIDTransformer3DModel
254
+
255
+ assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
256
+ return BlockAdapter(
257
+ pipe=pipe,
258
+ transformer=pipe.transformer,
259
+ blocks=pipe.transformer.transformer_blocks,
260
+ blocks_name="transformer_blocks",
261
+ dummy_blocks_names=[],
262
+ forward_pattern=ForwardPattern.Pattern_0,
263
+ )
264
+
265
+
266
+ @BlockAdapterRegistry.register("DiT")
267
+ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
268
+ from diffusers import DiTTransformer2DModel
269
+
270
+ assert isinstance(pipe.transformer, DiTTransformer2DModel)
271
+ return BlockAdapter(
272
+ pipe=pipe,
273
+ transformer=pipe.transformer,
274
+ blocks=pipe.transformer.transformer_blocks,
275
+ blocks_name="transformer_blocks",
276
+ dummy_blocks_names=[],
277
+ forward_pattern=ForwardPattern.Pattern_3,
278
+ )
279
+
280
+
281
+ @BlockAdapterRegistry.register("Amused")
282
+ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
283
+ from diffusers import UVit2DModel
284
+
285
+ assert isinstance(pipe.transformer, UVit2DModel)
286
+ return BlockAdapter(
287
+ pipe=pipe,
288
+ transformer=pipe.transformer,
289
+ blocks=pipe.transformer.transformer_layers,
290
+ blocks_name="transformer_layers",
291
+ dummy_blocks_names=[],
292
+ forward_pattern=ForwardPattern.Pattern_3,
293
+ )
294
+
295
+
296
+ @BlockAdapterRegistry.register("Bria")
297
+ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
298
+ from diffusers import BriaTransformer2DModel
299
+
300
+ assert isinstance(pipe.transformer, BriaTransformer2DModel)
301
+ return BlockAdapter(
302
+ pipe=pipe,
303
+ transformer=pipe.transformer,
304
+ blocks=(
305
+ pipe.transformer.transformer_blocks
306
+ + pipe.transformer.single_transformer_blocks
307
+ ),
308
+ blocks_name="transformer_blocks",
309
+ dummy_blocks_names=["single_transformer_blocks"],
310
+ forward_pattern=ForwardPattern.Pattern_0,
311
+ )
312
+
313
+
314
+ @BlockAdapterRegistry.register("HunyuanDiT")
315
+ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
316
+ from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
317
+
318
+ assert isinstance(
319
+ pipe.transformer,
320
+ (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
321
+ )
322
+ return BlockAdapter(
323
+ pipe=pipe,
324
+ transformer=pipe.transformer,
325
+ blocks=pipe.transformer.blocks,
326
+ blocks_name="blocks",
327
+ dummy_blocks_names=[],
328
+ forward_pattern=ForwardPattern.Pattern_3,
329
+ )
330
+
331
+
332
+ @BlockAdapterRegistry.register("HunyuanDiTPAG")
333
+ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
334
+ from diffusers import HunyuanDiT2DModel
335
+
336
+ assert isinstance(pipe.transformer, HunyuanDiT2DModel)
337
+ return BlockAdapter(
338
+ pipe=pipe,
339
+ transformer=pipe.transformer,
340
+ blocks=pipe.transformer.blocks,
341
+ blocks_name="blocks",
342
+ dummy_blocks_names=[],
343
+ forward_pattern=ForwardPattern.Pattern_3,
344
+ )
345
+
346
+
347
+ @BlockAdapterRegistry.register("Lumina")
348
+ def lumina_adapter(pipe) -> BlockAdapter:
349
+ from diffusers import LuminaNextDiT2DModel
350
+
351
+ assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
352
+ return BlockAdapter(
353
+ pipe=pipe,
354
+ transformer=pipe.transformer,
355
+ blocks=pipe.transformer.layers,
356
+ blocks_name="layers",
357
+ dummy_blocks_names=[],
358
+ forward_pattern=ForwardPattern.Pattern_3,
359
+ )
360
+
361
+
362
+ @BlockAdapterRegistry.register("Lumina2")
363
+ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
364
+ from diffusers import Lumina2Transformer2DModel
365
+
366
+ assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
367
+ return BlockAdapter(
368
+ pipe=pipe,
369
+ transformer=pipe.transformer,
370
+ blocks=pipe.transformer.layers,
371
+ blocks_name="layers",
372
+ dummy_blocks_names=[],
373
+ forward_pattern=ForwardPattern.Pattern_3,
374
+ )
375
+
376
+
377
+ @BlockAdapterRegistry.register("OmniGen")
378
+ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
379
+ from diffusers import OmniGenTransformer2DModel
380
+
381
+ assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
382
+ return BlockAdapter(
383
+ pipe=pipe,
384
+ transformer=pipe.transformer,
385
+ blocks=pipe.transformer.layers,
386
+ blocks_name="layers",
387
+ dummy_blocks_names=[],
388
+ forward_pattern=ForwardPattern.Pattern_3,
389
+ )
390
+
391
+
392
+ @BlockAdapterRegistry.register("PixArt")
393
+ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
394
+ from diffusers import PixArtTransformer2DModel
395
+
396
+ assert isinstance(pipe.transformer, PixArtTransformer2DModel)
397
+ return BlockAdapter(
398
+ pipe=pipe,
399
+ transformer=pipe.transformer,
400
+ blocks=pipe.transformer.transformer_blocks,
401
+ blocks_name="transformer_blocks",
402
+ dummy_blocks_names=[],
403
+ forward_pattern=ForwardPattern.Pattern_3,
404
+ )
405
+
406
+
407
+ @BlockAdapterRegistry.register("Sana")
408
+ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
409
+ from diffusers import SanaTransformer2DModel
410
+
411
+ assert isinstance(pipe.transformer, SanaTransformer2DModel)
412
+ return BlockAdapter(
413
+ pipe=pipe,
414
+ transformer=pipe.transformer,
415
+ blocks=pipe.transformer.transformer_blocks,
416
+ blocks_name="transformer_blocks",
417
+ dummy_blocks_names=[],
418
+ forward_pattern=ForwardPattern.Pattern_3,
419
+ )
420
+
421
+
422
+ @BlockAdapterRegistry.register("ShapE")
423
+ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
424
+ from diffusers import PriorTransformer
425
+
426
+ assert isinstance(pipe.prior, PriorTransformer)
427
+ return BlockAdapter(
428
+ pipe=pipe,
429
+ transformer=pipe.prior,
430
+ blocks=pipe.prior.transformer_blocks,
431
+ blocks_name="transformer_blocks",
432
+ dummy_blocks_names=[],
433
+ forward_pattern=ForwardPattern.Pattern_3,
434
+ )
435
+
436
+
437
+ @BlockAdapterRegistry.register("StableAudio")
438
+ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
439
+ from diffusers import StableAudioDiTModel
440
+
441
+ assert isinstance(pipe.transformer, StableAudioDiTModel)
442
+ return BlockAdapter(
443
+ pipe=pipe,
444
+ transformer=pipe.transformer,
445
+ blocks=pipe.transformer.transformer_blocks,
446
+ blocks_name="transformer_blocks",
447
+ dummy_blocks_names=[],
448
+ forward_pattern=ForwardPattern.Pattern_3,
449
+ )
450
+
451
+
452
+ @BlockAdapterRegistry.register("VisualCloze")
453
+ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
454
+ from diffusers import FluxTransformer2DModel
455
+ from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
456
+
457
+ assert isinstance(pipe.transformer, FluxTransformer2DModel)
458
+ return BlockAdapter(
459
+ pipe=pipe,
460
+ transformer=pipe.transformer,
461
+ blocks=(
462
+ pipe.transformer.transformer_blocks
463
+ + pipe.transformer.single_transformer_blocks
464
+ ),
465
+ blocks_name="transformer_blocks",
466
+ dummy_blocks_names=["single_transformer_blocks"],
467
+ patch_functor=FluxPatchFunctor(),
468
+ forward_pattern=ForwardPattern.Pattern_1,
469
+ )
470
+
471
+
472
+ @BlockAdapterRegistry.register("AuraFlow")
473
+ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
474
+ from diffusers import AuraFlowTransformer2DModel
475
+
476
+ assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
477
+ return BlockAdapter(
478
+ pipe=pipe,
479
+ transformer=pipe.transformer,
480
+ blocks=[
481
+ # Only 4 joint blocks, apply no-cache
482
+ # pipe.transformer.joint_transformer_blocks,
483
+ pipe.transformer.single_transformer_blocks,
484
+ ],
485
+ blocks_name=[
486
+ # "joint_transformer_blocks",
487
+ "single_transformer_blocks",
488
+ ],
489
+ forward_pattern=[
490
+ # ForwardPattern.Pattern_1,
491
+ ForwardPattern.Pattern_3,
492
+ ],
493
+ )
494
+
495
+
496
+ @BlockAdapterRegistry.register("Chroma")
497
+ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
498
+ from diffusers import ChromaTransformer2DModel
499
+
500
+ assert isinstance(pipe.transformer, ChromaTransformer2DModel)
501
+ return BlockAdapter(
502
+ pipe=pipe,
503
+ transformer=pipe.transformer,
504
+ blocks=[
505
+ pipe.transformer.transformer_blocks,
506
+ pipe.transformer.single_transformer_blocks,
507
+ ],
508
+ blocks_name=[
509
+ "transformer_blocks",
510
+ "single_transformer_blocks",
511
+ ],
512
+ forward_pattern=[
513
+ ForwardPattern.Pattern_1,
514
+ ForwardPattern.Pattern_3,
515
+ ],
516
+ has_separate_cfg=True,
517
+ )
518
+
519
+
520
+ @BlockAdapterRegistry.register("HiDream")
521
+ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
522
+ from diffusers import HiDreamImageTransformer2DModel
523
+
524
+ assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
525
+ return BlockAdapter(
526
+ pipe=pipe,
527
+ transformer=pipe.transformer,
528
+ blocks=[
529
+ pipe.transformer.double_stream_blocks,
530
+ pipe.transformer.single_stream_blocks,
531
+ ],
532
+ dummy_blocks_names=[],
533
+ forward_pattern=[
534
+ ForwardPattern.Pattern_4,
535
+ ForwardPattern.Pattern_3,
536
+ ],
537
+ check_num_outputs=False,
538
+ )