cache-dit 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (45) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/block_adapters/__init__.py +37 -0
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +75 -4
  5. cache_dit/cache_factory/block_adapters/block_registers.py +44 -17
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +72 -30
  7. cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
  8. cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
  9. cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
  10. cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
  11. cache_dit/cache_factory/cache_interface.py +102 -28
  12. cache_dit/cache_factory/forward_pattern.py +14 -14
  13. cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
  14. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
  15. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
  16. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
  17. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -49
  18. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  19. cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
  20. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  21. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  22. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
  23. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  24. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
  25. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  26. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
  27. cache_dit/parallelism/parallel_backend.py +2 -0
  28. cache_dit/parallelism/parallel_config.py +10 -3
  29. cache_dit/parallelism/parallel_interface.py +14 -5
  30. cache_dit/quantize/backends/__init__.py +1 -0
  31. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  32. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  33. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
  34. cache_dit/quantize/quantize_backend.py +0 -0
  35. cache_dit/quantize/quantize_config.py +0 -0
  36. cache_dit/quantize/quantize_interface.py +3 -16
  37. cache_dit/utils.py +56 -20
  38. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/METADATA +24 -13
  39. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
  40. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  41. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  42. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
  43. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
  44. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
  45. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.0.8'
32
- __version_tuple__ = version_tuple = (1, 0, 8)
31
+ __version__ = version = '1.0.10'
32
+ __version_tuple__ = version_tuple = (1, 0, 10)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -8,6 +8,7 @@ from cache_dit.cache_factory.patch_functors import PatchFunctor
8
8
 
9
9
  from cache_dit.cache_factory.block_adapters import BlockAdapter
10
10
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
11
+ from cache_dit.cache_factory.block_adapters import FakeDiffusionPipeline
11
12
 
12
13
  from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
13
14
  from cache_dit.cache_factory.cache_contexts import DBCacheConfig
@@ -1,5 +1,8 @@
1
1
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
2
  from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
3
+ from cache_dit.cache_factory.block_adapters.block_adapters import (
4
+ FakeDiffusionPipeline,
5
+ )
3
6
  from cache_dit.cache_factory.block_adapters.block_adapters import ParamsModifier
4
7
  from cache_dit.cache_factory.block_adapters.block_registers import (
5
8
  BlockAdapterRegistry,
@@ -27,6 +30,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
27
30
  ForwardPattern.Pattern_1,
28
31
  ForwardPattern.Pattern_1,
29
32
  ],
33
+ check_forward_pattern=True,
30
34
  **kwargs,
31
35
  )
32
36
  else:
@@ -41,6 +45,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
41
45
  ForwardPattern.Pattern_1,
42
46
  ForwardPattern.Pattern_3,
43
47
  ],
48
+ check_forward_pattern=True,
44
49
  **kwargs,
45
50
  )
46
51
 
@@ -55,6 +60,7 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
55
60
  transformer=pipe.transformer,
56
61
  blocks=pipe.transformer.transformer_blocks,
57
62
  forward_pattern=ForwardPattern.Pattern_0,
63
+ check_forward_pattern=True,
58
64
  **kwargs,
59
65
  )
60
66
 
@@ -69,6 +75,7 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
69
75
  transformer=pipe.transformer,
70
76
  blocks=pipe.transformer.transformer_blocks,
71
77
  forward_pattern=ForwardPattern.Pattern_0,
78
+ check_forward_pattern=True,
72
79
  **kwargs,
73
80
  )
74
81
 
@@ -104,6 +111,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
104
111
  ForwardPattern.Pattern_2,
105
112
  ForwardPattern.Pattern_2,
106
113
  ],
114
+ check_forward_pattern=True,
107
115
  has_separate_cfg=True,
108
116
  **kwargs,
109
117
  )
@@ -114,6 +122,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
114
122
  transformer=pipe.transformer,
115
123
  blocks=pipe.transformer.blocks,
116
124
  forward_pattern=ForwardPattern.Pattern_2,
125
+ check_forward_pattern=True,
117
126
  has_separate_cfg=True,
