cache-dit 1.0.9__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (45) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/block_adapters/__init__.py +37 -0
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +51 -3
  5. cache_dit/cache_factory/block_adapters/block_registers.py +41 -14
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +68 -30
  7. cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
  8. cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
  9. cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
  10. cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
  11. cache_dit/cache_factory/cache_interface.py +29 -3
  12. cache_dit/cache_factory/forward_pattern.py +14 -14
  13. cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
  14. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
  15. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
  16. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
  17. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -61
  18. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  19. cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
  20. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  21. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  22. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
  23. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  24. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
  25. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  26. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
  27. cache_dit/parallelism/parallel_backend.py +2 -0
  28. cache_dit/parallelism/parallel_config.py +8 -1
  29. cache_dit/parallelism/parallel_interface.py +9 -4
  30. cache_dit/quantize/backends/__init__.py +1 -0
  31. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  32. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  33. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
  34. cache_dit/quantize/quantize_backend.py +0 -0
  35. cache_dit/quantize/quantize_config.py +0 -0
  36. cache_dit/quantize/quantize_interface.py +3 -16
  37. cache_dit/utils.py +22 -2
  38. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/METADATA +22 -13
  39. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
  40. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  41. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  42. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
  43. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
  44. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
  45. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
@@ -56,6 +56,11 @@ class BasicCacheConfig:
56
56
  # Compute separate diff values for CFG and non-CFG step, default True. If False, we will
57
57
  # use the computed diff from current non-CFG transformer step for current CFG step.
58
58
  cfg_diff_compute_separate: bool = True
59
+ # num_inference_steps (`int`, *optional*, defaults to None):
60
+ # num_inference_steps for DiffusionPipeline, used to adjust some internal settings
61
+ # for better caching performance. For example, we will refresh the cache once the
62
+ # executed steps exceed num_inference_steps if num_inference_steps is provided.
63
+ num_inference_steps: Optional[int] = None
59
64
 
60
65
  def update(self, **kwargs) -> "BasicCacheConfig":
61
66
  for key, value in kwargs.items():
@@ -108,9 +113,6 @@ class ExtraCacheConfig:
108
113
  # downsample_factor (`int`, *optional*, defaults to 1):
109
114
  # Downsample factor for Fn buffer, in order the save GPU memory.
110
115
  downsample_factor: int = 1
111
- # num_inference_steps (`int`, *optional*, defaults to -1):
112
- # num_inference_steps for DiffusionPipeline, for future use.
113
- num_inference_steps: int = -1
114
116
 
115
117
 
116
118
  @dataclasses.dataclass
@@ -7,6 +7,7 @@ import torch.distributed as dist
7
7
 
8
8
  from cache_dit.cache_factory.cache_contexts.calibrators import CalibratorBase
9
9
  from cache_dit.cache_factory.cache_contexts.cache_context import (
10
+ BasicCacheConfig,
10
11
  CachedContext,
11
12
  )
12
13
  from cache_dit.logger import init_logger
@@ -21,23 +22,143 @@ class ContextNotExistError(Exception):
21
22
  class CachedContextManager:
22
23
  # Each Pipeline should have it's own context manager instance.
23
24
 
24
- def __init__(self, name: str = None):
25
+ def __init__(self, name: str = None, persistent_context: bool = False):
25
26
  self.name = name
26
27
  self._current_context: CachedContext = None
27
28
  self._cached_context_manager: Dict[str, CachedContext] = {}
