cache-dit 1.0.4__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -4,8 +4,7 @@ except ImportError:
4
4
  __version__ = "unknown version"
5
5
  version_tuple = (0, 0, "unknown version")
6
6
 
7
- from cache_dit.utils import summary
8
- from cache_dit.utils import strify
7
+
9
8
  from cache_dit.utils import disable_print
10
9
  from cache_dit.logger import init_logger
11
10
  from cache_dit.cache_factory import load_options
@@ -28,7 +27,10 @@ from cache_dit.cache_factory import supported_pipelines
28
27
  from cache_dit.cache_factory import get_adapter
29
28
  from cache_dit.compile import set_compile_configs
30
29
  from cache_dit.quantize import quantize
31
-
30
+ from cache_dit.parallelism import ParallelismBackend
31
+ from cache_dit.parallelism import ParallelismConfig
32
+ from cache_dit.utils import summary
33
+ from cache_dit.utils import strify
32
34
 
33
35
  NONE = CacheType.NONE
34
36
  DBCache = CacheType.DBCache
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.0.4'
32
- __version_tuple__ = version_tuple = (1, 0, 4)
31
+ __version__ = version = '1.0.6'
32
+ __version_tuple__ = version_tuple = (1, 0, 6)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -12,7 +12,10 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
12
12
  from cache_dit.utils import is_diffusers_at_least_0_3_5
13
13
 
14
14
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
15
- if is_diffusers_at_least_0_3_5():
15
+ transformer_cls_name: str = pipe.transformer.__class__.__name__
16
+ if is_diffusers_at_least_0_3_5() and not transformer_cls_name.startswith(
17
+ "Nunchaku"
18
+ ):
16
19
  return BlockAdapter(
17
20
  pipe=pipe,
18
21
  transformer=pipe.transformer,
@@ -1,3 +1,4 @@
1
+ import copy
1
2
  import torch
2
3
  import unittest
3
4
  import functools
@@ -197,7 +198,6 @@ class CachedAdapter:
197
198
  flatten_contexts, contexts_kwargs = cls.modify_context_params(
198
199
  block_adapter, **context_kwargs
199
200
  )
200
-
201
201
  original_call = block_adapter.pipe.__class__.__call__
202
202
 
203
203
  @functools.wraps(original_call)
@@ -238,7 +238,7 @@ class CachedAdapter:
238
238
  block_adapter.unique_blocks_name
239
239
  )
240
240
  contexts_kwargs = [
241
- context_kwargs.copy()
241
+ copy.deepcopy(context_kwargs) # must deep copy
242
242
  for _ in range(
243
243
  len(flatten_contexts),
244
244
  )
@@ -259,9 +259,41 @@ class CachedAdapter:
259
259
  for i in range(
260
260
  min(len(contexts_kwargs), len(flatten_modifiers)),
261
261
  ):
262
- contexts_kwargs[i].update(
263
- flatten_modifiers[i]._context_kwargs,
264
- )
262
+ if "cache_config" in flatten_modifiers[i]._context_kwargs:
263
+ modifier_cache_config = flatten_modifiers[
264
+ i
265
+ ]._context_kwargs.get("cache_config", None)
266
+ modifier_calibrator_config = flatten_modifiers[
267
+ i
268
+ ]._context_kwargs.get("calibrator_config", None)
269
+ if modifier_cache_config is not None:
270
+ assert isinstance(
271
+ modifier_cache_config, BasicCacheConfig
272
+ ), (
273
+ f"cache_config must be BasicCacheConfig, but got "
274
+ f"{type(modifier_cache_config)}."
275
+ )
276
+ contexts_kwargs[i]["cache_config"].update(
277
+ **modifier_cache_config.as_dict()
278
+ )
279
+ if modifier_calibrator_config is not None:
280
+ assert isinstance(
281
+ modifier_calibrator_config, CalibratorConfig
282
+ ), (
283
+ f"calibrator_config must be CalibratorConfig, but got "
284
+ f"{type(modifier_calibrator_config)}."
285
+ )
286
+ if (
287
+ contexts_kwargs[i].get("calibrator_config", None)
288
+ is None
289
+ ):
290
+ contexts_kwargs[i][
291
+ "calibrator_config"
292
+ ] = modifier_calibrator_config
293
+ else:
294
+ contexts_kwargs[i]["calibrator_config"].update(
295
+ **modifier_calibrator_config.as_dict()
296
+ )
265
297
  cls._config_messages(**contexts_kwargs[i])
266
298
 
267
299
  return flatten_contexts, contexts_kwargs
@@ -582,6 +614,18 @@ class CachedAdapter:
582
614
  pipe_or_adapter, remove_stats, remove_stats, remove_stats
583
615
  )
584
616
 
617
+ # maybe release parallelism stats
618
+ from cache_dit.parallelism.parallel_interface import (
619
+ remove_parallelism_stats,
620
+ )
621
+
622
+ cls.release_hooks(
623
+ pipe_or_adapter,
624
+ remove_parallelism_stats,
625
+ remove_parallelism_stats,
626
+ remove_parallelism_stats,
627
+ )
628
+
585
629
  @classmethod
