cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -1,25 +1,21 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import torch
|
|
2
|
-
|
|
3
3
|
import unittest
|
|
4
4
|
import functools
|
|
5
|
-
|
|
6
5
|
from contextlib import ExitStack
|
|
7
|
-
from typing import Dict, List, Tuple, Any, Union, Callable
|
|
8
|
-
|
|
9
|
-
from diffusers import DiffusionPipeline
|
|
10
|
-
|
|
11
|
-
from cache_dit.
|
|
12
|
-
from cache_dit.
|
|
13
|
-
from cache_dit.
|
|
14
|
-
from cache_dit.
|
|
15
|
-
from cache_dit.
|
|
16
|
-
from cache_dit.
|
|
17
|
-
from cache_dit.
|
|
18
|
-
from cache_dit.
|
|
19
|
-
from cache_dit.
|
|
20
|
-
patch_cached_stats,
|
|
21
|
-
remove_cached_stats,
|
|
22
|
-
)
|
|
6
|
+
from typing import Dict, List, Tuple, Any, Union, Callable, Optional
|
|
7
|
+
|
|
8
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
9
|
+
|
|
10
|
+
from cache_dit.caching.cache_types import CacheType
|
|
11
|
+
from cache_dit.caching.block_adapters import BlockAdapter
|
|
12
|
+
from cache_dit.caching.block_adapters import FakeDiffusionPipeline
|
|
13
|
+
from cache_dit.caching.block_adapters import ParamsModifier
|
|
14
|
+
from cache_dit.caching.block_adapters import BlockAdapterRegistry
|
|
15
|
+
from cache_dit.caching.cache_contexts import ContextManager
|
|
16
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
17
|
+
from cache_dit.caching.cache_contexts import CalibratorConfig
|
|
18
|
+
from cache_dit.caching.cache_blocks import UnifiedBlocks
|
|
23
19
|
from cache_dit.logger import init_logger
|
|
24
20
|
|
|
25
21
|
logger = init_logger(__name__)
|
|
@@ -37,8 +33,11 @@ class CachedAdapter:
|
|
|
37
33
|
pipe_or_adapter: Union[
|
|
38
34
|
DiffusionPipeline,
|
|
39
35
|
BlockAdapter,
|
|
36
|
+
# Transformer-only
|
|
37
|
+
torch.nn.Module,
|
|
38
|
+
ModelMixin,
|
|
40
39
|
],
|
|
41
|
-
**
|
|
40
|
+
**context_kwargs,
|
|
42
41
|
) -> Union[
|
|
43
42
|
DiffusionPipeline,
|
|
44
43
|
BlockAdapter,
|
|
@@ -47,7 +46,9 @@ class CachedAdapter:
|
|
|
47
46
|
pipe_or_adapter is not None
|
|
48
47
|
), "pipe or block_adapter can not both None!"
|
|
49
48
|
|
|
50
|
-
if isinstance(
|
|
49
|
+
if isinstance(
|
|
50
|
+
pipe_or_adapter, (DiffusionPipeline, torch.nn.Module, ModelMixin)
|
|
51
|
+
):
|
|
51
52
|
if BlockAdapterRegistry.is_supported(pipe_or_adapter):
|
|
52
53
|
logger.info(
|
|
53
54
|
f"{pipe_or_adapter.__class__.__name__} is officially "
|
|
@@ -57,16 +58,22 @@ class CachedAdapter:
|
|
|
57
58
|
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
58
59
|
pipe_or_adapter
|
|
59
60
|
)
|
|
60
|
-
|
|
61
|
+
assert block_adapter is not None, (
|
|
62
|
+
f"BlockAdapter for {pipe_or_adapter.__class__.__name__} "
|
|
63
|
+
"should not be None!"
|
|
64
|
+
)
|
|
65
|
+
if params_modifiers := context_kwargs.pop(
|
|
61
66
|
"params_modifiers",
|
|
62
67
|
None,
|
|
63
68
|
):
|
|
64
69
|
block_adapter.params_modifiers = params_modifiers
|
|
65
70
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
71
|
+
block_adapter = cls.cachify(block_adapter, **context_kwargs)
|
|
72
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
73
|
+
return block_adapter.pipe
|
|
74
|
+
|
|
75
|
+
return block_adapter.transformer
|
|
76
|
+
|
|
70
77
|
else:
|
|
71
78
|
raise ValueError(
|
|
72
79
|
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
@@ -78,21 +85,21 @@ class CachedAdapter:
|
|
|
78
85
|
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
79
86
|
)
|
|
80
87
|
if pipe_or_adapter.params_modifiers is None:
|
|
81
|
-
if params_modifiers :=
|
|
88
|
+
if params_modifiers := context_kwargs.pop(
|
|
82
89
|
"params_modifiers", None
|
|
83
90
|
):
|
|
84
91
|
pipe_or_adapter.params_modifiers = params_modifiers
|
|
85
92
|
|
|
86
93
|
return cls.cachify(
|
|
87
94
|
pipe_or_adapter,
|
|
88
|
-
**
|
|
95
|
+
**context_kwargs,
|
|
89
96
|
)
|
|
90
97
|
|
|
91
98
|
@classmethod
|
|
92
99
|
def cachify(
|
|
93
100
|
cls,
|
|
94
101
|
block_adapter: BlockAdapter,
|
|
95
|
-
**
|
|
102
|
+
**context_kwargs,
|
|
96
103
|
) -> BlockAdapter:
|
|
97
104
|
|
|
98
105
|
if block_adapter.auto:
|
|
@@ -109,14 +116,15 @@ class CachedAdapter:
|
|
|
109
116
|
|
|
110
117
|
# 1. Apply cache on pipeline: wrap cache context, must
|
|
111
118
|
# call create_context before mock_blocks.
|
|
112
|
-
cls.create_context(
|
|
119
|
+
_, contexts_kwargs = cls.create_context(
|
|
113
120
|
block_adapter,
|
|
114
|
-
**
|
|
121
|
+
**context_kwargs,
|
|
115
122
|
)
|
|
116
123
|
|
|
117
124
|
# 2. Apply cache on transformer: mock cached blocks
|
|
118
125
|
cls.mock_blocks(
|
|
119
126
|
block_adapter,
|
|
127
|
+
contexts_kwargs,
|
|
120
128
|
)
|
|
121
129
|
|
|
122
130
|
return block_adapter
|
|
@@ -125,12 +133,10 @@ class CachedAdapter:
|
|
|
125
133
|
def check_context_kwargs(
|
|
126
134
|
cls,
|
|
127
135
|
block_adapter: BlockAdapter,
|
|
128
|
-
**
|
|
136
|
+
**context_kwargs,
|
|
129
137
|
):
|
|
130
|
-
# Check
|
|
131
|
-
cache_config: BasicCacheConfig =
|
|
132
|
-
"cache_config"
|
|
133
|
-
] # ref
|
|
138
|
+
# Check context_kwargs
|
|
139
|
+
cache_config: BasicCacheConfig = context_kwargs["cache_config"] # ref
|
|
134
140
|
assert cache_config is not None, "cache_config can not be None."
|
|
135
141
|
if cache_config.enable_separate_cfg is None:
|
|
136
142
|
# Check cfg for some specific case if users don't set it as True
|
|
@@ -156,87 +162,113 @@ class CachedAdapter:
|
|
|
156
162
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
157
163
|
)
|
|
158
164
|
|
|
159
|
-
cache_type =
|
|
165
|
+
cache_type = context_kwargs.pop("cache_type", None)
|
|
160
166
|
if cache_type is not None:
|
|
161
|
-
assert (
|
|
162
|
-
cache_type
|
|
163
|
-
), "
|
|
167
|
+
assert isinstance(
|
|
168
|
+
cache_type, CacheType
|
|
169
|
+
), f"cache_type must be CacheType Enum, but got {type(cache_type)}."
|
|
170
|
+
assert cache_type == cache_config.cache_type, (
|
|
171
|
+
f"cache_type from context_kwargs ({cache_type}) must be the same "
|
|
172
|
+
f"as that from cache_config ({cache_config.cache_type})."
|
|
173
|
+
)
|
|
164
174
|
|
|
165
|
-
return
|
|
175
|
+
return context_kwargs
|
|
166
176
|
|
|
167
177
|
@classmethod
|
|
168
178
|
def create_context(
|
|
169
179
|
cls,
|
|
170
180
|
block_adapter: BlockAdapter,
|
|
171
|
-
**
|
|
172
|
-
) ->
|
|
181
|
+
**context_kwargs,
|
|
182
|
+
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
173
183
|
|
|
174
184
|
BlockAdapter.assert_normalized(block_adapter)
|
|
175
185
|
|
|
176
186
|
if BlockAdapter.is_cached(block_adapter.pipe):
|
|
177
187
|
return block_adapter.pipe
|
|
178
188
|
|
|
179
|
-
# Check
|
|
180
|
-
|
|
181
|
-
block_adapter, **
|
|
189
|
+
# Check context_kwargs
|
|
190
|
+
context_kwargs = cls.check_context_kwargs(
|
|
191
|
+
block_adapter, **context_kwargs
|
|
182
192
|
)
|
|
183
|
-
# Apply cache on pipeline: wrap cache context
|
|
184
|
-
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
185
193
|
|
|
186
194
|
# Each Pipeline should have it's own context manager instance.
|
|
187
195
|
# Different transformers (Wan2.2, etc) should shared the same
|
|
188
196
|
# cache manager but with different cache context (according
|
|
189
197
|
# to their unique instance id).
|
|
190
|
-
|
|
198
|
+
cache_config: BasicCacheConfig = context_kwargs.get(
|
|
199
|
+
"cache_config", None
|
|
200
|
+
)
|
|
201
|
+
assert cache_config is not None, "cache_config can not be None."
|
|
202
|
+
# Apply cache on pipeline: wrap cache context
|
|
203
|
+
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
204
|
+
context_manager = ContextManager(
|
|
191
205
|
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
206
|
+
cache_type=cache_config.cache_type,
|
|
207
|
+
# Force use persistent_context for FakeDiffusionPipeline
|
|
208
|
+
persistent_context=isinstance(
|
|
209
|
+
block_adapter.pipe, FakeDiffusionPipeline
|
|
210
|
+
),
|
|
192
211
|
)
|
|
193
|
-
block_adapter.pipe._cache_manager = cache_manager # instance level
|
|
194
|
-
|
|
195
212
|
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
196
|
-
block_adapter, **
|
|
213
|
+
block_adapter, **context_kwargs
|
|
197
214
|
)
|
|
198
215
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
):
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
216
|
+
block_adapter.pipe._context_manager = context_manager # instance level
|
|
217
|
+
|
|
218
|
+
if not context_manager.persistent_context:
|
|
219
|
+
|
|
220
|
+
original_call = block_adapter.pipe.__class__.__call__
|
|
221
|
+
|
|
222
|
+
@functools.wraps(original_call)
|
|
223
|
+
def new_call(self, *args, **kwargs):
|
|
224
|
+
with ExitStack() as stack:
|
|
225
|
+
# cache context will be reset for each pipe inference
|
|
226
|
+
for context_name, context_kwargs in zip(
|
|
227
|
+
flatten_contexts, contexts_kwargs
|
|
228
|
+
):
|
|
229
|
+
stack.enter_context(
|
|
230
|
+
context_manager.enter_context(
|
|
231
|
+
context_manager.reset_context(
|
|
232
|
+
context_name,
|
|
233
|
+
**context_kwargs,
|
|
234
|
+
),
|
|
235
|
+
)
|
|
214
236
|
)
|
|
215
|
-
)
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
237
|
+
outputs = original_call(self, *args, **kwargs)
|
|
238
|
+
cls.apply_stats_hooks(block_adapter)
|
|
239
|
+
return outputs
|
|
240
|
+
|
|
241
|
+
block_adapter.pipe.__class__.__call__ = new_call
|
|
242
|
+
block_adapter.pipe.__class__._original_call = original_call
|
|
243
|
+
|
|
244
|
+
else:
|
|
245
|
+
# Init persistent cache context for transformer
|
|
246
|
+
for context_name, context_kwargs in zip(
|
|
247
|
+
flatten_contexts, contexts_kwargs
|
|
248
|
+
):
|
|
249
|
+
context_manager.reset_context(
|
|
250
|
+
context_name,
|
|
251
|
+
**context_kwargs,
|
|
252
|
+
)
|
|
219
253
|
|
|
220
|
-
block_adapter.pipe.__class__.__call__ = new_call
|
|
221
|
-
block_adapter.pipe.__class__._original_call = original_call
|
|
222
254
|
block_adapter.pipe.__class__._is_cached = True
|
|
223
255
|
|
|
224
256
|
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
225
257
|
|
|
226
|
-
return
|
|
258
|
+
return flatten_contexts, contexts_kwargs
|
|
227
259
|
|
|
228
260
|
@classmethod
|
|
229
261
|
def modify_context_params(
|
|
230
262
|
cls,
|
|
231
263
|
block_adapter: BlockAdapter,
|
|
232
|
-
**
|
|
264
|
+
**context_kwargs,
|
|
233
265
|
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
234
266
|
|
|
235
267
|
flatten_contexts = BlockAdapter.flatten(
|
|
236
268
|
block_adapter.unique_blocks_name
|
|
237
269
|
)
|
|
238
270
|
contexts_kwargs = [
|
|
239
|
-
|
|
271
|
+
copy.deepcopy(context_kwargs) # must deep copy
|
|
240
272
|
for _ in range(
|
|
241
273
|
len(flatten_contexts),
|
|
242
274
|
)
|
|
@@ -257,9 +289,41 @@ class CachedAdapter:
|
|
|
257
289
|
for i in range(
|
|
258
290
|
min(len(contexts_kwargs), len(flatten_modifiers)),
|
|
259
291
|
):
|
|
260
|
-
|
|
261
|
-
flatten_modifiers[
|
|
262
|
-
|
|
292
|
+
if "cache_config" in flatten_modifiers[i]._context_kwargs:
|
|
293
|
+
modifier_cache_config = flatten_modifiers[
|
|
294
|
+
i
|
|
295
|
+
]._context_kwargs.get("cache_config", None)
|
|
296
|
+
modifier_calibrator_config = flatten_modifiers[
|
|
297
|
+
i
|
|
298
|
+
]._context_kwargs.get("calibrator_config", None)
|
|
299
|
+
if modifier_cache_config is not None:
|
|
300
|
+
assert isinstance(
|
|
301
|
+
modifier_cache_config, BasicCacheConfig
|
|
302
|
+
), (
|
|
303
|
+
f"cache_config must be BasicCacheConfig, but got "
|
|
304
|
+
f"{type(modifier_cache_config)}."
|
|
305
|
+
)
|
|
306
|
+
contexts_kwargs[i]["cache_config"].update(
|
|
307
|
+
**modifier_cache_config.as_dict()
|
|
308
|
+
)
|
|
309
|
+
if modifier_calibrator_config is not None:
|
|
310
|
+
assert isinstance(
|
|
311
|
+
modifier_calibrator_config, CalibratorConfig
|
|
312
|
+
), (
|
|
313
|
+
f"calibrator_config must be CalibratorConfig, but got "
|
|
314
|
+
f"{type(modifier_calibrator_config)}."
|
|
315
|
+
)
|
|
316
|
+
if (
|
|
317
|
+
contexts_kwargs[i].get("calibrator_config", None)
|
|
318
|
+
is None
|
|
319
|
+
):
|
|
320
|
+
contexts_kwargs[i][
|
|
321
|
+
"calibrator_config"
|
|
322
|
+
] = modifier_calibrator_config
|
|
323
|
+
else:
|
|
324
|
+
contexts_kwargs[i]["calibrator_config"].update(
|
|
325
|
+
**modifier_calibrator_config.as_dict()
|
|
326
|
+
)
|
|
263
327
|
cls._config_messages(**contexts_kwargs[i])
|
|
264
328
|
|
|
265
329
|
return flatten_contexts, contexts_kwargs
|
|
@@ -273,7 +337,7 @@ class CachedAdapter:
|
|
|
273
337
|
"calibrator_config", None
|
|
274
338
|
)
|
|
275
339
|
if cache_config is not None:
|
|
276
|
-
message = f"Collected
|
|
340
|
+
message = f"Collected Context Config: {cache_config.strify()}"
|
|
277
341
|
if calibrator_config is not None:
|
|
278
342
|
message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
|
|
279
343
|
else:
|
|
@@ -284,6 +348,7 @@ class CachedAdapter:
|
|
|
284
348
|
def mock_blocks(
|
|
285
349
|
cls,
|
|
286
350
|
block_adapter: BlockAdapter,
|
|
351
|
+
contexts_kwargs: List[Dict],
|
|
287
352
|
) -> List[torch.nn.Module]:
|
|
288
353
|
|
|
289
354
|
BlockAdapter.assert_normalized(block_adapter)
|
|
@@ -293,24 +358,28 @@ class CachedAdapter:
|
|
|
293
358
|
|
|
294
359
|
# Apply cache on transformer: mock cached transformer blocks
|
|
295
360
|
for (
|
|
296
|
-
|
|
361
|
+
unified_blocks,
|
|
297
362
|
transformer,
|
|
298
363
|
blocks_name,
|
|
299
364
|
unique_blocks_name,
|
|
300
365
|
dummy_blocks_names,
|
|
301
366
|
) in zip(
|
|
302
|
-
cls.
|
|
367
|
+
cls.collect_unified_blocks(
|
|
368
|
+
block_adapter,
|
|
369
|
+
contexts_kwargs,
|
|
370
|
+
),
|
|
303
371
|
block_adapter.transformer,
|
|
304
372
|
block_adapter.blocks_name,
|
|
305
373
|
block_adapter.unique_blocks_name,
|
|
306
374
|
block_adapter.dummy_blocks_names,
|
|
307
375
|
):
|
|
308
376
|
cls.mock_transformer(
|
|
309
|
-
|
|
377
|
+
unified_blocks,
|
|
310
378
|
transformer,
|
|
311
379
|
blocks_name,
|
|
312
380
|
unique_blocks_name,
|
|
313
381
|
dummy_blocks_names,
|
|
382
|
+
block_adapter,
|
|
314
383
|
)
|
|
315
384
|
|
|
316
385
|
return block_adapter.transformer
|
|
@@ -318,11 +387,12 @@ class CachedAdapter:
|
|
|
318
387
|
@classmethod
|
|
319
388
|
def mock_transformer(
|
|
320
389
|
cls,
|
|
321
|
-
|
|
390
|
+
unified_blocks: Dict[str, torch.nn.ModuleList],
|
|
322
391
|
transformer: torch.nn.Module,
|
|
323
392
|
blocks_name: List[str],
|
|
324
393
|
unique_blocks_name: List[str],
|
|
325
394
|
dummy_blocks_names: List[str],
|
|
395
|
+
block_adapter: BlockAdapter,
|
|
326
396
|
) -> torch.nn.Module:
|
|
327
397
|
dummy_blocks = torch.nn.ModuleList()
|
|
328
398
|
|
|
@@ -330,7 +400,28 @@ class CachedAdapter:
|
|
|
330
400
|
|
|
331
401
|
assert isinstance(dummy_blocks_names, list)
|
|
332
402
|
|
|
333
|
-
|
|
403
|
+
from accelerate import hooks
|
|
404
|
+
|
|
405
|
+
_hf_hook: Optional[hooks.ModelHook] = None
|
|
406
|
+
|
|
407
|
+
if getattr(transformer, "_hf_hook", None) is not None:
|
|
408
|
+
_hf_hook = transformer._hf_hook # hooks from accelerate.hooks
|
|
409
|
+
if hasattr(transformer, "_old_forward"):
|
|
410
|
+
logger.warning(
|
|
411
|
+
"_hf_hook is not None, so, we have to re-direct transformer's "
|
|
412
|
+
f"original_forward({id(original_forward)}) to transformer's "
|
|
413
|
+
f"_old_forward({id(transformer._old_forward)})"
|
|
414
|
+
)
|
|
415
|
+
original_forward = transformer._old_forward
|
|
416
|
+
|
|
417
|
+
# TODO: remove group offload hooks the re-apply after cache applied.
|
|
418
|
+
# hooks = _diffusers_hook.hooks.copy(); _diffusers_hook.hooks.clear()
|
|
419
|
+
# re-apply hooks to transformer after cache applied.
|
|
420
|
+
# from diffusers.hooks.hooks import HookFunctionReference, HookRegistry
|
|
421
|
+
# from diffusers.hooks.group_offloading import apply_group_offloading
|
|
422
|
+
context_manager: ContextManager = block_adapter.pipe._context_manager
|
|
423
|
+
assert isinstance(context_manager, ContextManager._supported_managers)
|
|
424
|
+
|
|
334
425
|
def new_forward(self, *args, **kwargs):
|
|
335
426
|
with ExitStack() as stack:
|
|
336
427
|
for name, context_name in zip(
|
|
@@ -339,7 +430,7 @@ class CachedAdapter:
|
|
|
339
430
|
):
|
|
340
431
|
stack.enter_context(
|
|
341
432
|
unittest.mock.patch.object(
|
|
342
|
-
self, name,
|
|
433
|
+
self, name, unified_blocks[context_name]
|
|
343
434
|
)
|
|
344
435
|
)
|
|
345
436
|
for dummy_name in dummy_blocks_names:
|
|
@@ -348,55 +439,85 @@ class CachedAdapter:
|
|
|
348
439
|
self, dummy_name, dummy_blocks
|
|
349
440
|
)
|
|
350
441
|
)
|
|
351
|
-
|
|
442
|
+
outputs = original_forward(*args, **kwargs)
|
|
443
|
+
|
|
444
|
+
if (
|
|
445
|
+
context_manager.persistent_context
|
|
446
|
+
and context_manager.is_pre_refreshed()
|
|
447
|
+
):
|
|
448
|
+
cls.apply_stats_hooks(block_adapter)
|
|
449
|
+
|
|
450
|
+
return outputs
|
|
451
|
+
|
|
452
|
+
def new_forward_with_hf_hook(self, *args, **kwargs):
|
|
453
|
+
# Compatible with model cpu offload
|
|
454
|
+
if _hf_hook is not None and hasattr(_hf_hook, "pre_forward"):
|
|
455
|
+
args, kwargs = _hf_hook.pre_forward(self, *args, **kwargs)
|
|
456
|
+
|
|
457
|
+
outputs = new_forward(self, *args, **kwargs)
|
|
458
|
+
|
|
459
|
+
if _hf_hook is not None and hasattr(_hf_hook, "post_forward"):
|
|
460
|
+
outputs = _hf_hook.post_forward(self, outputs)
|
|
461
|
+
|
|
462
|
+
return outputs
|
|
463
|
+
|
|
464
|
+
# NOTE: Still can't fully compatible with group offloading
|
|
465
|
+
transformer.forward = functools.update_wrapper(
|
|
466
|
+
functools.partial(new_forward_with_hf_hook, transformer),
|
|
467
|
+
new_forward_with_hf_hook,
|
|
468
|
+
)
|
|
352
469
|
|
|
353
|
-
transformer.forward = new_forward.__get__(transformer)
|
|
354
470
|
transformer._original_forward = original_forward
|
|
355
471
|
transformer._is_cached = True
|
|
356
472
|
|
|
357
473
|
return transformer
|
|
358
474
|
|
|
359
475
|
@classmethod
|
|
360
|
-
def
|
|
476
|
+
def collect_unified_blocks(
|
|
361
477
|
cls,
|
|
362
478
|
block_adapter: BlockAdapter,
|
|
479
|
+
contexts_kwargs: List[Dict],
|
|
363
480
|
) -> List[Dict[str, torch.nn.ModuleList]]:
|
|
364
481
|
|
|
365
482
|
BlockAdapter.assert_normalized(block_adapter)
|
|
366
483
|
|
|
367
484
|
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
368
|
-
assert hasattr(block_adapter.pipe, "
|
|
485
|
+
assert hasattr(block_adapter.pipe, "_context_manager")
|
|
369
486
|
assert isinstance(
|
|
370
|
-
block_adapter.pipe.
|
|
371
|
-
|
|
487
|
+
block_adapter.pipe._context_manager,
|
|
488
|
+
ContextManager._supported_managers,
|
|
372
489
|
)
|
|
373
490
|
|
|
374
491
|
for i in range(len(block_adapter.transformer)):
|
|
375
492
|
|
|
376
|
-
|
|
493
|
+
unified_blocks_bind_context = {}
|
|
377
494
|
for j in range(len(block_adapter.blocks[i])):
|
|
378
|
-
|
|
495
|
+
cache_config: BasicCacheConfig = contexts_kwargs[
|
|
496
|
+
i * len(block_adapter.blocks[i]) + j
|
|
497
|
+
]["cache_config"]
|
|
498
|
+
unified_blocks_bind_context[
|
|
379
499
|
block_adapter.unique_blocks_name[i][j]
|
|
380
500
|
] = torch.nn.ModuleList(
|
|
381
501
|
[
|
|
382
|
-
|
|
502
|
+
UnifiedBlocks(
|
|
383
503
|
# 0. Transformer blocks configuration
|
|
384
504
|
block_adapter.blocks[i][j],
|
|
385
505
|
transformer=block_adapter.transformer[i],
|
|
386
506
|
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
387
507
|
check_forward_pattern=block_adapter.check_forward_pattern,
|
|
388
508
|
check_num_outputs=block_adapter.check_num_outputs,
|
|
389
|
-
# 1. Cache context configuration
|
|
509
|
+
# 1. Cache/Prune context configuration
|
|
390
510
|
cache_prefix=block_adapter.blocks_name[i][j],
|
|
391
511
|
cache_context=block_adapter.unique_blocks_name[i][
|
|
392
512
|
j
|
|
393
513
|
],
|
|
394
|
-
|
|
514
|
+
context_manager=block_adapter.pipe._context_manager,
|
|
515
|
+
cache_type=cache_config.cache_type,
|
|
395
516
|
)
|
|
396
517
|
]
|
|
397
518
|
)
|
|
398
519
|
|
|
399
|
-
total_cached_blocks.append(
|
|
520
|
+
total_cached_blocks.append(unified_blocks_bind_context)
|
|
400
521
|
|
|
401
522
|
return total_cached_blocks
|
|
402
523
|
|
|
@@ -406,7 +527,7 @@ class CachedAdapter:
|
|
|
406
527
|
block_adapter: BlockAdapter,
|
|
407
528
|
contexts_kwargs: List[Dict],
|
|
408
529
|
):
|
|
409
|
-
block_adapter.pipe.
|
|
530
|
+
block_adapter.pipe._context_kwargs = contexts_kwargs[0]
|
|
410
531
|
|
|
411
532
|
params_shift = 0
|
|
412
533
|
for i in range(len(block_adapter.transformer)):
|
|
@@ -417,40 +538,43 @@ class CachedAdapter:
|
|
|
417
538
|
block_adapter.transformer[i]._has_separate_cfg = (
|
|
418
539
|
block_adapter.has_separate_cfg
|
|
419
540
|
)
|
|
420
|
-
block_adapter.transformer[i].
|
|
421
|
-
|
|
422
|
-
|
|
541
|
+
block_adapter.transformer[i]._context_kwargs = contexts_kwargs[
|
|
542
|
+
params_shift
|
|
543
|
+
]
|
|
423
544
|
|
|
424
545
|
blocks = block_adapter.blocks[i]
|
|
425
546
|
for j in range(len(blocks)):
|
|
426
547
|
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
427
|
-
blocks[j].
|
|
428
|
-
params_shift + j
|
|
429
|
-
]
|
|
548
|
+
blocks[j]._context_kwargs = contexts_kwargs[params_shift + j]
|
|
430
549
|
|
|
431
550
|
params_shift += len(blocks)
|
|
432
551
|
|
|
433
552
|
@classmethod
|
|
553
|
+
@torch.compiler.disable
|
|
434
554
|
def apply_stats_hooks(
|
|
435
555
|
cls,
|
|
436
556
|
block_adapter: BlockAdapter,
|
|
437
557
|
):
|
|
438
|
-
|
|
558
|
+
from cache_dit.caching.cache_blocks import (
|
|
559
|
+
apply_stats,
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
context_manager = block_adapter.pipe._context_manager
|
|
439
563
|
|
|
440
564
|
for i in range(len(block_adapter.transformer)):
|
|
441
|
-
|
|
565
|
+
apply_stats(
|
|
442
566
|
block_adapter.transformer[i],
|
|
443
567
|
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
444
|
-
|
|
568
|
+
context_manager=context_manager,
|
|
445
569
|
)
|
|
446
570
|
for blocks, unique_name in zip(
|
|
447
571
|
block_adapter.blocks[i],
|
|
448
572
|
block_adapter.unique_blocks_name[i],
|
|
449
573
|
):
|
|
450
|
-
|
|
574
|
+
apply_stats(
|
|
451
575
|
blocks,
|
|
452
576
|
cache_context=unique_name,
|
|
453
|
-
|
|
577
|
+
context_manager=context_manager,
|
|
454
578
|
)
|
|
455
579
|
|
|
456
580
|
@classmethod
|
|
@@ -478,11 +602,13 @@ class CachedAdapter:
|
|
|
478
602
|
original_call = pipe.__class__._original_call
|
|
479
603
|
pipe.__class__.__call__ = original_call
|
|
480
604
|
del pipe.__class__._original_call
|
|
481
|
-
if hasattr(pipe, "
|
|
482
|
-
|
|
483
|
-
if isinstance(
|
|
484
|
-
|
|
485
|
-
|
|
605
|
+
if hasattr(pipe, "_context_manager"):
|
|
606
|
+
context_manager = pipe._context_manager
|
|
607
|
+
if isinstance(
|
|
608
|
+
context_manager, ContextManager._supported_managers
|
|
609
|
+
):
|
|
610
|
+
context_manager.clear_contexts()
|
|
611
|
+
del pipe._context_manager
|
|
486
612
|
if hasattr(pipe, "_is_cached"):
|
|
487
613
|
del pipe.__class__._is_cached
|
|
488
614
|
|
|
@@ -497,22 +623,22 @@ class CachedAdapter:
|
|
|
497
623
|
def _release_blocks_params(blocks):
|
|
498
624
|
if hasattr(blocks, "_forward_pattern"):
|
|
499
625
|
del blocks._forward_pattern
|
|
500
|
-
if hasattr(blocks, "
|
|
501
|
-
del blocks.
|
|
626
|
+
if hasattr(blocks, "_context_kwargs"):
|
|
627
|
+
del blocks._context_kwargs
|
|
502
628
|
|
|
503
629
|
def _release_transformer_params(transformer):
|
|
504
630
|
if hasattr(transformer, "_forward_pattern"):
|
|
505
631
|
del transformer._forward_pattern
|
|
506
632
|
if hasattr(transformer, "_has_separate_cfg"):
|
|
507
633
|
del transformer._has_separate_cfg
|
|
508
|
-
if hasattr(transformer, "
|
|
509
|
-
del transformer.
|
|
634
|
+
if hasattr(transformer, "_context_kwargs"):
|
|
635
|
+
del transformer._context_kwargs
|
|
510
636
|
for blocks in BlockAdapter.find_blocks(transformer):
|
|
511
637
|
_release_blocks_params(blocks)
|
|
512
638
|
|
|
513
639
|
def _release_pipeline_params(pipe):
|
|
514
|
-
if hasattr(pipe, "
|
|
515
|
-
del pipe.
|
|
640
|
+
if hasattr(pipe, "_context_kwargs"):
|
|
641
|
+
del pipe._context_kwargs
|
|
516
642
|
|
|
517
643
|
cls.release_hooks(
|
|
518
644
|
pipe_or_adapter,
|
|
@@ -522,11 +648,24 @@ class CachedAdapter:
|
|
|
522
648
|
)
|
|
523
649
|
|
|
524
650
|
# release stats hooks
|
|
651
|
+
from cache_dit.caching.cache_blocks import (
|
|
652
|
+
remove_stats,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
cls.release_hooks(
|
|
656
|
+
pipe_or_adapter, remove_stats, remove_stats, remove_stats
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
# maybe release parallelism stats
|
|
660
|
+
from cache_dit.parallelism.parallel_interface import (
|
|
661
|
+
remove_parallelism_stats,
|
|
662
|
+
)
|
|
663
|
+
|
|
525
664
|
cls.release_hooks(
|
|
526
665
|
pipe_or_adapter,
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
666
|
+
remove_parallelism_stats,
|
|
667
|
+
remove_parallelism_stats,
|
|
668
|
+
remove_parallelism_stats,
|
|
530
669
|
)
|
|
531
670
|
|
|
532
671
|
@classmethod
|