cache-dit 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (45) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/block_adapters/__init__.py +37 -0
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +75 -4
  5. cache_dit/cache_factory/block_adapters/block_registers.py +44 -17
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +72 -30
  7. cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
  8. cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
  9. cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
  10. cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
  11. cache_dit/cache_factory/cache_interface.py +102 -28
  12. cache_dit/cache_factory/forward_pattern.py +14 -14
  13. cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
  14. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
  15. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
  16. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
  17. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -49
  18. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  19. cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
  20. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  21. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  22. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
  23. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  24. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
  25. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  26. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
  27. cache_dit/parallelism/parallel_backend.py +2 -0
  28. cache_dit/parallelism/parallel_config.py +10 -3
  29. cache_dit/parallelism/parallel_interface.py +14 -5
  30. cache_dit/quantize/backends/__init__.py +1 -0
  31. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  32. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  33. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
  34. cache_dit/quantize/quantize_backend.py +0 -0
  35. cache_dit/quantize/quantize_config.py +0 -0
  36. cache_dit/quantize/quantize_interface.py +3 -16
  37. cache_dit/utils.py +56 -20
  38. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/METADATA +24 -13
  39. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
  40. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  41. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  42. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
  43. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
  44. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
  45. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
@@ -5,10 +5,11 @@ import functools
5
5
  from contextlib import ExitStack
6
6
  from typing import Dict, List, Tuple, Any, Union, Callable, Optional
7
7
 
8
- from diffusers import DiffusionPipeline
8
+ from diffusers import DiffusionPipeline, ModelMixin
9
9
 
10
10
  from cache_dit.cache_factory.cache_types import CacheType
11
11
  from cache_dit.cache_factory.block_adapters import BlockAdapter
12
+ from cache_dit.cache_factory.block_adapters import FakeDiffusionPipeline
12
13
  from cache_dit.cache_factory.block_adapters import ParamsModifier
13
14
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
14
15
  from cache_dit.cache_factory.cache_contexts import ContextManager
@@ -32,6 +33,9 @@ class CachedAdapter:
32
33
  pipe_or_adapter: Union[
33
34
  DiffusionPipeline,
34
35
  BlockAdapter,
36
+ # Transformer-only
37
+ torch.nn.Module,
38
+ ModelMixin,
35
39
  ],
36
40
  **context_kwargs,
37
41
  ) -> Union[
@@ -42,7 +46,9 @@ class CachedAdapter:
42
46
  pipe_or_adapter is not None
43
47
  ), "pipe or block_adapter can not both None!"
44
48
 
45
- if isinstance(pipe_or_adapter, DiffusionPipeline):
49
+ if isinstance(
50
+ pipe_or_adapter, (DiffusionPipeline, torch.nn.Module, ModelMixin)
51
+ ):
46
52
  if BlockAdapterRegistry.is_supported(pipe_or_adapter):
