cache-dit 1.0.9__py3-none-any.whl → 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +37 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +51 -3
- cache_dit/cache_factory/block_adapters/block_registers.py +41 -14
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +68 -30
- cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
- cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
- cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
- cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
- cache_dit/cache_factory/cache_interface.py +29 -3
- cache_dit/cache_factory/forward_pattern.py +14 -14
- cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -61
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
- cache_dit/parallelism/parallel_backend.py +2 -0
- cache_dit/parallelism/parallel_config.py +8 -1
- cache_dit/parallelism/parallel_interface.py +9 -4
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/utils.py +22 -2
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/METADATA +22 -13
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
|
@@ -56,6 +56,11 @@ class BasicCacheConfig:
|
|
|
56
56
|
# Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
57
57
|
# use the computed diff from current non-CFG transformer step for current CFG step.
|
|
58
58
|
cfg_diff_compute_separate: bool = True
|
|
59
|
+
# num_inference_steps (`int`, *optional*, defaults to None):
|
|
60
|
+
# num_inference_steps for DiffusionPipeline, used to adjust some internal settings
|
|
61
|
+
# for better caching performance. For example, we will refresh the cache once the
|
|
62
|
+
# executed steps exceed num_inference_steps if num_inference_steps is provided.
|
|
63
|
+
num_inference_steps: Optional[int] = None
|
|
59
64
|
|
|
60
65
|
def update(self, **kwargs) -> "BasicCacheConfig":
|
|
61
66
|
for key, value in kwargs.items():
|
|
@@ -108,9 +113,6 @@ class ExtraCacheConfig:
|
|
|
108
113
|
# downsample_factor (`int`, *optional*, defaults to 1):
|
|
109
114
|
# Downsample factor for Fn buffer, in order the save GPU memory.
|
|
110
115
|
downsample_factor: int = 1
|
|
111
|
-
# num_inference_steps (`int`, *optional*, defaults to -1):
|
|
112
|
-
# num_inference_steps for DiffusionPipeline, for future use.
|
|
113
|
-
num_inference_steps: int = -1
|
|
114
116
|
|
|
115
117
|
|
|
116
118
|
@dataclasses.dataclass
|
|
@@ -7,6 +7,7 @@ import torch.distributed as dist
|
|
|
7
7
|
|
|
8
8
|
from cache_dit.cache_factory.cache_contexts.calibrators import CalibratorBase
|
|
9
9
|
from cache_dit.cache_factory.cache_contexts.cache_context import (
|
|
10
|
+
BasicCacheConfig,
|
|
10
11
|
CachedContext,
|
|
11
12
|
)
|
|
12
13
|
from cache_dit.logger import init_logger
|
|
@@ -21,23 +22,143 @@ class ContextNotExistError(Exception):
|
|
|
21
22
|
class CachedContextManager:
|
|
22
23
|
# Each Pipeline should have it's own context manager instance.
|
|
23
24
|
|
|
24
|
-
def __init__(self, name: str = None):
|
|
25
|
+
def __init__(self, name: str = None, persistent_context: bool = False):
|
|
25
26
|
self.name = name
|
|
26
27
|
self._current_context: CachedContext = None
|
|
27
28
|
self._cached_context_manager: Dict[str, CachedContext] = {}
|
|
29
|
+
# Whether to create new context automatically when setting
|
|
30
|
+
# a non-exist context name. Persistent context is useful when
|
|
31
|
+
# the pipeline class is not provided and users want to use
|
|
32
|
+
# cache-dit in a transformer-only way.
|
|
33
|
+
self._persistent_context = persistent_context
|
|
34
|
+
self._current_step_refreshed: bool = False
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def persistent_context(self) -> bool:
|
|
38
|
+
return self._persistent_context
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def current_context(self) -> CachedContext:
|
|
42
|
+
return self._current_context
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
@torch.compiler.disable
|
|
46
|
+
def current_step_refreshed(self) -> bool:
|
|
47
|
+
return self._current_step_refreshed
|
|
48
|
+
|
|
49
|
+
@torch.compiler.disable
|
|
50
|
+
def is_pre_refreshed(self) -> bool:
|
|
51
|
+
_context = self._current_context
|
|
52
|
+
if _context is None:
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
num_inference_steps = _context.cache_config.num_inference_steps
|
|
56
|
+
if num_inference_steps is not None:
|
|
57
|
+
current_step = _context.get_current_step() # e.g, 0~49,50~99,...
|
|
58
|
+
return current_step == num_inference_steps - 1
|
|
59
|
+
return False
|
|
28
60
|
|
|
61
|
+
@torch.compiler.disable
|
|
29
62
|
def new_context(self, *args, **kwargs) -> CachedContext:
|
|
63
|
+
if self._persistent_context:
|
|
64
|
+
cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
|
|
65
|
+
assert (
|
|
66
|
+
cache_config is not None
|
|
67
|
+
and cache_config.num_inference_steps is not None
|
|
68
|
+
), (
|
|
69
|
+
"When persistent_context is True, num_inference_steps "
|
|
70
|
+
"must be set in cache_config for proper cache refreshing."
|
|
71
|
+
f"\nkwargs: {kwargs}"
|
|
72
|
+
)
|
|
30
73
|
_context = CachedContext(*args, **kwargs)
|
|
74
|
+
# NOTE: Patch args and kwargs for implicit refresh.
|
|
75
|
+
_context._init_args = args # maybe empty tuple: ()
|
|
76
|
+
_context._init_kwargs = kwargs # maybe empty dict: {}
|
|
31
77
|
self._cached_context_manager[_context.name] = _context
|
|
32
78
|
return _context
|
|
33
79
|
|
|
34
|
-
|
|
80
|
+
@torch.compiler.disable
|
|
81
|
+
def maybe_refresh(
|
|
82
|
+
self,
|
|
83
|
+
cached_context: Optional[CachedContext | str] = None,
|
|
84
|
+
) -> bool:
|
|
85
|
+
if cached_context is None:
|
|
86
|
+
_context = self._current_context
|
|
87
|
+
assert _context is not None, "Current context is not set!"
|
|
88
|
+
|
|
35
89
|
if isinstance(cached_context, CachedContext):
|
|
36
|
-
|
|
90
|
+
_context = cached_context
|
|
37
91
|
else:
|
|
38
92
|
if cached_context not in self._cached_context_manager:
|
|
39
93
|
raise ContextNotExistError("Context not exist!")
|
|
40
|
-
|
|
94
|
+
_context = self._cached_context_manager[cached_context]
|
|
95
|
+
|
|
96
|
+
if self._persistent_context:
|
|
97
|
+
assert _context.cache_config.num_inference_steps is not None, (
|
|
98
|
+
"When persistent_context is True, num_inference_steps must be set "
|
|
99
|
+
"in cache_config for proper cache refreshing."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
num_inference_steps = _context.cache_config.num_inference_steps
|
|
103
|
+
if num_inference_steps is not None:
|
|
104
|
+
current_step = _context.get_current_step() # e.g, 0~49,50~99,...
|
|
105
|
+
# Another round of inference, need to refresh cache context.
|
|
106
|
+
if current_step >= num_inference_steps:
|
|
107
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
108
|
+
logger.debug(
|
|
109
|
+
f"Refreshing cache context '{_context.name}' "
|
|
110
|
+
f"as current step: {current_step} >= "
|
|
111
|
+
f"num_inference_steps: {num_inference_steps}."
|
|
112
|
+
)
|
|
113
|
+
return True
|
|
114
|
+
return False
|
|
115
|
+
|
|
116
|
+
@torch.compiler.disable
|
|
117
|
+
def set_context(
|
|
118
|
+
self,
|
|
119
|
+
cached_context: CachedContext | str,
|
|
120
|
+
*args,
|
|
121
|
+
**kwargs,
|
|
122
|
+
) -> CachedContext:
|
|
123
|
+
if isinstance(cached_context, CachedContext):
|
|
124
|
+
self._current_context = cached_context
|
|
125
|
+
else:
|
|
126
|
+
if cached_context not in self._cached_context_manager:
|
|
127
|
+
if not self._persistent_context:
|
|
128
|
+
raise ContextNotExistError(
|
|
129
|
+
"Context not exist and persistent_context is False. Please "
|
|
130
|
+
"create new context first or set persistent_context=True."
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
# Create new context if not exist
|
|
134
|
+
if any((bool(args), bool(kwargs))):
|
|
135
|
+
kwargs["name"] = cached_context
|
|
136
|
+
self._current_context = self.new_context(
|
|
137
|
+
*args, **kwargs
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
"To create new context, please provide args and kwargs."
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
self._current_context = self._cached_context_manager[
|
|
145
|
+
cached_context
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
if self.maybe_refresh(self._current_context):
|
|
149
|
+
if not any((bool(args), bool(kwargs))):
|
|
150
|
+
assert hasattr(self._current_context, "_init_args")
|
|
151
|
+
assert hasattr(self._current_context, "_init_kwargs")
|
|
152
|
+
args = self._current_context._init_args
|
|
153
|
+
kwargs = self._current_context._init_kwargs
|
|
154
|
+
|
|
155
|
+
self._current_context = self.reset_context(
|
|
156
|
+
self._current_context, *args, **kwargs
|
|
157
|
+
)
|
|
158
|
+
self._current_step_refreshed = True
|
|
159
|
+
else:
|
|
160
|
+
self._current_step_refreshed = False
|
|
161
|
+
|
|
41
162
|
return self._current_context
|
|
42
163
|
|
|
43
164
|
def get_context(self, name: str = None) -> CachedContext:
|
|
@@ -20,10 +20,17 @@ class ContextManager:
|
|
|
20
20
|
cls,
|
|
21
21
|
cache_type: CacheType,
|
|
22
22
|
name: str = "default",
|
|
23
|
+
persistent_context: bool = False,
|
|
23
24
|
) -> CachedContextManager | PrunedContextManager:
|
|
24
25
|
if cache_type == CacheType.DBCache:
|
|
25
|
-
return CachedContextManager(
|
|
26
|
+
return CachedContextManager(
|
|
27
|
+
name=name,
|
|
28
|
+
persistent_context=persistent_context,
|
|
29
|
+
)
|
|
26
30
|
elif cache_type == CacheType.DBPrune:
|
|
27
|
-
return PrunedContextManager(
|
|
31
|
+
return PrunedContextManager(
|
|
32
|
+
name=name,
|
|
33
|
+
persistent_context=persistent_context,
|
|
34
|
+
)
|
|
28
35
|
else:
|
|
29
36
|
raise ValueError(f"Unsupported cache_type: {cache_type}.")
|
|
@@ -3,6 +3,7 @@ import functools
|
|
|
3
3
|
from typing import Dict, List, Tuple, Union
|
|
4
4
|
|
|
5
5
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
6
|
+
BasicCacheConfig,
|
|
6
7
|
CachedContextManager,
|
|
7
8
|
)
|
|
8
9
|
from cache_dit.cache_factory.cache_contexts.prune_context import (
|
|
@@ -16,15 +17,27 @@ logger = init_logger(__name__)
|
|
|
16
17
|
class PrunedContextManager(CachedContextManager):
|
|
17
18
|
# Reuse CachedContextManager for Dynamic Block Prune
|
|
18
19
|
|
|
19
|
-
def __init__(self, name: str = None):
|
|
20
|
-
super().__init__(name)
|
|
20
|
+
def __init__(self, name: str = None, **kwargs):
|
|
21
|
+
super().__init__(name, **kwargs)
|
|
21
22
|
# Overwrite for Dynamic Block Prune
|
|
22
23
|
self._current_context: PrunedContext = None
|
|
23
24
|
self._cached_context_manager: Dict[str, PrunedContext] = {}
|
|
24
25
|
|
|
25
26
|
# Overwrite for Dynamic Block Prune
|
|
26
27
|
def new_context(self, *args, **kwargs) -> PrunedContext:
|
|
28
|
+
if self._persistent_context:
|
|
29
|
+
cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
|
|
30
|
+
assert (
|
|
31
|
+
cache_config is not None
|
|
32
|
+
and cache_config.num_inference_steps is not None
|
|
33
|
+
), (
|
|
34
|
+
"When persistent_context is True, num_inference_steps "
|
|
35
|
+
"must be set in cache_config for proper cache refreshing."
|
|
36
|
+
)
|
|
27
37
|
_context = PrunedContext(*args, **kwargs)
|
|
38
|
+
# NOTE: Patch args and kwargs for implicit refresh.
|
|
39
|
+
_context._init_args = args # maybe empty tuple: ()
|
|
40
|
+
_context._init_kwargs = kwargs # maybe empty dict: {}
|
|
28
41
|
self._cached_context_manager[_context.name] = _context
|
|
29
42
|
return _context
|
|
30
43
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
import torch
|
|
1
2
|
from typing import Any, Tuple, List, Union, Optional
|
|
2
|
-
from diffusers import DiffusionPipeline
|
|
3
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
3
4
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
4
5
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
5
6
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
@@ -22,6 +23,9 @@ def enable_cache(
|
|
|
22
23
|
pipe_or_adapter: Union[
|
|
23
24
|
DiffusionPipeline,
|
|
24
25
|
BlockAdapter,
|
|
26
|
+
# Transformer-only
|
|
27
|
+
torch.nn.Module,
|
|
28
|
+
ModelMixin,
|
|
25
29
|
],
|
|
26
30
|
# BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
|
|
27
31
|
cache_config: Optional[
|
|
@@ -47,6 +51,9 @@ def enable_cache(
|
|
|
47
51
|
**kwargs,
|
|
48
52
|
) -> Union[
|
|
49
53
|
DiffusionPipeline,
|
|
54
|
+
# Transformer-only
|
|
55
|
+
torch.nn.Module,
|
|
56
|
+
ModelMixin,
|
|
50
57
|
BlockAdapter,
|
|
51
58
|
]:
|
|
52
59
|
r"""
|
|
@@ -76,7 +83,7 @@ def enable_cache(
|
|
|
76
83
|
with minimal code changes.
|
|
77
84
|
|
|
78
85
|
Args:
|
|
79
|
-
pipe_or_adapter (`DiffusionPipeline` or `
|
|
86
|
+
pipe_or_adapter (`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
|
|
80
87
|
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
81
88
|
For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
|
|
82
89
|
for the usgae of BlockAdapter.
|
|
@@ -119,6 +126,10 @@ def enable_cache(
|
|
|
119
126
|
Whether to compute separate difference values for CFG and non-CFG steps, default is True.
|
|
120
127
|
If False, we will use the computed difference from the current non-CFG transformer step
|
|
121
128
|
for the current CFG step.
|
|
129
|
+
num_inference_steps (`int`, *optional*, defaults to None):
|
|
130
|
+
num_inference_steps for DiffusionPipeline, used to adjust some internal settings
|
|
131
|
+
for better caching performance. For example, we will refresh the cache once the
|
|
132
|
+
executed steps exceed num_inference_steps if num_inference_steps is provided.
|
|
122
133
|
|
|
123
134
|
calibrator_config (`CalibratorConfig`, *optional*, defaults to None):
|
|
124
135
|
Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache
|
|
@@ -137,10 +148,22 @@ def enable_cache(
|
|
|
137
148
|
Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
|
|
138
149
|
parallelism for cache-dit. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/parallel_config.py
|
|
139
150
|
for more details of ParallelismConfig.
|
|
151
|
+
backend: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
|
|
152
|
+
Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
|
|
153
|
+
For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
|
|
154
|
+
only NATIVE_PYTORCH backend is supported.
|
|
140
155
|
ulysses_size: (`int`, *optional*, defaults to None):
|
|
141
156
|
The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
|
|
157
|
+
This setting is only valid when backend is NATIVE_DIFFUSER.
|
|
142
158
|
ring_size: (`int`, *optional*, defaults to None):
|
|
143
159
|
The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
|
|
160
|
+
This setting is only valid when backend is NATIVE_DIFFUSER.
|
|
161
|
+
tp_size: (`int`, *optional*, defaults to None):
|
|
162
|
+
The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
|
|
163
|
+
This setting is only valid when backend is NATIVE_PYTORCH.
|
|
164
|
+
parallel_kwargs: (`dict`, *optional*, defaults to {}):
|
|
165
|
+
Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
|
|
166
|
+
it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
|
|
144
167
|
|
|
145
168
|
kwargs (`dict`, *optional*, defaults to {})
|
|
146
169
|
Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
|
|
@@ -243,7 +266,10 @@ def enable_cache(
|
|
|
243
266
|
context_kwargs["params_modifiers"] = params_modifiers
|
|
244
267
|
|
|
245
268
|
if cache_config is not None:
|
|
246
|
-
if isinstance(
|
|
269
|
+
if isinstance(
|
|
270
|
+
pipe_or_adapter,
|
|
271
|
+
(DiffusionPipeline, BlockAdapter, torch.nn.Module, ModelMixin),
|
|
272
|
+
):
|
|
247
273
|
pipe_or_adapter = CachedAdapter.apply(
|
|
248
274
|
pipe_or_adapter,
|
|
249
275
|
**context_kwargs,
|
|
@@ -20,33 +20,33 @@ class ForwardPattern(Enum):
|
|
|
20
20
|
|
|
21
21
|
Pattern_0 = (
|
|
22
22
|
True, # Return_H_First
|
|
23
|
-
False,
|
|
24
|
-
False,
|
|
23
|
+
False, # Return_H_Only
|
|
24
|
+
False, # Forward_H_only
|
|
25
25
|
("hidden_states", "encoder_hidden_states"), # In
|
|
26
26
|
("hidden_states", "encoder_hidden_states"), # Out
|
|
27
27
|
True, # Supported
|
|
28
28
|
)
|
|
29
29
|
|
|
30
30
|
Pattern_1 = (
|
|
31
|
-
False,
|
|
32
|
-
False,
|
|
33
|
-
False,
|
|
31
|
+
False, # Return_H_First
|
|
32
|
+
False, # Return_H_Only
|
|
33
|
+
False, # Forward_H_only
|
|
34
34
|
("hidden_states", "encoder_hidden_states"), # In
|
|
35
35
|
("encoder_hidden_states", "hidden_states"), # Out
|
|
36
36
|
True, # Supported
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
Pattern_2 = (
|
|
40
|
-
False,
|
|
40
|
+
False, # Return_H_First
|
|
41
41
|
True, # Return_H_Only
|
|
42
|
-
False,
|
|
42
|
+
False, # Forward_H_only
|
|
43
43
|
("hidden_states", "encoder_hidden_states"), # In
|
|
44
|
-
("hidden_states",),
|
|
44
|
+
("hidden_states",), # Out
|
|
45
45
|
True, # Supported
|
|
46
46
|
)
|
|
47
47
|
|
|
48
48
|
Pattern_3 = (
|
|
49
|
-
False,
|
|
49
|
+
False, # Return_H_First
|
|
50
50
|
True, # Return_H_Only
|
|
51
51
|
True, # Forward_H_only
|
|
52
52
|
("hidden_states",), # In
|
|
@@ -56,18 +56,18 @@ class ForwardPattern(Enum):
|
|
|
56
56
|
|
|
57
57
|
Pattern_4 = (
|
|
58
58
|
True, # Return_H_First
|
|
59
|
-
False,
|
|
59
|
+
False, # Return_H_Only
|
|
60
60
|
True, # Forward_H_only
|
|
61
|
-
("hidden_states",),
|
|
61
|
+
("hidden_states",), # In
|
|
62
62
|
("hidden_states", "encoder_hidden_states"), # Out
|
|
63
63
|
True, # Supported
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
Pattern_5 = (
|
|
67
|
-
False,
|
|
68
|
-
False,
|
|
67
|
+
False, # Return_H_First
|
|
68
|
+
False, # Return_H_Only
|
|
69
69
|
True, # Forward_H_only
|
|
70
|
-
("hidden_states",),
|
|
70
|
+
("hidden_states",), # In
|
|
71
71
|
("encoder_hidden_states", "hidden_states"), # Out
|
|
72
72
|
True, # Supported
|
|
73
73
|
)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
5
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
6
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
7
|
+
from cache_dit.logger import init_logger
|
|
8
|
+
from ..utils import (
|
|
9
|
+
native_diffusers_parallelism_available,
|
|
10
|
+
ContextParallelConfig,
|
|
11
|
+
)
|
|
12
|
+
from .cp_planners import *
|
|
13
|
+
|
|
14
|
+
logger = init_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def maybe_enable_context_parallelism(
|
|
18
|
+
transformer: torch.nn.Module,
|
|
19
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
20
|
+
) -> torch.nn.Module:
|
|
21
|
+
assert isinstance(transformer, ModelMixin), (
|
|
22
|
+
"transformer must be an instance of diffusers' ModelMixin, "
|
|
23
|
+
f"but got {type(transformer)}"
|
|
24
|
+
)
|
|
25
|
+
if parallelism_config is None:
|
|
26
|
+
return transformer
|
|
27
|
+
|
|
28
|
+
assert isinstance(parallelism_config, ParallelismConfig), (
|
|
29
|
+
"parallelism_config must be an instance of ParallelismConfig"
|
|
30
|
+
f" but got {type(parallelism_config)}"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
if (
|
|
34
|
+
parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
|
|
35
|
+
and native_diffusers_parallelism_available()
|
|
36
|
+
):
|
|
37
|
+
cp_config = None
|
|
38
|
+
if (
|
|
39
|
+
parallelism_config.ulysses_size is not None
|
|
40
|
+
or parallelism_config.ring_size is not None
|
|
41
|
+
):
|
|
42
|
+
cp_config = ContextParallelConfig(
|
|
43
|
+
ulysses_degree=parallelism_config.ulysses_size,
|
|
44
|
+
ring_degree=parallelism_config.ring_size,
|
|
45
|
+
)
|
|
46
|
+
if cp_config is not None:
|
|
47
|
+
attention_backend = parallelism_config.parallel_kwargs.get(
|
|
48
|
+
"attention_backend", None
|
|
49
|
+
)
|
|
50
|
+
if hasattr(transformer, "enable_parallelism"):
|
|
51
|
+
if hasattr(transformer, "set_attention_backend"):
|
|
52
|
+
# _native_cudnn, flash, etc.
|
|
53
|
+
if attention_backend is None:
|
|
54
|
+
# Now only _native_cudnn is supported for parallelism
|
|
55
|
+
# issue: https://github.com/huggingface/diffusers/pull/12443
|
|
56
|
+
transformer.set_attention_backend("_native_cudnn")
|
|
57
|
+
logger.warning(
|
|
58
|
+
"attention_backend is None, set default attention backend "
|
|
59
|
+
"to _native_cudnn for parallelism because of the issue: "
|
|
60
|
+
"https://github.com/huggingface/diffusers/pull/12443"
|
|
61
|
+
)
|
|
62
|
+
else:
|
|
63
|
+
transformer.set_attention_backend(attention_backend)
|
|
64
|
+
logger.info(
|
|
65
|
+
"Found attention_backend from config, set attention "
|
|
66
|
+
f"backend to: {attention_backend}"
|
|
67
|
+
)
|
|
68
|
+
# Prefer custom cp_plan if provided
|
|
69
|
+
cp_plan = parallelism_config.parallel_kwargs.get(
|
|
70
|
+
"cp_plan", None
|
|
71
|
+
)
|
|
72
|
+
if cp_plan is not None:
|
|
73
|
+
logger.info(
|
|
74
|
+
f"Using custom context parallelism plan: {cp_plan}"
|
|
75
|
+
)
|
|
76
|
+
else:
|
|
77
|
+
# Try get context parallelism plan from register if not provided
|
|
78
|
+
extra_parallel_kwargs = {}
|
|
79
|
+
if parallelism_config.parallel_kwargs is not None:
|
|
80
|
+
extra_parallel_kwargs = (
|
|
81
|
+
parallelism_config.parallel_kwargs
|
|
82
|
+
)
|
|
83
|
+
cp_plan = ContextParallelismPlannerRegister.get_planner(
|
|
84
|
+
transformer
|
|
85
|
+
)().apply(transformer=transformer, **extra_parallel_kwargs)
|
|
86
|
+
|
|
87
|
+
transformer.enable_parallelism(
|
|
88
|
+
config=cp_config, cp_plan=cp_plan
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"{transformer.__class__.__name__} does not support context parallelism."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return transformer
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import logging
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from diffusers.models._modeling_parallel import (
|
|
9
|
+
ContextParallelModelPlan,
|
|
10
|
+
)
|
|
11
|
+
except ImportError:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
14
|
+
"Please install latest version of diffusers from source: \n"
|
|
15
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from cache_dit.logger import init_logger
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ContextParallelismPlanner",
|
|
25
|
+
"ContextParallelismPlannerRegister",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ContextParallelismPlanner:
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def apply(
|
|
32
|
+
self,
|
|
33
|
+
# NOTE: Keep this kwarg for future extensions
|
|
34
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
) -> ContextParallelModelPlan:
|
|
37
|
+
# NOTE: This method should only return the CP plan dictionary.
|
|
38
|
+
raise NotImplementedError(
|
|
39
|
+
"apply method must be implemented by subclasses"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ContextParallelismPlannerRegister:
|
|
44
|
+
_cp_planner_registry: dict[str, ContextParallelismPlanner] = {}
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def register(cls, name: str):
|
|
48
|
+
def decorator(planner_cls: type[ContextParallelismPlanner]):
|
|
49
|
+
assert (
|
|
50
|
+
name not in cls._cp_planner_registry
|
|
51
|
+
), f"ContextParallelismPlanner with name {name} is already registered."
|
|
52
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
53
|
+
logger.debug(f"Registering ContextParallelismPlanner: {name}")
|
|
54
|
+
cls._cp_planner_registry[name] = planner_cls
|
|
55
|
+
return planner_cls
|
|
56
|
+
|
|
57
|
+
return decorator
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def get_planner(
|
|
61
|
+
cls, transformer: str | torch.nn.Module | ModelMixin
|
|
62
|
+
) -> type[ContextParallelismPlanner]:
|
|
63
|
+
if isinstance(transformer, (torch.nn.Module, ModelMixin)):
|
|
64
|
+
name = transformer.__class__.__name__
|
|
65
|
+
else:
|
|
66
|
+
name = transformer
|
|
67
|
+
planner_cls = None
|
|
68
|
+
for planner_name in cls._cp_planner_registry:
|
|
69
|
+
if name.startswith(planner_name):
|
|
70
|
+
planner_cls = cls._cp_planner_registry.get(planner_name)
|
|
71
|
+
break
|
|
72
|
+
if planner_cls is None:
|
|
73
|
+
raise ValueError(f"No planner registered under name: {name}")
|
|
74
|
+
return planner_cls
|