118
127
  **kwargs,
119
128
  )
@@ -135,6 +144,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
135
144
  ForwardPattern.Pattern_0,
136
145
  ForwardPattern.Pattern_0,
137
146
  ],
147
+ check_forward_pattern=True,
138
148
  # The type hint in diffusers is wrong
139
149
  check_num_outputs=False,
140
150
  **kwargs,
@@ -159,6 +169,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
159
169
  blocks=pipe.transformer.transformer_blocks,
160
170
  forward_pattern=ForwardPattern.Pattern_1,
161
171
  patch_functor=QwenImageControlNetPatchFunctor(),
172
+ check_forward_pattern=True,
162
173
  has_separate_cfg=True,
163
174
  )
164
175
  else:
@@ -167,6 +178,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
167
178
  transformer=pipe.transformer,
168
179
  blocks=pipe.transformer.transformer_blocks,
169
180
  forward_pattern=ForwardPattern.Pattern_1,
181
+ check_forward_pattern=True,
170
182
  has_separate_cfg=True,
171
183
  **kwargs,
172
184
  )
@@ -182,6 +194,7 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
182
194
  transformer=pipe.transformer,
183
195
  blocks=pipe.transformer.transformer_blocks,
184
196
  forward_pattern=ForwardPattern.Pattern_2,
197
+ check_forward_pattern=True,
185
198
  **kwargs,
186
199
  )
187
200
 
@@ -196,6 +209,7 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
196
209
  transformer=pipe.transformer,
197
210
  blocks=pipe.transformer.transformer_blocks,
198
211
  forward_pattern=ForwardPattern.Pattern_2,
212
+ check_forward_pattern=True,
199
213
  **kwargs,
200
214
  )
201
215
 
@@ -210,6 +224,7 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
210
224
  transformer=pipe.transformer,
211
225
  blocks=pipe.transformer.transformer_blocks,
212
226
  forward_pattern=ForwardPattern.Pattern_0,
227
+ check_forward_pattern=True,
213
228
  **kwargs,
214
229
  )
215
230
 
@@ -224,6 +239,7 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
224
239
  transformer=pipe.transformer,
225
240
  blocks=pipe.transformer.transformer_blocks,
226
241
  forward_pattern=ForwardPattern.Pattern_0,
242
+ check_forward_pattern=True,
227
243
  has_separate_cfg=True,
228
244
  **kwargs,
229
245
  )
@@ -239,6 +255,7 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
239
255
  transformer=pipe.transformer,
240
256
  blocks=pipe.transformer.transformer_blocks,
241
257
  forward_pattern=ForwardPattern.Pattern_2,
258
+ check_forward_pattern=True,
242
259
  has_separate_cfg=True,
243
260
  **kwargs,
244
261
  )
@@ -254,6 +271,7 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
254
271
  transformer=pipe.transformer,
255
272
  blocks=pipe.transformer.transformer_blocks,
256
273
  forward_pattern=ForwardPattern.Pattern_0,
274
+ check_forward_pattern=True,
257
275
  **kwargs,
258
276
  )
259
277
 
@@ -271,6 +289,7 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
271
289
  # encoder_hidden_states will never change in the blocks
272
290
  # forward loop.
273
291
  forward_pattern=ForwardPattern.Pattern_3,
292
+ check_forward_pattern=True,
274
293
  has_separate_cfg=True,
275
294
  **kwargs,
276
295
  )
@@ -286,6 +305,7 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
286
305
  transformer=pipe.transformer,
287
306
  blocks=pipe.transformer.transformer_blocks,
288
307
  forward_pattern=ForwardPattern.Pattern_1,
308
+ check_forward_pattern=True,
289
309
  **kwargs,
290
310
  )
291
311
 
@@ -300,6 +320,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
300
320
  transformer=pipe.transformer,
301
321
  blocks=pipe.transformer.transformer_blocks,
302
322
  forward_pattern=ForwardPattern.Pattern_0,
323
+ check_forward_pattern=True,
303
324
  **kwargs,
