cache-dit 0.2.26__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (28) hide show
  1. cache_dit/__init__.py +7 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +15 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +538 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +120 -911
  8. cache_dit/cache_factory/cache_blocks/__init__.py +7 -9
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +46 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +98 -79
  12. cache_dit/cache_factory/cache_blocks/utils.py +13 -9
  13. cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
  14. cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +89 -55
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +0 -0
  16. cache_dit/cache_factory/cache_interface.py +21 -18
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +19 -16
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/METADATA +42 -12
  22. cache_dit-0.2.27.dist-info/RECORD +47 -0
  23. cache_dit-0.2.26.dist-info/RECORD +0 -42
  24. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  25. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
@@ -1,865 +1,35 @@
1
1
  import torch
2
2
 
3
- import inspect
4
3
  import unittest
5
4
  import functools
6
- import dataclasses
7
5
 
8
- from typing import Any, Tuple, List, Optional
6
+ from typing import Dict
9
7
  from contextlib import ExitStack
10
8
  from diffusers import DiffusionPipeline
11
9
  from cache_dit.cache_factory import CacheType
12
- from cache_dit.cache_factory import cache_context
10
+ from cache_dit.cache_factory import CachedContext
13
11
  from cache_dit.cache_factory import ForwardPattern
14
- from cache_dit.cache_factory.patch_functors import PatchFunctor
15
- from cache_dit.cache_factory.cache_blocks import (
16
- DBCachedBlocks,
17
- )
12
+ from cache_dit.cache_factory import BlockAdapter
13
+ from cache_dit.cache_factory import BlockAdapterRegistry
14
+ from cache_dit.cache_factory import CachedBlocks
18
15
 
19
16
  from cache_dit.logger import init_logger
20
17
 
21
18
  logger = init_logger(__name__)
22
19
 
23
20
 
