cache-dit 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -9,8 +9,8 @@ from cache_dit.cache_factory import enable_cache
9
9
  from cache_dit.cache_factory import cache_type
10
10
  from cache_dit.cache_factory import block_range
11
11
  from cache_dit.cache_factory import CacheType
12
+ from cache_dit.cache_factory import BlockAdapter
12
13
  from cache_dit.cache_factory import ForwardPattern
13
- from cache_dit.cache_factory import BlockAdapterParams
14
14
  from cache_dit.compile import set_compile_configs
15
15
  from cache_dit.utils import summary
16
16
  from cache_dit.utils import strify
@@ -19,8 +19,6 @@ from cache_dit.logger import init_logger
19
19
  NONE = CacheType.NONE
20
20
  DBCache = CacheType.DBCache
21
21
 
22
- BlockAdapter = BlockAdapterParams
23
-
24
22
  Forward_Pattern_0 = ForwardPattern.Pattern_0
25
23
  Forward_Pattern_1 = ForwardPattern.Pattern_1
26
24
  Forward_Pattern_2 = ForwardPattern.Pattern_2
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.22'
32
- __version_tuple__ = version_tuple = (0, 2, 22)
31
+ __version__ = version = '0.2.23'
32
+ __version_tuple__ = version_tuple = (0, 2, 23)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -2,7 +2,7 @@ from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
2
  from cache_dit.cache_factory.cache_types import CacheType
3
3
  from cache_dit.cache_factory.cache_types import cache_type
4
4
  from cache_dit.cache_factory.cache_types import block_range
5
- from cache_dit.cache_factory.cache_adapters import BlockAdapterParams
5
+ from cache_dit.cache_factory.cache_adapters import BlockAdapter
6
6
  from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
7
7
  from cache_dit.cache_factory.cache_interface import enable_cache
8
8
  from cache_dit.cache_factory.utils import load_options
@@ -5,7 +5,7 @@ import unittest
5
5
  import functools
6
6
  import dataclasses
7
7
 
8
- from typing import Any
8
+ from typing import Any, Tuple, List
9
9
  from contextlib import ExitStack
10
10
  from diffusers import DiffusionPipeline