47
53
  logger.info(
48
54
  f"{pipe_or_adapter.__class__.__name__} is officially "
@@ -52,16 +58,22 @@ class CachedAdapter:
52
58
  block_adapter = BlockAdapterRegistry.get_adapter(
53
59
  pipe_or_adapter
54
60
  )
61
+ assert block_adapter is not None, (
62
+ f"BlockAdapter for {pipe_or_adapter.__class__.__name__} "
63
+ "should not be None!"
64
+ )
55
65
  if params_modifiers := context_kwargs.pop(
56
66
  "params_modifiers",
57
67
  None,
58
68
  ):
59
69
  block_adapter.params_modifiers = params_modifiers
60
70
 
61
- return cls.cachify(
62
- block_adapter,
63
- **context_kwargs,
64
- ).pipe
71
+ block_adapter = cls.cachify(block_adapter, **context_kwargs)
72
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
73
+ return block_adapter.pipe
74
+
75
+ return block_adapter.transformer
76
+
65
77
  else:
66
78
  raise ValueError(
67
79
  f"{pipe_or_adapter.__class__.__name__} is not officially supported "
@@ -178,8 +190,6 @@ class CachedAdapter:
178
190
  context_kwargs = cls.check_context_kwargs(
179
191
  block_adapter, **context_kwargs
180
192
  )
181
- # Apply cache on pipeline: wrap cache context
182
- pipe_cls_name = block_adapter.pipe.__class__.__name__
183
193
 
184
194
  # Each Pipeline should have it's own context manager instance.
185
195
  # Different transformers (Wan2.2, etc) should shared the same
@@ -189,38 +199,58 @@ class CachedAdapter:
189
199
  "cache_config", None
190
200
  )
191
201
  assert cache_config is not None, "cache_config can not be None."
202
+ # Apply cache on pipeline: wrap cache context
203
+ pipe_cls_name = block_adapter.pipe.__class__.__name__
192
204
  context_manager = ContextManager(
193
205
  name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
194
206
  cache_type=cache_config.cache_type,
207
+ # Force use persistent_context for FakeDiffusionPipeline
208
+ persistent_context=isinstance(
209
+ block_adapter.pipe, FakeDiffusionPipeline
210
+ ),
195
211
  )
196
- block_adapter.pipe._context_manager = context_manager # instance level
197
-
198
212
  flatten_contexts, contexts_kwargs = cls.modify_context_params(
199
213
  block_adapter, **context_kwargs
200
214
  )
201
- original_call = block_adapter.pipe.__class__.__call__
202
215
 
203
- @functools.wraps(original_call)
204
- def new_call(self, *args, **kwargs):
205
- with ExitStack() as stack:
206
- # cache context will be reset for each pipe inference
207
- for context_name, context_kwargs in zip(
208
- flatten_contexts, contexts_kwargs
209
- ):
210
- stack.enter_context(
211
- context_manager.enter_context(
212
- context_manager.reset_context(
213
- context_name,
214
- **context_kwargs,
215
- ),
216
+ block_adapter.pipe._context_manager = context_manager # instance level
217
+
218
+ if not context_manager.persistent_context:
219
+
220
+ original_call = block_adapter.pipe.__class__.__call__
221
+
222
+ @functools.wraps(original_call)
223
+ def new_call(self, *args, **kwargs):
224
+ with ExitStack() as stack:
225
+ # cache context will be reset for each pipe inference
226
+ for context_name, context_kwargs in zip(
227
+ flatten_contexts, contexts_kwargs
228
+ ):
229
+ stack.enter_context(
230
+ context_manager.enter_context(
231
+ context_manager.reset_context(
232
+ context_name,
233
+ **context_kwargs,
234
+ ),
235
+ )
216
236
  )
217
- )
218
- outputs = original_call(self, *args, **kwargs)
219
- cls.apply_stats_hooks(block_adapter)
220
- return outputs
237
+ outputs = original_call(self, *args, **kwargs)
238
+ cls.apply_stats_hooks(block_adapter)
239
+ return outputs
240
+
241
+ block_adapter.pipe.__class__.__call__ = new_call
242
+ block_adapter.pipe.__class__._original_call = original_call
243
+
244
+ else:
245
+ # Init persistent cache context for transformer
246
+ for context_name, context_kwargs in zip(
247
+ flatten_contexts, contexts_kwargs
248
+ ):
249
+ context_manager.reset_context(
250
+ context_name,
251
+ **context_kwargs,
252
+ )
221
253
 
222
- block_adapter.pipe.__class__.__call__ = new_call
223
- block_adapter.pipe.__class__._original_call = original_call
224
254
  block_adapter.pipe.__class__._is_cached = True
225
255
 
226
256
  cls.apply_params_hooks(block_adapter, contexts_kwargs)
@@ -349,6 +379,7 @@ class CachedAdapter:
349
379
  blocks_name,
350
380
  unique_blocks_name,
351
381
  dummy_blocks_names,
382
+ block_adapter,
352
383
  )
353
384
 
354
385
  return block_adapter.transformer
@@ -361,6 +392,7 @@ class CachedAdapter:
361
392
  blocks_name: List[str],
362
393
  unique_blocks_name: List[str],
363
394
  dummy_blocks_names: List[str],
395
+ block_adapter: BlockAdapter,
364
396
  ) -> torch.nn.Module:
365
397
  dummy_blocks = torch.nn.ModuleList()
366
398
 
