cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (32) hide show
  1. cache_dit/__init__.py +9 -4
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +16 -3
  4. cache_dit/cache_factory/block_adapters/__init__.py +538 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +121 -563
  8. cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
  11. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
  12. cache_dit/cache_factory/cache_blocks/utils.py +23 -0
  13. cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
  14. cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
  15. cache_dit/cache_factory/cache_interface.py +24 -16
  16. cache_dit/cache_factory/forward_pattern.py +45 -24
  17. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  18. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  19. cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
  20. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
  21. cache_dit/quantize/quantize_ao.py +19 -4
  22. cache_dit/quantize/quantize_interface.py +2 -2
  23. cache_dit/utils.py +19 -15
  24. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
  25. cache_dit-0.2.27.dist-info/RECORD +47 -0
  26. cache_dit-0.2.25.dist-info/RECORD +0 -36
  27. /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
  28. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  29. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
  30. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
  31. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
  32. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
@@ -1,524 +1,35 @@
1
1
  import torch
2
2
 
3
- import inspect
4
3
  import unittest
5
4
  import functools
6
- import dataclasses
7
5
 
8
- from typing import Any, Tuple, List, Optional
6
+ from typing import Dict
9
7
  from contextlib import ExitStack
10
8
  from diffusers import DiffusionPipeline
11
- from cache_dit.cache_factory.patch.flux import (
12
- maybe_patch_flux_transformer,
13
- )
14
9
  from cache_dit.cache_factory import CacheType
10
+ from cache_dit.cache_factory import CachedContext
15
11
  from cache_dit.cache_factory import ForwardPattern
16
- from cache_dit.cache_factory.cache_blocks import (
17
- cache_context,
18
- DBCachedTransformerBlocks,
19
- )
12
+ from cache_dit.cache_factory import BlockAdapter
13
+ from cache_dit.cache_factory import BlockAdapterRegistry
14
+ from cache_dit.cache_factory import CachedBlocks
15
+
20
16
  from cache_dit.logger import init_logger
21
17
 
22
18
  logger = init_logger(__name__)
23
19
 
24
20
 
