cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cache_dit/__init__.py +8 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +17 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +555 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +262 -938
  8. cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
  12. cache_dit/cache_factory/cache_blocks/utils.py +16 -10
  13. cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
  14. cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
  16. cache_dit/cache_factory/cache_interface.py +31 -31
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +26 -26
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
  22. cache_dit-0.2.28.dist-info/RECORD +47 -0
  23. cache_dit/cache_factory/cache_context.py +0 -1155
  24. cache_dit-0.2.26.dist-info/RECORD +0 -42
  25. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
@@ -1,865 +1,36 @@
1
1
  import torch
2
2
 
3
- import inspect
4
3
  import unittest
5
4
  import functools
6
- import dataclasses
7
5
 
8
- from typing import Any, Tuple, List, Optional
9
6
  from contextlib import ExitStack
7
+ from typing import Dict, List, Tuple, Any
8
+
10
9
  from diffusers import DiffusionPipeline
10
+
11
11
  from cache_dit.cache_factory import CacheType
12
- from cache_dit.cache_factory import cache_context
13
- from cache_dit.cache_factory import ForwardPattern
14
- from cache_dit.cache_factory.patch_functors import PatchFunctor
15
- from cache_dit.cache_factory.cache_blocks import (
16
- DBCachedBlocks,
17
- )
12
+ from cache_dit.cache_factory import BlockAdapter
13
+ from cache_dit.cache_factory import ParamsModifier
14
+ from cache_dit.cache_factory import BlockAdapterRegistry
15
+ from cache_dit.cache_factory import CachedContextManager
16
+ from cache_dit.cache_factory import CachedBlocks
18
17
 
19
18
  from cache_dit.logger import init_logger
20
19
 
21
20
  logger = init_logger(__name__)
22
21
 
23
22
 
