cache-dit 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -7,20 +7,18 @@ except ImportError:
7
7
  from cache_dit.cache_factory import load_options
8
8
  from cache_dit.cache_factory import enable_cache
9
9
  from cache_dit.cache_factory import cache_type
10
- from cache_dit.cache_factory import default_options
11
10
  from cache_dit.cache_factory import block_range
12
11
  from cache_dit.cache_factory import CacheType
12
+ from cache_dit.cache_factory import BlockAdapter
13
13
  from cache_dit.cache_factory import ForwardPattern
14
- from cache_dit.cache_factory import BlockAdapterParams
15
14
  from cache_dit.compile import set_compile_configs
16
15
  from cache_dit.utils import summary
16
+ from cache_dit.utils import strify
17
17
  from cache_dit.logger import init_logger
18
18
 
19
19
  NONE = CacheType.NONE
20
20
  DBCache = CacheType.DBCache
21
21
 
22
- BlockAdapter = BlockAdapterParams
23
-
24
22
  Forward_Pattern_0 = ForwardPattern.Pattern_0
25
23
  Forward_Pattern_1 = ForwardPattern.Pattern_1
26
24
  Forward_Pattern_2 = ForwardPattern.Pattern_2
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.21'
32
- __version_tuple__ = version_tuple = (0, 2, 21)
31
+ __version__ = version = '0.2.23'
32
+ __version_tuple__ = version_tuple = (0, 2, 23)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,65 +1,8 @@
1
- from typing import Dict, List
2
- from diffusers import DiffusionPipeline
3
1
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
4
2
  from cache_dit.cache_factory.cache_types import CacheType
5
- from cache_dit.cache_factory.cache_adapters import BlockAdapterParams
3
+ from cache_dit.cache_factory.cache_types import cache_type
4
+ from cache_dit.cache_factory.cache_types import block_range
5
+ from cache_dit.cache_factory.cache_adapters import BlockAdapter
6
6
  from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
7
- from cache_dit.cache_factory.utils import load_cache_options_from_yaml
8
-
9
- from cache_dit.logger import init_logger
10
-
11
- logger = init_logger(__name__)
12
-
13
-
14
- def load_options(path: str):
15
- return load_cache_options_from_yaml(path)
16
-
17
-
18
- def cache_type(
19
- type_hint: "CacheType | str",
20
- ) -> CacheType:
21
- return CacheType.type(cache_type=type_hint)
22
-
23
-
24
- def default_options(
25
- cache_type: CacheType = CacheType.DBCache,
26
- ) -> Dict:
27
- return CacheType.default_options(cache_type)
28
-
29
-
30
- def block_range(
31
- start: int,
32
- end: int,
33
- step: int = 1,
34
- ) -> List[int]:
35
- return CacheType.block_range(
36
- start,
37
- end,
38
- step,
39
- )
40
-
41
-
42
- def enable_cache(
43
- pipe_or_adapter: DiffusionPipeline | BlockAdapterParams,
44
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
45
- **cache_options_kwargs,
46
- ) -> DiffusionPipeline:
47
- if isinstance(pipe_or_adapter, BlockAdapterParams):
48
- return UnifiedCacheAdapter.apply(
49
- pipe=None,
50
- adapter_params=pipe_or_adapter,
51
- forward_pattern=forward_pattern,
52
- **cache_options_kwargs,
53
- )
54
- elif isinstance(pipe_or_adapter, DiffusionPipeline):
55
- return UnifiedCacheAdapter.apply(
56
- pipe=pipe_or_adapter,
57
- adapter_params=None,
58
- forward_pattern=forward_pattern,
59
- **cache_options_kwargs,
60
- )
61
- else:
62
- raise ValueError(
63
- "Please pass DiffusionPipeline or BlockAdapterParams"
64
- "(BlockAdapter) for the 1 position param: pipe_or_adapter"
65
- )
7
+ from cache_dit.cache_factory.cache_interface import enable_cache
8
+ from cache_dit.cache_factory.utils import load_options
@@ -5,7 +5,7 @@ import unittest
5
5
  import functools
6
6
  import dataclasses
7
7
 
8
- from typing import Any
8
+ from typing import Any, Tuple, List
9
9
  from contextlib import ExitStack
10
10
  from diffusers import DiffusionPipeline