25
- @dataclasses.dataclass
26
- class BlockAdapter:
27
- pipe: DiffusionPipeline = None
28
- transformer: torch.nn.Module = None
29
- blocks: torch.nn.ModuleList = None
30
- # transformer_blocks, blocks, etc.
31
- blocks_name: str = None
32
- dummy_blocks_names: list[str] = dataclasses.field(default_factory=list)
33
- # flags to control auto block adapter
34
- auto: bool = False
35
- allow_prefixes: List[str] = dataclasses.field(
36
- default_factory=lambda: [
37
- "transformer",
38
- "single_transformer",
39
- "blocks",
40
- "layers",
41
- ]
42
- )
43
- check_prefixes: bool = True
44
- allow_suffixes: List[str] = dataclasses.field(
45
- default_factory=lambda: ["TransformerBlock"]
46
- )
47
- check_suffixes: bool = False
48
- blocks_policy: str = dataclasses.field(
49
- default="max", metadata={"allowed_values": ["max", "min"]}
50
- )
51
-
52
- def __post_init__(self):
53
- self.maybe_apply_patch()
54
-
55
- def maybe_apply_patch(self):
56
- # Process some specificial cases, specific for transformers
57
- # that has different forward patterns between single_transformer_blocks
58
- # and transformer_blocks , such as Flux (diffusers < 0.35.0).
59
- if self.transformer.__class__.__name__.startswith("Flux"):
60
- self.transformer = maybe_patch_flux_transformer(
61
- self.transformer,
62
- blocks=self.blocks,
63
- )
64
-
65
- @staticmethod
66
- def auto_block_adapter(
67
- adapter: "BlockAdapter",
68
- forward_pattern: Optional[ForwardPattern] = None,
69
- ) -> "BlockAdapter":
70
- assert adapter.auto, (
71
- "Please manually set `auto` to True, or, manually "
72
- "set all the transformer blocks configuration."
73
- )
74
- assert adapter.pipe is not None, "adapter.pipe can not be None."
75
- pipe = adapter.pipe
76
-
77
- assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
78
-
79
- transformer = pipe.transformer
80
-
81
- # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
82
- blocks, blocks_name = BlockAdapter.find_blocks(
83
- transformer=transformer,
84
- allow_prefixes=adapter.allow_prefixes,
85
- allow_suffixes=adapter.allow_suffixes,
86
- check_prefixes=adapter.check_prefixes,
87
- check_suffixes=adapter.check_suffixes,
88
- blocks_policy=adapter.blocks_policy,
89
- forward_pattern=forward_pattern,
90
- )
91
-
92
- return BlockAdapter(
93
- pipe=pipe,
94
- transformer=transformer,
95
- blocks=blocks,
96
- blocks_name=blocks_name,
97
- )
98
-
99
- @staticmethod
100
- def check_block_adapter(adapter: "BlockAdapter") -> bool:
101
- if (
102
- isinstance(adapter.pipe, DiffusionPipeline)
103
- and adapter.transformer is not None
104
- and adapter.blocks is not None
105
- and adapter.blocks_name is not None
106
- and isinstance(adapter.blocks, torch.nn.ModuleList)
107
- ):
108
- return True
109
-
110
- logger.warning("Check block adapter failed!")
111
- return False
112
-
113
- @staticmethod
114
- def find_blocks(
115
- transformer: torch.nn.Module,
116
- allow_prefixes: List[str] = [
117
- "transformer",
118
- "single_transformer",
119
- "blocks",
120
- "layers",
121
- ],
122
- allow_suffixes: List[str] = [
123
- "TransformerBlock",
124
- ],
125
- check_prefixes: bool = True,
126
- check_suffixes: bool = False,
127
- **kwargs,
128
- ) -> Tuple[torch.nn.ModuleList, str]:
129
- # Check prefixes
130
- if check_prefixes:
131
- blocks_names = []
132
- for attr_name in dir(transformer):
133
- for prefix in allow_prefixes:
134
- if attr_name.startswith(prefix):
135
- blocks_names.append(attr_name)
136
- else:
137
- blocks_names = dir(transformer)
138
-
139
- # Check ModuleList
140
- valid_names = []
141
- valid_count = []
142
- forward_pattern = kwargs.get("forward_pattern", None)
143
- for blocks_name in blocks_names:
144
- if blocks := getattr(transformer, blocks_name, None):
145
- if isinstance(blocks, torch.nn.ModuleList):
146
- block = blocks[0]
147
- block_cls_name = block.__class__.__name__
148
- # Check suffixes
149
- if isinstance(block, torch.nn.Module) and (
150
- any(
151
- (
152
- block_cls_name.endswith(allow_suffix)
153
- for allow_suffix in allow_suffixes
154
- )
155
- )
156
- or (not check_suffixes)
157
- ):
158
- # May check forward pattern
159
- if forward_pattern is not None:
160
- if BlockAdapter.match_blocks_pattern(
161
- blocks,
162
- forward_pattern,
163
- logging=False,
164
- ):
165
- valid_names.append(blocks_name)
166
- valid_count.append(len(blocks))
167
- else:
168
- valid_names.append(blocks_name)
169
- valid_count.append(len(blocks))
170
-
171
- if not valid_names:
172
- raise ValueError(
173
- "Auto selected transformer blocks failed, please set it manually."
174
- )
175
-
176
- final_name = valid_names[0]
177
- final_count = valid_count[0]
178
- block_policy = kwargs.get("blocks_policy", "max")
179
-
180
- for blocks_name, count in zip(valid_names, valid_count):
181
- blocks = getattr(transformer, blocks_name)
182
- logger.info(
183
- f"Auto selected transformer blocks: {blocks_name}, "
184
- f"class: {blocks[0].__class__.__name__}, "
185
- f"num blocks: {count}"
186
- )
187
- if block_policy == "max":
188
- if final_count < count:
189
- final_count = count
190
- final_name = blocks_name
191
- else:
192
- if final_count > count:
193
- final_count = count
194
- final_name = blocks_name
195
-
196
- final_blocks = getattr(transformer, final_name)
197
-
198
- logger.info(
199
- f"Final selected transformer blocks: {final_name}, "
200
- f"class: {final_blocks[0].__class__.__name__}, "
201
- f"num blocks: {final_count}, block_policy: {block_policy}."
202
- )
203
-
204
- return final_blocks, final_name
205
-
206
- @staticmethod
207
- def match_block_pattern(
208
- block: torch.nn.Module,
209
- forward_pattern: ForwardPattern,
210
- ) -> bool:
211
- assert (
212
- forward_pattern.Supported
213
- and forward_pattern in ForwardPattern.supported_patterns()
214
- ), f"Pattern {forward_pattern} is not support now!"
215
-
216
- forward_parameters = set(
217
- inspect.signature(block.forward).parameters.keys()
218
- )
219
- num_outputs = str(
220
- inspect.signature(block.forward).return_annotation
221
- ).count("torch.Tensor")
222
-
223
- in_matched = True
224
- out_matched = True
225
- if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
226
- # output pattern not match
227
- out_matched = False
228
-
229
- for required_param in forward_pattern.In:
230
- if required_param not in forward_parameters:
231
- in_matched = False
232
-
233
- return in_matched and out_matched
234
-
235
- @staticmethod
236
- def match_blocks_pattern(
237
- transformer_blocks: torch.nn.ModuleList,
238
- forward_pattern: ForwardPattern,
239
- logging: bool = True,
240
- ) -> bool:
241
- assert (
242
- forward_pattern.Supported
243
- and forward_pattern in ForwardPattern.supported_patterns()
244
- ), f"Pattern {forward_pattern} is not support now!"
245
-
246
- assert isinstance(transformer_blocks, torch.nn.ModuleList)
247
-
248
- pattern_matched_states = []
249
- for block in transformer_blocks:
250
- pattern_matched_states.append(
251
- BlockAdapter.match_block_pattern(
252
- block,
253
- forward_pattern,
254
- )
255
- )
256
-
257
- pattern_matched = all(pattern_matched_states) # all block match
258
- if pattern_matched and logging:
259
- block_cls_name = transformer_blocks[0].__class__.__name__
260
- logger.info(
261
- f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
262
- f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
263
- )
264
-
265
- return pattern_matched
266
-
267
-
268
- @dataclasses.dataclass
269
- class UnifiedCacheParams:
270
- block_adapter: BlockAdapter = None
271
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
272
-
273
-
274
- class UnifiedCacheAdapter:
275
- _supported_pipelines = [
276
- "Flux",
277
- "Mochi",
278
- "CogVideoX",
279
- "Wan",
280
- "HunyuanVideo",
281
- "QwenImage",
282
- "LTXVideo",
283
- "Allegro",
284
- "CogView3Plus",
285
- "CogView4",
286
- "Cosmos",
287
- "EasyAnimate",
288
- "SkyReelsV2",
289
- "SD3",
290
- ]
21
+ # Unified Cached Adapter
22
+ class CachedAdapter:
291
23
 