24
- @dataclasses.dataclass
25
- class BlockAdapter:
26
- pipe: DiffusionPipeline | Any = None
27
- transformer: torch.nn.Module = None
28
- blocks: torch.nn.ModuleList = None
29
- # transformer_blocks, blocks, etc.
30
- blocks_name: str = None
31
- dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
32
- # patch functor: Flux, etc.
33
- patch_functor: Optional[PatchFunctor] = None
34
- # flags to control auto block adapter
35
- auto: bool = False
36
- allow_prefixes: List[str] = dataclasses.field(
37
- default_factory=lambda: [
38
- "transformer",
39
- "single_transformer",
40
- "blocks",
41
- "layers",
42
- "single_stream_blocks",
43
- "double_stream_blocks",
44
- ]
45
- )
46
- check_prefixes: bool = True
47
- allow_suffixes: List[str] = dataclasses.field(
48
- default_factory=lambda: ["TransformerBlock"]
49
- )
50
- check_suffixes: bool = False
51
- blocks_policy: str = dataclasses.field(
52
- default="max", metadata={"allowed_values": ["max", "min"]}
53
- )
54
-
55
- def __post_init__(self):
56
- assert any((self.pipe is not None, self.transformer is not None))
57
- self.patchify()
58
-
59
- def patchify(self, *args, **kwargs):
60
- # Process some specificial cases, specific for transformers
61
- # that has different forward patterns between single_transformer_blocks
62
- # and transformer_blocks , such as Flux (diffusers < 0.35.0).
63
- if self.patch_functor is not None:
64
- if self.transformer is not None:
65
- self.patch_functor.apply(self.transformer, *args, **kwargs)
66
- else:
67
- assert hasattr(self.pipe, "transformer")
68
- self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
69
-
70
- @staticmethod
71
- def auto_block_adapter(
72
- adapter: "BlockAdapter",
73
- forward_pattern: Optional[ForwardPattern] = None,
74
- ) -> "BlockAdapter":
75
- assert adapter.auto, (
76
- "Please manually set `auto` to True, or, manually "
77
- "set all the transformer blocks configuration."
78
- )
79
- assert adapter.pipe is not None, "adapter.pipe can not be None."
80
- pipe = adapter.pipe
81
-
82
- assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
83
-
84
- transformer = pipe.transformer
85
-
86
- # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
87
- blocks, blocks_name = BlockAdapter.find_blocks(
88
- transformer=transformer,
89
- allow_prefixes=adapter.allow_prefixes,
90
- allow_suffixes=adapter.allow_suffixes,
91
- check_prefixes=adapter.check_prefixes,
92
- check_suffixes=adapter.check_suffixes,
93
- blocks_policy=adapter.blocks_policy,
94
- forward_pattern=forward_pattern,
95
- )
96
-
97
- return BlockAdapter(
98
- pipe=pipe,
99
- transformer=transformer,
100
- blocks=blocks,
101
- blocks_name=blocks_name,
102
- )
103
-
104
- @staticmethod
105
- def check_block_adapter(adapter: "BlockAdapter") -> bool:
106
- if (
107
- # NOTE: pipe may not need to be DiffusionPipeline?
108
- # isinstance(adapter.pipe, DiffusionPipeline)
109
- adapter.pipe is not None
110
- and adapter.transformer is not None
111
- and adapter.blocks is not None
112
- and adapter.blocks_name is not None
113
- and isinstance(adapter.blocks, torch.nn.ModuleList)
114
- ):
115
- return True
116
-
117
- logger.warning("Check block adapter failed!")
118
- return False
119
-
120
- @staticmethod
121
- def find_blocks(
122
- transformer: torch.nn.Module,
123
- allow_prefixes: List[str] = [
124
- "transformer",
125
- "single_transformer",
126
- "blocks",
127
- "layers",
128
- ],
129
- allow_suffixes: List[str] = [
130
- "TransformerBlock",
131
- ],
132
- check_prefixes: bool = True,
133
- check_suffixes: bool = False,
134
- **kwargs,
135
- ) -> Tuple[torch.nn.ModuleList, str]:
136
- # Check prefixes
137
- if check_prefixes:
138
- blocks_names = []
139
- for attr_name in dir(transformer):
140
- for prefix in allow_prefixes:
141
- if attr_name.startswith(prefix):
142
- blocks_names.append(attr_name)
143
- else:
144
- blocks_names = dir(transformer)
145
-
146
- # Check ModuleList
147
- valid_names = []
148
- valid_count = []
149
- forward_pattern = kwargs.get("forward_pattern", None)
150
- for blocks_name in blocks_names:
151
- if blocks := getattr(transformer, blocks_name, None):
152
- if isinstance(blocks, torch.nn.ModuleList):
153
- block = blocks[0]
154
- block_cls_name = block.__class__.__name__
155
- # Check suffixes
156
- if isinstance(block, torch.nn.Module) and (
157
- any(
158
- (
159
- block_cls_name.endswith(allow_suffix)
160
- for allow_suffix in allow_suffixes
161
- )
162
- )
163
- or (not check_suffixes)
164
- ):
165
- # May check forward pattern
166
- if forward_pattern is not None:
167
- if BlockAdapter.match_blocks_pattern(
168
- blocks,
169
- forward_pattern,
170
- logging=False,
171
- ):
172
- valid_names.append(blocks_name)
173
- valid_count.append(len(blocks))
174
- else:
175
- valid_names.append(blocks_name)
176
- valid_count.append(len(blocks))
177
-
178
- if not valid_names:
179
- raise ValueError(
180
- "Auto selected transformer blocks failed, please set it manually."
181
- )
182
-
183
- final_name = valid_names[0]
184
- final_count = valid_count[0]
185
- block_policy = kwargs.get("blocks_policy", "max")
186
-
187
- for blocks_name, count in zip(valid_names, valid_count):
188
- blocks = getattr(transformer, blocks_name)
189
- logger.info(
190
- f"Auto selected transformer blocks: {blocks_name}, "
191
- f"class: {blocks[0].__class__.__name__}, "
192
- f"num blocks: {count}"
193
- )
194
- if block_policy == "max":
195
- if final_count < count:
196
- final_count = count
197
- final_name = blocks_name
198
- else:
199
- if final_count > count:
200
- final_count = count
201
- final_name = blocks_name
202
-
203
- final_blocks = getattr(transformer, final_name)
204
-
205
- logger.info(
206
- f"Final selected transformer blocks: {final_name}, "
207
- f"class: {final_blocks[0].__class__.__name__}, "
208
- f"num blocks: {final_count}, block_policy: {block_policy}."
209
- )
210
-
211
- return final_blocks, final_name
212
-
213
- @staticmethod
214
- def match_block_pattern(
215
- block: torch.nn.Module,
216
- forward_pattern: ForwardPattern,
217
- ) -> bool:
218
- assert (
219
- forward_pattern.Supported
220
- and forward_pattern in ForwardPattern.supported_patterns()
221
- ), f"Pattern {forward_pattern} is not support now!"
222
-
223
- forward_parameters = set(
224
- inspect.signature(block.forward).parameters.keys()
225
- )
226
- num_outputs = str(
227
- inspect.signature(block.forward).return_annotation
228
- ).count("torch.Tensor")
229
-
230
- in_matched = True
231
- out_matched = True
232
- if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
233
- # output pattern not match
234
- out_matched = False
235
-
236
- for required_param in forward_pattern.In:
237
- if required_param not in forward_parameters:
238
- in_matched = False
239
-
240
- return in_matched and out_matched
241
-
242
- @staticmethod
243
- def match_blocks_pattern(
244
- transformer_blocks: torch.nn.ModuleList,
245
- forward_pattern: ForwardPattern,
246
- logging: bool = True,
247
- ) -> bool:
248
- assert (
249
- forward_pattern.Supported
250
- and forward_pattern in ForwardPattern.supported_patterns()
251
- ), f"Pattern {forward_pattern} is not support now!"
252
-
253
- assert isinstance(transformer_blocks, torch.nn.ModuleList)
254
-
255
- pattern_matched_states = []
256
- for block in transformer_blocks:
257
- pattern_matched_states.append(
258
- BlockAdapter.match_block_pattern(
259
- block,
260
- forward_pattern,
261
- )
262
- )
263
-
264
- pattern_matched = all(pattern_matched_states) # all block match
265
- if pattern_matched and logging:
266
- block_cls_name = transformer_blocks[0].__class__.__name__
267
- logger.info(
268
- f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
269
- f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
270
- )
271
-
272
- return pattern_matched
273
-
274
-
275
- @dataclasses.dataclass
276
- class UnifiedCacheParams:
277
- block_adapter: BlockAdapter = None
278
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
279
-
280
-
281
- class UnifiedCacheAdapter:
282
- _supported_pipelines = [
283
- "Flux",
284
- "Mochi",
285
- "CogVideoX",
286
- "Wan",
287
- "HunyuanVideo",
288
- "QwenImage",
289
- "LTXVideo",
290
- "Allegro",
291
- "CogView3Plus",
292
- "CogView4",
293
- "Cosmos",
294
- "EasyAnimate",
295
- "SkyReelsV2",
296
- "SD3",
297
- "ConsisID",
298
- "DiT",
299
- "Amused",
300
- "Bria",
301
- "HunyuanDiT",
302
- "HunyuanDiTPAG",
303
- "Lumina",
304
- "Lumina2",
305
- "OmniGen",
306
- "PixArt",
307
- "Sana",
308
- "ShapE",
309
- "StableAudio",
310
- "VisualCloze",
311
- "AuraFlow",
312
- "Chroma",
313
- "HiDream",
314
- ]
23
+ # Unified Cached Adapter
24
+ class CachedAdapter:
315
25
 
316
26
  def __call__(self, *args, **kwargs):
317
27
  return self.apply(*args, **kwargs)
318
28
 
