cache-dit 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +13 -1
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/block_adapters.py +24 -1
- cache_dit/cache_factory/block_adapters/block_registers.py +2 -2
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +4 -0
- cache_dit/cache_factory/cache_interface.py +75 -27
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/{parallel_difffusers.py → native_diffusers/parallel_difffusers.py} +28 -8
- cache_dit/parallelism/backends/native_pytorch/__init__.py +0 -0
- cache_dit/parallelism/parallel_config.py +8 -0
- cache_dit/parallelism/parallel_interface.py +6 -2
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/utils.py +34 -18
- {cache_dit-1.0.7.dist-info → cache_dit-1.0.9.dist-info}/METADATA +18 -20
- {cache_dit-1.0.7.dist-info → cache_dit-1.0.9.dist-info}/RECORD +20 -18
- {cache_dit-1.0.7.dist-info → cache_dit-1.0.9.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.7.dist-info → cache_dit-1.0.9.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.7.dist-info → cache_dit-1.0.9.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.7.dist-info → cache_dit-1.0.9.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -26,12 +26,24 @@ from cache_dit.cache_factory import FoCaCalibratorConfig
|
|
|
26
26
|
from cache_dit.cache_factory import supported_pipelines
|
|
27
27
|
from cache_dit.cache_factory import get_adapter
|
|
28
28
|
from cache_dit.compile import set_compile_configs
|
|
29
|
-
from cache_dit.quantize import quantize
|
|
30
29
|
from cache_dit.parallelism import ParallelismBackend
|
|
31
30
|
from cache_dit.parallelism import ParallelismConfig
|
|
32
31
|
from cache_dit.utils import summary
|
|
33
32
|
from cache_dit.utils import strify
|
|
34
33
|
|
|
34
|
+
try:
|
|
35
|
+
from cache_dit.quantize import quantize
|
|
36
|
+
except ImportError as e: # noqa: F841
|
|
37
|
+
err_msg = str(e)
|
|
38
|
+
|
|
39
|
+
def quantize(*args, **kwargs):
|
|
40
|
+
raise ImportError(
|
|
41
|
+
"Quantization requires additional dependencies. "
|
|
42
|
+
"Please install cache-dit[quantization] or cache-dit[all] "
|
|
43
|
+
f"to use this feature. Error message: {err_msg}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
35
47
|
NONE = CacheType.NONE
|
|
36
48
|
DBCache = CacheType.DBCache
|
|
37
49
|
DBPrune = CacheType.DBPrune
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (1, 0,
|
|
31
|
+
__version__ = version = '1.0.9'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0, 9)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -489,6 +489,7 @@ class BlockAdapter:
|
|
|
489
489
|
@staticmethod
|
|
490
490
|
def normalize(
|
|
491
491
|
adapter: "BlockAdapter",
|
|
492
|
+
unique: bool = True,
|
|
492
493
|
) -> "BlockAdapter":
|
|
493
494
|
|
|
494
495
|
if getattr(adapter, "_is_normalized", False):
|
|
@@ -523,7 +524,10 @@ class BlockAdapter:
|
|
|
523
524
|
adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
|
|
524
525
|
adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
|
|
525
526
|
adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
|
|
526
|
-
|
|
527
|
+
# Some times, the cache_config will be None.
|
|
528
|
+
# So we do not perform unique check here.
|
|
529
|
+
if unique:
|
|
530
|
+
BlockAdapter.unique(adapter)
|
|
527
531
|
|
|
528
532
|
adapter._is_normalized = True
|
|
529
533
|
|
|
@@ -571,6 +575,10 @@ class BlockAdapter:
|
|
|
571
575
|
if not getattr(adapter, "_is_normalized", False):
|
|
572
576
|
raise RuntimeError("block_adapter must be normailzed.")
|
|
573
577
|
|
|
578
|
+
@classmethod
|
|
579
|
+
def is_normalized(cls, adapter: "BlockAdapter") -> bool:
|
|
580
|
+
return getattr(adapter, "_is_normalized", False)
|
|
581
|
+
|
|
574
582
|
@classmethod
|
|
575
583
|
def is_cached(cls, adapter: Any) -> bool:
|
|
576
584
|
if isinstance(adapter, cls):
|
|
@@ -592,6 +600,21 @@ class BlockAdapter:
|
|
|
592
600
|
else:
|
|
593
601
|
return getattr(adapter, "_is_cached", False)
|
|
594
602
|
|
|
603
|
+
@classmethod
|
|
604
|
+
def is_parallelized(cls, adapter: Any) -> bool:
|
|
605
|
+
if isinstance(adapter, cls):
|
|
606
|
+
cls.assert_normalized(adapter)
|
|
607
|
+
return getattr(adapter.transformer[0], "_is_parallelized", False)
|
|
608
|
+
elif isinstance(adapter, DiffusionPipeline):
|
|
609
|
+
return getattr(adapter.transformer, "_is_parallelized", False)
|
|
610
|
+
elif isinstance(adapter, torch.nn.Module):
|
|
611
|
+
return getattr(adapter, "_is_parallelized", False)
|
|
612
|
+
elif isinstance(adapter, list): # [TRN_0,...]
|
|
613
|
+
assert isinstance(adapter[0], torch.nn.Module)
|
|
614
|
+
return getattr(adapter[0], "_is_parallelized", False)
|
|
615
|
+
else:
|
|
616
|
+
return getattr(adapter, "_is_parallelized", False)
|
|
617
|
+
|
|
595
618
|
@classmethod
|
|
596
619
|
def nested_depth(cls, obj: Any):
|
|
597
620
|
# str: 0; List[str]: 1; List[List[str]]: 2
|
|
@@ -37,7 +37,7 @@ class BlockAdapterRegistry:
|
|
|
37
37
|
cls,
|
|
38
38
|
pipe: DiffusionPipeline | str | Any,
|
|
39
39
|
**kwargs,
|
|
40
|
-
) -> BlockAdapter:
|
|
40
|
+
) -> BlockAdapter | None:
|
|
41
41
|
if not isinstance(pipe, str):
|
|
42
42
|
pipe_cls_name: str = pipe.__class__.__name__
|
|
43
43
|
else:
|
|
@@ -47,7 +47,7 @@ class BlockAdapterRegistry:
|
|
|
47
47
|
if pipe_cls_name.startswith(name):
|
|
48
48
|
return cls._adapters[name](pipe, **kwargs)
|
|
49
49
|
|
|
50
|
-
return
|
|
50
|
+
return None
|
|
51
51
|
|
|
52
52
|
@classmethod
|
|
53
53
|
def has_separate_cfg(
|
|
@@ -52,6 +52,10 @@ class CachedAdapter:
|
|
|
52
52
|
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
53
53
|
pipe_or_adapter
|
|
54
54
|
)
|
|
55
|
+
assert block_adapter is not None, (
|
|
56
|
+
f"BlockAdapter for {pipe_or_adapter.__class__.__name__} "
|
|
57
|
+
"should not be None!"
|
|
58
|
+
)
|
|
55
59
|
if params_modifiers := context_kwargs.pop(
|
|
56
60
|
"params_modifiers",
|
|
57
61
|
None,
|
|
@@ -24,11 +24,13 @@ def enable_cache(
|
|
|
24
24
|
BlockAdapter,
|
|
25
25
|
],
|
|
26
26
|
# BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
|
|
27
|
-
cache_config:
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
27
|
+
cache_config: Optional[
|
|
28
|
+
Union[
|
|
29
|
+
BasicCacheConfig,
|
|
30
|
+
DBCacheConfig,
|
|
31
|
+
DBPruneConfig,
|
|
32
|
+
]
|
|
33
|
+
] = None,
|
|
32
34
|
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
33
35
|
calibrator_config: Optional[CalibratorConfig] = None,
|
|
34
36
|
# Modify cache context params for specific blocks.
|
|
@@ -154,13 +156,27 @@ def enable_cache(
|
|
|
154
156
|
>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
|
|
155
157
|
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
156
158
|
"""
|
|
159
|
+
# Precheck for compatibility of different configurations
|
|
160
|
+
if cache_config is None:
|
|
161
|
+
if parallelism_config is None:
|
|
162
|
+
# Set default cache config only when parallelism is not enabled
|
|
163
|
+
logger.info("cache_config is None, using default DBCacheConfig")
|
|
164
|
+
cache_config = DBCacheConfig()
|
|
165
|
+
else:
|
|
166
|
+
# Allow empty cache_config when parallelism is enabled
|
|
167
|
+
logger.warning(
|
|
168
|
+
"Parallelism is enabled and cache_config is None. Please manually "
|
|
169
|
+
"set cache_config to avoid potential compatibility issues. "
|
|
170
|
+
"Otherwise, cache will not be enabled."
|
|
171
|
+
)
|
|
172
|
+
|
|
157
173
|
# Collect cache context kwargs
|
|
158
174
|
context_kwargs = {}
|
|
159
175
|
if (cache_type := context_kwargs.get("cache_type", None)) is not None:
|
|
160
176
|
if cache_type == CacheType.NONE:
|
|
161
177
|
return pipe_or_adapter
|
|
162
178
|
|
|
163
|
-
#
|
|
179
|
+
# NOTE: Deprecated cache config params. These parameters are now retained
|
|
164
180
|
# for backward compatibility but will be removed in the future.
|
|
165
181
|
deprecated_kwargs = {
|
|
166
182
|
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
@@ -196,9 +212,9 @@ def enable_cache(
|
|
|
196
212
|
if cache_config is not None:
|
|
197
213
|
context_kwargs["cache_config"] = cache_config
|
|
198
214
|
|
|
199
|
-
#
|
|
215
|
+
# NOTE: Deprecated taylorseer params. These parameters are now retained
|
|
200
216
|
# for backward compatibility but will be removed in the future.
|
|
201
|
-
if (
|
|
217
|
+
if cache_config is not None and (
|
|
202
218
|
kwargs.get("enable_taylorseer", None) is not None
|
|
203
219
|
or kwargs.get("enable_encoder_taylorseer", None) is not None
|
|
204
220
|
):
|
|
@@ -226,16 +242,22 @@ def enable_cache(
|
|
|
226
242
|
if params_modifiers is not None:
|
|
227
243
|
context_kwargs["params_modifiers"] = params_modifiers
|
|
228
244
|
|
|
229
|
-
if
|
|
230
|
-
pipe_or_adapter
|
|
231
|
-
pipe_or_adapter
|
|
232
|
-
|
|
233
|
-
|
|
245
|
+
if cache_config is not None:
|
|
246
|
+
if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
|
|
247
|
+
pipe_or_adapter = CachedAdapter.apply(
|
|
248
|
+
pipe_or_adapter,
|
|
249
|
+
**context_kwargs,
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"type: {type(pipe_or_adapter)} is not valid, "
|
|
254
|
+
"Please pass DiffusionPipeline or BlockAdapter"
|
|
255
|
+
"for the 1's position param: pipe_or_adapter"
|
|
256
|
+
)
|
|
234
257
|
else:
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
"
|
|
238
|
-
"for the 1's position param: pipe_or_adapter"
|
|
258
|
+
logger.warning(
|
|
259
|
+
"cache_config is None, skip enabling cache for "
|
|
260
|
+
f"{pipe_or_adapter.__class__.__name__}."
|
|
239
261
|
)
|
|
240
262
|
|
|
241
263
|
# NOTE: Users should always enable parallelism after applying
|
|
@@ -244,19 +266,45 @@ def enable_cache(
|
|
|
244
266
|
assert isinstance(
|
|
245
267
|
parallelism_config, ParallelismConfig
|
|
246
268
|
), "parallelism_config should be of type ParallelismConfig."
|
|
269
|
+
|
|
270
|
+
transformers = []
|
|
247
271
|
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
248
|
-
|
|
272
|
+
adapter = BlockAdapterRegistry.get_adapter(pipe_or_adapter)
|
|
273
|
+
if adapter is None:
|
|
274
|
+
assert hasattr(pipe_or_adapter, "transformer"), (
|
|
275
|
+
"The given DiffusionPipeline does not have "
|
|
276
|
+
"a 'transformer' attribute, cannot enable "
|
|
277
|
+
"parallelism."
|
|
278
|
+
)
|
|
279
|
+
transformers = [pipe_or_adapter.transformer]
|
|
280
|
+
else:
|
|
281
|
+
adapter = BlockAdapter.normalize(adapter, unique=False)
|
|
282
|
+
transformers = BlockAdapter.flatten(adapter.transformer)
|
|
249
283
|
else:
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
284
|
+
if not BlockAdapter.is_normalized(pipe_or_adapter):
|
|
285
|
+
pipe_or_adapter = BlockAdapter.normalize(
|
|
286
|
+
pipe_or_adapter, unique=False
|
|
287
|
+
)
|
|
288
|
+
transformers = BlockAdapter.flatten(pipe_or_adapter.transformer)
|
|
289
|
+
|
|
290
|
+
if len(transformers) == 0:
|
|
291
|
+
logger.warning(
|
|
292
|
+
"No transformer is detected in the "
|
|
293
|
+
"BlockAdapter, skip enabling parallelism."
|
|
294
|
+
)
|
|
295
|
+
return pipe_or_adapter
|
|
296
|
+
|
|
297
|
+
if len(transformers) > 1:
|
|
298
|
+
logger.warning(
|
|
299
|
+
"Multiple transformers are detected in the "
|
|
300
|
+
"BlockAdapter, all transfomers will be "
|
|
301
|
+
"enabled for parallelism."
|
|
302
|
+
)
|
|
303
|
+
for i, transformer in enumerate(transformers):
|
|
304
|
+
# Enable parallelism for the transformer inplace
|
|
305
|
+
transformers[i] = enable_parallelism(
|
|
306
|
+
transformer, parallelism_config
|
|
256
307
|
)
|
|
257
|
-
transformer = BlockAdapter.flatten(pipe_or_adapter.transformer)[0]
|
|
258
|
-
# Enable parallelism for the transformer inplace
|
|
259
|
-
transformer = enable_parallelism(transformer, parallelism_config)
|
|
260
308
|
return pipe_or_adapter
|
|
261
309
|
|
|
262
310
|
|
cache_dit/metrics/__init__.py
CHANGED
|
@@ -1,3 +1,14 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import ImageReward
|
|
3
|
+
import lpips
|
|
4
|
+
import skimage
|
|
5
|
+
import scipy
|
|
6
|
+
except ImportError:
|
|
7
|
+
raise ImportError(
|
|
8
|
+
"Metrics functionality requires the 'metrics' extra dependencies. "
|
|
9
|
+
"Install with:\npip install cache-dit[metrics]"
|
|
10
|
+
)
|
|
11
|
+
|
|
1
12
|
from cache_dit.metrics.metrics import compute_psnr
|
|
2
13
|
from cache_dit.metrics.metrics import compute_ssim
|
|
3
14
|
from cache_dit.metrics.metrics import compute_mse
|
cache_dit/parallelism/backends/{parallel_difffusers.py → native_diffusers/parallel_difffusers.py}
RENAMED
|
@@ -54,17 +54,37 @@ def maybe_enable_parallelism(
|
|
|
54
54
|
ring_degree=parallelism_config.ring_size,
|
|
55
55
|
)
|
|
56
56
|
if cp_config is not None:
|
|
57
|
+
attention_backend = parallelism_config.parallel_kwargs.get(
|
|
58
|
+
"attention_backend", None
|
|
59
|
+
)
|
|
57
60
|
if hasattr(transformer, "enable_parallelism"):
|
|
58
61
|
if hasattr(transformer, "set_attention_backend"):
|
|
59
|
-
#
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
"
|
|
64
|
-
|
|
62
|
+
# _native_cudnn, flash, etc.
|
|
63
|
+
if attention_backend is None:
|
|
64
|
+
# Now only _native_cudnn is supported for parallelism
|
|
65
|
+
# issue: https://github.com/huggingface/diffusers/pull/12443
|
|
66
|
+
transformer.set_attention_backend("_native_cudnn")
|
|
67
|
+
logger.warning(
|
|
68
|
+
"attention_backend is None, set default attention backend "
|
|
69
|
+
"to _native_cudnn for parallelism because of the issue: "
|
|
70
|
+
"https://github.com/huggingface/diffusers/pull/12443"
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
transformer.set_attention_backend(attention_backend)
|
|
74
|
+
logger.info(
|
|
75
|
+
"Found attention_backend from config, set attention "
|
|
76
|
+
f"backend to: {attention_backend}"
|
|
77
|
+
)
|
|
78
|
+
cp_plan = parallelism_config.parallel_kwargs.get(
|
|
79
|
+
"cp_plan", None
|
|
80
|
+
)
|
|
81
|
+
if cp_plan is not None:
|
|
82
|
+
logger.info(
|
|
83
|
+
f"Using custom context parallelism plan: {cp_plan}"
|
|
65
84
|
)
|
|
66
|
-
|
|
67
|
-
|
|
85
|
+
transformer.enable_parallelism(
|
|
86
|
+
config=cp_config, cp_plan=cp_plan
|
|
87
|
+
)
|
|
68
88
|
else:
|
|
69
89
|
raise ValueError(
|
|
70
90
|
f"{transformer.__class__.__name__} does not support context parallelism."
|
|
File without changes
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import dataclasses
|
|
2
|
+
from typing import Optional, Dict, Any
|
|
2
3
|
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
3
4
|
from cache_dit.logger import init_logger
|
|
4
5
|
|
|
@@ -20,6 +21,13 @@ class ParallelismConfig:
|
|
|
20
21
|
# tp_size (`int`, *optional*):
|
|
21
22
|
# The degree of tensor parallelism.
|
|
22
23
|
tp_size: int = None
|
|
24
|
+
# parallel_kwargs (`dict`, *optional*):
|
|
25
|
+
# Additional kwargs for parallelism backends. For example, for
|
|
26
|
+
# NATIVE_DIFFUSER backend, it can include `cp_plan` and
|
|
27
|
+
# `attention_backend` arguments for `Context Parallelism`.
|
|
28
|
+
parallel_kwargs: Optional[Dict[str, Any]] = dataclasses.field(
|
|
29
|
+
default_factory=dict
|
|
30
|
+
)
|
|
23
31
|
|
|
24
32
|
def __post_init__(self):
|
|
25
33
|
assert ParallelismBackend.is_supported(self.backend), (
|
|
@@ -22,7 +22,7 @@ def enable_parallelism(
|
|
|
22
22
|
return transformer
|
|
23
23
|
|
|
24
24
|
if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
|
|
25
|
-
from cache_dit.parallelism.backends.
|
|
25
|
+
from cache_dit.parallelism.backends.native_diffusers import (
|
|
26
26
|
maybe_enable_parallelism,
|
|
27
27
|
native_diffusers_parallelism_available,
|
|
28
28
|
)
|
|
@@ -40,8 +40,12 @@ def enable_parallelism(
|
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
transformer._is_parallelized = True # type: ignore[attr-defined]
|
|
43
|
+
# Use `parallelism` not `parallel` to avoid name conflict with diffusers.
|
|
43
44
|
transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
|
|
44
|
-
logger.info(
|
|
45
|
+
logger.info(
|
|
46
|
+
f"Enabled parallelism: {parallelism_config.strify(True)}, "
|
|
47
|
+
f"transformer id:{id(transformer)}"
|
|
48
|
+
)
|
|
45
49
|
return transformer
|
|
46
50
|
|
|
47
51
|
|
cache_dit/quantize/__init__.py
CHANGED
cache_dit/utils.py
CHANGED
|
@@ -79,25 +79,31 @@ def summary(
|
|
|
79
79
|
transformer_2 = None
|
|
80
80
|
else:
|
|
81
81
|
transformer = adapter_or_others.transformer
|
|
82
|
-
transformer_2 = None
|
|
82
|
+
transformer_2 = None # Only for Wan2.2
|
|
83
83
|
if hasattr(adapter_or_others, "transformer_2"):
|
|
84
84
|
transformer_2 = adapter_or_others.transformer_2
|
|
85
85
|
|
|
86
|
-
if
|
|
86
|
+
if all(
|
|
87
|
+
(
|
|
88
|
+
not BlockAdapter.is_cached(transformer),
|
|
89
|
+
not BlockAdapter.is_parallelized(transformer),
|
|
90
|
+
)
|
|
91
|
+
):
|
|
87
92
|
return [CacheStats()]
|
|
88
93
|
|
|
89
94
|
blocks_stats: List[CacheStats] = []
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
95
|
+
if BlockAdapter.is_cached(transformer):
|
|
96
|
+
for blocks in BlockAdapter.find_blocks(transformer):
|
|
97
|
+
blocks_stats.append(
|
|
98
|
+
_summary(
|
|
99
|
+
blocks,
|
|
100
|
+
details=details,
|
|
101
|
+
logging=logging,
|
|
102
|
+
**kwargs,
|
|
103
|
+
)
|
|
97
104
|
)
|
|
98
|
-
)
|
|
99
105
|
|
|
100
|
-
if transformer_2 is not None:
|
|
106
|
+
if transformer_2 is not None and BlockAdapter.is_cached(transformer_2):
|
|
101
107
|
for blocks in BlockAdapter.find_blocks(transformer_2):
|
|
102
108
|
blocks_stats.append(
|
|
103
109
|
_summary(
|
|
@@ -126,7 +132,11 @@ def summary(
|
|
|
126
132
|
)
|
|
127
133
|
)
|
|
128
134
|
|
|
129
|
-
blocks_stats = [
|
|
135
|
+
blocks_stats = [
|
|
136
|
+
stats
|
|
137
|
+
for stats in blocks_stats
|
|
138
|
+
if (stats.cache_options or stats.parallelism_config)
|
|
139
|
+
]
|
|
130
140
|
|
|
131
141
|
return blocks_stats if len(blocks_stats) else [CacheStats()]
|
|
132
142
|
|
|
@@ -160,6 +170,8 @@ def strify(
|
|
|
160
170
|
Dict[str, Any],
|
|
161
171
|
],
|
|
162
172
|
) -> str:
|
|
173
|
+
|
|
174
|
+
parallelism_config: ParallelismConfig = None
|
|
163
175
|
if isinstance(adapter_or_others, BlockAdapter):
|
|
164
176
|
stats = summary(adapter_or_others, logging=False)[-1]
|
|
165
177
|
cache_options = stats.cache_options
|
|
@@ -182,8 +194,8 @@ def strify(
|
|
|
182
194
|
cache_options = adapter_or_others
|
|
183
195
|
cached_steps = None
|
|
184
196
|
cache_type = cache_options.get("cache_type", CacheType.NONE)
|
|
185
|
-
|
|
186
197
|
stats = None
|
|
198
|
+
parallelism_config = cache_options.get("parallelism_config", None)
|
|
187
199
|
|
|
188
200
|
if cache_type == CacheType.NONE:
|
|
189
201
|
return "NONE"
|
|
@@ -193,7 +205,10 @@ def strify(
|
|
|
193
205
|
"DiffusionPipeline | CacheStats | Dict[str, Any]"
|
|
194
206
|
)
|
|
195
207
|
|
|
196
|
-
if not
|
|
208
|
+
if stats is not None:
|
|
209
|
+
parallelism_config = stats.parallelism_config
|
|
210
|
+
|
|
211
|
+
if not cache_options and parallelism_config is None:
|
|
197
212
|
return "NONE"
|
|
198
213
|
|
|
199
214
|
def cache_str():
|
|
@@ -219,14 +234,14 @@ def strify(
|
|
|
219
234
|
return "T0O0"
|
|
220
235
|
|
|
221
236
|
def parallelism_str():
|
|
222
|
-
if stats is None:
|
|
223
|
-
return ""
|
|
224
|
-
parallelism_config: ParallelismConfig = stats.parallelism_config
|
|
225
237
|
if parallelism_config is not None:
|
|
226
238
|
return f"_{parallelism_config.strify()}"
|
|
227
239
|
return ""
|
|
228
240
|
|
|
229
|
-
cache_type_str = f"{cache_str()}
|
|
241
|
+
cache_type_str = f"{cache_str()}"
|
|
242
|
+
if cache_type_str != "NONE":
|
|
243
|
+
cache_type_str += f"_{calibrator_str()}"
|
|
244
|
+
cache_type_str += f"{parallelism_str()}"
|
|
230
245
|
|
|
231
246
|
if cached_steps:
|
|
232
247
|
cache_type_str += f"_S{cached_steps}"
|
|
@@ -245,6 +260,7 @@ def _summary(
|
|
|
245
260
|
) -> CacheStats:
|
|
246
261
|
cache_stats = CacheStats()
|
|
247
262
|
|
|
263
|
+
# Get stats from transformer
|
|
248
264
|
if not isinstance(pipe_or_module, torch.nn.Module):
|
|
249
265
|
assert hasattr(pipe_or_module, "transformer")
|
|
250
266
|
module = pipe_or_module.transformer
|
|
@@ -1,37 +1,33 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.9
|
|
4
4
|
Summary: A Unified, Flexible and Training-free Cache Acceleration Framework for 🤗Diffusers.
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
7
|
-
Project-URL: Repository, https://github.com/vipshop/cache-dit
|
|
8
|
-
Project-URL: Homepage, https://github.com/vipshop/cache-dit
|
|
7
|
+
Project-URL: Repository, https://github.com/vipshop/cache-dit
|
|
8
|
+
Project-URL: Homepage, https://github.com/vipshop/cache-dit
|
|
9
|
+
Project-URL: GitHub, https://github.com/vipshop/cache-dit
|
|
9
10
|
Requires-Python: >=3.10
|
|
10
11
|
Description-Content-Type: text/markdown
|
|
11
12
|
License-File: LICENSE
|
|
12
|
-
Requires-Dist: packaging
|
|
13
13
|
Requires-Dist: pyyaml
|
|
14
14
|
Requires-Dist: torch>=2.7.1
|
|
15
|
-
Requires-Dist: transformers>=4.55.2
|
|
16
15
|
Requires-Dist: diffusers>=0.35.1
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist: torchao>=0.12.0
|
|
21
|
-
Requires-Dist: image-reward
|
|
22
|
-
Provides-Extra: all
|
|
16
|
+
Requires-Dist: transformers>=4.55.2
|
|
17
|
+
Provides-Extra: quantization
|
|
18
|
+
Requires-Dist: torchao>=0.12.0; extra == "quantization"
|
|
23
19
|
Provides-Extra: metrics
|
|
20
|
+
Requires-Dist: scipy; extra == "metrics"
|
|
21
|
+
Requires-Dist: scikit-image; extra == "metrics"
|
|
24
22
|
Requires-Dist: image-reward; extra == "metrics"
|
|
25
|
-
Requires-Dist: pytorch-fid; extra == "metrics"
|
|
26
23
|
Requires-Dist: lpips==0.1.4; extra == "metrics"
|
|
27
24
|
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: packaging; extra == "dev"
|
|
28
26
|
Requires-Dist: pre-commit; extra == "dev"
|
|
29
27
|
Requires-Dist: pytest<8.0.0,>=7.0.0; extra == "dev"
|
|
30
28
|
Requires-Dist: pytest-html; extra == "dev"
|
|
31
29
|
Requires-Dist: expecttest; extra == "dev"
|
|
32
30
|
Requires-Dist: hypothesis; extra == "dev"
|
|
33
|
-
Requires-Dist: transformers; extra == "dev"
|
|
34
|
-
Requires-Dist: diffusers; extra == "dev"
|
|
35
31
|
Requires-Dist: accelerate; extra == "dev"
|
|
36
32
|
Requires-Dist: peft; extra == "dev"
|
|
37
33
|
Requires-Dist: protobuf; extra == "dev"
|
|
@@ -39,10 +35,10 @@ Requires-Dist: sentencepiece; extra == "dev"
|
|
|
39
35
|
Requires-Dist: opencv-python-headless; extra == "dev"
|
|
40
36
|
Requires-Dist: ftfy; extra == "dev"
|
|
41
37
|
Requires-Dist: scikit-image; extra == "dev"
|
|
42
|
-
|
|
38
|
+
Provides-Extra: all
|
|
39
|
+
Requires-Dist: cache-dit[quantization]; extra == "all"
|
|
40
|
+
Requires-Dist: cache-dit[metrics]; extra == "all"
|
|
43
41
|
Dynamic: license-file
|
|
44
|
-
Dynamic: provides-extra
|
|
45
|
-
Dynamic: requires-dist
|
|
46
42
|
Dynamic: requires-python
|
|
47
43
|
|
|
48
44
|
📚English | <a href="./README_CN.md">📚中文阅读 </a>
|
|
@@ -52,8 +48,9 @@ Dynamic: requires-python
|
|
|
52
48
|
<p align="center">
|
|
53
49
|
A <b>Unified</b>, Flexible and Training-free <b>Cache Acceleration</b> Framework for <b>🤗Diffusers</b> <br>
|
|
54
50
|
♥️ Cache Acceleration with <b>One-line</b> Code ~ ♥️ <br>
|
|
55
|
-
🔥<
|
|
56
|
-
🔥<
|
|
51
|
+
🔥<a href="./docs/User_Guide.md">Forward Pattern Matching</a> | <a href="./docs/User_Guide.md">Automatic Block Adapter</a>🔥 <br>
|
|
52
|
+
🔥<a href="./docs/User_Guide.md"><b>DBCache</b></a> | <a href="./docs/User_Guide.md"><b>DBPrune</b></a> | <a href="./docs/User_Guide.md">Hybrid <b>TaylorSeer</b> Calibrator</a> | <a href="./docs/User_Guide.md">Cache CFG</a>🔥<br>
|
|
53
|
+
🔥<a href="./docs/User_Guide.md"><b>Context Parallelism</b></a> | <a href="./docs/User_Guide.md">Torch Compile Compatible</a> | <a href="./docs/User_Guide.md">SOTA</a>🔥
|
|
57
54
|
</p>
|
|
58
55
|
<div align='center'>
|
|
59
56
|
<img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
|
|
@@ -173,7 +170,7 @@ Dynamic: requires-python
|
|
|
173
170
|
## 🔥Hightlight <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a>
|
|
174
171
|
|
|
175
172
|
We are excited to announce that the **first API-stable version (v1.0.0)** of cache-dit has finally been released!
|
|
176
|
-
**[cache-dit](https://github.com/vipshop/cache-dit)** is a **Unified**, **Flexible**, and **Training-free** cache acceleration framework for 🤗 Diffusers, enabling cache acceleration with just **one line** of code. Key features: **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid
|
|
173
|
+
**[cache-dit](https://github.com/vipshop/cache-dit)** is a **Unified**, **Flexible**, and **Training-free** cache acceleration framework for 🤗 Diffusers, enabling cache acceleration with just **one line** of code. Key features: **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **DBCache**, **DBPrune**, **Hybrid TaylorSeer Calibrator**, **Hybrid Cache CFG**, **Context Parallelism**, **Torch Compile Compatible** and **🎉SOTA** performance.
|
|
177
174
|
|
|
178
175
|
```bash
|
|
179
176
|
pip3 install -U cache-dit # pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
@@ -204,6 +201,7 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
|
|
|
204
201
|
|
|
205
202
|
## 🔥Important News
|
|
206
203
|
|
|
204
|
+
- 2025.10.23: 🎉Now cache-dit supported the [Kandinsky5 T2V](https://github.com/ai-forever/Kandinsky-5) and [Photoroom/PRX](https://github.com/huggingface/diffusers/pull/12456) pipelines.
|
|
207
205
|
- 2025.10.20: 🔥Now cache-dit supported the **[Hybrid Cache + Context Parallelism](./docs/User_Guide.md/#️hybrid-context-parallelism)** scheme!🔥
|
|
208
206
|
- 2025.10.16: 🎉cache-dit + [**🔥nunchaku 4-bits**](https://github.com/nunchaku-tech/nunchaku) supported: [Qwen-Image-Lightning 4/8 steps](./examples/quantize/).
|
|
209
207
|
- 2025.10.15: 🎉cache-dit now supported [**🔥nunchaku**](https://github.com/nunchaku-tech/nunchaku): Qwen-Image/FLUX.1 [4-bits examples](./examples/quantize/)
|
|
@@ -1,19 +1,19 @@
|
|
|
1
|
-
cache_dit/__init__.py,sha256=
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
1
|
+
cache_dit/__init__.py,sha256=Azqj-3QMQK4HZDTGgyUtAfatUwuU-YQ4w8erJSyrsbE,2082
|
|
2
|
+
cache_dit/_version.py,sha256=JXTThZsIEQNG8lSfLsQqv8iVrLso3IkPevWFvCathJU,704
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
|
-
cache_dit/utils.py,sha256=
|
|
4
|
+
cache_dit/utils.py,sha256=rjVXUyr7JUabO9bY2puXrfPvHl3Sp4eX3MHLY90Cau8,18432
|
|
5
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
6
6
|
cache_dit/cache_factory/__init__.py,sha256=5UjrpxLVlmjHttTL0O14fD5oU5uKI3FKYevL613ibFQ,1848
|
|
7
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/cache_interface.py,sha256=_7RSugGxNArLP2i3qmfq-hon_OTPCz3DSZbwQoCemcc,16558
|
|
8
8
|
cache_dit/cache_factory/cache_types.py,sha256=QnWfaS52UOXQtnoCUOwwz4ziY0dyBta6vQ6hvgtdV44,1404
|
|
9
9
|
cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
|
|
10
10
|
cache_dit/cache_factory/params_modifier.py,sha256=2T98IbepAolWW6GwQsqUDsRzu0k65vo7BOrN3V8mKog,3606
|
|
11
11
|
cache_dit/cache_factory/utils.py,sha256=S3SD6Zhexzhkqnmfo830v6oNLm8stZe32nF4VdxD_bA,2497
|
|
12
12
|
cache_dit/cache_factory/block_adapters/__init__.py,sha256=eeBcWUMIvS-x3GcD1LNesW2SuB9V5mtwG9MoUBWHsL8,19765
|
|
13
|
-
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=
|
|
14
|
-
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=
|
|
13
|
+
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=hnHZbM3UCIk1fb8HS8Z42w7kJ76xNIl36thONSjkT4g,23267
|
|
14
|
+
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=NvzeeBM32pxuUymcyNibcTgX-9UnnDTRt8_zTXcci6c,2591
|
|
15
15
|
cache_dit/cache_factory/cache_adapters/__init__.py,sha256=py71WGD3JztQ1uk6qdLVbzYcQ1rvqFidNNaQYo7tqTo,79
|
|
16
|
-
cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=
|
|
16
|
+
cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=_TfI4c-evcxP3mngiPhWYoVoOJ-q4xVGEqukGhZ7b0w,24270
|
|
17
17
|
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=cpxzmDcUhbXcReHqaKSnWyEEbIg1H91Pz5hE3z9Xj3k,9984
|
|
18
18
|
cache_dit/cache_factory/cache_blocks/offload_utils.py,sha256=wusgcqaCrwEjvv7Guy-6VXhNOgPPUrBV2sSVuRmGuvo,3513
|
|
19
19
|
cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=j4bTafqU5DLQhzP_X5XwOk-QUVLWkGrX-Q6JZvBGHh0,666
|
|
@@ -44,7 +44,7 @@ cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0
|
|
|
44
44
|
cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
|
|
45
45
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
46
|
cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
47
|
-
cache_dit/metrics/__init__.py,sha256=
|
|
47
|
+
cache_dit/metrics/__init__.py,sha256=Y_JrBr9XE6NKXwyXc7d_-PaX9c_rk5FKms-IYgCyHmY,936
|
|
48
48
|
cache_dit/metrics/clip_score.py,sha256=ERNCFQFJKzJdbIX9OAg-1LiSPuXUVHLOFxbf2gcENpc,3938
|
|
49
49
|
cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
|
|
50
50
|
cache_dit/metrics/fid.py,sha256=ZM_FM0XERtpnkMUfphmw2aOdljrh1uba-pnYItu0q6M,18219
|
|
@@ -54,15 +54,17 @@ cache_dit/metrics/lpips.py,sha256=hrHrmdM-f2B4TKDs0xLqJO5JFaYcCjq2qNIR8oCrVkc,81
|
|
|
54
54
|
cache_dit/metrics/metrics.py,sha256=AZbQyoavE-djvyRUZ_EfCIrWSQbiWQFo7n2dhn7XptE,40466
|
|
55
55
|
cache_dit/parallelism/__init__.py,sha256=dheBG5_TZCuwctviMslpAEgB-B3N8F816bE51qsw_fU,210
|
|
56
56
|
cache_dit/parallelism/parallel_backend.py,sha256=js1soTMenLeAyPMsBgdI3gWcdXoqjWgBD-PuFEywMr0,508
|
|
57
|
-
cache_dit/parallelism/parallel_config.py,sha256=
|
|
58
|
-
cache_dit/parallelism/parallel_interface.py,sha256=
|
|
59
|
-
cache_dit/parallelism/backends/
|
|
60
|
-
cache_dit/
|
|
57
|
+
cache_dit/parallelism/parallel_config.py,sha256=ZGCWsSu4LcBsjZ2h8NACHhw8WYi_oSqJbZaZIRdQl1Q,2120
|
|
58
|
+
cache_dit/parallelism/parallel_interface.py,sha256=8jNzmZdExMl0aKMAJHpYRlYfz3Ex65KzF9ZrHKlHi6Y,2340
|
|
59
|
+
cache_dit/parallelism/backends/native_diffusers/__init__.py,sha256=T_6GeBA7TRiVbvtqGLLH2flkRiK0o7JBREt2xhS_-YE,242
|
|
60
|
+
cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py,sha256=aRwL1lJXOWl9JaJx9XRv391irkN5xFWJiOOIT_1lu0E,3476
|
|
61
|
+
cache_dit/parallelism/backends/native_pytorch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
62
|
+
cache_dit/quantize/__init__.py,sha256=rUu0V9VRjOgwXuIUHHAI-osivNjAdUsi-jpkDbFp6Gk,278
|
|
61
63
|
cache_dit/quantize/quantize_ao.py,sha256=bbEUwsrMp3bMuRw8qJZREIvCHaJRQoZyfMjlu4ImRMI,6315
|
|
62
64
|
cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
|
|
63
|
-
cache_dit-1.0.
|
|
64
|
-
cache_dit-1.0.
|
|
65
|
-
cache_dit-1.0.
|
|
66
|
-
cache_dit-1.0.
|
|
67
|
-
cache_dit-1.0.
|
|
68
|
-
cache_dit-1.0.
|
|
65
|
+
cache_dit-1.0.9.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
66
|
+
cache_dit-1.0.9.dist-info/METADATA,sha256=ZAjv17YPgYMgNcyONGRh4fuwKIDhxjdOsWJYLJc3y18,29872
|
|
67
|
+
cache_dit-1.0.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
68
|
+
cache_dit-1.0.9.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
69
|
+
cache_dit-1.0.9.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
70
|
+
cache_dit-1.0.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|