24
- @dataclasses.dataclass
25
- class BlockAdapter:
26
- pipe: DiffusionPipeline | Any = None
27
- transformer: torch.nn.Module = None
28
- blocks: torch.nn.ModuleList = None
29
- # transformer_blocks, blocks, etc.
30
- blocks_name: str = None
31
- dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
32
- # patch functor: Flux, etc.
33
- patch_functor: Optional[PatchFunctor] = None
34
- # flags to control auto block adapter
35
- auto: bool = False
36
- allow_prefixes: List[str] = dataclasses.field(
37
- default_factory=lambda: [
38
- "transformer",
39
- "single_transformer",
40
- "blocks",
41
- "layers",
42
- "single_stream_blocks",
43
- "double_stream_blocks",
44
- ]
45
- )
46
- check_prefixes: bool = True
47
- allow_suffixes: List[str] = dataclasses.field(
48
- default_factory=lambda: ["TransformerBlock"]
49
- )
50
- check_suffixes: bool = False
51
- blocks_policy: str = dataclasses.field(
52
- default="max", metadata={"allowed_values": ["max", "min"]}
53
- )
54
-
55
- def __post_init__(self):
56
- assert any((self.pipe is not None, self.transformer is not None))
57
- self.patchify()
58
-
59
- def patchify(self, *args, **kwargs):
60
- # Process some specificial cases, specific for transformers
61
- # that has different forward patterns between single_transformer_blocks
62
- # and transformer_blocks , such as Flux (diffusers < 0.35.0).
63
- if self.patch_functor is not None:
64
- if self.transformer is not None:
65
- self.patch_functor.apply(self.transformer, *args, **kwargs)
66
- else:
67
- assert hasattr(self.pipe, "transformer")
68
- self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
69
-
70
- @staticmethod
71
- def auto_block_adapter(
72
- adapter: "BlockAdapter",
73
- forward_pattern: Optional[ForwardPattern] = None,
74
- ) -> "BlockAdapter":
75
- assert adapter.auto, (
76
- "Please manually set `auto` to True, or, manually "
77
- "set all the transformer blocks configuration."
78
- )
79
- assert adapter.pipe is not None, "adapter.pipe can not be None."
80
- pipe = adapter.pipe
81
-
82
- assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
83
-
84
- transformer = pipe.transformer
85
-
86
- # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
87
- blocks, blocks_name = BlockAdapter.find_blocks(
88
- transformer=transformer,
89
- allow_prefixes=adapter.allow_prefixes,
90
- allow_suffixes=adapter.allow_suffixes,
91
- check_prefixes=adapter.check_prefixes,
92
- check_suffixes=adapter.check_suffixes,
93
- blocks_policy=adapter.blocks_policy,
94
- forward_pattern=forward_pattern,
95
- )
96
-
97
- return BlockAdapter(
98
- pipe=pipe,
99
- transformer=transformer,
100
- blocks=blocks,
101
- blocks_name=blocks_name,
102
- )
103
-
104
- @staticmethod
105
- def check_block_adapter(adapter: "BlockAdapter") -> bool:
106
- if (
107
- # NOTE: pipe may not need to be DiffusionPipeline?
108
- # isinstance(adapter.pipe, DiffusionPipeline)
109
- adapter.pipe is not None
110
- and adapter.transformer is not None
111
- and adapter.blocks is not None
112
- and adapter.blocks_name is not None
113
- and isinstance(adapter.blocks, torch.nn.ModuleList)
114
- ):
115
- return True
116
-
117
- logger.warning("Check block adapter failed!")
118
- return False
119
-
120
- @staticmethod
121
- def find_blocks(
122
- transformer: torch.nn.Module,
123
- allow_prefixes: List[str] = [
124
- "transformer",
125
- "single_transformer",
126
- "blocks",
127
- "layers",
128
- ],
129
- allow_suffixes: List[str] = [
130
- "TransformerBlock",
131
- ],
132
- check_prefixes: bool = True,
133
- check_suffixes: bool = False,
134
- **kwargs,
135
- ) -> Tuple[torch.nn.ModuleList, str]:
136
- # Check prefixes
137
- if check_prefixes:
138
- blocks_names = []
139
- for attr_name in dir(transformer):
140
- for prefix in allow_prefixes:
141
- if attr_name.startswith(prefix):
142
- blocks_names.append(attr_name)
143
- else:
144
- blocks_names = dir(transformer)
145
-
146
- # Check ModuleList
147
- valid_names = []
148
- valid_count = []
149
- forward_pattern = kwargs.get("forward_pattern", None)
150
- for blocks_name in blocks_names:
151
- if blocks := getattr(transformer, blocks_name, None):
152
- if isinstance(blocks, torch.nn.ModuleList):
153
- block = blocks[0]
154
- block_cls_name = block.__class__.__name__
155
- # Check suffixes
156
- if isinstance(block, torch.nn.Module) and (
157
- any(
158
- (
159
- block_cls_name.endswith(allow_suffix)
160
- for allow_suffix in allow_suffixes
161
- )
162
- )
163
- or (not check_suffixes)
164
- ):
165
- # May check forward pattern
166
- if forward_pattern is not None:
167
- if BlockAdapter.match_blocks_pattern(
168
- blocks,
169
- forward_pattern,
170
- logging=False,
171
- ):
172
- valid_names.append(blocks_name)
173
- valid_count.append(len(blocks))
174
- else:
175
- valid_names.append(blocks_name)
176
- valid_count.append(len(blocks))
177
-
178
- if not valid_names:
179
- raise ValueError(
180
- "Auto selected transformer blocks failed, please set it manually."
181
- )
182
-
183
- final_name = valid_names[0]
184
- final_count = valid_count[0]
185
- block_policy = kwargs.get("blocks_policy", "max")
186
-
187
- for blocks_name, count in zip(valid_names, valid_count):
188
- blocks = getattr(transformer, blocks_name)
189
- logger.info(
190
- f"Auto selected transformer blocks: {blocks_name}, "
191
- f"class: {blocks[0].__class__.__name__}, "
192
- f"num blocks: {count}"
193
- )
194
- if block_policy == "max":
195
- if final_count < count:
196
- final_count = count
197
- final_name = blocks_name
198
- else:
199
- if final_count > count:
200
- final_count = count
201
- final_name = blocks_name
202
-
203
- final_blocks = getattr(transformer, final_name)
204
-
205
- logger.info(
206
- f"Final selected transformer blocks: {final_name}, "
207
- f"class: {final_blocks[0].__class__.__name__}, "
208
- f"num blocks: {final_count}, block_policy: {block_policy}."
209
- )
210
-
211
- return final_blocks, final_name
212
-
213
- @staticmethod
214
- def match_block_pattern(
215
- block: torch.nn.Module,
216
- forward_pattern: ForwardPattern,
217
- ) -> bool:
218
- assert (
219
- forward_pattern.Supported
220
- and forward_pattern in ForwardPattern.supported_patterns()
221
- ), f"Pattern {forward_pattern} is not support now!"
222
-
223
- forward_parameters = set(
224
- inspect.signature(block.forward).parameters.keys()
225
- )
226
- num_outputs = str(
227
- inspect.signature(block.forward).return_annotation
228
- ).count("torch.Tensor")
229
-
230
- in_matched = True
231
- out_matched = True
232
- if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
233
- # output pattern not match
234
- out_matched = False
235
-
236
- for required_param in forward_pattern.In:
237
- if required_param not in forward_parameters:
238
- in_matched = False
239
-
240
- return in_matched and out_matched
241
-
242
- @staticmethod
243
- def match_blocks_pattern(
244
- transformer_blocks: torch.nn.ModuleList,
245
- forward_pattern: ForwardPattern,
246
- logging: bool = True,
247
- ) -> bool:
248
- assert (
249
- forward_pattern.Supported
250
- and forward_pattern in ForwardPattern.supported_patterns()
251
- ), f"Pattern {forward_pattern} is not support now!"
252
-
253
- assert isinstance(transformer_blocks, torch.nn.ModuleList)
254
-
255
- pattern_matched_states = []
256
- for block in transformer_blocks:
257
- pattern_matched_states.append(
258
- BlockAdapter.match_block_pattern(
259
- block,
260
- forward_pattern,
261
- )
262
- )
263
-
264
- pattern_matched = all(pattern_matched_states) # all block match
265
- if pattern_matched and logging:
266
- block_cls_name = transformer_blocks[0].__class__.__name__
267
- logger.info(
268
- f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
269
- f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
270
- )
271
-
272
- return pattern_matched
273
-
274
-
275
- @dataclasses.dataclass
276
- class UnifiedCacheParams:
277
- block_adapter: BlockAdapter = None
278
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
279
-
280
-
281
- class UnifiedCacheAdapter:
282
- _supported_pipelines = [
283
- "Flux",
284
- "Mochi",
285
- "CogVideoX",
286
- "Wan",
287
- "HunyuanVideo",
288
- "QwenImage",
289
- "LTXVideo",
290
- "Allegro",
291
- "CogView3Plus",
292
- "CogView4",
293
- "Cosmos",
294
- "EasyAnimate",
295
- "SkyReelsV2",
296
- "SD3",
297
- "ConsisID",
298
- "DiT",
299
- "Amused",
300
- "Bria",
301
- "HunyuanDiT",
302
- "HunyuanDiTPAG",
303
- "Lumina",
304
- "Lumina2",
305
- "OmniGen",
306
- "PixArt",
307
- "Sana",
308
- "ShapE",
309
- "StableAudio",
310
- "VisualCloze",
311
- "AuraFlow",
312
- "Chroma",
313
- "HiDream",
314
- ]
21
+ # Unified Cached Adapter
22
+ class CachedAdapter:
315
23
 