292
24
  def __call__(self, *args, **kwargs):
293
25
  return self.apply(*args, **kwargs)
294
26
 
295
- @classmethod
296
- def is_supported(cls, pipe: DiffusionPipeline) -> bool:
297
- pipe_cls_name: str = pipe.__class__.__name__
298
- for prefix in cls._supported_pipelines:
299
- if pipe_cls_name.startswith(prefix):
300
- return True
301
- return False
302
-
303
- @classmethod
304
- def get_params(cls, pipe: DiffusionPipeline) -> UnifiedCacheParams:
305
- pipe_cls_name: str = pipe.__class__.__name__
306
- if pipe_cls_name.startswith("Flux"):
307
- from diffusers import FluxTransformer2DModel
308
-
309
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
310
- return UnifiedCacheParams(
311
- block_adapter=BlockAdapter(
312
- pipe=pipe,
313
- transformer=pipe.transformer,
314
- blocks=(
315
- pipe.transformer.transformer_blocks
316
- + pipe.transformer.single_transformer_blocks
317
- ),
318
- blocks_name="transformer_blocks",
319
- dummy_blocks_names=["single_transformer_blocks"],
320
- ),
321
- forward_pattern=ForwardPattern.Pattern_1,
322
- )
323
- elif pipe_cls_name.startswith("Mochi"):
324
- from diffusers import MochiTransformer3DModel
325
-
326
- assert isinstance(pipe.transformer, MochiTransformer3DModel)
327
- return UnifiedCacheParams(
328
- block_adapter=BlockAdapter(
329
- pipe=pipe,
330
- transformer=pipe.transformer,
331
- blocks=pipe.transformer.transformer_blocks,
332
- blocks_name="transformer_blocks",
333
- dummy_blocks_names=[],
334
- ),
335
- forward_pattern=ForwardPattern.Pattern_0,
336
- )
337
- elif pipe_cls_name.startswith("CogVideoX"):
338
- from diffusers import CogVideoXTransformer3DModel
339
-
340
- assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
341
- return UnifiedCacheParams(
342
- block_adapter=BlockAdapter(
343
- pipe=pipe,
344
- transformer=pipe.transformer,
345
- blocks=pipe.transformer.transformer_blocks,
346
- blocks_name="transformer_blocks",
347
- dummy_blocks_names=[],
348
- ),
349
- forward_pattern=ForwardPattern.Pattern_0,
350
- )
351
- elif pipe_cls_name.startswith("Wan"):
352
- from diffusers import (
353
- WanTransformer3DModel,
354
- WanVACETransformer3DModel,
355
- )
356
-
357
- assert isinstance(
358
- pipe.transformer,
359
- (WanTransformer3DModel, WanVACETransformer3DModel),
360
- )
361
- return UnifiedCacheParams(
362
- block_adapter=BlockAdapter(
363
- pipe=pipe,
364
- transformer=pipe.transformer,
365
- blocks=pipe.transformer.blocks,
366
- blocks_name="blocks",
367
- dummy_blocks_names=[],
368
- ),
369
- forward_pattern=ForwardPattern.Pattern_2,
370
- )
371
- elif pipe_cls_name.startswith("HunyuanVideo"):
372
- from diffusers import HunyuanVideoTransformer3DModel
373
-
374
- assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
375
- return UnifiedCacheParams(
376
- block_adapter=BlockAdapter(
377
- pipe=pipe,
378
- blocks=(
379
- pipe.transformer.transformer_blocks
380
- + pipe.transformer.single_transformer_blocks
381
- ),
382
- blocks_name="transformer_blocks",
383
- dummy_blocks_names=["single_transformer_blocks"],
384
- ),
385
- forward_pattern=ForwardPattern.Pattern_0,
386
- )
387
- elif pipe_cls_name.startswith("QwenImage"):
388
- from diffusers import QwenImageTransformer2DModel
389
-
390
- assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
391
- return UnifiedCacheParams(
392
- block_adapter=BlockAdapter(
393
- pipe=pipe,
394
- transformer=pipe.transformer,
395
- blocks=pipe.transformer.transformer_blocks,
396
- blocks_name="transformer_blocks",
397
- dummy_blocks_names=[],
398
- ),
399
- forward_pattern=ForwardPattern.Pattern_1,
400
- )
401
- elif pipe_cls_name.startswith("LTXVideo"):
402
- from diffusers import LTXVideoTransformer3DModel
403
-
404
- assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
405
- return UnifiedCacheParams(
406
- block_adapter=BlockAdapter(
407
- pipe=pipe,
408
- transformer=pipe.transformer,
409
- blocks=pipe.transformer.transformer_blocks,
410
- blocks_name="transformer_blocks",
411
- dummy_blocks_names=[],
412
- ),
413
- forward_pattern=ForwardPattern.Pattern_2,
414
- )
415
- elif pipe_cls_name.startswith("Allegro"):
416
- from diffusers import AllegroTransformer3DModel
417
-
418
- assert isinstance(pipe.transformer, AllegroTransformer3DModel)
419
- return UnifiedCacheParams(
420
- block_adapter=BlockAdapter(
421
- pipe=pipe,
422
- transformer=pipe.transformer,
423
- blocks=pipe.transformer.transformer_blocks,
424
- blocks_name="transformer_blocks",
425
- dummy_blocks_names=[],
426
- ),
427
- forward_pattern=ForwardPattern.Pattern_2,
428
- )
429
- elif pipe_cls_name.startswith("CogView3Plus"):
430
- from diffusers import CogView3PlusTransformer2DModel
431
-
432
- assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
433
- return UnifiedCacheParams(
434
- block_adapter=BlockAdapter(
435
- pipe=pipe,
436
- transformer=pipe.transformer,
437
- blocks=pipe.transformer.transformer_blocks,
438
- blocks_name="transformer_blocks",
439
- dummy_blocks_names=[],
440
- ),
441
- forward_pattern=ForwardPattern.Pattern_0,
442
- )
443
- elif pipe_cls_name.startswith("CogView4"):
444
- from diffusers import CogView4Transformer2DModel
445
-
446
- assert isinstance(pipe.transformer, CogView4Transformer2DModel)
447
- return UnifiedCacheParams(
448
- block_adapter=BlockAdapter(
449
- pipe=pipe,
450
- transformer=pipe.transformer,
451
- blocks=pipe.transformer.transformer_blocks,
452
- blocks_name="transformer_blocks",
453
- dummy_blocks_names=[],
454
- ),
455
- forward_pattern=ForwardPattern.Pattern_0,
456
- )
457
- elif pipe_cls_name.startswith("Cosmos"):
458
- from diffusers import CosmosTransformer3DModel
459
-
460
- assert isinstance(pipe.transformer, CosmosTransformer3DModel)
461
- return UnifiedCacheParams(
462
- block_adapter=BlockAdapter(
463
- pipe=pipe,
464
- transformer=pipe.transformer,
465
- blocks=pipe.transformer.transformer_blocks,
466
- blocks_name="transformer_blocks",
467
- dummy_blocks_names=[],
468
- ),
469
- forward_pattern=ForwardPattern.Pattern_2,
470
- )
471
- elif pipe_cls_name.startswith("EasyAnimate"):
472
- from diffusers import EasyAnimateTransformer3DModel
473
-
474
- assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
475
- return UnifiedCacheParams(
476
- block_adapter=BlockAdapter(
477
- pipe=pipe,
478
- transformer=pipe.transformer,
479
- blocks=pipe.transformer.transformer_blocks,
480
- blocks_name="transformer_blocks",
481
- dummy_blocks_names=[],
482
- ),
483
- forward_pattern=ForwardPattern.Pattern_0,
484
- )
485
- elif pipe_cls_name.startswith("SkyReelsV2"):
486
- from diffusers import SkyReelsV2Transformer3DModel
487
-
488
- assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
489
- return UnifiedCacheParams(
490
- block_adapter=BlockAdapter(
491
- pipe=pipe,
492
- transformer=pipe.transformer,
493
- blocks=pipe.transformer.blocks,
494
- blocks_name="blocks",
495
- dummy_blocks_names=[],
496
- ),
497
- forward_pattern=ForwardPattern.Pattern_2,
498
- )
499
- elif pipe_cls_name.startswith("SD3"):
500
- from diffusers import SD3Transformer2DModel
501
-
502
- assert isinstance(pipe.transformer, SD3Transformer2DModel)
503
- return UnifiedCacheParams(
504
- block_adapter=BlockAdapter(
505
- pipe=pipe,
506
- transformer=pipe.transformer,
507
- blocks=pipe.transformer.transformer_blocks,
508
- blocks_name="transformer_blocks",
509
- dummy_blocks_names=[],
510
- ),
511
- forward_pattern=ForwardPattern.Pattern_1,
512
- )
513
- else:
514
- raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
515
-
516
27
  @classmethod