304
325
  )
305
326
 
@@ -316,6 +337,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
316
337
  blocks=pipe.transformer.transformer_blocks,
317
338
  forward_pattern=ForwardPattern.Pattern_3,
318
339
  patch_functor=DiTPatchFunctor(),
340
+ check_forward_pattern=True,
319
341
  **kwargs,
320
342
  )
321
343
 
@@ -330,6 +352,7 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
330
352
  transformer=pipe.transformer,
331
353
  blocks=pipe.transformer.transformer_layers,
332
354
  forward_pattern=ForwardPattern.Pattern_3,
355
+ check_forward_pattern=True,
333
356
  **kwargs,
334
357
  )
335
358
 
@@ -350,6 +373,7 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
350
373
  ForwardPattern.Pattern_0,
351
374
  ForwardPattern.Pattern_0,
352
375
  ],
376
+ check_forward_pattern=True,
353
377
  **kwargs,
354
378
  )
355
379
 
@@ -367,6 +391,7 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
367
391
  transformer=pipe.transformer,
368
392
  blocks=pipe.transformer.layers,
369
393
  forward_pattern=ForwardPattern.Pattern_3,
394
+ check_forward_pattern=True,
370
395
  **kwargs,
371
396
  )
372
397
 
@@ -381,6 +406,7 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
381
406
  transformer=pipe.transformer,
382
407
  blocks=pipe.transformer.layers,
383
408
  forward_pattern=ForwardPattern.Pattern_3,
409
+ check_forward_pattern=True,
384
410
  **kwargs,
385
411
  )
386
412
 
@@ -395,6 +421,7 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
395
421
  transformer=pipe.transformer,
396
422
  blocks=pipe.transformer.transformer_blocks,
397
423
  forward_pattern=ForwardPattern.Pattern_3,
424
+ check_forward_pattern=True,
398
425
  **kwargs,
399
426
  )
400
427
 
@@ -409,6 +436,7 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
409
436
  transformer=pipe.transformer,
410
437
  blocks=pipe.transformer.transformer_blocks,
411
438
  forward_pattern=ForwardPattern.Pattern_3,
439
+ check_forward_pattern=True,
412
440
  **kwargs,
413
441
  )
414
442
 
@@ -423,6 +451,7 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
423
451
  transformer=pipe.transformer,
424
452
  blocks=pipe.transformer.transformer_blocks,
425
453
  forward_pattern=ForwardPattern.Pattern_3,
454
+ check_forward_pattern=True,
426
455
  **kwargs,
427
456
  )
428
457
 
@@ -445,6 +474,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
445
474
  ForwardPattern.Pattern_1,
446
475
  ForwardPattern.Pattern_1,
447
476
  ],
477
+ check_forward_pattern=True,
448
478
  **kwargs,
449
479
  )
450
480
  else:
@@ -459,6 +489,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
459
489
  ForwardPattern.Pattern_1,
460
490
  ForwardPattern.Pattern_3,
461
491
  ],
492
+ check_forward_pattern=True,
462
493
  **kwargs,
463
494
  )
464
495
 
@@ -473,6 +504,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
473
504
  transformer=pipe.transformer,
474
505
  blocks=pipe.transformer.single_transformer_blocks,
475
506
  forward_pattern=ForwardPattern.Pattern_3,
507
+ check_forward_pattern=True,
476
508
  **kwargs,
477
509
  )
478
510
 
@@ -495,6 +527,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
495
527
  ForwardPattern.Pattern_3,
496
528
  ],
497
529
  patch_functor=ChromaPatchFunctor(),
530
+ check_forward_pattern=True,
498
531
  has_separate_cfg=True,
499
532
  **kwargs,
500
533
  )
@@ -510,6 +543,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
510
543
  transformer=pipe.prior,
511
544
  blocks=pipe.prior.transformer_blocks,
512
545
  forward_pattern=ForwardPattern.Pattern_3,
546
+ check_forward_pattern=True,
513
547
  **kwargs,
514
548
  )
515
549
 
