cache-dit 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -2,6 +2,7 @@ import torch
2
2
 
3
3
  import inspect
4
4
  import dataclasses
5
+ from collections.abc import Iterable
5
6
 
6
7
  from typing import Any, Tuple, List, Optional, Union
7
8
 
@@ -74,7 +75,7 @@ class BlockAdapter:
74
75
  List[List[ParamsModifier]],
75
76
  ] = None
76
77
 
77
- check_num_outputs: bool = True
78
+ check_num_outputs: bool = False
78
79
 
79
80
  # Pipeline Level Flags
80
81
  # Patch Functor: Flux, etc.
@@ -82,9 +83,6 @@ class BlockAdapter:
82
83
  # Flags for separate cfg
83
84
  has_separate_cfg: bool = False
84
85
 
85
- # Other Flags
86
- disable_patch: bool = False
87
-
88
86
  # Flags to control auto block adapter
89
87
  # NOTE: NOT support for multi-transformers.
90
88
  auto: bool = False
@@ -107,15 +105,94 @@ class BlockAdapter:
107
105
  default="max", metadata={"allowed_values": ["max", "min"]}
108
106
  )
109
107
 
108
+ # Other Flags
109
+ skip_post_init: bool = False
110
+
110
111
  def __post_init__(self):
111
- assert any((self.pipe is not None, self.transformer is not None))
112
- self.patchify()
112
+ if self.skip_post_init:
113
+ return
114
+ if any((self.pipe is not None, self.transformer is not None)):
115
+ self.maybe_fill_attrs()
116
+ self.maybe_patchify()
117
+
118
+ def maybe_fill_attrs(self):
119
+ # NOTE: This func should be call before normalize.
120
+ # Allow empty `blocks_names`, we will auto fill it.
121
+ # TODO: preprocess more empty attrs.
122
+ if (
123
+ self.transformer is not None
124
+ and self.blocks is not None
125
+ and self.blocks_name is None
126
+ ):
113
127
 
114
- def patchify(self, *args, **kwargs):
128
+ def _find(transformer, blocks):
129
+ attr_names = dir(transformer)
130
+ assert isinstance(blocks, torch.nn.ModuleList)
131
+ blocks_name = None
132
+ for attr_name in attr_names:
133
+ if (
134
+ attr := getattr(transformer, attr_name, None)
135
+ ) is not None:
136
+ if isinstance(attr, torch.nn.ModuleList) and id(
137
+ attr
138
+ ) == id(blocks):
139
+ blocks_name = attr_name
140
+ break
141
+ assert (
142
+ blocks_name is not None
143
+ ), "No blocks_name match, please set it manually!"
144
+ return blocks_name
145
+
146
+ if self.nested_depth(self.transformer) == 0:
147
+ if self.nested_depth(self.blocks) == 0: # str
148
+ self.blocks_name = _find(self.transformer, self.blocks)
149
+ elif self.nested_depth(self.blocks) == 1:
150
+ self.blocks_name = [
151
+ _find(self.transformer, blocks)
152
+ for blocks in self.blocks
153
+ ]
154
+ else:
155
+ raise ValueError(
156
+ "Blocks nested depth can't more than 1 if transformer "
157
+ f"is not a list, current is: {self.nested_depth(self.blocks)}"
158
+ )
159
+ elif self.nested_depth(self.transformer) == 1: # List[str]
160
+ if self.nested_depth(self.blocks) == 1: # List[str]
161
+ assert len(self.transformer) == len(self.blocks)
162
+ self.blocks_name = [
163
+ _find(transformer, blocks)
164
+ for transformer, blocks in zip(
165
+ self.transformer, self.blocks
166
+ )
167
+ ]
168
+ elif self.nested_depth(self.blocks) == 2: # List[List[str]]
169
+ assert len(self.transformer) == len(self.blocks)
170
+ self.blocks_name = []
171
+ for i in range(len(self.blocks)):
172
+ self.blocks_name.append(
173
+ [
174
+ _find(self.transformer[i], blocks)
175
+ for blocks in self.blocks[i]
176
+ ]
177
+ )
178
+ else:
179
+ raise ValueError(
180
+ "Blocks nested depth can only be 1 or 2 "
181
+ "if transformer is a list, current is: "
182
+ f"{self.nested_depth(self.blocks)}"
183
+ )
184
+ else:
185
+ raise ValueError(
186
+ "transformer nested depth can't more than 1, "
187
+ f"current is: {self.nested_depth(self.transformer)}"
188
+ )
189
+ logger.info(f"Auto fill blocks_name: {self.blocks_name}.")
190
+
191
+ def maybe_patchify(self, *args, **kwargs):
115
192
  # Process some specificial cases, specific for transformers