517
28
  def apply(
518
29
  cls,
519
30
  pipe: DiffusionPipeline = None,
520
31
  block_adapter: BlockAdapter = None,
521
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
32
+ # forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
522
33
  **cache_context_kwargs,
523
34
  ) -> DiffusionPipeline:
524
35
  assert (
@@ -526,15 +37,14 @@ class UnifiedCacheAdapter:
526
37
  ), "pipe or block_adapter can not both None!"
527
38
 
528
39
  if pipe is not None:
529
- if cls.is_supported(pipe):
40
+ if BlockAdapterRegistry.is_supported(pipe):
530
41
  logger.info(
531
42
  f"{pipe.__class__.__name__} is officially supported by cache-dit. "
532
43
  "Use it's pre-defined BlockAdapter directly!"
533
44
  )
534
- params = cls.get_params(pipe)
45
+ block_adapter = BlockAdapterRegistry.get_adapter(pipe)
535
46
  return cls.cachify(
536
- params.block_adapter,
537
- forward_pattern=params.forward_pattern,
47
+ block_adapter,
538
48
  **cache_context_kwargs,
539
49
  )
540
50
  else:
@@ -548,7 +58,6 @@ class UnifiedCacheAdapter:
548
58
  )
549
59
  return cls.cachify(
550
60
  block_adapter,
551
- forward_pattern=forward_pattern,
552
61
  **cache_context_kwargs,
553
62
  )
