cache-dit 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +166 -160
- cache_dit/cache_factory/block_adapters/block_adapters.py +195 -125
- cache_dit/cache_factory/block_adapters/block_registers.py +25 -13
- cache_dit/cache_factory/cache_adapters.py +209 -86
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +70 -67
- cache_dit/cache_factory/cache_blocks/utils.py +16 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +22 -10
- cache_dit/cache_factory/cache_interface.py +26 -14
- cache_dit/cache_factory/cache_types.py +5 -5
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -2
- cache_dit/cache_factory/patch_functors/functor_flux.py +3 -2
- cache_dit/utils.py +168 -55
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/METADATA +34 -55
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/RECORD +21 -21
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import dataclasses
|
|
5
|
+
from collections.abc import Iterable
|
|
5
6
|
|
|
6
7
|
from typing import Any, Tuple, List, Optional, Union
|
|
7
8
|
|
|
@@ -74,7 +75,7 @@ class BlockAdapter:
|
|
|
74
75
|
List[List[ParamsModifier]],
|
|
75
76
|
] = None
|
|
76
77
|
|
|
77
|
-
check_num_outputs: bool =
|
|
78
|
+
check_num_outputs: bool = False
|
|
78
79
|
|
|
79
80
|
# Pipeline Level Flags
|
|
80
81
|
# Patch Functor: Flux, etc.
|
|
@@ -82,9 +83,6 @@ class BlockAdapter:
|
|
|
82
83
|
# Flags for separate cfg
|
|
83
84
|
has_separate_cfg: bool = False
|
|
84
85
|
|
|
85
|
-
# Other Flags
|
|
86
|
-
disable_patch: bool = False
|
|
87
|
-
|
|
88
86
|
# Flags to control auto block adapter
|
|
89
87
|
# NOTE: NOT support for multi-transformers.
|
|
90
88
|
auto: bool = False
|
|
@@ -107,15 +105,94 @@ class BlockAdapter:
|
|
|
107
105
|
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
108
106
|
)
|
|
109
107
|
|
|
108
|
+
# Other Flags
|
|
109
|
+
skip_post_init: bool = False
|
|
110
|
+
|
|
110
111
|
def __post_init__(self):
|
|
111
|
-
|
|
112
|
-
|
|
112
|
+
if self.skip_post_init:
|
|
113
|
+
return
|
|
114
|
+
if any((self.pipe is not None, self.transformer is not None)):
|
|
115
|
+
self.maybe_fill_attrs()
|
|
116
|
+
self.maybe_patchify()
|
|
117
|
+
|
|
118
|
+
def maybe_fill_attrs(self):
|
|
119
|
+
# NOTE: This func should be call before normalize.
|
|
120
|
+
# Allow empty `blocks_names`, we will auto fill it.
|
|
121
|
+
# TODO: preprocess more empty attrs.
|
|
122
|
+
if (
|
|
123
|
+
self.transformer is not None
|
|
124
|
+
and self.blocks is not None
|
|
125
|
+
and self.blocks_name is None
|
|
126
|
+
):
|
|
113
127
|
|
|
114
|
-
|
|
128
|
+
def _find(transformer, blocks):
|
|
129
|
+
attr_names = dir(transformer)
|
|
130
|
+
assert isinstance(blocks, torch.nn.ModuleList)
|
|
131
|
+
blocks_name = None
|
|
132
|
+
for attr_name in attr_names:
|
|
133
|
+
if (
|
|
134
|
+
attr := getattr(transformer, attr_name, None)
|
|
135
|
+
) is not None:
|
|
136
|
+
if isinstance(attr, torch.nn.ModuleList) and id(
|
|
137
|
+
attr
|
|
138
|
+
) == id(blocks):
|
|
139
|
+
blocks_name = attr_name
|
|
140
|
+
break
|
|
141
|
+
assert (
|
|
142
|
+
blocks_name is not None
|
|
143
|
+
), "No blocks_name match, please set it manually!"
|
|
144
|
+
return blocks_name
|
|
145
|
+
|
|
146
|
+
if self.nested_depth(self.transformer) == 0:
|
|
147
|
+
if self.nested_depth(self.blocks) == 0: # str
|
|
148
|
+
self.blocks_name = _find(self.transformer, self.blocks)
|
|
149
|
+
elif self.nested_depth(self.blocks) == 1:
|
|
150
|
+
self.blocks_name = [
|
|
151
|
+
_find(self.transformer, blocks)
|
|
152
|
+
for blocks in self.blocks
|
|
153
|
+
]
|
|
154
|
+
else:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
"Blocks nested depth can't more than 1 if transformer "
|
|
157
|
+
f"is not a list, current is: {self.nested_depth(self.blocks)}"
|
|
158
|
+
)
|
|
159
|
+
elif self.nested_depth(self.transformer) == 1: # List[str]
|
|
160
|
+
if self.nested_depth(self.blocks) == 1: # List[str]
|
|
161
|
+
assert len(self.transformer) == len(self.blocks)
|
|
162
|
+
self.blocks_name = [
|
|
163
|
+
_find(transformer, blocks)
|
|
164
|
+
for transformer, blocks in zip(
|
|
165
|
+
self.transformer, self.blocks
|
|
166
|
+
)
|
|
167
|
+
]
|
|
168
|
+
elif self.nested_depth(self.blocks) == 2: # List[List[str]]
|
|
169
|
+
assert len(self.transformer) == len(self.blocks)
|
|
170
|
+
self.blocks_name = []
|
|
171
|
+
for i in range(len(self.blocks)):
|
|
172
|
+
self.blocks_name.append(
|
|
173
|
+
[
|
|
174
|
+
_find(self.transformer[i], blocks)
|
|
175
|
+
for blocks in self.blocks[i]
|
|
176
|
+
]
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
"Blocks nested depth can only be 1 or 2 "
|
|
181
|
+
"if transformer is a list, current is: "
|
|
182
|
+
f"{self.nested_depth(self.blocks)}"
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
raise ValueError(
|
|
186
|
+
"transformer nested depth can't more than 1, "
|
|
187
|
+
f"current is: {self.nested_depth(self.transformer)}"
|
|
188
|
+
)
|
|
189
|
+
logger.info(f"Auto fill blocks_name: {self.blocks_name}.")
|
|
190
|
+
|
|
191
|
+
def maybe_patchify(self, *args, **kwargs):
|
|
115
192
|
# Process some specificial cases, specific for transformers
|
|
116
193
|
# that has different forward patterns between single_transformer_blocks
|
|
117
194
|
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
118
|
-
if self.patch_functor is not None
|
|
195
|
+
if self.patch_functor is not None:
|
|
119
196
|
if self.transformer is not None:
|
|
120
197
|
self.patch_functor.apply(self.transformer, *args, **kwargs)
|
|
121
198
|
else:
|
|
@@ -141,7 +218,7 @@ class BlockAdapter:
|
|
|
141
218
|
transformer = pipe.transformer
|
|
142
219
|
|
|
143
220
|
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
144
|
-
blocks, blocks_name = BlockAdapter.
|
|
221
|
+
blocks, blocks_name = BlockAdapter.find_match_blocks(
|
|
145
222
|
transformer=transformer,
|
|
146
223
|
allow_prefixes=adapter.allow_prefixes,
|
|
147
224
|
allow_suffixes=adapter.allow_suffixes,
|
|
@@ -164,6 +241,10 @@ class BlockAdapter:
|
|
|
164
241
|
def check_block_adapter(
|
|
165
242
|
adapter: "BlockAdapter",
|
|
166
243
|
) -> bool:
|
|
244
|
+
|
|
245
|
+
if getattr(adapter, "_is_normlized", False):
|
|
246
|
+
return True
|
|
247
|
+
|
|
167
248
|
def _check_warning(attr: str):
|
|
168
249
|
if getattr(adapter, attr, None) is None:
|
|
169
250
|
logger.warning(f"{attr} is None!")
|
|
@@ -185,24 +266,23 @@ class BlockAdapter:
|
|
|
185
266
|
if not _check_warning("forward_pattern"):
|
|
186
267
|
return False
|
|
187
268
|
|
|
188
|
-
if
|
|
189
|
-
|
|
190
|
-
if not isinstance(blocks, torch.nn.ModuleList):
|
|
191
|
-
logger.warning(f"blocks[{i}] is not ModuleList.")
|
|
192
|
-
return False
|
|
269
|
+
if BlockAdapter.nested_depth(adapter.blocks) == 0:
|
|
270
|
+
blocks = adapter.blocks
|
|
193
271
|
else:
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
272
|
+
blocks = BlockAdapter.flatten(adapter.blocks)[0]
|
|
273
|
+
|
|
274
|
+
if not isinstance(blocks, torch.nn.ModuleList):
|
|
275
|
+
logger.warning("blocks is not ModuleList.")
|
|
276
|
+
return False
|
|
197
277
|
|
|
198
278
|
return True
|
|
199
279
|
|
|
200
280
|
@staticmethod
|
|
201
|
-
def
|
|
281
|
+
def find_match_blocks(
|
|
202
282
|
transformer: torch.nn.Module,
|
|
203
283
|
allow_prefixes: List[str] = [
|
|
204
|
-
"
|
|
205
|
-
"
|
|
284
|
+
"transformer_blocks",
|
|
285
|
+
"single_transformer_blocks",
|
|
206
286
|
"blocks",
|
|
207
287
|
"layers",
|
|
208
288
|
"single_stream_blocks",
|
|
@@ -230,10 +310,10 @@ class BlockAdapter:
|
|
|
230
310
|
valid_count = []
|
|
231
311
|
forward_pattern = kwargs.pop("forward_pattern", None)
|
|
232
312
|
for blocks_name in blocks_names:
|
|
233
|
-
if blocks := getattr(transformer, blocks_name, None):
|
|
313
|
+
if (blocks := getattr(transformer, blocks_name, None)) is not None:
|
|
234
314
|
if isinstance(blocks, torch.nn.ModuleList):
|
|
235
315
|
block = blocks[0]
|
|
236
|
-
block_cls_name = block.__class__.__name__
|
|
316
|
+
block_cls_name: str = block.__class__.__name__
|
|
237
317
|
# Check suffixes
|
|
238
318
|
if isinstance(block, torch.nn.Module) and (
|
|
239
319
|
any(
|
|
@@ -293,6 +373,18 @@ class BlockAdapter:
|
|
|
293
373
|
|
|
294
374
|
return final_blocks, final_name
|
|
295
375
|
|
|
376
|
+
@staticmethod
|
|
377
|
+
def find_blocks(
|
|
378
|
+
transformer: torch.nn.Module,
|
|
379
|
+
) -> List[torch.nn.ModuleList]:
|
|
380
|
+
total_blocks = []
|
|
381
|
+
for attr in dir(transformer):
|
|
382
|
+
if (blocks := getattr(transformer, attr, None)) is not None:
|
|
383
|
+
if isinstance(blocks, torch.nn.ModuleList):
|
|
384
|
+
if isinstance(blocks[0], torch.nn.Module):
|
|
385
|
+
total_blocks.append(blocks)
|
|
386
|
+
return total_blocks
|
|
387
|
+
|
|
296
388
|
@staticmethod
|
|
297
389
|
def match_block_pattern(
|
|
298
390
|
block: torch.nn.Module,
|
|
@@ -373,103 +465,51 @@ class BlockAdapter:
|
|
|
373
465
|
if getattr(adapter, "_is_normalized", False):
|
|
374
466
|
return adapter
|
|
375
467
|
|
|
376
|
-
if
|
|
468
|
+
if BlockAdapter.nested_depth(adapter.transformer) == 0:
|
|
377
469
|
adapter.transformer = [adapter.transformer]
|
|
378
470
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
# [[blocks_0, blocks_1],[blocks_2, blocks_3],] -> match [TRN_0, TRN_1,]
|
|
391
|
-
pass
|
|
392
|
-
|
|
393
|
-
if isinstance(adapter.blocks_name, str):
|
|
394
|
-
adapter.blocks_name = [[adapter.blocks_name]]
|
|
395
|
-
elif isinstance(adapter.blocks_name, list):
|
|
396
|
-
if isinstance(adapter.blocks_name[0], str):
|
|
397
|
-
if len(adapter.blocks_name) == len(adapter.transformer):
|
|
398
|
-
adapter.blocks_name = [
|
|
399
|
-
[blocks_name] for blocks_name in adapter.blocks_name
|
|
400
|
-
]
|
|
401
|
-
else:
|
|
402
|
-
adapter.blocks_name = [adapter.blocks_name]
|
|
403
|
-
elif isinstance(adapter.blocks_name[0], list):
|
|
404
|
-
pass
|
|
405
|
-
|
|
406
|
-
if isinstance(adapter.forward_pattern, ForwardPattern):
|
|
407
|
-
adapter.forward_pattern = [[adapter.forward_pattern]]
|
|
408
|
-
elif isinstance(adapter.forward_pattern, list):
|
|
409
|
-
if isinstance(adapter.forward_pattern[0], ForwardPattern):
|
|
410
|
-
if len(adapter.forward_pattern) == len(adapter.transformer):
|
|
411
|
-
adapter.forward_pattern = [
|
|
412
|
-
[forward_pattern]
|
|
413
|
-
for forward_pattern in adapter.forward_pattern
|
|
414
|
-
]
|
|
415
|
-
else:
|
|
416
|
-
adapter.forward_pattern = [adapter.forward_pattern]
|
|
417
|
-
elif isinstance(adapter.forward_pattern[0], list):
|
|
418
|
-
pass
|
|
419
|
-
|
|
420
|
-
if isinstance(adapter.dummy_blocks_names, list):
|
|
421
|
-
if len(adapter.dummy_blocks_names) > 0:
|
|
422
|
-
if isinstance(adapter.dummy_blocks_names[0], str):
|
|
423
|
-
if len(adapter.dummy_blocks_names) == len(
|
|
424
|
-
adapter.transformer
|
|
425
|
-
):
|
|
426
|
-
adapter.dummy_blocks_names = [
|
|
427
|
-
[dummy_blocks_names]
|
|
428
|
-
for dummy_blocks_names in adapter.dummy_blocks_names
|
|
429
|
-
]
|
|
430
|
-
else:
|
|
431
|
-
adapter.dummy_blocks_names = [
|
|
432
|
-
adapter.dummy_blocks_names
|
|
433
|
-
]
|
|
434
|
-
elif isinstance(adapter.dummy_blocks_names[0], list):
|
|
435
|
-
pass
|
|
436
|
-
else:
|
|
437
|
-
# Empty dummy_blocks_names
|
|
438
|
-
adapter.dummy_blocks_names = [
|
|
439
|
-
[] for _ in range(len(adapter.transformer))
|
|
440
|
-
]
|
|
441
|
-
|
|
442
|
-
if adapter.params_modifiers is not None:
|
|
443
|
-
if isinstance(adapter.params_modifiers, ParamsModifier):
|
|
444
|
-
adapter.params_modifiers = [[adapter.params_modifiers]]
|
|
445
|
-
elif isinstance(adapter.params_modifiers, list):
|
|
446
|
-
if isinstance(adapter.params_modifiers[0], ParamsModifier):
|
|
447
|
-
if len(adapter.params_modifiers) == len(
|
|
448
|
-
adapter.transformer
|
|
449
|
-
):
|
|
450
|
-
adapter.params_modifiers = [
|
|
451
|
-
[params_modifiers]
|
|
452
|
-
for params_modifiers in adapter.params_modifiers
|
|
453
|
-
]
|
|
471
|
+
def _normalize_attr(attr: Any):
|
|
472
|
+
normalized_attr = attr
|
|
473
|
+
if attr is None:
|
|
474
|
+
return normalized_attr
|
|
475
|
+
|
|
476
|
+
if BlockAdapter.nested_depth(attr) == 0:
|
|
477
|
+
normalized_attr = [[attr]]
|
|
478
|
+
elif BlockAdapter.nested_depth(attr) == 1: # List
|
|
479
|
+
if attr: # not-empty
|
|
480
|
+
if len(attr) == len(adapter.transformer):
|
|
481
|
+
normalized_attr = [[a] for a in attr]
|
|
454
482
|
else:
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
483
|
+
normalized_attr = [attr]
|
|
484
|
+
else: # [] empty
|
|
485
|
+
normalized_attr = [
|
|
486
|
+
[] for _ in range(len(adapter.transformer))
|
|
487
|
+
]
|
|
488
|
+
|
|
489
|
+
assert len(adapter.transformer) == len(normalized_attr)
|
|
490
|
+
return normalized_attr
|
|
491
|
+
|
|
492
|
+
adapter.blocks = _normalize_attr(adapter.blocks)
|
|
493
|
+
adapter.blocks_name = _normalize_attr(adapter.blocks_name)
|
|
494
|
+
adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
|
|
495
|
+
adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
|
|
496
|
+
adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
|
|
497
|
+
BlockAdapter.unique(adapter)
|
|
458
498
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
assert len(adapter.transformer) == len(adapter.dummy_blocks_names)
|
|
463
|
-
if adapter.params_modifiers is not None:
|
|
464
|
-
assert len(adapter.transformer) == len(adapter.params_modifiers)
|
|
499
|
+
adapter._is_normalized = True
|
|
500
|
+
|
|
501
|
+
return adapter
|
|
465
502
|
|
|
503
|
+
@classmethod
|
|
504
|
+
def unique(cls, adapter: "BlockAdapter"):
|
|
505
|
+
# NOTE: Users should never call this function
|
|
466
506
|
for i in range(len(adapter.blocks)):
|
|
467
507
|
assert len(adapter.blocks[i]) == len(adapter.blocks_name[i])
|
|
468
508
|
assert len(adapter.blocks[i]) == len(adapter.forward_pattern[i])
|
|
469
509
|
|
|
510
|
+
# Generate unique blocks names
|
|
470
511
|
if len(adapter.unique_blocks_name) == 0:
|
|
471
512
|
for i in range(len(adapter.transformer)):
|
|
472
|
-
# Generate unique blocks names
|
|
473
513
|
adapter.unique_blocks_name.append(
|
|
474
514
|
[
|
|
475
515
|
f"{name}_{hash(id(blocks))}"
|
|
@@ -479,10 +519,10 @@ class BlockAdapter:
|
|
|
479
519
|
)
|
|
480
520
|
]
|
|
481
521
|
)
|
|
522
|
+
else:
|
|
523
|
+
assert len(adapter.transformer) == len(adapter.unique_blocks_name)
|
|
482
524
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
# Match Forward Pattern
|
|
525
|
+
# Also check Match Forward Pattern
|
|
486
526
|
for i in range(len(adapter.transformer)):
|
|
487
527
|
for forward_pattern, blocks in zip(
|
|
488
528
|
adapter.forward_pattern[i], adapter.blocks[i]
|
|
@@ -496,10 +536,6 @@ class BlockAdapter:
|
|
|
496
536
|
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
497
537
|
)
|
|
498
538
|
|
|
499
|
-
adapter._is_normalized = True
|
|
500
|
-
|
|
501
|
-
return adapter
|
|
502
|
-
|
|
503
539
|
@classmethod
|
|
504
540
|
def assert_normalized(cls, adapter: "BlockAdapter"):
|
|
505
541
|
if not getattr(adapter, "_is_normalized", False):
|
|
@@ -524,15 +560,49 @@ class BlockAdapter:
|
|
|
524
560
|
assert isinstance(adapter[0], torch.nn.Module)
|
|
525
561
|
return getattr(adapter[0], "_is_cached", False)
|
|
526
562
|
else:
|
|
527
|
-
raise TypeError(f"Can't check this type: {adapter}!")
|
|
563
|
+
raise TypeError(f"Can't check this type: {type(adapter)}!")
|
|
528
564
|
|
|
529
565
|
@classmethod
|
|
530
|
-
def
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
566
|
+
def nested_depth(cls, obj: Any):
|
|
567
|
+
# str: 0; List[str]: 1; List[List[str]]: 2
|
|
568
|
+
atom_types = (
|
|
569
|
+
str,
|
|
570
|
+
bytes,
|
|
571
|
+
torch.nn.ModuleList,
|
|
572
|
+
torch.nn.Module,
|
|
573
|
+
torch.Tensor,
|
|
574
|
+
)
|
|
575
|
+
if isinstance(obj, atom_types):
|
|
576
|
+
return 0
|
|
577
|
+
if not isinstance(obj, Iterable):
|
|
578
|
+
return 0
|
|
579
|
+
if isinstance(obj, dict):
|
|
580
|
+
items = obj.values()
|
|
581
|
+
else:
|
|
582
|
+
items = obj
|
|
583
|
+
|
|
584
|
+
max_depth = 0
|
|
585
|
+
for item in items:
|
|
586
|
+
current_depth = cls.nested_depth(item)
|
|
587
|
+
if current_depth > max_depth:
|
|
588
|
+
max_depth = current_depth
|
|
589
|
+
return 1 + max_depth
|
|
590
|
+
|
|
591
|
+
@classmethod
|
|
592
|
+
def flatten(cls, attr: List[Any]) -> List[Any]:
|
|
593
|
+
atom_types = (
|
|
594
|
+
str,
|
|
595
|
+
bytes,
|
|
596
|
+
torch.nn.ModuleList,
|
|
597
|
+
torch.nn.Module,
|
|
598
|
+
torch.Tensor,
|
|
599
|
+
)
|
|
600
|
+
if not isinstance(attr, list):
|
|
601
|
+
return attr
|
|
602
|
+
flattened = []
|
|
603
|
+
for item in attr:
|
|
604
|
+
if isinstance(item, list) and not isinstance(item, atom_types):
|
|
605
|
+
flattened.extend(cls.flatten(item))
|
|
606
|
+
else:
|
|
607
|
+
flattened.append(item)
|
|
608
|
+
return flattened
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Tuple, List, Dict
|
|
1
|
+
from typing import Any, Tuple, List, Dict, Callable
|
|
2
2
|
|
|
3
3
|
from diffusers import DiffusionPipeline
|
|
4
4
|
from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
|
|
@@ -9,20 +9,23 @@ logger = init_logger(__name__)
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class BlockAdapterRegistry:
|
|
12
|
-
_adapters: Dict[str, BlockAdapter] = {}
|
|
13
|
-
_predefined_adapters_has_spearate_cfg: List[str] =
|
|
12
|
+
_adapters: Dict[str, Callable[..., BlockAdapter]] = {}
|
|
13
|
+
_predefined_adapters_has_spearate_cfg: List[str] = [
|
|
14
14
|
"QwenImage",
|
|
15
15
|
"Wan",
|
|
16
16
|
"CogView4",
|
|
17
17
|
"Cosmos",
|
|
18
18
|
"SkyReelsV2",
|
|
19
19
|
"Chroma",
|
|
20
|
-
|
|
20
|
+
]
|
|
21
21
|
|
|
22
22
|
@classmethod
|
|
23
|
-
def register(cls, name):
|
|
24
|
-
def decorator(
|
|
25
|
-
|
|
23
|
+
def register(cls, name: str, supported: bool = True):
|
|
24
|
+
def decorator(
|
|
25
|
+
func: Callable[..., BlockAdapter]
|
|
26
|
+
) -> Callable[..., BlockAdapter]:
|
|
27
|
+
if supported:
|
|
28
|
+
cls._adapters[name] = func
|
|
26
29
|
return func
|
|
27
30
|
|
|
28
31
|
return decorator
|
|
@@ -47,15 +50,24 @@ class BlockAdapterRegistry:
|
|
|
47
50
|
@classmethod
|
|
48
51
|
def has_separate_cfg(
|
|
49
52
|
cls,
|
|
50
|
-
|
|
53
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
|
|
51
54
|
) -> bool:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
55
|
+
|
|
56
|
+
# Prefer custom setting from block adapter.
|
|
57
|
+
if isinstance(pipe_or_adapter, BlockAdapter):
|
|
58
|
+
return pipe_or_adapter.has_separate_cfg
|
|
59
|
+
|
|
60
|
+
has_separate_cfg = False
|
|
61
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
62
|
+
has_separate_cfg = cls.get_adapter(
|
|
63
|
+
pipe_or_adapter,
|
|
64
|
+
skip_post_init=True, # check cfg setting only
|
|
65
|
+
).has_separate_cfg
|
|
66
|
+
|
|
67
|
+
if has_separate_cfg:
|
|
56
68
|
return True
|
|
57
69
|
|
|
58
|
-
pipe_cls_name =
|
|
70
|
+
pipe_cls_name = pipe_or_adapter.__class__.__name__
|
|
59
71
|
for name in cls._predefined_adapters_has_spearate_cfg:
|
|
60
72
|
if pipe_cls_name.startswith(name):
|
|
61
73
|
return True
|