cache-dit 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +2 -0
- cache_dit/cache_factory/block_adapters/__init__.py +22 -5
- cache_dit/cache_factory/block_adapters/block_adapters.py +230 -25
- cache_dit/cache_factory/cache_adapters.py +209 -94
- cache_dit/cache_factory/cache_blocks/__init__.py +55 -4
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +36 -37
- cache_dit/cache_factory/cache_blocks/pattern_base.py +83 -76
- cache_dit/cache_factory/cache_blocks/utils.py +10 -8
- cache_dit/cache_factory/cache_contexts/__init__.py +4 -1
- cache_dit/cache_factory/cache_contexts/cache_context.py +14 -876
- cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
- cache_dit/cache_factory/cache_interface.py +10 -13
- cache_dit/utils.py +7 -10
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/METADATA +30 -24
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/RECORD +21 -21
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -10,6 +10,7 @@ from cache_dit.cache_factory import cache_type
|
|
|
10
10
|
from cache_dit.cache_factory import block_range
|
|
11
11
|
from cache_dit.cache_factory import CacheType
|
|
12
12
|
from cache_dit.cache_factory import BlockAdapter
|
|
13
|
+
from cache_dit.cache_factory import ParamsModifier
|
|
13
14
|
from cache_dit.cache_factory import ForwardPattern
|
|
14
15
|
from cache_dit.cache_factory import PatchFunctor
|
|
15
16
|
from cache_dit.cache_factory import supported_pipelines
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.28'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 28)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -7,9 +7,11 @@ from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
|
7
7
|
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
8
8
|
|
|
9
9
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
10
|
+
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
10
11
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
11
12
|
|
|
12
13
|
from cache_dit.cache_factory.cache_contexts import CachedContext
|
|
14
|
+
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
13
15
|
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
14
16
|
|
|
15
17
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
2
2
|
from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
|
|
3
|
+
from cache_dit.cache_factory.block_adapters.block_adapters import ParamsModifier
|
|
3
4
|
from cache_dit.cache_factory.block_adapters.block_registers import (
|
|
4
5
|
BlockAdapterRegistry,
|
|
5
6
|
)
|
|
@@ -69,14 +70,30 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
69
70
|
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
70
71
|
)
|
|
71
72
|
if getattr(pipe, "transformer_2", None):
|
|
72
|
-
|
|
73
|
+
assert isinstance(
|
|
74
|
+
pipe.transformer_2,
|
|
75
|
+
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
76
|
+
)
|
|
77
|
+
# Wan 2.2 MoE
|
|
73
78
|
return BlockAdapter(
|
|
74
79
|
pipe=pipe,
|
|
75
|
-
transformer=
|
|
76
|
-
|
|
77
|
-
|
|
80
|
+
transformer=[
|
|
81
|
+
pipe.transformer,
|
|
82
|
+
pipe.transformer_2,
|
|
83
|
+
],
|
|
84
|
+
blocks=[
|
|
85
|
+
pipe.transformer.blocks,
|
|
86
|
+
pipe.transformer_2.blocks,
|
|
87
|
+
],
|
|
88
|
+
blocks_name=[
|
|
89
|
+
"blocks",
|
|
90
|
+
"blocks",
|
|
91
|
+
],
|
|
92
|
+
forward_pattern=[
|
|
93
|
+
ForwardPattern.Pattern_2,
|
|
94
|
+
ForwardPattern.Pattern_2,
|
|
95
|
+
],
|
|
78
96
|
dummy_blocks_names=[],
|
|
79
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
80
97
|
has_separate_cfg=True,
|
|
81
98
|
)
|
|
82
99
|
else:
|
|
@@ -3,7 +3,7 @@ import torch
|
|
|
3
3
|
import inspect
|
|
4
4
|
import dataclasses
|
|
5
5
|
|
|
6
|
-
from typing import Any, Tuple, List, Optional
|
|
6
|
+
from typing import Any, Tuple, List, Optional, Union
|
|
7
7
|
|
|
8
8
|
from diffusers import DiffusionPipeline
|
|
9
9
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
@@ -14,21 +14,79 @@ from cache_dit.logger import init_logger
|
|
|
14
14
|
logger = init_logger(__name__)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
class ParamsModifier:
|
|
18
|
+
def __init__(self, **kwargs):
|
|
19
|
+
self._context_kwargs = kwargs.copy()
|
|
20
|
+
|
|
21
|
+
|
|
17
22
|
@dataclasses.dataclass
|
|
18
23
|
class BlockAdapter:
|
|
24
|
+
|
|
19
25
|
# Transformer configurations.
|
|
20
|
-
pipe:
|
|
21
|
-
|
|
26
|
+
pipe: Union[
|
|
27
|
+
DiffusionPipeline,
|
|
28
|
+
Any,
|
|
29
|
+
] = None
|
|
30
|
+
|
|
31
|
+
# single transformer (most cases) or list of transformers (Wan2.2, etc)
|
|
32
|
+
transformer: Union[
|
|
33
|
+
torch.nn.Module,
|
|
34
|
+
List[torch.nn.Module],
|
|
35
|
+
] = None
|
|
36
|
+
|
|
37
|
+
# Block Level Flags
|
|
38
|
+
# Each transformer contains a list of blocks-list,
|
|
39
|
+
# blocks_name-list, dummy_blocks_names-list, etc.
|
|
40
|
+
blocks: Union[
|
|
41
|
+
torch.nn.ModuleList,
|
|
42
|
+
List[torch.nn.ModuleList],
|
|
43
|
+
List[List[torch.nn.ModuleList]],
|
|
44
|
+
] = None
|
|
22
45
|
|
|
23
|
-
# ------------ Block Level Flags ------------
|
|
24
|
-
blocks: torch.nn.ModuleList | List[torch.nn.ModuleList] = None
|
|
25
46
|
# transformer_blocks, blocks, etc.
|
|
26
|
-
blocks_name:
|
|
27
|
-
|
|
28
|
-
|
|
47
|
+
blocks_name: Union[
|
|
48
|
+
str,
|
|
49
|
+
List[str],
|
|
50
|
+
List[List[str]],
|
|
51
|
+
] = None
|
|
52
|
+
|
|
53
|
+
unique_blocks_name: Union[
|
|
54
|
+
str,
|
|
55
|
+
List[str],
|
|
56
|
+
List[List[str]],
|
|
57
|
+
] = dataclasses.field(default_factory=list)
|
|
58
|
+
|
|
59
|
+
dummy_blocks_names: Union[
|
|
60
|
+
List[str],
|
|
61
|
+
List[List[str]],
|
|
62
|
+
] = dataclasses.field(default_factory=list)
|
|
63
|
+
|
|
64
|
+
forward_pattern: Union[
|
|
65
|
+
ForwardPattern,
|
|
66
|
+
List[ForwardPattern],
|
|
67
|
+
List[List[ForwardPattern]],
|
|
68
|
+
] = None
|
|
69
|
+
|
|
70
|
+
# modify cache context params for specific blocks.
|
|
71
|
+
params_modifiers: Union[
|
|
72
|
+
ParamsModifier,
|
|
73
|
+
List[ParamsModifier],
|
|
74
|
+
List[List[ParamsModifier]],
|
|
75
|
+
] = None
|
|
76
|
+
|
|
29
77
|
check_num_outputs: bool = True
|
|
30
78
|
|
|
79
|
+
# Pipeline Level Flags
|
|
80
|
+
# Patch Functor: Flux, etc.
|
|
81
|
+
patch_functor: Optional[PatchFunctor] = None
|
|
82
|
+
# Flags for separate cfg
|
|
83
|
+
has_separate_cfg: bool = False
|
|
84
|
+
|
|
85
|
+
# Other Flags
|
|
86
|
+
disable_patch: bool = False
|
|
87
|
+
|
|
31
88
|
# Flags to control auto block adapter
|
|
89
|
+
# NOTE: NOT support for multi-transformers.
|
|
32
90
|
auto: bool = False
|
|
33
91
|
allow_prefixes: List[str] = dataclasses.field(
|
|
34
92
|
default_factory=lambda: [
|
|
@@ -49,15 +107,6 @@ class BlockAdapter:
|
|
|
49
107
|
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
50
108
|
)
|
|
51
109
|
|
|
52
|
-
# NOTE: Other flags.
|
|
53
|
-
disable_patch: bool = False
|
|
54
|
-
|
|
55
|
-
# ------------ Pipeline Level Flags ------------
|
|
56
|
-
# Patch Functor: Flux, etc.
|
|
57
|
-
patch_functor: Optional[PatchFunctor] = None
|
|
58
|
-
# Flags for separate cfg
|
|
59
|
-
has_separate_cfg: bool = False
|
|
60
|
-
|
|
61
110
|
def __post_init__(self):
|
|
62
111
|
assert any((self.pipe is not None, self.transformer is not None))
|
|
63
112
|
self.patchify()
|
|
@@ -320,14 +369,170 @@ class BlockAdapter:
|
|
|
320
369
|
def normalize(
|
|
321
370
|
adapter: "BlockAdapter",
|
|
322
371
|
) -> "BlockAdapter":
|
|
323
|
-
if not isinstance(adapter.blocks, list):
|
|
324
|
-
adapter.blocks = [adapter.blocks]
|
|
325
|
-
if not isinstance(adapter.blocks_name, list):
|
|
326
|
-
adapter.blocks_name = [adapter.blocks_name]
|
|
327
|
-
if not isinstance(adapter.forward_pattern, list):
|
|
328
|
-
adapter.forward_pattern = [adapter.forward_pattern]
|
|
329
372
|
|
|
330
|
-
|
|
331
|
-
|
|
373
|
+
if getattr(adapter, "_is_normalized", False):
|
|
374
|
+
return adapter
|
|
375
|
+
|
|
376
|
+
if not isinstance(adapter.transformer, list):
|
|
377
|
+
adapter.transformer = [adapter.transformer]
|
|
378
|
+
|
|
379
|
+
if isinstance(adapter.blocks, torch.nn.ModuleList):
|
|
380
|
+
# blocks_0 = [[blocks_0,],] -> match [TRN_0,]
|
|
381
|
+
adapter.blocks = [[adapter.blocks]]
|
|
382
|
+
elif isinstance(adapter.blocks, list):
|
|
383
|
+
if isinstance(adapter.blocks[0], torch.nn.ModuleList):
|
|
384
|
+
# [blocks_0, blocks_1] -> [[blocks_0, blocks_1],] -> match [TRN_0,]
|
|
385
|
+
if len(adapter.blocks) == len(adapter.transformer):
|
|
386
|
+
adapter.blocks = [[blocks] for blocks in adapter.blocks]
|
|
387
|
+
else:
|
|
388
|
+
adapter.blocks = [adapter.blocks]
|
|
389
|
+
elif isinstance(adapter.blocks[0], list):
|
|
390
|
+
# [[blocks_0, blocks_1],[blocks_2, blocks_3],] -> match [TRN_0, TRN_1,]
|
|
391
|
+
pass
|
|
392
|
+
|
|
393
|
+
if isinstance(adapter.blocks_name, str):
|
|
394
|
+
adapter.blocks_name = [[adapter.blocks_name]]
|
|
395
|
+
elif isinstance(adapter.blocks_name, list):
|
|
396
|
+
if isinstance(adapter.blocks_name[0], str):
|
|
397
|
+
if len(adapter.blocks_name) == len(adapter.transformer):
|
|
398
|
+
adapter.blocks_name = [
|
|
399
|
+
[blocks_name] for blocks_name in adapter.blocks_name
|
|
400
|
+
]
|
|
401
|
+
else:
|
|
402
|
+
adapter.blocks_name = [adapter.blocks_name]
|
|
403
|
+
elif isinstance(adapter.blocks_name[0], list):
|
|
404
|
+
pass
|
|
405
|
+
|
|
406
|
+
if isinstance(adapter.forward_pattern, ForwardPattern):
|
|
407
|
+
adapter.forward_pattern = [[adapter.forward_pattern]]
|
|
408
|
+
elif isinstance(adapter.forward_pattern, list):
|
|
409
|
+
if isinstance(adapter.forward_pattern[0], ForwardPattern):
|
|
410
|
+
if len(adapter.forward_pattern) == len(adapter.transformer):
|
|
411
|
+
adapter.forward_pattern = [
|
|
412
|
+
[forward_pattern]
|
|
413
|
+
for forward_pattern in adapter.forward_pattern
|
|
414
|
+
]
|
|
415
|
+
else:
|
|
416
|
+
adapter.forward_pattern = [adapter.forward_pattern]
|
|
417
|
+
elif isinstance(adapter.forward_pattern[0], list):
|
|
418
|
+
pass
|
|
419
|
+
|
|
420
|
+
if isinstance(adapter.dummy_blocks_names, list):
|
|
421
|
+
if len(adapter.dummy_blocks_names) > 0:
|
|
422
|
+
if isinstance(adapter.dummy_blocks_names[0], str):
|
|
423
|
+
if len(adapter.dummy_blocks_names) == len(
|
|
424
|
+
adapter.transformer
|
|
425
|
+
):
|
|
426
|
+
adapter.dummy_blocks_names = [
|
|
427
|
+
[dummy_blocks_names]
|
|
428
|
+
for dummy_blocks_names in adapter.dummy_blocks_names
|
|
429
|
+
]
|
|
430
|
+
else:
|
|
431
|
+
adapter.dummy_blocks_names = [
|
|
432
|
+
adapter.dummy_blocks_names
|
|
433
|
+
]
|
|
434
|
+
elif isinstance(adapter.dummy_blocks_names[0], list):
|
|
435
|
+
pass
|
|
436
|
+
else:
|
|
437
|
+
# Empty dummy_blocks_names
|
|
438
|
+
adapter.dummy_blocks_names = [
|
|
439
|
+
[] for _ in range(len(adapter.transformer))
|
|
440
|
+
]
|
|
441
|
+
|
|
442
|
+
if adapter.params_modifiers is not None:
|
|
443
|
+
if isinstance(adapter.params_modifiers, ParamsModifier):
|
|
444
|
+
adapter.params_modifiers = [[adapter.params_modifiers]]
|
|
445
|
+
elif isinstance(adapter.params_modifiers, list):
|
|
446
|
+
if isinstance(adapter.params_modifiers[0], ParamsModifier):
|
|
447
|
+
if len(adapter.params_modifiers) == len(
|
|
448
|
+
adapter.transformer
|
|
449
|
+
):
|
|
450
|
+
adapter.params_modifiers = [
|
|
451
|
+
[params_modifiers]
|
|
452
|
+
for params_modifiers in adapter.params_modifiers
|
|
453
|
+
]
|
|
454
|
+
else:
|
|
455
|
+
adapter.params_modifiers = [adapter.params_modifiers]
|
|
456
|
+
elif isinstance(adapter.params_modifiers[0], list):
|
|
457
|
+
pass
|
|
458
|
+
|
|
459
|
+
assert len(adapter.transformer) == len(adapter.blocks)
|
|
460
|
+
assert len(adapter.transformer) == len(adapter.blocks_name)
|
|
461
|
+
assert len(adapter.transformer) == len(adapter.forward_pattern)
|
|
462
|
+
assert len(adapter.transformer) == len(adapter.dummy_blocks_names)
|
|
463
|
+
if adapter.params_modifiers is not None:
|
|
464
|
+
assert len(adapter.transformer) == len(adapter.params_modifiers)
|
|
465
|
+
|
|
466
|
+
for i in range(len(adapter.blocks)):
|
|
467
|
+
assert len(adapter.blocks[i]) == len(adapter.blocks_name[i])
|
|
468
|
+
assert len(adapter.blocks[i]) == len(adapter.forward_pattern[i])
|
|
469
|
+
|
|
470
|
+
if len(adapter.unique_blocks_name) == 0:
|
|
471
|
+
for i in range(len(adapter.transformer)):
|
|
472
|
+
# Generate unique blocks names
|
|
473
|
+
adapter.unique_blocks_name.append(
|
|
474
|
+
[
|
|
475
|
+
f"{name}_{hash(id(blocks))}"
|
|
476
|
+
for blocks, name in zip(
|
|
477
|
+
adapter.blocks[i],
|
|
478
|
+
adapter.blocks_name[i],
|
|
479
|
+
)
|
|
480
|
+
]
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
assert len(adapter.transformer) == len(adapter.unique_blocks_name)
|
|
484
|
+
|
|
485
|
+
# Match Forward Pattern
|
|
486
|
+
for i in range(len(adapter.transformer)):
|
|
487
|
+
for forward_pattern, blocks in zip(
|
|
488
|
+
adapter.forward_pattern[i], adapter.blocks[i]
|
|
489
|
+
):
|
|
490
|
+
assert BlockAdapter.match_blocks_pattern(
|
|
491
|
+
blocks,
|
|
492
|
+
forward_pattern=forward_pattern,
|
|
493
|
+
check_num_outputs=adapter.check_num_outputs,
|
|
494
|
+
), (
|
|
495
|
+
"No block forward pattern matched, "
|
|
496
|
+
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
adapter._is_normalized = True
|
|
332
500
|
|
|
333
501
|
return adapter
|
|
502
|
+
|
|
503
|
+
@classmethod
|
|
504
|
+
def assert_normalized(cls, adapter: "BlockAdapter"):
|
|
505
|
+
if not getattr(adapter, "_is_normalized", False):
|
|
506
|
+
raise RuntimeError("block_adapter must be normailzed.")
|
|
507
|
+
|
|
508
|
+
@classmethod
|
|
509
|
+
def is_cached(cls, adapter: Any) -> bool:
|
|
510
|
+
if isinstance(adapter, cls):
|
|
511
|
+
cls.assert_normalized(adapter)
|
|
512
|
+
return all(
|
|
513
|
+
(
|
|
514
|
+
getattr(adapter.pipe, "_is_cached", False),
|
|
515
|
+
getattr(adapter.transformer[0], "_is_cached", False),
|
|
516
|
+
)
|
|
517
|
+
)
|
|
518
|
+
elif isinstance(
|
|
519
|
+
adapter,
|
|
520
|
+
(DiffusionPipeline, torch.nn.Module),
|
|
521
|
+
):
|
|
522
|
+
return getattr(adapter, "_is_cached", False)
|
|
523
|
+
elif isinstance(adapter, list): # [TRN_0,...]
|
|
524
|
+
assert isinstance(adapter[0], torch.nn.Module)
|
|
525
|
+
return getattr(adapter[0], "_is_cached", False)
|
|
526
|
+
else:
|
|
527
|
+
raise TypeError(f"Can't check this type: {adapter}!")
|
|
528
|
+
|
|
529
|
+
@classmethod
|
|
530
|
+
def flatten(cls, attr: List[List[Any]]):
|
|
531
|
+
if isinstance(attr, list):
|
|
532
|
+
if not isinstance(attr[0], list):
|
|
533
|
+
return attr
|
|
534
|
+
flatten_attr = []
|
|
535
|
+
for i in range(len(attr)):
|
|
536
|
+
flatten_attr.extend(attr[i])
|
|
537
|
+
return flatten_attr
|
|
538
|
+
return attr
|