586
630
  def release_hooks(
587
631
  cls,
@@ -60,9 +60,25 @@ class BasicCacheConfig:
60
60
  def update(self, **kwargs) -> "BasicCacheConfig":
61
61
  for key, value in kwargs.items():
62
62
  if hasattr(self, key):
63
- setattr(self, key, value)
63
+ if value is not None:
64
+ setattr(self, key, value)
64
65
  return self
65
66
 
67
+ def empty(self, **kwargs) -> "BasicCacheConfig":
68
+ # Set all fields to None
69
+ for field in dataclasses.fields(self):
70
+ if hasattr(self, field.name):
71
+ setattr(self, field.name, None)
72
+ if kwargs:
73
+ self.update(**kwargs)
74
+ return self
75
+
76
+ def reset(self, **kwargs) -> "BasicCacheConfig":
77
+ return self.empty(**kwargs)
78
+
79
+ def as_dict(self) -> dict:
80
+ return dataclasses.asdict(self)
81
+
66
82
  def strify(self) -> str:
67
83
  return (
68
84
  f"{self.cache_type}_"
@@ -45,6 +45,28 @@ class CalibratorConfig:
45
45
  def to_kwargs(self) -> Dict:
46
46
  return self.calibrator_kwargs.copy()
47
47
 
48
+ def as_dict(self) -> dict:
49
+ return dataclasses.asdict(self)
50
+
51
+ def update(self, **kwargs) -> "CalibratorConfig":
52
+ for key, value in kwargs.items():
53
+ if hasattr(self, key):
54
+ if value is not None:
55
+ setattr(self, key, value)
56
+ return self
57
+
58
+ def empty(self, **kwargs) -> "CalibratorConfig":
59
+ # Set all fields to None
60
+ for field in dataclasses.fields(self):
61
+ if hasattr(self, field.name):
62
+ setattr(self, field.name, None)
63
+ if kwargs:
64
+ self.update(**kwargs)
65
+ return self
66
+
67
+ def reset(self, **kwargs) -> "CalibratorConfig":
68
+ return self.empty(**kwargs)
69
+
48
70
 
49
71
  @dataclasses.dataclass
50
72
  class TaylorSeerCalibratorConfig(CalibratorConfig):
@@ -50,12 +50,6 @@ class DBPruneConfig(BasicCacheConfig):
50
50
  # to at least 2 to reduce the VRAM usage of the calibrator.
51
51
  force_reduce_calibrator_vram: bool = False
52
52
 
53
- def update(self, **kwargs) -> "DBPruneConfig":
54
- for key, value in kwargs.items():
55
- if hasattr(self, key):
56
- setattr(self, key, value)
57
- return self
58
-
59
53
  def strify(self) -> str:
60
54
  return (
61
55
  f"{self.cache_type}_"
@@ -9,6 +9,8 @@ from cache_dit.cache_factory.cache_contexts import DBCacheConfig
9
9
  from cache_dit.cache_factory.cache_contexts import DBPruneConfig
10
10
  from cache_dit.cache_factory.cache_contexts import CalibratorConfig
11
11
  from cache_dit.cache_factory.params_modifier import ParamsModifier
12
+ from cache_dit.parallelism import ParallelismConfig
13
+ from cache_dit.parallelism import enable_parallelism
12
14
 
13
15
  from cache_dit.logger import init_logger
14
16
 
@@ -37,6 +39,8 @@ def enable_cache(
37
39
  List[List[ParamsModifier]],
38
40
  ]
39
41
  ] = None,
42
+ # Config for Parallelism
43
+ parallelism_config: Optional[ParallelismConfig] = None,
40
44
  # Other cache context kwargs: Deprecated cache kwargs
41
45
  **kwargs,
42
46
  ) -> Union[
@@ -127,6 +131,15 @@ def enable_cache(
127
131
  **kwargs: (`dict`, *optional*, defaults to {}):
128
132
  The same as 'kwargs' param in cache_dit.enable_cache() interface.
129
133
 
134
+ parallelism_config (`ParallelismConfig`, *optional*, defaults to None):
135
+ Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
136
+ parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
137
+ for more details of ParallelismConfig.
138
+ ulysses_size: (`int`, *optional*, defaults to None):
139
+ The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
140
+ ring_size: (`int`, *optional*, defaults to None):
141
+ The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
142
+
130
143
  kwargs (`dict`, *optional*, defaults to {})
131
144
  Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
132
145
  for more details.
@@ -214,7 +227,7 @@ def enable_cache(
214
227
  context_kwargs["params_modifiers"] = params_modifiers
215
228
 
216
229
  if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
217
- return CachedAdapter.apply(
230
+ pipe_or_adapter = CachedAdapter.apply(
218
231
  pipe_or_adapter,
219
232
  **context_kwargs,
220
233
  )
@@ -225,6 +238,27 @@ def enable_cache(
225
238
  "for the 1's position param: pipe_or_adapter"
226
239
  )
227
240
 
241
+ # NOTE: Users should always enable parallelism after applying
242
+ # cache to avoid hooks conflict.
243
+ if parallelism_config is not None:
244
+ assert isinstance(
245
+ parallelism_config, ParallelismConfig
246
+ ), "parallelism_config should be of type ParallelismConfig."
247
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
248
+ transformer = pipe_or_adapter.transformer
249
+ else:
250
+ assert BlockAdapter.assert_normalized(pipe_or_adapter)
251
+ assert (
252
+ len(BlockAdapter.flatten(pipe_or_adapter.transformer)) == 1
253
+ ), (
254
+ "Only single transformer is supported to enable parallelism "
255
+ "currently for BlockAdapter."
256
+ )
257
+ transformer = BlockAdapter.flatten(pipe_or_adapter.transformer)[0]
258
+ # Enable parallelism for the transformer inplace
259
+ transformer = enable_parallelism(transformer, parallelism_config)
260
+ return pipe_or_adapter
261
+
228
262
 
229
263
  def disable_cache(
230
264
  pipe_or_adapter: Union[
@@ -0,0 +1,3 @@
1
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
2
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
3
+ from cache_dit.parallelism.parallel_interface import enable_parallelism
@@ -0,0 +1,56 @@
1
+ import torch
2
+
3
+ from typing import Optional
4
+
5
+ try:
6
+ from diffusers import ContextParallelConfig
7
+
8
+ def native_diffusers_parallelism_available() -> bool:
9
+ return True
10
+
11
+ except ImportError:
12
+ ContextParallelConfig = None
13
+
14
+ def native_diffusers_parallelism_available() -> bool:
15
+ return False
16
+
17
+
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
20
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
21
+
22
+
23
+ def maybe_enable_parallelism(
24
+ transformer: torch.nn.Module,
25
+ parallelism_config: Optional[ParallelismConfig],
26
+ ) -> torch.nn.Module:
27
+ assert isinstance(transformer, ModelMixin)
28
+ if parallelism_config is None:
29
+ return transformer
30
+
31
+ if (
32
+ parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
33
+ and native_diffusers_parallelism_available()
34
+ ):
35
+ cp_config = None
36
+ if (
37
+ parallelism_config.ulysses_size is not None
38
+ or parallelism_config.ring_size is not None
39
+ ):
40
+ cp_config = ContextParallelConfig(
41
+ ulysses_degree=parallelism_config.ulysses_size,
42
+ ring_degree=parallelism_config.ring_size,
43
+ )
44
+ if cp_config is not None:
45
+ if hasattr(transformer, "enable_parallelism"):
46
+ if hasattr(transformer, "set_attention_backend"): # type: ignore[attr-defined]
47
+ # Now only _native_cudnn is supported for parallelism
48
+ # issue: https://github.com/huggingface/diffusers/pull/12443
49
+ transformer.set_attention_backend("_native_cudnn") # type: ignore[attr-defined]
50
+ transformer.enable_parallelism(config=cp_config)
51
+ else:
52
+ raise ValueError(
53
+ f"{transformer.__class__.__name__} does not support context parallelism."
54
+ )
55
+
56
+ return transformer
@@ -0,0 +1,18 @@
1
+ from enum import Enum
2
+
3
+
4
+ class ParallelismBackend(Enum):
5
+ NATIVE_DIFFUSER = "Native_Diffuser"
6
+ NATIVE_PYTORCH = "Native_PyTorch"
7
+ NONE = "None"
8
+
9
+ @classmethod
10
+ def is_supported(cls, backend: "ParallelismBackend") -> bool:
11
+ # Now, only Native_Diffuser backend is supported
12
+ if backend in [cls.NATIVE_DIFFUSER]:
13
+ try:
14
+ import diffusers # noqa: F401
15
+ except ImportError:
16
+ return False
17
+ return True
18
+ return False
@@ -0,0 +1,47 @@
1
+ import dataclasses
2
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
3
+ from cache_dit.logger import init_logger
4
+
5
+ logger = init_logger(__name__)
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class ParallelismConfig:
10
+ # Parallelism backend, defaults to NATIVE_DIFFUSER
11
+ backend: ParallelismBackend = ParallelismBackend.NATIVE_DIFFUSER
12
+ # Context parallelism config
13
+ # ulysses_size (`int`, *optional*):
14
+ # The degree of ulysses parallelism.
15
+ ulysses_size: int = None
16
+ # ring_size (`int`, *optional*):
17
+ # The degree of ring parallelism.
18
+ ring_size: int = None
19
+ # Tensor parallelism config
20
+ # tp_size (`int`, *optional*):
21
+ # The degree of tensor parallelism.
22
+ tp_size: int = None
23
+
24
+ def __post_init__(self):
25
+ assert ParallelismBackend.is_supported(self.backend), (
26
+ f"Parallel backend {self.backend} is not supported. "
27
+ f"Please make sure the required packages are installed."
28
+ )
29
+ assert self.tp_size is None, "Tensor parallelism is not supported yet."
30
+
31
+ def strify(self, details: bool = False) -> str:
32
+ if details:
33
+ return (
34
+ f"ParallelismConfig(backend={self.backend}, "
35
+ f"ulysses_size={self.ulysses_size}, "
36
+ f"ring_size={self.ring_size}, "
37
+ f"tp_size={self.tp_size})"
38
+ )
39
+ else:
40
+ parallel_str = ""
41
+ if self.ulysses_size is not None:
42
+ parallel_str += f"Ulysses{self.ulysses_size}"
43
+ if self.ring_size is not None:
44
+ parallel_str += f"Ring{self.ring_size}"
45
+ if self.tp_size is not None:
46
+ parallel_str += f"TP{self.tp_size}"
47
+ return parallel_str
@@ -0,0 +1,62 @@
1
+ import torch
2
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
3
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
4
+ from cache_dit.logger import init_logger
5
+
6
+ logger = init_logger(__name__)
7
+
8
+
9
+ def enable_parallelism(
10
+ transformer: torch.nn.Module,
11
+ parallelism_config: ParallelismConfig,
12
+ ) -> torch.nn.Module:
13
+ assert isinstance(transformer, torch.nn.Module), (
14
+ "transformer must be an instance of torch.nn.Module, "
15
+ f"but got {type(transformer)}"
16
+ )
17
+ if getattr(transformer, "_is_parallelized", False):
18
+ logger.warning(
19
+ "The transformer is already parallelized. "
20
+ "Skipping parallelism enabling."
21
+ )
22
+ return transformer
23
+
24
+ if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
25
+ from cache_dit.parallelism.backends.parallel_difffusers import (
26
+ maybe_enable_parallelism,
27
+ native_diffusers_parallelism_available,
28
+ )
29
+
30
+ assert (
31
+ native_diffusers_parallelism_available()
32
+ ), "Please install diffusers>=0.36.dev0 to use Native_Diffuser backend."
33
+ transformer = maybe_enable_parallelism(
34
+ transformer,
35
+ parallelism_config,
36
+ )
37
+ else:
38
+ raise ValueError(
39
+ f"Parallel backend {parallelism_config.backend} is not supported yet."
40
+ )
41
+
42
+ transformer._is_parallelized = True # type: ignore[attr-defined]
43
+ transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
44
+ logger.info(f"Enabled parallelism: {parallelism_config.strify(True)}")
45
+ return transformer
46
+
47
+
48
+ def remove_parallelism_stats(
49
+ transformer: torch.nn.Module,
50
+ ) -> torch.nn.Module:
51
+ if not getattr(transformer, "_is_parallelized", False):
52
+ logger.warning(
53
+ "The transformer is not parallelized. "
54
+ "Skipping removing parallelism."
55
+ )
56
+ return transformer
57
+
58
+ if hasattr(transformer, "_is_parallelized"):
59
+ del transformer._is_parallelized # type: ignore[attr-defined]
60
+ if hasattr(transformer, "_parallelism_config"):
61
+ del transformer._parallelism_config # type: ignore[attr-defined]
62
+ return transformer
@@ -80,11 +80,11 @@ def quantize_ao(
80
80
 
81
81
  return False
82
82
 
83
- def _quantization_fn():
83
+ def _quant_config():
84
84
  try:
85
85
  if quant_type == "fp8_w8a8_dq":
86
86
  from torchao.quantization import (
87
- float8_dynamic_activation_float8_weight,
87
+ Float8DynamicActivationFloat8WeightConfig,
88
88
  PerTensor,
89
89
  PerRow,
90
90
  )
@@ -92,7 +92,7 @@ def quantize_ao(
92
92
  if per_row: # Ensure bfloat16
93
93
  module.to(torch.bfloat16)
94
94
 
95
- quantization_fn = float8_dynamic_activation_float8_weight(
95
+ quant_config = Float8DynamicActivationFloat8WeightConfig(
96
96
  weight_dtype=kwargs.get(
97
97
  "weight_dtype",
98
98
  torch.float8_e4m3fn,
@@ -109,9 +109,9 @@ def quantize_ao(
109
109
  )
110
110
 
111
111
  elif quant_type == "fp8_w8a16_wo":
112
- from torchao.quantization import float8_weight_only
112
+ from torchao.quantization import Float8WeightOnlyConfig
113
113
 
114
- quantization_fn = float8_weight_only(
114
+ quant_config = Float8WeightOnlyConfig(
115
115
  weight_dtype=kwargs.get(
116
116
  "weight_dtype",
117
117
  torch.float8_e4m3fn,
@@ -120,39 +120,43 @@ def quantize_ao(
120
120
 
121
121
  elif quant_type == "int8_w8a8_dq":
122
122
  from torchao.quantization import (
123
- int8_dynamic_activation_int8_weight,
123
+ Int8DynamicActivationInt8WeightConfig,
124
124
  )
125
125
 
126
- quantization_fn = int8_dynamic_activation_int8_weight()
126
+ quant_config = Int8DynamicActivationInt8WeightConfig()
127
127
 
128
128
  elif quant_type == "int8_w8a16_wo":
129
- from torchao.quantization import int8_weight_only
130
129
 
131
- quantization_fn = int8_weight_only(
130
+ from torchao.quantization import Int8WeightOnlyConfig
131
+
132
+ quant_config = Int8WeightOnlyConfig(
132
133
  # group_size is None -> per_channel, else per group
133
134
  group_size=kwargs.get("group_size", None),
134
135
  )
135
136
 
136
137
  elif quant_type == "int4_w4a8_dq":
138
+
137
139
  from torchao.quantization import (
138
- int8_dynamic_activation_int4_weight,
140
+ Int8DynamicActivationInt4WeightConfig,
139
141
  )
140
142
 
141
- quantization_fn = int8_dynamic_activation_int4_weight(
143
+ quant_config = Int8DynamicActivationInt4WeightConfig(
142
144
  group_size=kwargs.get("group_size", 32),
143
145
  )
144
146
 
145
147
  elif quant_type == "int4_w4a4_dq":
148
+
146
149
  from torchao.quantization import (
147
- int4_dynamic_activation_int4_weight,
150
+ Int4DynamicActivationInt4WeightConfig,
148
151
  )
149
152
 
150
- quantization_fn = int4_dynamic_activation_int4_weight()
153
+ quant_config = Int4DynamicActivationInt4WeightConfig()
151
154
 
152
155
  elif quant_type == "int4_w4a16_wo":
153
- from torchao.quantization import int4_weight_only
154
156
 
155
- quantization_fn = int4_weight_only(
157
+ from torchao.quantization import Int4WeightOnlyConfig
158
+
159
+ quant_config = Int4WeightOnlyConfig(
156
160
  group_size=kwargs.get("group_size", 32),
157
161
  )
158
162
 
@@ -168,13 +172,13 @@ def quantize_ao(
168
172
  )
169
173
  raise e
170
174
 
171
- return quantization_fn
175
+ return quant_config
172
176
 
173
177
  from torchao.quantization import quantize_
174
178
 
175
179
  quantize_(
176
180
  module,
177
- _quantization_fn(),
181
+ _quant_config(),
178
182
  filter_fn=_filter_fn if filter_fn is None else filter_fn,
179
183
  device=kwargs.get("device", None),
180
184
  )
cache_dit/utils.py CHANGED
@@ -13,6 +13,7 @@ from cache_dit.cache_factory import CacheType
13
13
  from cache_dit.cache_factory import BlockAdapter
14
14
  from cache_dit.cache_factory import BasicCacheConfig
15
15
  from cache_dit.cache_factory import CalibratorConfig
16
+ from cache_dit.parallelism import ParallelismConfig
16
17
  from cache_dit.logger import init_logger
17
18
 
18
19
 
@@ -55,6 +56,8 @@ class CacheStats:
55
56
  cfg_pruned_blocks: list[int] = dataclasses.field(default_factory=list)
56
57
  cfg_actual_blocks: list[int] = dataclasses.field(default_factory=list)
57
58
  cfg_pruned_ratio: float = None
59
+ # Parallelism Stats
60
+ parallelism_config: ParallelismConfig = None
58
61
 
59
62
 
60
63
  def summary(
@@ -213,7 +216,13 @@ def strify(
213
216
  return calibrator_config.strify()
214
217
  return "T0O0"
215
218
 
216
- cache_type_str = f"{cache_str()}_{calibrator_str()}"
219
+ def parallelism_str():
220
+ parallelism_config: ParallelismConfig = stats.parallelism_config
221
+ if parallelism_config is not None:
222
+ return f"_{parallelism_config.strify()}"
223
+ return ""
224
+
225
+ cache_type_str = f"{cache_str()}_{calibrator_str()}{parallelism_str()}"
217
226
 
218
227
  if cached_steps:
219
228
  cache_type_str += f"_S{cached_steps}"
@@ -252,6 +261,17 @@ def _summary(
252
261
  if logging:
253
262
  logger.warning(f"Can't find Context Options for: {cls_name}")
254
263
 
264
+ if hasattr(module, "_parallelism_config"):
265
+ parallelism_config: ParallelismConfig = module._parallelism_config
266
+ cache_stats.parallelism_config = parallelism_config
267
+ if logging:
268
+ print(
269
+ f"\n🤖Parallelism Config: {cls_name}\n\n{parallelism_config.strify(True)}"
270
+ )
271
+ else:
272
+ if logging:
273
+ logger.warning(f"Can't find Parallelism Config for: {cls_name}")
274
+
255
275
  if hasattr(module, "_cached_steps"):
256
276
  cached_steps: list[int] = module._cached_steps
257
277
  residual_diffs: dict[str, list | float] = dict(module._residual_diffs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 1.0.4
3
+ Version: 1.0.6
4
4
  Summary: A Unified, Flexible and Training-free Cache Acceleration Framework for 🤗Diffusers.
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -53,6 +53,10 @@ Dynamic: requires-python
53
53
  A <b>Unified</b>, Flexible and Training-free <b>Cache Acceleration</b> Framework for <b>🤗Diffusers</b> <br>
54
54
  ♥️ Cache Acceleration with <b>One-line</b> Code ~ ♥️
55
55
  </p>
56
+ <p align="center">
57
+ 🔥<b><a href="./docs/User_Guide.md">DBCache</a> | <a href="./docs/User_Guide.md">DBPrune</a> | <a href="./docs/User_Guide.md">Hybird TaylorSeer</a> | <a href="./docs/User_Guide.md">Hybird Cache CFG</a></b>🔥 <br>
58
+ 🔥<b><a href="./docs/User_Guide.md">Hybrid Context Paralleism</a> | <a href="./docs/User_Guide.md">PyTorch Native</a> | <a href="./docs/User_Guide.md">SOTA</a></b>🔥
59
+ </p>
56
60
  <div align='center'>
57
61
  <img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
58
62
  <img src=https://img.shields.io/badge/PRs-welcome-blue.svg >
@@ -194,7 +198,7 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
194
198
  - **[🎉Easy New Model Integration](./docs/User_Guide.md#automatic-block-adapter)**: Features like **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid Forward Pattern**, and **Patch Functor** make it highly functional and flexible. For example, we achieved 🎉 Day 1 support for [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with 1.7x speedup w/o precision loss—even before it was available in the Diffusers library.
195
199
  - **[🎉State-of-the-Art Performance](./bench/)**: Compared with algorithms including Δ-DiT, Chipmunk, FORA, DuCa, TaylorSeer and FoCa, cache-dit achieved the **SOTA** performance w/ **7.4x↑🎉** speedup on ClipScore!
196
200
  - **[🎉Support for 4/8-Steps Distilled Models](./bench/)**: Surprisingly, cache-dit's **DBCache** works for extremely few-step distilled models—something many other methods fail to do.
197
- - **[🎉Compatibility with Other Optimizations](./docs/User_Guide.md#️torch-compile)**: Designed to work seamlessly with torch.compile, model CPU offload, sequential CPU offload, group offloading, etc.
201
+ - **[🎉Compatibility with Other Optimizations](./docs/User_Guide.md#️torch-compile)**: Designed to work seamlessly with torch.compile, Offloading, Quantization([torchao](./examples/quantize/), [🔥nunchaku](./examples/quantize/)), [🔥Context Parallelism](./docs/User_Guide.md/#️hybrid-context-parallelism), etc.
198
202
  - **[🎉Hybrid Cache Acceleration](./docs/User_Guide.md#taylorseer-calibrator)**: Now supports hybrid **Block-wise Cache + Calibrator** schemes (e.g., DBCache or DBPrune + TaylorSeerCalibrator). DBCache or DBPrune acts as the **Indicator** to decide *when* to cache, while the Calibrator decides *how* to cache. More mainstream cache acceleration algorithms (e.g., FoCa) will be supported in the future, along with additional benchmarks—stay tuned for updates!
199
203
  - **[🤗Diffusers Ecosystem Integration](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)**: 🔥**cache-dit** has joined the Diffusers community ecosystem as the **first** DiT-specific cache acceleration framework! Check out the documentation here: <a href="https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit"><img src=https://img.shields.io/badge/🤗Diffusers-ecosystem-yellow.svg ></a>
200
204
 
@@ -202,6 +206,9 @@ You can install the stable release of cache-dit from PyPI, or the latest develop
202
206
 
203
207
  ## 🔥Important News
204
208
 
209
+ - 2025.10.20: 🔥Now cache-dit supported the [Hybrid Cache + Context Parallelism](./docs/User_Guide.md/#️hybrid-context-parallelism) scheme!🔥
210
+ - 2025.10.16: 🎉cache-dit + [**🔥nunchaku 4-bits**](https://github.com/nunchaku-tech/nunchaku) supported: [Qwen-Image-Lightning 4/8 steps](./examples/quantize/).
211
+ - 2025.10.15: 🎉cache-dit now supported [**🔥nunchaku**](https://github.com/nunchaku-tech/nunchaku): Qwen-Image/FLUX.1 [4-bits examples](./examples/quantize/)
205
212
  - 2025.10.13: 🎉cache-dit achieved the **SOTA** performance w/ **7.4x↑🎉** speedup on ClipScore!
206
213
  - 2025.10.10: 🔥[**Qwen-Image-ControlNet-Inpainting**](https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting) **2.3x↑🎉** speedup! Check the [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_controlnet_inpaint.py).
207
214
  - 2025.09.26: 🔥[**Qwen-Image-Edit-Plus(2509)**](https://github.com/QwenLM/Qwen-Image) **2.1x↑🎉** speedup! Please check the [example](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image_edit_plus.py).
@@ -243,9 +250,10 @@ For more advanced features such as **Unified Cache APIs**, **Forward Pattern Mat
243
250
  - [🤖Cache Acceleration Stats](./docs/User_Guide.md#cache-acceleration-stats-summary)
244
251
  - [⚡️DBCache: Dual Block Cache](./docs/User_Guide.md#️dbcache-dual-block-cache)
245
252
  - [⚡️DBPrune: Dynamic Block Prune](./docs/User_Guide.md#️dbprune-dynamic-block-prune)
246
- - [🔥TaylorSeer Calibrator](./docs/User_Guide.md#taylorseer-calibrator)
253
+ - [🔥Hybrid TaylorSeer](./docs/User_Guide.md#taylorseer-calibrator)
247
254
  - [⚡️Hybrid Cache CFG](./docs/User_Guide.md#️hybrid-cache-cfg)
248
- - [🛠Metrics CLI](./docs/User_Guide.md#metrics-cli)
255
+ - [⚡️Hybrid Context Parallelism](./docs/User_Guide.md#context-paralleism)
256
+ - [🛠Metrics Command Line](./docs/User_Guide.md#metrics-cli)
249
257
  - [⚙️Torch Compile](./docs/User_Guide.md#️torch-compile)
250
258
  - [📚API Documents](./docs/User_Guide.md#api-documentation)
251
259
 
@@ -1,19 +1,19 @@
1
- cache_dit/__init__.py,sha256=JQLxwr5aqoMFp-BNR58J0i6NutbRmNXKsaRJKCZQDCg,1638
2
- cache_dit/_version.py,sha256=jp1Oow7okdi1HqeKIp8SmyysmUf-oq2X9syICfrgATI,704
1
+ cache_dit/__init__.py,sha256=HZb04M7AHCfk9DaEAGApGJ2lCM-rsP6pbsNQxsQudi0,1743
2
+ cache_dit/_version.py,sha256=r9csd7YQr6ubaa9S-K5iWXSr4c-RpLuQWy5uJw8f4MU,704
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
- cache_dit/utils.py,sha256=0YNFr84pxYoHOCZvnONKKXYN3PZY4kao9Tq2yEfHHR8,16986
4
+ cache_dit/utils.py,sha256=sLJNMARd9a3dA9dmuD6pZZg5n5FslwUeAktfsL1eO4I,17781
5
5
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
6
  cache_dit/cache_factory/__init__.py,sha256=5UjrpxLVlmjHttTL0O14fD5oU5uKI3FKYevL613ibFQ,1848
7
- cache_dit/cache_factory/cache_interface.py,sha256=spiE7pWF80G3Y06_TKVvrmKufbAvQmyvshZZVsmb-nM,12714
7
+ cache_dit/cache_factory/cache_interface.py,sha256=244uTVx83hpCpbCDgEOydi5HqG7hKHHzEoz1ApJW6lI,14627
8
8
  cache_dit/cache_factory/cache_types.py,sha256=QnWfaS52UOXQtnoCUOwwz4ziY0dyBta6vQ6hvgtdV44,1404
9
9
  cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
10
10
  cache_dit/cache_factory/params_modifier.py,sha256=2T98IbepAolWW6GwQsqUDsRzu0k65vo7BOrN3V8mKog,3606
11
11
  cache_dit/cache_factory/utils.py,sha256=S3SD6Zhexzhkqnmfo830v6oNLm8stZe32nF4VdxD_bA,2497
12
- cache_dit/cache_factory/block_adapters/__init__.py,sha256=vM3aDMzPY79Tw4L0hlV2PdA3MFYomnf0eo0BGBo9P78,18087
12
+ cache_dit/cache_factory/block_adapters/__init__.py,sha256=zs-cYacRL_hWlhUXmKc0TZNDAKzzWuznvHeuDpAmuwc,18221
13
13
  cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=2TVK_KqiYXC7AKZ2s07fzdOzUoeUBc9P1SzQtLVzhf4,22249
14
14
  cache_dit/cache_factory/block_adapters/block_registers.py,sha256=2L7QeM4ygnaKQpC9PoJod0QRYyxidUKU2AYpysDCUwE,2572
15
15
  cache_dit/cache_factory/cache_adapters/__init__.py,sha256=py71WGD3JztQ1uk6qdLVbzYcQ1rvqFidNNaQYo7tqTo,79
16
- cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=Za9HixVkEKldYzyDA57xvF91fm9dao2S-Fz5QBIT02M,22123
16
+ cache_dit/cache_factory/cache_adapters/cache_adapter.py,sha256=WYrgV3DKxOxttl-wEKymyKIB1Po0eW73Q2_vOlGEKdQ,24080
17
17
  cache_dit/cache_factory/cache_blocks/__init__.py,sha256=cpxzmDcUhbXcReHqaKSnWyEEbIg1H91Pz5hE3z9Xj3k,9984
18
18
  cache_dit/cache_factory/cache_blocks/offload_utils.py,sha256=wusgcqaCrwEjvv7Guy-6VXhNOgPPUrBV2sSVuRmGuvo,3513
19
19
  cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=j4bTafqU5DLQhzP_X5XwOk-QUVLWkGrX-Q6JZvBGHh0,666
@@ -21,14 +21,14 @@ cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=2qPnXVZwpQIm2oJ-Yrn
21
21
  cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=9H87qBRpa6UWRkUKXLVO0_9NJgxCVKkFSzaQxM9YPw8,25487
22
22
  cache_dit/cache_factory/cache_blocks/pattern_utils.py,sha256=qOxoVTlYPQzPMrR06-7_Ce_lwNg6n5pt1KQrvxzAJhE,3124
23
23
  cache_dit/cache_factory/cache_contexts/__init__.py,sha256=7uY8fX9uhpC71VNm1HH4aDIicYn-dD3kRpPQhvc9-EI,853
24
- cache_dit/cache_factory/cache_contexts/cache_config.py,sha256=WBHU2XVuYSFUSkrrJk8c4952LTeqvgetdkdtch_uSmg,5238
24
+ cache_dit/cache_factory/cache_contexts/cache_config.py,sha256=G0PVWgckDqeyARc72Ne_0lRtO_LftsOeMERRhbh2gCA,5739
25
25
  cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=fjZMEHaT1DZvUKnzY41GP0Ep8tmPEZTOsCSvG-5it5k,11269
26
26
  cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=tKtP35GDwZDoxGrQ_Okg_enlh3L-t-iqpytx8TFO_fw,30519
27
27
  cache_dit/cache_factory/cache_contexts/context_manager.py,sha256=j5zP_kwZAKla3EXbfr6JKI1vIxZuUEbZVhAPrtC4COw,853
28
- cache_dit/cache_factory/cache_contexts/prune_config.py,sha256=efFO_tu6AFJxIDp0OxExWKPzOFj95-NSrLGXggimBMA,3407
28
+ cache_dit/cache_factory/cache_contexts/prune_config.py,sha256=WMTh6zb480a0oJiYMlgI0cwCsDSVvs6UjyeJLiXbjP8,3216
29
29
  cache_dit/cache_factory/cache_contexts/prune_context.py,sha256=ywiT9P0w_GjIFLowzUDa6jhTohNsSGfTbanZcs9wMic,6359
30
30
  cache_dit/cache_factory/cache_contexts/prune_manager.py,sha256=rZG7HD9ATqgH4VZdMq1XtP_h2pokaotFOVx1svB3J7E,5478
31
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py,sha256=mzYXO8tbytGpJJ9rpPu20kMoj1Iu_7Ym9tjfzV8rA98,5574
31
+ cache_dit/cache_factory/cache_contexts/calibrators/__init__.py,sha256=QTbyT8xcFEjfIp9xjbnsnlnVCNvMjUc20NjB0W-s95k,6269
32
32
  cache_dit/cache_factory/cache_contexts/calibrators/base.py,sha256=mn6ZBkChGpGwN5csrHTUGMoX6BBPvqHXSLbIExiW-EU,748
33
33
  cache_dit/cache_factory/cache_contexts/calibrators/foca.py,sha256=nhHGs_hxwW1M942BQDMJb9-9IuHdnOxp774Jrna1bJI,891
34
34
  cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py,sha256=l1QSNaBwtGtpZZFAgCE7Hu8Nf1oL4QAcYu7lShpFGyw,5850
@@ -52,12 +52,17 @@ cache_dit/metrics/image_reward.py,sha256=N8HalJo1T1js0dsNb2V1KRv4kIdcm3nhx7iOXJu
52
52
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
53
53
  cache_dit/metrics/lpips.py,sha256=hrHrmdM-f2B4TKDs0xLqJO5JFaYcCjq2qNIR8oCrVkc,811
54
54
  cache_dit/metrics/metrics.py,sha256=AZbQyoavE-djvyRUZ_EfCIrWSQbiWQFo7n2dhn7XptE,40466
55
+ cache_dit/parallelism/__init__.py,sha256=dheBG5_TZCuwctviMslpAEgB-B3N8F816bE51qsw_fU,210
56
+ cache_dit/parallelism/parallel_backend.py,sha256=js1soTMenLeAyPMsBgdI3gWcdXoqjWgBD-PuFEywMr0,508
57
+ cache_dit/parallelism/parallel_config.py,sha256=bu24sRSzJMmH7FZqzUPTcT6tAzQ20-FAqAEvGV3Q1Fw,1733
58
+ cache_dit/parallelism/parallel_interface.py,sha256=tsiIdHosTmRbeRg0z9q0eMQlx-7vefmSIlc56OWnuMg,2205
59
+ cache_dit/parallelism/backends/parallel_difffusers.py,sha256=i57yZzYc9kGPUjLXTzoAA4j7U8EtVIAJRK1exw30Voo,1939
55
60
  cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
56
- cache_dit/quantize/quantize_ao.py,sha256=Pr3u3Qr6qLvFkd8k-_rfcz4Mkjlg36U9BHG2t6Bl-6M,6301
61
+ cache_dit/quantize/quantize_ao.py,sha256=bbEUwsrMp3bMuRw8qJZREIvCHaJRQoZyfMjlu4ImRMI,6315
57
62
  cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
58
- cache_dit-1.0.4.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
59
- cache_dit-1.0.4.dist-info/METADATA,sha256=f04uCgApjgfHTC7Ll9aPejXCFFXbFTpTy-rjd5I_iwM,28376
60
- cache_dit-1.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
- cache_dit-1.0.4.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
62
- cache_dit-1.0.4.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
63
- cache_dit-1.0.4.dist-info/RECORD,,
63
+ cache_dit-1.0.6.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
64
+ cache_dit-1.0.6.dist-info/METADATA,sha256=h_fb2Lf6XGsTofkmLJrsWh67H0LRvry1SDrMeQ9Uf5I,29476
65
+ cache_dit-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
+ cache_dit-1.0.6.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
67
+ cache_dit-1.0.6.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
68
+ cache_dit-1.0.6.dist-info/RECORD,,