116
193
  # that has different forward patterns between single_transformer_blocks
117
194
  # and transformer_blocks , such as Flux (diffusers < 0.35.0).
118
- if self.patch_functor is not None and not self.disable_patch:
195
+ if self.patch_functor is not None:
119
196
  if self.transformer is not None:
120
197
  self.patch_functor.apply(self.transformer, *args, **kwargs)
121
198
  else:
@@ -141,7 +218,7 @@ class BlockAdapter:
141
218
  transformer = pipe.transformer
142
219
 
143
220
  # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
144
- blocks, blocks_name = BlockAdapter.find_blocks(
221
+ blocks, blocks_name = BlockAdapter.find_match_blocks(
145
222
  transformer=transformer,
146
223
  allow_prefixes=adapter.allow_prefixes,
147
224
  allow_suffixes=adapter.allow_suffixes,
@@ -164,6 +241,10 @@ class BlockAdapter:
164
241
  def check_block_adapter(
165
242
  adapter: "BlockAdapter",
166
243
  ) -> bool:
244
+
245
+ if getattr(adapter, "_is_normlized", False):
246
+ return True
247
+
167
248
  def _check_warning(attr: str):
168
249
  if getattr(adapter, attr, None) is None:
169
250
  logger.warning(f"{attr} is None!")
@@ -185,24 +266,23 @@ class BlockAdapter:
185
266
  if not _check_warning("forward_pattern"):
186
267
  return False
187
268
 
188
- if isinstance(adapter.blocks, list):
189
- for i, blocks in enumerate(adapter.blocks):
190
- if not isinstance(blocks, torch.nn.ModuleList):
191
- logger.warning(f"blocks[{i}] is not ModuleList.")
192
- return False
269
+ if BlockAdapter.nested_depth(adapter.blocks) == 0:
270
+ blocks = adapter.blocks
193
271
  else:
194
- if not isinstance(adapter.blocks, torch.nn.ModuleList):
195
- logger.warning("blocks is not ModuleList.")
196
- return False
272
+ blocks = BlockAdapter.flatten(adapter.blocks)[0]
273
+
274
+ if not isinstance(blocks, torch.nn.ModuleList):
275
+ logger.warning("blocks is not ModuleList.")
276
+ return False
197
277
 
198
278
  return True
199
279
 
200
280
  @staticmethod
201
- def find_blocks(
281
+ def find_match_blocks(
202
282
  transformer: torch.nn.Module,
203
283
  allow_prefixes: List[str] = [
204
- "transformer",
205
- "single_transformer",
284
+ "transformer_blocks",
285
+ "single_transformer_blocks",
206
286
  "blocks",
207
287
  "layers",
208
288
  "single_stream_blocks",
@@ -230,10 +310,10 @@ class BlockAdapter:
230
310
  valid_count = []
231
311
  forward_pattern = kwargs.pop("forward_pattern", None)
232
312
  for blocks_name in blocks_names:
233
- if blocks := getattr(transformer, blocks_name, None):
313
+ if (blocks := getattr(transformer, blocks_name, None)) is not None:
234
314
  if isinstance(blocks, torch.nn.ModuleList):
235
315
  block = blocks[0]
236
- block_cls_name = block.__class__.__name__
316
+ block_cls_name: str = block.__class__.__name__
237
317
  # Check suffixes
238
318
  if isinstance(block, torch.nn.Module) and (
239
319
  any(
@@ -293,6 +373,18 @@ class BlockAdapter:
293
373
 
294
374
  return final_blocks, final_name
295
375
 
376
+ @staticmethod
377
+ def find_blocks(
378
+ transformer: torch.nn.Module,
379
+ ) -> List[torch.nn.ModuleList]:
380
+ total_blocks = []
381
+ for attr in dir(transformer):
382
+ if (blocks := getattr(transformer, attr, None)) is not None:
383
+ if isinstance(blocks, torch.nn.ModuleList):
384
+ if isinstance(blocks[0], torch.nn.Module):
385
+ total_blocks.append(blocks)
386
+ return total_blocks
387
+
296
388
  @staticmethod
297
389
  def match_block_pattern(
298
390
  block: torch.nn.Module,
@@ -373,103 +465,51 @@ class BlockAdapter:
373
465
  if getattr(adapter, "_is_normalized", False):
374
466
  return adapter
375
467
 
376
- if not isinstance(adapter.transformer, list):
468
+ if BlockAdapter.nested_depth(adapter.transformer) == 0:
377
469
  adapter.transformer = [adapter.transformer]
378
470
 
379
- if isinstance(adapter.blocks, torch.nn.ModuleList):
380
- # blocks_0 = [[blocks_0,],] -> match [TRN_0,]
381
- adapter.blocks = [[adapter.blocks]]
382
- elif isinstance(adapter.blocks, list):
383
- if isinstance(adapter.blocks[0], torch.nn.ModuleList):
384
- # [blocks_0, blocks_1] -> [[blocks_0, blocks_1],] -> match [TRN_0,]
385
- if len(adapter.blocks) == len(adapter.transformer):
386
- adapter.blocks = [[blocks] for blocks in adapter.blocks]
387
- else:
388
- adapter.blocks = [adapter.blocks]
389
- elif isinstance(adapter.blocks[0], list):
390
- # [[blocks_0, blocks_1],[blocks_2, blocks_3],] -> match [TRN_0, TRN_1,]
391
- pass
392
-
393
- if isinstance(adapter.blocks_name, str):
394
- adapter.blocks_name = [[adapter.blocks_name]]
395
- elif isinstance(adapter.blocks_name, list):
396
- if isinstance(adapter.blocks_name[0], str):
397
- if len(adapter.blocks_name) == len(adapter.transformer):
398
- adapter.blocks_name = [
399
- [blocks_name] for blocks_name in adapter.blocks_name
400
- ]
401
- else:
402
- adapter.blocks_name = [adapter.blocks_name]
403
- elif isinstance(adapter.blocks_name[0], list):
404
- pass
405
-
406
- if isinstance(adapter.forward_pattern, ForwardPattern):
407
- adapter.forward_pattern = [[adapter.forward_pattern]]
408
- elif isinstance(adapter.forward_pattern, list):
409
- if isinstance(adapter.forward_pattern[0], ForwardPattern):
410
- if len(adapter.forward_pattern) == len(adapter.transformer):
411
- adapter.forward_pattern = [
412
- [forward_pattern]
413
- for forward_pattern in adapter.forward_pattern
414
- ]
415
- else:
416
- adapter.forward_pattern = [adapter.forward_pattern]
417
- elif isinstance(adapter.forward_pattern[0], list):
418
- pass
419
-
420
- if isinstance(adapter.dummy_blocks_names, list):
421
- if len(adapter.dummy_blocks_names) > 0:
422
- if isinstance(adapter.dummy_blocks_names[0], str):
423
- if len(adapter.dummy_blocks_names) == len(
424
- adapter.transformer
425
- ):
426
- adapter.dummy_blocks_names = [
427
- [dummy_blocks_names]
428
- for dummy_blocks_names in adapter.dummy_blocks_names
429
- ]
430
- else:
431
- adapter.dummy_blocks_names = [
432
- adapter.dummy_blocks_names
433
- ]
434
- elif isinstance(adapter.dummy_blocks_names[0], list):
435
- pass
436
- else:
437
- # Empty dummy_blocks_names
438
- adapter.dummy_blocks_names = [
439
- [] for _ in range(len(adapter.transformer))
440
- ]
441
-
442
- if adapter.params_modifiers is not None:
443
- if isinstance(adapter.params_modifiers, ParamsModifier):
444
- adapter.params_modifiers = [[adapter.params_modifiers]]
445
- elif isinstance(adapter.params_modifiers, list):
446
- if isinstance(adapter.params_modifiers[0], ParamsModifier):
447
- if len(adapter.params_modifiers) == len(
448
- adapter.transformer
449
- ):
450
- adapter.params_modifiers = [
451
- [params_modifiers]
452
- for params_modifiers in adapter.params_modifiers
453
- ]
471
+ def _normalize_attr(attr: Any):
472
+ normalized_attr = attr
473
+ if attr is None:
474
+ return normalized_attr
475
+
476
+ if BlockAdapter.nested_depth(attr) == 0:
477
+ normalized_attr = [[attr]]
478
+ elif BlockAdapter.nested_depth(attr) == 1: # List
479
+ if attr: # not-empty
480
+ if len(attr) == len(adapter.transformer):
481
+ normalized_attr = [[a] for a in attr]
454
482
  else:
455
- adapter.params_modifiers = [adapter.params_modifiers]
456
- elif isinstance(adapter.params_modifiers[0], list):
457
- pass
483
+ normalized_attr = [attr]
484
+ else: # [] empty
485
+ normalized_attr = [
486
+ [] for _ in range(len(adapter.transformer))
487
+ ]
488
+
489
+ assert len(adapter.transformer) == len(normalized_attr)
490
+ return normalized_attr
491
+
492
+ adapter.blocks = _normalize_attr(adapter.blocks)
493
+ adapter.blocks_name = _normalize_attr(adapter.blocks_name)
494
+ adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
495
+ adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
496
+ adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
497
+ BlockAdapter.unique(adapter)
458
498
 
459
- assert len(adapter.transformer) == len(adapter.blocks)
460
- assert len(adapter.transformer) == len(adapter.blocks_name)
461
- assert len(adapter.transformer) == len(adapter.forward_pattern)
462
- assert len(adapter.transformer) == len(adapter.dummy_blocks_names)
463
- if adapter.params_modifiers is not None:
464
- assert len(adapter.transformer) == len(adapter.params_modifiers)
499
+ adapter._is_normalized = True
500
+
501
+ return adapter
465
502
 
503
+ @classmethod
504
+ def unique(cls, adapter: "BlockAdapter"):
505
+ # NOTE: Users should never call this function
466
506
  for i in range(len(adapter.blocks)):
467
507
  assert len(adapter.blocks[i]) == len(adapter.blocks_name[i])
468
508
  assert len(adapter.blocks[i]) == len(adapter.forward_pattern[i])
469
509
 
510
+ # Generate unique blocks names
470
511
  if len(adapter.unique_blocks_name) == 0:
471
512
  for i in range(len(adapter.transformer)):
472
- # Generate unique blocks names
473
513
  adapter.unique_blocks_name.append(
474
514
  [
475
515
  f"{name}_{hash(id(blocks))}"
@@ -479,10 +519,10 @@ class BlockAdapter:
479
519
  )
480
520
  ]
481
521
  )
522
+ else:
523
+ assert len(adapter.transformer) == len(adapter.unique_blocks_name)
482
524
 
483
- assert len(adapter.transformer) == len(adapter.unique_blocks_name)
484
-
485
- # Match Forward Pattern
525
+ # Also check Match Forward Pattern
486
526
  for i in range(len(adapter.transformer)):
487
527
  for forward_pattern, blocks in zip(
488
528
  adapter.forward_pattern[i], adapter.blocks[i]
@@ -496,10 +536,6 @@ class BlockAdapter:
496
536
  f"supported lists: {ForwardPattern.supported_patterns()}"
497
537
  )
498
538
 
499
- adapter._is_normalized = True
500
-
501
- return adapter
502
-
503
539
  @classmethod
504
540
  def assert_normalized(cls, adapter: "BlockAdapter"):
505
541
  if not getattr(adapter, "_is_normalized", False):
@@ -524,15 +560,49 @@ class BlockAdapter:
524
560
  assert isinstance(adapter[0], torch.nn.Module)
525
561
  return getattr(adapter[0], "_is_cached", False)
526
562
  else:
527
- raise TypeError(f"Can't check this type: {adapter}!")
563
+ raise TypeError(f"Can't check this type: {type(adapter)}!")
528
564
 
529
565
  @classmethod
530
- def flatten(cls, attr: List[List[Any]]):
531
- if isinstance(attr, list):
532
- if not isinstance(attr[0], list):
533
- return attr
534
- flatten_attr = []
535
- for i in range(len(attr)):
536
- flatten_attr.extend(attr[i])
537
- return flatten_attr
538
- return attr
566
+ def nested_depth(cls, obj: Any):
567
+ # str: 0; List[str]: 1; List[List[str]]: 2
568
+ atom_types = (
569
+ str,
570
+ bytes,
571
+ torch.nn.ModuleList,
572
+ torch.nn.Module,
573
+ torch.Tensor,
574
+ )
575
+ if isinstance(obj, atom_types):
576
+ return 0
577
+ if not isinstance(obj, Iterable):
578
+ return 0
579
+ if isinstance(obj, dict):
580
+ items = obj.values()
581
+ else:
582
+ items = obj
583
+
584
+ max_depth = 0
585
+ for item in items:
586
+ current_depth = cls.nested_depth(item)
587
+ if current_depth > max_depth:
588
+ max_depth = current_depth
589
+ return 1 + max_depth
590
+
591
+ @classmethod
592
+ def flatten(cls, attr: List[Any]) -> List[Any]:
593
+ atom_types = (
594
+ str,
595
+ bytes,
596
+ torch.nn.ModuleList,
597
+ torch.nn.Module,
598
+ torch.Tensor,
599
+ )
600
+ if not isinstance(attr, list):
601
+ return attr
602
+ flattened = []
603
+ for item in attr:
604
+ if isinstance(item, list) and not isinstance(item, atom_types):
605
+ flattened.extend(cls.flatten(item))
606
+ else:
607
+ flattened.append(item)
608
+ return flattened
@@ -1,4 +1,4 @@
1
- from typing import Any, Tuple, List, Dict
1
+ from typing import Any, Tuple, List, Dict, Callable
2
2
 
3
3
  from diffusers import DiffusionPipeline
4
4
  from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
@@ -9,20 +9,23 @@ logger = init_logger(__name__)
9
9
 
10
10
 
11
11
  class BlockAdapterRegistry:
12
- _adapters: Dict[str, BlockAdapter] = {}
13
- _predefined_adapters_has_spearate_cfg: List[str] = {
12
+ _adapters: Dict[str, Callable[..., BlockAdapter]] = {}
13
+ _predefined_adapters_has_spearate_cfg: List[str] = [
14
14
  "QwenImage",
15
15
  "Wan",
16
16
  "CogView4",
17
17
  "Cosmos",
18
18
  "SkyReelsV2",
19
19
  "Chroma",
20
- }
20
+ ]
21
21
 
22
22
  @classmethod
23
- def register(cls, name):
24
- def decorator(func):
25
- cls._adapters[name] = func
23
+ def register(cls, name: str, supported: bool = True):
24
+ def decorator(
25
+ func: Callable[..., BlockAdapter]
26
+ ) -> Callable[..., BlockAdapter]:
27
+ if supported:
28
+ cls._adapters[name] = func
26
29
  return func
27
30
 
28
31
  return decorator
@@ -47,15 +50,24 @@ class BlockAdapterRegistry:
47
50
  @classmethod
48
51
  def has_separate_cfg(
49
52
  cls,
50
- pipe: DiffusionPipeline | str | Any,
53
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
51
54
  ) -> bool:
52
- if cls.get_adapter(
53
- pipe,
54
- disable_patch=True,
55
- ).has_separate_cfg:
55
+
56
+ # Prefer custom setting from block adapter.
57
+ if isinstance(pipe_or_adapter, BlockAdapter):
58
+ return pipe_or_adapter.has_separate_cfg
59
+
60
+ has_separate_cfg = False
61
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
62
+ has_separate_cfg = cls.get_adapter(
63
+ pipe_or_adapter,
64
+ skip_post_init=True, # check cfg setting only
65
+ ).has_separate_cfg
66
+
67
+ if has_separate_cfg:
56
68
  return True
57
69
 
58
- pipe_cls_name = pipe.__class__.__name__
70
+ pipe_cls_name = pipe_or_adapter.__class__.__name__
59
71
  for name in cls._predefined_adapters_has_spearate_cfg:
60
72
  if pipe_cls_name.startswith(name):
61
73
  return True