554
63
 
@@ -556,31 +65,27 @@ class UnifiedCacheAdapter:
556
65
  def cachify(
557
66
  cls,
558
67
  block_adapter: BlockAdapter,
559
- *,
560
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
561
68
  **cache_context_kwargs,
562
69
  ) -> DiffusionPipeline:
563
70
 
564
71
  if block_adapter.auto:
565
72
  block_adapter = BlockAdapter.auto_block_adapter(
566
73
  block_adapter,
567
- forward_pattern,
568
74
  )
569
75
 
570
76
  if BlockAdapter.check_block_adapter(block_adapter):
571
- # Apply cache on pipeline: wrap cache context
77
+ block_adapter = BlockAdapter.normalize(block_adapter)
78
+ # 0. Apply cache on pipeline: wrap cache context
572
79
  cls.create_context(
573
- block_adapter.pipe,
80
+ block_adapter,
574
81
  **cache_context_kwargs,
575
82
  )
576
- # Apply cache on transformer: mock cached transformer blocks
83
+ # 1. Apply cache on transformer: mock cached transformer blocks
577
84
  cls.mock_blocks(
578
85
  block_adapter,
579
- forward_pattern=forward_pattern,
580
86
  )
581
87
  cls.patch_params(
582
88
  block_adapter,
583
- forward_pattern=forward_pattern,
584
89
  **cache_context_kwargs,
585
90
  )
