cache-dit 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +2 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -0
- cache_dit/cache_factory/block_adapters/__init__.py +105 -111
- cache_dit/cache_factory/block_adapters/block_adapters.py +314 -41
- cache_dit/cache_factory/block_adapters/block_registers.py +15 -6
- cache_dit/cache_factory/cache_adapters.py +244 -116
- cache_dit/cache_factory/cache_blocks/__init__.py +55 -4
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +36 -37
- cache_dit/cache_factory/cache_blocks/pattern_base.py +83 -76
- cache_dit/cache_factory/cache_blocks/utils.py +26 -8
- cache_dit/cache_factory/cache_contexts/__init__.py +4 -1
- cache_dit/cache_factory/cache_contexts/cache_context.py +14 -876
- cache_dit/cache_factory/cache_contexts/cache_manager.py +847 -0
- cache_dit/cache_factory/cache_interface.py +91 -24
- cache_dit/cache_factory/patch_functors/functor_chroma.py +1 -1
- cache_dit/cache_factory/patch_functors/functor_flux.py +1 -1
- cache_dit/utils.py +164 -58
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/METADATA +59 -34
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/RECORD +24 -24
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,9 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import dataclasses
|
|
5
|
+
from collections.abc import Iterable
|
|
5
6
|
|
|
6
|
-
from typing import Any, Tuple, List, Optional
|
|
7
|
+
from typing import Any, Tuple, List, Optional, Union
|
|
7
8
|
|
|
8
9
|
from diffusers import DiffusionPipeline
|
|
9
10
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
@@ -14,21 +15,76 @@ from cache_dit.logger import init_logger
|
|
|
14
15
|
logger = init_logger(__name__)
|
|
15
16
|
|
|
16
17
|
|
|
18
|
+
class ParamsModifier:
|
|
19
|
+
def __init__(self, **kwargs):
|
|
20
|
+
self._context_kwargs = kwargs.copy()
|
|
21
|
+
|
|
22
|
+
|
|
17
23
|
@dataclasses.dataclass
|
|
18
24
|
class BlockAdapter:
|
|
25
|
+
|
|
19
26
|
# Transformer configurations.
|
|
20
|
-
pipe:
|
|
21
|
-
|
|
27
|
+
pipe: Union[
|
|
28
|
+
DiffusionPipeline,
|
|
29
|
+
Any,
|
|
30
|
+
] = None
|
|
31
|
+
|
|
32
|
+
# single transformer (most cases) or list of transformers (Wan2.2, etc)
|
|
33
|
+
transformer: Union[
|
|
34
|
+
torch.nn.Module,
|
|
35
|
+
List[torch.nn.Module],
|
|
36
|
+
] = None
|
|
37
|
+
|
|
38
|
+
# Block Level Flags
|
|
39
|
+
# Each transformer contains a list of blocks-list,
|
|
40
|
+
# blocks_name-list, dummy_blocks_names-list, etc.
|
|
41
|
+
blocks: Union[
|
|
42
|
+
torch.nn.ModuleList,
|
|
43
|
+
List[torch.nn.ModuleList],
|
|
44
|
+
List[List[torch.nn.ModuleList]],
|
|
45
|
+
] = None
|
|
22
46
|
|
|
23
|
-
# ------------ Block Level Flags ------------
|
|
24
|
-
blocks: torch.nn.ModuleList | List[torch.nn.ModuleList] = None
|
|
25
47
|
# transformer_blocks, blocks, etc.
|
|
26
|
-
blocks_name:
|
|
27
|
-
|
|
28
|
-
|
|
48
|
+
blocks_name: Union[
|
|
49
|
+
str,
|
|
50
|
+
List[str],
|
|
51
|
+
List[List[str]],
|
|
52
|
+
] = None
|
|
53
|
+
|
|
54
|
+
unique_blocks_name: Union[
|
|
55
|
+
str,
|
|
56
|
+
List[str],
|
|
57
|
+
List[List[str]],
|
|
58
|
+
] = dataclasses.field(default_factory=list)
|
|
59
|
+
|
|
60
|
+
dummy_blocks_names: Union[
|
|
61
|
+
List[str],
|
|
62
|
+
List[List[str]],
|
|
63
|
+
] = dataclasses.field(default_factory=list)
|
|
64
|
+
|
|
65
|
+
forward_pattern: Union[
|
|
66
|
+
ForwardPattern,
|
|
67
|
+
List[ForwardPattern],
|
|
68
|
+
List[List[ForwardPattern]],
|
|
69
|
+
] = None
|
|
70
|
+
|
|
71
|
+
# modify cache context params for specific blocks.
|
|
72
|
+
params_modifiers: Union[
|
|
73
|
+
ParamsModifier,
|
|
74
|
+
List[ParamsModifier],
|
|
75
|
+
List[List[ParamsModifier]],
|
|
76
|
+
] = None
|
|
77
|
+
|
|
29
78
|
check_num_outputs: bool = True
|
|
30
79
|
|
|
80
|
+
# Pipeline Level Flags
|
|
81
|
+
# Patch Functor: Flux, etc.
|
|
82
|
+
patch_functor: Optional[PatchFunctor] = None
|
|
83
|
+
# Flags for separate cfg
|
|
84
|
+
has_separate_cfg: bool = False
|
|
85
|
+
|
|
31
86
|
# Flags to control auto block adapter
|
|
87
|
+
# NOTE: NOT support for multi-transformers.
|
|
32
88
|
auto: bool = False
|
|
33
89
|
allow_prefixes: List[str] = dataclasses.field(
|
|
34
90
|
default_factory=lambda: [
|
|
@@ -49,24 +105,92 @@ class BlockAdapter:
|
|
|
49
105
|
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
50
106
|
)
|
|
51
107
|
|
|
52
|
-
#
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
# ------------ Pipeline Level Flags ------------
|
|
56
|
-
# Patch Functor: Flux, etc.
|
|
57
|
-
patch_functor: Optional[PatchFunctor] = None
|
|
58
|
-
# Flags for separate cfg
|
|
59
|
-
has_separate_cfg: bool = False
|
|
108
|
+
# Other Flags
|
|
109
|
+
skip_post_init: bool = False
|
|
60
110
|
|
|
61
111
|
def __post_init__(self):
|
|
112
|
+
if self.skip_post_init:
|
|
113
|
+
return
|
|
62
114
|
assert any((self.pipe is not None, self.transformer is not None))
|
|
63
|
-
self.
|
|
115
|
+
self.maybe_fill_attrs()
|
|
116
|
+
self.maybe_patchify()
|
|
117
|
+
|
|
118
|
+
def maybe_fill_attrs(self):
|
|
119
|
+
# NOTE: This func should be call before normalize.
|
|
120
|
+
# Allow empty `blocks_names`, we will auto fill it.
|
|
121
|
+
# TODO: preprocess more empty attrs.
|
|
122
|
+
if (
|
|
123
|
+
self.transformer is not None
|
|
124
|
+
and self.blocks is not None
|
|
125
|
+
and self.blocks_name is None
|
|
126
|
+
):
|
|
127
|
+
|
|
128
|
+
def _find(transformer, blocks):
|
|
129
|
+
attr_names = dir(transformer)
|
|
130
|
+
assert isinstance(blocks, torch.nn.ModuleList)
|
|
131
|
+
blocks_name = None
|
|
132
|
+
for attr_name in attr_names:
|
|
133
|
+
if attr := getattr(transformer, attr_name, None):
|
|
134
|
+
if isinstance(attr, torch.nn.ModuleList) and id(
|
|
135
|
+
attr
|
|
136
|
+
) == id(blocks):
|
|
137
|
+
blocks_name = attr_name
|
|
138
|
+
break
|
|
139
|
+
assert (
|
|
140
|
+
blocks_name is not None
|
|
141
|
+
), "No blocks_name match, please set it manually!"
|
|
142
|
+
return blocks_name
|
|
143
|
+
|
|
144
|
+
if self.nested_depth(self.transformer) == 0:
|
|
145
|
+
if self.nested_depth(self.blocks) == 0: # str
|
|
146
|
+
self.blocks_name = _find(self.transformer, self.blocks)
|
|
147
|
+
elif self.nested_depth(self.blocks) == 1:
|
|
148
|
+
self.blocks_name = [
|
|
149
|
+
_find(self.transformer, blocks)
|
|
150
|
+
for blocks in self.blocks
|
|
151
|
+
]
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"Blocks nested depth can't more than 1 if transformer "
|
|
155
|
+
f"is not a list, current is: {self.nested_depth(self.blocks)}"
|
|
156
|
+
)
|
|
157
|
+
elif self.nested_depth(self.transformer) == 1: # List[str]
|
|
158
|
+
if self.nested_depth(self.blocks) == 1: # List[str]
|
|
159
|
+
assert len(self.transformer) == len(self.blocks)
|
|
160
|
+
self.blocks_name = [
|
|
161
|
+
_find(transformer, blocks)
|
|
162
|
+
for transformer, blocks in zip(
|
|
163
|
+
self.transformer, self.blocks
|
|
164
|
+
)
|
|
165
|
+
]
|
|
166
|
+
elif self.nested_depth(self.blocks) == 2: # List[List[str]]
|
|
167
|
+
assert len(self.transformer) == len(self.blocks)
|
|
168
|
+
self.blocks_name = []
|
|
169
|
+
for i in range(len(self.blocks)):
|
|
170
|
+
self.blocks_name.append(
|
|
171
|
+
[
|
|
172
|
+
_find(self.transformer[i], blocks)
|
|
173
|
+
for blocks in self.blocks[i]
|
|
174
|
+
]
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"Blocks nested depth can only be 1 or 2 "
|
|
179
|
+
"if transformer is a list, current is: "
|
|
180
|
+
f"{self.nested_depth(self.blocks)}"
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"transformer nested depth can't more than 1, "
|
|
185
|
+
f"current is: {self.nested_depth(self.transformer)}"
|
|
186
|
+
)
|
|
187
|
+
logger.info(f"Auto fill blocks_name: {self.blocks_name}.")
|
|
64
188
|
|
|
65
|
-
def
|
|
189
|
+
def maybe_patchify(self, *args, **kwargs):
|
|
66
190
|
# Process some specificial cases, specific for transformers
|
|
67
191
|
# that has different forward patterns between single_transformer_blocks
|
|
68
192
|
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
69
|
-
if self.patch_functor is not None
|
|
193
|
+
if self.patch_functor is not None:
|
|
70
194
|
if self.transformer is not None:
|
|
71
195
|
self.patch_functor.apply(self.transformer, *args, **kwargs)
|
|
72
196
|
else:
|
|
@@ -92,7 +216,7 @@ class BlockAdapter:
|
|
|
92
216
|
transformer = pipe.transformer
|
|
93
217
|
|
|
94
218
|
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
95
|
-
blocks, blocks_name = BlockAdapter.
|
|
219
|
+
blocks, blocks_name = BlockAdapter.find_match_blocks(
|
|
96
220
|
transformer=transformer,
|
|
97
221
|
allow_prefixes=adapter.allow_prefixes,
|
|
98
222
|
allow_suffixes=adapter.allow_suffixes,
|
|
@@ -115,6 +239,10 @@ class BlockAdapter:
|
|
|
115
239
|
def check_block_adapter(
|
|
116
240
|
adapter: "BlockAdapter",
|
|
117
241
|
) -> bool:
|
|
242
|
+
|
|
243
|
+
if getattr(adapter, "_is_normlized", False):
|
|
244
|
+
return True
|
|
245
|
+
|
|
118
246
|
def _check_warning(attr: str):
|
|
119
247
|
if getattr(adapter, attr, None) is None:
|
|
120
248
|
logger.warning(f"{attr} is None!")
|
|
@@ -136,24 +264,23 @@ class BlockAdapter:
|
|
|
136
264
|
if not _check_warning("forward_pattern"):
|
|
137
265
|
return False
|
|
138
266
|
|
|
139
|
-
if
|
|
140
|
-
|
|
141
|
-
if not isinstance(blocks, torch.nn.ModuleList):
|
|
142
|
-
logger.warning(f"blocks[{i}] is not ModuleList.")
|
|
143
|
-
return False
|
|
267
|
+
if BlockAdapter.nested_depth(adapter.blocks) == 0:
|
|
268
|
+
blocks = adapter.blocks
|
|
144
269
|
else:
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
270
|
+
blocks = BlockAdapter.flatten(adapter.blocks)[0]
|
|
271
|
+
|
|
272
|
+
if not isinstance(blocks, torch.nn.ModuleList):
|
|
273
|
+
logger.warning("blocks is not ModuleList.")
|
|
274
|
+
return False
|
|
148
275
|
|
|
149
276
|
return True
|
|
150
277
|
|
|
151
278
|
@staticmethod
|
|
152
|
-
def
|
|
279
|
+
def find_match_blocks(
|
|
153
280
|
transformer: torch.nn.Module,
|
|
154
281
|
allow_prefixes: List[str] = [
|
|
155
|
-
"
|
|
156
|
-
"
|
|
282
|
+
"transformer_blocks",
|
|
283
|
+
"single_transformer_blocks",
|
|
157
284
|
"blocks",
|
|
158
285
|
"layers",
|
|
159
286
|
"single_stream_blocks",
|
|
@@ -181,10 +308,10 @@ class BlockAdapter:
|
|
|
181
308
|
valid_count = []
|
|
182
309
|
forward_pattern = kwargs.pop("forward_pattern", None)
|
|
183
310
|
for blocks_name in blocks_names:
|
|
184
|
-
if blocks := getattr(transformer, blocks_name, None):
|
|
311
|
+
if (blocks := getattr(transformer, blocks_name, None)) is not None:
|
|
185
312
|
if isinstance(blocks, torch.nn.ModuleList):
|
|
186
313
|
block = blocks[0]
|
|
187
|
-
block_cls_name = block.__class__.__name__
|
|
314
|
+
block_cls_name: str = block.__class__.__name__
|
|
188
315
|
# Check suffixes
|
|
189
316
|
if isinstance(block, torch.nn.Module) and (
|
|
190
317
|
any(
|
|
@@ -244,6 +371,18 @@ class BlockAdapter:
|
|
|
244
371
|
|
|
245
372
|
return final_blocks, final_name
|
|
246
373
|
|
|
374
|
+
@staticmethod
|
|
375
|
+
def find_blocks(
|
|
376
|
+
transformer: torch.nn.Module,
|
|
377
|
+
) -> List[torch.nn.ModuleList]:
|
|
378
|
+
total_blocks = []
|
|
379
|
+
for attr in dir(transformer):
|
|
380
|
+
if (blocks := getattr(transformer, attr, None)) is not None:
|
|
381
|
+
if isinstance(blocks, torch.nn.ModuleList):
|
|
382
|
+
if isinstance(blocks[0], torch.nn.Module):
|
|
383
|
+
total_blocks.append(blocks)
|
|
384
|
+
return total_blocks
|
|
385
|
+
|
|
247
386
|
@staticmethod
|
|
248
387
|
def match_block_pattern(
|
|
249
388
|
block: torch.nn.Module,
|
|
@@ -320,14 +459,148 @@ class BlockAdapter:
|
|
|
320
459
|
def normalize(
|
|
321
460
|
adapter: "BlockAdapter",
|
|
322
461
|
) -> "BlockAdapter":
|
|
323
|
-
if not isinstance(adapter.blocks, list):
|
|
324
|
-
adapter.blocks = [adapter.blocks]
|
|
325
|
-
if not isinstance(adapter.blocks_name, list):
|
|
326
|
-
adapter.blocks_name = [adapter.blocks_name]
|
|
327
|
-
if not isinstance(adapter.forward_pattern, list):
|
|
328
|
-
adapter.forward_pattern = [adapter.forward_pattern]
|
|
329
462
|
|
|
330
|
-
|
|
331
|
-
|
|
463
|
+
if getattr(adapter, "_is_normalized", False):
|
|
464
|
+
return adapter
|
|
465
|
+
|
|
466
|
+
if BlockAdapter.nested_depth(adapter.transformer) == 0:
|
|
467
|
+
adapter.transformer = [adapter.transformer]
|
|
468
|
+
|
|
469
|
+
def _normalize_attr(attr: Any):
|
|
470
|
+
normalized_attr = attr
|
|
471
|
+
if attr is None:
|
|
472
|
+
return normalized_attr
|
|
473
|
+
|
|
474
|
+
if BlockAdapter.nested_depth(attr) == 0:
|
|
475
|
+
normalized_attr = [[attr]]
|
|
476
|
+
elif BlockAdapter.nested_depth(attr) == 1: # List
|
|
477
|
+
if attr: # not-empty
|
|
478
|
+
if len(attr) == len(adapter.transformer):
|
|
479
|
+
normalized_attr = [[a] for a in attr]
|
|
480
|
+
else:
|
|
481
|
+
normalized_attr = [attr]
|
|
482
|
+
else: # [] empty
|
|
483
|
+
normalized_attr = [
|
|
484
|
+
[] for _ in range(len(adapter.transformer))
|
|
485
|
+
]
|
|
486
|
+
|
|
487
|
+
assert len(adapter.transformer) == len(normalized_attr)
|
|
488
|
+
return normalized_attr
|
|
489
|
+
|
|
490
|
+
adapter.blocks = _normalize_attr(adapter.blocks)
|
|
491
|
+
adapter.blocks_name = _normalize_attr(adapter.blocks_name)
|
|
492
|
+
adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
|
|
493
|
+
adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
|
|
494
|
+
adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
|
|
495
|
+
BlockAdapter.unique(adapter)
|
|
496
|
+
|
|
497
|
+
adapter._is_normalized = True
|
|
332
498
|
|
|
333
499
|
return adapter
|
|
500
|
+
|
|
501
|
+
@classmethod
|
|
502
|
+
def unique(cls, adapter: "BlockAdapter"):
|
|
503
|
+
# NOTE: Users should never call this function
|
|
504
|
+
for i in range(len(adapter.blocks)):
|
|
505
|
+
assert len(adapter.blocks[i]) == len(adapter.blocks_name[i])
|
|
506
|
+
assert len(adapter.blocks[i]) == len(adapter.forward_pattern[i])
|
|
507
|
+
|
|
508
|
+
# Generate unique blocks names
|
|
509
|
+
if len(adapter.unique_blocks_name) == 0:
|
|
510
|
+
for i in range(len(adapter.transformer)):
|
|
511
|
+
adapter.unique_blocks_name.append(
|
|
512
|
+
[
|
|
513
|
+
f"{name}_{hash(id(blocks))}"
|
|
514
|
+
for blocks, name in zip(
|
|
515
|
+
adapter.blocks[i],
|
|
516
|
+
adapter.blocks_name[i],
|
|
517
|
+
)
|
|
518
|
+
]
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
assert len(adapter.transformer) == len(adapter.unique_blocks_name)
|
|
522
|
+
|
|
523
|
+
# Also check Match Forward Pattern
|
|
524
|
+
for i in range(len(adapter.transformer)):
|
|
525
|
+
for forward_pattern, blocks in zip(
|
|
526
|
+
adapter.forward_pattern[i], adapter.blocks[i]
|
|
527
|
+
):
|
|
528
|
+
assert BlockAdapter.match_blocks_pattern(
|
|
529
|
+
blocks,
|
|
530
|
+
forward_pattern=forward_pattern,
|
|
531
|
+
check_num_outputs=adapter.check_num_outputs,
|
|
532
|
+
), (
|
|
533
|
+
"No block forward pattern matched, "
|
|
534
|
+
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
@classmethod
|
|
538
|
+
def assert_normalized(cls, adapter: "BlockAdapter"):
|
|
539
|
+
if not getattr(adapter, "_is_normalized", False):
|
|
540
|
+
raise RuntimeError("block_adapter must be normailzed.")
|
|
541
|
+
|
|
542
|
+
@classmethod
|
|
543
|
+
def is_cached(cls, adapter: Any) -> bool:
|
|
544
|
+
if isinstance(adapter, cls):
|
|
545
|
+
cls.assert_normalized(adapter)
|
|
546
|
+
return all(
|
|
547
|
+
(
|
|
548
|
+
getattr(adapter.pipe, "_is_cached", False),
|
|
549
|
+
getattr(adapter.transformer[0], "_is_cached", False),
|
|
550
|
+
)
|
|
551
|
+
)
|
|
552
|
+
elif isinstance(
|
|
553
|
+
adapter,
|
|
554
|
+
(DiffusionPipeline, torch.nn.Module),
|
|
555
|
+
):
|
|
556
|
+
return getattr(adapter, "_is_cached", False)
|
|
557
|
+
elif isinstance(adapter, list): # [TRN_0,...]
|
|
558
|
+
assert isinstance(adapter[0], torch.nn.Module)
|
|
559
|
+
return getattr(adapter[0], "_is_cached", False)
|
|
560
|
+
else:
|
|
561
|
+
raise TypeError(f"Can't check this type: {adapter}!")
|
|
562
|
+
|
|
563
|
+
@classmethod
|
|
564
|
+
def nested_depth(cls, obj: Any):
|
|
565
|
+
# str: 0; List[str]: 1; List[List[str]]: 2
|
|
566
|
+
atom_types = (
|
|
567
|
+
str,
|
|
568
|
+
bytes,
|
|
569
|
+
torch.nn.ModuleList,
|
|
570
|
+
torch.nn.Module,
|
|
571
|
+
torch.Tensor,
|
|
572
|
+
)
|
|
573
|
+
if isinstance(obj, atom_types):
|
|
574
|
+
return 0
|
|
575
|
+
if not isinstance(obj, Iterable):
|
|
576
|
+
return 0
|
|
577
|
+
if isinstance(obj, dict):
|
|
578
|
+
items = obj.values()
|
|
579
|
+
else:
|
|
580
|
+
items = obj
|
|
581
|
+
|
|
582
|
+
max_depth = 0
|
|
583
|
+
for item in items:
|
|
584
|
+
current_depth = cls.nested_depth(item)
|
|
585
|
+
if current_depth > max_depth:
|
|
586
|
+
max_depth = current_depth
|
|
587
|
+
return 1 + max_depth
|
|
588
|
+
|
|
589
|
+
@classmethod
|
|
590
|
+
def flatten(cls, attr: List[Any]) -> List[Any]:
|
|
591
|
+
atom_types = (
|
|
592
|
+
str,
|
|
593
|
+
bytes,
|
|
594
|
+
torch.nn.ModuleList,
|
|
595
|
+
torch.nn.Module,
|
|
596
|
+
torch.Tensor,
|
|
597
|
+
)
|
|
598
|
+
if not isinstance(attr, list):
|
|
599
|
+
return attr
|
|
600
|
+
flattened = []
|
|
601
|
+
for item in attr:
|
|
602
|
+
if isinstance(item, list) and not isinstance(item, atom_types):
|
|
603
|
+
flattened.extend(cls.flatten(item))
|
|
604
|
+
else:
|
|
605
|
+
flattened.append(item)
|
|
606
|
+
return flattened
|
|
@@ -47,15 +47,24 @@ class BlockAdapterRegistry:
|
|
|
47
47
|
@classmethod
|
|
48
48
|
def has_separate_cfg(
|
|
49
49
|
cls,
|
|
50
|
-
|
|
50
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
|
|
51
51
|
) -> bool:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
52
|
+
|
|
53
|
+
# Prefer custom setting from block adapter.
|
|
54
|
+
if isinstance(pipe_or_adapter, BlockAdapter):
|
|
55
|
+
return pipe_or_adapter.has_separate_cfg
|
|
56
|
+
|
|
57
|
+
has_separate_cfg = False
|
|
58
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
59
|
+
has_separate_cfg = cls.get_adapter(
|
|
60
|
+
pipe_or_adapter,
|
|
61
|
+
skip_post_init=True, # check cfg setting only
|
|
62
|
+
).has_separate_cfg
|
|
63
|
+
|
|
64
|
+
if has_separate_cfg:
|
|
56
65
|
return True
|
|
57
66
|
|
|
58
|
-
pipe_cls_name =
|
|
67
|
+
pipe_cls_name = pipe_or_adapter.__class__.__name__
|
|
59
68
|
for name in cls._predefined_adapters_has_spearate_cfg:
|
|
60
69
|
if pipe_cls_name.startswith(name):
|
|
61
70
|
return True
|