29
+ # Whether to create new context automatically when setting
30
+ # a non-exist context name. Persistent context is useful when
31
+ # the pipeline class is not provided and users want to use
32
+ # cache-dit in a transformer-only way.
33
+ self._persistent_context = persistent_context
34
+ self._current_step_refreshed: bool = False
35
+
36
+ @property
37
+ def persistent_context(self) -> bool:
38
+ return self._persistent_context
39
+
40
+ @property
41
+ def current_context(self) -> CachedContext:
42
+ return self._current_context
43
+
44
+ @property
45
+ @torch.compiler.disable
46
+ def current_step_refreshed(self) -> bool:
47
+ return self._current_step_refreshed
48
+
49
+ @torch.compiler.disable
50
+ def is_pre_refreshed(self) -> bool:
51
+ _context = self._current_context
52
+ if _context is None:
53
+ return False
54
+
55
+ num_inference_steps = _context.cache_config.num_inference_steps
56
+ if num_inference_steps is not None:
57
+ current_step = _context.get_current_step() # e.g, 0~49,50~99,...
58
+ return current_step == num_inference_steps - 1
59
+ return False
28
60
 
61
+ @torch.compiler.disable
29
62
  def new_context(self, *args, **kwargs) -> CachedContext:
63
+ if self._persistent_context:
64
+ cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
65
+ assert (
66
+ cache_config is not None
67
+ and cache_config.num_inference_steps is not None
68
+ ), (
69
+ "When persistent_context is True, num_inference_steps "
70
+ "must be set in cache_config for proper cache refreshing."
71
+ f"\nkwargs: {kwargs}"
72
+ )
30
73
  _context = CachedContext(*args, **kwargs)
74
+ # NOTE: Patch args and kwargs for implicit refresh.
75
+ _context._init_args = args # maybe empty tuple: ()
76
+ _context._init_kwargs = kwargs # maybe empty dict: {}
31
77
  self._cached_context_manager[_context.name] = _context
32
78
  return _context
33
79
 
34
- def set_context(self, cached_context: CachedContext | str) -> CachedContext:
80
+ @torch.compiler.disable
81
+ def maybe_refresh(
82
+ self,
83
+ cached_context: Optional[CachedContext | str] = None,
84
+ ) -> bool:
85
+ if cached_context is None:
86
+ _context = self._current_context
87
+ assert _context is not None, "Current context is not set!"
88
+
35
89
  if isinstance(cached_context, CachedContext):
36
- self._current_context = cached_context
90
+ _context = cached_context
37
91
  else:
38
92
  if cached_context not in self._cached_context_manager:
39
93
  raise ContextNotExistError("Context not exist!")
40
- self._current_context = self._cached_context_manager[cached_context]
94
+ _context = self._cached_context_manager[cached_context]
95
+
96
+ if self._persistent_context:
97
+ assert _context.cache_config.num_inference_steps is not None, (
98
+ "When persistent_context is True, num_inference_steps must be set "
99
+ "in cache_config for proper cache refreshing."
100
+ )
101
+
102
+ num_inference_steps = _context.cache_config.num_inference_steps
103
+ if num_inference_steps is not None:
104
+ current_step = _context.get_current_step() # e.g, 0~49,50~99,...
105
+ # Another round of inference, need to refresh cache context.
106
+ if current_step >= num_inference_steps:
107
+ if logger.isEnabledFor(logging.DEBUG):
108
+ logger.debug(
109
+ f"Refreshing cache context '{_context.name}' "
110
+ f"as current step: {current_step} >= "
111
+ f"num_inference_steps: {num_inference_steps}."
112
+ )
113
+ return True
114
+ return False
115
+
116
+ @torch.compiler.disable
117
+ def set_context(
118
+ self,
119
+ cached_context: CachedContext | str,
120
+ *args,
121
+ **kwargs,
122
+ ) -> CachedContext:
123
+ if isinstance(cached_context, CachedContext):
124
+ self._current_context = cached_context
125
+ else:
126
+ if cached_context not in self._cached_context_manager:
127
+ if not self._persistent_context:
128
+ raise ContextNotExistError(
129
+ "Context not exist and persistent_context is False. Please "
130
+ "create new context first or set persistent_context=True."
131
+ )
132
+ else:
133
+ # Create new context if not exist
134
+ if any((bool(args), bool(kwargs))):
135
+ kwargs["name"] = cached_context
136
+ self._current_context = self.new_context(
137
+ *args, **kwargs
138
+ )
139
+ else:
140
+ raise ValueError(
141
+ "To create new context, please provide args and kwargs."
142
+ )
143
+ else:
144
+ self._current_context = self._cached_context_manager[
145
+ cached_context
146
+ ]
147
+
148
+ if self.maybe_refresh(self._current_context):
149
+ if not any((bool(args), bool(kwargs))):
150
+ assert hasattr(self._current_context, "_init_args")
151
+ assert hasattr(self._current_context, "_init_kwargs")
152
+ args = self._current_context._init_args
153
+ kwargs = self._current_context._init_kwargs
154
+
155
+ self._current_context = self.reset_context(
156
+ self._current_context, *args, **kwargs
157
+ )
158
+ self._current_step_refreshed = True
159
+ else:
160
+ self._current_step_refreshed = False
161
+
41
162
  return self._current_context