586
91
  return block_adapter.pipe
@@ -589,33 +94,36 @@ class UnifiedCacheAdapter:
589
94
  def patch_params(
590
95
  cls,
591
96
  block_adapter: BlockAdapter,
592
- forward_pattern: ForwardPattern = None,
593
97
  **cache_context_kwargs,
594
98
  ):
595
- block_adapter.transformer._forward_pattern = forward_pattern
99
+ block_adapter.transformer._forward_pattern = (
100
+ block_adapter.forward_pattern
101
+ )
102
+ block_adapter.transformer._has_separate_cfg = (
103
+ block_adapter.has_separate_cfg
104
+ )
596
105
  block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
597
106
  block_adapter.pipe.__class__._cache_context_kwargs = (
598
107
  cache_context_kwargs
599
108
  )
600
-
601
- @classmethod
602
- def has_separate_cfg(
603
- cls,
604
- pipe_or_transformer: DiffusionPipeline | Any,
605
- ) -> bool:
606
- cls_name = pipe_or_transformer.__class__.__name__
607
- if cls_name.startswith("QwenImage"):
608
- return True
609
- elif cls_name.startswith("Wan"):
610
- return True
611
- return False
109
+ for blocks, forward_pattern in zip(
110
+ block_adapter.blocks, block_adapter.forward_pattern
111
+ ):
112
+ blocks._forward_pattern = forward_pattern
113
+ blocks._cache_context_kwargs = cache_context_kwargs
612
114
 