@@ -559,6 +593,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
559
593
  blocks=pipe.transformer.blocks,
560
594
  forward_pattern=ForwardPattern.Pattern_3,
561
595
  patch_functor=HunyuanDiTPatchFunctor(),
596
+ check_forward_pattern=True,
562
597
  **kwargs,
563
598
  )
564
599
 
@@ -575,6 +610,7 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
575
610
  blocks=pipe.transformer.blocks,
576
611
  forward_pattern=ForwardPattern.Pattern_3,
577
612
  patch_functor=HunyuanDiTPatchFunctor(),
613
+ check_forward_pattern=True,
578
614
  **kwargs,
579
615
  )
580
616
 
@@ -613,6 +649,7 @@ def prx_adapter(pipe, **kwargs) -> BlockAdapter:
613
649
  transformer=pipe.transformer,
614
650
  blocks=pipe.transformer.blocks,
615
651
  forward_pattern=ForwardPattern.Pattern_3,
652
+ check_forward_pattern=True,
616
653
  check_num_outputs=False,
617
654
  **kwargs,
618
655
  )
@@ -6,7 +6,7 @@ from collections.abc import Iterable
6
6
 
7
7
  from typing import Any, Tuple, List, Optional, Union
8
8
 
9
- from diffusers import DiffusionPipeline
9
+ from diffusers import DiffusionPipeline, ModelMixin
10
10
  from cache_dit.cache_factory.patch_functors import PatchFunctor
11
11
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
12
12
  from cache_dit.cache_factory.params_modifier import ParamsModifier
@@ -16,12 +16,22 @@ from cache_dit.logger import init_logger
16
16
  logger = init_logger(__name__)
17
17
 
18
18
 
19
+ class FakeDiffusionPipeline:
20
+ # A placeholder for pipelines when pipe is None.
21
+ def __init__(
22
+ self,
23
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
24
+ ):
25
+ self.transformer = transformer # Reference only
26
+
27
+
19
28
  @dataclasses.dataclass
20
29
  class BlockAdapter:
21
30
 
22
31
  # Transformer configurations.
23
32
  pipe: Union[
24
33
  DiffusionPipeline,
34
+ FakeDiffusionPipeline,
25
35
  Any,
26
36
  ] = None
27
37
 
@@ -73,7 +83,7 @@ class BlockAdapter:
73
83
  ]
74
84
  ] = None
75
85
 
76
- check_forward_pattern: bool = True
86
+ check_forward_pattern: Optional[bool] = None
77
87
  check_num_outputs: bool = False
78
88
 
79
89
  # Pipeline Level Flags
@@ -110,12 +120,43 @@ class BlockAdapter:
110
120
  def __post_init__(self):
111
121
  if self.skip_post_init:
112
122
  return
123
+
124
+ self.maybe_fake_pipe()
113
125
  if any((self.pipe is not None, self.transformer is not None)):
114
126
  self.maybe_fill_attrs()
115
127
  self.maybe_patchify()
116
128
  self.maybe_skip_checks()
117
129
 
130
+ def maybe_fake_pipe(self):
131
+ if self.pipe is None:
132
+ self.pipe = FakeDiffusionPipeline()
133
+ logger.warning("pipe is None, use FakeDiffusionPipeline instead.")
134
+
118
135
  def maybe_skip_checks(self):
136
+ if self.check_forward_pattern is None:
137
+ if self.transformer is not None:
138
+ if self.nested_depth(self.transformer) == 0:
139
+ transformer = self.transformer
140
+ elif self.nested_depth(self.transformer) == 1:
141
+ transformer = self.transformer[0]
142
+ else:
143
+ raise ValueError(
144
+ "transformer nested depth can't more than 1, "
145
+ f"current is: {self.nested_depth(self.transformer)}"
146
+ )
147
+ if transformer.__module__.startswith("diffusers"):
148
+ self.check_forward_pattern = True
149
+ logger.info(
150
+ f"Found transformer from diffusers: {transformer.__module__} "
151
+ "enable check_forward_pattern by default."
152
+ )
153
+ else:
154
+ self.check_forward_pattern = False
155
+ logger.info(
156
+ f"Found transformer NOT from diffusers: {transformer.__module__} "
157
+ "disable check_forward_pattern by default."
158
+ )
159
+
119
160
  if getattr(self.transformer, "_hf_hook", None) is not None:
120
161
  logger.warning("_hf_hook is not None, force skip pattern check!")
121
162
  self.check_forward_pattern = False
@@ -208,7 +249,10 @@ class BlockAdapter:
208
249
  if self.transformer is not None:
209
250
  self.patch_functor.apply(self.transformer, *args, **kwargs)
210
251
  else:
211
- assert hasattr(self.pipe, "transformer")
252
+ assert hasattr(self.pipe, "transformer"), (
253
+ "pipe.transformer can not be None when patch_functor "
254
+ "is provided and transformer is None."
255
+ )
212
256
  self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
213
257
 
214
258
  @staticmethod
@@ -224,6 +268,10 @@ class BlockAdapter:
224
268
  adapter.forward_pattern is not None
225
269
  ), "adapter.forward_pattern can not be None."
226
270
  pipe = adapter.pipe
271
+ if isinstance(pipe, FakeDiffusionPipeline):
272
+ raise ValueError(
273
+ "Can not auto block adapter for FakeDiffusionPipeline."
274
+ )
227
275
 
228
276
  assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
229
277
 
@@ -489,6 +537,7 @@ class BlockAdapter:
489
537
  @staticmethod
490
538
  def normalize(
491
539
  adapter: "BlockAdapter",
540
+ unique: bool = True,
492
541
  ) -> "BlockAdapter":
493
542
 
494
543
  if getattr(adapter, "_is_normalized", False):
@@ -523,7 +572,10 @@ class BlockAdapter:
523
572
  adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
524
573
  adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
525
574
  adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
526
- BlockAdapter.unique(adapter)
575
+ # Some times, the cache_config will be None.
576
+ # So we do not perform unique check here.
577
+ if unique:
578
+ BlockAdapter.unique(adapter)
527
579
 
528
580
  adapter._is_normalized = True
529
581
 
@@ -571,6 +623,10 @@ class BlockAdapter:
571
623
  if not getattr(adapter, "_is_normalized", False):
572
624
  raise RuntimeError("block_adapter must be normailzed.")
573
625
 
626
+ @classmethod
627
+ def is_normalized(cls, adapter: "BlockAdapter") -> bool:
628
+ return getattr(adapter, "_is_normalized", False)
629
+
574
630
  @classmethod
575
631
  def is_cached(cls, adapter: Any) -> bool:
576
632
  if isinstance(adapter, cls):
@@ -592,6 +648,21 @@ class BlockAdapter:
592
648
  else:
593
649
  return getattr(adapter, "_is_cached", False)
594
650
 
651
+ @classmethod
652
+ def is_parallelized(cls, adapter: Any) -> bool:
653
+ if isinstance(adapter, cls):
654
+ cls.assert_normalized(adapter)
655
+ return getattr(adapter.transformer[0], "_is_parallelized", False)
656
+ elif isinstance(adapter, DiffusionPipeline):
657
+ return getattr(adapter.transformer, "_is_parallelized", False)
658
+ elif isinstance(adapter, torch.nn.Module):
659
+ return getattr(adapter, "_is_parallelized", False)
660
+ elif isinstance(adapter, list): # [TRN_0,...]
661
+ assert isinstance(adapter[0], torch.nn.Module)
662
+ return getattr(adapter[0], "_is_parallelized", False)
663
+ else:
664
+ return getattr(adapter, "_is_parallelized", False)
665
+
595
666
  @classmethod
596
667
  def nested_depth(cls, obj: Any):
597
668
  # str: 0; List[str]: 1; List[List[str]]: 2
@@ -1,7 +1,11 @@
1
- from typing import Any, Tuple, List, Dict, Callable
1
+ import torch
2
+ from typing import Any, Tuple, List, Dict, Callable, Union
2
3
 