316
24
  def __call__(self, *args, **kwargs):
317
25
  return self.apply(*args, **kwargs)
318
26
 
319
- @classmethod
320
- def supported_pipelines(cls) -> Tuple[int, List[str]]:
321
- return len(cls._supported_pipelines), [
322
- p + "*" for p in cls._supported_pipelines
323
- ]
324
-
325
- @classmethod
326
- def is_supported(cls, pipe: DiffusionPipeline) -> bool:
327
- pipe_cls_name: str = pipe.__class__.__name__
328
- for prefix in cls._supported_pipelines:
329
- if pipe_cls_name.startswith(prefix):
330
- return True
331
- return False
332
-
333
- @classmethod
334
- def get_params(cls, pipe: DiffusionPipeline) -> UnifiedCacheParams:
335
- pipe_cls_name: str = pipe.__class__.__name__
336
-
337
- if pipe_cls_name.startswith("Flux"):
338
- from diffusers import FluxTransformer2DModel
339
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
340
-
341
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
342
- return UnifiedCacheParams(
343
- block_adapter=BlockAdapter(
344
- pipe=pipe,
345
- transformer=pipe.transformer,
346
- blocks=(
347
- pipe.transformer.transformer_blocks
348
- + pipe.transformer.single_transformer_blocks
349
- ),
350
- blocks_name="transformer_blocks",
351
- dummy_blocks_names=["single_transformer_blocks"],
352
- patch_functor=FluxPatchFunctor(),
353
- ),
354
- forward_pattern=ForwardPattern.Pattern_1,
355
- )
356
-
357
- elif pipe_cls_name.startswith("Mochi"):
358
- from diffusers import MochiTransformer3DModel
359
-
360
- assert isinstance(pipe.transformer, MochiTransformer3DModel)
361
- return UnifiedCacheParams(
362
- block_adapter=BlockAdapter(
363
- pipe=pipe,
364
- transformer=pipe.transformer,
365
- blocks=pipe.transformer.transformer_blocks,
366
- blocks_name="transformer_blocks",
367
- dummy_blocks_names=[],
368
- ),
369
- forward_pattern=ForwardPattern.Pattern_0,
370
- )
371
-
372
- elif pipe_cls_name.startswith("CogVideoX"):
373
- from diffusers import CogVideoXTransformer3DModel
374
-
375
- assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
376
- return UnifiedCacheParams(
377
- block_adapter=BlockAdapter(
378
- pipe=pipe,
379
- transformer=pipe.transformer,
380
- blocks=pipe.transformer.transformer_blocks,
381
- blocks_name="transformer_blocks",
382
- dummy_blocks_names=[],
383
- ),
384
- forward_pattern=ForwardPattern.Pattern_0,
385
- )
386
-
387
- elif pipe_cls_name.startswith("Wan"):
388
- from diffusers import (
389
- WanTransformer3DModel,
390
- WanVACETransformer3DModel,
391
- )
392
-
393
- assert isinstance(
394
- pipe.transformer,
395
- (WanTransformer3DModel, WanVACETransformer3DModel),
396
- )
397
- if getattr(pipe, "transformer_2", None):
398
- # Wan 2.2, cache for low-noise transformer
399
- assert isinstance(
400
- pipe.transformer_2,
401
- (WanTransformer3DModel, WanVACETransformer3DModel),
402
- )
403
- return UnifiedCacheParams(
404
- block_adapter=BlockAdapter(
405
- pipe=pipe,
406
- transformer=pipe.transformer_2,
407
- blocks=pipe.transformer_2.blocks,
408
- blocks_name="blocks",
409
- dummy_blocks_names=[],
410
- ),
411
- forward_pattern=ForwardPattern.Pattern_2,
412
- )
413
- else:
414
- # Wan 2.1
415
- return UnifiedCacheParams(
416
- block_adapter=BlockAdapter(
417
- pipe=pipe,
418
- transformer=pipe.transformer,
419
- blocks=pipe.transformer.blocks,
420
- blocks_name="blocks",
421
- dummy_blocks_names=[],
422
- ),
423
- forward_pattern=ForwardPattern.Pattern_2,
424
- )
425
-
426
- elif pipe_cls_name.startswith("HunyuanVideo"):
427
- from diffusers import HunyuanVideoTransformer3DModel
428
-
429
- assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
430
- return UnifiedCacheParams(
431
- block_adapter=BlockAdapter(
432
- pipe=pipe,
433
- blocks=(
434
- pipe.transformer.transformer_blocks
435
- + pipe.transformer.single_transformer_blocks
436
- ),
437
- blocks_name="transformer_blocks",
438
- dummy_blocks_names=["single_transformer_blocks"],
439
- ),
440
- forward_pattern=ForwardPattern.Pattern_0,
441
- )
442
-
443
- elif pipe_cls_name.startswith("QwenImage"):
444
- from diffusers import QwenImageTransformer2DModel
445
-
446
- assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
447
- return UnifiedCacheParams(
448
- block_adapter=BlockAdapter(
449
- pipe=pipe,
450
- transformer=pipe.transformer,
451
- blocks=pipe.transformer.transformer_blocks,
452
- blocks_name="transformer_blocks",
453
- dummy_blocks_names=[],
454
- ),
455
- forward_pattern=ForwardPattern.Pattern_1,
456
- )
457
-
458
- elif pipe_cls_name.startswith("LTXVideo"):
459
- from diffusers import LTXVideoTransformer3DModel
460
-
461
- assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
462
- return UnifiedCacheParams(
463
- block_adapter=BlockAdapter(
464
- pipe=pipe,
465
- transformer=pipe.transformer,
466
- blocks=pipe.transformer.transformer_blocks,
467
- blocks_name="transformer_blocks",
468
- dummy_blocks_names=[],
469
- ),
470
- forward_pattern=ForwardPattern.Pattern_2,
471
- )
472
-
473
- elif pipe_cls_name.startswith("Allegro"):
474
- from diffusers import AllegroTransformer3DModel
475
-
476
- assert isinstance(pipe.transformer, AllegroTransformer3DModel)
477
- return UnifiedCacheParams(
478
- block_adapter=BlockAdapter(
479
- pipe=pipe,
480
- transformer=pipe.transformer,
481
- blocks=pipe.transformer.transformer_blocks,
482
- blocks_name="transformer_blocks",
483
- dummy_blocks_names=[],
484
- ),
485
- forward_pattern=ForwardPattern.Pattern_2,
486
- )
487
-
488
- elif pipe_cls_name.startswith("CogView3Plus"):
489
- from diffusers import CogView3PlusTransformer2DModel
490
-
491
- assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
492
- return UnifiedCacheParams(
493
- block_adapter=BlockAdapter(
494
- pipe=pipe,
495
- transformer=pipe.transformer,
496
- blocks=pipe.transformer.transformer_blocks,
497
- blocks_name="transformer_blocks",
498
- dummy_blocks_names=[],
499
- ),
500
- forward_pattern=ForwardPattern.Pattern_0,
501
- )
502
-
503
- elif pipe_cls_name.startswith("CogView4"):
504
- from diffusers import CogView4Transformer2DModel
505
-
506
- assert isinstance(pipe.transformer, CogView4Transformer2DModel)
507
- return UnifiedCacheParams(
508
- block_adapter=BlockAdapter(
509
- pipe=pipe,
510
- transformer=pipe.transformer,
511
- blocks=pipe.transformer.transformer_blocks,
512
- blocks_name="transformer_blocks",
513
- dummy_blocks_names=[],
514
- ),
515
- forward_pattern=ForwardPattern.Pattern_0,
516
- )
517
-
518
- elif pipe_cls_name.startswith("Cosmos"):
519
- from diffusers import CosmosTransformer3DModel
520
-
521
- assert isinstance(pipe.transformer, CosmosTransformer3DModel)
522
- return UnifiedCacheParams(
523
- block_adapter=BlockAdapter(
524
- pipe=pipe,
525
- transformer=pipe.transformer,
526
- blocks=pipe.transformer.transformer_blocks,
527
- blocks_name="transformer_blocks",
528
- dummy_blocks_names=[],
529
- ),
530
- forward_pattern=ForwardPattern.Pattern_2,
531
- )
532
-
533
- elif pipe_cls_name.startswith("EasyAnimate"):
534
- from diffusers import EasyAnimateTransformer3DModel
535
-
536
- assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
537
- return UnifiedCacheParams(
538
- block_adapter=BlockAdapter(
539
- pipe=pipe,
540
- transformer=pipe.transformer,
541
- blocks=pipe.transformer.transformer_blocks,
542
- blocks_name="transformer_blocks",
543
- dummy_blocks_names=[],
544
- ),
545
- forward_pattern=ForwardPattern.Pattern_0,
546
- )
547
-
548
- elif pipe_cls_name.startswith("SkyReelsV2"):
549
- from diffusers import SkyReelsV2Transformer3DModel
550
-
551
- assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
552
- return UnifiedCacheParams(
553
- block_adapter=BlockAdapter(
554
- pipe=pipe,
555
- transformer=pipe.transformer,
556
- blocks=pipe.transformer.blocks,
557
- blocks_name="blocks",
558
- dummy_blocks_names=[],
559
- ),
560
- forward_pattern=ForwardPattern.Pattern_2,
561
- )
562
- elif pipe_cls_name.startswith("SD3"):
563
- from diffusers import SD3Transformer2DModel
564
-
565
- assert isinstance(pipe.transformer, SD3Transformer2DModel)
566
- return UnifiedCacheParams(
567
- block_adapter=BlockAdapter(
568
- pipe=pipe,
569
- transformer=pipe.transformer,
570
- blocks=pipe.transformer.transformer_blocks,
571
- blocks_name="transformer_blocks",
572
- dummy_blocks_names=[],
573
- ),
574
- forward_pattern=ForwardPattern.Pattern_1,
575
- )
576
-
577
- elif pipe_cls_name.startswith("ConsisID"):
578
- from diffusers import ConsisIDTransformer3DModel
579
-
580
- assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
581
- return UnifiedCacheParams(
582
- block_adapter=BlockAdapter(
583
- pipe=pipe,
584
- transformer=pipe.transformer,
585
- blocks=pipe.transformer.transformer_blocks,
586
- blocks_name="transformer_blocks",
587
- dummy_blocks_names=[],
588
- ),
589
- forward_pattern=ForwardPattern.Pattern_0,
590
- )
591
-
592
- elif pipe_cls_name.startswith("DiT"):
593
- from diffusers import DiTTransformer2DModel
594
-
595
- assert isinstance(pipe.transformer, DiTTransformer2DModel)
596
- return UnifiedCacheParams(
597
- block_adapter=BlockAdapter(
598
- pipe=pipe,
599
- transformer=pipe.transformer,
600
- blocks=pipe.transformer.transformer_blocks,
601
- blocks_name="transformer_blocks",
602
- dummy_blocks_names=[],
603
- ),
604
- forward_pattern=ForwardPattern.Pattern_3,
605
- )
606
-
607
- elif pipe_cls_name.startswith("Amused"):
608
- from diffusers import UVit2DModel
609
-
610
- assert isinstance(pipe.transformer, UVit2DModel)
611
- return UnifiedCacheParams(
612
- block_adapter=BlockAdapter(
613
- pipe=pipe,
614
- transformer=pipe.transformer,
615
- blocks=pipe.transformer.transformer_layers,
616
- blocks_name="transformer_layers",
617
- dummy_blocks_names=[],
618
- ),
619
- forward_pattern=ForwardPattern.Pattern_3,
620
- )
621
-
622
- elif pipe_cls_name.startswith("Bria"):
623
- from diffusers import BriaTransformer2DModel
624
-
625
- assert isinstance(pipe.transformer, BriaTransformer2DModel)
626
- return UnifiedCacheParams(
627
- block_adapter=BlockAdapter(
628
- pipe=pipe,
629
- transformer=pipe.transformer,
630
- blocks=(
631
- pipe.transformer.transformer_blocks
632
- + pipe.transformer.single_transformer_blocks
633
- ),
634
- blocks_name="transformer_blocks",
635
- dummy_blocks_names=["single_transformer_blocks"],
636
- ),
637
- forward_pattern=ForwardPattern.Pattern_0,
638
- )
639
-
640
- elif pipe_cls_name.startswith("HunyuanDiT"):
641
- from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
642
-
643
- assert isinstance(
644
- pipe.transformer,
645
- (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
646
- )
647
- return UnifiedCacheParams(
648
- block_adapter=BlockAdapter(
649
- pipe=pipe,
650
- transformer=pipe.transformer,
651
- blocks=pipe.transformer.blocks,
652
- blocks_name="blocks",
653
- dummy_blocks_names=[],
654
- ),
655
- forward_pattern=ForwardPattern.Pattern_3,
656
- )
657
-
658
- elif pipe_cls_name.startswith("HunyuanDiTPAG"):
659
- from diffusers import HunyuanDiT2DModel
660
-
661
- assert isinstance(pipe.transformer, HunyuanDiT2DModel)
662
- return UnifiedCacheParams(
663
- block_adapter=BlockAdapter(
664
- pipe=pipe,
665
- transformer=pipe.transformer,
666
- blocks=pipe.transformer.blocks,
667
- blocks_name="blocks",
668
- dummy_blocks_names=[],
669
- ),
670
- forward_pattern=ForwardPattern.Pattern_3,
671
- )
672
-
673
- elif pipe_cls_name.startswith("Lumina"):
674
- from diffusers import LuminaNextDiT2DModel
675
-
676
- assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
677
- return UnifiedCacheParams(
678
- block_adapter=BlockAdapter(
679
- pipe=pipe,
680
- transformer=pipe.transformer,
681
- blocks=pipe.transformer.layers,
682
- blocks_name="layers",
683
- dummy_blocks_names=[],
684
- ),
685
- forward_pattern=ForwardPattern.Pattern_3,
686
- )
687
-
688
- elif pipe_cls_name.startswith("Lumina2"):
689
- from diffusers import Lumina2Transformer2DModel
690
-
691
- assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
692
- return UnifiedCacheParams(
693
- block_adapter=BlockAdapter(
694
- pipe=pipe,
695
- transformer=pipe.transformer,
696
- blocks=pipe.transformer.layers,
697
- blocks_name="layers",
698
- dummy_blocks_names=[],
699
- ),
700
- forward_pattern=ForwardPattern.Pattern_3,
701
- )
702
-
703
- elif pipe_cls_name.startswith("OmniGen"):
704
- from diffusers import OmniGenTransformer2DModel
705
-
706
- assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
707
- return UnifiedCacheParams(
708
- block_adapter=BlockAdapter(
709
- pipe=pipe,
710
- transformer=pipe.transformer,
711
- blocks=pipe.transformer.layers,
712
- blocks_name="layers",
713
- dummy_blocks_names=[],
714
- ),
715
- forward_pattern=ForwardPattern.Pattern_3,
716
- )
717
-
718
- elif pipe_cls_name.startswith("PixArt"):
719
- from diffusers import PixArtTransformer2DModel
720
-
721
- assert isinstance(pipe.transformer, PixArtTransformer2DModel)
722
- return UnifiedCacheParams(
723
- block_adapter=BlockAdapter(
724
- pipe=pipe,
725
- transformer=pipe.transformer,
726
- blocks=pipe.transformer.transformer_blocks,
727
- blocks_name="transformer_blocks",
728
- dummy_blocks_names=[],
729
- ),
730
- forward_pattern=ForwardPattern.Pattern_3,
731
- )
732
-
733
- elif pipe_cls_name.startswith("Sana"):
734
- from diffusers import SanaTransformer2DModel
735
-
736
- assert isinstance(pipe.transformer, SanaTransformer2DModel)
737
- return UnifiedCacheParams(
738
- block_adapter=BlockAdapter(
739
- pipe=pipe,
740
- transformer=pipe.transformer,
741
- blocks=pipe.transformer.transformer_blocks,
742
- blocks_name="transformer_blocks",
743
- dummy_blocks_names=[],
744
- ),
745
- forward_pattern=ForwardPattern.Pattern_3,
746
- )
747
-
748
- elif pipe_cls_name.startswith("ShapE"):
749
- from diffusers import PriorTransformer
750
-
751
- assert isinstance(pipe.prior, PriorTransformer)
752
- return UnifiedCacheParams(
753
- block_adapter=BlockAdapter(
754
- pipe=pipe,
755
- transformer=pipe.prior,
756
- blocks=pipe.prior.transformer_blocks,
757
- blocks_name="transformer_blocks",
758
- dummy_blocks_names=[],
759
- ),
760
- forward_pattern=ForwardPattern.Pattern_3,
761
- )
762
-
763
- elif pipe_cls_name.startswith("StableAudio"):
764
- from diffusers import StableAudioDiTModel
765
-
766
- assert isinstance(pipe.transformer, StableAudioDiTModel)
767
- return UnifiedCacheParams(
768
- block_adapter=BlockAdapter(
769
- pipe=pipe,
770
- transformer=pipe.transformer,
771
- blocks=pipe.transformer.transformer_blocks,
772
- blocks_name="transformer_blocks",
773
- dummy_blocks_names=[],
774
- ),
775
- forward_pattern=ForwardPattern.Pattern_3,
776
- )
777
-
778
- elif pipe_cls_name.startswith("VisualCloze"):
779
- from diffusers import FluxTransformer2DModel
780
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
781
-
782
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
783
- return UnifiedCacheParams(
784
- block_adapter=BlockAdapter(
785
- pipe=pipe,
786
- transformer=pipe.transformer,
787
- blocks=(
788
- pipe.transformer.transformer_blocks
789
- + pipe.transformer.single_transformer_blocks
790
- ),
791
- blocks_name="transformer_blocks",
792
- dummy_blocks_names=["single_transformer_blocks"],
793
- patch_functor=FluxPatchFunctor(),
794
- ),
795
- forward_pattern=ForwardPattern.Pattern_1,
796
- )
797
-
798
- elif pipe_cls_name.startswith("AuraFlow"):
799
- from diffusers import AuraFlowTransformer2DModel
800
-
801
- assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
802
- return UnifiedCacheParams(
803
- block_adapter=BlockAdapter(
804
- pipe=pipe,
805
- transformer=pipe.transformer,
806
- # Only support caching single_transformer_blocks for AuraFlow now.
807
- # TODO: Support AuraFlowPatchFunctor.
808
- blocks=pipe.transformer.single_transformer_blocks,
809
- blocks_name="single_transformer_blocks",
810
- dummy_blocks_names=[],
811
- ),
812
- forward_pattern=ForwardPattern.Pattern_3,
813
- )
814
-
815
- elif pipe_cls_name.startswith("Chroma"):
816
- from diffusers import ChromaTransformer2DModel
817
- from cache_dit.cache_factory.patch_functors import (
818
- ChromaPatchFunctor,
819
- )
820
-
821
- assert isinstance(pipe.transformer, ChromaTransformer2DModel)
822
- return UnifiedCacheParams(
823
- block_adapter=BlockAdapter(
824
- pipe=pipe,
825
- transformer=pipe.transformer,
826
- blocks=(
827
- pipe.transformer.transformer_blocks
828
- + pipe.transformer.single_transformer_blocks
829
- ),
830
- blocks_name="transformer_blocks",
831
- dummy_blocks_names=["single_transformer_blocks"],
832
- patch_functor=ChromaPatchFunctor(),
833
- ),
834
- forward_pattern=ForwardPattern.Pattern_1,
835
- )
836
-
837
- elif pipe_cls_name.startswith("HiDream"):
838
- from diffusers import HiDreamImageTransformer2DModel
839
-
840
- assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
841
- return UnifiedCacheParams(
842
- block_adapter=BlockAdapter(
843
- pipe=pipe,
844
- transformer=pipe.transformer,
845
- # Only support caching single_stream_blocks for HiDream now.
846
- # TODO: Support HiDreamPatchFunctor.
847
- blocks=pipe.transformer.single_stream_blocks,
848
- blocks_name="single_stream_blocks",
849
- dummy_blocks_names=[],
850
- ),
851
- forward_pattern=ForwardPattern.Pattern_3,
852
- )
853
-
854
- else:
855
- raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
856
-
857
27
  @classmethod
858
28
  def apply(
859
29
  cls,
860
30
  pipe: DiffusionPipeline = None,
861
31
  block_adapter: BlockAdapter = None,
862
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
32
+ # forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
863
33
  **cache_context_kwargs,
864
34
  ) -> DiffusionPipeline:
865
35
  assert (
@@ -867,15 +37,14 @@ class UnifiedCacheAdapter:
867
37
  ), "pipe or block_adapter can not both None!"
868
38
 
869
39
  if pipe is not None:
870
- if cls.is_supported(pipe):
40
+ if BlockAdapterRegistry.is_supported(pipe):
871
41
  logger.info(
872
42
  f"{pipe.__class__.__name__} is officially supported by cache-dit. "
873
43
  "Use it's pre-defined BlockAdapter directly!"
874
44
  )
875
- params = cls.get_params(pipe)
45
+ block_adapter = BlockAdapterRegistry.get_adapter(pipe)
876
46
  return cls.cachify(
877
- params.block_adapter,
878
- forward_pattern=params.forward_pattern,
47
+ block_adapter,
879
48
  **cache_context_kwargs,
880
49
  )
881
50
  else:
@@ -889,7 +58,6 @@ class UnifiedCacheAdapter:
889
58
  )
