cache-dit 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -9,8 +9,8 @@ from cache_dit.cache_factory import enable_cache
9
9
  from cache_dit.cache_factory import cache_type
10
10
  from cache_dit.cache_factory import block_range
11
11
  from cache_dit.cache_factory import CacheType
12
+ from cache_dit.cache_factory import BlockAdapter
12
13
  from cache_dit.cache_factory import ForwardPattern
13
- from cache_dit.cache_factory import BlockAdapterParams
14
14
  from cache_dit.compile import set_compile_configs
15
15
  from cache_dit.utils import summary
16
16
  from cache_dit.utils import strify
@@ -19,8 +19,6 @@ from cache_dit.logger import init_logger
19
19
  NONE = CacheType.NONE
20
20
  DBCache = CacheType.DBCache
21
21
 
22
- BlockAdapter = BlockAdapterParams
23
-
24
22
  Forward_Pattern_0 = ForwardPattern.Pattern_0
25
23
  Forward_Pattern_1 = ForwardPattern.Pattern_1
26
24
  Forward_Pattern_2 = ForwardPattern.Pattern_2
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.22'
32
- __version_tuple__ = version_tuple = (0, 2, 22)
31
+ __version__ = version = '0.2.24'
32
+ __version_tuple__ = version_tuple = (0, 2, 24)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -2,7 +2,7 @@ from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
2
  from cache_dit.cache_factory.cache_types import CacheType
3
3
  from cache_dit.cache_factory.cache_types import cache_type
4
4
  from cache_dit.cache_factory.cache_types import block_range
5
- from cache_dit.cache_factory.cache_adapters import BlockAdapterParams
5
+ from cache_dit.cache_factory.cache_adapters import BlockAdapter
6
6
  from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
7
7
  from cache_dit.cache_factory.cache_interface import enable_cache
8
8
  from cache_dit.cache_factory.utils import load_options
@@ -5,7 +5,7 @@ import unittest
5
5
  import functools
6
6
  import dataclasses
7
7
 
8
- from typing import Any
8
+ from typing import Any, Tuple, List, Optional
9
9
  from contextlib import ExitStack
10
10
  from diffusers import DiffusionPipeline
