cache-dit 1.0.5__py3-none-any.whl → 1.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +5 -3
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +12 -0
- cache_dit/cache_factory/cache_interface.py +35 -1
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/parallel_difffusers.py +56 -0
- cache_dit/parallelism/parallel_backend.py +18 -0
- cache_dit/parallelism/parallel_config.py +47 -0
- cache_dit/parallelism/parallel_interface.py +62 -0
- cache_dit/quantize/quantize_ao.py +27 -64
- cache_dit/utils.py +21 -1
- {cache_dit-1.0.5.dist-info → cache_dit-1.0.6.dist-info}/METADATA +11 -4
- {cache_dit-1.0.5.dist-info → cache_dit-1.0.6.dist-info}/RECORD +17 -12
- {cache_dit-1.0.5.dist-info → cache_dit-1.0.6.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.5.dist-info → cache_dit-1.0.6.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.5.dist-info → cache_dit-1.0.6.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.5.dist-info → cache_dit-1.0.6.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -4,8 +4,7 @@ except ImportError:
|
|
|
4
4
|
__version__ = "unknown version"
|
|
5
5
|
version_tuple = (0, 0, "unknown version")
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
from cache_dit.utils import strify
|
|
7
|
+
|
|
9
8
|
from cache_dit.utils import disable_print
|
|
10
9
|
from cache_dit.logger import init_logger
|
|
11
10
|
from cache_dit.cache_factory import load_options
|
|
@@ -28,7 +27,10 @@ from cache_dit.cache_factory import supported_pipelines
|
|
|
28
27
|
from cache_dit.cache_factory import get_adapter
|
|
29
28
|
from cache_dit.compile import set_compile_configs
|
|
30
29
|
from cache_dit.quantize import quantize
|
|
31
|
-
|
|
30
|
+
from cache_dit.parallelism import ParallelismBackend
|
|
31
|
+
from cache_dit.parallelism import ParallelismConfig
|
|
32
|
+
from cache_dit.utils import summary
|
|
33
|
+
from cache_dit.utils import strify
|
|
32
34
|
|
|
33
35
|
NONE = CacheType.NONE
|
|
34
36
|
DBCache = CacheType.DBCache
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (1, 0,
|
|
31
|
+
__version__ = version = '1.0.6'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0, 6)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -614,6 +614,18 @@ class CachedAdapter:
|
|
|
614
614
|
pipe_or_adapter, remove_stats, remove_stats, remove_stats
|
|
615
615
|
)
|
|
616
616
|
|
|
617
|
+
# maybe release parallelism stats
|
|
618
|
+
from cache_dit.parallelism.parallel_interface import (
|
|
619
|
+
remove_parallelism_stats,
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
cls.release_hooks(
|
|
623
|
+
pipe_or_adapter,
|
|
624
|
+
remove_parallelism_stats,
|
|
625
|
+
remove_parallelism_stats,
|
|
626
|
+
remove_parallelism_stats,
|
|
627
|
+
)
|
|
628
|
+
|
|
617
629
|
@classmethod
|
|
618
630
|
def release_hooks(
|
|
619
631
|
cls,
|
|
@@ -9,6 +9,8 @@ from cache_dit.cache_factory.cache_contexts import DBCacheConfig
|
|
|
9
9
|
from cache_dit.cache_factory.cache_contexts import DBPruneConfig
|
|
10
10
|
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
11
11
|
from cache_dit.cache_factory.params_modifier import ParamsModifier
|
|
12
|
+
from cache_dit.parallelism import ParallelismConfig
|
|
13
|
+
from cache_dit.parallelism import enable_parallelism
|
|
12
14
|
|
|
13
15
|
from cache_dit.logger import init_logger
|
|
14
16
|
|
|
@@ -37,6 +39,8 @@ def enable_cache(
|
|
|
37
39
|
List[List[ParamsModifier]],
|
|
38
40
|
]
|
|
39
41
|
] = None,
|
|
42
|
+
# Config for Parallelism
|
|
43
|
+
parallelism_config: Optional[ParallelismConfig] = None,
|
|
40
44
|
# Other cache context kwargs: Deprecated cache kwargs
|
|
41
45
|
**kwargs,
|
|
42
46
|
) -> Union[
|
|
@@ -127,6 +131,15 @@ def enable_cache(
|
|
|
127
131
|
**kwargs: (`dict`, *optional*, defaults to {}):
|
|
128
132
|
The same as 'kwargs' param in cache_dit.enable_cache() interface.
|
|
129
133
|
|
|
134
|
+
parallelism_config (`ParallelismConfig`, *optional*, defaults to None):
|
|
135
|
+
Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
|
|
136
|
+
parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
|
|
137
|
+
for more details of ParallelismConfig.
|
|
138
|
+
ulysses_size: (`int`, *optional*, defaults to None):
|
|
139
|
+
The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
|
|
140
|
+
ring_size: (`int`, *optional*, defaults to None):
|
|
141
|
+
The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
|
|
142
|
+
|
|
130
143
|
kwargs (`dict`, *optional*, defaults to {})
|
|
131
144
|
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
|
|
132
145
|
for more details.
|
|
@@ -214,7 +227,7 @@ def enable_cache(
|
|
|
214
227
|
context_kwargs["params_modifiers"] = params_modifiers
|
|
215
228
|
|
|
216
229
|
if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
|
|
217
|
-
|
|
230
|
+
pipe_or_adapter = CachedAdapter.apply(
|
|
218
231
|
pipe_or_adapter,
|
|
219
232
|
**context_kwargs,
|
|
220
233
|
)
|
|
@@ -225,6 +238,27 @@ def enable_cache(
|
|
|
225
238
|
"for the 1's position param: pipe_or_adapter"
|
|
226
239
|
)
|
|
227
240
|
|
|
241
|
+
# NOTE: Users should always enable parallelism after applying
|
|
242
|
+
# cache to avoid hooks conflict.
|
|
243
|
+
if parallelism_config is not None:
|
|
244
|
+
assert isinstance(
|
|
245
|
+
parallelism_config, ParallelismConfig
|
|
246
|
+
), "parallelism_config should be of type ParallelismConfig."
|
|
247
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
248
|
+
transformer = pipe_or_adapter.transformer
|
|
249
|
+
else:
|
|
250
|
+
assert BlockAdapter.assert_normalized(pipe_or_adapter)
|
|
251
|
+
assert (
|
|
252
|
+
len(BlockAdapter.flatten(pipe_or_adapter.transformer)) == 1
|
|
253
|
+
), (
|
|
254
|
+
"Only single transformer is supported to enable parallelism "
|
|
255
|
+
"currently for BlockAdapter."
|
|
256
|
+
)
|
|
257
|
+
transformer = BlockAdapter.flatten(pipe_or_adapter.transformer)[0]
|
|
258
|
+
# Enable parallelism for the transformer inplace
|
|
259
|
+
transformer = enable_parallelism(transformer, parallelism_config)
|
|
260
|
+
return pipe_or_adapter
|
|
261
|
+
|
|
228
262
|
|
|
229
263
|
def disable_cache(
|
|
230
264
|
pipe_or_adapter: Union[
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from diffusers import ContextParallelConfig
|
|
7
|
+
|
|
8
|
+
def native_diffusers_parallelism_available() -> bool:
|
|
9
|
+
return True
|
|
10
|
+
|
|
11
|
+
except ImportError:
|
|
12
|
+
ContextParallelConfig = None
|
|
13
|
+
|
|
14
|
+
def native_diffusers_parallelism_available() -> bool:
|
|
15
|
+
return False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
19
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
20
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def maybe_enable_parallelism(
|
|
24
|
+
transformer: torch.nn.Module,
|
|
25
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
26
|
+
) -> torch.nn.Module:
|
|
27
|
+
assert isinstance(transformer, ModelMixin)
|
|
28
|
+
if parallelism_config is None:
|
|
29
|
+
return transformer
|
|
30
|
+
|
|
31
|
+
if (
|
|
32
|
+
parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
|
|
33
|
+
and native_diffusers_parallelism_available()
|
|
34
|
+
):
|
|
35
|
+
cp_config = None
|
|
36
|
+
if (
|
|
37
|
+
parallelism_config.ulysses_size is not None
|
|
38
|
+
or parallelism_config.ring_size is not None
|
|
39
|
+
):
|
|
40
|
+
cp_config = ContextParallelConfig(
|
|
41
|
+
ulysses_degree=parallelism_config.ulysses_size,
|
|
42
|
+
ring_degree=parallelism_config.ring_size,
|
|
43
|
+
)
|
|
44
|
+
if cp_config is not None:
|
|
45
|
+
if hasattr(transformer, "enable_parallelism"):
|
|
46
|
+
if hasattr(transformer, "set_attention_backend"): # type: ignore[attr-defined]
|
|
47
|
+
# Now only _native_cudnn is supported for parallelism
|
|
48
|
+
# issue: https://github.com/huggingface/diffusers/pull/12443
|
|
49
|
+
transformer.set_attention_backend("_native_cudnn") # type: ignore[attr-defined]
|
|
50
|
+
transformer.enable_parallelism(config=cp_config)
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"{transformer.__class__.__name__} does not support context parallelism."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return transformer
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ParallelismBackend(Enum):
|
|
5
|
+
NATIVE_DIFFUSER = "Native_Diffuser"
|
|
6
|
+
NATIVE_PYTORCH = "Native_PyTorch"
|
|
7
|
+
NONE = "None"
|
|
8
|
+
|
|
9
|
+
@classmethod
|
|
10
|
+
def is_supported(cls, backend: "ParallelismBackend") -> bool:
|
|
11
|
+
# Now, only Native_Diffuser backend is supported
|
|
12
|
+
if backend in [cls.NATIVE_DIFFUSER]:
|
|
13
|
+
try:
|
|
14
|
+
import diffusers # noqa: F401
|
|
15
|
+
except ImportError:
|
|
16
|
+
return False
|
|
17
|
+
return True
|
|
18
|
+
return False
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
3
|
+
from cache_dit.logger import init_logger
|
|
4
|
+
|
|
5
|
+
logger = init_logger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclasses.dataclass
|
|
9
|
+
class ParallelismConfig:
|
|
10
|
+
# Parallelism backend, defaults to NATIVE_DIFFUSER
|
|
11
|
+
backend: ParallelismBackend = ParallelismBackend.NATIVE_DIFFUSER
|
|
12
|
+
# Context parallelism config
|
|
13
|
+
# ulysses_size (`int`, *optional*):
|
|
14
|
+
# The degree of ulysses parallelism.
|
|
15
|
+
ulysses_size: int = None
|
|
16
|
+
# ring_size (`int`, *optional*):
|
|
17
|
+
# The degree of ring parallelism.
|
|
18
|
+
ring_size: int = None
|
|
19
|
+
# Tensor parallelism config
|
|
20
|
+
# tp_size (`int`, *optional*):
|
|
21
|
+
# The degree of tensor parallelism.
|
|
22
|
+
tp_size: int = None
|
|
23
|
+
|
|
24
|
+
def __post_init__(self):
|
|
25
|
+
assert ParallelismBackend.is_supported(self.backend), (
|
|
26
|
+
f"Parallel backend {self.backend} is not supported. "
|
|
27
|
+
f"Please make sure the required packages are installed."
|
|
28
|
+
)
|
|
29
|
+
assert self.tp_size is None, "Tensor parallelism is not supported yet."
|
|
30
|
+
|
|
31
|
+
def strify(self, details: bool = False) -> str:
|
|
32
|
+
if details:
|
|
33
|
+
return (
|
|
34
|
+
f"ParallelismConfig(backend={self.backend}, "
|
|
35
|
+
f"ulysses_size={self.ulysses_size}, "
|
|
36
|
+
f"ring_size={self.ring_size}, "
|
|
37
|
+
f"tp_size={self.tp_size})"
|
|
38
|
+
)
|
|
39
|
+
else:
|
|
40
|
+
parallel_str = ""
|
|
41
|
+
if self.ulysses_size is not None:
|
|
42
|
+
parallel_str += f"Ulysses{self.ulysses_size}"
|
|
43
|
+
if self.ring_size is not None:
|
|
44
|
+
parallel_str += f"Ring{self.ring_size}"
|
|
45
|
+
if self.tp_size is not None:
|
|
46
|
+
parallel_str += f"TP{self.tp_size}"
|
|
47
|
+
return parallel_str
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
3
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
4
|
+
from cache_dit.logger import init_logger
|
|
5
|
+
|
|
6
|
+
logger = init_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def enable_parallelism(
|
|
10
|
+
transformer: torch.nn.Module,
|
|
11
|
+
parallelism_config: ParallelismConfig,
|
|
12
|
+
) -> torch.nn.Module:
|
|
13
|
+
assert isinstance(transformer, torch.nn.Module), (
|
|
14
|
+
"transformer must be an instance of torch.nn.Module, "
|
|
15
|
+
f"but got {type(transformer)}"
|
|
16
|
+
)
|
|
17
|
+
if getattr(transformer, "_is_parallelized", False):
|
|
18
|
+
logger.warning(
|
|
19
|
+
"The transformer is already parallelized. "
|
|
20
|
+
"Skipping parallelism enabling."
|
|
21
|
+
)
|
|
22
|
+
return transformer
|
|
23
|
+
|
|
24
|
+
if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
|
|
25
|
+
from cache_dit.parallelism.backends.parallel_difffusers import (
|
|
26
|
+
maybe_enable_parallelism,
|
|
27
|
+
native_diffusers_parallelism_available,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
assert (
|
|
31
|
+
native_diffusers_parallelism_available()
|
|
32
|
+
), "Please install diffusers>=0.36.dev0 to use Native_Diffuser backend."
|
|
33
|
+
transformer = maybe_enable_parallelism(
|
|
34
|
+
transformer,
|
|
35
|
+
parallelism_config,
|
|
36
|
+
)
|
|
37
|
+
else:
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f"Parallel backend {parallelism_config.backend} is not supported yet."
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
transformer._is_parallelized = True # type: ignore[attr-defined]
|
|
43
|
+
transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
|
|
44
|
+
logger.info(f"Enabled parallelism: {parallelism_config.strify(True)}")
|
|
45
|
+
return transformer
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def remove_parallelism_stats(
|
|
49
|
+
transformer: torch.nn.Module,
|
|
50
|
+
) -> torch.nn.Module:
|
|
51
|
+
if not getattr(transformer, "_is_parallelized", False):
|
|
52
|
+
logger.warning(
|
|
53
|
+
"The transformer is not parallelized. "
|
|
54
|
+
"Skipping removing parallelism."
|
|
55
|
+
)
|
|
56
|
+
return transformer
|
|
57
|
+
|
|
58
|
+
if hasattr(transformer, "_is_parallelized"):
|
|
59
|
+
del transformer._is_parallelized # type: ignore[attr-defined]
|
|
60
|
+
if hasattr(transformer, "_parallelism_config"):
|
|
61
|
+
del transformer._parallelism_config # type: ignore[attr-defined]
|
|
62
|
+
return transformer
|
|
@@ -80,26 +80,19 @@ def quantize_ao(
|
|
|
80
80
|
|
|
81
81
|
return False
|
|
82
82
|
|
|
83
|
-
def
|
|
83
|
+
def _quant_config():
|
|
84
84
|
try:
|
|
85
85
|
if quant_type == "fp8_w8a8_dq":
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
)
|
|
92
|
-
except ImportError:
|
|
93
|
-
from torchao.quantization import (
|
|
94
|
-
Float8DynamicActivationFloat8WeightConfig as float8_dynamic_activation_float8_weight,
|
|
95
|
-
PerTensor,
|
|
96
|
-
PerRow,
|
|
97
|
-
)
|
|
86
|
+
from torchao.quantization import (
|
|
87
|
+
Float8DynamicActivationFloat8WeightConfig,
|
|
88
|
+
PerTensor,
|
|
89
|
+
PerRow,
|
|
90
|
+
)
|
|
98
91
|
|
|
99
92
|
if per_row: # Ensure bfloat16
|
|
100
93
|
module.to(torch.bfloat16)
|
|
101
94
|
|
|
102
|
-
|
|
95
|
+
quant_config = Float8DynamicActivationFloat8WeightConfig(
|
|
103
96
|
weight_dtype=kwargs.get(
|
|
104
97
|
"weight_dtype",
|
|
105
98
|
torch.float8_e4m3fn,
|
|
@@ -116,14 +109,9 @@ def quantize_ao(
|
|
|
116
109
|
)
|
|
117
110
|
|
|
118
111
|
elif quant_type == "fp8_w8a16_wo":
|
|
119
|
-
|
|
120
|
-
from torchao.quantization import float8_weight_only
|
|
121
|
-
except ImportError:
|
|
122
|
-
from torchao.quantization import (
|
|
123
|
-
Float8WeightOnlyConfig as float8_weight_only,
|
|
124
|
-
)
|
|
112
|
+
from torchao.quantization import Float8WeightOnlyConfig
|
|
125
113
|
|
|
126
|
-
|
|
114
|
+
quant_config = Float8WeightOnlyConfig(
|
|
127
115
|
weight_dtype=kwargs.get(
|
|
128
116
|
"weight_dtype",
|
|
129
117
|
torch.float8_e4m3fn,
|
|
@@ -131,69 +119,44 @@ def quantize_ao(
|
|
|
131
119
|
)
|
|
132
120
|
|
|
133
121
|
elif quant_type == "int8_w8a8_dq":
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
)
|
|
138
|
-
except ImportError:
|
|
139
|
-
from torchao.quantization import (
|
|
140
|
-
Int8DynamicActivationInt8WeightConfig as int8_dynamic_activation_int8_weight,
|
|
141
|
-
)
|
|
122
|
+
from torchao.quantization import (
|
|
123
|
+
Int8DynamicActivationInt8WeightConfig,
|
|
124
|
+
)
|
|
142
125
|
|
|
143
|
-
|
|
126
|
+
quant_config = Int8DynamicActivationInt8WeightConfig()
|
|
144
127
|
|
|
145
128
|
elif quant_type == "int8_w8a16_wo":
|
|
146
129
|
|
|
147
|
-
|
|
148
|
-
from torchao.quantization import int8_weight_only
|
|
149
|
-
except ImportError:
|
|
150
|
-
from torchao.quantization import (
|
|
151
|
-
Int8WeightOnlyConfig as int8_weight_only,
|
|
152
|
-
)
|
|
130
|
+
from torchao.quantization import Int8WeightOnlyConfig
|
|
153
131
|
|
|
154
|
-
|
|
132
|
+
quant_config = Int8WeightOnlyConfig(
|
|
155
133
|
# group_size is None -> per_channel, else per group
|
|
156
134
|
group_size=kwargs.get("group_size", None),
|
|
157
135
|
)
|
|
158
136
|
|
|
159
137
|
elif quant_type == "int4_w4a8_dq":
|
|
160
138
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
)
|
|
165
|
-
except ImportError:
|
|
166
|
-
from torchao.quantization import (
|
|
167
|
-
Int8DynamicActivationInt4WeightConfig as int8_dynamic_activation_int4_weight,
|
|
168
|
-
)
|
|
139
|
+
from torchao.quantization import (
|
|
140
|
+
Int8DynamicActivationInt4WeightConfig,
|
|
141
|
+
)
|
|
169
142
|
|
|
170
|
-
|
|
143
|
+
quant_config = Int8DynamicActivationInt4WeightConfig(
|
|
171
144
|
group_size=kwargs.get("group_size", 32),
|
|
172
145
|
)
|
|
173
146
|
|
|
174
147
|
elif quant_type == "int4_w4a4_dq":
|
|
175
148
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
)
|
|
180
|
-
except ImportError:
|
|
181
|
-
from torchao.quantization import (
|
|
182
|
-
Int4DynamicActivationInt4WeightConfig as int4_dynamic_activation_int4_weight,
|
|
183
|
-
)
|
|
149
|
+
from torchao.quantization import (
|
|
150
|
+
Int4DynamicActivationInt4WeightConfig,
|
|
151
|
+
)
|
|
184
152
|
|
|
185
|
-
|
|
153
|
+
quant_config = Int4DynamicActivationInt4WeightConfig()
|
|
186
154
|
|
|
187
155
|
elif quant_type == "int4_w4a16_wo":
|
|
188
156
|
|
|
189
|
-
|
|
190
|
-
from torchao.quantization import int4_weight_only
|
|
191
|
-
except ImportError:
|
|
192
|
-
from torchao.quantization import (
|
|
193
|
-
Int4WeightOnlyConfig as int4_weight_only,
|
|
194
|
-
)
|
|
157
|
+
from torchao.quantization import Int4WeightOnlyConfig
|
|
195
158
|
|
|
196
|
-
|
|
159
|
+
quant_config = Int4WeightOnlyConfig(
|
|
197
160
|
group_size=kwargs.get("group_size", 32),
|
|
198
161
|
)
|
|
199
162
|
|
|
@@ -209,13 +172,13 @@ def quantize_ao(
|
|
|
209
172
|
)
|
|
210
173
|
raise e
|
|
211
174
|
|
|
212
|
-
return
|
|
175
|
+
return quant_config
|
|
213
176
|
|
|
214
177
|
from torchao.quantization import quantize_
|
|
215
178
|
|
|
216
179
|
quantize_(
|
|
217
180
|
module,
|
|
218
|
-
|
|
181
|
+
_quant_config(),
|
|
219
182
|
filter_fn=_filter_fn if filter_fn is None else filter_fn,
|
|
220
183
|
device=kwargs.get("device", None),
|
|
221
184
|
)
|
cache_dit/utils.py
CHANGED
|
@@ -13,6 +13,7 @@ from cache_dit.cache_factory import CacheType
|
|
|
13
13
|
from cache_dit.cache_factory import BlockAdapter
|
|
14
14
|
from cache_dit.cache_factory import BasicCacheConfig
|
|
15
15
|
from cache_dit.cache_factory import CalibratorConfig
|
|
16
|
+
from cache_dit.parallelism import ParallelismConfig
|
|
16
17
|
from cache_dit.logger import init_logger
|
|
17
18
|
|
|
18
19
|
|
|
@@ -55,6 +56,8 @@ class CacheStats:
|
|
|
55
56
|
cfg_pruned_blocks: list[int] = dataclasses.field(default_factory=list)
|
|
56
57
|
cfg_actual_blocks: list[int] = dataclasses.field(default_factory=list)
|
|
57
58
|
cfg_pruned_ratio: float = None
|
|
59
|
+
# Parallelism Stats
|
|
60
|
+
parallelism_config: ParallelismConfig = None
|
|
58
61
|
|
|
59
62
|
|
|
60
63
|
def summary(
|
|
@@ -213,7 +216,13 @@ def strify(
|
|
|
213
216
|
return calibrator_config.strify()
|
|
214
217
|
return "T0O0"
|
|
215
218
|
|
|
216
|
-
|
|
219
|
+
def parallelism_str():
|
|
220
|
+
parallelism_config: ParallelismConfig = stats.parallelism_config
|
|
221
|
+
if parallelism_config is not None:
|
|
222
|
+
return f"_{parallelism_config.strify()}"
|
|
223
|
+
return ""
|
|
224
|
+
|
|
225
|
+
cache_type_str = f"{cache_str()}_{calibrator_str()}{parallelism_str()}"
|
|
217
226
|
|
|
218
227
|
if cached_steps:
|
|
219
228
|
cache_type_str += f"_S{cached_steps}"
|
|
@@ -252,6 +261,17 @@ def _summary(
|
|
|
252
261
|
if logging:
|
|
253
262
|
logger.warning(f"Can't find Context Options for: {cls_name}")
|
|
254
263
|
|
|
264
|
+
if hasattr(module, "_parallelism_config"):
|
|
265
|
+
parallelism_config: ParallelismConfig = module._parallelism_config
|
|
266
|
+
cache_stats.parallelism_config = parallelism_config
|
|
267
|
+
if logging:
|
|
268
|
+
print(
|
|
269
|
+
f"\n🤖Parallelism Config: {cls_name}\n\n{parallelism_config.strify(True)}"
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
if logging:
|
|
273
|
+
logger.warning(f"Can't find Parallelism Config for: {cls_name}")
|
|
274
|
+
|
|
255
275
|
if hasattr(module, "_cached_steps"):
|
|
256
276
|
cached_steps: list[int] = module._cached_steps
|
|
257
277
|
residual_diffs: dict[str, list | float] = dict(module._residual_diffs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.6
|
|
4
4
|
Summary: A Unified, Flexible and Training-free Cache Acceleration Framework for 🤗Diffusers.
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -53,6 +53,10 @@ Dynamic: requires-python
|
|
|
53
53
|
A <b>Unified</b>, Flexible and Training-free <b>Cache Acceleration</b> Framework for <b>🤗Diffusers</b> <br>
|
|
54
54
|
♥️ Cache Acceleration with <b>One-line</b> Code ~ ♥️
|
|
55
55
|
</p>
|
|
56
|
+
<p align="center">
|
|
57
|
+
🔥<b><a href="./docs/User_Guide.md">DBCache</a> | <a href="./docs/User_Guide.md">DBPrune</a> | <a href="./docs/User_Guide.md">Hybird TaylorSeer</a> | <a href="./docs/User_Guide.md">Hybird Cache CFG</a></b>🔥 <br>
|
|
58
|
+
🔥<b><a href="./docs/User_Guide.md">Hybrid Context Paralleism</a> | <a href="./docs/User_Guide.md">PyTorch Native</a> | <a href="./docs/User_Guide.md">SOTA</a></b>🔥
|
|
59
|
+
</p>
|
|
56
60
|
<div align='center'>
|
|
57
61
|
<img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
|
|
58
62
|
<img src=https://img.shields.io/badge/PRs-welcome-blue.svg >
|
|
@@ -194,7 +198,7 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
|
|
|
194
198
|
- **[🎉Easy New Model Integration](./docs/User_Guide.md#automatic-block-adapter)**: Features like **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid Forward Pattern**, and **Patch Functor** make it highly functional and flexible. For example, we achieved 🎉 Day 1 support for [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with 1.7x speedup w/o precision loss—even before it was available in the Diffusers library.
|
|
195
199
|
- **[🎉State-of-the-Art Performance](./bench/)**: Compared with algorithms including Δ-DiT, Chipmunk, FORA, DuCa, TaylorSeer and FoCa, cache-dit achieved the **SOTA** performance w/ **7.4x↑🎉** speedup on ClipScore!
|
|
196
200
|
- **[🎉Support for 4/8-Steps Distilled Models](./bench/)**: Surprisingly, cache-dit's **DBCache** works for extremely few-step distilled models—something many other methods fail to do.
|
|
197
|
-
- **[🎉Compatibility with Other Optimizations](./docs/User_Guide.md#️torch-compile)**: Designed to work seamlessly with torch.compile,
|
|
201
|
+
- **[🎉Compatibility with Other Optimizations](./docs/User_Guide.md#️torch-compile)**: Designed to work seamlessly with torch.compile, Offloading, Quantization([torchao](./examples/quantize/), [🔥nunchaku](./examples/quantize/)), [🔥Context Parallelism](./docs/User_Guide.md/#️hybrid-context-parallelism), etc.
|
|
198
202
|
- **[🎉Hybrid Cache Acceleration](./docs/User_Guide.md#taylorseer-calibrator)**: Now supports hybrid **Block-wise Cache + Calibrator** schemes (e.g., DBCache or DBPrune + TaylorSeerCalibrator). DBCache or DBPrune acts as the **Indicator** to decide *when* to cache, while the Calibrator decides *how* to cache. More mainstream cache acceleration algorithms (e.g., FoCa) will be supported in the future, along with additional benchmarks—stay tuned for updates!
|
|
199
203
|
- **[🤗Diffusers Ecosystem Integration](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)**: 🔥**cache-dit** has joined the Diffusers community ecosystem as the **first** DiT-specific cache acceleration framework! Check out the documentation here: <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a>
|
|
200
204
|
|
|
@@ -202,6 +206,8 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
|
|
|
202
206
|
|
|
203
207
|
## 🔥Important News
|
|
204
208
|
|
|
209
|
+
- 2025.10.20: 🔥Now cache-dit supported the [Hybrid Cache + Context Parallelism](./docs/User_Guide.md/#️hybrid-context-parallelism) scheme!🔥
|
|
210
|
+
- 2025.10.16: 🎉cache-dit + [**🔥nunchaku 4-bits**](https://github.com/nunchaku-tech/nunchaku) supported: [Qwen-Image-Lightning 4/8 steps](./examples/quantize/).
|
|
205
211
|
- 2025.10.15: 🎉cache-dit now supported [**🔥nunchaku**](https://github.com/nunchaku-tech/nunchaku): Qwen-Image/FLUX.1 [4-bits examples](./examples/quantize/)
|
|
206
212
|
- 2025.10.13: 🎉cache-dit achieved the **SOTA** performance w/ **7.4x↑🎉** speedup on ClipScore!
|
|
207
213
|
- 2025.10.10: 🔥[**Qwen-Image-ControlNet-Inpainting**](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting) **2.3x↑🎉** speedup! Check the [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_controlnet_inpaint.py).
|
|
@@ -244,9 +250,10 @@ For more advanced features such as **Unified Cache APIs**, **Forward Pattern Mat
|
|
|
244
250
|
- [🤖Cache Acceleration Stats](./docs/User_Guide.md#cache-acceleration-stats-summary)
|
|
245
251
|
- [⚡️DBCache: Dual Block Cache](./docs/User_Guide.md#️dbcache-dual-block-cache)
|
|
246
252
|
- [⚡️DBPrune: Dynamic Block Prune](./docs/User_Guide.md#️dbprune-dynamic-block-prune)
|
|
247
|
-
- [🔥TaylorSeer
|
|
253
|
+
- [🔥Hybrid TaylorSeer](./docs/User_Guide.md#taylorseer-calibrator)
|
|
248
254
|
- [⚡️Hybrid Cache CFG](./docs/User_Guide.md#️hybrid-cache-cfg)
|
|
249
|
-
- [
|
|
255
|
+
- [⚡️Hybrid Context Parallelism](./docs/User_Guide.md#context-paralleism)
|
|
256
|
+
- [🛠Metrics Command Line](./docs/User_Guide.md#metrics-cli)
|
|
250
257
|
- [⚙️Torch Compile](./docs/User_Guide.md#️torch-compile)
|
|
251
258
|
- [📚API Documents](./docs/User_Guide.md#api-documentation)
|
|
252
259
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
cache_dit/__init__.py,sha256=
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
1
|
+
cache_dit/__init__.py,sha256=HZb04M7AHCfk9DaEAGApGJ2lCM-rsP6pbsNQxsQudi0,1743
|
|
2
|
+
cache_dit/_version.py,sha256=r9csd7YQr6ubaa9S-K5iWXSr4c-RpLuQWy5uJw8f4MU,704
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
|
-
cache_dit/utils.py,sha256=
|
|
4
|
+
cache_dit/utils.py,sha256=sLJNMARd9a3dA9dmuD6pZZg5n5FslwUeAktfsL1eO4I,17781
|
|
5
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
6
6
|
cache_dit/cache_factory/__init__.py,sha256=5UjrpxLVlmjHttTL0O14fD5oU5uKI3FKYevL613ibFQ,1848
|
|
7
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/cache_interface.py,sha256=244uTVx83hpCpbCDgEOydi5HqG7hKHHzEoz1ApJW6lI,14627
|
|
8
8
|
cache_dit/cache_factory/cache_types.py,sha256=QnWfaS52UOXQtnoCUOwwz4ziY0dyBta6vQ6hvgtdV44,1404
|
|
9
9
|
cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
|
|
10
10
|
cache_dit/cache_factory/params_modifier.py,sha256=2T98IbepAolWW6GwQsqUDsRzu0k65vo7BOrN3V8mKog,3606
|
|
@@ -13,7 +13,7 @@ cache_dit/cache_factory/block_adapters/__init__.py,sha256=zs-cYacRL_hWlhUXmKc0TZ
|
|
|
13
13
|
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=2TVK_KqiYXC7AKZ2s07fzdOzUoeUBc9P1SzQtLVzhf4,22249
|
|
14
14
|
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=2L7QeM4ygnaKQpC9PoJod0QRYyxidUKU2AYpysDCUwE,2572
|
|
15
15
|
cache_dit/cache_factory/cache_adapters/__init__.py,sha256=py71WGD3JztQ1uk6qdLVbzYcQ1rvqFidNNaQYo7tqTo,79
|
|
16
|
-
cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256
|
|
16
|
+
cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=WYrgV3DKxOxttl-wEKymyKIB1Po0eW73Q2_vOlGEKdQ,24080
|
|
17
17
|
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=cpxzmDcUhbXcReHqaKSnWyEEbIg1H91Pz5hE3z9Xj3k,9984
|
|
18
18
|
cache_dit/cache_factory/cache_blocks/offload_utils.py,sha256=wusgcqaCrwEjvv7Guy-6VXhNOgPPUrBV2sSVuRmGuvo,3513
|
|
19
19
|
cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=j4bTafqU5DLQhzP_X5XwOk-QUVLWkGrX-Q6JZvBGHh0,666
|
|
@@ -52,12 +52,17 @@ cache_dit/metrics/image_reward.py,sha256=N8HalJo1T1js0dsNb2V1KRv4kIdcm3nhx7iOXJu
|
|
|
52
52
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
53
53
|
cache_dit/metrics/lpips.py,sha256=hrHrmdM-f2B4TKDs0xLqJO5JFaYcCjq2qNIR8oCrVkc,811
|
|
54
54
|
cache_dit/metrics/metrics.py,sha256=AZbQyoavE-djvyRUZ_EfCIrWSQbiWQFo7n2dhn7XptE,40466
|
|
55
|
+
cache_dit/parallelism/__init__.py,sha256=dheBG5_TZCuwctviMslpAEgB-B3N8F816bE51qsw_fU,210
|
|
56
|
+
cache_dit/parallelism/parallel_backend.py,sha256=js1soTMenLeAyPMsBgdI3gWcdXoqjWgBD-PuFEywMr0,508
|
|
57
|
+
cache_dit/parallelism/parallel_config.py,sha256=bu24sRSzJMmH7FZqzUPTcT6tAzQ20-FAqAEvGV3Q1Fw,1733
|
|
58
|
+
cache_dit/parallelism/parallel_interface.py,sha256=tsiIdHosTmRbeRg0z9q0eMQlx-7vefmSIlc56OWnuMg,2205
|
|
59
|
+
cache_dit/parallelism/backends/parallel_difffusers.py,sha256=i57yZzYc9kGPUjLXTzoAA4j7U8EtVIAJRK1exw30Voo,1939
|
|
55
60
|
cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
|
|
56
|
-
cache_dit/quantize/quantize_ao.py,sha256=
|
|
61
|
+
cache_dit/quantize/quantize_ao.py,sha256=bbEUwsrMp3bMuRw8qJZREIvCHaJRQoZyfMjlu4ImRMI,6315
|
|
57
62
|
cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
|
|
58
|
-
cache_dit-1.0.
|
|
59
|
-
cache_dit-1.0.
|
|
60
|
-
cache_dit-1.0.
|
|
61
|
-
cache_dit-1.0.
|
|
62
|
-
cache_dit-1.0.
|
|
63
|
-
cache_dit-1.0.
|
|
63
|
+
cache_dit-1.0.6.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
64
|
+
cache_dit-1.0.6.dist-info/METADATA,sha256=h_fb2Lf6XGsTofkmLJrsWh67H0LRvry1SDrMeQ9Uf5I,29476
|
|
65
|
+
cache_dit-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
66
|
+
cache_dit-1.0.6.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
67
|
+
cache_dit-1.0.6.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
68
|
+
cache_dit-1.0.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|