3
4
  from diffusers import DiffusionPipeline
4
- from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
5
+ from cache_dit.cache_factory.block_adapters.block_adapters import (
6
+ BlockAdapter,
7
+ FakeDiffusionPipeline,
8
+ )
5
9
 
6
10
  from cache_dit.logger import init_logger
7
11
 
@@ -35,24 +39,42 @@ class BlockAdapterRegistry:
35
39
  @classmethod
36
40
  def get_adapter(
37
41
  cls,
38
- pipe: DiffusionPipeline | str | Any,
42
+ pipe_or_module: DiffusionPipeline | torch.nn.Module | str | Any,
39
43
  **kwargs,
40
- ) -> BlockAdapter:
41
- if not isinstance(pipe, str):
42
- pipe_cls_name: str = pipe.__class__.__name__
44
+ ) -> BlockAdapter | None:
45
+ if not isinstance(pipe_or_module, str):
46
+ cls_name: str = pipe_or_module.__class__.__name__
43
47
  else:
44
- pipe_cls_name = pipe
48
+ cls_name = pipe_or_module
45
49
 
46
50
  for name in cls._adapters:
47
- if pipe_cls_name.startswith(name):
48
- return cls._adapters[name](pipe, **kwargs)
49
-
50
- return BlockAdapter()
51
+ if cls_name.startswith(name):
52
+ if not isinstance(pipe_or_module, DiffusionPipeline):
53
+ assert isinstance(pipe_or_module, torch.nn.Module)
54
+ # NOTE: Make pre-registered adapters support Transformer-only case.
55
+ # WARN: This branch is not officially supported and only for testing
56
+ # purpose. We construct a fake diffusion pipeline that contains the
57
+ # given transformer module. Currently, only works for DiT models which
58
+ # only have one transformer module. Case like multiple transformers
59
+ # is not supported, e.g, Wan2.2. Please use BlockAdapter directly for
60
+ # such cases.
61
+ return cls._adapters[name](
62
+ FakeDiffusionPipeline(pipe_or_module), **kwargs
63
+ )
64
+ else:
65
+ return cls._adapters[name](pipe_or_module, **kwargs)
66
+
67
+ return None
51
68
 
52
69
  @classmethod
53
70
  def has_separate_cfg(
54
71
  cls,
55
- pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
72
+ pipe_or_adapter: Union[
73
+ DiffusionPipeline,
74
+ FakeDiffusionPipeline,
75
+ BlockAdapter,
76
+ Any,
77
+ ],
56
78
  ) -> bool:
57
79
 
58
80
  # Prefer custom setting from block adapter.
@@ -60,11 +82,16 @@ class BlockAdapterRegistry:
60
82
  return pipe_or_adapter.has_separate_cfg
61
83
 
62
84
  has_separate_cfg = False
85
+ if isinstance(pipe_or_adapter, FakeDiffusionPipeline):
86
+ return False
87
+
63
88
  if isinstance(pipe_or_adapter, DiffusionPipeline):
64
- has_separate_cfg = cls.get_adapter(
89
+ adapter = cls.get_adapter(
65
90
  pipe_or_adapter,
66
91
  skip_post_init=True, # check cfg setting only
67
- ).has_separate_cfg
92
+ )
93
+ if adapter is not None:
94
+ has_separate_cfg = adapter.has_separate_cfg
68
95
 
69
96
  if has_separate_cfg:
70
97
  return True
@@ -77,11 +104,11 @@ class BlockAdapterRegistry:
77
104
  return False
78
105
 
79
106
  @classmethod
80
- def is_supported(cls, pipe) -> bool:
81
- pipe_cls_name: str = pipe.__class__.__name__
107
+ def is_supported(cls, pipe_or_module) -> bool:
108
+ cls_name: str = pipe_or_module.__class__.__name__
82
109
 
83
110
  for name in cls._adapters:
84
- if pipe_cls_name.startswith(name):
111
+ if cls_name.startswith(name):
85
112
  return True
86
113
  return False
87
114