890
59
  return cls.cachify(
891
60
  block_adapter,
892
- forward_pattern=forward_pattern,
893
61
  **cache_context_kwargs,
894
62
  )
895
63
 
@@ -897,31 +65,27 @@ class UnifiedCacheAdapter:
897
65
  def cachify(
898
66
  cls,
899
67
  block_adapter: BlockAdapter,
900
- *,
901
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
902
68
  **cache_context_kwargs,
903
69
  ) -> DiffusionPipeline:
904
70
 
905
71
  if block_adapter.auto:
906
72
  block_adapter = BlockAdapter.auto_block_adapter(
907
73
  block_adapter,
908
- forward_pattern,
909
74
  )
910
75
 
911
76
  if BlockAdapter.check_block_adapter(block_adapter):
912
- # Apply cache on pipeline: wrap cache context
77
+ block_adapter = BlockAdapter.normalize(block_adapter)
78
+ # 0. Apply cache on pipeline: wrap cache context
913
79
  cls.create_context(
914
- block_adapter.pipe,
80
+ block_adapter,
915
81
  **cache_context_kwargs,
916
82
  )
917
- # Apply cache on transformer: mock cached transformer blocks
83
+ # 1. Apply cache on transformer: mock cached transformer blocks
918
84
  cls.mock_blocks(
919
85
  block_adapter,
920
- forward_pattern=forward_pattern,
921
86
  )