11
11
  from cache_dit.cache_factory.patch.flux import (
@@ -23,28 +23,152 @@ logger = init_logger(__name__)
23
23
 
24
24
 
25
25
  @dataclasses.dataclass
26
- class BlockAdapterParams:
26
+ class BlockAdapter:
27
27
  pipe: DiffusionPipeline = None
28
28
  transformer: torch.nn.Module = None
29
29
  blocks: torch.nn.ModuleList = None
30
30
  # transformer_blocks, blocks, etc.
31
31
  blocks_name: str = None
32
32
  dummy_blocks_names: list[str] = dataclasses.field(default_factory=list)
33
+ # flags to control auto block adapter
34
+ auto: bool = False
35
+ allow_prefixes: List[str] = dataclasses.field(
36
+ default_factory=lambda: [
37
+ "transformer",
38
+ "single_transformer",
39
+ "blocks",
40
+ "layers",
41
+ ]
42
+ )
43
+ allow_suffixes: List[str] = dataclasses.field(
44
+ default_factory=lambda: ["TransformerBlock"]
45
+ )
46
+ check_suffixes: bool = False
47
+ blocks_policy: str = dataclasses.field(
48
+ default="max", metadata={"allowed_values": ["max", "min"]}
49
+ )
50
+
51
+ @staticmethod
52
+ def auto_block_adapter(adapter: "BlockAdapter") -> "BlockAdapter":
53
+ assert adapter.auto, (
54
+ "Please manually set `auto` to True, or, manually "
55
+ "set all the transformer blocks configuration."
56
+ )
57
+ assert adapter.pipe is not None, "adapter.pipe can not be None."
58
+ pipe = adapter.pipe
59
+
60
+ assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
61
+
62
+ transformer = pipe.transformer
63
+
64
+ # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
65
+ blocks, blocks_name = BlockAdapter.find_blocks(
66
+ transformer=transformer,
67
+ allow_prefixes=adapter.allow_prefixes,
68
+ allow_suffixes=adapter.allow_suffixes,
69
+ check_suffixes=adapter.check_suffixes,
70
+ blocks_policy=adapter.blocks_policy,
71
+ )
72
+
73
+ return BlockAdapter(
74
+ pipe=pipe,
75
+ transformer=transformer,
76
+ blocks=blocks,
77
+ blocks_name=blocks_name,
78
+ )
33
79
 
34
- def check_adapter_params(self) -> bool:
80
+ @staticmethod
81
+ def check_block_adapter(adapter: "BlockAdapter") -> bool:
35
82
  if (
36
- isinstance(self.pipe, DiffusionPipeline)
37
- and self.transformer is not None
38
- and self.blocks is not None
39
- and isinstance(self.blocks, torch.nn.ModuleList)
83
+ isinstance(adapter.pipe, DiffusionPipeline)
84
+ and adapter.transformer is not None
85
+ and adapter.blocks is not None
86
+ and adapter.blocks_name is not None
87
+ and isinstance(adapter.blocks, torch.nn.ModuleList)
40
88
  ):
41
89
  return True
42
90
  return False
43
91
 
92
+ @staticmethod
93
+ def find_blocks(
94
+ transformer: torch.nn.Module,
95
+ allow_prefixes: List[str] = [
96
+ "transformer",
97
+ "single_transformer",
98
+ "blocks",
99
+ "layers",
100
+ ],
101
+ allow_suffixes: List[str] = [
102
+ "TransformerBlock",
103
+ ],
104
+ check_suffixes: bool = False,
105
+ **kwargs,
106
+ ) -> Tuple[torch.nn.ModuleList, str]:
107
+
108
+ blocks_names = []
109
+ for attr_name in dir(transformer):
110
+ for prefix in allow_prefixes:
111
+ if attr_name.startswith(prefix):
112
+ blocks_names.append(attr_name)
113
+
114
+ # Type check
115
+ valid_names = []
116
+ valid_count = []
117
+ for blocks_name in blocks_names:
118
+ if blocks := getattr(transformer, blocks_name, None):
119
+ if isinstance(blocks, torch.nn.ModuleList):
120
+ block = blocks[0]
121
+ block_cls_name = block.__class__.__name__
122
+ if isinstance(block, torch.nn.Module) and (
123
+ any(
124
+ (
125
+ block_cls_name.endswith(allow_suffix)
126
+ for allow_suffix in allow_suffixes
127
+ )
128
+ )
129
+ or (not check_suffixes)
130
+ ):
131
+ valid_names.append(blocks_name)
132
+ valid_count.append(len(blocks))
133
+
134
+ if not valid_names:
135
+ raise ValueError(
136
+ "Auto selected transformer blocks failed, please set it manually."
137
+ )
138
+
139
+ final_name = valid_names[0]
140
+ final_count = valid_count[0]
141
+ block_policy = kwargs.get("blocks_policy", "max")
142
+ for blocks_name, count in zip(valid_names, valid_count):
143
+ blocks = getattr(transformer, blocks_name)
144
+ logger.info(
145
+ f"Auto selected transformer blocks: {blocks_name}, "
146
+ f"class: {blocks[0].__class__.__name__}, "
147
+ f"num blocks: {count}"
148
+ )
149
+ if block_policy == "max":
150
+ if final_count < count:
151
+ final_count = count
152
+ final_name = blocks_name
153
+ else:
154
+ if final_count > count:
155
+ final_count = count
156
+ final_name = blocks_name
157
+
158
+ final_blocks = getattr(transformer, final_name)
159
+
160
+ logger.info(
161
+ f"Final selected transformer blocks: {final_name}, "
162
+ f"class: {final_blocks[0].__class__.__name__}, "
163
+ f"num blocks: {final_count}, block_policy: {block_policy}."
164
+ )
165
+
166
+ return final_blocks, final_name
167
+
44
168
 
45
169
  @dataclasses.dataclass
46
170
  class UnifiedCacheParams:
47
- adapter_params: BlockAdapterParams = None
171
+ block_adapter: BlockAdapter = None
48
172
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
49
173
 
50
174
 
@@ -85,7 +209,7 @@ class UnifiedCacheAdapter:
85
209
 
86
210
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
87
211
  return UnifiedCacheParams(
88
- adapter_params=BlockAdapterParams(
212
+ block_adapter=BlockAdapter(
89
213
  pipe=pipe,
90
214
  transformer=pipe.transformer,
91
215
  blocks=(
@@ -102,7 +226,7 @@ class UnifiedCacheAdapter:
102
226
 
103
227
  assert isinstance(pipe.transformer, MochiTransformer3DModel)
104
228
  return UnifiedCacheParams(
105
- adapter_params=BlockAdapterParams(
229
+ block_adapter=BlockAdapter(
106
230
  pipe=pipe,
107
231
  transformer=pipe.transformer,
108
232
  blocks=pipe.transformer.transformer_blocks,
@@ -116,7 +240,7 @@ class UnifiedCacheAdapter:
116
240
 
117
241
  assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
118
242
  return UnifiedCacheParams(
119
- adapter_params=BlockAdapterParams(
243
+ block_adapter=BlockAdapter(
120
244
  pipe=pipe,
121
245
  transformer=pipe.transformer,
122
246
  blocks=pipe.transformer.transformer_blocks,
@@ -136,7 +260,7 @@ class UnifiedCacheAdapter:
136
260
  (WanTransformer3DModel, WanVACETransformer3DModel),
137
261
  )
138
262
  return UnifiedCacheParams(
139
- adapter_params=BlockAdapterParams(
263
+ block_adapter=BlockAdapter(
140
264
  pipe=pipe,
141
265
  transformer=pipe.transformer,
142
266
  blocks=pipe.transformer.blocks,
@@ -150,7 +274,7 @@ class UnifiedCacheAdapter:
150
274
 
151
275
  assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
152
276
  return UnifiedCacheParams(
153
- adapter_params=BlockAdapterParams(
277
+ block_adapter=BlockAdapter(
154
278
  pipe=pipe,
155
279
  blocks=(
156
280
  pipe.transformer.transformer_blocks
@@ -166,7 +290,7 @@ class UnifiedCacheAdapter:
166
290
 
167
291
  assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
168
292
  return UnifiedCacheParams(
169
- adapter_params=BlockAdapterParams(
293
+ block_adapter=BlockAdapter(
170
294
  pipe=pipe,
171
295
  transformer=pipe.transformer,
172
296
  blocks=pipe.transformer.transformer_blocks,
@@ -180,7 +304,7 @@ class UnifiedCacheAdapter:
180
304
 
181
305
  assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
182
306
  return UnifiedCacheParams(
183
- adapter_params=BlockAdapterParams(
307
+ block_adapter=BlockAdapter(
184
308
  pipe=pipe,
185
309
  transformer=pipe.transformer,
186
310
  blocks=pipe.transformer.transformer_blocks,
@@ -194,7 +318,7 @@ class UnifiedCacheAdapter:
194
318
 
195
319
  assert isinstance(pipe.transformer, AllegroTransformer3DModel)
196
320
  return UnifiedCacheParams(
197
- adapter_params=BlockAdapterParams(
321
+ block_adapter=BlockAdapter(
198
322
  pipe=pipe,
199
323
  transformer=pipe.transformer,
200
324
  blocks=pipe.transformer.transformer_blocks,
@@ -208,7 +332,7 @@ class UnifiedCacheAdapter:
208
332
 
209
333
  assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
210
334
  return UnifiedCacheParams(
211
- adapter_params=BlockAdapterParams(
335
+ block_adapter=BlockAdapter(
212
336
  pipe=pipe,
213
337
  transformer=pipe.transformer,
214
338
  blocks=pipe.transformer.transformer_blocks,
@@ -222,7 +346,7 @@ class UnifiedCacheAdapter:
222
346
 
223
347
  assert isinstance(pipe.transformer, CogView4Transformer2DModel)
224
348
  return UnifiedCacheParams(
225
- adapter_params=BlockAdapterParams(
349
+ block_adapter=BlockAdapter(
226
350
  pipe=pipe,
227
351
  transformer=pipe.transformer,
228
352
  blocks=pipe.transformer.transformer_blocks,
@@ -236,7 +360,7 @@ class UnifiedCacheAdapter:
236
360
 
237
361
  assert isinstance(pipe.transformer, CosmosTransformer3DModel)
238
362
  return UnifiedCacheParams(
239
- adapter_params=BlockAdapterParams(
363
+ block_adapter=BlockAdapter(
240
364
  pipe=pipe,
241
365
  transformer=pipe.transformer,
242
366
  blocks=pipe.transformer.transformer_blocks,
@@ -250,7 +374,7 @@ class UnifiedCacheAdapter:
250
374
 
251
375
  assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
252
376
  return UnifiedCacheParams(
253
- adapter_params=BlockAdapterParams(
377
+ block_adapter=BlockAdapter(
254
378
  pipe=pipe,
255
379
  transformer=pipe.transformer,
256
380
  blocks=pipe.transformer.transformer_blocks,
@@ -264,7 +388,7 @@ class UnifiedCacheAdapter:
264
388
 
265
389
  assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
266
390
  return UnifiedCacheParams(
267
- adapter_params=BlockAdapterParams(
391
+ block_adapter=BlockAdapter(
268
392
  pipe=pipe,
269
393
  transformer=pipe.transformer,
270
394
  blocks=pipe.transformer.blocks,
@@ -278,7 +402,7 @@ class UnifiedCacheAdapter:
278
402
 
279
403
  assert isinstance(pipe.transformer, SD3Transformer2DModel)
280
404
  return UnifiedCacheParams(
281
- adapter_params=BlockAdapterParams(
405
+ block_adapter=BlockAdapter(
282
406
  pipe=pipe,
283
407
  transformer=pipe.transformer,
284
408
  blocks=pipe.transformer.transformer_blocks,
@@ -294,13 +418,13 @@ class UnifiedCacheAdapter:
294
418
  def apply(
295
419
  cls,
296
420
  pipe: DiffusionPipeline = None,
297
- adapter_params: BlockAdapterParams = None,
421
+ block_adapter: BlockAdapter = None,
298
422
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
299
423
  **cache_context_kwargs,
300
424
  ) -> DiffusionPipeline:
301
425
  assert (
302
- pipe is not None or adapter_params is not None
303
- ), "pipe or adapter_params can not both None!"
426
+ pipe is not None or block_adapter is not None
427
+ ), "pipe or block_adapter can not both None!"
304
428
 
305
429
  if pipe is not None:
306
430
  if cls.is_supported(pipe):
@@ -310,7 +434,7 @@ class UnifiedCacheAdapter:
310
434
  )
311
435
  params = cls.get_params(pipe)
312
436
  return cls.cachify(
313
- params.adapter_params,
437
+ params.block_adapter,
314
438
  forward_pattern=params.forward_pattern,
315
439
  **cache_context_kwargs,
316
440
  )
@@ -324,7 +448,7 @@ class UnifiedCacheAdapter:
324
448
  "Adapting cache acceleration using custom BlockAdapter!"
325
449
  )
326
450
  return cls.cachify(
327
- adapter_params,
451
+ block_adapter,
328
452
  forward_pattern=forward_pattern,
329
453
  **cache_context_kwargs,
330
454
  )
@@ -332,22 +456,25 @@ class UnifiedCacheAdapter:
332
456
  @classmethod
333
457
  def cachify(
334
458
  cls,
335
- adapter_params: BlockAdapterParams,
459
+ block_adapter: BlockAdapter,
336
460
  *,
337
461
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
338
462
  **cache_context_kwargs,
339
463
  ) -> DiffusionPipeline:
340
- if adapter_params.check_adapter_params():
341
- assert isinstance(adapter_params.blocks, torch.nn.ModuleList)
464
+
465
+ if block_adapter.auto:
466
+ block_adapter = BlockAdapter.auto_block_adapter(block_adapter)
467
+
468
+ if BlockAdapter.check_block_adapter(block_adapter):
469
+ assert isinstance(block_adapter.blocks, torch.nn.ModuleList)
342
470
  # Apply cache on pipeline: wrap cache context
343
- cls.create_context(adapter_params.pipe, **cache_context_kwargs)
471
+ cls.create_context(block_adapter.pipe, **cache_context_kwargs)
344
472
  # Apply cache on transformer: mock cached transformer blocks
345
473
  cls.mock_blocks(
346
- adapter_params,
474
+ block_adapter,
347
475
  forward_pattern=forward_pattern,
348
476
  )
349
-
350
- return adapter_params.pipe
477
+ return block_adapter.pipe
351
478
 
352
479
  @classmethod
353
480
  def has_separate_cfg(
@@ -413,22 +540,33 @@ class UnifiedCacheAdapter:
413
540
  @classmethod
414
541
  def mock_blocks(
415
542
  cls,
416
- adapter_params: BlockAdapterParams,
543
+ block_adapter: BlockAdapter,
417
544
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
418
545
  ) -> torch.nn.Module:
419
- if getattr(adapter_params.transformer, "_is_cached", False):
420
- return adapter_params.transformer
546
+
547
+ if (
548
+ block_adapter.transformer is None
549
+ or block_adapter.blocks_name is None
550
+ or block_adapter.blocks is None
551
+ ):
552
+ assert block_adapter.auto, (
553
+ "Please manually set `auto` to True, or, "
554
+ "manually set transformer blocks configuration."
555
+ )
556
+
557
+ if getattr(block_adapter.transformer, "_is_cached", False):
558
+ return block_adapter.transformer
421
559
 
422
560
  # Firstly, process some specificial cases (TODO: more patches)
423
- if adapter_params.transformer.__class__.__name__.startswith("Flux"):
424
- adapter_params.transformer = maybe_patch_flux_transformer(
425
- adapter_params.transformer,
426
- blocks=adapter_params.blocks,
561
+ if block_adapter.transformer.__class__.__name__.startswith("Flux"):
562
+ block_adapter.transformer = maybe_patch_flux_transformer(
563
+ block_adapter.transformer,
564
+ blocks=block_adapter.blocks,
427
565
  )
428
566
 
429
567
  # Check block forward pattern matching
430
568
  assert cls.match_pattern(
431
- adapter_params.blocks,
569
+ block_adapter.blocks,
432
570
  forward_pattern=forward_pattern,
433
571
  ), (
434
572
  "No block forward pattern matched, "
@@ -439,22 +577,17 @@ class UnifiedCacheAdapter:
439
577
  cached_blocks = torch.nn.ModuleList(
440
578
  [
441
579
  DBCachedTransformerBlocks(
442
- adapter_params.blocks,
443
- transformer=adapter_params.transformer,
580
+ block_adapter.blocks,
581
+ transformer=block_adapter.transformer,
444
582
  forward_pattern=forward_pattern,
445
583
  )
446
584
  ]
447
585
  )
448
586
  dummy_blocks = torch.nn.ModuleList()
449
587
 
450
- original_forward = adapter_params.transformer.forward
588
+ original_forward = block_adapter.transformer.forward
451
589
 
452
- assert isinstance(adapter_params.dummy_blocks_names, list)
453
- if adapter_params.blocks_name is None:
454
- adapter_params.blocks_name = cls.find_blocks_name(
455
- adapter_params.transformer
456
- )
457
- assert adapter_params.blocks_name is not None
590
+ assert isinstance(block_adapter.dummy_blocks_names, list)
458
591
 
459
592
  @functools.wraps(original_forward)
460
593
  def new_forward(self, *args, **kwargs):
@@ -462,11 +595,11 @@ class UnifiedCacheAdapter:
462
595
  stack.enter_context(
463
596
  unittest.mock.patch.object(
464
597
  self,
465
- adapter_params.blocks_name,
598
+ block_adapter.blocks_name,
466
599
  cached_blocks,
467
600
  )
468
601
  )
469
- for dummy_name in adapter_params.dummy_blocks_names:
602
+ for dummy_name in block_adapter.dummy_blocks_names:
470
603
  stack.enter_context(
471
604
  unittest.mock.patch.object(
472
605
  self,
@@ -476,12 +609,12 @@ class UnifiedCacheAdapter:
476
609
  )
477
610
  return original_forward(*args, **kwargs)
478
611
 
479
- adapter_params.transformer.forward = new_forward.__get__(
480
- adapter_params.transformer
612
+ block_adapter.transformer.forward = new_forward.__get__(
613
+ block_adapter.transformer
481
614
  )
482
- adapter_params.transformer._is_cached = True
615
+ block_adapter.transformer._is_cached = True
483
616
 
484
- return adapter_params.transformer
617
+ return block_adapter.transformer
485
618
 
486
619
  @classmethod
487
620
  def match_pattern(
@@ -525,22 +658,3 @@ class UnifiedCacheAdapter:
525
658
  )
526
659
 
527
660
  return pattern_matched
528
-
529
- @classmethod
530
- def find_blocks_name(cls, transformer):
531
- blocks_name = None
532
- allow_prefixes = ["transformer", "blocks"]
533
- for attr_name in dir(transformer):
534
- if blocks_name is None:
535
- for prefix in allow_prefixes:
536
- # transformer_blocks, blocks
537
- if attr_name.startswith(prefix):
538
- blocks_name = attr_name
539
- logger.info(f"Auto selected blocks name: {blocks_name}")
540
- # only find one transformer blocks name
541
- break
542
- if blocks_name is None:
543
- logger.warning(
544
- "Auto selected blocks name failed, please set it manually."
545
- )
546
- return blocks_name
@@ -1,5 +1,6 @@
1
1
  import inspect
2
2
  import torch
3
+ import torch.distributed as dist
3
4
 
4
5
  from cache_dit.cache_factory import cache_context
5
6
  from cache_dit.cache_factory import ForwardPattern
@@ -179,10 +180,15 @@ class DBCachedTransformerBlocks(torch.nn.Module):
179
180
  @torch.compiler.disable
180
181
  def _is_parallelized(self):
181
182
  # Compatible with distributed inference.
182
- return all(
183
+ return any(
183
184
  (
184
- self.transformer is not None,
185
- getattr(self.transformer, "_is_parallelized", False),
185
+ all(
186
+ (
187
+ self.transformer is not None,
188
+ getattr(self.transformer, "_is_parallelized", False),
189
+ )
190
+ ),
191
+ (dist.is_initialized() and dist.get_world_size() > 1),
186
192
  )
187
193
  )
188
194
 
@@ -1,7 +1,7 @@
1
1
  from diffusers import DiffusionPipeline
2
2
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
3
3
  from cache_dit.cache_factory.cache_types import CacheType
4
- from cache_dit.cache_factory.cache_adapters import BlockAdapterParams
4
+ from cache_dit.cache_factory.cache_adapters import BlockAdapter
5
5
  from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
6
6
 
7
7
  from cache_dit.logger import init_logger
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
11
11
 
12
12
  def enable_cache(
13
13
  # BlockAdapter & forward pattern
14
- pipe_or_adapter: DiffusionPipeline | BlockAdapterParams,
14
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
15
15
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
16
16
  # Cache context kwargs
17
17
  Fn_compute_blocks: int = 8,
@@ -38,7 +38,7 @@ def enable_cache(
38
38
  with F8B0, 8 warmup steps, and unlimited cached steps.
39
39
 
40
40
  Args:
41
- pipe_or_adapter (`DiffusionPipeline` or `BlockAdapterParams`, *required*):
41
+ pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
42
42
  The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
43
43
  For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
44
44
  for the usgae of BlockAdapter.
@@ -128,22 +128,22 @@ def enable_cache(
128
128
  "n_derivatives": taylorseer_order
129
129
  }
130
130
 
131
- if isinstance(pipe_or_adapter, BlockAdapterParams):
131
+ if isinstance(pipe_or_adapter, BlockAdapter):
132
132
  return UnifiedCacheAdapter.apply(
133
133
  pipe=None,
134
- adapter_params=pipe_or_adapter,
134
+ block_adapter=pipe_or_adapter,
135
135
  forward_pattern=forward_pattern,
136
136
  **cache_context_kwargs,
137
137
  )
138
138
  elif isinstance(pipe_or_adapter, DiffusionPipeline):
139
139
  return UnifiedCacheAdapter.apply(
140
140
  pipe=pipe_or_adapter,
141
- adapter_params=None,
141
+ block_adapter=None,
142
142
  forward_pattern=forward_pattern,
143
143
  **cache_context_kwargs,
144
144
  )
145
145
  else:
146
146
  raise ValueError(
147
- "Please pass DiffusionPipeline or BlockAdapterParams"
148
- "(BlockAdapter) for the 1 position param: pipe_or_adapter"
147
+ "Please pass DiffusionPipeline or BlockAdapter"
148
+ "for the 1's position param: pipe_or_adapter"
149
149
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.22
3
+ Version: 0.2.23
4
4
  Summary: 🤗 CacheDiT: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -173,16 +173,23 @@ But in some cases, you may have a **modified** Diffusion Pipeline or Transformer
173
173
  ```python
174
174
  from cache_dit import ForwardPattern, BlockAdapter
175
175
 
176
- # Please check docs/BlockAdapter.md for more details.
176
+ # Use BlockAdapter with `auto` mode.
177
+ cache_dit.enable_cache(
178
+ BlockAdapter(pipe=pipe, auto=True), # Qwen-Image, etc.
179
+ # Check `📚Forward Pattern Matching` documentation and hack the code of
180
+ # of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
181
+ forward_pattern=ForwardPattern.Pattern_1,
182
+ )
183
+
184
+ # Or, manualy setup transformer configurations.
177
185
  cache_dit.enable_cache(
178
186
  BlockAdapter(
179
187
  pipe=pipe, # Qwen-Image, etc.
180
188
  transformer=pipe.transformer,
181
189
  blocks=pipe.transformer.transformer_blocks,
182
- ),
183
- # Check `📚Forward Pattern Matching` documentation and hack the code of
184
- # of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
185
- forward_pattern=ForwardPattern.Pattern_1,
190
+ blocks_name="transformer_blocks",
191
+ ),
192
+ forward_pattern=ForwardPattern.Pattern_1,
186
193
  )
187
194
  ```
188
195
  For such situations, **BlockAdapter** can help you quickly apply various cache acceleration features to your own Diffusion Pipelines and Transformers. Please check the [📚BlockAdapter.md](./docs/BlockAdapter.md) for more details.
@@ -1,14 +1,14 @@
1
- cache_dit/__init__.py,sha256=wVOaj_LSDsgYygL0cDdUU80_6RINh_JctQFyDalZN7k,946
2
- cache_dit/_version.py,sha256=I7oxlElEVr-U2wT5qgQ2G41IxS87cokjF8Z2fKVHGrc,706
1
+ cache_dit/__init__.py,sha256=KwhX9NfYkWSvDFuuUVeVjcuiZiGS_22y386l8j4afMo,905
2
+ cache_dit/_version.py,sha256=6GZdGbiFdhndXqR5oFLOd8VGzUvRkESP-NyStAZWYUw,706
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/utils.py,sha256=3UgVhfmTFG28w6CV-Rfxp5u1uzLrRozocHwLCTGiQ5M,5865
6
6
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
7
- cache_dit/cache_factory/__init__.py,sha256=cXqBAVvldXNStpxAmNIJnpfJEf2miDlzzyjIqDauFI8,505
8
- cache_dit/cache_factory/cache_adapters.py,sha256=bNcUz4SP3XFpVbkgSlehLdAqKbEXjQJcm-5oS8pKqxg,20289
9
- cache_dit/cache_factory/cache_blocks.py,sha256=kMEOoNvygzeiM2yvUSAPkKpHeQTOpXQYH2qz34TqXzs,18457
7
+ cache_dit/cache_factory/__init__.py,sha256=evWenCin1kuBGa6W5BCKMrDZc1C1R2uVPSg0BjXgdXE,499
8
+ cache_dit/cache_factory/cache_adapters.py,sha256=QTSwjdCmHDeF80TLp6D3KhzQS_oMPna0_bESgJBrdkg,23978
9
+ cache_dit/cache_factory/cache_blocks.py,sha256=ZeazBsYvLIjI5M_OnLL2xP2W7zMeM0rxVfBBwIVHBRs,18661
10
10
  cache_dit/cache_factory/cache_context.py,sha256=4thx9NYxVaYZ_Nr2quUVE8bsNmTsXhZK0F960rccOc8,39000
11
- cache_dit/cache_factory/cache_interface.py,sha256=V1FbtwI78Qj-yoDnz956o5lpnPxH8bMmiZNhiuiYLQo,8090
11
+ cache_dit/cache_factory/cache_interface.py,sha256=PohG_2oy747O37YSsWz_DwxxTXN7ORhQatyEbg_6fQs,8045
12
12
  cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
13
13
  cache_dit/cache_factory/forward_pattern.py,sha256=B2YeqV2t_zo2Ar8m7qimPBjwQgoXHGp2grPZmEAhi8s,1286
14
14
  cache_dit/cache_factory/taylorseer.py,sha256=WeK2WlAJa4Px_pnAKokmnZXeqQYylQkPw4-EDqBIqeQ,3770
@@ -25,9 +25,9 @@ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,1706
25
25
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
26
26
  cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
27
27
  cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
28
- cache_dit-0.2.22.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
29
- cache_dit-0.2.22.dist-info/METADATA,sha256=BABVrkyVTakN0jel9xgApSd9IzDBRLqJHLHhauqka50,19566
30
- cache_dit-0.2.22.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- cache_dit-0.2.22.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
32
- cache_dit-0.2.22.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
33
- cache_dit-0.2.22.dist-info/RECORD,,
28
+ cache_dit-0.2.23.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
29
+ cache_dit-0.2.23.dist-info/METADATA,sha256=Dq2f8TlyTmv36otIJ2F-fRGkJlZmpW2SY6O14P2AYKo,19772
30
+ cache_dit-0.2.23.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ cache_dit-0.2.23.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
32
+ cache_dit-0.2.23.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
33
+ cache_dit-0.2.23.dist-info/RECORD,,