cache-dit 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -2,8 +2,9 @@ import torch
2
2
 
3
3
  import inspect
4
4
  import dataclasses
5
+ from collections.abc import Iterable
5
6
 
6
- from typing import Any, Tuple, List, Optional
7
+ from typing import Any, Tuple, List, Optional, Union
7
8
 
8
9
  from diffusers import DiffusionPipeline
9
10
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
@@ -14,21 +15,76 @@ from cache_dit.logger import init_logger
14
15
  logger = init_logger(__name__)
15
16
 
16
17
 
18
+ class ParamsModifier:
19
+ def __init__(self, **kwargs):
20
+ self._context_kwargs = kwargs.copy()
21
+
22
+
17
23
  @dataclasses.dataclass
18
24
  class BlockAdapter:
25
+
19
26
  # Transformer configurations.
20
- pipe: DiffusionPipeline | Any = None
21
- transformer: torch.nn.Module = None
27
+ pipe: Union[
28
+ DiffusionPipeline,
29
+ Any,
30
+ ] = None
31
+
32
+ # single transformer (most cases) or list of transformers (Wan2.2, etc)
33
+ transformer: Union[
34
+ torch.nn.Module,
35
+ List[torch.nn.Module],
36
+ ] = None
37
+
38
+ # Block Level Flags
39
+ # Each transformer contains a list of blocks-list,
40
+ # blocks_name-list, dummy_blocks_names-list, etc.
41
+ blocks: Union[
42
+ torch.nn.ModuleList,
43
+ List[torch.nn.ModuleList],
44
+ List[List[torch.nn.ModuleList]],
45
+ ] = None
22
46
 
23
- # ------------ Block Level Flags ------------
24
- blocks: torch.nn.ModuleList | List[torch.nn.ModuleList] = None
25
47
  # transformer_blocks, blocks, etc.
26
- blocks_name: str | List[str] = None
27
- dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
28
- forward_pattern: ForwardPattern | List[ForwardPattern] = None
48
+ blocks_name: Union[
49
+ str,
50
+ List[str],
51
+ List[List[str]],
52
+ ] = None
53
+
54
+ unique_blocks_name: Union[
55
+ str,
56
+ List[str],
57
+ List[List[str]],
58
+ ] = dataclasses.field(default_factory=list)
59
+
60
+ dummy_blocks_names: Union[
61
+ List[str],
62
+ List[List[str]],
63
+ ] = dataclasses.field(default_factory=list)
64
+
65
+ forward_pattern: Union[
66
+ ForwardPattern,
67
+ List[ForwardPattern],
68
+ List[List[ForwardPattern]],
69
+ ] = None
70
+
71
+ # modify cache context params for specific blocks.
72
+ params_modifiers: Union[
73
+ ParamsModifier,
74
+ List[ParamsModifier],
75
+ List[List[ParamsModifier]],
76
+ ] = None
77
+
29
78
  check_num_outputs: bool = True
30
79
 
80
+ # Pipeline Level Flags
81
+ # Patch Functor: Flux, etc.
82
+ patch_functor: Optional[PatchFunctor] = None
83
+ # Flags for separate cfg
84
+ has_separate_cfg: bool = False
85
+
31
86
  # Flags to control auto block adapter
87
+ # NOTE: NOT support for multi-transformers.
32
88
  auto: bool = False