319
- @classmethod
320
- def supported_pipelines(cls) -> Tuple[int, List[str]]:
321
- return len(cls._supported_pipelines), [
322
- p + "*" for p in cls._supported_pipelines
323
- ]
324
-
325
- @classmethod
326
- def is_supported(cls, pipe: DiffusionPipeline) -> bool:
327
- pipe_cls_name: str = pipe.__class__.__name__
328
- for prefix in cls._supported_pipelines:
329
- if pipe_cls_name.startswith(prefix):
330
- return True
331
- return False
332
-
333
- @classmethod
334
- def get_params(cls, pipe: DiffusionPipeline) -> UnifiedCacheParams:
335
- pipe_cls_name: str = pipe.__class__.__name__
336
-
337
- if pipe_cls_name.startswith("Flux"):
338
- from diffusers import FluxTransformer2DModel
339
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
340
-
341
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
342
- return UnifiedCacheParams(
343
- block_adapter=BlockAdapter(
344
- pipe=pipe,
345
- transformer=pipe.transformer,
346
- blocks=(
347
- pipe.transformer.transformer_blocks
348
- + pipe.transformer.single_transformer_blocks
349
- ),
350
- blocks_name="transformer_blocks",
351
- dummy_blocks_names=["single_transformer_blocks"],
352
- patch_functor=FluxPatchFunctor(),
353
- ),
354
- forward_pattern=ForwardPattern.Pattern_1,
355
- )
356
-
357
- elif pipe_cls_name.startswith("Mochi"):
358
- from diffusers import MochiTransformer3DModel
359
-
360
- assert isinstance(pipe.transformer, MochiTransformer3DModel)
361
- return UnifiedCacheParams(
362
- block_adapter=BlockAdapter(
363
- pipe=pipe,
364
- transformer=pipe.transformer,
365
- blocks=pipe.transformer.transformer_blocks,
366
- blocks_name="transformer_blocks",
367
- dummy_blocks_names=[],
368
- ),
369
- forward_pattern=ForwardPattern.Pattern_0,
370
- )
371
-
372
- elif pipe_cls_name.startswith("CogVideoX"):
373
- from diffusers import CogVideoXTransformer3DModel
374
-
375
- assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
376
- return UnifiedCacheParams(
377
- block_adapter=BlockAdapter(
378
- pipe=pipe,
379
- transformer=pipe.transformer,
380
- blocks=pipe.transformer.transformer_blocks,
381
- blocks_name="transformer_blocks",
382
- dummy_blocks_names=[],
383
- ),
384
- forward_pattern=ForwardPattern.Pattern_0,
385
- )
386
-
387
- elif pipe_cls_name.startswith("Wan"):
388
- from diffusers import (
389
- WanTransformer3DModel,
390
- WanVACETransformer3DModel,
391
- )
392
-
393
- assert isinstance(
394
- pipe.transformer,
395
- (WanTransformer3DModel, WanVACETransformer3DModel),
396
- )
397
- if getattr(pipe, "transformer_2", None):
398
- # Wan 2.2, cache for low-noise transformer
399
- assert isinstance(
400
- pipe.transformer_2,
401
- (WanTransformer3DModel, WanVACETransformer3DModel),
402
- )
403
- return UnifiedCacheParams(
404
- block_adapter=BlockAdapter(
405
- pipe=pipe,
406
- transformer=pipe.transformer_2,
407
- blocks=pipe.transformer_2.blocks,
408
- blocks_name="blocks",
409
- dummy_blocks_names=[],
410
- ),
411
- forward_pattern=ForwardPattern.Pattern_2,
412
- )
413
- else:
414
- # Wan 2.1
415
- return UnifiedCacheParams(
416
- block_adapter=BlockAdapter(
417
- pipe=pipe,
418
- transformer=pipe.transformer,
419
- blocks=pipe.transformer.blocks,
420
- blocks_name="blocks",
421
- dummy_blocks_names=[],
422
- ),
423
- forward_pattern=ForwardPattern.Pattern_2,
424
- )
425
-
426
- elif pipe_cls_name.startswith("HunyuanVideo"):
427
- from diffusers import HunyuanVideoTransformer3DModel
428
-
429
- assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
430
- return UnifiedCacheParams(
431
- block_adapter=BlockAdapter(
432
- pipe=pipe,
433
- blocks=(
434
- pipe.transformer.transformer_blocks
435
- + pipe.transformer.single_transformer_blocks
436
- ),
437
- blocks_name="transformer_blocks",
438
- dummy_blocks_names=["single_transformer_blocks"],
439
- ),
440
- forward_pattern=ForwardPattern.Pattern_0,
441
- )
442
-
443
- elif pipe_cls_name.startswith("QwenImage"):
444
- from diffusers import QwenImageTransformer2DModel
445
-
446
- assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
447
- return UnifiedCacheParams(
448
- block_adapter=BlockAdapter(
449
- pipe=pipe,
450
- transformer=pipe.transformer,
451
- blocks=pipe.transformer.transformer_blocks,
452
- blocks_name="transformer_blocks",
453
- dummy_blocks_names=[],
454
- ),
455
- forward_pattern=ForwardPattern.Pattern_1,
456
- )
457
-
458
- elif pipe_cls_name.startswith("LTXVideo"):
459
- from diffusers import LTXVideoTransformer3DModel
460
-
461
- assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
462
- return UnifiedCacheParams(
463
- block_adapter=BlockAdapter(
464
- pipe=pipe,
465
- transformer=pipe.transformer,
466
- blocks=pipe.transformer.transformer_blocks,
467
- blocks_name="transformer_blocks",
468
- dummy_blocks_names=[],
469
- ),
470
- forward_pattern=ForwardPattern.Pattern_2,
471
- )
472
-
473
- elif pipe_cls_name.startswith("Allegro"):
474
- from diffusers import AllegroTransformer3DModel
475
-
476
- assert isinstance(pipe.transformer, AllegroTransformer3DModel)
477
- return UnifiedCacheParams(
478
- block_adapter=BlockAdapter(
479
- pipe=pipe,
480
- transformer=pipe.transformer,
481
- blocks=pipe.transformer.transformer_blocks,
482
- blocks_name="transformer_blocks",
483
- dummy_blocks_names=[],
484
- ),
485
- forward_pattern=ForwardPattern.Pattern_2,
486
- )
487
-
488
- elif pipe_cls_name.startswith("CogView3Plus"):
489
- from diffusers import CogView3PlusTransformer2DModel
490
-
491
- assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
492
- return UnifiedCacheParams(
493
- block_adapter=BlockAdapter(
494
- pipe=pipe,
495
- transformer=pipe.transformer,
496
- blocks=pipe.transformer.transformer_blocks,
497
- blocks_name="transformer_blocks",
498
- dummy_blocks_names=[],
499
- ),
500
- forward_pattern=ForwardPattern.Pattern_0,
501
- )
502
-
503
- elif pipe_cls_name.startswith("CogView4"):
504
- from diffusers import CogView4Transformer2DModel
505
-
506
- assert isinstance(pipe.transformer, CogView4Transformer2DModel)
507
- return UnifiedCacheParams(
508
- block_adapter=BlockAdapter(
509
- pipe=pipe,
510
- transformer=pipe.transformer,
511
- blocks=pipe.transformer.transformer_blocks,
512
- blocks_name="transformer_blocks",
513
- dummy_blocks_names=[],
514
- ),
515
- forward_pattern=ForwardPattern.Pattern_0,
516
- )
517
-
518
- elif pipe_cls_name.startswith("Cosmos"):
519
- from diffusers import CosmosTransformer3DModel
520
-
521
- assert isinstance(pipe.transformer, CosmosTransformer3DModel)
522
- return UnifiedCacheParams(
523
- block_adapter=BlockAdapter(
524
- pipe=pipe,
525
- transformer=pipe.transformer,
526
- blocks=pipe.transformer.transformer_blocks,
527
- blocks_name="transformer_blocks",
528
- dummy_blocks_names=[],
529
- ),
530
- forward_pattern=ForwardPattern.Pattern_2,
531
- )
532
-
533
- elif pipe_cls_name.startswith("EasyAnimate"):
534
- from diffusers import EasyAnimateTransformer3DModel
535
-
536
- assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
537
- return UnifiedCacheParams(
538
- block_adapter=BlockAdapter(
539
- pipe=pipe,
540
- transformer=pipe.transformer,
541
- blocks=pipe.transformer.transformer_blocks,
542
- blocks_name="transformer_blocks",
543
- dummy_blocks_names=[],
544
- ),
545
- forward_pattern=ForwardPattern.Pattern_0,
546
- )
547
-
548
- elif pipe_cls_name.startswith("SkyReelsV2"):
549
- from diffusers import SkyReelsV2Transformer3DModel
550
-
551
- assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
552
- return UnifiedCacheParams(
553
- block_adapter=BlockAdapter(
554
- pipe=pipe,
555
- transformer=pipe.transformer,
556
- blocks=pipe.transformer.blocks,
557
- blocks_name="blocks",
558
- dummy_blocks_names=[],
559
- ),
560
- forward_pattern=ForwardPattern.Pattern_2,
561
- )
562
- elif pipe_cls_name.startswith("SD3"):
563
- from diffusers import SD3Transformer2DModel
564
-
565
- assert isinstance(pipe.transformer, SD3Transformer2DModel)
566
- return UnifiedCacheParams(
567
- block_adapter=BlockAdapter(
568
- pipe=pipe,
569
- transformer=pipe.transformer,
570
- blocks=pipe.transformer.transformer_blocks,
571
- blocks_name="transformer_blocks",
572
- dummy_blocks_names=[],
573
- ),
574
- forward_pattern=ForwardPattern.Pattern_1,
575
- )
576
-
577
- elif pipe_cls_name.startswith("ConsisID"):
578
- from diffusers import ConsisIDTransformer3DModel
579
-
580
- assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
581
- return UnifiedCacheParams(
582
- block_adapter=BlockAdapter(
583
- pipe=pipe,
584
- transformer=pipe.transformer,
585
- blocks=pipe.transformer.transformer_blocks,
586
- blocks_name="transformer_blocks",
587
- dummy_blocks_names=[],
588
- ),
589
- forward_pattern=ForwardPattern.Pattern_0,
590
- )
591
-
592
- elif pipe_cls_name.startswith("DiT"):
593
- from diffusers import DiTTransformer2DModel
594
-
595
- assert isinstance(pipe.transformer, DiTTransformer2DModel)
596
- return UnifiedCacheParams(
597
- block_adapter=BlockAdapter(
598
- pipe=pipe,
599
- transformer=pipe.transformer,
600
- blocks=pipe.transformer.transformer_blocks,
601
- blocks_name="transformer_blocks",
602
- dummy_blocks_names=[],
603
- ),
604
- forward_pattern=ForwardPattern.Pattern_3,
605
- )
606
-
607
- elif pipe_cls_name.startswith("Amused"):
608
- from diffusers import UVit2DModel
609
-
610
- assert isinstance(pipe.transformer, UVit2DModel)
611
- return UnifiedCacheParams(
612
- block_adapter=BlockAdapter(
613
- pipe=pipe,
614
- transformer=pipe.transformer,
615
- blocks=pipe.transformer.transformer_layers,
616
- blocks_name="transformer_layers",
617
- dummy_blocks_names=[],
618
- ),
619
- forward_pattern=ForwardPattern.Pattern_3,
620
- )
621
-
622
- elif pipe_cls_name.startswith("Bria"):
623
- from diffusers import BriaTransformer2DModel
624
-
625
- assert isinstance(pipe.transformer, BriaTransformer2DModel)
626
- return UnifiedCacheParams(
627
- block_adapter=BlockAdapter(
628
- pipe=pipe,
629
- transformer=pipe.transformer,
630
- blocks=(
631
- pipe.transformer.transformer_blocks
632
- + pipe.transformer.single_transformer_blocks
633
- ),
634
- blocks_name="transformer_blocks",
635
- dummy_blocks_names=["single_transformer_blocks"],
636
- ),
637
- forward_pattern=ForwardPattern.Pattern_0,
638
- )
639
-
640
- elif pipe_cls_name.startswith("HunyuanDiT"):
641
- from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
642
-
643
- assert isinstance(
644
- pipe.transformer,
645
- (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
646
- )
647
- return UnifiedCacheParams(
648
- block_adapter=BlockAdapter(
649
- pipe=pipe,
650
- transformer=pipe.transformer,
651
- blocks=pipe.transformer.blocks,
652
- blocks_name="blocks",
653
- dummy_blocks_names=[],
654
- ),
655
- forward_pattern=ForwardPattern.Pattern_3,
656
- )
657
-
658
- elif pipe_cls_name.startswith("HunyuanDiTPAG"):
659
- from diffusers import HunyuanDiT2DModel
660
-
661
- assert isinstance(pipe.transformer, HunyuanDiT2DModel)
662
- return UnifiedCacheParams(
663
- block_adapter=BlockAdapter(
664
- pipe=pipe,
665
- transformer=pipe.transformer,
666
- blocks=pipe.transformer.blocks,
667
- blocks_name="blocks",
668
- dummy_blocks_names=[],
669
- ),
670
- forward_pattern=ForwardPattern.Pattern_3,
671
- )
672
-
673
- elif pipe_cls_name.startswith("Lumina"):
674
- from diffusers import LuminaNextDiT2DModel
675
-
676
- assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
677
- return UnifiedCacheParams(
678
- block_adapter=BlockAdapter(
679
- pipe=pipe,
680
- transformer=pipe.transformer,
681
- blocks=pipe.transformer.layers,
682
- blocks_name="layers",
683
- dummy_blocks_names=[],
684
- ),
685
- forward_pattern=ForwardPattern.Pattern_3,
686
- )
687
-
688
- elif pipe_cls_name.startswith("Lumina2"):
689
- from diffusers import Lumina2Transformer2DModel
690
-
691
- assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
692
- return UnifiedCacheParams(
693
- block_adapter=BlockAdapter(
694
- pipe=pipe,
695
- transformer=pipe.transformer,
696
- blocks=pipe.transformer.layers,
697
- blocks_name="layers",
698
- dummy_blocks_names=[],
699
- ),
700
- forward_pattern=ForwardPattern.Pattern_3,
701
- )
702
-
703
- elif pipe_cls_name.startswith("OmniGen"):
704
- from diffusers import OmniGenTransformer2DModel
705
-
706
- assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
707
- return UnifiedCacheParams(
708
- block_adapter=BlockAdapter(
709
- pipe=pipe,
710
- transformer=pipe.transformer,
711
- blocks=pipe.transformer.layers,
712
- blocks_name="layers",
713
- dummy_blocks_names=[],
714
- ),
715
- forward_pattern=ForwardPattern.Pattern_3,
716
- )
717
-
718
- elif pipe_cls_name.startswith("PixArt"):
719
- from diffusers import PixArtTransformer2DModel
720
-
721
- assert isinstance(pipe.transformer, PixArtTransformer2DModel)
722
- return UnifiedCacheParams(
723
- block_adapter=BlockAdapter(
724
- pipe=pipe,
725
- transformer=pipe.transformer,
726
- blocks=pipe.transformer.transformer_blocks,
727
- blocks_name="transformer_blocks",
728
- dummy_blocks_names=[],
729
- ),
730
- forward_pattern=ForwardPattern.Pattern_3,
731
- )
732
-
733
- elif pipe_cls_name.startswith("Sana"):
734
- from diffusers import SanaTransformer2DModel
735
-
736
- assert isinstance(pipe.transformer, SanaTransformer2DModel)
737
- return UnifiedCacheParams(
738
- block_adapter=BlockAdapter(
739
- pipe=pipe,
740
- transformer=pipe.transformer,
741
- blocks=pipe.transformer.transformer_blocks,
742
- blocks_name="transformer_blocks",
743
- dummy_blocks_names=[],
744
- ),
745
- forward_pattern=ForwardPattern.Pattern_3,
746
- )
747
-
748
- elif pipe_cls_name.startswith("ShapE"):
749
- from diffusers import PriorTransformer
750
-
751
- assert isinstance(pipe.prior, PriorTransformer)
752
- return UnifiedCacheParams(
753
- block_adapter=BlockAdapter(
754
- pipe=pipe,
755
- transformer=pipe.prior,
756
- blocks=pipe.prior.transformer_blocks,
757
- blocks_name="transformer_blocks",
758
- dummy_blocks_names=[],
759
- ),
760
- forward_pattern=ForwardPattern.Pattern_3,
761
- )
762
-
763
- elif pipe_cls_name.startswith("StableAudio"):
764
- from diffusers import StableAudioDiTModel
765
-
766
- assert isinstance(pipe.transformer, StableAudioDiTModel)
767
- return UnifiedCacheParams(
768
- block_adapter=BlockAdapter(
769
- pipe=pipe,
770
- transformer=pipe.transformer,
771
- blocks=pipe.transformer.transformer_blocks,
772
- blocks_name="transformer_blocks",
773
- dummy_blocks_names=[],
774
- ),
775
- forward_pattern=ForwardPattern.Pattern_3,
776
- )
777
-
778
- elif pipe_cls_name.startswith("VisualCloze"):
779
- from diffusers import FluxTransformer2DModel
780
- from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
781
-
782
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
783
- return UnifiedCacheParams(
784
- block_adapter=BlockAdapter(
785
- pipe=pipe,
786
- transformer=pipe.transformer,
787
- blocks=(
788
- pipe.transformer.transformer_blocks
789
- + pipe.transformer.single_transformer_blocks
790
- ),
791
- blocks_name="transformer_blocks",
792
- dummy_blocks_names=["single_transformer_blocks"],
793
- patch_functor=FluxPatchFunctor(),
794
- ),
795
- forward_pattern=ForwardPattern.Pattern_1,
796
- )
797
-
798
- elif pipe_cls_name.startswith("AuraFlow"):
799
- from diffusers import AuraFlowTransformer2DModel
800
-
801
- assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
802
- return UnifiedCacheParams(
803
- block_adapter=BlockAdapter(
804
- pipe=pipe,
805
- transformer=pipe.transformer,
806
- # Only support caching single_transformer_blocks for AuraFlow now.
807
- # TODO: Support AuraFlowPatchFunctor.
808
- blocks=pipe.transformer.single_transformer_blocks,
809
- blocks_name="single_transformer_blocks",
810
- dummy_blocks_names=[],
811
- ),
812
- forward_pattern=ForwardPattern.Pattern_3,
813
- )
814
-
815
- elif pipe_cls_name.startswith("Chroma"):
816
- from diffusers import ChromaTransformer2DModel
817
- from cache_dit.cache_factory.patch_functors import (
818
- ChromaPatchFunctor,
819
- )
820
-
821
- assert isinstance(pipe.transformer, ChromaTransformer2DModel)
822
- return UnifiedCacheParams(
823
- block_adapter=BlockAdapter(
824
- pipe=pipe,
825
- transformer=pipe.transformer,
826
- blocks=(
827
- pipe.transformer.transformer_blocks
828
- + pipe.transformer.single_transformer_blocks
829
- ),
830
- blocks_name="transformer_blocks",
831
- dummy_blocks_names=["single_transformer_blocks"],
832
- patch_functor=ChromaPatchFunctor(),
833
- ),
834
- forward_pattern=ForwardPattern.Pattern_1,
835
- )
836
-
837
- elif pipe_cls_name.startswith("HiDream"):
838
- from diffusers import HiDreamImageTransformer2DModel
839
-
840
- assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
841
- return UnifiedCacheParams(
842
- block_adapter=BlockAdapter(
843
- pipe=pipe,
844
- transformer=pipe.transformer,
845
- # Only support caching single_stream_blocks for HiDream now.
846
- # TODO: Support HiDreamPatchFunctor.
847
- blocks=pipe.transformer.single_stream_blocks,
848
- blocks_name="single_stream_blocks",
849
- dummy_blocks_names=[],
850
- ),
851
- forward_pattern=ForwardPattern.Pattern_3,
852
- )
853
-
854
- else:
855
- raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
856
-
857
29
  @classmethod
858
30
  def apply(
859
31
  cls,
860
32
  pipe: DiffusionPipeline = None,
861
33
  block_adapter: BlockAdapter = None,
862
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
863
34
  **cache_context_kwargs,
864
35
  ) -> DiffusionPipeline:
865
36
  assert (
@@ -867,15 +38,14 @@ class UnifiedCacheAdapter:
867
38
  ), "pipe or block_adapter can not both None!"
868
39
 
869
40
  if pipe is not None:
870
- if cls.is_supported(pipe):
41
+ if BlockAdapterRegistry.is_supported(pipe):
871
42
  logger.info(
872
43
  f"{pipe.__class__.__name__} is officially supported by cache-dit. "
873
44
  "Use it's pre-defined BlockAdapter directly!"
874
45
  )
875
- params = cls.get_params(pipe)
46
+ block_adapter = BlockAdapterRegistry.get_adapter(pipe)
876
47
  return cls.cachify(
877
- params.block_adapter,
878
- forward_pattern=params.forward_pattern,
48
+ block_adapter,
879
49
  **cache_context_kwargs,
880
50
  )
881
51
  else:
@@ -889,7 +59,6 @@ class UnifiedCacheAdapter:
889
59
  )
