cache-dit 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +91 -62
- cache_dit/cache_factory/block_adapters/block_adapters.py +8 -6
- cache_dit/cache_factory/block_adapters/block_registers.py +10 -7
- cache_dit/cache_factory/cache_adapters.py +176 -66
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +70 -67
- cache_dit/cache_factory/cache_contexts/cache_manager.py +8 -10
- cache_dit/cache_factory/cache_interface.py +19 -77
- cache_dit/cache_factory/cache_types.py +5 -5
- cache_dit/cache_factory/patch_functors/functor_chroma.py +2 -1
- cache_dit/cache_factory/patch_functors/functor_flux.py +2 -1
- cache_dit/utils.py +5 -1
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.30.dist-info}/METADATA +2 -42
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.30.dist-info}/RECORD +18 -18
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.30.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.30.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.30.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.30.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.30'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 30)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -123,6 +123,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
123
123
|
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
124
124
|
return BlockAdapter(
|
|
125
125
|
pipe=pipe,
|
|
126
|
+
transformer=pipe.transformer,
|
|
126
127
|
blocks=[
|
|
127
128
|
pipe.transformer.transformer_blocks,
|
|
128
129
|
pipe.transformer.single_transformer_blocks,
|
|
@@ -131,6 +132,8 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
131
132
|
ForwardPattern.Pattern_0,
|
|
132
133
|
ForwardPattern.Pattern_0,
|
|
133
134
|
],
|
|
135
|
+
# The type hint in diffusers is wrong
|
|
136
|
+
check_num_outputs=False,
|
|
134
137
|
**kwargs,
|
|
135
138
|
)
|
|
136
139
|
|
|
@@ -327,37 +330,6 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
327
330
|
)
|
|
328
331
|
|
|
329
332
|
|
|
330
|
-
@BlockAdapterRegistry.register("HunyuanDiT")
|
|
331
|
-
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
332
|
-
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
333
|
-
|
|
334
|
-
assert isinstance(
|
|
335
|
-
pipe.transformer,
|
|
336
|
-
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
337
|
-
)
|
|
338
|
-
return BlockAdapter(
|
|
339
|
-
pipe=pipe,
|
|
340
|
-
transformer=pipe.transformer,
|
|
341
|
-
blocks=pipe.transformer.blocks,
|
|
342
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
343
|
-
**kwargs,
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@BlockAdapterRegistry.register("HunyuanDiTPAG")
|
|
348
|
-
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
349
|
-
from diffusers import HunyuanDiT2DModel
|
|
350
|
-
|
|
351
|
-
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
352
|
-
return BlockAdapter(
|
|
353
|
-
pipe=pipe,
|
|
354
|
-
transformer=pipe.transformer,
|
|
355
|
-
blocks=pipe.transformer.blocks,
|
|
356
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
357
|
-
**kwargs,
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
|
|
361
333
|
@BlockAdapterRegistry.register("Lumina")
|
|
362
334
|
def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
363
335
|
from diffusers import LuminaNextDiT2DModel
|
|
@@ -414,10 +386,12 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
414
386
|
)
|
|
415
387
|
|
|
416
388
|
|
|
417
|
-
@BlockAdapterRegistry.register("Sana")
|
|
389
|
+
@BlockAdapterRegistry.register("Sana", supported=False)
|
|
418
390
|
def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
419
391
|
from diffusers import SanaTransformer2DModel
|
|
420
392
|
|
|
393
|
+
# TODO: fix -> got multiple values for argument 'encoder_hidden_states'
|
|
394
|
+
|
|
421
395
|
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
422
396
|
return BlockAdapter(
|
|
423
397
|
pipe=pipe,
|
|
@@ -428,20 +402,6 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
428
402
|
)
|
|
429
403
|
|
|
430
404
|
|
|
431
|
-
@BlockAdapterRegistry.register("ShapE")
|
|
432
|
-
def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
433
|
-
from diffusers import PriorTransformer
|
|
434
|
-
|
|
435
|
-
assert isinstance(pipe.prior, PriorTransformer)
|
|
436
|
-
return BlockAdapter(
|
|
437
|
-
pipe=pipe,
|
|
438
|
-
transformer=pipe.prior,
|
|
439
|
-
blocks=pipe.prior.transformer_blocks,
|
|
440
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
441
|
-
**kwargs,
|
|
442
|
-
)
|
|
443
|
-
|
|
444
|
-
|
|
445
405
|
@BlockAdapterRegistry.register("StableAudio")
|
|
446
406
|
def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
447
407
|
from diffusers import StableAudioDiTModel
|
|
@@ -459,21 +419,37 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
459
419
|
@BlockAdapterRegistry.register("VisualCloze")
|
|
460
420
|
def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
461
421
|
from diffusers import FluxTransformer2DModel
|
|
422
|
+
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
462
423
|
|
|
463
424
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
425
|
+
if is_diffusers_at_least_0_3_5():
|
|
426
|
+
return BlockAdapter(
|
|
427
|
+
pipe=pipe,
|
|
428
|
+
transformer=pipe.transformer,
|
|
429
|
+
blocks=[
|
|
430
|
+
pipe.transformer.transformer_blocks,
|
|
431
|
+
pipe.transformer.single_transformer_blocks,
|
|
432
|
+
],
|
|
433
|
+
forward_pattern=[
|
|
434
|
+
ForwardPattern.Pattern_1,
|
|
435
|
+
ForwardPattern.Pattern_1,
|
|
436
|
+
],
|
|
437
|
+
**kwargs,
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
return BlockAdapter(
|
|
441
|
+
pipe=pipe,
|
|
442
|
+
transformer=pipe.transformer,
|
|
443
|
+
blocks=[
|
|
444
|
+
pipe.transformer.transformer_blocks,
|
|
445
|
+
pipe.transformer.single_transformer_blocks,
|
|
446
|
+
],
|
|
447
|
+
forward_pattern=[
|
|
448
|
+
ForwardPattern.Pattern_1,
|
|
449
|
+
ForwardPattern.Pattern_3,
|
|
450
|
+
],
|
|
451
|
+
**kwargs,
|
|
452
|
+
)
|
|
477
453
|
|
|
478
454
|
|
|
479
455
|
@BlockAdapterRegistry.register("AuraFlow")
|
|
@@ -511,8 +487,27 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
511
487
|
)
|
|
512
488
|
|
|
513
489
|
|
|
514
|
-
@BlockAdapterRegistry.register("
|
|
490
|
+
@BlockAdapterRegistry.register("ShapE")
|
|
491
|
+
def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
492
|
+
from diffusers import PriorTransformer
|
|
493
|
+
|
|
494
|
+
assert isinstance(pipe.prior, PriorTransformer)
|
|
495
|
+
return BlockAdapter(
|
|
496
|
+
pipe=pipe,
|
|
497
|
+
transformer=pipe.prior,
|
|
498
|
+
blocks=pipe.prior.transformer_blocks,
|
|
499
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
500
|
+
**kwargs,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
@BlockAdapterRegistry.register("HiDream", supported=True)
|
|
515
505
|
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
506
|
+
# NOTE: Need to patch Transformer forward to fully support
|
|
507
|
+
# double_stream_blocks and single_stream_blocks, namely, need
|
|
508
|
+
# to remove the logics inside the blocks forward loop:
|
|
509
|
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
|
|
510
|
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
|
|
516
511
|
from diffusers import HiDreamImageTransformer2DModel
|
|
517
512
|
|
|
518
513
|
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
@@ -520,13 +515,47 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
520
515
|
pipe=pipe,
|
|
521
516
|
transformer=pipe.transformer,
|
|
522
517
|
blocks=[
|
|
523
|
-
pipe.transformer.double_stream_blocks,
|
|
518
|
+
# pipe.transformer.double_stream_blocks,
|
|
524
519
|
pipe.transformer.single_stream_blocks,
|
|
525
520
|
],
|
|
526
521
|
forward_pattern=[
|
|
527
|
-
ForwardPattern.Pattern_4,
|
|
522
|
+
# ForwardPattern.Pattern_4,
|
|
528
523
|
ForwardPattern.Pattern_3,
|
|
529
524
|
],
|
|
525
|
+
# The type hint in diffusers is wrong
|
|
530
526
|
check_num_outputs=False,
|
|
531
527
|
**kwargs,
|
|
532
528
|
)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
@BlockAdapterRegistry.register("HunyuanDiT", supported=False)
|
|
532
|
+
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
533
|
+
# TODO: Patch Transformer forward
|
|
534
|
+
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
535
|
+
|
|
536
|
+
assert isinstance(
|
|
537
|
+
pipe.transformer,
|
|
538
|
+
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
539
|
+
)
|
|
540
|
+
return BlockAdapter(
|
|
541
|
+
pipe=pipe,
|
|
542
|
+
transformer=pipe.transformer,
|
|
543
|
+
blocks=pipe.transformer.blocks,
|
|
544
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
545
|
+
**kwargs,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
@BlockAdapterRegistry.register("HunyuanDiTPAG", supported=False)
|
|
550
|
+
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
551
|
+
# TODO: Patch Transformer forward
|
|
552
|
+
from diffusers import HunyuanDiT2DModel
|
|
553
|
+
|
|
554
|
+
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
555
|
+
return BlockAdapter(
|
|
556
|
+
pipe=pipe,
|
|
557
|
+
transformer=pipe.transformer,
|
|
558
|
+
blocks=pipe.transformer.blocks,
|
|
559
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
560
|
+
**kwargs,
|
|
561
|
+
)
|
|
@@ -75,7 +75,7 @@ class BlockAdapter:
|
|
|
75
75
|
List[List[ParamsModifier]],
|
|
76
76
|
] = None
|
|
77
77
|
|
|
78
|
-
check_num_outputs: bool =
|
|
78
|
+
check_num_outputs: bool = False
|
|
79
79
|
|
|
80
80
|
# Pipeline Level Flags
|
|
81
81
|
# Patch Functor: Flux, etc.
|
|
@@ -111,9 +111,9 @@ class BlockAdapter:
|
|
|
111
111
|
def __post_init__(self):
|
|
112
112
|
if self.skip_post_init:
|
|
113
113
|
return
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
114
|
+
if any((self.pipe is not None, self.transformer is not None)):
|
|
115
|
+
self.maybe_fill_attrs()
|
|
116
|
+
self.maybe_patchify()
|
|
117
117
|
|
|
118
118
|
def maybe_fill_attrs(self):
|
|
119
119
|
# NOTE: This func should be call before normalize.
|
|
@@ -130,7 +130,9 @@ class BlockAdapter:
|
|
|
130
130
|
assert isinstance(blocks, torch.nn.ModuleList)
|
|
131
131
|
blocks_name = None
|
|
132
132
|
for attr_name in attr_names:
|
|
133
|
-
if
|
|
133
|
+
if (
|
|
134
|
+
attr := getattr(transformer, attr_name, None)
|
|
135
|
+
) is not None:
|
|
134
136
|
if isinstance(attr, torch.nn.ModuleList) and id(
|
|
135
137
|
attr
|
|
136
138
|
) == id(blocks):
|
|
@@ -558,7 +560,7 @@ class BlockAdapter:
|
|
|
558
560
|
assert isinstance(adapter[0], torch.nn.Module)
|
|
559
561
|
return getattr(adapter[0], "_is_cached", False)
|
|
560
562
|
else:
|
|
561
|
-
raise TypeError(f"Can't check this type: {adapter}!")
|
|
563
|
+
raise TypeError(f"Can't check this type: {type(adapter)}!")
|
|
562
564
|
|
|
563
565
|
@classmethod
|
|
564
566
|
def nested_depth(cls, obj: Any):
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Tuple, List, Dict
|
|
1
|
+
from typing import Any, Tuple, List, Dict, Callable
|
|
2
2
|
|
|
3
3
|
from diffusers import DiffusionPipeline
|
|
4
4
|
from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
|
|
@@ -9,20 +9,23 @@ logger = init_logger(__name__)
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class BlockAdapterRegistry:
|
|
12
|
-
_adapters: Dict[str, BlockAdapter] = {}
|
|
13
|
-
_predefined_adapters_has_spearate_cfg: List[str] =
|
|
12
|
+
_adapters: Dict[str, Callable[..., BlockAdapter]] = {}
|
|
13
|
+
_predefined_adapters_has_spearate_cfg: List[str] = [
|
|
14
14
|
"QwenImage",
|
|
15
15
|
"Wan",
|
|
16
16
|
"CogView4",
|
|
17
17
|
"Cosmos",
|
|
18
18
|
"SkyReelsV2",
|
|
19
19
|
"Chroma",
|
|
20
|
-
|
|
20
|
+
]
|
|
21
21
|
|
|
22
22
|
@classmethod
|
|
23
|
-
def register(cls, name):
|
|
24
|
-
def decorator(
|
|
25
|
-
|
|
23
|
+
def register(cls, name: str, supported: bool = True):
|
|
24
|
+
def decorator(
|
|
25
|
+
func: Callable[..., BlockAdapter]
|
|
26
|
+
) -> Callable[..., BlockAdapter]:
|
|
27
|
+
if supported:
|
|
28
|
+
cls._adapters[name] = func
|
|
26
29
|
return func
|
|
27
30
|
|
|
28
31
|
return decorator
|
|
@@ -4,7 +4,7 @@ import unittest
|
|
|
4
4
|
import functools
|
|
5
5
|
|
|
6
6
|
from contextlib import ExitStack
|
|
7
|
-
from typing import Dict, List, Tuple, Any
|
|
7
|
+
from typing import Dict, List, Tuple, Any, Union, Callable
|
|
8
8
|
|
|
9
9
|
from diffusers import DiffusionPipeline
|
|
10
10
|
|
|
@@ -14,7 +14,10 @@ from cache_dit.cache_factory import ParamsModifier
|
|
|
14
14
|
from cache_dit.cache_factory import BlockAdapterRegistry
|
|
15
15
|
from cache_dit.cache_factory import CachedContextManager
|
|
16
16
|
from cache_dit.cache_factory import CachedBlocks
|
|
17
|
-
|
|
17
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
18
|
+
patch_cached_stats,
|
|
19
|
+
remove_cached_stats,
|
|
20
|
+
)
|
|
18
21
|
from cache_dit.logger import init_logger
|
|
19
22
|
|
|
20
23
|
logger = init_logger(__name__)
|
|
@@ -29,9 +32,15 @@ class CachedAdapter:
|
|
|
29
32
|
@classmethod
|
|
30
33
|
def apply(
|
|
31
34
|
cls,
|
|
32
|
-
pipe_or_adapter:
|
|
35
|
+
pipe_or_adapter: Union[
|
|
36
|
+
DiffusionPipeline,
|
|
37
|
+
BlockAdapter,
|
|
38
|
+
],
|
|
33
39
|
**cache_context_kwargs,
|
|
34
|
-
) ->
|
|
40
|
+
) -> Union[
|
|
41
|
+
DiffusionPipeline,
|
|
42
|
+
BlockAdapter,
|
|
43
|
+
]:
|
|
35
44
|
assert (
|
|
36
45
|
pipe_or_adapter is not None
|
|
37
46
|
), "pipe or block_adapter can not both None!"
|
|
@@ -49,7 +58,7 @@ class CachedAdapter:
|
|
|
49
58
|
return cls.cachify(
|
|
50
59
|
block_adapter,
|
|
51
60
|
**cache_context_kwargs,
|
|
52
|
-
)
|
|
61
|
+
).pipe
|
|
53
62
|
else:
|
|
54
63
|
raise ValueError(
|
|
55
64
|
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
@@ -82,7 +91,7 @@ class CachedAdapter:
|
|
|
82
91
|
# 0. Must normalize block_adapter before apply cache
|
|
83
92
|
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
84
93
|
if BlockAdapter.is_cached(block_adapter):
|
|
85
|
-
return block_adapter
|
|
94
|
+
return block_adapter
|
|
86
95
|
|
|
87
96
|
# 1. Apply cache on pipeline: wrap cache context, must
|
|
88
97
|
# call create_context before mock_blocks.
|
|
@@ -98,36 +107,6 @@ class CachedAdapter:
|
|
|
98
107
|
|
|
99
108
|
return block_adapter
|
|
100
109
|
|
|
101
|
-
@classmethod
|
|
102
|
-
def patch_params(
|
|
103
|
-
cls,
|
|
104
|
-
block_adapter: BlockAdapter,
|
|
105
|
-
contexts_kwargs: List[Dict],
|
|
106
|
-
):
|
|
107
|
-
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
108
|
-
|
|
109
|
-
params_shift = 0
|
|
110
|
-
for i in range(len(block_adapter.transformer)):
|
|
111
|
-
|
|
112
|
-
block_adapter.transformer[i]._forward_pattern = (
|
|
113
|
-
block_adapter.forward_pattern
|
|
114
|
-
)
|
|
115
|
-
block_adapter.transformer[i]._has_separate_cfg = (
|
|
116
|
-
block_adapter.has_separate_cfg
|
|
117
|
-
)
|
|
118
|
-
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
119
|
-
contexts_kwargs[params_shift]
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
blocks = block_adapter.blocks[i]
|
|
123
|
-
for j in range(len(blocks)):
|
|
124
|
-
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
125
|
-
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
126
|
-
params_shift + j
|
|
127
|
-
]
|
|
128
|
-
|
|
129
|
-
params_shift += len(blocks)
|
|
130
|
-
|
|
131
110
|
@classmethod
|
|
132
111
|
def check_context_kwargs(
|
|
133
112
|
cls,
|
|
@@ -153,7 +132,9 @@ class CachedAdapter:
|
|
|
153
132
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
154
133
|
)
|
|
155
134
|
|
|
156
|
-
if
|
|
135
|
+
if (
|
|
136
|
+
cache_type := cache_context_kwargs.pop("cache_type", None)
|
|
137
|
+
) is not None:
|
|
157
138
|
assert (
|
|
158
139
|
cache_type == CacheType.DBCache
|
|
159
140
|
), "Custom cache setting only support for DBCache now!"
|
|
@@ -210,14 +191,14 @@ class CachedAdapter:
|
|
|
210
191
|
)
|
|
211
192
|
)
|
|
212
193
|
outputs = original_call(self, *args, **kwargs)
|
|
213
|
-
cls.
|
|
194
|
+
cls.apply_stats_hooks(block_adapter)
|
|
214
195
|
return outputs
|
|
215
196
|
|
|
216
197
|
block_adapter.pipe.__class__.__call__ = new_call
|
|
217
198
|
block_adapter.pipe.__class__._original_call = original_call
|
|
218
199
|
block_adapter.pipe.__class__._is_cached = True
|
|
219
200
|
|
|
220
|
-
cls.
|
|
201
|
+
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
221
202
|
|
|
222
203
|
return block_adapter.pipe
|
|
223
204
|
|
|
@@ -261,33 +242,6 @@ class CachedAdapter:
|
|
|
261
242
|
|
|
262
243
|
return flatten_contexts, contexts_kwargs
|
|
263
244
|
|
|
264
|
-
@classmethod
|
|
265
|
-
def patch_stats(
|
|
266
|
-
cls,
|
|
267
|
-
block_adapter: BlockAdapter,
|
|
268
|
-
):
|
|
269
|
-
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
270
|
-
patch_cached_stats,
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
cache_manager = block_adapter.pipe._cache_manager
|
|
274
|
-
|
|
275
|
-
for i in range(len(block_adapter.transformer)):
|
|
276
|
-
patch_cached_stats(
|
|
277
|
-
block_adapter.transformer[i],
|
|
278
|
-
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
279
|
-
cache_manager=cache_manager,
|
|
280
|
-
)
|
|
281
|
-
for blocks, unique_name in zip(
|
|
282
|
-
block_adapter.blocks[i],
|
|
283
|
-
block_adapter.unique_blocks_name[i],
|
|
284
|
-
):
|
|
285
|
-
patch_cached_stats(
|
|
286
|
-
blocks,
|
|
287
|
-
cache_context=unique_name,
|
|
288
|
-
cache_manager=cache_manager,
|
|
289
|
-
)
|
|
290
|
-
|
|
291
245
|
@classmethod
|
|
292
246
|
def mock_blocks(
|
|
293
247
|
cls,
|
|
@@ -405,3 +359,159 @@ class CachedAdapter:
|
|
|
405
359
|
total_cached_blocks.append(cached_blocks_bind_context)
|
|
406
360
|
|
|
407
361
|
return total_cached_blocks
|
|
362
|
+
|
|
363
|
+
@classmethod
|
|
364
|
+
def apply_params_hooks(
|
|
365
|
+
cls,
|
|
366
|
+
block_adapter: BlockAdapter,
|
|
367
|
+
contexts_kwargs: List[Dict],
|
|
368
|
+
):
|
|
369
|
+
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
370
|
+
|
|
371
|
+
params_shift = 0
|
|
372
|
+
for i in range(len(block_adapter.transformer)):
|
|
373
|
+
|
|
374
|
+
block_adapter.transformer[i]._forward_pattern = (
|
|
375
|
+
block_adapter.forward_pattern
|
|
376
|
+
)
|
|
377
|
+
block_adapter.transformer[i]._has_separate_cfg = (
|
|
378
|
+
block_adapter.has_separate_cfg
|
|
379
|
+
)
|
|
380
|
+
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
381
|
+
contexts_kwargs[params_shift]
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
blocks = block_adapter.blocks[i]
|
|
385
|
+
for j in range(len(blocks)):
|
|
386
|
+
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
387
|
+
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
388
|
+
params_shift + j
|
|
389
|
+
]
|
|
390
|
+
|
|
391
|
+
params_shift += len(blocks)
|
|
392
|
+
|
|
393
|
+
@classmethod
|
|
394
|
+
def apply_stats_hooks(
|
|
395
|
+
cls,
|
|
396
|
+
block_adapter: BlockAdapter,
|
|
397
|
+
):
|
|
398
|
+
cache_manager = block_adapter.pipe._cache_manager
|
|
399
|
+
|
|
400
|
+
for i in range(len(block_adapter.transformer)):
|
|
401
|
+
patch_cached_stats(
|
|
402
|
+
block_adapter.transformer[i],
|
|
403
|
+
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
404
|
+
cache_manager=cache_manager,
|
|
405
|
+
)
|
|
406
|
+
for blocks, unique_name in zip(
|
|
407
|
+
block_adapter.blocks[i],
|
|
408
|
+
block_adapter.unique_blocks_name[i],
|
|
409
|
+
):
|
|
410
|
+
patch_cached_stats(
|
|
411
|
+
blocks,
|
|
412
|
+
cache_context=unique_name,
|
|
413
|
+
cache_manager=cache_manager,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
@classmethod
|
|
417
|
+
def maybe_release_hooks(
|
|
418
|
+
cls,
|
|
419
|
+
pipe_or_adapter: Union[
|
|
420
|
+
DiffusionPipeline,
|
|
421
|
+
BlockAdapter,
|
|
422
|
+
],
|
|
423
|
+
):
|
|
424
|
+
# release model hooks
|
|
425
|
+
def _release_blocks_hooks(blocks):
|
|
426
|
+
return
|
|
427
|
+
|
|
428
|
+
def _release_transformer_hooks(transformer):
|
|
429
|
+
if hasattr(transformer, "_original_forward"):
|
|
430
|
+
original_forward = transformer._original_forward
|
|
431
|
+
transformer.forward = original_forward.__get__(transformer)
|
|
432
|
+
del transformer._original_forward
|
|
433
|
+
if hasattr(transformer, "_is_cached"):
|
|
434
|
+
del transformer._is_cached
|
|
435
|
+
|
|
436
|
+
def _release_pipeline_hooks(pipe):
|
|
437
|
+
if hasattr(pipe, "_original_call"):
|
|
438
|
+
original_call = pipe.__class__._original_call
|
|
439
|
+
pipe.__class__.__call__ = original_call
|
|
440
|
+
del pipe.__class__._original_call
|
|
441
|
+
if hasattr(pipe, "_cache_manager"):
|
|
442
|
+
cache_manager = pipe._cache_manager
|
|
443
|
+
if isinstance(cache_manager, CachedContextManager):
|
|
444
|
+
cache_manager.clear_contexts()
|
|
445
|
+
del pipe._cache_manager
|
|
446
|
+
if hasattr(pipe, "_is_cached"):
|
|
447
|
+
del pipe.__class__._is_cached
|
|
448
|
+
|
|
449
|
+
cls.release_hooks(
|
|
450
|
+
pipe_or_adapter,
|
|
451
|
+
_release_blocks_hooks,
|
|
452
|
+
_release_transformer_hooks,
|
|
453
|
+
_release_pipeline_hooks,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# release params hooks
|
|
457
|
+
def _release_blocks_params(blocks):
|
|
458
|
+
if hasattr(blocks, "_forward_pattern"):
|
|
459
|
+
del blocks._forward_pattern
|
|
460
|
+
if hasattr(blocks, "_cache_context_kwargs"):
|
|
461
|
+
del blocks._cache_context_kwargs
|
|
462
|
+
|
|
463
|
+
def _release_transformer_params(transformer):
|
|
464
|
+
if hasattr(transformer, "_forward_pattern"):
|
|
465
|
+
del transformer._forward_pattern
|
|
466
|
+
if hasattr(transformer, "_has_separate_cfg"):
|
|
467
|
+
del transformer._has_separate_cfg
|
|
468
|
+
if hasattr(transformer, "_cache_context_kwargs"):
|
|
469
|
+
del transformer._cache_context_kwargs
|
|
470
|
+
for blocks in BlockAdapter.find_blocks(transformer):
|
|
471
|
+
_release_blocks_params(blocks)
|
|
472
|
+
|
|
473
|
+
def _release_pipeline_params(pipe):
|
|
474
|
+
if hasattr(pipe, "_cache_context_kwargs"):
|
|
475
|
+
del pipe._cache_context_kwargs
|
|
476
|
+
|
|
477
|
+
cls.release_hooks(
|
|
478
|
+
pipe_or_adapter,
|
|
479
|
+
_release_blocks_params,
|
|
480
|
+
_release_transformer_params,
|
|
481
|
+
_release_pipeline_params,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# release stats hooks
|
|
485
|
+
cls.release_hooks(
|
|
486
|
+
pipe_or_adapter,
|
|
487
|
+
remove_cached_stats,
|
|
488
|
+
remove_cached_stats,
|
|
489
|
+
remove_cached_stats,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
@classmethod
|
|
493
|
+
def release_hooks(
|
|
494
|
+
cls,
|
|
495
|
+
pipe_or_adapter: Union[
|
|
496
|
+
DiffusionPipeline,
|
|
497
|
+
BlockAdapter,
|
|
498
|
+
],
|
|
499
|
+
_release_blocks: Callable,
|
|
500
|
+
_release_transformer: Callable,
|
|
501
|
+
_release_pipeline: Callable,
|
|
502
|
+
):
|
|
503
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
504
|
+
pipe = pipe_or_adapter
|
|
505
|
+
_release_pipeline(pipe)
|
|
506
|
+
if hasattr(pipe, "transformer"):
|
|
507
|
+
_release_transformer(pipe.transformer)
|
|
508
|
+
if hasattr(pipe, "transformer_2"): # Wan 2.2
|
|
509
|
+
_release_transformer(pipe.transformer_2)
|
|
510
|
+
elif isinstance(pipe_or_adapter, BlockAdapter):
|
|
511
|
+
adapter = pipe_or_adapter
|
|
512
|
+
BlockAdapter.assert_normalized(adapter)
|
|
513
|
+
_release_pipeline(adapter.pipe)
|
|
514
|
+
for transformer in BlockAdapter.flatten(adapter.transformer):
|
|
515
|
+
_release_transformer(transformer)
|
|
516
|
+
for blocks in BlockAdapter.flatten(adapter.blocks):
|
|
517
|
+
_release_blocks(blocks)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
+
from typing import Dict, Any
|
|
3
4
|
from cache_dit.cache_factory import ForwardPattern
|
|
4
5
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
5
6
|
CachedBlocks_Pattern_Base,
|
|
@@ -31,7 +32,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
31
32
|
# Call first `n` blocks to process the hidden states for
|
|
32
33
|
# more stable diff calculation.
|
|
33
34
|
# encoder_hidden_states: None Pattern 3, else 4, 5
|
|
34
|
-
hidden_states,
|
|
35
|
+
hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
|
|
35
36
|
hidden_states,
|
|
36
37
|
*args,
|
|
37
38
|
**kwargs,
|
|
@@ -60,11 +61,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
60
61
|
if can_use_cache:
|
|
61
62
|
self.cache_manager.add_cached_step()
|
|
62
63
|
del Fn_hidden_states_residual
|
|
63
|
-
hidden_states,
|
|
64
|
+
hidden_states, new_encoder_hidden_states = (
|
|
64
65
|
self.cache_manager.apply_cache(
|
|
65
66
|
hidden_states,
|
|
66
|
-
#
|
|
67
|
-
encoder_hidden_states,
|
|
67
|
+
new_encoder_hidden_states, # encoder_hidden_states not use cache
|
|
68
68
|
prefix=(
|
|
69
69
|
f"{self.cache_prefix}_Bn_residual"
|
|
70
70
|
if self.cache_manager.is_cache_residual()
|
|
@@ -80,12 +80,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
80
80
|
torch._dynamo.graph_break()
|
|
81
81
|
# Call last `n` blocks to further process the hidden states
|
|
82
82
|
# for higher precision.
|
|
83
|
-
|
|
84
|
-
hidden_states,
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
83
|
+
if self.cache_manager.Bn_compute_blocks() > 0:
|
|
84
|
+
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
85
|
+
hidden_states,
|
|
86
|
+
*args,
|
|
87
|
+
**kwargs,
|
|
88
|
+
)
|
|
89
89
|
else:
|
|
90
90
|
self.cache_manager.set_Fn_buffer(
|
|
91
91
|
Fn_hidden_states_residual,
|
|
@@ -99,19 +99,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
99
99
|
)
|
|
100
100
|
del Fn_hidden_states_residual
|
|
101
101
|
torch._dynamo.graph_break()
|
|
102
|
+
old_encoder_hidden_states = new_encoder_hidden_states
|
|
102
103
|
(
|
|
103
104
|
hidden_states,
|
|
104
|
-
|
|
105
|
+
new_encoder_hidden_states,
|
|
105
106
|
hidden_states_residual,
|
|
106
|
-
# None Pattern 3, else 4, 5
|
|
107
|
-
encoder_hidden_states_residual,
|
|
108
107
|
) = self.call_Mn_blocks( # middle
|
|
109
108
|
hidden_states,
|
|
110
|
-
# None Pattern 3, else 4, 5
|
|
111
|
-
encoder_hidden_states,
|
|
112
109
|
*args,
|
|
113
110
|
**kwargs,
|
|
114
111
|
)
|
|
112
|
+
if new_encoder_hidden_states is not None:
|
|
113
|
+
new_encoder_hidden_states_residual = (
|
|
114
|
+
new_encoder_hidden_states - old_encoder_hidden_states
|
|
115
|
+
)
|
|
115
116
|
torch._dynamo.graph_break()
|
|
116
117
|
if self.cache_manager.is_cache_residual():
|
|
117
118
|
self.cache_manager.set_Bn_buffer(
|
|
@@ -119,34 +120,32 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
119
120
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
120
121
|
)
|
|
121
122
|
else:
|
|
122
|
-
# TaylorSeer
|
|
123
123
|
self.cache_manager.set_Bn_buffer(
|
|
124
124
|
hidden_states,
|
|
125
125
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
126
126
|
)
|
|
127
|
+
|
|
127
128
|
if self.cache_manager.is_encoder_cache_residual():
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
129
|
+
if new_encoder_hidden_states is not None:
|
|
130
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
131
|
+
new_encoder_hidden_states_residual,
|
|
132
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
133
|
+
)
|
|
133
134
|
else:
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
)
|
|
135
|
+
if new_encoder_hidden_states is not None:
|
|
136
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
137
|
+
new_encoder_hidden_states_residual,
|
|
138
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
139
|
+
)
|
|
140
140
|
torch._dynamo.graph_break()
|
|
141
141
|
# Call last `n` blocks to further process the hidden states
|
|
142
142
|
# for higher precision.
|
|
143
|
-
|
|
144
|
-
hidden_states,
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
)
|
|
143
|
+
if self.cache_manager.Bn_compute_blocks() > 0:
|
|
144
|
+
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
145
|
+
hidden_states,
|
|
146
|
+
*args,
|
|
147
|
+
**kwargs,
|
|
148
|
+
)
|
|
150
149
|
|
|
151
150
|
torch._dynamo.graph_break()
|
|
152
151
|
|
|
@@ -154,12 +153,21 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
154
153
|
hidden_states
|
|
155
154
|
if self.forward_pattern.Return_H_Only
|
|
156
155
|
else (
|
|
157
|
-
(hidden_states,
|
|
156
|
+
(hidden_states, new_encoder_hidden_states)
|
|
158
157
|
if self.forward_pattern.Return_H_First
|
|
159
|
-
else (
|
|
158
|
+
else (new_encoder_hidden_states, hidden_states)
|
|
160
159
|
)
|
|
161
160
|
)
|
|
162
161
|
|
|
162
|
+
@torch.compiler.disable
|
|
163
|
+
def maybe_update_kwargs(
|
|
164
|
+
self, encoder_hidden_states, kwargs: Dict[str, Any]
|
|
165
|
+
) -> Dict[str, Any]:
|
|
166
|
+
# if "encoder_hidden_states" in kwargs:
|
|
167
|
+
# kwargs["encoder_hidden_states"] = encoder_hidden_states
|
|
168
|
+
# return kwargs
|
|
169
|
+
return kwargs
|
|
170
|
+
|
|
163
171
|
def call_Fn_blocks(
|
|
164
172
|
self,
|
|
165
173
|
hidden_states: torch.Tensor,
|
|
@@ -172,7 +180,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
172
180
|
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
173
181
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
174
182
|
)
|
|
175
|
-
|
|
183
|
+
new_encoder_hidden_states = None
|
|
176
184
|
for block in self._Fn_blocks():
|
|
177
185
|
hidden_states = block(
|
|
178
186
|
hidden_states,
|
|
@@ -180,25 +188,27 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
180
188
|
**kwargs,
|
|
181
189
|
)
|
|
182
190
|
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
183
|
-
hidden_states,
|
|
191
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
184
192
|
if not self.forward_pattern.Return_H_First:
|
|
185
|
-
hidden_states,
|
|
186
|
-
|
|
193
|
+
hidden_states, new_encoder_hidden_states = (
|
|
194
|
+
new_encoder_hidden_states,
|
|
187
195
|
hidden_states,
|
|
188
196
|
)
|
|
197
|
+
kwargs = self.maybe_update_kwargs(
|
|
198
|
+
new_encoder_hidden_states,
|
|
199
|
+
kwargs,
|
|
200
|
+
)
|
|
189
201
|
|
|
190
|
-
return hidden_states,
|
|
202
|
+
return hidden_states, new_encoder_hidden_states
|
|
191
203
|
|
|
192
204
|
def call_Mn_blocks(
|
|
193
205
|
self,
|
|
194
206
|
hidden_states: torch.Tensor,
|
|
195
|
-
# None Pattern 3, else 4, 5
|
|
196
|
-
encoder_hidden_states: torch.Tensor | None,
|
|
197
207
|
*args,
|
|
198
208
|
**kwargs,
|
|
199
209
|
):
|
|
200
210
|
original_hidden_states = hidden_states
|
|
201
|
-
|
|
211
|
+
new_encoder_hidden_states = None
|
|
202
212
|
for block in self._Mn_blocks():
|
|
203
213
|
hidden_states = block(
|
|
204
214
|
hidden_states,
|
|
@@ -206,44 +216,33 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
206
216
|
**kwargs,
|
|
207
217
|
)
|
|
208
218
|
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
209
|
-
hidden_states,
|
|
219
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
210
220
|
if not self.forward_pattern.Return_H_First:
|
|
211
|
-
hidden_states,
|
|
212
|
-
|
|
221
|
+
hidden_states, new_encoder_hidden_states = (
|
|
222
|
+
new_encoder_hidden_states,
|
|
213
223
|
hidden_states,
|
|
214
224
|
)
|
|
225
|
+
kwargs = self.maybe_update_kwargs(
|
|
226
|
+
new_encoder_hidden_states,
|
|
227
|
+
kwargs,
|
|
228
|
+
)
|
|
215
229
|
|
|
216
230
|
# compute hidden_states residual
|
|
217
231
|
hidden_states = hidden_states.contiguous()
|
|
218
232
|
hidden_states_residual = hidden_states - original_hidden_states
|
|
219
|
-
if (
|
|
220
|
-
original_encoder_hidden_states is not None
|
|
221
|
-
and encoder_hidden_states is not None
|
|
222
|
-
): # Pattern 4, 5
|
|
223
|
-
encoder_hidden_states_residual = (
|
|
224
|
-
encoder_hidden_states - original_encoder_hidden_states
|
|
225
|
-
)
|
|
226
|
-
else:
|
|
227
|
-
encoder_hidden_states_residual = None # Pattern 3
|
|
228
233
|
|
|
229
234
|
return (
|
|
230
235
|
hidden_states,
|
|
231
|
-
|
|
236
|
+
new_encoder_hidden_states,
|
|
232
237
|
hidden_states_residual,
|
|
233
|
-
encoder_hidden_states_residual,
|
|
234
238
|
)
|
|
235
239
|
|
|
236
240
|
def call_Bn_blocks(
|
|
237
241
|
self,
|
|
238
242
|
hidden_states: torch.Tensor,
|
|
239
|
-
# None Pattern 3, else 4, 5
|
|
240
|
-
encoder_hidden_states: torch.Tensor | None,
|
|
241
243
|
*args,
|
|
242
244
|
**kwargs,
|
|
243
245
|
):
|
|
244
|
-
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
245
|
-
return hidden_states, encoder_hidden_states
|
|
246
|
-
|
|
247
246
|
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
248
247
|
self.transformer_blocks
|
|
249
248
|
), (
|
|
@@ -264,11 +263,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
264
263
|
**kwargs,
|
|
265
264
|
)
|
|
266
265
|
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
267
|
-
hidden_states,
|
|
266
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
268
267
|
if not self.forward_pattern.Return_H_First:
|
|
269
|
-
hidden_states,
|
|
270
|
-
|
|
268
|
+
hidden_states, new_encoder_hidden_states = (
|
|
269
|
+
new_encoder_hidden_states,
|
|
271
270
|
hidden_states,
|
|
272
271
|
)
|
|
272
|
+
kwargs = self.maybe_update_kwargs(
|
|
273
|
+
new_encoder_hidden_states,
|
|
274
|
+
kwargs,
|
|
275
|
+
)
|
|
273
276
|
|
|
274
|
-
return hidden_states,
|
|
277
|
+
return hidden_states, new_encoder_hidden_states
|
|
@@ -733,17 +733,15 @@ class CachedContextManager:
|
|
|
733
733
|
encoder_prefix
|
|
734
734
|
)
|
|
735
735
|
|
|
736
|
-
|
|
737
|
-
encoder_hidden_states_prev is not None
|
|
738
|
-
), f"{prefix}_encoder_buffer must be set before"
|
|
736
|
+
if encoder_hidden_states_prev is not None:
|
|
739
737
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
738
|
+
if self.is_encoder_cache_residual():
|
|
739
|
+
encoder_hidden_states = (
|
|
740
|
+
encoder_hidden_states_prev + encoder_hidden_states
|
|
741
|
+
)
|
|
742
|
+
else:
|
|
743
|
+
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
744
|
+
encoder_hidden_states = encoder_hidden_states_prev
|
|
747
745
|
|
|
748
746
|
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
749
747
|
|
|
@@ -1,11 +1,9 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Any, Tuple, List
|
|
1
|
+
from typing import Any, Tuple, List, Union
|
|
3
2
|
from diffusers import DiffusionPipeline
|
|
4
3
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
5
4
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
6
5
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
7
6
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
8
|
-
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
9
7
|
|
|
10
8
|
from cache_dit.logger import init_logger
|
|
11
9
|
|
|
@@ -14,7 +12,10 @@ logger = init_logger(__name__)
|
|
|
14
12
|
|
|
15
13
|
def enable_cache(
|
|
16
14
|
# DiffusionPipeline or BlockAdapter
|
|
17
|
-
pipe_or_adapter:
|
|
15
|
+
pipe_or_adapter: Union[
|
|
16
|
+
DiffusionPipeline,
|
|
17
|
+
BlockAdapter,
|
|
18
|
+
],
|
|
18
19
|
# Cache context kwargs
|
|
19
20
|
Fn_compute_blocks: int = 8,
|
|
20
21
|
Bn_compute_blocks: int = 0,
|
|
@@ -32,7 +33,10 @@ def enable_cache(
|
|
|
32
33
|
taylorseer_cache_type: str = "residual",
|
|
33
34
|
taylorseer_order: int = 2,
|
|
34
35
|
**other_cache_context_kwargs,
|
|
35
|
-
) ->
|
|
36
|
+
) -> Union[
|
|
37
|
+
DiffusionPipeline,
|
|
38
|
+
BlockAdapter,
|
|
39
|
+
]:
|
|
36
40
|
r"""
|
|
37
41
|
Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
|
|
38
42
|
that match the specific Input and Output patterns).
|
|
@@ -102,11 +106,11 @@ def enable_cache(
|
|
|
102
106
|
>>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
|
|
103
107
|
>>> output = pipe(...) # Just call the pipe as normal.
|
|
104
108
|
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
109
|
+
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
105
110
|
"""
|
|
106
|
-
|
|
107
111
|
# Collect cache context kwargs
|
|
108
112
|
cache_context_kwargs = other_cache_context_kwargs.copy()
|
|
109
|
-
if cache_type := cache_context_kwargs.get("cache_type", None):
|
|
113
|
+
if (cache_type := cache_context_kwargs.get("cache_type", None)) is not None:
|
|
110
114
|
if cache_type == CacheType.NONE:
|
|
111
115
|
return pipe_or_adapter
|
|
112
116
|
|
|
@@ -145,79 +149,17 @@ def enable_cache(
|
|
|
145
149
|
|
|
146
150
|
|
|
147
151
|
def disable_cache(
|
|
148
|
-
|
|
149
|
-
|
|
152
|
+
pipe_or_adapter: Union[
|
|
153
|
+
DiffusionPipeline,
|
|
154
|
+
BlockAdapter,
|
|
155
|
+
],
|
|
150
156
|
):
|
|
151
|
-
|
|
152
|
-
|
|
157
|
+
CachedAdapter.maybe_release_hooks(pipe_or_adapter)
|
|
158
|
+
logger.warning(
|
|
159
|
+
f"Cache Acceleration is disabled for: "
|
|
160
|
+
f"{pipe_or_adapter.__class__.__name__}."
|
|
153
161
|
)
|
|
154
162
|
|
|
155
|
-
def _disable_blocks(blocks: torch.nn.ModuleList):
|
|
156
|
-
if blocks is None:
|
|
157
|
-
return
|
|
158
|
-
if hasattr(blocks, "_forward_pattern"):
|
|
159
|
-
del blocks._forward_pattern
|
|
160
|
-
if hasattr(blocks, "_cache_context_kwargs"):
|
|
161
|
-
del blocks._cache_context_kwargs
|
|
162
|
-
remove_cached_stats(blocks)
|
|
163
|
-
|
|
164
|
-
def _disable_transformer(transformer: torch.nn.Module):
|
|
165
|
-
if transformer is None or not BlockAdapter.is_cached(transformer):
|
|
166
|
-
return
|
|
167
|
-
if original_forward := getattr(transformer, "_original_forward"):
|
|
168
|
-
transformer.forward = original_forward.__get__(transformer)
|
|
169
|
-
del transformer._original_forward
|
|
170
|
-
if hasattr(transformer, "_is_cached"):
|
|
171
|
-
del transformer._is_cached
|
|
172
|
-
if hasattr(transformer, "_forward_pattern"):
|
|
173
|
-
del transformer._forward_pattern
|
|
174
|
-
if hasattr(transformer, "_has_separate_cfg"):
|
|
175
|
-
del transformer._has_separate_cfg
|
|
176
|
-
if hasattr(transformer, "_cache_context_kwargs"):
|
|
177
|
-
del transformer._cache_context_kwargs
|
|
178
|
-
remove_cached_stats(transformer)
|
|
179
|
-
for blocks in BlockAdapter.find_blocks(transformer):
|
|
180
|
-
_disable_blocks(blocks)
|
|
181
|
-
|
|
182
|
-
def _disable_pipe(pipe: DiffusionPipeline):
|
|
183
|
-
if pipe is None or not BlockAdapter.is_cached(pipe):
|
|
184
|
-
return
|
|
185
|
-
if original_call := getattr(pipe, "_original_call"):
|
|
186
|
-
pipe.__class__.__call__ = original_call
|
|
187
|
-
del pipe.__class__._original_call
|
|
188
|
-
if cache_manager := getattr(pipe, "_cache_manager"):
|
|
189
|
-
assert isinstance(cache_manager, CachedContextManager)
|
|
190
|
-
cache_manager.clear_contexts()
|
|
191
|
-
del pipe._cache_manager
|
|
192
|
-
if hasattr(pipe, "_is_cached"):
|
|
193
|
-
del pipe.__class__._is_cached
|
|
194
|
-
if hasattr(pipe, "_cache_context_kwargs"):
|
|
195
|
-
del pipe._cache_context_kwargs
|
|
196
|
-
remove_cached_stats(pipe)
|
|
197
|
-
|
|
198
|
-
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
199
|
-
pipe = pipe_or_adapter
|
|
200
|
-
_disable_pipe(pipe)
|
|
201
|
-
if hasattr(pipe, "transformer"):
|
|
202
|
-
_disable_transformer(pipe.transformer)
|
|
203
|
-
if hasattr(pipe, "transformer_2"): # Wan 2.2
|
|
204
|
-
_disable_transformer(pipe.transformer_2)
|
|
205
|
-
pipe_cls_name = pipe.__class__.__name__
|
|
206
|
-
logger.warning(f"Cache Acceleration is disabled for: {pipe_cls_name}")
|
|
207
|
-
elif isinstance(pipe_or_adapter, BlockAdapter):
|
|
208
|
-
# BlockAdapter
|
|
209
|
-
adapter = pipe_or_adapter
|
|
210
|
-
BlockAdapter.assert_normalized(adapter)
|
|
211
|
-
_disable_pipe(adapter.pipe)
|
|
212
|
-
for transformer in BlockAdapter.flatten(adapter.transformer):
|
|
213
|
-
_disable_transformer(transformer)
|
|
214
|
-
for blocks in BlockAdapter.flatten(adapter.blocks):
|
|
215
|
-
_disable_blocks(blocks)
|
|
216
|
-
pipe_cls_name = adapter.pipe.__class__.__name__
|
|
217
|
-
logger.warning(f"Cache Acceleration is disabled for: {pipe_cls_name}")
|
|
218
|
-
else:
|
|
219
|
-
pass # do nothing
|
|
220
|
-
|
|
221
163
|
|
|
222
164
|
def supported_pipelines(
|
|
223
165
|
**kwargs,
|
|
@@ -22,11 +22,11 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
|
22
22
|
if isinstance(type_hint, CacheType):
|
|
23
23
|
return type_hint
|
|
24
24
|
|
|
25
|
-
elif type_hint.
|
|
26
|
-
"
|
|
27
|
-
"
|
|
28
|
-
"
|
|
29
|
-
"
|
|
25
|
+
elif type_hint.upper() in (
|
|
26
|
+
"DUAL_BLOCK_CACHE",
|
|
27
|
+
"DB_CACHE",
|
|
28
|
+
"DBCACHE",
|
|
29
|
+
"DB",
|
|
30
30
|
):
|
|
31
31
|
return CacheType.DBCache
|
|
32
32
|
return CacheType.NONE
|
|
@@ -56,7 +56,8 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
56
56
|
transformer.forward = __patch_transformer_forward__.__get__(
|
|
57
57
|
transformer
|
|
58
58
|
)
|
|
59
|
-
|
|
59
|
+
|
|
60
|
+
transformer._is_patched = is_patched # True or False
|
|
60
61
|
|
|
61
62
|
cls_name = transformer.__class__.__name__
|
|
62
63
|
logger.info(
|
|
@@ -57,7 +57,8 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
57
57
|
transformer.forward = __patch_transformer_forward__.__get__(
|
|
58
58
|
transformer
|
|
59
59
|
)
|
|
60
|
-
|
|
60
|
+
|
|
61
|
+
transformer._is_patched = is_patched # True or False
|
|
61
62
|
|
|
62
63
|
cls_name = transformer.__class__.__name__
|
|
63
64
|
logger.info(
|
cache_dit/utils.py
CHANGED
|
@@ -52,6 +52,9 @@ def summary(
|
|
|
52
52
|
if hasattr(adapter_or_others, "transformer_2"):
|
|
53
53
|
transformer_2 = adapter_or_others.transformer_2
|
|
54
54
|
|
|
55
|
+
if not BlockAdapter.is_cached(transformer):
|
|
56
|
+
return [CacheStats()]
|
|
57
|
+
|
|
55
58
|
blocks_stats: List[CacheStats] = []
|
|
56
59
|
for blocks in BlockAdapter.find_blocks(transformer):
|
|
57
60
|
blocks_stats.append(
|
|
@@ -212,7 +215,8 @@ def _summary(
|
|
|
212
215
|
if logging:
|
|
213
216
|
print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
|
|
214
217
|
else:
|
|
215
|
-
|
|
218
|
+
if logging:
|
|
219
|
+
logger.warning(f"Can't find Cache Options for: {cls_name}")
|
|
216
220
|
|
|
217
221
|
if hasattr(module, "_cached_steps"):
|
|
218
222
|
cached_steps: list[int] = module._cached_steps
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.30
|
|
4
4
|
Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -59,7 +59,7 @@ Dynamic: requires-python
|
|
|
59
59
|
🔥<b><a href="#unified">Unified Cache APIs</a> | <a href="#dbcache">DBCache</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a></b>🔥
|
|
60
60
|
</p>
|
|
61
61
|
<p align="center">
|
|
62
|
-
🎉Now, <b>cache-dit</b> covers <b>
|
|
62
|
+
🎉Now, <b>cache-dit</b> covers <b>mainstream</b> Diffusers' <b>DiT-based</b> Pipelines🎉<br>
|
|
63
63
|
🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
|
|
64
64
|
</p>
|
|
65
65
|
</div>
|
|
@@ -87,7 +87,6 @@ Dynamic: requires-python
|
|
|
87
87
|
<summary> Previous News </summary>
|
|
88
88
|
|
|
89
89
|
- [2025-09-01] 📚[**Hybird Forward Pattern**](#unified) is supported! Please check [FLUX.1-dev](./examples/run_flux_adapter.py) as an example.
|
|
90
|
-
- [2025-08-29] 🔥</b>Covers <b>100%</b> Diffusers' <b>DiT-based</b> Pipelines: **[BlockAdapter](#unified) + [Pattern Matching](#unified).**
|
|
91
90
|
- [2025-08-10] 🔥[**FLUX.1-Kontext-dev**](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please refer [run_flux_kontext.py](./examples/pipeline/run_flux_kontext.py) as an example.
|
|
92
91
|
- [2025-07-18] 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).
|
|
93
92
|
|
|
@@ -130,19 +129,8 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
130
129
|
|
|
131
130
|
<div id="supported"></div>
|
|
132
131
|
|
|
133
|
-
```python
|
|
134
|
-
>>> import cache_dit
|
|
135
|
-
>>> cache_dit.supported_pipelines()
|
|
136
|
-
(31, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTXVideo*',
|
|
137
|
-
'Allegro*', 'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'SD3*',
|
|
138
|
-
'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'HunyuanDiT*', 'HunyuanDiTPAG*', 'Lumina*', 'Lumina2*',
|
|
139
|
-
'OmniGen*', 'PixArt*', 'Sana*', 'ShapE*', 'StableAudio*', 'VisualCloze*', 'AuraFlow*',
|
|
140
|
-
'Chroma*', 'HiDream*'])
|
|
141
|
-
```
|
|
142
|
-
|
|
143
132
|
Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Unified Cache APIs](#unified) for more details. Here are just some of the tested models listed:
|
|
144
133
|
|
|
145
|
-
|
|
146
134
|
- [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
147
135
|
- [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
148
136
|
- [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -154,35 +142,7 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
154
142
|
- [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
155
143
|
- [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
156
144
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
157
|
-
- [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
158
145
|
|
|
159
|
-
<details>
|
|
160
|
-
<summary> More Pipelines </summary>
|
|
161
|
-
|
|
162
|
-
- [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
163
|
-
- [🚀LTXVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
164
|
-
- [🚀Allegro](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
165
|
-
- [🚀CogView3Plus](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
166
|
-
- [🚀CogView4](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
167
|
-
- [🚀Cosmos](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
168
|
-
- [🚀EasyAnimate](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
169
|
-
- [🚀SkyReelsV2](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
170
|
-
- [🚀SD3](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
171
|
-
- [🚀ConsisID](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
172
|
-
- [🚀DiT](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
173
|
-
- [🚀Amused](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
174
|
-
- [🚀HunyuanDiTPAG](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
175
|
-
- [🚀Lumina](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
176
|
-
- [🚀Lumina2](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
177
|
-
- [🚀OmniGen](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
178
|
-
- [🚀PixArt](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
179
|
-
- [🚀Sana](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
180
|
-
- [🚀StableAudio](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
181
|
-
- [🚀VisualCloze](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
182
|
-
- [🚀AuraFlow](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
183
|
-
- [🚀Chroma](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
184
|
-
- [🚀HiDream](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
185
|
-
|
|
186
146
|
</details>
|
|
187
147
|
|
|
188
148
|
## 🎉Unified Cache APIs
|
|
@@ -1,30 +1,30 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=6uKAYeE03adIcUS0SDwp52AaQx0KO8z_-07D_lPHrz8,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
|
-
cache_dit/utils.py,sha256=
|
|
4
|
+
cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
|
|
5
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
6
6
|
cache_dit/cache_factory/__init__.py,sha256=Iw6-iJLFbdzCsIDZXXOw371L-HPmoeZO_P9a3sDjP5s,1103
|
|
7
|
-
cache_dit/cache_factory/cache_adapters.py,sha256=
|
|
8
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
9
|
-
cache_dit/cache_factory/cache_types.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/cache_adapters.py,sha256=TA_0mEHMdSQDrt4rYASeX4-BD8pJOznSJfMV3hkrGuk,17851
|
|
8
|
+
cache_dit/cache_factory/cache_interface.py,sha256=y1nY6R3MucRmAnG2UJRI_tIKrRk27FktGWLbfckf3zE,8543
|
|
9
|
+
cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
|
|
10
10
|
cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
|
|
11
11
|
cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
|
|
12
|
-
cache_dit/cache_factory/block_adapters/__init__.py,sha256=
|
|
13
|
-
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=
|
|
14
|
-
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=
|
|
12
|
+
cache_dit/cache_factory/block_adapters/__init__.py,sha256=EA-4mEVy-JJ5vRDo6C3nJIOXu0ZDNc6FQ-ZLAKHDtB0,17251
|
|
13
|
+
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=OrKhuNdcGBCgSsPchdf4h32Ad-bQVUXNigMhPJ4cCvk,21069
|
|
14
|
+
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=ZeN2wGPmuf2u3puSsBx8x-rl3wRo8-cWcuWNcrssVfA,2553
|
|
15
15
|
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=OWjnpJxA8EJVoRzuyb5miuiRphUFj831-bbtWsTDjnM,2750
|
|
16
16
|
cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
|
|
17
|
-
cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=
|
|
17
|
+
cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=nf2f5wdxp6tfq9AhFyMyBeKiZfxh63WG1g8q-c2BBSg,10182
|
|
18
18
|
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=XSDy3hsaKbAZPGZY92YgGA0qLgjQyIX8irQkb2R5T2c,20331
|
|
19
19
|
cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
|
|
20
20
|
cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
|
|
21
21
|
cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=N88WLdd4KE9DuMWmpX8URcF55E2zWNwcKMxgVYkxMJY,13691
|
|
22
|
-
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=
|
|
22
|
+
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=_NUXcMYYEIVfDHpc4HJr9RUjU5RUEkZmAgFGE8bh5Wc,34883
|
|
23
23
|
cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
|
|
24
24
|
cache_dit/cache_factory/patch_functors/__init__.py,sha256=yK05iONMGILsTZ83ynrUUJtiJKJ_FDjxmVIzRLy416s,252
|
|
25
25
|
cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
|
|
26
|
-
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=
|
|
27
|
-
cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=
|
|
26
|
+
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=N3UzszCM55g3GHeVdyXkid2_n72VJrfqBM2gdtD52gw,10042
|
|
27
|
+
cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=rJsbGEIxWPGnZyGI4ZwLLBdg8u6ZItsOeh0UoD_bVwk,9551
|
|
28
28
|
cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
|
|
29
29
|
cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
|
|
30
30
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -39,9 +39,9 @@ cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70
|
|
|
39
39
|
cache_dit/quantize/quantize_ao.py,sha256=mGspqYgQtenl3QnKPtsSYsSD7LbVX93f1M940bhXKLU,6066
|
|
40
40
|
cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
|
|
41
41
|
cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
-
cache_dit-0.2.
|
|
43
|
-
cache_dit-0.2.
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
46
|
-
cache_dit-0.2.
|
|
47
|
-
cache_dit-0.2.
|
|
42
|
+
cache_dit-0.2.30.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
43
|
+
cache_dit-0.2.30.dist-info/METADATA,sha256=8Ln_X5fw14U3greCM7cSukrei1SRiMDpksFalg5ZBAU,22130
|
|
44
|
+
cache_dit-0.2.30.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
45
|
+
cache_dit-0.2.30.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
46
|
+
cache_dit-0.2.30.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
47
|
+
cache_dit-0.2.30.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|