42
163
 
43
164
  def get_context(self, name: str = None) -> CachedContext:
@@ -20,10 +20,17 @@ class ContextManager:
20
20
  cls,
21
21
  cache_type: CacheType,
22
22
  name: str = "default",
23
+ persistent_context: bool = False,
23
24
  ) -> CachedContextManager | PrunedContextManager:
24
25
  if cache_type == CacheType.DBCache:
25
- return CachedContextManager(name)
26
+ return CachedContextManager(
27
+ name=name,
28
+ persistent_context=persistent_context,
29
+ )
26
30
  elif cache_type == CacheType.DBPrune:
27
- return PrunedContextManager(name)
31
+ return PrunedContextManager(
32
+ name=name,
33
+ persistent_context=persistent_context,
34
+ )
28
35
  else:
29
36
  raise ValueError(f"Unsupported cache_type: {cache_type}.")
@@ -3,6 +3,7 @@ import functools
3
3
  from typing import Dict, List, Tuple, Union
4
4
 
5
5
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
6
+ BasicCacheConfig,
6
7
  CachedContextManager,
7
8
  )
8
9
  from cache_dit.cache_factory.cache_contexts.prune_context import (
@@ -16,15 +17,27 @@ logger = init_logger(__name__)
16
17
  class PrunedContextManager(CachedContextManager):
17
18
  # Reuse CachedContextManager for Dynamic Block Prune
18
19
 
19
- def __init__(self, name: str = None):
20
- super().__init__(name)
20
+ def __init__(self, name: str = None, **kwargs):
21
+ super().__init__(name, **kwargs)
21
22
  # Overwrite for Dynamic Block Prune
22
23
  self._current_context: PrunedContext = None
23
24
  self._cached_context_manager: Dict[str, PrunedContext] = {}
24
25
 
25
26
  # Overwrite for Dynamic Block Prune
26
27
  def new_context(self, *args, **kwargs) -> PrunedContext:
28
+ if self._persistent_context:
29
+ cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
30
+ assert (
31
+ cache_config is not None
32
+ and cache_config.num_inference_steps is not None
33
+ ), (
34
+ "When persistent_context is True, num_inference_steps "
35
+ "must be set in cache_config for proper cache refreshing."
36
+ )
27
37
  _context = PrunedContext(*args, **kwargs)
38
+ # NOTE: Patch args and kwargs for implicit refresh.
39
+ _context._init_args = args # maybe empty tuple: ()
40
+ _context._init_kwargs = kwargs # maybe empty dict: {}
28
41
  self._cached_context_manager[_context.name] = _context
29
42
  return _context
30
43
 
@@ -1,5 +1,6 @@
1
+ import torch
1
2
  from typing import Any, Tuple, List, Union, Optional
2
- from diffusers import DiffusionPipeline
3
+ from diffusers import DiffusionPipeline, ModelMixin
3
4
  from cache_dit.cache_factory.cache_types import CacheType
4
5
  from cache_dit.cache_factory.block_adapters import BlockAdapter
5
6
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
@@ -22,6 +23,9 @@ def enable_cache(
22
23
  pipe_or_adapter: Union[
23
24
  DiffusionPipeline,
24
25
  BlockAdapter,
26
+ # Transformer-only
27
+ torch.nn.Module,
28
+ ModelMixin,
25
29
  ],
26
30
  # BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
27
31
  cache_config: Optional[
@@ -47,6 +51,9 @@ def enable_cache(
47
51
  **kwargs,
48
52
  ) -> Union[
49
53
  DiffusionPipeline,
54
+ # Transformer-only
55
+ torch.nn.Module,
56
+ ModelMixin,
50
57
  BlockAdapter,
51
58
  ]:
52
59
  r"""
@@ -76,7 +83,7 @@ def enable_cache(
76
83
  with minimal code changes.
77
84
 
78
85
  Args:
79
- pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
86
+ pipe_or_adapter (`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
80
87
  The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
81
88
  For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
82
89
  for the usgae of BlockAdapter.
@@ -119,6 +126,10 @@ def enable_cache(
119
126
  Whether to compute separate difference values for CFG and non-CFG steps, default is True.
120
127
  If False, we will use the computed difference from the current non-CFG transformer step
121
128
  for the current CFG step.
129
+ num_inference_steps (`int`, *optional*, defaults to None):
130
+ num_inference_steps for DiffusionPipeline, used to adjust some internal settings
131
+ for better caching performance. For example, we will refresh the cache once the
132
+ executed steps exceed num_inference_steps if num_inference_steps is provided.
122
133
 
123
134
  calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
124
135
  Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
@@ -137,10 +148,22 @@ def enable_cache(
137
148
  Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
138
149
  parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
139
150
  for more details of ParallelismConfig.
151
+ backend: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
152
+ Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
153
+ For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
154
+ only NATIVE_PYTORCH backend is supported.
140
155
  ulysses_size: (`int`, *optional*, defaults to None):
141
156
  The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
157
+ This setting is only valid when backend is NATIVE_DIFFUSER.
142
158
  ring_size: (`int`, *optional*, defaults to None):
143
159
  The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
160
+ This setting is only valid when backend is NATIVE_DIFFUSER.
161
+ tp_size: (`int`, *optional*, defaults to None):
162
+ The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
163
+ This setting is only valid when backend is NATIVE_PYTORCH.
164
+ parallel_kwargs: (`dict`, *optional*, defaults to {}):
165
+ Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
166
+ it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
144
167
 
145
168
  kwargs (`dict`, *optional*, defaults to {})
146
169
  Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
@@ -243,7 +266,10 @@ def enable_cache(
243
266
  context_kwargs["params_modifiers"] = params_modifiers
244
267
 
245
268
  if cache_config is not None:
246
- if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
269
+ if isinstance(
270
+ pipe_or_adapter,
271
+ (DiffusionPipeline, BlockAdapter, torch.nn.Module, ModelMixin),
272
+ ):
247
273
  pipe_or_adapter = CachedAdapter.apply(
248
274
  pipe_or_adapter,
249
275
  **context_kwargs,
@@ -20,33 +20,33 @@ class ForwardPattern(Enum):
20
20
 
21
21
  Pattern_0 = (
22
22
  True, # Return_H_First
23
- False, # Return_H_Only
24
- False, # Forward_H_only
23
+ False, # Return_H_Only
24
+ False, # Forward_H_only
25
25
  ("hidden_states", "encoder_hidden_states"), # In
26
26
  ("hidden_states", "encoder_hidden_states"), # Out
27
27
  True, # Supported
28
28
  )
29
29
 
30
30
  Pattern_1 = (
31
- False, # Return_H_First
32
- False, # Return_H_Only
33
- False, # Forward_H_only
31
+ False, # Return_H_First
32
+ False, # Return_H_Only
33
+ False, # Forward_H_only
34
34
  ("hidden_states", "encoder_hidden_states"), # In
35
35
  ("encoder_hidden_states", "hidden_states"), # Out
36
36
  True, # Supported
37
37
  )
38
38
 
39
39
  Pattern_2 = (
40
- False, # Return_H_First
40
+ False, # Return_H_First
41
41
  True, # Return_H_Only
42
- False, # Forward_H_only
42
+ False, # Forward_H_only
43
43
  ("hidden_states", "encoder_hidden_states"), # In
44
- ("hidden_states",), # Out
44
+ ("hidden_states",), # Out
45
45
  True, # Supported
46
46
  )
47
47
 
48
48
  Pattern_3 = (
49
- False, # Return_H_First
49
+ False, # Return_H_First
50
50
  True, # Return_H_Only
51
51
  True, # Forward_H_only
52
52
  ("hidden_states",), # In
@@ -56,18 +56,18 @@ class ForwardPattern(Enum):
56
56
 
57
57
  Pattern_4 = (
58
58
  True, # Return_H_First
59
- False, # Return_H_Only
59
+ False, # Return_H_Only
60
60
  True, # Forward_H_only
61
- ("hidden_states",), # In
61
+ ("hidden_states",), # In
62
62
  ("hidden_states", "encoder_hidden_states"), # Out
63
63
  True, # Supported
64
64
  )
65
65
 
66
66
  Pattern_5 = (
67
- False, # Return_H_First
68
- False, # Return_H_Only
67
+ False, # Return_H_First
68
+ False, # Return_H_Only
69
69
  True, # Forward_H_only
70
- ("hidden_states",), # In
70
+ ("hidden_states",), # In
71
71
  ("encoder_hidden_states", "hidden_states"), # Out
72
72
  True, # Supported
73
73
  )
@@ -1,6 +1,3 @@
1
1
  from cache_dit.parallelism.backends.native_diffusers.parallel_difffusers import (
2
2
  maybe_enable_parallelism,
3
3
  )
4
- from cache_dit.parallelism.backends.native_diffusers.parallel_difffusers import (
5
- native_diffusers_parallelism_available,
6
- )
@@ -0,0 +1,95 @@
1
+ import torch
2
+ from typing import Optional
3
+
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
6
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
7
+ from cache_dit.logger import init_logger
8
+ from ..utils import (
9
+ native_diffusers_parallelism_available,
10
+ ContextParallelConfig,
11
+ )
12
+ from .cp_planners import *
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ def maybe_enable_context_parallelism(
18
+ transformer: torch.nn.Module,
19
+ parallelism_config: Optional[ParallelismConfig],
20
+ ) -> torch.nn.Module:
21
+ assert isinstance(transformer, ModelMixin), (
22
+ "transformer must be an instance of diffusers' ModelMixin, "
23
+ f"but got {type(transformer)}"
24
+ )
25
+ if parallelism_config is None:
26
+ return transformer
27
+
28
+ assert isinstance(parallelism_config, ParallelismConfig), (
29
+ "parallelism_config must be an instance of ParallelismConfig"
30
+ f" but got {type(parallelism_config)}"
31
+ )
32
+
33
+ if (
34
+ parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
35
+ and native_diffusers_parallelism_available()
36
+ ):
37
+ cp_config = None
38
+ if (
39
+ parallelism_config.ulysses_size is not None
40
+ or parallelism_config.ring_size is not None
41
+ ):
42
+ cp_config = ContextParallelConfig(
43
+ ulysses_degree=parallelism_config.ulysses_size,
44
+ ring_degree=parallelism_config.ring_size,
45
+ )
46
+ if cp_config is not None:
47
+ attention_backend = parallelism_config.parallel_kwargs.get(
48
+ "attention_backend", None
49
+ )
50
+ if hasattr(transformer, "enable_parallelism"):
51
+ if hasattr(transformer, "set_attention_backend"):
52
+ # _native_cudnn, flash, etc.
53
+ if attention_backend is None:
54
+ # Now only _native_cudnn is supported for parallelism
55
+ # issue: https://github.com/huggingface/diffusers/pull/12443
56
+ transformer.set_attention_backend("_native_cudnn")
57
+ logger.warning(
58
+ "attention_backend is None, set default attention backend "
59
+ "to _native_cudnn for parallelism because of the issue: "
60
+ "https://github.com/huggingface/diffusers/pull/12443"
61
+ )
62
+ else:
63
+ transformer.set_attention_backend(attention_backend)
64
+ logger.info(
65
+ "Found attention_backend from config, set attention "
66
+ f"backend to: {attention_backend}"
67
+ )
68
+ # Prefer custom cp_plan if provided
69
+ cp_plan = parallelism_config.parallel_kwargs.get(
70
+ "cp_plan", None
71
+ )
72
+ if cp_plan is not None:
73
+ logger.info(
74
+ f"Using custom context parallelism plan: {cp_plan}"
75
+ )
76
+ else:
77
+ # Try get context parallelism plan from register if not provided
78
+ extra_parallel_kwargs = {}
79
+ if parallelism_config.parallel_kwargs is not None:
80
+ extra_parallel_kwargs = (
81
+ parallelism_config.parallel_kwargs
82
+ )
83
+ cp_plan = ContextParallelismPlannerRegister.get_planner(
84
+ transformer
85
+ )().apply(transformer=transformer, **extra_parallel_kwargs)
86
+
87
+ transformer.enable_parallelism(
88
+ config=cp_config, cp_plan=cp_plan
89
+ )
90
+ else:
91
+ raise ValueError(
92
+ f"{transformer.__class__.__name__} does not support context parallelism."
93
+ )
94
+
95
+ return transformer
@@ -0,0 +1,74 @@
1
+ import torch
2
+ import logging
3
+ from abc import abstractmethod
4
+ from typing import Optional
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+
7
+ try:
8
+ from diffusers.models._modeling_parallel import (
9
+ ContextParallelModelPlan,
10
+ )
11
+ except ImportError:
12
+ raise ImportError(
13
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
14
+ "Please install latest version of diffusers from source: \n"
15
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
16
+ )
17
+
18
+ from cache_dit.logger import init_logger
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ __all__ = [
24
+ "ContextParallelismPlanner",
25
+ "ContextParallelismPlannerRegister",
26
+ ]
27
+
28
+
29
+ class ContextParallelismPlanner:
30
+ @abstractmethod
31
+ def apply(
32
+ self,
33
+ # NOTE: Keep this kwarg for future extensions
34
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
35
+ **kwargs,
36
+ ) -> ContextParallelModelPlan:
37
+ # NOTE: This method should only return the CP plan dictionary.
38
+ raise NotImplementedError(
39
+ "apply method must be implemented by subclasses"
40
+ )
41
+
42
+
43
+ class ContextParallelismPlannerRegister:
44
+ _cp_planner_registry: dict[str, ContextParallelismPlanner] = {}
45
+
46
+ @classmethod
47
+ def register(cls, name: str):
48
+ def decorator(planner_cls: type[ContextParallelismPlanner]):
49
+ assert (
50
+ name not in cls._cp_planner_registry
51
+ ), f"ContextParallelismPlanner with name {name} is already registered."
52
+ if logger.isEnabledFor(logging.DEBUG):
53
+ logger.debug(f"Registering ContextParallelismPlanner: {name}")
54
+ cls._cp_planner_registry[name] = planner_cls
55
+ return planner_cls
56
+
57
+ return decorator
58
+
59
+ @classmethod
60
+ def get_planner(
61
+ cls, transformer: str | torch.nn.Module | ModelMixin
62
+ ) -> type[ContextParallelismPlanner]:
63
+ if isinstance(transformer, (torch.nn.Module, ModelMixin)):
64
+ name = transformer.__class__.__name__
65
+ else:
66
+ name = transformer
67
+ planner_cls = None
68
+ for planner_name in cls._cp_planner_registry:
69
+ if name.startswith(planner_name):
70
+ planner_cls = cls._cp_planner_registry.get(planner_name)
71
+ break
72
+ if planner_cls is None:
73
+ raise ValueError(f"No planner registered under name: {name}")
74
+ return planner_cls