11
11
  from cache_dit.cache_factory.patch.flux import (
@@ -23,28 +23,251 @@ logger = init_logger(__name__)
23
23
 
24
24
 
25
25
  @dataclasses.dataclass
26
- class BlockAdapterParams:
26
+ class BlockAdapter:
27
27
  pipe: DiffusionPipeline = None
28
28
  transformer: torch.nn.Module = None
29
29
  blocks: torch.nn.ModuleList = None
30
30
  # transformer_blocks, blocks, etc.
31
31
  blocks_name: str = None
32
32
  dummy_blocks_names: list[str] = dataclasses.field(default_factory=list)
33
+ # flags to control auto block adapter
34
+ auto: bool = False
35
+ allow_prefixes: List[str] = dataclasses.field(
36
+ default_factory=lambda: [
37
+ "transformer",
38
+ "single_transformer",
39
+ "blocks",
40
+ "layers",
41
+ ]
42
+ )
43
+ check_prefixes: bool = True
44
+ allow_suffixes: List[str] = dataclasses.field(
45
+ default_factory=lambda: ["TransformerBlock"]
46
+ )
47
+ check_suffixes: bool = False
48
+ blocks_policy: str = dataclasses.field(
49
+ default="max", metadata={"allowed_values": ["max", "min"]}
50
+ )
51
+
52
+ def __post_init__(self):
53
+ self.maybe_apply_patch()
54
+
55
+ def maybe_apply_patch(self):
56
+ # Process some specificial cases, specific for transformers
57
+ # that has different forward patterns between single_transformer_blocks
58
+ # and transformer_blocks , such as Flux (diffusers < 0.35.0).
59
+ if self.transformer.__class__.__name__.startswith("Flux"):
60
+ self.transformer = maybe_patch_flux_transformer(
61
+ self.transformer,
62
+ blocks=self.blocks,
63
+ )
64
+
65
+ @staticmethod
66
+ def auto_block_adapter(
67
+ adapter: "BlockAdapter",
68
+ forward_pattern: Optional[ForwardPattern] = None,
69
+ ) -> "BlockAdapter":
70
+ assert adapter.auto, (
71
+ "Please manually set `auto` to True, or, manually "
72
+ "set all the transformer blocks configuration."
73
+ )
74
+ assert adapter.pipe is not None, "adapter.pipe can not be None."
75
+ pipe = adapter.pipe
76
+
77
+ assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
78
+
79
+ transformer = pipe.transformer
80
+
81
+ # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
82
+ blocks, blocks_name = BlockAdapter.find_blocks(
83
+ transformer=transformer,
84
+ allow_prefixes=adapter.allow_prefixes,
85
+ allow_suffixes=adapter.allow_suffixes,
86
+ check_prefixes=adapter.check_prefixes,
87
+ check_suffixes=adapter.check_suffixes,
88
+ blocks_policy=adapter.blocks_policy,
89
+ forward_pattern=forward_pattern,
90
+ )
91
+
92
+ return BlockAdapter(
93
+ pipe=pipe,
94
+ transformer=transformer,
95
+ blocks=blocks,
96
+ blocks_name=blocks_name,
97
+ )
33
98
 
34
- def check_adapter_params(self) -> bool:
99
+ @staticmethod
100
+ def check_block_adapter(adapter: "BlockAdapter") -> bool:
35
101
  if (
36
- isinstance(self.pipe, DiffusionPipeline)
37
- and self.transformer is not None
38
- and self.blocks is not None
39
- and isinstance(self.blocks, torch.nn.ModuleList)
102
+ isinstance(adapter.pipe, DiffusionPipeline)
103
+ and adapter.transformer is not None
104
+ and adapter.blocks is not None
105
+ and adapter.blocks_name is not None
106
+ and isinstance(adapter.blocks, torch.nn.ModuleList)
40
107
  ):
41
108
  return True
109
+
110
+ logger.warning("Check block adapter failed!")
42
111
  return False
43
112
 
113
+ @staticmethod
114
+ def find_blocks(
115
+ transformer: torch.nn.Module,
116
+ allow_prefixes: List[str] = [
117
+ "transformer",
118
+ "single_transformer",
119
+ "blocks",
120
+ "layers",
121
+ ],
122
+ allow_suffixes: List[str] = [
123
+ "TransformerBlock",
124
+ ],
125
+ check_prefixes: bool = True,
126
+ check_suffixes: bool = False,
127
+ **kwargs,
128
+ ) -> Tuple[torch.nn.ModuleList, str]:
129
+ # Check prefixes
130
+ if check_prefixes:
131
+ blocks_names = []
132
+ for attr_name in dir(transformer):
133
+ for prefix in allow_prefixes:
134
+ if attr_name.startswith(prefix):
135
+ blocks_names.append(attr_name)
136
+ else:
137
+ blocks_names = dir(transformer)
138
+
139
+ # Check ModuleList
140
+ valid_names = []
141
+ valid_count = []
142
+ forward_pattern = kwargs.get("forward_pattern", None)
143
+ for blocks_name in blocks_names:
144
+ if blocks := getattr(transformer, blocks_name, None):
145
+ if isinstance(blocks, torch.nn.ModuleList):
146
+ block = blocks[0]
147
+ block_cls_name = block.__class__.__name__
148
+ # Check suffixes
149
+ if isinstance(block, torch.nn.Module) and (
150
+ any(
151
+ (
152
+ block_cls_name.endswith(allow_suffix)
153
+ for allow_suffix in allow_suffixes
154
+ )
155
+ )
156
+ or (not check_suffixes)
157
+ ):
158
+ # May check forward pattern
159
+ if forward_pattern is not None:
160
+ if BlockAdapter.match_blocks_pattern(
161
+ blocks,
162
+ forward_pattern,
163
+ logging=False,
164
+ ):
165
+ valid_names.append(blocks_name)
166
+ valid_count.append(len(blocks))
167
+ else:
168
+ valid_names.append(blocks_name)
169
+ valid_count.append(len(blocks))
170
+
171
+ if not valid_names:
172
+ raise ValueError(
173
+ "Auto selected transformer blocks failed, please set it manually."
174
+ )
175
+
176
+ final_name = valid_names[0]
177
+ final_count = valid_count[0]
178
+ block_policy = kwargs.get("blocks_policy", "max")
179
+
180
+ for blocks_name, count in zip(valid_names, valid_count):
181
+ blocks = getattr(transformer, blocks_name)
182
+ logger.info(
183
+ f"Auto selected transformer blocks: {blocks_name}, "
184
+ f"class: {blocks[0].__class__.__name__}, "
185
+ f"num blocks: {count}"
186
+ )
187
+ if block_policy == "max":
188
+ if final_count < count:
189
+ final_count = count
190
+ final_name = blocks_name
191
+ else:
192
+ if final_count > count:
193
+ final_count = count
194
+ final_name = blocks_name
195
+
196
+ final_blocks = getattr(transformer, final_name)
197
+
198
+ logger.info(
199
+ f"Final selected transformer blocks: {final_name}, "
200
+ f"class: {final_blocks[0].__class__.__name__}, "
201
+ f"num blocks: {final_count}, block_policy: {block_policy}."
202
+ )
203
+
204
+ return final_blocks, final_name
205
+
206
+ @staticmethod
207
+ def match_block_pattern(
208
+ block: torch.nn.Module,
209
+ forward_pattern: ForwardPattern,
210
+ ) -> bool:
211
+ assert (
212
+ forward_pattern.Supported
213
+ and forward_pattern in ForwardPattern.supported_patterns()
214
+ ), f"Pattern {forward_pattern} is not support now!"
215
+
216
+ forward_parameters = set(
217
+ inspect.signature(block.forward).parameters.keys()
218
+ )
219
+ num_outputs = str(
220
+ inspect.signature(block.forward).return_annotation
221
+ ).count("torch.Tensor")
222
+
223
+ in_matched = True
224
+ out_matched = True
225
+ if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
226
+ # output pattern not match
227
+ out_matched = False
228
+
229
+ for required_param in forward_pattern.In:
230
+ if required_param not in forward_parameters:
231
+ in_matched = False
232
+
233
+ return in_matched and out_matched
234
+
235
+ @staticmethod
236
+ def match_blocks_pattern(
237
+ transformer_blocks: torch.nn.ModuleList,
238
+ forward_pattern: ForwardPattern,
239
+ logging: bool = True,
240
+ ) -> bool:
241
+ assert (
242
+ forward_pattern.Supported
243
+ and forward_pattern in ForwardPattern.supported_patterns()
244
+ ), f"Pattern {forward_pattern} is not support now!"
245
+
246
+ assert isinstance(transformer_blocks, torch.nn.ModuleList)
247
+
248
+ pattern_matched_states = []
249
+ for block in transformer_blocks:
250
+ pattern_matched_states.append(
251
+ BlockAdapter.match_block_pattern(
252
+ block,
253
+ forward_pattern,
254
+ )
255
+ )
256
+
257
+ pattern_matched = all(pattern_matched_states) # all block match
258
+ if pattern_matched and logging:
259
+ block_cls_name = transformer_blocks[0].__class__.__name__
260
+ logger.info(
261
+ f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
262
+ f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
263
+ )
264
+
265
+ return pattern_matched
266
+
44
267
 
45
268
  @dataclasses.dataclass
46
269
  class UnifiedCacheParams:
47
- adapter_params: BlockAdapterParams = None
270
+ block_adapter: BlockAdapter = None
48
271
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
49
272
 
50
273
 
@@ -85,7 +308,7 @@ class UnifiedCacheAdapter:
85
308
 
86
309
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
87
310
  return UnifiedCacheParams(
88
- adapter_params=BlockAdapterParams(
311
+ block_adapter=BlockAdapter(
89
312
  pipe=pipe,
90
313
  transformer=pipe.transformer,
91
314
  blocks=(
@@ -102,7 +325,7 @@ class UnifiedCacheAdapter:
102
325
 
103
326
  assert isinstance(pipe.transformer, MochiTransformer3DModel)
104
327
  return UnifiedCacheParams(
105
- adapter_params=BlockAdapterParams(
328
+ block_adapter=BlockAdapter(
106
329
  pipe=pipe,
107
330
  transformer=pipe.transformer,
108
331
  blocks=pipe.transformer.transformer_blocks,
@@ -116,7 +339,7 @@ class UnifiedCacheAdapter:
116
339
 
117
340
  assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
118
341
  return UnifiedCacheParams(
119
- adapter_params=BlockAdapterParams(
342
+ block_adapter=BlockAdapter(
120
343
  pipe=pipe,
121
344
  transformer=pipe.transformer,
122
345
  blocks=pipe.transformer.transformer_blocks,
@@ -136,7 +359,7 @@ class UnifiedCacheAdapter:
136
359
  (WanTransformer3DModel, WanVACETransformer3DModel),
137
360
  )
138
361
  return UnifiedCacheParams(
139
- adapter_params=BlockAdapterParams(
362
+ block_adapter=BlockAdapter(
140
363
  pipe=pipe,
141
364
  transformer=pipe.transformer,
142
365
  blocks=pipe.transformer.blocks,
@@ -150,7 +373,7 @@ class UnifiedCacheAdapter:
150
373
 
151
374
  assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
152
375
  return UnifiedCacheParams(
153
- adapter_params=BlockAdapterParams(
376
+ block_adapter=BlockAdapter(
154
377
  pipe=pipe,
155
378
  blocks=(
156
379
  pipe.transformer.transformer_blocks
@@ -166,7 +389,7 @@ class UnifiedCacheAdapter:
166
389
 
167
390
  assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
168
391
  return UnifiedCacheParams(
169
- adapter_params=BlockAdapterParams(
392
+ block_adapter=BlockAdapter(
170
393
  pipe=pipe,
171
394
  transformer=pipe.transformer,
172
395
  blocks=pipe.transformer.transformer_blocks,
@@ -180,7 +403,7 @@ class UnifiedCacheAdapter:
180
403
 
181
404
  assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
182
405
  return UnifiedCacheParams(
183
- adapter_params=BlockAdapterParams(
406
+ block_adapter=BlockAdapter(
184
407
  pipe=pipe,
185
408
  transformer=pipe.transformer,
186
409
  blocks=pipe.transformer.transformer_blocks,
@@ -194,7 +417,7 @@ class UnifiedCacheAdapter:
194
417
 
195
418
  assert isinstance(pipe.transformer, AllegroTransformer3DModel)
196
419
  return UnifiedCacheParams(
197
- adapter_params=BlockAdapterParams(
420
+ block_adapter=BlockAdapter(
198
421
  pipe=pipe,
199
422
  transformer=pipe.transformer,
200
423
  blocks=pipe.transformer.transformer_blocks,
@@ -208,7 +431,7 @@ class UnifiedCacheAdapter:
208
431
 
209
432
  assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
210
433
  return UnifiedCacheParams(
211
- adapter_params=BlockAdapterParams(
434
+ block_adapter=BlockAdapter(
212
435
  pipe=pipe,
213
436
  transformer=pipe.transformer,
214
437
  blocks=pipe.transformer.transformer_blocks,
@@ -222,7 +445,7 @@ class UnifiedCacheAdapter:
222
445
 
223
446
  assert isinstance(pipe.transformer, CogView4Transformer2DModel)
224
447
  return UnifiedCacheParams(
225
- adapter_params=BlockAdapterParams(
448
+ block_adapter=BlockAdapter(
226
449
  pipe=pipe,
227
450
  transformer=pipe.transformer,
228
451
  blocks=pipe.transformer.transformer_blocks,
@@ -236,7 +459,7 @@ class UnifiedCacheAdapter:
236
459
 
237
460
  assert isinstance(pipe.transformer, CosmosTransformer3DModel)
238
461
  return UnifiedCacheParams(
239
- adapter_params=BlockAdapterParams(
462
+ block_adapter=BlockAdapter(
240
463
  pipe=pipe,
241
464
  transformer=pipe.transformer,
242
465
  blocks=pipe.transformer.transformer_blocks,
@@ -250,7 +473,7 @@ class UnifiedCacheAdapter:
250
473
 
251
474
  assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
252
475
  return UnifiedCacheParams(
253
- adapter_params=BlockAdapterParams(
476
+ block_adapter=BlockAdapter(
254
477
  pipe=pipe,
255
478
  transformer=pipe.transformer,
256
479
  blocks=pipe.transformer.transformer_blocks,
@@ -264,7 +487,7 @@ class UnifiedCacheAdapter:
264
487
 
265
488
  assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
266
489
  return UnifiedCacheParams(
267
- adapter_params=BlockAdapterParams(
490
+ block_adapter=BlockAdapter(
268
491
  pipe=pipe,
269
492
  transformer=pipe.transformer,
270
493
  blocks=pipe.transformer.blocks,
@@ -278,7 +501,7 @@ class UnifiedCacheAdapter:
278
501
 
279
502
  assert isinstance(pipe.transformer, SD3Transformer2DModel)
280
503
  return UnifiedCacheParams(
281
- adapter_params=BlockAdapterParams(
504
+ block_adapter=BlockAdapter(
282
505
  pipe=pipe,
283
506
  transformer=pipe.transformer,
284
507
  blocks=pipe.transformer.transformer_blocks,
@@ -294,13 +517,13 @@ class UnifiedCacheAdapter:
294
517
  def apply(
295
518
  cls,
296
519
  pipe: DiffusionPipeline = None,
297
- adapter_params: BlockAdapterParams = None,
520
+ block_adapter: BlockAdapter = None,
298
521
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
299
522
  **cache_context_kwargs,
300
523
  ) -> DiffusionPipeline:
301
524
  assert (
302
- pipe is not None or adapter_params is not None
303
- ), "pipe or adapter_params can not both None!"
525
+ pipe is not None or block_adapter is not None
526
+ ), "pipe or block_adapter can not both None!"
304
527
 
305
528
  if pipe is not None:
306
529
  if cls.is_supported(pipe):
@@ -310,7 +533,7 @@ class UnifiedCacheAdapter:
310
533
  )
311
534
  params = cls.get_params(pipe)
312
535
  return cls.cachify(
313
- params.adapter_params,
536
+ params.block_adapter,
314
537
  forward_pattern=params.forward_pattern,
315
538
  **cache_context_kwargs,
316
539
  )
@@ -324,7 +547,7 @@ class UnifiedCacheAdapter:
324
547
  "Adapting cache acceleration using custom BlockAdapter!"
325
548
  )
326
549
  return cls.cachify(
327
- adapter_params,
550
+ block_adapter,
328
551
  forward_pattern=forward_pattern,
329
552
  **cache_context_kwargs,
330
553
  )
@@ -332,22 +555,48 @@ class UnifiedCacheAdapter:
332
555
  @classmethod
333
556
  def cachify(
334
557
  cls,
335
- adapter_params: BlockAdapterParams,
558
+ block_adapter: BlockAdapter,
336
559
  *,
337
560
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
338
561
  **cache_context_kwargs,
339
562
  ) -> DiffusionPipeline:
340
- if adapter_params.check_adapter_params():
341
- assert isinstance(adapter_params.blocks, torch.nn.ModuleList)
563
+
564
+ if block_adapter.auto:
565
+ block_adapter = BlockAdapter.auto_block_adapter(
566
+ block_adapter,
567
+ forward_pattern,
568
+ )
569
+
570
+ if BlockAdapter.check_block_adapter(block_adapter):
342
571
  # Apply cache on pipeline: wrap cache context
343
- cls.create_context(adapter_params.pipe, **cache_context_kwargs)
572
+ cls.create_context(
573
+ block_adapter.pipe,
574
+ **cache_context_kwargs,
575
+ )
344
576
  # Apply cache on transformer: mock cached transformer blocks
345
577
  cls.mock_blocks(
346
- adapter_params,
578
+ block_adapter,
347
579
  forward_pattern=forward_pattern,
348
580
  )
581
+ cls.patch_params(
582
+ block_adapter,
583
+ forward_pattern=forward_pattern,
584
+ **cache_context_kwargs,
585
+ )
586
+ return block_adapter.pipe
349
587
 
350
- return adapter_params.pipe
588
+ @classmethod
589
+ def patch_params(
590
+ cls,
591
+ block_adapter: BlockAdapter,
592
+ forward_pattern: ForwardPattern = None,
593
+ **cache_context_kwargs,
594
+ ):
595
+ block_adapter.transformer._forward_pattern = forward_pattern
596
+ block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
597
+ block_adapter.pipe.__class__._cache_context_kwargs = (
598
+ cache_context_kwargs
599
+ )
351
600
 
352
601
  @classmethod
353
602
  def has_separate_cfg(
@@ -407,28 +656,21 @@ class UnifiedCacheAdapter:
407
656
 
408
657
  pipe.__class__.__call__ = new_call
409
658
  pipe.__class__._is_cached = True
410
- pipe.__class__._cache_options = cache_kwargs
411
659
  return pipe
412
660
 
413
661
  @classmethod
414
662
  def mock_blocks(
415
663
  cls,
416
- adapter_params: BlockAdapterParams,
664
+ block_adapter: BlockAdapter,
417
665
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
418
666
  ) -> torch.nn.Module:
419
- if getattr(adapter_params.transformer, "_is_cached", False):
420
- return adapter_params.transformer
421
-
422
- # Firstly, process some specificial cases (TODO: more patches)
423
- if adapter_params.transformer.__class__.__name__.startswith("Flux"):
424
- adapter_params.transformer = maybe_patch_flux_transformer(
425
- adapter_params.transformer,
426
- blocks=adapter_params.blocks,
427
- )
667
+
668
+ if getattr(block_adapter.transformer, "_is_cached", False):
669
+ return block_adapter.transformer
428
670
 
429
671
  # Check block forward pattern matching
430
- assert cls.match_pattern(
431
- adapter_params.blocks,
672
+ assert BlockAdapter.match_blocks_pattern(
673
+ block_adapter.blocks,
432
674
  forward_pattern=forward_pattern,
433
675
  ), (
434
676
  "No block forward pattern matched, "
@@ -439,22 +681,17 @@ class UnifiedCacheAdapter:
439
681
  cached_blocks = torch.nn.ModuleList(
440
682
  [
441
683
  DBCachedTransformerBlocks(
442
- adapter_params.blocks,
443
- transformer=adapter_params.transformer,
684
+ block_adapter.blocks,
685
+ transformer=block_adapter.transformer,
444
686
  forward_pattern=forward_pattern,
445
687
  )
446
688
  ]
447
689
  )
448
690
  dummy_blocks = torch.nn.ModuleList()
449
691
 
450
- original_forward = adapter_params.transformer.forward
692
+ original_forward = block_adapter.transformer.forward
451
693
 
452
- assert isinstance(adapter_params.dummy_blocks_names, list)
453
- if adapter_params.blocks_name is None:
454
- adapter_params.blocks_name = cls.find_blocks_name(
455
- adapter_params.transformer
456
- )
457
- assert adapter_params.blocks_name is not None
694
+ assert isinstance(block_adapter.dummy_blocks_names, list)
458
695
 
459
696
  @functools.wraps(original_forward)
460
697
  def new_forward(self, *args, **kwargs):
@@ -462,11 +699,11 @@ class UnifiedCacheAdapter:
462
699
  stack.enter_context(
463
700
  unittest.mock.patch.object(
464
701
  self,
465
- adapter_params.blocks_name,
702
+ block_adapter.blocks_name,
466
703
  cached_blocks,
467
704
  )
468
705
  )
469
- for dummy_name in adapter_params.dummy_blocks_names:
706
+ for dummy_name in block_adapter.dummy_blocks_names:
470
707
  stack.enter_context(
471
708
  unittest.mock.patch.object(
472
709
  self,
@@ -476,71 +713,9 @@ class UnifiedCacheAdapter:
476
713
  )
477
714
  return original_forward(*args, **kwargs)
478
715
 
479
- adapter_params.transformer.forward = new_forward.__get__(
480
- adapter_params.transformer
716
+ block_adapter.transformer.forward = new_forward.__get__(
717
+ block_adapter.transformer
481
718
  )
482
- adapter_params.transformer._is_cached = True
719
+ block_adapter.transformer._is_cached = True
483
720
 
484
- return adapter_params.transformer
485
-
486
- @classmethod
487
- def match_pattern(
488
- cls,
489
- transformer_blocks: torch.nn.ModuleList,
490
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
491
- ) -> bool:
492
- pattern_matched_states = []
493
-
494
- assert (
495
- forward_pattern.Supported
496
- and forward_pattern in ForwardPattern.supported_patterns()
497
- ), f"Pattern {forward_pattern} is not support now!"
498
-
499
- for block in transformer_blocks:
500
- forward_parameters = set(
501
- inspect.signature(block.forward).parameters.keys()
502
- )
503
- num_outputs = str(
504
- inspect.signature(block.forward).return_annotation
505
- ).count("torch.Tensor")
506
-
507
- in_matched = True
508
- out_matched = True
509
- if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
510
- # output pattern not match
511
- out_matched = False
512
-
513
- for required_param in forward_pattern.In:
514
- if required_param not in forward_parameters:
515
- in_matched = False
516
-
517
- pattern_matched_states.append(in_matched and out_matched)
518
-
519
- pattern_matched = all(pattern_matched_states) # all block match
520
- if pattern_matched:
521
- block_cls_name = transformer_blocks[0].__class__.__name__
522
- logger.info(
523
- f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
524
- f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
525
- )
526
-
527
- return pattern_matched
528
-
529
- @classmethod
530
- def find_blocks_name(cls, transformer):
531
- blocks_name = None
532
- allow_prefixes = ["transformer", "blocks"]
533
- for attr_name in dir(transformer):
534
- if blocks_name is None:
535
- for prefix in allow_prefixes:
536
- # transformer_blocks, blocks
537
- if attr_name.startswith(prefix):
538
- blocks_name = attr_name
539
- logger.info(f"Auto selected blocks name: {blocks_name}")
540
- # only find one transformer blocks name
541
- break
542
- if blocks_name is None:
543
- logger.warning(
544
- "Auto selected blocks name failed, please set it manually."
545
- )
546
- return blocks_name
721
+ return block_adapter.transformer
@@ -1,5 +1,6 @@
1
1
  import inspect
2
2
  import torch
3
+ import torch.distributed as dist
3
4
 
4
5
  from cache_dit.cache_factory import cache_context
5
6
  from cache_dit.cache_factory import ForwardPattern
@@ -179,10 +180,15 @@ class DBCachedTransformerBlocks(torch.nn.Module):
179
180
  @torch.compiler.disable
180
181
  def _is_parallelized(self):
181
182
  # Compatible with distributed inference.
182
- return all(
183
+ return any(
183
184
  (
184
- self.transformer is not None,
185
- getattr(self.transformer, "_is_parallelized", False),
185
+ all(
186
+ (
187
+ self.transformer is not None,
188
+ getattr(self.transformer, "_is_parallelized", False),
189
+ )
190
+ ),
191
+ (dist.is_initialized() and dist.get_world_size() > 1),
186
192
  )
187
193
  )
188
194