33
89
  allow_prefixes: List[str] = dataclasses.field(
34
90
  default_factory=lambda: [
@@ -49,24 +105,92 @@ class BlockAdapter:
49
105
  default="max", metadata={"allowed_values": ["max", "min"]}
50
106
  )
51
107
 
52
- # NOTE: Other flags.
53
- disable_patch: bool = False
54
-
55
- # ------------ Pipeline Level Flags ------------
56
- # Patch Functor: Flux, etc.
57
- patch_functor: Optional[PatchFunctor] = None
58
- # Flags for separate cfg
59
- has_separate_cfg: bool = False
108
+ # Other Flags
109
+ skip_post_init: bool = False
60
110
 
61
111
  def __post_init__(self):
112
+ if self.skip_post_init:
113
+ return
62
114
  assert any((self.pipe is not None, self.transformer is not None))
63
- self.patchify()
115
+ self.maybe_fill_attrs()
116
+ self.maybe_patchify()
117
+
118
+ def maybe_fill_attrs(self):
119
+ # NOTE: This func should be call before normalize.
120
+ # Allow empty `blocks_names`, we will auto fill it.
121
+ # TODO: preprocess more empty attrs.
122
+ if (
123
+ self.transformer is not None
124
+ and self.blocks is not None
125
+ and self.blocks_name is None
126
+ ):
127
+
128
+ def _find(transformer, blocks):
129
+ attr_names = dir(transformer)
130
+ assert isinstance(blocks, torch.nn.ModuleList)
131
+ blocks_name = None
132
+ for attr_name in attr_names:
133
+ if attr := getattr(transformer, attr_name, None):
134
+ if isinstance(attr, torch.nn.ModuleList) and id(
135
+ attr
136
+ ) == id(blocks):
137
+ blocks_name = attr_name
138
+ break
139
+ assert (
140
+ blocks_name is not None
141
+ ), "No blocks_name match, please set it manually!"
142
+ return blocks_name
143
+
144
+ if self.nested_depth(self.transformer) == 0:
145
+ if self.nested_depth(self.blocks) == 0: # str
146
+ self.blocks_name = _find(self.transformer, self.blocks)
147
+ elif self.nested_depth(self.blocks) == 1:
148
+ self.blocks_name = [
149
+ _find(self.transformer, blocks)
150
+ for blocks in self.blocks
151
+ ]
152
+ else:
153
+ raise ValueError(
154
+ "Blocks nested depth can't more than 1 if transformer "
155
+ f"is not a list, current is: {self.nested_depth(self.blocks)}"
156
+ )
157
+ elif self.nested_depth(self.transformer) == 1: # List[str]
158
+ if self.nested_depth(self.blocks) == 1: # List[str]
159
+ assert len(self.transformer) == len(self.blocks)
160
+ self.blocks_name = [
161
+ _find(transformer, blocks)
162
+ for transformer, blocks in zip(
163
+ self.transformer, self.blocks
164
+ )
165
+ ]
166
+ elif self.nested_depth(self.blocks) == 2: # List[List[str]]
167
+ assert len(self.transformer) == len(self.blocks)
168
+ self.blocks_name = []
169
+ for i in range(len(self.blocks)):
170
+ self.blocks_name.append(
171
+ [
172
+ _find(self.transformer[i], blocks)
173
+ for blocks in self.blocks[i]
174
+ ]
175
+ )
176
+ else:
177
+ raise ValueError(
178
+ "Blocks nested depth can only be 1 or 2 "
179
+ "if transformer is a list, current is: "
180
+ f"{self.nested_depth(self.blocks)}"
181
+ )
182
+ else:
183
+ raise ValueError(
184
+ "transformer nested depth can't more than 1, "
185
+ f"current is: {self.nested_depth(self.transformer)}"
186
+ )
187
+ logger.info(f"Auto fill blocks_name: {self.blocks_name}.")
64
188
 
65
- def patchify(self, *args, **kwargs):
189
+ def maybe_patchify(self, *args, **kwargs):
66
190
  # Process some specificial cases, specific for transformers
67
191
  # that has different forward patterns between single_transformer_blocks
68
192
  # and transformer_blocks , such as Flux (diffusers < 0.35.0).
69
- if self.patch_functor is not None and not self.disable_patch:
193
+ if self.patch_functor is not None:
70
194
  if self.transformer is not None:
71
195
  self.patch_functor.apply(self.transformer, *args, **kwargs)
72
196
  else:
@@ -92,7 +216,7 @@ class BlockAdapter:
92
216
  transformer = pipe.transformer
93
217
 
94
218
  # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
95
- blocks, blocks_name = BlockAdapter.find_blocks(
219
+ blocks, blocks_name = BlockAdapter.find_match_blocks(
96
220
  transformer=transformer,
97
221
  allow_prefixes=adapter.allow_prefixes,
98
222
  allow_suffixes=adapter.allow_suffixes,
@@ -115,6 +239,10 @@ class BlockAdapter:
115
239
  def check_block_adapter(
116
240
  adapter: "BlockAdapter",
117
241
  ) -> bool:
242
+
243
+ if getattr(adapter, "_is_normlized", False):
244
+ return True
245
+
118
246
  def _check_warning(attr: str):
119
247
  if getattr(adapter, attr, None) is None:
120
248
  logger.warning(f"{attr} is None!")
@@ -136,24 +264,23 @@ class BlockAdapter:
136
264
  if not _check_warning("forward_pattern"):
137
265
  return False
138
266
 
139
- if isinstance(adapter.blocks, list):
140
- for i, blocks in enumerate(adapter.blocks):
141
- if not isinstance(blocks, torch.nn.ModuleList):
142
- logger.warning(f"blocks[{i}] is not ModuleList.")
143
- return False
267
+ if BlockAdapter.nested_depth(adapter.blocks) == 0:
268
+ blocks = adapter.blocks
144
269
  else:
145
- if not isinstance(adapter.blocks, torch.nn.ModuleList):
146
- logger.warning("blocks is not ModuleList.")
147
- return False
270
+ blocks = BlockAdapter.flatten(adapter.blocks)[0]
271
+
272
+ if not isinstance(blocks, torch.nn.ModuleList):
273
+ logger.warning("blocks is not ModuleList.")
274
+ return False
148
275
 
149
276
  return True
150
277
 
151
278
  @staticmethod
152
- def find_blocks(
279
+ def find_match_blocks(
153
280
  transformer: torch.nn.Module,
154
281
  allow_prefixes: List[str] = [
155
- "transformer",
156
- "single_transformer",
282
+ "transformer_blocks",
283
+ "single_transformer_blocks",
157
284
  "blocks",
158
285
  "layers",
159
286
  "single_stream_blocks",
@@ -181,10 +308,10 @@ class BlockAdapter:
181
308
  valid_count = []
182
309
  forward_pattern = kwargs.pop("forward_pattern", None)
183
310
  for blocks_name in blocks_names:
184
- if blocks := getattr(transformer, blocks_name, None):
311
+ if (blocks := getattr(transformer, blocks_name, None)) is not None:
185
312
  if isinstance(blocks, torch.nn.ModuleList):
186
313
  block = blocks[0]
187
- block_cls_name = block.__class__.__name__
314
+ block_cls_name: str = block.__class__.__name__
188
315
  # Check suffixes
189
316
  if isinstance(block, torch.nn.Module) and (
190
317
  any(
@@ -244,6 +371,18 @@ class BlockAdapter:
244
371
 
245
372
  return final_blocks, final_name
246
373
 
374
+ @staticmethod
375
+ def find_blocks(
376
+ transformer: torch.nn.Module,
377
+ ) -> List[torch.nn.ModuleList]:
378
+ total_blocks = []
379
+ for attr in dir(transformer):
380
+ if (blocks := getattr(transformer, attr, None)) is not None:
381
+ if isinstance(blocks, torch.nn.ModuleList):
382
+ if isinstance(blocks[0], torch.nn.Module):
383
+ total_blocks.append(blocks)
384
+ return total_blocks
385
+
247
386
  @staticmethod
248
387
  def match_block_pattern(
249
388
  block: torch.nn.Module,
@@ -320,14 +459,148 @@ class BlockAdapter:
320
459
  def normalize(
321
460
  adapter: "BlockAdapter",
322
461
  ) -> "BlockAdapter":
323
- if not isinstance(adapter.blocks, list):
324
- adapter.blocks = [adapter.blocks]
325
- if not isinstance(adapter.blocks_name, list):
326
- adapter.blocks_name = [adapter.blocks_name]
327
- if not isinstance(adapter.forward_pattern, list):
328
- adapter.forward_pattern = [adapter.forward_pattern]
329
462
 
330
- assert len(adapter.blocks) == len(adapter.blocks_name)
331
- assert len(adapter.blocks) == len(adapter.forward_pattern)
463
+ if getattr(adapter, "_is_normalized", False):
464
+ return adapter
465
+
466
+ if BlockAdapter.nested_depth(adapter.transformer) == 0:
467
+ adapter.transformer = [adapter.transformer]
468
+
469
+ def _normalize_attr(attr: Any):
470
+ normalized_attr = attr
471
+ if attr is None:
472
+ return normalized_attr
473
+
474
+ if BlockAdapter.nested_depth(attr) == 0:
475
+ normalized_attr = [[attr]]
476
+ elif BlockAdapter.nested_depth(attr) == 1: # List
477
+ if attr: # not-empty
478
+ if len(attr) == len(adapter.transformer):
479
+ normalized_attr = [[a] for a in attr]
480
+ else:
481
+ normalized_attr = [attr]
482
+ else: # [] empty
483
+ normalized_attr = [
484
+ [] for _ in range(len(adapter.transformer))
485
+ ]
486
+
487
+ assert len(adapter.transformer) == len(normalized_attr)
488
+ return normalized_attr
489
+
490
+ adapter.blocks = _normalize_attr(adapter.blocks)
491
+ adapter.blocks_name = _normalize_attr(adapter.blocks_name)
492
+ adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
493
+ adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
494
+ adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
495
+ BlockAdapter.unique(adapter)
496
+
497
+ adapter._is_normalized = True
332
498
 
333
499
  return adapter
500
+
501
+ @classmethod
502
+ def unique(cls, adapter: "BlockAdapter"):
503
+ # NOTE: Users should never call this function
504
+ for i in range(len(adapter.blocks)):
505
+ assert len(adapter.blocks[i]) == len(adapter.blocks_name[i])
506
+ assert len(adapter.blocks[i]) == len(adapter.forward_pattern[i])
507
+
508
+ # Generate unique blocks names
509
+ if len(adapter.unique_blocks_name) == 0:
510
+ for i in range(len(adapter.transformer)):
511
+ adapter.unique_blocks_name.append(
512
+ [
513
+ f"{name}_{hash(id(blocks))}"
514
+ for blocks, name in zip(
515
+ adapter.blocks[i],
516
+ adapter.blocks_name[i],
517
+ )
518
+ ]
519
+ )
520
+ else:
521
+ assert len(adapter.transformer) == len(adapter.unique_blocks_name)
522
+
523
+ # Also check Match Forward Pattern
524
+ for i in range(len(adapter.transformer)):
525
+ for forward_pattern, blocks in zip(
526
+ adapter.forward_pattern[i], adapter.blocks[i]
527
+ ):
528
+ assert BlockAdapter.match_blocks_pattern(
529
+ blocks,
530
+ forward_pattern=forward_pattern,
531
+ check_num_outputs=adapter.check_num_outputs,
532
+ ), (
533
+ "No block forward pattern matched, "
534
+ f"supported lists: {ForwardPattern.supported_patterns()}"
535
+ )
536
+
537
+ @classmethod
538
+ def assert_normalized(cls, adapter: "BlockAdapter"):
539
+ if not getattr(adapter, "_is_normalized", False):
540
+ raise RuntimeError("block_adapter must be normailzed.")
541
+
542
+ @classmethod
543
+ def is_cached(cls, adapter: Any) -> bool:
544
+ if isinstance(adapter, cls):
545
+ cls.assert_normalized(adapter)
546
+ return all(
547
+ (
548
+ getattr(adapter.pipe, "_is_cached", False),
549
+ getattr(adapter.transformer[0], "_is_cached", False),
550
+ )
551
+ )
552
+ elif isinstance(
553
+ adapter,
554
+ (DiffusionPipeline, torch.nn.Module),
555
+ ):
556
+ return getattr(adapter, "_is_cached", False)
557
+ elif isinstance(adapter, list): # [TRN_0,...]
558
+ assert isinstance(adapter[0], torch.nn.Module)
559
+ return getattr(adapter[0], "_is_cached", False)
560
+ else:
561
+ raise TypeError(f"Can't check this type: {adapter}!")
562
+
563
+ @classmethod
564
+ def nested_depth(cls, obj: Any):
565
+ # str: 0; List[str]: 1; List[List[str]]: 2
566
+ atom_types = (
567
+ str,
568
+ bytes,
569
+ torch.nn.ModuleList,
570
+ torch.nn.Module,
571
+ torch.Tensor,
572
+ )
573
+ if isinstance(obj, atom_types):
574
+ return 0
575
+ if not isinstance(obj, Iterable):
576
+ return 0
577
+ if isinstance(obj, dict):
578
+ items = obj.values()
579
+ else:
580
+ items = obj
581
+
582
+ max_depth = 0
583
+ for item in items:
584
+ current_depth = cls.nested_depth(item)
585
+ if current_depth > max_depth:
586
+ max_depth = current_depth
587
+ return 1 + max_depth
588
+
589
+ @classmethod
590
+ def flatten(cls, attr: List[Any]) -> List[Any]:
591
+ atom_types = (
592
+ str,
593
+ bytes,
594
+ torch.nn.ModuleList,
595
+ torch.nn.Module,
596
+ torch.Tensor,
597
+ )
598
+ if not isinstance(attr, list):
599
+ return attr
600
+ flattened = []
601
+ for item in attr:
602
+ if isinstance(item, list) and not isinstance(item, atom_types):
603
+ flattened.extend(cls.flatten(item))
604
+ else:
605
+ flattened.append(item)
606
+ return flattened
@@ -47,15 +47,24 @@ class BlockAdapterRegistry:
47
47
  @classmethod
48
48
  def has_separate_cfg(
49
49
  cls,
50
- pipe: DiffusionPipeline | str | Any,
50
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
51
51
  ) -> bool:
52
- if cls.get_adapter(
53
- pipe,
54
- disable_patch=True,
55
- ).has_separate_cfg:
52
+
53
+ # Prefer custom setting from block adapter.
54
+ if isinstance(pipe_or_adapter, BlockAdapter):
55
+ return pipe_or_adapter.has_separate_cfg
56
+
57
+ has_separate_cfg = False
58
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
59
+ has_separate_cfg = cls.get_adapter(
60
+ pipe_or_adapter,
61
+ skip_post_init=True, # check cfg setting only
62
+ ).has_separate_cfg
63
+
64
+ if has_separate_cfg:
56
65
  return True
57
66
 
58
- pipe_cls_name = pipe.__class__.__name__
67
+ pipe_cls_name = pipe_or_adapter.__class__.__name__
59
68
  for name in cls._predefined_adapters_has_spearate_cfg:
60
69
  if pipe_cls_name.startswith(name):
61
70
  return True