922
87
  cls.patch_params(
923
88
  block_adapter,
924
- forward_pattern=forward_pattern,
925
89
  **cache_context_kwargs,
926
90
  )
927
91
  return block_adapter.pipe
@@ -930,41 +94,36 @@ class UnifiedCacheAdapter:
930
94
  def patch_params(
931
95
  cls,
932
96
  block_adapter: BlockAdapter,
933
- forward_pattern: ForwardPattern = None,
934
97
  **cache_context_kwargs,
935
98
  ):
936
- block_adapter.transformer._forward_pattern = forward_pattern
99
+ block_adapter.transformer._forward_pattern = (
100
+ block_adapter.forward_pattern
101
+ )
102
+ block_adapter.transformer._has_separate_cfg = (
103
+ block_adapter.has_separate_cfg
104
+ )
937
105
  block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
938
106
  block_adapter.pipe.__class__._cache_context_kwargs = (
939
107
  cache_context_kwargs
940
108
  )
941
-
942
- @classmethod
943
- def has_separate_cfg(
944
- cls,
945
- pipe_or_transformer: DiffusionPipeline | Any,
946
- ) -> bool:
947
- cls_name = pipe_or_transformer.__class__.__name__
948
- if cls_name.startswith("QwenImage"):
949
- return True
950
- elif cls_name.startswith("Wan"):
951
- return True
952
- elif cls_name.startswith("CogView4"):
953
- return True
954
- elif cls_name.startswith("Cosmos"):
955
- return True
956
- elif cls_name.startswith("SkyReelsV2"):
957
- return True
958
- elif cls_name.startswith("Chroma"):
959
- return True
960
- return False
109
+ for blocks, forward_pattern in zip(
110
+ block_adapter.blocks, block_adapter.forward_pattern
111
+ ):
112
+ blocks._forward_pattern = forward_pattern
113
+ blocks._cache_context_kwargs = cache_context_kwargs
961
114
 
