cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,17 @@
|
|
|
1
|
+
import torch
|
|
1
2
|
from typing import Any, Tuple, List, Union, Optional
|
|
2
|
-
from diffusers import DiffusionPipeline
|
|
3
|
-
from cache_dit.
|
|
4
|
-
from cache_dit.
|
|
5
|
-
from cache_dit.
|
|
6
|
-
from cache_dit.
|
|
7
|
-
from cache_dit.
|
|
8
|
-
from cache_dit.
|
|
9
|
-
from cache_dit.
|
|
3
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
4
|
+
from cache_dit.caching.cache_types import CacheType
|
|
5
|
+
from cache_dit.caching.block_adapters import BlockAdapter
|
|
6
|
+
from cache_dit.caching.block_adapters import BlockAdapterRegistry
|
|
7
|
+
from cache_dit.caching.cache_adapters import CachedAdapter
|
|
8
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
9
|
+
from cache_dit.caching.cache_contexts import DBCacheConfig
|
|
10
|
+
from cache_dit.caching.cache_contexts import DBPruneConfig
|
|
11
|
+
from cache_dit.caching.cache_contexts import CalibratorConfig
|
|
12
|
+
from cache_dit.caching.params_modifier import ParamsModifier
|
|
13
|
+
from cache_dit.parallelism import ParallelismConfig
|
|
14
|
+
from cache_dit.parallelism import enable_parallelism
|
|
10
15
|
|
|
11
16
|
from cache_dit.logger import init_logger
|
|
12
17
|
|
|
@@ -18,9 +23,18 @@ def enable_cache(
|
|
|
18
23
|
pipe_or_adapter: Union[
|
|
19
24
|
DiffusionPipeline,
|
|
20
25
|
BlockAdapter,
|
|
26
|
+
# Transformer-only
|
|
27
|
+
torch.nn.Module,
|
|
28
|
+
ModelMixin,
|
|
21
29
|
],
|
|
22
|
-
#
|
|
23
|
-
cache_config:
|
|
30
|
+
# BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
|
|
31
|
+
cache_config: Optional[
|
|
32
|
+
Union[
|
|
33
|
+
BasicCacheConfig,
|
|
34
|
+
DBCacheConfig,
|
|
35
|
+
DBPruneConfig,
|
|
36
|
+
]
|
|
37
|
+
] = None,
|
|
24
38
|
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
25
39
|
calibrator_config: Optional[CalibratorConfig] = None,
|
|
26
40
|
# Modify cache context params for specific blocks.
|
|
@@ -31,10 +45,15 @@ def enable_cache(
|
|
|
31
45
|
List[List[ParamsModifier]],
|
|
32
46
|
]
|
|
33
47
|
] = None,
|
|
48
|
+
# Config for Parallelism
|
|
49
|
+
parallelism_config: Optional[ParallelismConfig] = None,
|
|
34
50
|
# Other cache context kwargs: Deprecated cache kwargs
|
|
35
51
|
**kwargs,
|
|
36
52
|
) -> Union[
|
|
37
53
|
DiffusionPipeline,
|
|
54
|
+
# Transformer-only
|
|
55
|
+
torch.nn.Module,
|
|
56
|
+
ModelMixin,
|
|
38
57
|
BlockAdapter,
|
|
39
58
|
]:
|
|
40
59
|
r"""
|
|
@@ -64,10 +83,9 @@ def enable_cache(
|
|
|
64
83
|
with minimal code changes.
|
|
65
84
|
|
|
66
85
|
Args:
|
|
67
|
-
pipe_or_adapter (`DiffusionPipeline` or `
|
|
86
|
+
pipe_or_adapter (`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
|
|
68
87
|
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
69
|
-
For example: cache_dit.enable_cache(FluxPipeline(...)).
|
|
70
|
-
for the usgae of BlockAdapter.
|
|
88
|
+
For example: cache_dit.enable_cache(FluxPipeline(...)).
|
|
71
89
|
|
|
72
90
|
cache_config (`BasicCacheConfig`, *required*, defaults to BasicCacheConfig()):
|
|
73
91
|
Basic DBCache config for cache context, defaults to BasicCacheConfig(). The configurable params listed belows:
|
|
@@ -107,6 +125,10 @@ def enable_cache(
|
|
|
107
125
|
Whether to compute separate difference values for CFG and non-CFG steps, default is True.
|
|
108
126
|
If False, we will use the computed difference from the current non-CFG transformer step
|
|
109
127
|
for the current CFG step.
|
|
128
|
+
num_inference_steps (`int`, *optional*, defaults to None):
|
|
129
|
+
num_inference_steps for DiffusionPipeline, used to adjust some internal settings
|
|
130
|
+
for better caching performance. For example, we will refresh the cache once the
|
|
131
|
+
executed steps exceed num_inference_steps if num_inference_steps is provided.
|
|
110
132
|
|
|
111
133
|
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
112
134
|
Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
|
|
@@ -121,8 +143,29 @@ def enable_cache(
|
|
|
121
143
|
**kwargs: (`dict`, *optional*, defaults to {}):
|
|
122
144
|
The same as 'kwargs' param in cache_dit.enable_cache() interface.
|
|
123
145
|
|
|
146
|
+
parallelism_config (`ParallelismConfig`, *optional*, defaults to None):
|
|
147
|
+
Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
|
|
148
|
+
parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
|
|
149
|
+
for more details of ParallelismConfig.
|
|
150
|
+
backend: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
|
|
151
|
+
Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
|
|
152
|
+
For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
|
|
153
|
+
only NATIVE_PYTORCH backend is supported.
|
|
154
|
+
ulysses_size: (`int`, *optional*, defaults to None):
|
|
155
|
+
The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
|
|
156
|
+
This setting is only valid when backend is NATIVE_DIFFUSER.
|
|
157
|
+
ring_size: (`int`, *optional*, defaults to None):
|
|
158
|
+
The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
|
|
159
|
+
This setting is only valid when backend is NATIVE_DIFFUSER.
|
|
160
|
+
tp_size: (`int`, *optional*, defaults to None):
|
|
161
|
+
The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
|
|
162
|
+
This setting is only valid when backend is NATIVE_PYTORCH.
|
|
163
|
+
parallel_kwargs: (`dict`, *optional*, defaults to {}):
|
|
164
|
+
Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
|
|
165
|
+
it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
|
|
166
|
+
|
|
124
167
|
kwargs (`dict`, *optional*, defaults to {})
|
|
125
|
-
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/
|
|
168
|
+
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_context.py
|
|
126
169
|
for more details.
|
|
127
170
|
|
|
128
171
|
Examples:
|
|
@@ -135,15 +178,29 @@ def enable_cache(
|
|
|
135
178
|
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
136
179
|
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
137
180
|
"""
|
|
181
|
+
# Precheck for compatibility of different configurations
|
|
182
|
+
if cache_config is None:
|
|
183
|
+
if parallelism_config is None:
|
|
184
|
+
# Set default cache config only when parallelism is not enabled
|
|
185
|
+
logger.info("cache_config is None, using default DBCacheConfig")
|
|
186
|
+
cache_config = DBCacheConfig()
|
|
187
|
+
else:
|
|
188
|
+
# Allow empty cache_config when parallelism is enabled
|
|
189
|
+
logger.warning(
|
|
190
|
+
"Parallelism is enabled and cache_config is None. Please manually "
|
|
191
|
+
"set cache_config to avoid potential compatibility issues. "
|
|
192
|
+
"Otherwise, cache will not be enabled."
|
|
193
|
+
)
|
|
194
|
+
|
|
138
195
|
# Collect cache context kwargs
|
|
139
|
-
|
|
140
|
-
if (cache_type :=
|
|
196
|
+
context_kwargs = {}
|
|
197
|
+
if (cache_type := context_kwargs.get("cache_type", None)) is not None:
|
|
141
198
|
if cache_type == CacheType.NONE:
|
|
142
199
|
return pipe_or_adapter
|
|
143
200
|
|
|
144
|
-
#
|
|
201
|
+
# NOTE: Deprecated cache config params. These parameters are now retained
|
|
145
202
|
# for backward compatibility but will be removed in the future.
|
|
146
|
-
|
|
203
|
+
deprecated_kwargs = {
|
|
147
204
|
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
148
205
|
"Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
|
|
149
206
|
"max_warmup_steps": kwargs.get("max_warmup_steps", None),
|
|
@@ -159,27 +216,27 @@ def enable_cache(
|
|
|
159
216
|
),
|
|
160
217
|
}
|
|
161
218
|
|
|
162
|
-
|
|
163
|
-
k: v for k, v in
|
|
219
|
+
deprecated_kwargs = {
|
|
220
|
+
k: v for k, v in deprecated_kwargs.items() if v is not None
|
|
164
221
|
}
|
|
165
222
|
|
|
166
|
-
if
|
|
223
|
+
if deprecated_kwargs:
|
|
167
224
|
logger.warning(
|
|
168
225
|
"Manually settup DBCache context without BasicCacheConfig is "
|
|
169
226
|
"deprecated and will be removed in the future, please use "
|
|
170
227
|
"`cache_config` parameter instead!"
|
|
171
228
|
)
|
|
172
229
|
if cache_config is not None:
|
|
173
|
-
cache_config.update(**
|
|
230
|
+
cache_config.update(**deprecated_kwargs)
|
|
174
231
|
else:
|
|
175
|
-
cache_config = BasicCacheConfig(**
|
|
232
|
+
cache_config = BasicCacheConfig(**deprecated_kwargs)
|
|
176
233
|
|
|
177
234
|
if cache_config is not None:
|
|
178
|
-
|
|
235
|
+
context_kwargs["cache_config"] = cache_config
|
|
179
236
|
|
|
180
|
-
#
|
|
237
|
+
# NOTE: Deprecated taylorseer params. These parameters are now retained
|
|
181
238
|
# for backward compatibility but will be removed in the future.
|
|
182
|
-
if (
|
|
239
|
+
if cache_config is not None and (
|
|
183
240
|
kwargs.get("enable_taylorseer", None) is not None
|
|
184
241
|
or kwargs.get("enable_encoder_taylorseer", None) is not None
|
|
185
242
|
):
|
|
@@ -188,7 +245,7 @@ def enable_cache(
|
|
|
188
245
|
"deprecated and will be removed in the future, please use "
|
|
189
246
|
"`calibrator_config` parameter instead!"
|
|
190
247
|
)
|
|
191
|
-
from cache_dit.
|
|
248
|
+
from cache_dit.caching.cache_contexts.calibrators import (
|
|
192
249
|
TaylorSeerCalibratorConfig,
|
|
193
250
|
)
|
|
194
251
|
|
|
@@ -202,23 +259,79 @@ def enable_cache(
|
|
|
202
259
|
)
|
|
203
260
|
|
|
204
261
|
if calibrator_config is not None:
|
|
205
|
-
|
|
262
|
+
context_kwargs["calibrator_config"] = calibrator_config
|
|
206
263
|
|
|
207
264
|
if params_modifiers is not None:
|
|
208
|
-
|
|
265
|
+
context_kwargs["params_modifiers"] = params_modifiers
|
|
209
266
|
|
|
210
|
-
if
|
|
211
|
-
|
|
267
|
+
if cache_config is not None:
|
|
268
|
+
if isinstance(
|
|
212
269
|
pipe_or_adapter,
|
|
213
|
-
|
|
214
|
-
)
|
|
270
|
+
(DiffusionPipeline, BlockAdapter, torch.nn.Module, ModelMixin),
|
|
271
|
+
):
|
|
272
|
+
pipe_or_adapter = CachedAdapter.apply(
|
|
273
|
+
pipe_or_adapter,
|
|
274
|
+
**context_kwargs,
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
raise ValueError(
|
|
278
|
+
f"type: {type(pipe_or_adapter)} is not valid, "
|
|
279
|
+
"Please pass DiffusionPipeline or BlockAdapter"
|
|
280
|
+
"for the 1's position param: pipe_or_adapter"
|
|
281
|
+
)
|
|
215
282
|
else:
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
"
|
|
219
|
-
"for the 1's position param: pipe_or_adapter"
|
|
283
|
+
logger.warning(
|
|
284
|
+
"cache_config is None, skip enabling cache for "
|
|
285
|
+
f"{pipe_or_adapter.__class__.__name__}."
|
|
220
286
|
)
|
|
221
287
|
|
|
288
|
+
# NOTE: Users should always enable parallelism after applying
|
|
289
|
+
# cache to avoid hooks conflict.
|
|
290
|
+
if parallelism_config is not None:
|
|
291
|
+
assert isinstance(
|
|
292
|
+
parallelism_config, ParallelismConfig
|
|
293
|
+
), "parallelism_config should be of type ParallelismConfig."
|
|
294
|
+
|
|
295
|
+
transformers = []
|
|
296
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
297
|
+
adapter = BlockAdapterRegistry.get_adapter(pipe_or_adapter)
|
|
298
|
+
if adapter is None:
|
|
299
|
+
assert hasattr(pipe_or_adapter, "transformer"), (
|
|
300
|
+
"The given DiffusionPipeline does not have "
|
|
301
|
+
"a 'transformer' attribute, cannot enable "
|
|
302
|
+
"parallelism."
|
|
303
|
+
)
|
|
304
|
+
transformers = [pipe_or_adapter.transformer]
|
|
305
|
+
else:
|
|
306
|
+
adapter = BlockAdapter.normalize(adapter, unique=False)
|
|
307
|
+
transformers = BlockAdapter.flatten(adapter.transformer)
|
|
308
|
+
else:
|
|
309
|
+
if not BlockAdapter.is_normalized(pipe_or_adapter):
|
|
310
|
+
pipe_or_adapter = BlockAdapter.normalize(
|
|
311
|
+
pipe_or_adapter, unique=False
|
|
312
|
+
)
|
|
313
|
+
transformers = BlockAdapter.flatten(pipe_or_adapter.transformer)
|
|
314
|
+
|
|
315
|
+
if len(transformers) == 0:
|
|
316
|
+
logger.warning(
|
|
317
|
+
"No transformer is detected in the "
|
|
318
|
+
"BlockAdapter, skip enabling parallelism."
|
|
319
|
+
)
|
|
320
|
+
return pipe_or_adapter
|
|
321
|
+
|
|
322
|
+
if len(transformers) > 1:
|
|
323
|
+
logger.warning(
|
|
324
|
+
"Multiple transformers are detected in the "
|
|
325
|
+
"BlockAdapter, all transfomers will be "
|
|
326
|
+
"enabled for parallelism."
|
|
327
|
+
)
|
|
328
|
+
for i, transformer in enumerate(transformers):
|
|
329
|
+
# Enable parallelism for the transformer inplace
|
|
330
|
+
transformers[i] = enable_parallelism(
|
|
331
|
+
transformer, parallelism_config
|
|
332
|
+
)
|
|
333
|
+
return pipe_or_adapter
|
|
334
|
+
|
|
222
335
|
|
|
223
336
|
def disable_cache(
|
|
224
337
|
pipe_or_adapter: Union[
|
|
@@ -6,7 +6,8 @@ logger = init_logger(__name__)
|
|
|
6
6
|
|
|
7
7
|
class CacheType(Enum):
|
|
8
8
|
NONE = "NONE"
|
|
9
|
-
DBCache = "Dual_Block_Cache"
|
|
9
|
+
DBCache = "DBCache" # "Dual_Block_Cache"
|
|
10
|
+
DBPrune = "DBPrune" # "Dynamic_Block_Prune"
|
|
10
11
|
|
|
11
12
|
@staticmethod
|
|
12
13
|
def type(type_hint: "CacheType | str") -> "CacheType":
|
|
@@ -14,6 +15,9 @@ class CacheType(Enum):
|
|
|
14
15
|
return type_hint
|
|
15
16
|
return cache_type(type_hint)
|
|
16
17
|
|
|
18
|
+
def __str__(self) -> str:
|
|
19
|
+
return self.value
|
|
20
|
+
|
|
17
21
|
|
|
18
22
|
def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
19
23
|
if type_hint is None:
|
|
@@ -21,7 +25,6 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
|
21
25
|
|
|
22
26
|
if isinstance(type_hint, CacheType):
|
|
23
27
|
return type_hint
|
|
24
|
-
|
|
25
28
|
elif type_hint.upper() in (
|
|
26
29
|
"DUAL_BLOCK_CACHE",
|
|
27
30
|
"DB_CACHE",
|
|
@@ -29,6 +32,20 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
|
29
32
|
"DB",
|
|
30
33
|
):
|
|
31
34
|
return CacheType.DBCache
|
|
35
|
+
elif type_hint.upper() in (
|
|
36
|
+
"DYNAMIC_BLOCK_PRUNE",
|
|
37
|
+
"DB_PRUNE",
|
|
38
|
+
"DBPRUNE",
|
|
39
|
+
"DBP",
|
|
40
|
+
):
|
|
41
|
+
return CacheType.DBPrune
|
|
42
|
+
elif type_hint.upper() in (
|
|
43
|
+
"NONE",
|
|
44
|
+
"NO_CACHE",
|
|
45
|
+
"NOCACHE",
|
|
46
|
+
"NC",
|
|
47
|
+
):
|
|
48
|
+
return CacheType.NONE
|
|
32
49
|
return CacheType.NONE
|
|
33
50
|
|
|
34
51
|
|
|
@@ -20,33 +20,33 @@ class ForwardPattern(Enum):
|
|
|
20
20
|
|
|
21
21
|
Pattern_0 = (
|
|
22
22
|
True, # Return_H_First
|
|
23
|
-
False,
|
|
24
|
-
False,
|
|
23
|
+
False, # Return_H_Only
|
|
24
|
+
False, # Forward_H_only
|
|
25
25
|
("hidden_states", "encoder_hidden_states"), # In
|
|
26
26
|
("hidden_states", "encoder_hidden_states"), # Out
|
|
27
27
|
True, # Supported
|
|
28
28
|
)
|
|
29
29
|
|
|
30
30
|
Pattern_1 = (
|
|
31
|
-
False,
|
|
32
|
-
False,
|
|
33
|
-
False,
|
|
31
|
+
False, # Return_H_First
|
|
32
|
+
False, # Return_H_Only
|
|
33
|
+
False, # Forward_H_only
|
|
34
34
|
("hidden_states", "encoder_hidden_states"), # In
|
|
35
35
|
("encoder_hidden_states", "hidden_states"), # Out
|
|
36
36
|
True, # Supported
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
Pattern_2 = (
|
|
40
|
-
False,
|
|
40
|
+
False, # Return_H_First
|
|
41
41
|
True, # Return_H_Only
|
|
42
|
-
False,
|
|
42
|
+
False, # Forward_H_only
|
|
43
43
|
("hidden_states", "encoder_hidden_states"), # In
|
|
44
|
-
("hidden_states",),
|
|
44
|
+
("hidden_states",), # Out
|
|
45
45
|
True, # Supported
|
|
46
46
|
)
|
|
47
47
|
|
|
48
48
|
Pattern_3 = (
|
|
49
|
-
False,
|
|
49
|
+
False, # Return_H_First
|
|
50
50
|
True, # Return_H_Only
|
|
51
51
|
True, # Forward_H_only
|
|
52
52
|
("hidden_states",), # In
|
|
@@ -56,18 +56,18 @@ class ForwardPattern(Enum):
|
|
|
56
56
|
|
|
57
57
|
Pattern_4 = (
|
|
58
58
|
True, # Return_H_First
|
|
59
|
-
False,
|
|
59
|
+
False, # Return_H_Only
|
|
60
60
|
True, # Forward_H_only
|
|
61
|
-
("hidden_states",),
|
|
61
|
+
("hidden_states",), # In
|
|
62
62
|
("hidden_states", "encoder_hidden_states"), # Out
|
|
63
63
|
True, # Supported
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
Pattern_5 = (
|
|
67
|
-
False,
|
|
68
|
-
False,
|
|
67
|
+
False, # Return_H_First
|
|
68
|
+
False, # Return_H_Only
|
|
69
69
|
True, # Forward_H_only
|
|
70
|
-
("hidden_states",),
|
|
70
|
+
("hidden_states",), # In
|
|
71
71
|
("encoder_hidden_states", "hidden_states"), # Out
|
|
72
72
|
True, # Supported
|
|
73
73
|
)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
-
from cache_dit.
|
|
4
|
-
from cache_dit.
|
|
3
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
4
|
+
from cache_dit.caching.cache_contexts import CalibratorConfig
|
|
5
5
|
|
|
6
6
|
from cache_dit.logger import init_logger
|
|
7
7
|
|
|
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
|
|
|
11
11
|
class ParamsModifier:
|
|
12
12
|
def __init__(
|
|
13
13
|
self,
|
|
14
|
-
#
|
|
14
|
+
# BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
|
|
15
15
|
cache_config: BasicCacheConfig = None,
|
|
16
16
|
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
17
17
|
calibrator_config: Optional[CalibratorConfig] = None,
|
|
@@ -22,7 +22,7 @@ class ParamsModifier:
|
|
|
22
22
|
|
|
23
23
|
# WARNING: Deprecated cache config params. These parameters are now retained
|
|
24
24
|
# for backward compatibility but will be removed in the future.
|
|
25
|
-
|
|
25
|
+
deprecated_kwargs = {
|
|
26
26
|
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
27
27
|
"Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
|
|
28
28
|
"max_warmup_steps": kwargs.get("max_warmup_steps", None),
|
|
@@ -40,20 +40,20 @@ class ParamsModifier:
|
|
|
40
40
|
),
|
|
41
41
|
}
|
|
42
42
|
|
|
43
|
-
|
|
44
|
-
k: v for k, v in
|
|
43
|
+
deprecated_kwargs = {
|
|
44
|
+
k: v for k, v in deprecated_kwargs.items() if v is not None
|
|
45
45
|
}
|
|
46
46
|
|
|
47
|
-
if
|
|
47
|
+
if deprecated_kwargs:
|
|
48
48
|
logger.warning(
|
|
49
49
|
"Manually settup DBCache context without BasicCacheConfig is "
|
|
50
50
|
"deprecated and will be removed in the future, please use "
|
|
51
51
|
"`cache_config` parameter instead!"
|
|
52
52
|
)
|
|
53
53
|
if cache_config is not None:
|
|
54
|
-
cache_config.update(**
|
|
54
|
+
cache_config.update(**deprecated_kwargs)
|
|
55
55
|
else:
|
|
56
|
-
cache_config = BasicCacheConfig(**
|
|
56
|
+
cache_config = BasicCacheConfig(**deprecated_kwargs)
|
|
57
57
|
|
|
58
58
|
if cache_config is not None:
|
|
59
59
|
self._context_kwargs["cache_config"] = cache_config
|
|
@@ -68,7 +68,7 @@ class ParamsModifier:
|
|
|
68
68
|
"deprecated and will be removed in the future, please use "
|
|
69
69
|
"`calibrator_config` parameter instead!"
|
|
70
70
|
)
|
|
71
|
-
from cache_dit.
|
|
71
|
+
from cache_dit.caching.cache_contexts.calibrators import (
|
|
72
72
|
TaylorSeerCalibratorConfig,
|
|
73
73
|
)
|
|
74
74
|
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from cache_dit.caching.patch_functors.functor_base import PatchFunctor
|
|
2
|
+
from cache_dit.caching.patch_functors.functor_dit import DiTPatchFunctor
|
|
3
|
+
from cache_dit.caching.patch_functors.functor_flux import FluxPatchFunctor
|
|
4
|
+
from cache_dit.caching.patch_functors.functor_chroma import (
|
|
5
|
+
ChromaPatchFunctor,
|
|
6
|
+
)
|
|
7
|
+
from cache_dit.caching.patch_functors.functor_hidream import (
|
|
8
|
+
HiDreamPatchFunctor,
|
|
9
|
+
)
|
|
10
|
+
from cache_dit.caching.patch_functors.functor_hunyuan_dit import (
|
|
11
|
+
HunyuanDiTPatchFunctor,
|
|
12
|
+
)
|
|
13
|
+
from cache_dit.caching.patch_functors.functor_qwen_image_controlnet import (
|
|
14
|
+
QwenImageControlNetPatchFunctor,
|
|
15
|
+
)
|
|
@@ -6,7 +6,7 @@ from diffusers.models.transformers.dit_transformer_2d import (
|
|
|
6
6
|
DiTTransformer2DModel,
|
|
7
7
|
Transformer2DModelOutput,
|
|
8
8
|
)
|
|
9
|
-
from cache_dit.
|
|
9
|
+
from cache_dit.caching.patch_functors.functor_base import (
|
|
10
10
|
PatchFunctor,
|
|
11
11
|
)
|
|
12
12
|
from cache_dit.logger import init_logger
|
|
@@ -13,7 +13,7 @@ from diffusers.utils import (
|
|
|
13
13
|
scale_lora_layers,
|
|
14
14
|
unscale_lora_layers,
|
|
15
15
|
)
|
|
16
|
-
from cache_dit.
|
|
16
|
+
from cache_dit.caching.patch_functors.functor_base import (
|
|
17
17
|
PatchFunctor,
|
|
18
18
|
)
|
|
19
19
|
from cache_dit.logger import init_logger
|
|
@@ -5,7 +5,7 @@ from diffusers.models.transformers.hunyuan_transformer_2d import (
|
|
|
5
5
|
HunyuanDiTBlock,
|
|
6
6
|
Transformer2DModelOutput,
|
|
7
7
|
)
|
|
8
|
-
from cache_dit.
|
|
8
|
+
from cache_dit.caching.patch_functors.functor_base import (
|
|
9
9
|
PatchFunctor,
|
|
10
10
|
)
|
|
11
11
|
from cache_dit.logger import init_logger
|
|
@@ -11,7 +11,7 @@ from diffusers.utils import (
|
|
|
11
11
|
scale_lora_layers,
|
|
12
12
|
unscale_lora_layers,
|
|
13
13
|
)
|
|
14
|
-
from cache_dit.
|
|
14
|
+
from cache_dit.caching.patch_functors.functor_base import (
|
|
15
15
|
PatchFunctor,
|
|
16
16
|
)
|
|
17
17
|
from cache_dit.logger import init_logger
|
|
@@ -7,10 +7,6 @@ def load_cache_options_from_yaml(yaml_file_path):
|
|
|
7
7
|
kwargs: dict = yaml.safe_load(f)
|
|
8
8
|
|
|
9
9
|
required_keys = [
|
|
10
|
-
"max_warmup_steps",
|
|
11
|
-
"max_cached_steps",
|
|
12
|
-
"Fn_compute_blocks",
|
|
13
|
-
"Bn_compute_blocks",
|
|
14
10
|
"residual_diff_threshold",
|
|
15
11
|
]
|
|
16
12
|
for key in required_keys:
|
|
@@ -21,7 +17,7 @@ def load_cache_options_from_yaml(yaml_file_path):
|
|
|
21
17
|
|
|
22
18
|
cache_context_kwargs = {}
|
|
23
19
|
if kwargs.get("enable_taylorseer", False):
|
|
24
|
-
from cache_dit.
|
|
20
|
+
from cache_dit.caching.cache_contexts.calibrators import (
|
|
25
21
|
TaylorSeerCalibratorConfig,
|
|
26
22
|
)
|
|
27
23
|
|
|
@@ -38,10 +34,25 @@ def load_cache_options_from_yaml(yaml_file_path):
|
|
|
38
34
|
)
|
|
39
35
|
)
|
|
40
36
|
|
|
41
|
-
|
|
37
|
+
if "cache_type" not in kwargs:
|
|
38
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
42
39
|
|
|
43
|
-
|
|
44
|
-
|
|
40
|
+
cache_context_kwargs["cache_config"] = BasicCacheConfig()
|
|
41
|
+
cache_context_kwargs["cache_config"].update(**kwargs)
|
|
42
|
+
else:
|
|
43
|
+
cache_type = kwargs.pop("cache_type")
|
|
44
|
+
if cache_type == "DBCache":
|
|
45
|
+
from cache_dit.caching.cache_contexts import DBCacheConfig
|
|
46
|
+
|
|
47
|
+
cache_context_kwargs["cache_config"] = DBCacheConfig()
|
|
48
|
+
cache_context_kwargs["cache_config"].update(**kwargs)
|
|
49
|
+
elif cache_type == "DBPrune":
|
|
50
|
+
from cache_dit.caching.cache_contexts import DBPruneConfig
|
|
51
|
+
|
|
52
|
+
cache_context_kwargs["cache_config"] = DBPruneConfig()
|
|
53
|
+
cache_context_kwargs["cache_config"].update(**kwargs)
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Unsupported cache_type: {cache_type}.")
|
|
45
56
|
|
|
46
57
|
return cache_context_kwargs
|
|
47
58
|
|
cache_dit/metrics/__init__.py
CHANGED
|
@@ -1,3 +1,14 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import ImageReward
|
|
3
|
+
import lpips
|
|
4
|
+
import skimage
|
|
5
|
+
import scipy
|
|
6
|
+
except ImportError:
|
|
7
|
+
raise ImportError(
|
|
8
|
+
"Metrics functionality requires the 'metrics' extra dependencies. "
|
|
9
|
+
"Install with:\npip install cache-dit[metrics]"
|
|
10
|
+
)
|
|
11
|
+
|
|
1
12
|
from cache_dit.metrics.metrics import compute_psnr
|
|
2
13
|
from cache_dit.metrics.metrics import compute_ssim
|
|
3
14
|
from cache_dit.metrics.metrics import compute_mse
|