890
60
  return cls.cachify(
891
61
  block_adapter,
892
- forward_pattern=forward_pattern,
893
62
  **cache_context_kwargs,
894
63
  )
895
64
 
@@ -897,74 +66,78 @@ class UnifiedCacheAdapter:
897
66
  def cachify(
898
67
  cls,
899
68
  block_adapter: BlockAdapter,
900
- *,
901
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
902
69
  **cache_context_kwargs,
903
70
  ) -> DiffusionPipeline:
904
71
 
905
72
  if block_adapter.auto:
906
73
  block_adapter = BlockAdapter.auto_block_adapter(
907
74
  block_adapter,
908
- forward_pattern,
909
75
  )
910
76
 
911
77
  if BlockAdapter.check_block_adapter(block_adapter):
912
- # Apply cache on pipeline: wrap cache context
78
+
79
+ # 0. Must normalize block_adapter before apply cache
80
+ block_adapter = BlockAdapter.normalize(block_adapter)
81
+ if BlockAdapter.is_cached(block_adapter):
82
+ return block_adapter.pipe
83
+
84
+ # 1. Apply cache on pipeline: wrap cache context, must
85
+ # call create_context before mock_blocks.
913
86
  cls.create_context(
914
- block_adapter.pipe,
87
+ block_adapter,
915
88
  **cache_context_kwargs,
916
89
  )