@@ -387,6 +419,8 @@ class CachedAdapter:
387
419
  # re-apply hooks to transformer after cache applied.
388
420
  # from diffusers.hooks.hooks import HookFunctionReference, HookRegistry
389
421
  # from diffusers.hooks.group_offloading import apply_group_offloading
422
+ context_manager: ContextManager = block_adapter.pipe._context_manager
423
+ assert isinstance(context_manager, ContextManager._supported_managers)
390
424
 
391
425
  def new_forward(self, *args, **kwargs):
392
426
  with ExitStack() as stack:
@@ -406,6 +440,13 @@ class CachedAdapter:
406
440
  )
407
441
  )
408
442
  outputs = original_forward(*args, **kwargs)
443
+
444
+ if (
445
+ context_manager.persistent_context
446
+ and context_manager.is_pre_refreshed()
447
+ ):
448
+ cls.apply_stats_hooks(block_adapter)
449
+
409
450
  return outputs
410
451
 
411
452
  def new_forward_with_hf_hook(self, *args, **kwargs):
@@ -509,6 +550,7 @@ class CachedAdapter:
509
550
  params_shift += len(blocks)
510
551
 
511
552
  @classmethod
553
+ @torch.compiler.disable
512
554
  def apply_stats_hooks(
513
555
  cls,
514
556
  block_adapter: BlockAdapter,
@@ -56,6 +56,11 @@ class BasicCacheConfig:
56
56
  # Compute separate diff values for CFG and non-CFG step, default True. If False, we will
57
57
  # use the computed diff from current non-CFG transformer step for current CFG step.
58
58
  cfg_diff_compute_separate: bool = True
59
+ # num_inference_steps (`int`, *optional*, defaults to None):
60
+ # num_inference_steps for DiffusionPipeline, used to adjust some internal settings
61
+ # for better caching performance. For example, we will refresh the cache once the
62
+ # executed steps exceed num_inference_steps if num_inference_steps is provided.
63
+ num_inference_steps: Optional[int] = None
59
64
 
60
65
  def update(self, **kwargs) -> "BasicCacheConfig":
61
66
  for key, value in kwargs.items():
@@ -108,9 +113,6 @@ class ExtraCacheConfig:
108
113
  # downsample_factor (`int`, *optional*, defaults to 1):
109
114
  # Downsample factor for Fn buffer, in order the save GPU memory.
110
115
  downsample_factor: int = 1
111
- # num_inference_steps (`int`, *optional*, defaults to -1):
112
- # num_inference_steps for DiffusionPipeline, for future use.
113
- num_inference_steps: int = -1
114
116
 
115
117
 
116
118
  @dataclasses.dataclass
@@ -7,6 +7,7 @@ import torch.distributed as dist
7
7
 
8
8
  from cache_dit.cache_factory.cache_contexts.calibrators import CalibratorBase
9
9
  from cache_dit.cache_factory.cache_contexts.cache_context import (
10
+ BasicCacheConfig,
10
11
  CachedContext,
11
12
  )
12
13
  from cache_dit.logger import init_logger
@@ -21,23 +22,143 @@ class ContextNotExistError(Exception):
21
22
  class CachedContextManager:
22
23
  # Each Pipeline should have it's own context manager instance.
23
24
 
24
- def __init__(self, name: str = None):
25
+ def __init__(self, name: str = None, persistent_context: bool = False):
25
26
  self.name = name
26
27
  self._current_context: CachedContext = None
27
28
  self._cached_context_manager: Dict[str, CachedContext] = {}
29
+ # Whether to create new context automatically when setting
30
+ # a non-exist context name. Persistent context is useful when
31
+ # the pipeline class is not provided and users want to use
32
+ # cache-dit in a transformer-only way.
33
+ self._persistent_context = persistent_context
34
+ self._current_step_refreshed: bool = False
35
+
36
+ @property
37
+ def persistent_context(self) -> bool:
38
+ return self._persistent_context
39
+
40
+ @property
41
+ def current_context(self) -> CachedContext:
42
+ return self._current_context
43
+
44
+ @property
45
+ @torch.compiler.disable
46
+ def current_step_refreshed(self) -> bool:
47
+ return self._current_step_refreshed
48
+
49
+ @torch.compiler.disable
50
+ def is_pre_refreshed(self) -> bool:
51
+ _context = self._current_context
52
+ if _context is None:
53
+ return False
54
+
55
+ num_inference_steps = _context.cache_config.num_inference_steps
56
+ if num_inference_steps is not None:
57
+ current_step = _context.get_current_step() # e.g, 0~49,50~99,...
58
+ return current_step == num_inference_steps - 1
59
+ return False
28
60
 
61
+ @torch.compiler.disable
29
62
  def new_context(self, *args, **kwargs) -> CachedContext:
63
+ if self._persistent_context:
64
+ cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
65
+ assert (
66
+ cache_config is not None
67
+ and cache_config.num_inference_steps is not None
68
+ ), (
69
+ "When persistent_context is True, num_inference_steps "
70
+ "must be set in cache_config for proper cache refreshing."
71
+ f"\nkwargs: {kwargs}"
72
+ )
30
73
  _context = CachedContext(*args, **kwargs)
74
+ # NOTE: Patch args and kwargs for implicit refresh.
75
+ _context._init_args = args # maybe empty tuple: ()
76
+ _context._init_kwargs = kwargs # maybe empty dict: {}
31
77
  self._cached_context_manager[_context.name] = _context
32
78
  return _context
33
79
 
34
- def set_context(self, cached_context: CachedContext | str) -> CachedContext:
80
+ @torch.compiler.disable
81
+ def maybe_refresh(
82
+ self,
83
+ cached_context: Optional[CachedContext | str] = None,
84
+ ) -> bool:
85
+ if cached_context is None:
86
+ _context = self._current_context
87
+ assert _context is not None, "Current context is not set!"
88
+
35
89
  if isinstance(cached_context, CachedContext):
36
- self._current_context = cached_context
90
+ _context = cached_context
37
91
  else:
38
92
  if cached_context not in self._cached_context_manager:
39
93
  raise ContextNotExistError("Context not exist!")
40
- self._current_context = self._cached_context_manager[cached_context]
94
+ _context = self._cached_context_manager[cached_context]
95
+
96
+ if self._persistent_context:
97
+ assert _context.cache_config.num_inference_steps is not None, (
98
+ "When persistent_context is True, num_inference_steps must be set "
99
+ "in cache_config for proper cache refreshing."
100
+ )
101
+
102
+ num_inference_steps = _context.cache_config.num_inference_steps
103
+ if num_inference_steps is not None:
104
+ current_step = _context.get_current_step() # e.g, 0~49,50~99,...
105
+ # Another round of inference, need to refresh cache context.
106
+ if current_step >= num_inference_steps:
107
+ if logger.isEnabledFor(logging.DEBUG):
108
+ logger.debug(
109
+ f"Refreshing cache context '{_context.name}' "
110
+ f"as current step: {current_step} >= "
111
+ f"num_inference_steps: {num_inference_steps}."
112
+ )
113
+ return True
114
+ return False
115
+
116
+ @torch.compiler.disable
117
+ def set_context(
118
+ self,
119
+ cached_context: CachedContext | str,
120
+ *args,
121
+ **kwargs,
122
+ ) -> CachedContext:
123
+ if isinstance(cached_context, CachedContext):
124
+ self._current_context = cached_context
125
+ else:
126
+ if cached_context not in self._cached_context_manager:
127
+ if not self._persistent_context:
128
+ raise ContextNotExistError(
129
+ "Context not exist and persistent_context is False. Please "
130
+ "create new context first or set persistent_context=True."
131
+ )
132
+ else:
133
+ # Create new context if not exist
134
+ if any((bool(args), bool(kwargs))):
135
+ kwargs["name"] = cached_context
136
+ self._current_context = self.new_context(
137
+ *args, **kwargs
138
+ )
139
+ else:
140
+ raise ValueError(
141
+ "To create new context, please provide args and kwargs."
142
+ )
143
+ else:
144
+ self._current_context = self._cached_context_manager[
145
+ cached_context
146
+ ]
147
+
148
+ if self.maybe_refresh(self._current_context):
149
+ if not any((bool(args), bool(kwargs))):
150
+ assert hasattr(self._current_context, "_init_args")
151
+ assert hasattr(self._current_context, "_init_kwargs")
152
+ args = self._current_context._init_args
153
+ kwargs = self._current_context._init_kwargs
154
+
155
+ self._current_context = self.reset_context(
156
+ self._current_context, *args, **kwargs
157
+ )
158
+ self._current_step_refreshed = True
159
+ else:
160
+ self._current_step_refreshed = False
161
+
41
162
  return self._current_context
42
163
 
43
164
  def get_context(self, name: str = None) -> CachedContext:
@@ -20,10 +20,17 @@ class ContextManager:
20
20
  cls,
21
21
  cache_type: CacheType,
22
22
  name: str = "default",
23
+ persistent_context: bool = False,
23
24
  ) -> CachedContextManager | PrunedContextManager:
