cache-dit 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +87 -110
- cache_dit/cache_factory/block_adapters/block_adapters.py +190 -122
- cache_dit/cache_factory/block_adapters/block_registers.py +15 -6
- cache_dit/cache_factory/cache_adapters.py +38 -25
- cache_dit/cache_factory/cache_blocks/utils.py +16 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +14 -0
- cache_dit/cache_factory/cache_interface.py +81 -11
- cache_dit/cache_factory/patch_functors/functor_chroma.py +1 -1
- cache_dit/cache_factory/patch_functors/functor_flux.py +1 -1
- cache_dit/utils.py +164 -55
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.29.dist-info}/METADATA +33 -14
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.29.dist-info}/RECORD +19 -19
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.29.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.29.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.29.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.29.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
4
|
import dataclasses
|
|
5
|
+
from collections.abc import Iterable
|
|
5
6
|
|
|
6
7
|
from typing import Any, Tuple, List, Optional, Union
|
|
7
8
|
|
|
@@ -82,9 +83,6 @@ class BlockAdapter:
|
|
|
82
83
|
# Flags for separate cfg
|
|
83
84
|
has_separate_cfg: bool = False
|
|
84
85
|
|
|
85
|
-
# Other Flags
|
|
86
|
-
disable_patch: bool = False
|
|
87
|
-
|
|
88
86
|
# Flags to control auto block adapter
|
|
89
87
|
# NOTE: NOT support for multi-transformers.
|
|
90
88
|
auto: bool = False
|
|
@@ -107,15 +105,92 @@ class BlockAdapter:
|
|
|
107
105
|
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
108
106
|
)
|
|
109
107
|
|
|
108
|
+
# Other Flags
|
|
109
|
+
skip_post_init: bool = False
|
|
110
|
+
|
|
110
111
|
def __post_init__(self):
|
|
112
|
+
if self.skip_post_init:
|
|
113
|
+
return
|
|
111
114
|
assert any((self.pipe is not None, self.transformer is not None))
|
|
112
|
-
self.
|
|
115
|
+
self.maybe_fill_attrs()
|
|
116
|
+
self.maybe_patchify()
|
|
117
|
+
|
|
118
|
+
def maybe_fill_attrs(self):
|
|
119
|
+
# NOTE: This func should be call before normalize.
|
|
120
|
+
# Allow empty `blocks_names`, we will auto fill it.
|
|
121
|
+
# TODO: preprocess more empty attrs.
|
|
122
|
+
if (
|
|
123
|
+
self.transformer is not None
|
|
124
|
+
and self.blocks is not None
|
|
125
|
+
and self.blocks_name is None
|
|
126
|
+
):
|
|
113
127
|
|
|
114
|
-
|
|
128
|
+
def _find(transformer, blocks):
|
|
129
|
+
attr_names = dir(transformer)
|
|
130
|
+
assert isinstance(blocks, torch.nn.ModuleList)
|
|
131
|
+
blocks_name = None
|
|
132
|
+
for attr_name in attr_names:
|
|
133
|
+
if attr := getattr(transformer, attr_name, None):
|
|
134
|
+
if isinstance(attr, torch.nn.ModuleList) and id(
|
|
135
|
+
attr
|
|
136
|
+
) == id(blocks):
|
|
137
|
+
blocks_name = attr_name
|
|
138
|
+
break
|
|
139
|
+
assert (
|
|
140
|
+
blocks_name is not None
|
|
141
|
+
), "No blocks_name match, please set it manually!"
|
|
142
|
+
return blocks_name
|
|
143
|
+
|
|
144
|
+
if self.nested_depth(self.transformer) == 0:
|
|
145
|
+
if self.nested_depth(self.blocks) == 0: # str
|
|
146
|
+
self.blocks_name = _find(self.transformer, self.blocks)
|
|
147
|
+
elif self.nested_depth(self.blocks) == 1:
|
|
148
|
+
self.blocks_name = [
|
|
149
|
+
_find(self.transformer, blocks)
|
|
150
|
+
for blocks in self.blocks
|
|
151
|
+
]
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"Blocks nested depth can't more than 1 if transformer "
|
|
155
|
+
f"is not a list, current is: {self.nested_depth(self.blocks)}"
|
|
156
|
+
)
|
|
157
|
+
elif self.nested_depth(self.transformer) == 1: # List[str]
|
|
158
|
+
if self.nested_depth(self.blocks) == 1: # List[str]
|
|
159
|
+
assert len(self.transformer) == len(self.blocks)
|
|
160
|
+
self.blocks_name = [
|
|
161
|
+
_find(transformer, blocks)
|
|
162
|
+
for transformer, blocks in zip(
|
|
163
|
+
self.transformer, self.blocks
|
|
164
|
+
)
|
|
165
|
+
]
|
|
166
|
+
elif self.nested_depth(self.blocks) == 2: # List[List[str]]
|
|
167
|
+
assert len(self.transformer) == len(self.blocks)
|
|
168
|
+
self.blocks_name = []
|
|
169
|
+
for i in range(len(self.blocks)):
|
|
170
|
+
self.blocks_name.append(
|
|
171
|
+
[
|
|
172
|
+
_find(self.transformer[i], blocks)
|
|
173
|
+
for blocks in self.blocks[i]
|
|
174
|
+
]
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"Blocks nested depth can only be 1 or 2 "
|
|
179
|
+
"if transformer is a list, current is: "
|
|
180
|
+
f"{self.nested_depth(self.blocks)}"
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
"transformer nested depth can't more than 1, "
|
|
185
|
+
f"current is: {self.nested_depth(self.transformer)}"
|
|
186
|
+
)
|
|
187
|
+
logger.info(f"Auto fill blocks_name: {self.blocks_name}.")
|
|
188
|
+
|
|
189
|
+
def maybe_patchify(self, *args, **kwargs):
|
|
115
190
|
# Process some specificial cases, specific for transformers
|
|
116
191
|
# that has different forward patterns between single_transformer_blocks
|
|
117
192
|
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
118
|
-
if self.patch_functor is not None
|
|
193
|
+
if self.patch_functor is not None:
|
|
119
194
|
if self.transformer is not None:
|
|
120
195
|
self.patch_functor.apply(self.transformer, *args, **kwargs)
|
|
121
196
|
else:
|
|
@@ -141,7 +216,7 @@ class BlockAdapter:
|
|
|
141
216
|
transformer = pipe.transformer
|
|
142
217
|
|
|
143
218
|
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
144
|
-
blocks, blocks_name = BlockAdapter.
|
|
219
|
+
blocks, blocks_name = BlockAdapter.find_match_blocks(
|
|
145
220
|
transformer=transformer,
|
|
146
221
|
allow_prefixes=adapter.allow_prefixes,
|
|
147
222
|
allow_suffixes=adapter.allow_suffixes,
|
|
@@ -164,6 +239,10 @@ class BlockAdapter:
|
|
|
164
239
|
def check_block_adapter(
|
|
165
240
|
adapter: "BlockAdapter",
|
|
166
241
|
) -> bool:
|
|
242
|
+
|
|
243
|
+
if getattr(adapter, "_is_normlized", False):
|
|
244
|
+
return True
|
|
245
|
+
|
|
167
246
|
def _check_warning(attr: str):
|
|
168
247
|
if getattr(adapter, attr, None) is None:
|
|
169
248
|
logger.warning(f"{attr} is None!")
|
|
@@ -185,24 +264,23 @@ class BlockAdapter:
|
|
|
185
264
|
if not _check_warning("forward_pattern"):
|
|
186
265
|
return False
|
|
187
266
|
|
|
188
|
-
if
|
|
189
|
-
|
|
190
|
-
if not isinstance(blocks, torch.nn.ModuleList):
|
|
191
|
-
logger.warning(f"blocks[{i}] is not ModuleList.")
|
|
192
|
-
return False
|
|
267
|
+
if BlockAdapter.nested_depth(adapter.blocks) == 0:
|
|
268
|
+
blocks = adapter.blocks
|
|
193
269
|
else:
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
270
|
+
blocks = BlockAdapter.flatten(adapter.blocks)[0]
|
|
271
|
+
|
|
272
|
+
if not isinstance(blocks, torch.nn.ModuleList):
|
|
273
|
+
logger.warning("blocks is not ModuleList.")
|
|
274
|
+
return False
|
|
197
275
|
|
|
198
276
|
return True
|
|
199
277
|
|
|
200
278
|
@staticmethod
|
|
201
|
-
def
|
|
279
|
+
def find_match_blocks(
|
|
202
280
|
transformer: torch.nn.Module,
|
|
203
281
|
allow_prefixes: List[str] = [
|
|
204
|
-
"
|
|
205
|
-
"
|
|
282
|
+
"transformer_blocks",
|
|
283
|
+
"single_transformer_blocks",
|
|
206
284
|
"blocks",
|
|
207
285
|
"layers",
|
|
208
286
|
"single_stream_blocks",
|
|
@@ -230,10 +308,10 @@ class BlockAdapter:
|
|
|
230
308
|
valid_count = []
|
|
231
309
|
forward_pattern = kwargs.pop("forward_pattern", None)
|
|
232
310
|
for blocks_name in blocks_names:
|
|
233
|
-
if blocks := getattr(transformer, blocks_name, None):
|
|
311
|
+
if (blocks := getattr(transformer, blocks_name, None)) is not None:
|
|
234
312
|
if isinstance(blocks, torch.nn.ModuleList):
|
|
235
313
|
block = blocks[0]
|
|
236
|
-
block_cls_name = block.__class__.__name__
|
|
314
|
+
block_cls_name: str = block.__class__.__name__
|
|
237
315
|
# Check suffixes
|
|
238
316
|
if isinstance(block, torch.nn.Module) and (
|
|
239
317
|
any(
|
|
@@ -293,6 +371,18 @@ class BlockAdapter:
|
|
|
293
371
|
|
|
294
372
|
return final_blocks, final_name
|
|
295
373
|
|
|
374
|
+
@staticmethod
|
|
375
|
+
def find_blocks(
|
|
376
|
+
transformer: torch.nn.Module,
|
|
377
|
+
) -> List[torch.nn.ModuleList]:
|
|
378
|
+
total_blocks = []
|
|
379
|
+
for attr in dir(transformer):
|
|
380
|
+
if (blocks := getattr(transformer, attr, None)) is not None:
|
|
381
|
+
if isinstance(blocks, torch.nn.ModuleList):
|
|
382
|
+
if isinstance(blocks[0], torch.nn.Module):
|
|
383
|
+
total_blocks.append(blocks)
|
|
384
|
+
return total_blocks
|
|
385
|
+
|
|
296
386
|
@staticmethod
|
|
297
387
|
def match_block_pattern(
|
|
298
388
|
block: torch.nn.Module,
|
|
@@ -373,103 +463,51 @@ class BlockAdapter:
|
|
|
373
463
|
if getattr(adapter, "_is_normalized", False):
|
|
374
464
|
return adapter
|
|
375
465
|
|
|
376
|
-
if
|
|
466
|
+
if BlockAdapter.nested_depth(adapter.transformer) == 0:
|
|
377
467
|
adapter.transformer = [adapter.transformer]
|
|
378
468
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
# [[blocks_0, blocks_1],[blocks_2, blocks_3],] -> match [TRN_0, TRN_1,]
|
|
391
|
-
pass
|
|
392
|
-
|
|
393
|
-
if isinstance(adapter.blocks_name, str):
|
|
394
|
-
adapter.blocks_name = [[adapter.blocks_name]]
|
|
395
|
-
elif isinstance(adapter.blocks_name, list):
|
|
396
|
-
if isinstance(adapter.blocks_name[0], str):
|
|
397
|
-
if len(adapter.blocks_name) == len(adapter.transformer):
|
|
398
|
-
adapter.blocks_name = [
|
|
399
|
-
[blocks_name] for blocks_name in adapter.blocks_name
|
|
400
|
-
]
|
|
401
|
-
else:
|
|
402
|
-
adapter.blocks_name = [adapter.blocks_name]
|
|
403
|
-
elif isinstance(adapter.blocks_name[0], list):
|
|
404
|
-
pass
|
|
405
|
-
|
|
406
|
-
if isinstance(adapter.forward_pattern, ForwardPattern):
|
|
407
|
-
adapter.forward_pattern = [[adapter.forward_pattern]]
|
|
408
|
-
elif isinstance(adapter.forward_pattern, list):
|
|
409
|
-
if isinstance(adapter.forward_pattern[0], ForwardPattern):
|
|
410
|
-
if len(adapter.forward_pattern) == len(adapter.transformer):
|
|
411
|
-
adapter.forward_pattern = [
|
|
412
|
-
[forward_pattern]
|
|
413
|
-
for forward_pattern in adapter.forward_pattern
|
|
414
|
-
]
|
|
415
|
-
else:
|
|
416
|
-
adapter.forward_pattern = [adapter.forward_pattern]
|
|
417
|
-
elif isinstance(adapter.forward_pattern[0], list):
|
|
418
|
-
pass
|
|
419
|
-
|
|
420
|
-
if isinstance(adapter.dummy_blocks_names, list):
|
|
421
|
-
if len(adapter.dummy_blocks_names) > 0:
|
|
422
|
-
if isinstance(adapter.dummy_blocks_names[0], str):
|
|
423
|
-
if len(adapter.dummy_blocks_names) == len(
|
|
424
|
-
adapter.transformer
|
|
425
|
-
):
|
|
426
|
-
adapter.dummy_blocks_names = [
|
|
427
|
-
[dummy_blocks_names]
|
|
428
|
-
for dummy_blocks_names in adapter.dummy_blocks_names
|
|
429
|
-
]
|
|
430
|
-
else:
|
|
431
|
-
adapter.dummy_blocks_names = [
|
|
432
|
-
adapter.dummy_blocks_names
|
|
433
|
-
]
|
|
434
|
-
elif isinstance(adapter.dummy_blocks_names[0], list):
|
|
435
|
-
pass
|
|
436
|
-
else:
|
|
437
|
-
# Empty dummy_blocks_names
|
|
438
|
-
adapter.dummy_blocks_names = [
|
|
439
|
-
[] for _ in range(len(adapter.transformer))
|
|
440
|
-
]
|
|
441
|
-
|
|
442
|
-
if adapter.params_modifiers is not None:
|
|
443
|
-
if isinstance(adapter.params_modifiers, ParamsModifier):
|
|
444
|
-
adapter.params_modifiers = [[adapter.params_modifiers]]
|
|
445
|
-
elif isinstance(adapter.params_modifiers, list):
|
|
446
|
-
if isinstance(adapter.params_modifiers[0], ParamsModifier):
|
|
447
|
-
if len(adapter.params_modifiers) == len(
|
|
448
|
-
adapter.transformer
|
|
449
|
-
):
|
|
450
|
-
adapter.params_modifiers = [
|
|
451
|
-
[params_modifiers]
|
|
452
|
-
for params_modifiers in adapter.params_modifiers
|
|
453
|
-
]
|
|
469
|
+
def _normalize_attr(attr: Any):
|
|
470
|
+
normalized_attr = attr
|
|
471
|
+
if attr is None:
|
|
472
|
+
return normalized_attr
|
|
473
|
+
|
|
474
|
+
if BlockAdapter.nested_depth(attr) == 0:
|
|
475
|
+
normalized_attr = [[attr]]
|
|
476
|
+
elif BlockAdapter.nested_depth(attr) == 1: # List
|
|
477
|
+
if attr: # not-empty
|
|
478
|
+
if len(attr) == len(adapter.transformer):
|
|
479
|
+
normalized_attr = [[a] for a in attr]
|
|
454
480
|
else:
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
481
|
+
normalized_attr = [attr]
|
|
482
|
+
else: # [] empty
|
|
483
|
+
normalized_attr = [
|
|
484
|
+
[] for _ in range(len(adapter.transformer))
|
|
485
|
+
]
|
|
486
|
+
|
|
487
|
+
assert len(adapter.transformer) == len(normalized_attr)
|
|
488
|
+
return normalized_attr
|
|
489
|
+
|
|
490
|
+
adapter.blocks = _normalize_attr(adapter.blocks)
|
|
491
|
+
adapter.blocks_name = _normalize_attr(adapter.blocks_name)
|
|
492
|
+
adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
|
|
493
|
+
adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
|
|
494
|
+
adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
|
|
495
|
+
BlockAdapter.unique(adapter)
|
|
458
496
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
assert len(adapter.transformer) == len(adapter.dummy_blocks_names)
|
|
463
|
-
if adapter.params_modifiers is not None:
|
|
464
|
-
assert len(adapter.transformer) == len(adapter.params_modifiers)
|
|
497
|
+
adapter._is_normalized = True
|
|
498
|
+
|
|
499
|
+
return adapter
|
|
465
500
|
|
|
501
|
+
@classmethod
|
|
502
|
+
def unique(cls, adapter: "BlockAdapter"):
|
|
503
|
+
# NOTE: Users should never call this function
|
|
466
504
|
for i in range(len(adapter.blocks)):
|
|
467
505
|
assert len(adapter.blocks[i]) == len(adapter.blocks_name[i])
|
|
468
506
|
assert len(adapter.blocks[i]) == len(adapter.forward_pattern[i])
|
|
469
507
|
|
|
508
|
+
# Generate unique blocks names
|
|
470
509
|
if len(adapter.unique_blocks_name) == 0:
|
|
471
510
|
for i in range(len(adapter.transformer)):
|
|
472
|
-
# Generate unique blocks names
|
|
473
511
|
adapter.unique_blocks_name.append(
|
|
474
512
|
[
|
|
475
513
|
f"{name}_{hash(id(blocks))}"
|
|
@@ -479,10 +517,10 @@ class BlockAdapter:
|
|
|
479
517
|
)
|
|
480
518
|
]
|
|
481
519
|
)
|
|
520
|
+
else:
|
|
521
|
+
assert len(adapter.transformer) == len(adapter.unique_blocks_name)
|
|
482
522
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
# Match Forward Pattern
|
|
523
|
+
# Also check Match Forward Pattern
|
|
486
524
|
for i in range(len(adapter.transformer)):
|
|
487
525
|
for forward_pattern, blocks in zip(
|
|
488
526
|
adapter.forward_pattern[i], adapter.blocks[i]
|
|
@@ -496,10 +534,6 @@ class BlockAdapter:
|
|
|
496
534
|
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
497
535
|
)
|
|
498
536
|
|
|
499
|
-
adapter._is_normalized = True
|
|
500
|
-
|
|
501
|
-
return adapter
|
|
502
|
-
|
|
503
537
|
@classmethod
|
|
504
538
|
def assert_normalized(cls, adapter: "BlockAdapter"):
|
|
505
539
|
if not getattr(adapter, "_is_normalized", False):
|
|
@@ -527,12 +561,46 @@ class BlockAdapter:
|
|
|
527
561
|
raise TypeError(f"Can't check this type: {adapter}!")
|
|
528
562
|
|
|
529
563
|
@classmethod
|
|
530
|
-
def
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
564
|
+
def nested_depth(cls, obj: Any):
|
|
565
|
+
# str: 0; List[str]: 1; List[List[str]]: 2
|
|
566
|
+
atom_types = (
|
|
567
|
+
str,
|
|
568
|
+
bytes,
|
|
569
|
+
torch.nn.ModuleList,
|
|
570
|
+
torch.nn.Module,
|
|
571
|
+
torch.Tensor,
|
|
572
|
+
)
|
|
573
|
+
if isinstance(obj, atom_types):
|
|
574
|
+
return 0
|
|
575
|
+
if not isinstance(obj, Iterable):
|
|
576
|
+
return 0
|
|
577
|
+
if isinstance(obj, dict):
|
|
578
|
+
items = obj.values()
|
|
579
|
+
else:
|
|
580
|
+
items = obj
|
|
581
|
+
|
|
582
|
+
max_depth = 0
|
|
583
|
+
for item in items:
|
|
584
|
+
current_depth = cls.nested_depth(item)
|
|
585
|
+
if current_depth > max_depth:
|
|
586
|
+
max_depth = current_depth
|
|
587
|
+
return 1 + max_depth
|
|
588
|
+
|
|
589
|
+
@classmethod
|
|
590
|
+
def flatten(cls, attr: List[Any]) -> List[Any]:
|
|
591
|
+
atom_types = (
|
|
592
|
+
str,
|
|
593
|
+
bytes,
|
|
594
|
+
torch.nn.ModuleList,
|
|
595
|
+
torch.nn.Module,
|
|
596
|
+
torch.Tensor,
|
|
597
|
+
)
|
|
598
|
+
if not isinstance(attr, list):
|
|
599
|
+
return attr
|
|
600
|
+
flattened = []
|
|
601
|
+
for item in attr:
|
|
602
|
+
if isinstance(item, list) and not isinstance(item, atom_types):
|
|
603
|
+
flattened.extend(cls.flatten(item))
|
|
604
|
+
else:
|
|
605
|
+
flattened.append(item)
|
|
606
|
+
return flattened
|
|
@@ -47,15 +47,24 @@ class BlockAdapterRegistry:
|
|
|
47
47
|
@classmethod
|
|
48
48
|
def has_separate_cfg(
|
|
49
49
|
cls,
|
|
50
|
-
|
|
50
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
|
|
51
51
|
) -> bool:
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
52
|
+
|
|
53
|
+
# Prefer custom setting from block adapter.
|
|
54
|
+
if isinstance(pipe_or_adapter, BlockAdapter):
|
|
55
|
+
return pipe_or_adapter.has_separate_cfg
|
|
56
|
+
|
|
57
|
+
has_separate_cfg = False
|
|
58
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
59
|
+
has_separate_cfg = cls.get_adapter(
|
|
60
|
+
pipe_or_adapter,
|
|
61
|
+
skip_post_init=True, # check cfg setting only
|
|
62
|
+
).has_separate_cfg
|
|
63
|
+
|
|
64
|
+
if has_separate_cfg:
|
|
56
65
|
return True
|
|
57
66
|
|
|
58
|
-
pipe_cls_name =
|
|
67
|
+
pipe_cls_name = pipe_or_adapter.__class__.__name__
|
|
59
68
|
for name in cls._predefined_adapters_has_spearate_cfg:
|
|
60
69
|
if pipe_cls_name.startswith(name):
|
|
61
70
|
return True
|
|
@@ -29,36 +29,39 @@ class CachedAdapter:
|
|
|
29
29
|
@classmethod
|
|
30
30
|
def apply(
|
|
31
31
|
cls,
|
|
32
|
-
|
|
33
|
-
block_adapter: BlockAdapter = None,
|
|
32
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapter,
|
|
34
33
|
**cache_context_kwargs,
|
|
35
|
-
) ->
|
|
34
|
+
) -> BlockAdapter:
|
|
36
35
|
assert (
|
|
37
|
-
|
|
36
|
+
pipe_or_adapter is not None
|
|
38
37
|
), "pipe or block_adapter can not both None!"
|
|
39
38
|
|
|
40
|
-
if
|
|
41
|
-
if BlockAdapterRegistry.is_supported(
|
|
39
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
40
|
+
if BlockAdapterRegistry.is_supported(pipe_or_adapter):
|
|
42
41
|
logger.info(
|
|
43
|
-
f"{
|
|
44
|
-
"Use it's pre-defined BlockAdapter
|
|
42
|
+
f"{pipe_or_adapter.__class__.__name__} is officially "
|
|
43
|
+
"supported by cache-dit. Use it's pre-defined BlockAdapter "
|
|
44
|
+
"directly!"
|
|
45
|
+
)
|
|
46
|
+
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
47
|
+
pipe_or_adapter
|
|
45
48
|
)
|
|
46
|
-
block_adapter = BlockAdapterRegistry.get_adapter(pipe)
|
|
47
49
|
return cls.cachify(
|
|
48
50
|
block_adapter,
|
|
49
51
|
**cache_context_kwargs,
|
|
50
52
|
)
|
|
51
53
|
else:
|
|
52
54
|
raise ValueError(
|
|
53
|
-
f"{
|
|
55
|
+
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
54
56
|
"by cache-dit, please set BlockAdapter instead!"
|
|
55
57
|
)
|
|
56
58
|
else:
|
|
59
|
+
assert isinstance(pipe_or_adapter, BlockAdapter)
|
|
57
60
|
logger.info(
|
|
58
|
-
"Adapting
|
|
61
|
+
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
59
62
|
)
|
|
60
63
|
return cls.cachify(
|
|
61
|
-
|
|
64
|
+
pipe_or_adapter,
|
|
62
65
|
**cache_context_kwargs,
|
|
63
66
|
)
|
|
64
67
|
|
|
@@ -67,7 +70,7 @@ class CachedAdapter:
|
|
|
67
70
|
cls,
|
|
68
71
|
block_adapter: BlockAdapter,
|
|
69
72
|
**cache_context_kwargs,
|
|
70
|
-
) ->
|
|
73
|
+
) -> BlockAdapter:
|
|
71
74
|
|
|
72
75
|
if block_adapter.auto:
|
|
73
76
|
block_adapter = BlockAdapter.auto_block_adapter(
|
|
@@ -93,7 +96,7 @@ class CachedAdapter:
|
|
|
93
96
|
block_adapter,
|
|
94
97
|
)
|
|
95
98
|
|
|
96
|
-
return block_adapter
|
|
99
|
+
return block_adapter
|
|
97
100
|
|
|
98
101
|
@classmethod
|
|
99
102
|
def patch_params(
|
|
@@ -126,18 +129,29 @@ class CachedAdapter:
|
|
|
126
129
|
params_shift += len(blocks)
|
|
127
130
|
|
|
128
131
|
@classmethod
|
|
129
|
-
def check_context_kwargs(
|
|
132
|
+
def check_context_kwargs(
|
|
133
|
+
cls,
|
|
134
|
+
block_adapter: BlockAdapter,
|
|
135
|
+
**cache_context_kwargs,
|
|
136
|
+
):
|
|
130
137
|
# Check cache_context_kwargs
|
|
131
138
|
if not cache_context_kwargs["enable_spearate_cfg"]:
|
|
132
139
|
# Check cfg for some specific case if users don't set it as True
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
140
|
+
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
141
|
+
cache_context_kwargs["enable_spearate_cfg"] = True
|
|
142
|
+
logger.info(
|
|
143
|
+
f"Use custom 'enable_spearate_cfg' from BlockAdapter: True. "
|
|
144
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
cache_context_kwargs["enable_spearate_cfg"] = (
|
|
148
|
+
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
149
|
+
)
|
|
150
|
+
logger.info(
|
|
151
|
+
f"Use default 'enable_spearate_cfg' from block adapter "
|
|
152
|
+
f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
|
|
153
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
154
|
+
)
|
|
141
155
|
|
|
142
156
|
if cache_type := cache_context_kwargs.pop("cache_type", None):
|
|
143
157
|
assert (
|
|
@@ -160,8 +174,7 @@ class CachedAdapter:
|
|
|
160
174
|
|
|
161
175
|
# Check cache_context_kwargs
|
|
162
176
|
cache_context_kwargs = cls.check_context_kwargs(
|
|
163
|
-
block_adapter
|
|
164
|
-
**cache_context_kwargs,
|
|
177
|
+
block_adapter, **cache_context_kwargs
|
|
165
178
|
)
|
|
166
179
|
# Apply cache on pipeline: wrap cache context
|
|
167
180
|
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
@@ -23,3 +23,19 @@ def patch_cached_stats(
|
|
|
23
23
|
module._residual_diffs = cache_manager.get_residual_diffs()
|
|
24
24
|
module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
|
|
25
25
|
module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def remove_cached_stats(
|
|
29
|
+
module: torch.nn.Module | Any,
|
|
30
|
+
):
|
|
31
|
+
if module is None:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
if hasattr(module, "_cached_steps"):
|
|
35
|
+
del module._cached_steps
|
|
36
|
+
if hasattr(module, "_residual_diffs"):
|
|
37
|
+
del module._residual_diffs
|
|
38
|
+
if hasattr(module, "_cfg_cached_steps"):
|
|
39
|
+
del module._cfg_cached_steps
|
|
40
|
+
if hasattr(module, "_cfg_residual_diffs"):
|
|
41
|
+
del module._cfg_residual_diffs
|
|
@@ -63,6 +63,20 @@ class CachedContextManager:
|
|
|
63
63
|
_context = self.new_context(*args, **kwargs)
|
|
64
64
|
return _context
|
|
65
65
|
|
|
66
|
+
def remove_context(self, cached_context: CachedContext | str):
|
|
67
|
+
if isinstance(cached_context, CachedContext):
|
|
68
|
+
cached_context.clear_buffers()
|
|
69
|
+
if cached_context.name in self._cached_context_manager:
|
|
70
|
+
del self._cached_context_manager[cached_context.name]
|
|
71
|
+
else:
|
|
72
|
+
if cached_context in self._cached_context_manager:
|
|
73
|
+
self._cached_context_manager[cached_context].clear_buffers()
|
|
74
|
+
del self._cached_context_manager[cached_context]
|
|
75
|
+
|
|
76
|
+
def clear_contexts(self):
|
|
77
|
+
for cached_context in self._cached_context_manager:
|
|
78
|
+
self.remove_context(cached_context)
|
|
79
|
+
|
|
66
80
|
@contextlib.contextmanager
|
|
67
81
|
def enter_context(self, cached_context: CachedContext | str):
|
|
68
82
|
old_cached_context = self._current_context
|