917
- # Apply cache on transformer: mock cached transformer blocks
90
+
91
+ # 2. Apply cache on transformer: mock cached blocks
918
92
  cls.mock_blocks(
919
93
  block_adapter,
920
- forward_pattern=forward_pattern,
921
- )
922
- cls.patch_params(
923
- block_adapter,
924
- forward_pattern=forward_pattern,
925
- **cache_context_kwargs,
926
94
  )
95
+
927
96
  return block_adapter.pipe
928
97
 
929
98
  @classmethod
930
99
  def patch_params(
931
100
  cls,
932
101
  block_adapter: BlockAdapter,
933
- forward_pattern: ForwardPattern = None,
934
- **cache_context_kwargs,
102
+ contexts_kwargs: List[Dict],
935
103
  ):
936
- block_adapter.transformer._forward_pattern = forward_pattern
937
- block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
938
- block_adapter.pipe.__class__._cache_context_kwargs = (
939
- cache_context_kwargs
940
- )
104
+ block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
941
105
 
942
- @classmethod
943
- def has_separate_cfg(
944
- cls,
945
- pipe_or_transformer: DiffusionPipeline | Any,
946
- ) -> bool:
947
- cls_name = pipe_or_transformer.__class__.__name__
948
- if cls_name.startswith("QwenImage"):
949
- return True
950
- elif cls_name.startswith("Wan"):
951
- return True
952
- elif cls_name.startswith("CogView4"):
953
- return True
954
- elif cls_name.startswith("Cosmos"):
955
- return True
956
- elif cls_name.startswith("SkyReelsV2"):
957
- return True
958
- elif cls_name.startswith("Chroma"):
959
- return True
960
- return False
106
+ params_shift = 0
107
+ for i in range(len(block_adapter.transformer)):
108
+
109
+ block_adapter.transformer[i]._forward_pattern = (
110
+ block_adapter.forward_pattern
111
+ )
112
+ block_adapter.transformer[i]._has_separate_cfg = (
113
+ block_adapter.has_separate_cfg
114
+ )
115
+ block_adapter.transformer[i]._cache_context_kwargs = (
116
+ contexts_kwargs[params_shift]
117
+ )
118
+
119
+ blocks = block_adapter.blocks[i]
120
+ for j in range(len(blocks)):
121
+ blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
122
+ blocks[j]._cache_context_kwargs = contexts_kwargs[
123
+ params_shift + j
124
+ ]
125
+
126
+ params_shift += len(blocks)
961
127
 