11
11
  from cache_dit.cache_factory.patch.flux import (
@@ -23,28 +23,152 @@ logger = init_logger(__name__)
23
23
 
24
24
 
25
25
  @dataclasses.dataclass
26
- class BlockAdapterParams:
26
+ class BlockAdapter:
27
27
  pipe: DiffusionPipeline = None
28
28
  transformer: torch.nn.Module = None
29
29
  blocks: torch.nn.ModuleList = None
30
30
  # transformer_blocks, blocks, etc.
31
31
  blocks_name: str = None
32
32
  dummy_blocks_names: list[str] = dataclasses.field(default_factory=list)
33
+ # flags to control auto block adapter
34
+ auto: bool = False
35
+ allow_prefixes: List[str] = dataclasses.field(
36
+ default_factory=lambda: [
37
+ "transformer",
38
+ "single_transformer",
39
+ "blocks",
40
+ "layers",
41
+ ]
42
+ )
43
+ allow_suffixes: List[str] = dataclasses.field(
44
+ default_factory=lambda: ["TransformerBlock"]
45
+ )
46
+ check_suffixes: bool = False
47
+ blocks_policy: str = dataclasses.field(
48
+ default="max", metadata={"allowed_values": ["max", "min"]}
49
+ )
50
+
51
+ @staticmethod
52
+ def auto_block_adapter(adapter: "BlockAdapter") -> "BlockAdapter":
53
+ assert adapter.auto, (
54
+ "Please manually set `auto` to True, or, manually "
55
+ "set all the transformer blocks configuration."
56
+ )
57
+ assert adapter.pipe is not None, "adapter.pipe can not be None."
58
+ pipe = adapter.pipe
59
+
60
+ assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
61
+
62
+ transformer = pipe.transformer
63
+
64
+ # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
65
+ blocks, blocks_name = BlockAdapter.find_blocks(
66
+ transformer=transformer,
67
+ allow_prefixes=adapter.allow_prefixes,
68
+ allow_suffixes=adapter.allow_suffixes,
69
+ check_suffixes=adapter.check_suffixes,
70
+ blocks_policy=adapter.blocks_policy,
71
+ )
33
72
 
34
- def check_adapter_params(self) -> bool:
73
+ return BlockAdapter(
74
+ pipe=pipe,
75
+ transformer=transformer,
76
+ blocks=blocks,
77
+ blocks_name=blocks_name,
78
+ )
79
+
80
+ @staticmethod
81
+ def check_block_adapter(adapter: "BlockAdapter") -> bool:
35
82
  if (
36
- isinstance(self.pipe, DiffusionPipeline)
37
- and self.transformer is not None
38
- and self.blocks is not None
39
- and isinstance(self.blocks, torch.nn.ModuleList)
83
+ isinstance(adapter.pipe, DiffusionPipeline)
84
+ and adapter.transformer is not None
85
+ and adapter.blocks is not None
86
+ and adapter.blocks_name is not None
87
+ and isinstance(adapter.blocks, torch.nn.ModuleList)
40
88
  ):
41
89
  return True
42
90
  return False
43
91
 
92
+ @staticmethod
93
+ def find_blocks(
94
+ transformer: torch.nn.Module,
95
+ allow_prefixes: List[str] = [
96
+ "transformer",
97
+ "single_transformer",
98
+ "blocks",
99
+ "layers",
100
+ ],
101
+ allow_suffixes: List[str] = [
102
+ "TransformerBlock",
103
+ ],
104
+ check_suffixes: bool = False,
105
+ **kwargs,
106
+ ) -> Tuple[torch.nn.ModuleList, str]:
107
+
108
+ blocks_names = []
109
+ for attr_name in dir(transformer):
110
+ for prefix in allow_prefixes:
111
+ if attr_name.startswith(prefix):
112
+ blocks_names.append(attr_name)
113
+
114
+ # Type check
115
+ valid_names = []
116
+ valid_count = []
117
+ for blocks_name in blocks_names:
118
+ if blocks := getattr(transformer, blocks_name, None):
119
+ if isinstance(blocks, torch.nn.ModuleList):
120
+ block = blocks[0]
121
+ block_cls_name = block.__class__.__name__
122
+ if isinstance(block, torch.nn.Module) and (
123
+ any(
124
+ (
125
+ block_cls_name.endswith(allow_suffix)
126
+ for allow_suffix in allow_suffixes
127
+ )
128
+ )
129
+ or (not check_suffixes)
130
+ ):
131
+ valid_names.append(blocks_name)
132
+ valid_count.append(len(blocks))
133
+
134
+ if not valid_names:
135
+ raise ValueError(
136
+ "Auto selected transformer blocks failed, please set it manually."
137
+ )
138
+
139
+ final_name = valid_names[0]
140
+ final_count = valid_count[0]
141
+ block_policy = kwargs.get("blocks_policy", "max")
142
+ for blocks_name, count in zip(valid_names, valid_count):
143
+ blocks = getattr(transformer, blocks_name)
144
+ logger.info(
145
+ f"Auto selected transformer blocks: {blocks_name}, "
146
+ f"class: {blocks[0].__class__.__name__}, "
147
+ f"num blocks: {count}"
148
+ )
149
+ if block_policy == "max":
150
+ if final_count < count:
151
+ final_count = count
152
+ final_name = blocks_name
153
+ else:
154
+ if final_count > count:
155
+ final_count = count
156
+ final_name = blocks_name
157
+
158
+ final_blocks = getattr(transformer, final_name)
159
+
160
+ logger.info(
161
+ f"Final selected transformer blocks: {final_name}, "
162
+ f"class: {final_blocks[0].__class__.__name__}, "
163
+ f"num blocks: {final_count}, block_policy: {block_policy}."
164
+ )
165
+
166
+ return final_blocks, final_name
167
+
44
168
 
45
169
  @dataclasses.dataclass
46
170
  class UnifiedCacheParams:
47
- adapter_params: BlockAdapterParams = None
171
+ block_adapter: BlockAdapter = None
48
172
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
49
173
 
50
174
 
@@ -85,7 +209,7 @@ class UnifiedCacheAdapter:
85
209
 
86
210
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
87
211
  return UnifiedCacheParams(
88
- adapter_params=BlockAdapterParams(
212
+ block_adapter=BlockAdapter(
89
213
  pipe=pipe,
90
214
  transformer=pipe.transformer,
91
215
  blocks=(
@@ -102,7 +226,7 @@ class UnifiedCacheAdapter:
102
226
 
103
227
  assert isinstance(pipe.transformer, MochiTransformer3DModel)
104
228
  return UnifiedCacheParams(
105
- adapter_params=BlockAdapterParams(
229
+ block_adapter=BlockAdapter(
106
230
  pipe=pipe,
107
231
  transformer=pipe.transformer,
108
232
  blocks=pipe.transformer.transformer_blocks,
@@ -116,7 +240,7 @@ class UnifiedCacheAdapter:
116
240
 
117
241
  assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
118
242
  return UnifiedCacheParams(
119
- adapter_params=BlockAdapterParams(
243
+ block_adapter=BlockAdapter(
120
244
  pipe=pipe,
121
245
  transformer=pipe.transformer,
122
246
  blocks=pipe.transformer.transformer_blocks,
@@ -136,7 +260,7 @@ class UnifiedCacheAdapter:
136
260
  (WanTransformer3DModel, WanVACETransformer3DModel),
137
261
  )
138
262
  return UnifiedCacheParams(
139
- adapter_params=BlockAdapterParams(
263
+ block_adapter=BlockAdapter(
140
264
  pipe=pipe,
141
265
  transformer=pipe.transformer,
142
266
  blocks=pipe.transformer.blocks,
@@ -150,7 +274,7 @@ class UnifiedCacheAdapter:
150
274
 
151
275
  assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
152
276
  return UnifiedCacheParams(
153
- adapter_params=BlockAdapterParams(
277
+ block_adapter=BlockAdapter(
154
278
  pipe=pipe,
155
279
  blocks=(
156
280
  pipe.transformer.transformer_blocks
@@ -166,7 +290,7 @@ class UnifiedCacheAdapter:
166
290
 
167
291
  assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
168
292
  return UnifiedCacheParams(
169
- adapter_params=BlockAdapterParams(
293
+ block_adapter=BlockAdapter(
170
294
  pipe=pipe,
171
295
  transformer=pipe.transformer,
172
296
  blocks=pipe.transformer.transformer_blocks,
@@ -180,7 +304,7 @@ class UnifiedCacheAdapter:
180
304
 
181
305
  assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
182
306
  return UnifiedCacheParams(
183
- adapter_params=BlockAdapterParams(
307
+ block_adapter=BlockAdapter(
184
308
  pipe=pipe,
185
309
  transformer=pipe.transformer,
186
310
  blocks=pipe.transformer.transformer_blocks,
@@ -194,7 +318,7 @@ class UnifiedCacheAdapter:
194
318
 
195
319
  assert isinstance(pipe.transformer, AllegroTransformer3DModel)
196
320
  return UnifiedCacheParams(
197
- adapter_params=BlockAdapterParams(
321
+ block_adapter=BlockAdapter(
198
322
  pipe=pipe,
199
323
  transformer=pipe.transformer,
200
324
  blocks=pipe.transformer.transformer_blocks,
@@ -208,7 +332,7 @@ class UnifiedCacheAdapter:
208
332
 
209
333
  assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
210
334
  return UnifiedCacheParams(
211
- adapter_params=BlockAdapterParams(
335
+ block_adapter=BlockAdapter(
212
336
  pipe=pipe,
213
337
  transformer=pipe.transformer,
214
338
  blocks=pipe.transformer.transformer_blocks,
@@ -222,7 +346,7 @@ class UnifiedCacheAdapter:
222
346
 
223
347
  assert isinstance(pipe.transformer, CogView4Transformer2DModel)
224
348
  return UnifiedCacheParams(
225
- adapter_params=BlockAdapterParams(
349
+ block_adapter=BlockAdapter(
226
350
  pipe=pipe,
227
351
  transformer=pipe.transformer,
228
352
  blocks=pipe.transformer.transformer_blocks,
@@ -236,7 +360,7 @@ class UnifiedCacheAdapter:
236
360
 
237
361
  assert isinstance(pipe.transformer, CosmosTransformer3DModel)
238
362
  return UnifiedCacheParams(
239
- adapter_params=BlockAdapterParams(
363
+ block_adapter=BlockAdapter(
240
364
  pipe=pipe,
241
365
  transformer=pipe.transformer,
242
366
  blocks=pipe.transformer.transformer_blocks,
@@ -250,7 +374,7 @@ class UnifiedCacheAdapter:
250
374
 
251
375
  assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
252
376
  return UnifiedCacheParams(
253
- adapter_params=BlockAdapterParams(
377
+ block_adapter=BlockAdapter(
254
378
  pipe=pipe,
255
379
  transformer=pipe.transformer,
256
380
  blocks=pipe.transformer.transformer_blocks,
@@ -264,7 +388,7 @@ class UnifiedCacheAdapter:
264
388
 
265
389
  assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
266
390
  return UnifiedCacheParams(
267
- adapter_params=BlockAdapterParams(
391
+ block_adapter=BlockAdapter(
268
392
  pipe=pipe,
269
393
  transformer=pipe.transformer,
270
394
  blocks=pipe.transformer.blocks,
@@ -278,7 +402,7 @@ class UnifiedCacheAdapter:
278
402
 
279
403
  assert isinstance(pipe.transformer, SD3Transformer2DModel)
280
404
  return UnifiedCacheParams(
281
- adapter_params=BlockAdapterParams(
405
+ block_adapter=BlockAdapter(
282
406
  pipe=pipe,
283
407
  transformer=pipe.transformer,
284
408
  blocks=pipe.transformer.transformer_blocks,
@@ -294,13 +418,13 @@ class UnifiedCacheAdapter:
294
418
  def apply(
295
419
  cls,
296
420
  pipe: DiffusionPipeline = None,
297
- adapter_params: BlockAdapterParams = None,
421
+ block_adapter: BlockAdapter = None,
298
422
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
299
423
  **cache_context_kwargs,
300
424
  ) -> DiffusionPipeline:
301
425
  assert (
302
- pipe is not None or adapter_params is not None
303
- ), "pipe or adapter_params can not both None!"
426
+ pipe is not None or block_adapter is not None
427
+ ), "pipe or block_adapter can not both None!"
304
428
 
305
429
  if pipe is not None:
306
430
  if cls.is_supported(pipe):
@@ -310,7 +434,7 @@ class UnifiedCacheAdapter:
310
434
  )
311
435
  params = cls.get_params(pipe)
312
436
  return cls.cachify(
313
- params.adapter_params,
437
+ params.block_adapter,
314
438
  forward_pattern=params.forward_pattern,
315
439
  **cache_context_kwargs,
316
440
  )
@@ -324,7 +448,7 @@ class UnifiedCacheAdapter:
324
448
  "Adapting cache acceleration using custom BlockAdapter!"
325
449
  )
326
450
  return cls.cachify(
327
- adapter_params,
451
+ block_adapter,
328
452
  forward_pattern=forward_pattern,
329
453
  **cache_context_kwargs,
330
454
  )
@@ -332,25 +456,28 @@ class UnifiedCacheAdapter:
332
456
  @classmethod
333
457
  def cachify(
334
458
  cls,
335
- adapter_params: BlockAdapterParams,
459
+ block_adapter: BlockAdapter,
336
460
  *,
337
461
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
338
462
  **cache_context_kwargs,
339
463
  ) -> DiffusionPipeline:
340
- if adapter_params.check_adapter_params():
341
- assert isinstance(adapter_params.blocks, torch.nn.ModuleList)
464
+
465
+ if block_adapter.auto:
466
+ block_adapter = BlockAdapter.auto_block_adapter(block_adapter)
467
+
468
+ if BlockAdapter.check_block_adapter(block_adapter):
469
+ assert isinstance(block_adapter.blocks, torch.nn.ModuleList)
342
470
  # Apply cache on pipeline: wrap cache context
343
- cls.create_context(adapter_params.pipe, **cache_context_kwargs)
471
+ cls.create_context(block_adapter.pipe, **cache_context_kwargs)
344
472
  # Apply cache on transformer: mock cached transformer blocks
345
473
  cls.mock_blocks(
346
- adapter_params,
474
+ block_adapter,
347
475
  forward_pattern=forward_pattern,
348
476
  )
349
-
350
- return adapter_params.pipe
477
+ return block_adapter.pipe
351
478
 
352
479
  @classmethod
353
- def has_separate_classifier_free_guidance(
480
+ def has_separate_cfg(
354
481
  cls,
355
482
  pipe_or_transformer: DiffusionPipeline | Any,
356
483
  ) -> bool:
@@ -364,20 +491,9 @@ class UnifiedCacheAdapter:
364
491
  @classmethod
365
492
  def check_context_kwargs(cls, pipe, **cache_context_kwargs):
366
493
  # Check cache_context_kwargs
367
- if not cache_context_kwargs:
368
- cache_context_kwargs = CacheType.default_options(CacheType.DBCache)
369
- if cls.has_separate_classifier_free_guidance(pipe):
370
- cache_context_kwargs["do_separate_classifier_free_guidance"] = (
371
- True
372
- )
373
- logger.warning(
374
- "cache_context_kwargs is empty, use default "
375
- f"cache options: {cache_context_kwargs}"
376
- )
377
- else:
378
- # Allow empty cache_type, we only support DBCache now.
379
- if cache_context_kwargs.get("cache_type", None):
380
- cache_context_kwargs["cache_type"] = CacheType.DBCache
494
+ if not cache_context_kwargs["do_separate_cfg"]:
495
+ # Check cfg for some specific case if users don't set it as True
496
+ cache_context_kwargs["do_separate_cfg"] = cls.has_separate_cfg(pipe)
381
497
 
382
498
  if cache_type := cache_context_kwargs.pop("cache_type", None):
383
499
  assert (
@@ -424,22 +540,33 @@ class UnifiedCacheAdapter:
424
540
  @classmethod
425
541
  def mock_blocks(
426
542
  cls,
427
- adapter_params: BlockAdapterParams,
543
+ block_adapter: BlockAdapter,
428
544
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
429
545
  ) -> torch.nn.Module:
430
- if getattr(adapter_params.transformer, "_is_cached", False):
431
- return adapter_params.transformer
546
+
547
+ if (
548
+ block_adapter.transformer is None
549
+ or block_adapter.blocks_name is None
550
+ or block_adapter.blocks is None
551
+ ):
552
+ assert block_adapter.auto, (
553
+ "Please manually set `auto` to True, or, "
554
+ "manually set transformer blocks configuration."
555
+ )
556
+
557
+ if getattr(block_adapter.transformer, "_is_cached", False):
558
+ return block_adapter.transformer
432
559
 
433
560
  # Firstly, process some specificial cases (TODO: more patches)
434
- if adapter_params.transformer.__class__.__name__.startswith("Flux"):
435
- adapter_params.transformer = maybe_patch_flux_transformer(
436
- adapter_params.transformer,
437
- blocks=adapter_params.blocks,
561
+ if block_adapter.transformer.__class__.__name__.startswith("Flux"):
562
+ block_adapter.transformer = maybe_patch_flux_transformer(
563
+ block_adapter.transformer,
564
+ blocks=block_adapter.blocks,
438
565
  )
439
566
 
440
567
  # Check block forward pattern matching
441
568
  assert cls.match_pattern(
442
- adapter_params.blocks,
569
+ block_adapter.blocks,
443
570
  forward_pattern=forward_pattern,
444
571
  ), (
445
572
  "No block forward pattern matched, "
@@ -450,22 +577,17 @@ class UnifiedCacheAdapter:
450
577
  cached_blocks = torch.nn.ModuleList(
451
578
  [
452
579
  DBCachedTransformerBlocks(
453
- adapter_params.blocks,
454
- transformer=adapter_params.transformer,
580
+ block_adapter.blocks,
581
+ transformer=block_adapter.transformer,
455
582
  forward_pattern=forward_pattern,
456
583
  )
457
584
  ]
458
585
  )
459
586
  dummy_blocks = torch.nn.ModuleList()
460
587
 
461
- original_forward = adapter_params.transformer.forward
588
+ original_forward = block_adapter.transformer.forward
462
589
 
463
- assert isinstance(adapter_params.dummy_blocks_names, list)
464
- if adapter_params.blocks_name is None:
465
- adapter_params.blocks_name = cls.find_blocks_name(
466
- adapter_params.transformer
467
- )
468
- assert adapter_params.blocks_name is not None
590
+ assert isinstance(block_adapter.dummy_blocks_names, list)
469
591
 
470
592
  @functools.wraps(original_forward)
471
593
  def new_forward(self, *args, **kwargs):
@@ -473,11 +595,11 @@ class UnifiedCacheAdapter:
473
595
  stack.enter_context(
474
596
  unittest.mock.patch.object(
475
597
  self,
476
- adapter_params.blocks_name,
598
+ block_adapter.blocks_name,
477
599
  cached_blocks,
478
600
  )
479
601
  )
480
- for dummy_name in adapter_params.dummy_blocks_names:
602
+ for dummy_name in block_adapter.dummy_blocks_names:
481
603
  stack.enter_context(
482
604
  unittest.mock.patch.object(
483
605
  self,
@@ -487,12 +609,12 @@ class UnifiedCacheAdapter:
487
609
  )
488
610
  return original_forward(*args, **kwargs)
489
611
 
490
- adapter_params.transformer.forward = new_forward.__get__(
491
- adapter_params.transformer
612
+ block_adapter.transformer.forward = new_forward.__get__(
613
+ block_adapter.transformer
492
614
  )
493
- adapter_params.transformer._is_cached = True
615
+ block_adapter.transformer._is_cached = True
494
616
 
495
- return adapter_params.transformer
617
+ return block_adapter.transformer
496
618
 
497
619
  @classmethod
498
620
  def match_pattern(
@@ -536,22 +658,3 @@ class UnifiedCacheAdapter:
536
658
  )
537
659
 
538
660
  return pattern_matched
539
-
540
- @classmethod
541
- def find_blocks_name(cls, transformer):
542
- blocks_name = None
543
- allow_prefixes = ["transformer", "blocks"]
544
- for attr_name in dir(transformer):
545
- if blocks_name is None:
546
- for prefix in allow_prefixes:
547
- # transformer_blocks, blocks
548
- if attr_name.startswith(prefix):
549
- blocks_name = attr_name
550
- logger.info(f"Auto selected blocks name: {blocks_name}")
551
- # only find one transformer blocks name
552
- break
553
- if blocks_name is None:
554
- logger.warning(
555
- "Auto selected blocks name failed, please set it manually."
556
- )
557
- return blocks_name
@@ -1,5 +1,6 @@
1
1
  import inspect
2
2
  import torch
3
+ import torch.distributed as dist
3
4
 
4
5
  from cache_dit.cache_factory import cache_context
5
6
  from cache_dit.cache_factory import ForwardPattern
@@ -179,10 +180,15 @@ class DBCachedTransformerBlocks(torch.nn.Module):
179
180
  @torch.compiler.disable
180
181
  def _is_parallelized(self):
181
182
  # Compatible with distributed inference.
182
- return all(
183
+ return any(
183
184
  (
184
- self.transformer is not None,
185
- getattr(self.transformer, "_is_parallelized", False),
185
+ all(
186
+ (
187
+ self.transformer is not None,
188
+ getattr(self.transformer, "_is_parallelized", False),
189
+ )
190
+ ),
191
+ (dist.is_initialized() and dist.get_world_size() > 1),
186
192
  )
187
193
  )
188
194