613
115
  @classmethod
614
116
  def check_context_kwargs(cls, pipe, **cache_context_kwargs):
615
117
  # Check cache_context_kwargs
616
118
  if not cache_context_kwargs["do_separate_cfg"]:
617
119
  # Check cfg for some specific case if users don't set it as True
618
- cache_context_kwargs["do_separate_cfg"] = cls.has_separate_cfg(pipe)
120
+ cache_context_kwargs["do_separate_cfg"] = (
121
+ BlockAdapterRegistry.has_separate_cfg(pipe)
122
+ )
123
+ logger.info(
124
+ f"Use default 'do_separate_cfg': {cache_context_kwargs['do_separate_cfg']}, "
125
+ f"Pipeline: {pipe.__class__.__name__}."
126
+ )
619
127
 
620
128
  if cache_type := cache_context_kwargs.pop("cache_type", None):
621
129
  assert (
@@ -627,65 +135,87 @@ class UnifiedCacheAdapter:
627
135
  @classmethod
628
136
  def create_context(
629
137
  cls,
630
- pipe: DiffusionPipeline,
138
+ block_adapter: BlockAdapter,
631
139
  **cache_context_kwargs,
632
140
  ) -> DiffusionPipeline:
633
- if getattr(pipe, "_is_cached", False):
634
- return pipe
141
+ if getattr(block_adapter.pipe, "_is_cached", False):
142
+ return block_adapter.pipe
635
143
 
636
144
  # Check cache_context_kwargs
637
145
  cache_context_kwargs = cls.check_context_kwargs(
638
- pipe,
146
+ block_adapter.pipe,
639
147
  **cache_context_kwargs,
640
148
  )
641
149
  # Apply cache on pipeline: wrap cache context
642
- cache_kwargs, _ = cache_context.collect_cache_kwargs(
150
+ cache_kwargs, _ = CachedContext.collect_cache_kwargs(
643
151
  default_attrs={},
644
152
  **cache_context_kwargs,
645
153
  )
646
- original_call = pipe.__class__.__call__
154
+ original_call = block_adapter.pipe.__class__.__call__
647
155
 
648
156
  @functools.wraps(original_call)
649
157
  def new_call(self, *args, **kwargs):
650
- with cache_context.cache_context(
651
- cache_context.create_cache_context(
652
- **cache_kwargs,
653
- )
654
- ):
655
- return original_call(self, *args, **kwargs)
158
+ with ExitStack() as stack:
159
+ # cache context will reset for each pipe inference
160
+ for blocks_name in block_adapter.blocks_name:
161
+ stack.enter_context(
162
+ CachedContext.cache_context(
163
+ CachedContext.reset_cache_context(
164
+ blocks_name,
165
+ **cache_kwargs,
166
+ ),
167
+ )
168
+ )
169
+ outputs = original_call(self, *args, **kwargs)
170
+ cls.patch_stats(block_adapter)
171
+ return outputs
656
172
 
657
- pipe.__class__.__call__ = new_call
658
- pipe.__class__._is_cached = True
659
- return pipe
173
+ block_adapter.pipe.__class__.__call__ = new_call
174
+ block_adapter.pipe.__class__._is_cached = True
175
+ return block_adapter.pipe
176
+
177
+ @classmethod
178
+ def patch_stats(cls, block_adapter: BlockAdapter):
179
+ from cache_dit.cache_factory.cache_blocks.utils import (
180
+ patch_cached_stats,
181
+ )
182
+
183
+ patch_cached_stats(block_adapter.transformer)
184
+ for blocks, blocks_name in zip(
185
+ block_adapter.blocks, block_adapter.blocks_name
186
+ ):
187
+ patch_cached_stats(blocks, blocks_name)
660
188
 
661
189
  @classmethod
662
190
  def mock_blocks(
663
191
  cls,
664
192
  block_adapter: BlockAdapter,
665
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
666
193
  ) -> torch.nn.Module:
667
194
 
668
195
  if getattr(block_adapter.transformer, "_is_cached", False):
669
196
  return block_adapter.transformer
670
197
 
671
198
  # Check block forward pattern matching
672
- assert BlockAdapter.match_blocks_pattern(
673
- block_adapter.blocks,
674
- forward_pattern=forward_pattern,
675
- ), (
676
- "No block forward pattern matched, "
677
- f"supported lists: {ForwardPattern.supported_patterns()}"
678
- )
199
+ block_adapter = BlockAdapter.normalize(block_adapter)
200
+ for forward_pattern, blocks in zip(
201
+ block_adapter.forward_pattern, block_adapter.blocks
202
+ ):
203
+ assert BlockAdapter.match_blocks_pattern(
204
+ blocks,
205
+ forward_pattern=forward_pattern,
206
+ check_num_outputs=block_adapter.check_num_outputs,
207
+ ), (
208
+ "No block forward pattern matched, "
209
+ f"supported lists: {ForwardPattern.supported_patterns()}"
210
+ )
679
211
 
680
212
  # Apply cache on transformer: mock cached transformer blocks
681
- cached_blocks = torch.nn.ModuleList(
682
- [
683
- DBCachedTransformerBlocks(
684
- block_adapter.blocks,
685
- transformer=block_adapter.transformer,
686
- forward_pattern=forward_pattern,
687
- )
688
- ]
213
+ # TODO: Use blocks_name to spearate cached context for different
214
+ # blocks list. For example, single_transformer_blocks and
215
+ # transformer_blocks should have different cached context and
216
+ # forward pattern.
217
+ cached_blocks = cls.collect_cached_blocks(
218
+ block_adapter=block_adapter,
689
219
  )
690
220
  dummy_blocks = torch.nn.ModuleList()
691
221
 
@@ -696,13 +226,14 @@ class UnifiedCacheAdapter:
696
226
  @functools.wraps(original_forward)
697
227
  def new_forward(self, *args, **kwargs):
698
228
  with ExitStack() as stack:
699
- stack.enter_context(
700
- unittest.mock.patch.object(
701
- self,
702
- block_adapter.blocks_name,
703
- cached_blocks,
229
+ for blocks_name in block_adapter.blocks_name:
230
+ stack.enter_context(
231
+ unittest.mock.patch.object(
232
+ self,
233
+ blocks_name,
234
+ cached_blocks[blocks_name],
235
+ )
704
236
  )
705
- )
706
237
  for dummy_name in block_adapter.dummy_blocks_names:
707
238
  stack.enter_context(
708
239
  unittest.mock.patch.object(
@@ -719,3 +250,30 @@ class UnifiedCacheAdapter:
719
250
  block_adapter.transformer._is_cached = True
720
251
 
721
252
  return block_adapter.transformer
253
+
254
+ @classmethod
255
+ def collect_cached_blocks(
256
+ cls,
257
+ block_adapter: BlockAdapter,
258
+ ) -> Dict[str, torch.nn.ModuleList]:
259
+ block_adapter = BlockAdapter.normalize(block_adapter)
260
+
261
+ cached_blocks_bind_context = {}
262
+
263
+ for i in range(len(block_adapter.blocks)):
264
+ cached_blocks_bind_context[block_adapter.blocks_name[i]] = (
265
+ torch.nn.ModuleList(
266
+ [
267
+ CachedBlocks(
268
+ block_adapter.blocks[i],
269
+ block_adapter.blocks_name[i],
270
+ block_adapter.blocks_name[i], # context name
271
+ transformer=block_adapter.transformer,
272
+ forward_pattern=block_adapter.forward_pattern[i],
273
+ check_num_outputs=block_adapter.check_num_outputs,
274
+ )
275
+ ]
276
+ )
277
+ )
278
+
279
+ return cached_blocks_bind_context