962
128
  @classmethod
963
129
  def check_context_kwargs(cls, pipe, **cache_context_kwargs):
964
130
  # Check cache_context_kwargs
965
- if not cache_context_kwargs["do_separate_cfg"]:
131
+ if not cache_context_kwargs["enable_spearate_cfg"]:
966
132
  # Check cfg for some specific case if users don't set it as True
967
- cache_context_kwargs["do_separate_cfg"] = cls.has_separate_cfg(pipe)
133
+ cache_context_kwargs["enable_spearate_cfg"] = (
134
+ BlockAdapterRegistry.has_separate_cfg(pipe)
135
+ )
136
+ logger.info(
137
+ f"Use default 'enable_spearate_cfg': "
138
+ f"{cache_context_kwargs['enable_spearate_cfg']}, "
139
+ f"Pipeline: {pipe.__class__.__name__}."
140
+ )
968
141
 
969
142
  if cache_type := cache_context_kwargs.pop("cache_type", None):
970
143
  assert (
@@ -976,95 +149,246 @@ class UnifiedCacheAdapter:
976
149
  @classmethod
977
150
  def create_context(
978
151
  cls,
979
- pipe: DiffusionPipeline,
152
+ block_adapter: BlockAdapter,
980
153
  **cache_context_kwargs,
981
154
  ) -> DiffusionPipeline:
982
- if getattr(pipe, "_is_cached", False):
983
- return pipe
155
+
156
+ BlockAdapter.assert_normalized(block_adapter)
157
+
158
+ if BlockAdapter.is_cached(block_adapter.pipe):
159
+ return block_adapter.pipe
984
160
 
985
161
  # Check cache_context_kwargs
986
162
  cache_context_kwargs = cls.check_context_kwargs(
987
- pipe,
163
+ block_adapter.pipe,
988
164
  **cache_context_kwargs,
989
165
  )
990
166
  # Apply cache on pipeline: wrap cache context
991
- cache_kwargs, _ = cache_context.collect_cache_kwargs(
992
- default_attrs={},
993
- **cache_context_kwargs,
167
+ pipe_cls_name = block_adapter.pipe.__class__.__name__
168
+
169
+ # Each Pipeline should have it's own context manager instance.
170
+ # Different transformers (Wan2.2, etc) should shared the same
171
+ # cache manager but with different cache context (according
172
+ # to their unique instance id).
173
+ cache_manager = CachedContextManager(
174
+ name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
994
175
  )
995
- original_call = pipe.__class__.__call__
176
+ block_adapter.pipe._cache_manager = cache_manager # instance level
177
+
178
+ flatten_contexts, contexts_kwargs = cls.modify_context_params(
179
+ block_adapter, cache_manager, **cache_context_kwargs
180
+ )
181
+
182
+ original_call = block_adapter.pipe.__class__.__call__
996
183
 
997
184
  @functools.wraps(original_call)
998
185
  def new_call(self, *args, **kwargs):
999
- with cache_context.cache_context(
1000
- cache_context.create_cache_context(
1001
- **cache_kwargs,
1002
- )
1003
- ):
1004
- return original_call(self, *args, **kwargs)
186
+ with ExitStack() as stack:
187
+ # cache context will be reset for each pipe inference
188
+ for context_name, context_kwargs in zip(
189
+ flatten_contexts, contexts_kwargs
190
+ ):
191
+ stack.enter_context(
192
+ cache_manager.enter_context(
193
+ cache_manager.reset_context(
194
+ context_name,
195
+ **context_kwargs,
196
+ ),
197
+ )
198
+ )
199
+ outputs = original_call(self, *args, **kwargs)
200
+ cls.patch_stats(block_adapter)
201
+ return outputs
202
+
203
+ block_adapter.pipe.__class__.__call__ = new_call
204
+ block_adapter.pipe.__class__._original_call = original_call
205
+ block_adapter.pipe.__class__._is_cached = True
1005
206
 
1006
- pipe.__class__.__call__ = new_call
1007
- pipe.__class__._is_cached = True
1008
- return pipe
207
+ cls.patch_params(block_adapter, contexts_kwargs)
208
+
209
+ return block_adapter.pipe
1009
210
 
1010
211
  @classmethod
1011
- def mock_blocks(
212
+ def modify_context_params(
1012
213
  cls,
1013
214
  block_adapter: BlockAdapter,
1014
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
1015
- ) -> torch.nn.Module:
215
+ cache_manager: CachedContextManager,
216
+ **cache_context_kwargs,
217
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
1016
218
 
1017
- if getattr(block_adapter.transformer, "_is_cached", False):
1018
- return block_adapter.transformer
219
+ flatten_contexts = BlockAdapter.flatten(
220
+ block_adapter.unique_blocks_name
221
+ )
222
+ contexts_kwargs = [
223
+ cache_context_kwargs.copy()
224
+ for _ in range(
225
+ len(flatten_contexts),
226
+ )
227
+ ]
1019
228
 
1020
- # Check block forward pattern matching
1021
- assert BlockAdapter.match_blocks_pattern(
1022
- block_adapter.blocks,
1023
- forward_pattern=forward_pattern,
1024
- ), (
1025
- "No block forward pattern matched, "
1026
- f"supported lists: {ForwardPattern.supported_patterns()}"
229
+ for i in range(len(contexts_kwargs)):
230
+ contexts_kwargs[i]["name"] = flatten_contexts[i]
231
+
232
+ if block_adapter.params_modifiers is None:
233
+ return flatten_contexts, contexts_kwargs
234
+
235
+ flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
236
+ block_adapter.params_modifiers,
1027
237
  )
1028
238
 
1029
- # Apply cache on transformer: mock cached transformer blocks
1030
- cached_blocks = torch.nn.ModuleList(
1031
- [
1032
- DBCachedBlocks(
1033
- block_adapter.blocks,
1034
- transformer=block_adapter.transformer,
1035
- forward_pattern=forward_pattern,
1036
- )
1037
- ]
239
+ for i in range(
240
+ min(len(contexts_kwargs), len(flatten_modifiers)),
241
+ ):
242
+ contexts_kwargs[i].update(
243
+ flatten_modifiers[i]._context_kwargs,
244
+ )
245
+ contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
246
+ default_attrs={}, **contexts_kwargs[i]
247
+ )
248
+
249
+ return flatten_contexts, contexts_kwargs
250
+
251
+ @classmethod
252
+ def patch_stats(
253
+ cls,
254
+ block_adapter: BlockAdapter,
255
+ ):
256
+ from cache_dit.cache_factory.cache_blocks.utils import (
257
+ patch_cached_stats,
1038
258
  )
259
+
260
+ cache_manager = block_adapter.pipe._cache_manager
261
+
262
+ for i in range(len(block_adapter.transformer)):
263
+ patch_cached_stats(
264
+ block_adapter.transformer[i],
265
+ cache_context=block_adapter.unique_blocks_name[i][-1],
266
+ cache_manager=cache_manager,
267
+ )
268
+ for blocks, unique_name in zip(
269
+ block_adapter.blocks[i],
270
+ block_adapter.unique_blocks_name[i],
271
+ ):
272
+ patch_cached_stats(
273
+ blocks,
274
+ cache_context=unique_name,
275
+ cache_manager=cache_manager,
276
+ )
277
+
278
+ @classmethod
279
+ def mock_blocks(
280
+ cls,
281
+ block_adapter: BlockAdapter,
282
+ ) -> List[torch.nn.Module]:
283
+
284
+ BlockAdapter.assert_normalized(block_adapter)
285
+
286
+ if BlockAdapter.is_cached(block_adapter.transformer):
287
+ return block_adapter.transformer
288
+
289
+ # Apply cache on transformer: mock cached transformer blocks
290
+ for (
291
+ cached_blocks,
292
+ transformer,
293
+ blocks_name,
294
+ unique_blocks_name,
295
+ dummy_blocks_names,
296
+ ) in zip(
297
+ cls.collect_cached_blocks(block_adapter),
298
+ block_adapter.transformer,
299
+ block_adapter.blocks_name,
300
+ block_adapter.unique_blocks_name,
301
+ block_adapter.dummy_blocks_names,
302
+ ):
303
+ cls.mock_transformer(
304
+ cached_blocks,
305
+ transformer,
306
+ blocks_name,
307
+ unique_blocks_name,
308
+ dummy_blocks_names,
309
+ )
310
+
311
+ return block_adapter.transformer
312
+
313
+ @classmethod
314
+ def mock_transformer(
315
+ cls,
316
+ cached_blocks: Dict[str, torch.nn.ModuleList],
317
+ transformer: torch.nn.Module,
318
+ blocks_name: List[str],
319
+ unique_blocks_name: List[str],
320
+ dummy_blocks_names: List[str],
321
+ ) -> torch.nn.Module:
1039
322
  dummy_blocks = torch.nn.ModuleList()
1040
323
 
1041
- original_forward = block_adapter.transformer.forward
324
+ original_forward = transformer.forward
1042
325
 
1043
- assert isinstance(block_adapter.dummy_blocks_names, list)
326
+ assert isinstance(dummy_blocks_names, list)
1044
327
 
1045
328
  @functools.wraps(original_forward)
1046
329
  def new_forward(self, *args, **kwargs):
1047
330
  with ExitStack() as stack:
1048
- stack.enter_context(
1049
- unittest.mock.patch.object(
1050
- self,
1051
- block_adapter.blocks_name,
1052
- cached_blocks,
331
+ for name, context_name in zip(
332
+ blocks_name,
333
+ unique_blocks_name,
334
+ ):
335
+ stack.enter_context(
336
+ unittest.mock.patch.object(
337
+ self, name, cached_blocks[context_name]
338
+ )
1053
339
  )
1054
- )
1055
- for dummy_name in block_adapter.dummy_blocks_names:
340
+ for dummy_name in dummy_blocks_names:
1056
341
  stack.enter_context(
1057
342
  unittest.mock.patch.object(
1058
- self,
1059
- dummy_name,
1060
- dummy_blocks,
343
+ self, dummy_name, dummy_blocks
1061
344
  )
1062
345
  )
1063
346
  return original_forward(*args, **kwargs)
1064
347
 
1065
- block_adapter.transformer.forward = new_forward.__get__(
1066
- block_adapter.transformer
348
+ transformer.forward = new_forward.__get__(transformer)
349
+ transformer._original_forward = original_forward
350
+ transformer._is_cached = True
351
+
352
+ return transformer
353
+
354
+ @classmethod
355
+ def collect_cached_blocks(
356
+ cls,
357
+ block_adapter: BlockAdapter,
358
+ ) -> List[Dict[str, torch.nn.ModuleList]]:
359
+
360
+ BlockAdapter.assert_normalized(block_adapter)
361
+
362
+ total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
363
+ assert hasattr(block_adapter.pipe, "_cache_manager")
364
+ assert isinstance(
365
+ block_adapter.pipe._cache_manager, CachedContextManager
1067
366
  )
1068
- block_adapter.transformer._is_cached = True
1069
367
 
1070
- return block_adapter.transformer
368
+ for i in range(len(block_adapter.transformer)):
369
+
370
+ cached_blocks_bind_context = {}
371
+ for j in range(len(block_adapter.blocks[i])):
372
+ cached_blocks_bind_context[
373
+ block_adapter.unique_blocks_name[i][j]
374
+ ] = torch.nn.ModuleList(
375
+ [
376
+ CachedBlocks(
377
+ # 0. Transformer blocks configuration
378
+ block_adapter.blocks[i][j],
379
+ transformer=block_adapter.transformer[i],
380
+ forward_pattern=block_adapter.forward_pattern[i][j],
381
+ check_num_outputs=block_adapter.check_num_outputs,
382
+ # 1. Cache context configuration
383
+ cache_prefix=block_adapter.blocks_name[i][j],
384
+ cache_context=block_adapter.unique_blocks_name[i][
385
+ j
386
+ ],
387
+ cache_manager=block_adapter.pipe._cache_manager,
388
+ )
389
+ ]
390
+ )
391
+
392
+ total_cached_blocks.append(cached_blocks_bind_context)
393
+
394
+ return total_cached_blocks