cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -1,19 +1,21 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import torch
|
|
2
3
|
import unittest
|
|
3
4
|
import functools
|
|
4
5
|
from contextlib import ExitStack
|
|
5
6
|
from typing import Dict, List, Tuple, Any, Union, Callable, Optional
|
|
6
7
|
|
|
7
|
-
from diffusers import DiffusionPipeline
|
|
8
|
-
|
|
9
|
-
from cache_dit.
|
|
10
|
-
from cache_dit.
|
|
11
|
-
from cache_dit.
|
|
12
|
-
from cache_dit.
|
|
13
|
-
from cache_dit.
|
|
14
|
-
from cache_dit.
|
|
15
|
-
from cache_dit.
|
|
16
|
-
from cache_dit.
|
|
8
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
9
|
+
|
|
10
|
+
from cache_dit.caching.cache_types import CacheType
|
|
11
|
+
from cache_dit.caching.block_adapters import BlockAdapter
|
|
12
|
+
from cache_dit.caching.block_adapters import FakeDiffusionPipeline
|
|
13
|
+
from cache_dit.caching.block_adapters import ParamsModifier
|
|
14
|
+
from cache_dit.caching.block_adapters import BlockAdapterRegistry
|
|
15
|
+
from cache_dit.caching.cache_contexts import ContextManager
|
|
16
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
17
|
+
from cache_dit.caching.cache_contexts import CalibratorConfig
|
|
18
|
+
from cache_dit.caching.cache_blocks import UnifiedBlocks
|
|
17
19
|
from cache_dit.logger import init_logger
|
|
18
20
|
|
|
19
21
|
logger = init_logger(__name__)
|
|
@@ -31,8 +33,11 @@ class CachedAdapter:
|
|
|
31
33
|
pipe_or_adapter: Union[
|
|
32
34
|
DiffusionPipeline,
|
|
33
35
|
BlockAdapter,
|
|
36
|
+
# Transformer-only
|
|
37
|
+
torch.nn.Module,
|
|
38
|
+
ModelMixin,
|
|
34
39
|
],
|
|
35
|
-
**
|
|
40
|
+
**context_kwargs,
|
|
36
41
|
) -> Union[
|
|
37
42
|
DiffusionPipeline,
|
|
38
43
|
BlockAdapter,
|
|
@@ -41,7 +46,9 @@ class CachedAdapter:
|
|
|
41
46
|
pipe_or_adapter is not None
|
|
42
47
|
), "pipe or block_adapter can not both None!"
|
|
43
48
|
|
|
44
|
-
if isinstance(
|
|
49
|
+
if isinstance(
|
|
50
|
+
pipe_or_adapter, (DiffusionPipeline, torch.nn.Module, ModelMixin)
|
|
51
|
+
):
|
|
45
52
|
if BlockAdapterRegistry.is_supported(pipe_or_adapter):
|
|
46
53
|
logger.info(
|
|
47
54
|
f"{pipe_or_adapter.__class__.__name__} is officially "
|
|
@@ -51,16 +58,22 @@ class CachedAdapter:
|
|
|
51
58
|
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
52
59
|
pipe_or_adapter
|
|
53
60
|
)
|
|
54
|
-
|
|
61
|
+
assert block_adapter is not None, (
|
|
62
|
+
f"BlockAdapter for {pipe_or_adapter.__class__.__name__} "
|
|
63
|
+
"should not be None!"
|
|
64
|
+
)
|
|
65
|
+
if params_modifiers := context_kwargs.pop(
|
|
55
66
|
"params_modifiers",
|
|
56
67
|
None,
|
|
57
68
|
):
|
|
58
69
|
block_adapter.params_modifiers = params_modifiers
|
|
59
70
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
71
|
+
block_adapter = cls.cachify(block_adapter, **context_kwargs)
|
|
72
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
73
|
+
return block_adapter.pipe
|
|
74
|
+
|
|
75
|
+
return block_adapter.transformer
|
|
76
|
+
|
|
64
77
|
else:
|
|
65
78
|
raise ValueError(
|
|
66
79
|
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
@@ -72,21 +85,21 @@ class CachedAdapter:
|
|
|
72
85
|
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
73
86
|
)
|
|
74
87
|
if pipe_or_adapter.params_modifiers is None:
|
|
75
|
-
if params_modifiers :=
|
|
88
|
+
if params_modifiers := context_kwargs.pop(
|
|
76
89
|
"params_modifiers", None
|
|
77
90
|
):
|
|
78
91
|
pipe_or_adapter.params_modifiers = params_modifiers
|
|
79
92
|
|
|
80
93
|
return cls.cachify(
|
|
81
94
|
pipe_or_adapter,
|
|
82
|
-
**
|
|
95
|
+
**context_kwargs,
|
|
83
96
|
)
|
|
84
97
|
|
|
85
98
|
@classmethod
|
|
86
99
|
def cachify(
|
|
87
100
|
cls,
|
|
88
101
|
block_adapter: BlockAdapter,
|
|
89
|
-
**
|
|
102
|
+
**context_kwargs,
|
|
90
103
|
) -> BlockAdapter:
|
|
91
104
|
|
|
92
105
|
if block_adapter.auto:
|
|
@@ -103,14 +116,15 @@ class CachedAdapter:
|
|
|
103
116
|
|
|
104
117
|
# 1. Apply cache on pipeline: wrap cache context, must
|
|
105
118
|
# call create_context before mock_blocks.
|
|
106
|
-
cls.create_context(
|
|
119
|
+
_, contexts_kwargs = cls.create_context(
|
|
107
120
|
block_adapter,
|
|
108
|
-
**
|
|
121
|
+
**context_kwargs,
|
|
109
122
|
)
|
|
110
123
|
|
|
111
124
|
# 2. Apply cache on transformer: mock cached blocks
|
|
112
125
|
cls.mock_blocks(
|
|
113
126
|
block_adapter,
|
|
127
|
+
contexts_kwargs,
|
|
114
128
|
)
|
|
115
129
|
|
|
116
130
|
return block_adapter
|
|
@@ -119,12 +133,10 @@ class CachedAdapter:
|
|
|
119
133
|
def check_context_kwargs(
|
|
120
134
|
cls,
|
|
121
135
|
block_adapter: BlockAdapter,
|
|
122
|
-
**
|
|
136
|
+
**context_kwargs,
|
|
123
137
|
):
|
|
124
|
-
# Check
|
|
125
|
-
cache_config: BasicCacheConfig =
|
|
126
|
-
"cache_config"
|
|
127
|
-
] # ref
|
|
138
|
+
# Check context_kwargs
|
|
139
|
+
cache_config: BasicCacheConfig = context_kwargs["cache_config"] # ref
|
|
128
140
|
assert cache_config is not None, "cache_config can not be None."
|
|
129
141
|
if cache_config.enable_separate_cfg is None:
|
|
130
142
|
# Check cfg for some specific case if users don't set it as True
|
|
@@ -150,19 +162,23 @@ class CachedAdapter:
|
|
|
150
162
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
151
163
|
)
|
|
152
164
|
|
|
153
|
-
cache_type =
|
|
165
|
+
cache_type = context_kwargs.pop("cache_type", None)
|
|
154
166
|
if cache_type is not None:
|
|
155
|
-
assert (
|
|
156
|
-
cache_type
|
|
157
|
-
), "
|
|
167
|
+
assert isinstance(
|
|
168
|
+
cache_type, CacheType
|
|
169
|
+
), f"cache_type must be CacheType Enum, but got {type(cache_type)}."
|
|
170
|
+
assert cache_type == cache_config.cache_type, (
|
|
171
|
+
f"cache_type from context_kwargs ({cache_type}) must be the same "
|
|
172
|
+
f"as that from cache_config ({cache_config.cache_type})."
|
|
173
|
+
)
|
|
158
174
|
|
|
159
|
-
return
|
|
175
|
+
return context_kwargs
|
|
160
176
|
|
|
161
177
|
@classmethod
|
|
162
178
|
def create_context(
|
|
163
179
|
cls,
|
|
164
180
|
block_adapter: BlockAdapter,
|
|
165
|
-
**
|
|
181
|
+
**context_kwargs,
|
|
166
182
|
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
167
183
|
|
|
168
184
|
BlockAdapter.assert_normalized(block_adapter)
|
|
@@ -170,49 +186,71 @@ class CachedAdapter:
|
|
|
170
186
|
if BlockAdapter.is_cached(block_adapter.pipe):
|
|
171
187
|
return block_adapter.pipe
|
|
172
188
|
|
|
173
|
-
# Check
|
|
174
|
-
|
|
175
|
-
block_adapter, **
|
|
189
|
+
# Check context_kwargs
|
|
190
|
+
context_kwargs = cls.check_context_kwargs(
|
|
191
|
+
block_adapter, **context_kwargs
|
|
176
192
|
)
|
|
177
|
-
# Apply cache on pipeline: wrap cache context
|
|
178
|
-
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
179
193
|
|
|
180
194
|
# Each Pipeline should have it's own context manager instance.
|
|
181
195
|
# Different transformers (Wan2.2, etc) should shared the same
|
|
182
196
|
# cache manager but with different cache context (according
|
|
183
197
|
# to their unique instance id).
|
|
184
|
-
|
|
198
|
+
cache_config: BasicCacheConfig = context_kwargs.get(
|
|
199
|
+
"cache_config", None
|
|
200
|
+
)
|
|
201
|
+
assert cache_config is not None, "cache_config can not be None."
|
|
202
|
+
# Apply cache on pipeline: wrap cache context
|
|
203
|
+
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
204
|
+
context_manager = ContextManager(
|
|
185
205
|
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
206
|
+
cache_type=cache_config.cache_type,
|
|
207
|
+
# Force use persistent_context for FakeDiffusionPipeline
|
|
208
|
+
persistent_context=isinstance(
|
|
209
|
+
block_adapter.pipe, FakeDiffusionPipeline
|
|
210
|
+
),
|
|
186
211
|
)
|
|
187
|
-
block_adapter.pipe._cache_manager = cache_manager # instance level
|
|
188
|
-
|
|
189
212
|
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
190
|
-
block_adapter, **
|
|
213
|
+
block_adapter, **context_kwargs
|
|
191
214
|
)
|
|
192
215
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
):
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
216
|
+
block_adapter.pipe._context_manager = context_manager # instance level
|
|
217
|
+
|
|
218
|
+
if not context_manager.persistent_context:
|
|
219
|
+
|
|
220
|
+
original_call = block_adapter.pipe.__class__.__call__
|
|
221
|
+
|
|
222
|
+
@functools.wraps(original_call)
|
|
223
|
+
def new_call(self, *args, **kwargs):
|
|
224
|
+
with ExitStack() as stack:
|
|
225
|
+
# cache context will be reset for each pipe inference
|
|
226
|
+
for context_name, context_kwargs in zip(
|
|
227
|
+
flatten_contexts, contexts_kwargs
|
|
228
|
+
):
|
|
229
|
+
stack.enter_context(
|
|
230
|
+
context_manager.enter_context(
|
|
231
|
+
context_manager.reset_context(
|
|
232
|
+
context_name,
|
|
233
|
+
**context_kwargs,
|
|
234
|
+
),
|
|
235
|
+
)
|
|
208
236
|
)
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
237
|
+
outputs = original_call(self, *args, **kwargs)
|
|
238
|
+
cls.apply_stats_hooks(block_adapter)
|
|
239
|
+
return outputs
|
|
240
|
+
|
|
241
|
+
block_adapter.pipe.__class__.__call__ = new_call
|
|
242
|
+
block_adapter.pipe.__class__._original_call = original_call
|
|
243
|
+
|
|
244
|
+
else:
|
|
245
|
+
# Init persistent cache context for transformer
|
|
246
|
+
for context_name, context_kwargs in zip(
|
|
247
|
+
flatten_contexts, contexts_kwargs
|
|
248
|
+
):
|
|
249
|
+
context_manager.reset_context(
|
|
250
|
+
context_name,
|
|
251
|
+
**context_kwargs,
|
|
252
|
+
)
|
|
213
253
|
|
|
214
|
-
block_adapter.pipe.__class__.__call__ = new_call
|
|
215
|
-
block_adapter.pipe.__class__._original_call = original_call
|
|
216
254
|
block_adapter.pipe.__class__._is_cached = True
|
|
217
255
|
|
|
218
256
|
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
@@ -223,14 +261,14 @@ class CachedAdapter:
|
|
|
223
261
|
def modify_context_params(
|
|
224
262
|
cls,
|
|
225
263
|
block_adapter: BlockAdapter,
|
|
226
|
-
**
|
|
264
|
+
**context_kwargs,
|
|
227
265
|
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
228
266
|
|
|
229
267
|
flatten_contexts = BlockAdapter.flatten(
|
|
230
268
|
block_adapter.unique_blocks_name
|
|
231
269
|
)
|
|
232
270
|
contexts_kwargs = [
|
|
233
|
-
|
|
271
|
+
copy.deepcopy(context_kwargs) # must deep copy
|
|
234
272
|
for _ in range(
|
|
235
273
|
len(flatten_contexts),
|
|
236
274
|
)
|
|
@@ -251,9 +289,41 @@ class CachedAdapter:
|
|
|
251
289
|
for i in range(
|
|
252
290
|
min(len(contexts_kwargs), len(flatten_modifiers)),
|
|
253
291
|
):
|
|
254
|
-
|
|
255
|
-
flatten_modifiers[
|
|
256
|
-
|
|
292
|
+
if "cache_config" in flatten_modifiers[i]._context_kwargs:
|
|
293
|
+
modifier_cache_config = flatten_modifiers[
|
|
294
|
+
i
|
|
295
|
+
]._context_kwargs.get("cache_config", None)
|
|
296
|
+
modifier_calibrator_config = flatten_modifiers[
|
|
297
|
+
i
|
|
298
|
+
]._context_kwargs.get("calibrator_config", None)
|
|
299
|
+
if modifier_cache_config is not None:
|
|
300
|
+
assert isinstance(
|
|
301
|
+
modifier_cache_config, BasicCacheConfig
|
|
302
|
+
), (
|
|
303
|
+
f"cache_config must be BasicCacheConfig, but got "
|
|
304
|
+
f"{type(modifier_cache_config)}."
|
|
305
|
+
)
|
|
306
|
+
contexts_kwargs[i]["cache_config"].update(
|
|
307
|
+
**modifier_cache_config.as_dict()
|
|
308
|
+
)
|
|
309
|
+
if modifier_calibrator_config is not None:
|
|
310
|
+
assert isinstance(
|
|
311
|
+
modifier_calibrator_config, CalibratorConfig
|
|
312
|
+
), (
|
|
313
|
+
f"calibrator_config must be CalibratorConfig, but got "
|
|
314
|
+
f"{type(modifier_calibrator_config)}."
|
|
315
|
+
)
|
|
316
|
+
if (
|
|
317
|
+
contexts_kwargs[i].get("calibrator_config", None)
|
|
318
|
+
is None
|
|
319
|
+
):
|
|
320
|
+
contexts_kwargs[i][
|
|
321
|
+
"calibrator_config"
|
|
322
|
+
] = modifier_calibrator_config
|
|
323
|
+
else:
|
|
324
|
+
contexts_kwargs[i]["calibrator_config"].update(
|
|
325
|
+
**modifier_calibrator_config.as_dict()
|
|
326
|
+
)
|
|
257
327
|
cls._config_messages(**contexts_kwargs[i])
|
|
258
328
|
|
|
259
329
|
return flatten_contexts, contexts_kwargs
|
|
@@ -267,7 +337,7 @@ class CachedAdapter:
|
|
|
267
337
|
"calibrator_config", None
|
|
268
338
|
)
|
|
269
339
|
if cache_config is not None:
|
|
270
|
-
message = f"Collected
|
|
340
|
+
message = f"Collected Context Config: {cache_config.strify()}"
|
|
271
341
|
if calibrator_config is not None:
|
|
272
342
|
message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
|
|
273
343
|
else:
|
|
@@ -278,6 +348,7 @@ class CachedAdapter:
|
|
|
278
348
|
def mock_blocks(
|
|
279
349
|
cls,
|
|
280
350
|
block_adapter: BlockAdapter,
|
|
351
|
+
contexts_kwargs: List[Dict],
|
|
281
352
|
) -> List[torch.nn.Module]:
|
|
282
353
|
|
|
283
354
|
BlockAdapter.assert_normalized(block_adapter)
|
|
@@ -287,24 +358,28 @@ class CachedAdapter:
|
|
|
287
358
|
|
|
288
359
|
# Apply cache on transformer: mock cached transformer blocks
|
|
289
360
|
for (
|
|
290
|
-
|
|
361
|
+
unified_blocks,
|
|
291
362
|
transformer,
|
|
292
363
|
blocks_name,
|
|
293
364
|
unique_blocks_name,
|
|
294
365
|
dummy_blocks_names,
|
|
295
366
|
) in zip(
|
|
296
|
-
cls.
|
|
367
|
+
cls.collect_unified_blocks(
|
|
368
|
+
block_adapter,
|
|
369
|
+
contexts_kwargs,
|
|
370
|
+
),
|
|
297
371
|
block_adapter.transformer,
|
|
298
372
|
block_adapter.blocks_name,
|
|
299
373
|
block_adapter.unique_blocks_name,
|
|
300
374
|
block_adapter.dummy_blocks_names,
|
|
301
375
|
):
|
|
302
376
|
cls.mock_transformer(
|
|
303
|
-
|
|
377
|
+
unified_blocks,
|
|
304
378
|
transformer,
|
|
305
379
|
blocks_name,
|
|
306
380
|
unique_blocks_name,
|
|
307
381
|
dummy_blocks_names,
|
|
382
|
+
block_adapter,
|
|
308
383
|
)
|
|
309
384
|
|
|
310
385
|
return block_adapter.transformer
|
|
@@ -312,11 +387,12 @@ class CachedAdapter:
|
|
|
312
387
|
@classmethod
|
|
313
388
|
def mock_transformer(
|
|
314
389
|
cls,
|
|
315
|
-
|
|
390
|
+
unified_blocks: Dict[str, torch.nn.ModuleList],
|
|
316
391
|
transformer: torch.nn.Module,
|
|
317
392
|
blocks_name: List[str],
|
|
318
393
|
unique_blocks_name: List[str],
|
|
319
394
|
dummy_blocks_names: List[str],
|
|
395
|
+
block_adapter: BlockAdapter,
|
|
320
396
|
) -> torch.nn.Module:
|
|
321
397
|
dummy_blocks = torch.nn.ModuleList()
|
|
322
398
|
|
|
@@ -343,6 +419,8 @@ class CachedAdapter:
|
|
|
343
419
|
# re-apply hooks to transformer after cache applied.
|
|
344
420
|
# from diffusers.hooks.hooks import HookFunctionReference, HookRegistry
|
|
345
421
|
# from diffusers.hooks.group_offloading import apply_group_offloading
|
|
422
|
+
context_manager: ContextManager = block_adapter.pipe._context_manager
|
|
423
|
+
assert isinstance(context_manager, ContextManager._supported_managers)
|
|
346
424
|
|
|
347
425
|
def new_forward(self, *args, **kwargs):
|
|
348
426
|
with ExitStack() as stack:
|
|
@@ -352,7 +430,7 @@ class CachedAdapter:
|
|
|
352
430
|
):
|
|
353
431
|
stack.enter_context(
|
|
354
432
|
unittest.mock.patch.object(
|
|
355
|
-
self, name,
|
|
433
|
+
self, name, unified_blocks[context_name]
|
|
356
434
|
)
|
|
357
435
|
)
|
|
358
436
|
for dummy_name in dummy_blocks_names:
|
|
@@ -362,6 +440,13 @@ class CachedAdapter:
|
|
|
362
440
|
)
|
|
363
441
|
)
|
|
364
442
|
outputs = original_forward(*args, **kwargs)
|
|
443
|
+
|
|
444
|
+
if (
|
|
445
|
+
context_manager.persistent_context
|
|
446
|
+
and context_manager.is_pre_refreshed()
|
|
447
|
+
):
|
|
448
|
+
cls.apply_stats_hooks(block_adapter)
|
|
449
|
+
|
|
365
450
|
return outputs
|
|
366
451
|
|
|
367
452
|
def new_forward_with_hf_hook(self, *args, **kwargs):
|
|
@@ -388,46 +473,51 @@ class CachedAdapter:
|
|
|
388
473
|
return transformer
|
|
389
474
|
|
|
390
475
|
@classmethod
|
|
391
|
-
def
|
|
476
|
+
def collect_unified_blocks(
|
|
392
477
|
cls,
|
|
393
478
|
block_adapter: BlockAdapter,
|
|
479
|
+
contexts_kwargs: List[Dict],
|
|
394
480
|
) -> List[Dict[str, torch.nn.ModuleList]]:
|
|
395
481
|
|
|
396
482
|
BlockAdapter.assert_normalized(block_adapter)
|
|
397
483
|
|
|
398
484
|
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
399
|
-
assert hasattr(block_adapter.pipe, "
|
|
485
|
+
assert hasattr(block_adapter.pipe, "_context_manager")
|
|
400
486
|
assert isinstance(
|
|
401
|
-
block_adapter.pipe.
|
|
402
|
-
|
|
487
|
+
block_adapter.pipe._context_manager,
|
|
488
|
+
ContextManager._supported_managers,
|
|
403
489
|
)
|
|
404
490
|
|
|
405
491
|
for i in range(len(block_adapter.transformer)):
|
|
406
492
|
|
|
407
|
-
|
|
493
|
+
unified_blocks_bind_context = {}
|
|
408
494
|
for j in range(len(block_adapter.blocks[i])):
|
|
409
|
-
|
|
495
|
+
cache_config: BasicCacheConfig = contexts_kwargs[
|
|
496
|
+
i * len(block_adapter.blocks[i]) + j
|
|
497
|
+
]["cache_config"]
|
|
498
|
+
unified_blocks_bind_context[
|
|
410
499
|
block_adapter.unique_blocks_name[i][j]
|
|
411
500
|
] = torch.nn.ModuleList(
|
|
412
501
|
[
|
|
413
|
-
|
|
502
|
+
UnifiedBlocks(
|
|
414
503
|
# 0. Transformer blocks configuration
|
|
415
504
|
block_adapter.blocks[i][j],
|
|
416
505
|
transformer=block_adapter.transformer[i],
|
|
417
506
|
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
418
507
|
check_forward_pattern=block_adapter.check_forward_pattern,
|
|
419
508
|
check_num_outputs=block_adapter.check_num_outputs,
|
|
420
|
-
# 1. Cache context configuration
|
|
509
|
+
# 1. Cache/Prune context configuration
|
|
421
510
|
cache_prefix=block_adapter.blocks_name[i][j],
|
|
422
511
|
cache_context=block_adapter.unique_blocks_name[i][
|
|
423
512
|
j
|
|
424
513
|
],
|
|
425
|
-
|
|
514
|
+
context_manager=block_adapter.pipe._context_manager,
|
|
515
|
+
cache_type=cache_config.cache_type,
|
|
426
516
|
)
|
|
427
517
|
]
|
|
428
518
|
)
|
|
429
519
|
|
|
430
|
-
total_cached_blocks.append(
|
|
520
|
+
total_cached_blocks.append(unified_blocks_bind_context)
|
|
431
521
|
|
|
432
522
|
return total_cached_blocks
|
|
433
523
|
|
|
@@ -437,7 +527,7 @@ class CachedAdapter:
|
|
|
437
527
|
block_adapter: BlockAdapter,
|
|
438
528
|
contexts_kwargs: List[Dict],
|
|
439
529
|
):
|
|
440
|
-
block_adapter.pipe.
|
|
530
|
+
block_adapter.pipe._context_kwargs = contexts_kwargs[0]
|
|
441
531
|
|
|
442
532
|
params_shift = 0
|
|
443
533
|
for i in range(len(block_adapter.transformer)):
|
|
@@ -448,44 +538,43 @@ class CachedAdapter:
|
|
|
448
538
|
block_adapter.transformer[i]._has_separate_cfg = (
|
|
449
539
|
block_adapter.has_separate_cfg
|
|
450
540
|
)
|
|
451
|
-
block_adapter.transformer[i].
|
|
452
|
-
|
|
453
|
-
|
|
541
|
+
block_adapter.transformer[i]._context_kwargs = contexts_kwargs[
|
|
542
|
+
params_shift
|
|
543
|
+
]
|
|
454
544
|
|
|
455
545
|
blocks = block_adapter.blocks[i]
|
|
456
546
|
for j in range(len(blocks)):
|
|
457
547
|
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
458
|
-
blocks[j].
|
|
459
|
-
params_shift + j
|
|
460
|
-
]
|
|
548
|
+
blocks[j]._context_kwargs = contexts_kwargs[params_shift + j]
|
|
461
549
|
|
|
462
550
|
params_shift += len(blocks)
|
|
463
551
|
|
|
464
552
|
@classmethod
|
|
553
|
+
@torch.compiler.disable
|
|
465
554
|
def apply_stats_hooks(
|
|
466
555
|
cls,
|
|
467
556
|
block_adapter: BlockAdapter,
|
|
468
557
|
):
|
|
469
|
-
from cache_dit.
|
|
470
|
-
|
|
558
|
+
from cache_dit.caching.cache_blocks import (
|
|
559
|
+
apply_stats,
|
|
471
560
|
)
|
|
472
561
|
|
|
473
|
-
|
|
562
|
+
context_manager = block_adapter.pipe._context_manager
|
|
474
563
|
|
|
475
564
|
for i in range(len(block_adapter.transformer)):
|
|
476
|
-
|
|
565
|
+
apply_stats(
|
|
477
566
|
block_adapter.transformer[i],
|
|
478
567
|
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
479
|
-
|
|
568
|
+
context_manager=context_manager,
|
|
480
569
|
)
|
|
481
570
|
for blocks, unique_name in zip(
|
|
482
571
|
block_adapter.blocks[i],
|
|
483
572
|
block_adapter.unique_blocks_name[i],
|
|
484
573
|
):
|
|
485
|
-
|
|
574
|
+
apply_stats(
|
|
486
575
|
blocks,
|
|
487
576
|
cache_context=unique_name,
|
|
488
|
-
|
|
577
|
+
context_manager=context_manager,
|
|
489
578
|
)
|
|
490
579
|
|
|
491
580
|
@classmethod
|
|
@@ -513,11 +602,13 @@ class CachedAdapter:
|
|
|
513
602
|
original_call = pipe.__class__._original_call
|
|
514
603
|
pipe.__class__.__call__ = original_call
|
|
515
604
|
del pipe.__class__._original_call
|
|
516
|
-
if hasattr(pipe, "
|
|
517
|
-
|
|
518
|
-
if isinstance(
|
|
519
|
-
|
|
520
|
-
|
|
605
|
+
if hasattr(pipe, "_context_manager"):
|
|
606
|
+
context_manager = pipe._context_manager
|
|
607
|
+
if isinstance(
|
|
608
|
+
context_manager, ContextManager._supported_managers
|
|
609
|
+
):
|
|
610
|
+
context_manager.clear_contexts()
|
|
611
|
+
del pipe._context_manager
|
|
521
612
|
if hasattr(pipe, "_is_cached"):
|
|
522
613
|
del pipe.__class__._is_cached
|
|
523
614
|
|
|
@@ -532,22 +623,22 @@ class CachedAdapter:
|
|
|
532
623
|
def _release_blocks_params(blocks):
|
|
533
624
|
if hasattr(blocks, "_forward_pattern"):
|
|
534
625
|
del blocks._forward_pattern
|
|
535
|
-
if hasattr(blocks, "
|
|
536
|
-
del blocks.
|
|
626
|
+
if hasattr(blocks, "_context_kwargs"):
|
|
627
|
+
del blocks._context_kwargs
|
|
537
628
|
|
|
538
629
|
def _release_transformer_params(transformer):
|
|
539
630
|
if hasattr(transformer, "_forward_pattern"):
|
|
540
631
|
del transformer._forward_pattern
|
|
541
632
|
if hasattr(transformer, "_has_separate_cfg"):
|
|
542
633
|
del transformer._has_separate_cfg
|
|
543
|
-
if hasattr(transformer, "
|
|
544
|
-
del transformer.
|
|
634
|
+
if hasattr(transformer, "_context_kwargs"):
|
|
635
|
+
del transformer._context_kwargs
|
|
545
636
|
for blocks in BlockAdapter.find_blocks(transformer):
|
|
546
637
|
_release_blocks_params(blocks)
|
|
547
638
|
|
|
548
639
|
def _release_pipeline_params(pipe):
|
|
549
|
-
if hasattr(pipe, "
|
|
550
|
-
del pipe.
|
|
640
|
+
if hasattr(pipe, "_context_kwargs"):
|
|
641
|
+
del pipe._context_kwargs
|
|
551
642
|
|
|
552
643
|
cls.release_hooks(
|
|
553
644
|
pipe_or_adapter,
|
|
@@ -557,15 +648,24 @@ class CachedAdapter:
|
|
|
557
648
|
)
|
|
558
649
|
|
|
559
650
|
# release stats hooks
|
|
560
|
-
from cache_dit.
|
|
561
|
-
|
|
651
|
+
from cache_dit.caching.cache_blocks import (
|
|
652
|
+
remove_stats,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
cls.release_hooks(
|
|
656
|
+
pipe_or_adapter, remove_stats, remove_stats, remove_stats
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
# maybe release parallelism stats
|
|
660
|
+
from cache_dit.parallelism.parallel_interface import (
|
|
661
|
+
remove_parallelism_stats,
|
|
562
662
|
)
|
|
563
663
|
|
|
564
664
|
cls.release_hooks(
|
|
565
665
|
pipe_or_adapter,
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
666
|
+
remove_parallelism_stats,
|
|
667
|
+
remove_parallelism_stats,
|
|
668
|
+
remove_parallelism_stats,
|
|
569
669
|
)
|
|
570
670
|
|
|
571
671
|
@classmethod
|