962
115
  @classmethod
963
116
  def check_context_kwargs(cls, pipe, **cache_context_kwargs):
964
117
  # Check cache_context_kwargs
965
118
  if not cache_context_kwargs["do_separate_cfg"]:
966
119
  # Check cfg for some specific case if users don't set it as True
967
- cache_context_kwargs["do_separate_cfg"] = cls.has_separate_cfg(pipe)
120
+ cache_context_kwargs["do_separate_cfg"] = (
121
+ BlockAdapterRegistry.has_separate_cfg(pipe)
122
+ )
123
+ logger.info(
124
+ f"Use default 'do_separate_cfg': {cache_context_kwargs['do_separate_cfg']}, "
125
+ f"Pipeline: {pipe.__class__.__name__}."
126
+ )
968
127
 
969
128
  if cache_type := cache_context_kwargs.pop("cache_type", None):
970
129
  assert (
@@ -976,65 +135,87 @@ class UnifiedCacheAdapter:
976
135
  @classmethod
977
136
  def create_context(
978
137
  cls,
979
- pipe: DiffusionPipeline,
138
+ block_adapter: BlockAdapter,
980
139
  **cache_context_kwargs,
981
140
  ) -> DiffusionPipeline:
982
- if getattr(pipe, "_is_cached", False):
983
- return pipe
141
+ if getattr(block_adapter.pipe, "_is_cached", False):
142
+ return block_adapter.pipe
984
143
 
985
144
  # Check cache_context_kwargs
986
145
  cache_context_kwargs = cls.check_context_kwargs(
987
- pipe,
146
+ block_adapter.pipe,
988
147
  **cache_context_kwargs,
989
148
  )
990
149
  # Apply cache on pipeline: wrap cache context
991
- cache_kwargs, _ = cache_context.collect_cache_kwargs(
150
+ cache_kwargs, _ = CachedContext.collect_cache_kwargs(
992
151
  default_attrs={},
993
152
  **cache_context_kwargs,
994
153
  )
995
- original_call = pipe.__class__.__call__
154
+ original_call = block_adapter.pipe.__class__.__call__
996
155
 
997
156
  @functools.wraps(original_call)
998
157
  def new_call(self, *args, **kwargs):
999
- with cache_context.cache_context(
1000
- cache_context.create_cache_context(
1001
- **cache_kwargs,
1002
- )
1003
- ):
1004
- return original_call(self, *args, **kwargs)
158
+ with ExitStack() as stack:
159
+ # cache context will reset for each pipe inference
160
+ for blocks_name in block_adapter.blocks_name:
161
+ stack.enter_context(
162
+ CachedContext.cache_context(
163
+ CachedContext.reset_cache_context(
164
+ blocks_name,
165
+ **cache_kwargs,
166
+ ),
167
+ )
168
+ )
169
+ outputs = original_call(self, *args, **kwargs)
170
+ cls.patch_stats(block_adapter)
171
+ return outputs
172
+
173
+ block_adapter.pipe.__class__.__call__ = new_call
174
+ block_adapter.pipe.__class__._is_cached = True
175
+ return block_adapter.pipe
1005
176
 
1006
- pipe.__class__.__call__ = new_call
1007
- pipe.__class__._is_cached = True
1008
- return pipe
177
+ @classmethod
178
+ def patch_stats(cls, block_adapter: BlockAdapter):
179
+ from cache_dit.cache_factory.cache_blocks.utils import (
180
+ patch_cached_stats,
181
+ )
182
+
183
+ patch_cached_stats(block_adapter.transformer)
184
+ for blocks, blocks_name in zip(
185
+ block_adapter.blocks, block_adapter.blocks_name
186
+ ):
187
+ patch_cached_stats(blocks, blocks_name)
1009
188
 
1010
189
  @classmethod
1011
190
  def mock_blocks(
1012
191
  cls,
1013
192
  block_adapter: BlockAdapter,
1014
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
1015
193
  ) -> torch.nn.Module:
1016
194
 
1017
195
  if getattr(block_adapter.transformer, "_is_cached", False):
1018
196
  return block_adapter.transformer
1019
197
 
1020
198
  # Check block forward pattern matching
1021
- assert BlockAdapter.match_blocks_pattern(
1022
- block_adapter.blocks,
1023
- forward_pattern=forward_pattern,
1024
- ), (
1025
- "No block forward pattern matched, "
1026
- f"supported lists: {ForwardPattern.supported_patterns()}"
1027
- )
199
+ block_adapter = BlockAdapter.normalize(block_adapter)
200
+ for forward_pattern, blocks in zip(
201
+ block_adapter.forward_pattern, block_adapter.blocks
202
+ ):
203
+ assert BlockAdapter.match_blocks_pattern(
204
+ blocks,
205
+ forward_pattern=forward_pattern,
206
+ check_num_outputs=block_adapter.check_num_outputs,
207
+ ), (
208
+ "No block forward pattern matched, "
209
+ f"supported lists: {ForwardPattern.supported_patterns()}"
210
+ )
1028
211
 
1029
212
  # Apply cache on transformer: mock cached transformer blocks
1030
- cached_blocks = torch.nn.ModuleList(
1031
- [
1032
- DBCachedBlocks(
1033
- block_adapter.blocks,
1034
- transformer=block_adapter.transformer,
1035
- forward_pattern=forward_pattern,
1036
- )
1037
- ]
213
+ # TODO: Use blocks_name to spearate cached context for different
214
+ # blocks list. For example, single_transformer_blocks and
215
+ # transformer_blocks should have different cached context and
216
+ # forward pattern.
217
+ cached_blocks = cls.collect_cached_blocks(
218
+ block_adapter=block_adapter,
1038
219
  )
1039
220
  dummy_blocks = torch.nn.ModuleList()
1040
221
 
@@ -1045,13 +226,14 @@ class UnifiedCacheAdapter:
1045
226
  @functools.wraps(original_forward)
1046
227
  def new_forward(self, *args, **kwargs):
1047
228
  with ExitStack() as stack:
1048
- stack.enter_context(
1049
- unittest.mock.patch.object(
1050
- self,
1051
- block_adapter.blocks_name,
1052
- cached_blocks,
229
+ for blocks_name in block_adapter.blocks_name:
230
+ stack.enter_context(
231
+ unittest.mock.patch.object(
232
+ self,
233
+ blocks_name,
234
+ cached_blocks[blocks_name],
235
+ )
1053
236
  )
1054
- )
1055
237
  for dummy_name in block_adapter.dummy_blocks_names:
1056
238
  stack.enter_context(
1057
239
  unittest.mock.patch.object(
@@ -1068,3 +250,30 @@ class UnifiedCacheAdapter:
1068
250
  block_adapter.transformer._is_cached = True
1069
251
 
1070
252
  return block_adapter.transformer
253
+
254
+ @classmethod
255
+ def collect_cached_blocks(
256
+ cls,
257
+ block_adapter: BlockAdapter,
258
+ ) -> Dict[str, torch.nn.ModuleList]:
259
+ block_adapter = BlockAdapter.normalize(block_adapter)
260
+
261
+ cached_blocks_bind_context = {}
262
+
263
+ for i in range(len(block_adapter.blocks)):
264
+ cached_blocks_bind_context[block_adapter.blocks_name[i]] = (
265
+ torch.nn.ModuleList(
266
+ [
267
+ CachedBlocks(
268
+ block_adapter.blocks[i],
269
+ block_adapter.blocks_name[i],
270
+ block_adapter.blocks_name[i], # context name
271
+ transformer=block_adapter.transformer,
272
+ forward_pattern=block_adapter.forward_pattern[i],
273
+ check_num_outputs=block_adapter.check_num_outputs,
274
+ )
275
+ ]
276
+ )
277
+ )
278
+
279
+ return cached_blocks_bind_context