24
25
  if cache_type == CacheType.DBCache:
25
- return CachedContextManager(name)
26
+ return CachedContextManager(
27
+ name=name,
28
+ persistent_context=persistent_context,
29
+ )
26
30
  elif cache_type == CacheType.DBPrune:
27
- return PrunedContextManager(name)
31
+ return PrunedContextManager(
32
+ name=name,
33
+ persistent_context=persistent_context,
34
+ )
28
35
  else:
29
36
  raise ValueError(f"Unsupported cache_type: {cache_type}.")
@@ -3,6 +3,7 @@ import functools
3
3
  from typing import Dict, List, Tuple, Union
4
4
 
5
5
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
6
+ BasicCacheConfig,
6
7
  CachedContextManager,
7
8
  )
8
9
  from cache_dit.cache_factory.cache_contexts.prune_context import (
@@ -16,15 +17,27 @@ logger = init_logger(__name__)
16
17
  class PrunedContextManager(CachedContextManager):
17
18
  # Reuse CachedContextManager for Dynamic Block Prune
18
19
 
19
- def __init__(self, name: str = None):
20
- super().__init__(name)
20
+ def __init__(self, name: str = None, **kwargs):
21
+ super().__init__(name, **kwargs)
21
22
  # Overwrite for Dynamic Block Prune
22
23
  self._current_context: PrunedContext = None
23
24
  self._cached_context_manager: Dict[str, PrunedContext] = {}
24
25
 
25
26
  # Overwrite for Dynamic Block Prune
26
27
  def new_context(self, *args, **kwargs) -> PrunedContext:
28
+ if self._persistent_context:
29
+ cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
30
+ assert (
31
+ cache_config is not None
32
+ and cache_config.num_inference_steps is not None
33
+ ), (
34
+ "When persistent_context is True, num_inference_steps "
35
+ "must be set in cache_config for proper cache refreshing."
36
+ )
27
37
  _context = PrunedContext(*args, **kwargs)
38
+ # NOTE: Patch args and kwargs for implicit refresh.
39
+ _context._init_args = args # maybe empty tuple: ()
40
+ _context._init_kwargs = kwargs # maybe empty dict: {}
28
41
  self._cached_context_manager[_context.name] = _context
29
42
  return _context
30
43
 
@@ -1,5 +1,6 @@
1
+ import torch
1
2
  from typing import Any, Tuple, List, Union, Optional
2
- from diffusers import DiffusionPipeline
3
+ from diffusers import DiffusionPipeline, ModelMixin
3
4
  from cache_dit.cache_factory.cache_types import CacheType
4
5
  from cache_dit.cache_factory.block_adapters import BlockAdapter
5
6
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
@@ -22,13 +23,18 @@ def enable_cache(
22
23
  pipe_or_adapter: Union[
23
24
  DiffusionPipeline,
24
25
  BlockAdapter,
26
+ # Transformer-only
27
+ torch.nn.Module,
28
+ ModelMixin,
25
29
  ],
26
30
  # BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
27
- cache_config: Union[
28
- BasicCacheConfig,
29
- DBCacheConfig,
30
- DBPruneConfig,
31
- ] = DBCacheConfig(),
31
+ cache_config: Optional[
32
+ Union[
33
+ BasicCacheConfig,
34
+ DBCacheConfig,
35
+ DBPruneConfig,
36
+ ]
37
+ ] = None,
32
38
  # Calibrator config: TaylorSeerCalibratorConfig, etc.
33
39
  calibrator_config: Optional[CalibratorConfig] = None,
34
40
  # Modify cache context params for specific blocks.
@@ -45,6 +51,9 @@ def enable_cache(
45
51
  **kwargs,
46
52
  ) -> Union[
47
53
  DiffusionPipeline,
54
+ # Transformer-only
55
+ torch.nn.Module,
56
+ ModelMixin,
48
57
  BlockAdapter,
49
58
  ]:
50
59
  r"""
@@ -74,7 +83,7 @@ def enable_cache(
74
83
  with minimal code changes.
75
84
 
76
85
  Args:
77
- pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
86
+ pipe_or_adapter (`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
78
87
  The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
79
88
  For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
80
89
  for the usgae of BlockAdapter.
@@ -117,6 +126,10 @@ def enable_cache(
117
126
  Whether to compute separate difference values for CFG and non-CFG steps, default is True.
118
127
  If False, we will use the computed difference from the current non-CFG transformer step
119
128
  for the current CFG step.
129
+ num_inference_steps (`int`, *optional*, defaults to None):
130
+ num_inference_steps for DiffusionPipeline, used to adjust some internal settings
131
+ for better caching performance. For example, we will refresh the cache once the
132
+ executed steps exceed num_inference_steps if num_inference_steps is provided.
120
133
 
121
134
  calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
122
135
  Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
@@ -135,10 +148,22 @@ def enable_cache(
135
148
  Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
136
149
  parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
137
150
  for more details of ParallelismConfig.
151
+ backend: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
152
+ Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
153
+ For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
154
+ only NATIVE_PYTORCH backend is supported.
138
155
  ulysses_size: (`int`, *optional*, defaults to None):
139
156
  The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
157
+ This setting is only valid when backend is NATIVE_DIFFUSER.
140
158
  ring_size: (`int`, *optional*, defaults to None):
141
159
  The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
160
+ This setting is only valid when backend is NATIVE_DIFFUSER.
161
+ tp_size: (`int`, *optional*, defaults to None):
162
+ The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
163
+ This setting is only valid when backend is NATIVE_PYTORCH.
164
+ parallel_kwargs: (`dict`, *optional*, defaults to {}):
165
+ Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
166
+ it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
142
167
 
143
168
  kwargs (`dict`, *optional*, defaults to {})
144
169
  Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
@@ -154,13 +179,27 @@ def enable_cache(
154
179
  >>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
155
180
  >>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
156
181
  """
182
+ # Precheck for compatibility of different configurations
183
+ if cache_config is None:
184
+ if parallelism_config is None:
185
+ # Set default cache config only when parallelism is not enabled
186
+ logger.info("cache_config is None, using default DBCacheConfig")
187
+ cache_config = DBCacheConfig()
188
+ else:
189
+ # Allow empty cache_config when parallelism is enabled
190
+ logger.warning(
191
+ "Parallelism is enabled and cache_config is None. Please manually "
192
+ "set cache_config to avoid potential compatibility issues. "
193
+ "Otherwise, cache will not be enabled."
194
+ )
195
+
157
196
  # Collect cache context kwargs
158
197
  context_kwargs = {}
159
198
  if (cache_type := context_kwargs.get("cache_type", None)) is not None:
160
199
  if cache_type == CacheType.NONE:
161
200
  return pipe_or_adapter
162
201
 
163
- # WARNING: Deprecated cache config params. These parameters are now retained
202
+ # NOTE: Deprecated cache config params. These parameters are now retained
164
203
  # for backward compatibility but will be removed in the future.
165
204
  deprecated_kwargs = {
166
205
  "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
@@ -196,9 +235,9 @@ def enable_cache(
196
235
  if cache_config is not None:
197
236
  context_kwargs["cache_config"] = cache_config
198
237
 
199
- # WARNING: Deprecated taylorseer params. These parameters are now retained
238
+ # NOTE: Deprecated taylorseer params. These parameters are now retained
200
239
  # for backward compatibility but will be removed in the future.
201
- if (
240
+ if cache_config is not None and (
202
241
  kwargs.get("enable_taylorseer", None) is not None
203
242
  or kwargs.get("enable_encoder_taylorseer", None) is not None
204
243
  ):
@@ -226,16 +265,25 @@ def enable_cache(
226
265
  if params_modifiers is not None:
227
266
  context_kwargs["params_modifiers"] = params_modifiers
228
267
 
229
- if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
230
- pipe_or_adapter = CachedAdapter.apply(
268
+ if cache_config is not None:
269
+ if isinstance(
231
270
  pipe_or_adapter,
232
- **context_kwargs,
233
- )
271
+ (DiffusionPipeline, BlockAdapter, torch.nn.Module, ModelMixin),
272
+ ):
273
+ pipe_or_adapter = CachedAdapter.apply(
274
+ pipe_or_adapter,
275
+ **context_kwargs,
276
+ )
277
+ else:
278
+ raise ValueError(
279
+ f"type: {type(pipe_or_adapter)} is not valid, "
280
+ "Please pass DiffusionPipeline or BlockAdapter"
281
+ "for the 1's position param: pipe_or_adapter"
282
+ )
234
283
  else:
235
- raise ValueError(
236
- f"type: {type(pipe_or_adapter)} is not valid, "
237
- "Please pass DiffusionPipeline or BlockAdapter"
238
- "for the 1's position param: pipe_or_adapter"
284
+ logger.warning(
285
+ "cache_config is None, skip enabling cache for "
286
+ f"{pipe_or_adapter.__class__.__name__}."
239
287
  )
240
288
 
241
289
  # NOTE: Users should always enable parallelism after applying
@@ -244,19 +292,45 @@ def enable_cache(
244
292
  assert isinstance(
245
293
  parallelism_config, ParallelismConfig
246
294
  ), "parallelism_config should be of type ParallelismConfig."
295
+
296
+ transformers = []
247
297
  if isinstance(pipe_or_adapter, DiffusionPipeline):
248
- transformer = pipe_or_adapter.transformer
298
+ adapter = BlockAdapterRegistry.get_adapter(pipe_or_adapter)
299
+ if adapter is None:
300
+ assert hasattr(pipe_or_adapter, "transformer"), (
301
+ "The given DiffusionPipeline does not have "
302
+ "a 'transformer' attribute, cannot enable "
303
+ "parallelism."
304
+ )
305
+ transformers = [pipe_or_adapter.transformer]
306
+ else:
307
+ adapter = BlockAdapter.normalize(adapter, unique=False)
308
+ transformers = BlockAdapter.flatten(adapter.transformer)
249
309
  else:
250
- assert BlockAdapter.assert_normalized(pipe_or_adapter)
251
- assert (
252
- len(BlockAdapter.flatten(pipe_or_adapter.transformer)) == 1
253
- ), (
254
- "Only single transformer is supported to enable parallelism "
255
- "currently for BlockAdapter."
310
+ if not BlockAdapter.is_normalized(pipe_or_adapter):
311
+ pipe_or_adapter = BlockAdapter.normalize(
312
+ pipe_or_adapter, unique=False
313
+ )
314
+ transformers = BlockAdapter.flatten(pipe_or_adapter.transformer)
315
+
316
+ if len(transformers) == 0:
317
+ logger.warning(
318
+ "No transformer is detected in the "
319
+ "BlockAdapter, skip enabling parallelism."
320
+ )
321
+ return pipe_or_adapter
322
+
323
+ if len(transformers) > 1:
324
+ logger.warning(
325
+ "Multiple transformers are detected in the "
326
+ "BlockAdapter, all transfomers will be "
327
+ "enabled for parallelism."
328
+ )
329
+ for i, transformer in enumerate(transformers):
330
+ # Enable parallelism for the transformer inplace
331
+ transformers[i] = enable_parallelism(
332
+ transformer, parallelism_config
256
333
  )
257
- transformer = BlockAdapter.flatten(pipe_or_adapter.transformer)[0]
258
- # Enable parallelism for the transformer inplace
259
- transformer = enable_parallelism(transformer, parallelism_config)
260
334
  return pipe_or_adapter
261
335
 
262
336