cache-dit 1.0.8__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (45) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/block_adapters/__init__.py +37 -0
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +75 -4
  5. cache_dit/cache_factory/block_adapters/block_registers.py +44 -17
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +72 -30
  7. cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
  8. cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
  9. cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
  10. cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
  11. cache_dit/cache_factory/cache_interface.py +102 -28
  12. cache_dit/cache_factory/forward_pattern.py +14 -14
  13. cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
  14. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
  15. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
  16. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
  17. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -49
  18. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  19. cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
  20. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  21. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  22. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
  23. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  24. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
  25. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  26. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
  27. cache_dit/parallelism/parallel_backend.py +2 -0
  28. cache_dit/parallelism/parallel_config.py +10 -3
  29. cache_dit/parallelism/parallel_interface.py +14 -5
  30. cache_dit/quantize/backends/__init__.py +1 -0
  31. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  32. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  33. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
  34. cache_dit/quantize/quantize_backend.py +0 -0
  35. cache_dit/quantize/quantize_config.py +0 -0
  36. cache_dit/quantize/quantize_interface.py +3 -16
  37. cache_dit/utils.py +56 -20
  38. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/METADATA +24 -13
  39. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
  40. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  41. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  42. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
  43. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
  44. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
  45. {cache_dit-1.0.8.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
@@ -20,33 +20,33 @@ class ForwardPattern(Enum):
20
20
 
21
21
  Pattern_0 = (
22
22
  True, # Return_H_First
23
- False, # Return_H_Only
24
- False, # Forward_H_only
23
+ False, # Return_H_Only
24
+ False, # Forward_H_only
25
25
  ("hidden_states", "encoder_hidden_states"), # In
26
26
  ("hidden_states", "encoder_hidden_states"), # Out
27
27
  True, # Supported
28
28
  )
29
29
 
30
30
  Pattern_1 = (
31
- False, # Return_H_First
32
- False, # Return_H_Only
33
- False, # Forward_H_only
31
+ False, # Return_H_First
32
+ False, # Return_H_Only
33
+ False, # Forward_H_only
34
34
  ("hidden_states", "encoder_hidden_states"), # In
35
35
  ("encoder_hidden_states", "hidden_states"), # Out
36
36
  True, # Supported
37
37
  )
38
38
 
39
39
  Pattern_2 = (
40
- False, # Return_H_First
40
+ False, # Return_H_First
41
41
  True, # Return_H_Only
42
- False, # Forward_H_only
42
+ False, # Forward_H_only
43
43
  ("hidden_states", "encoder_hidden_states"), # In
44
- ("hidden_states",), # Out
44
+ ("hidden_states",), # Out
45
45
  True, # Supported
46
46
  )
47
47
 
48
48
  Pattern_3 = (
49
- False, # Return_H_First
49
+ False, # Return_H_First
50
50
  True, # Return_H_Only
51
51
  True, # Forward_H_only
52
52
  ("hidden_states",), # In
@@ -56,18 +56,18 @@ class ForwardPattern(Enum):
56
56
 
57
57
  Pattern_4 = (
58
58
  True, # Return_H_First
59
- False, # Return_H_Only
59
+ False, # Return_H_Only
60
60
  True, # Forward_H_only
61
- ("hidden_states",), # In
61
+ ("hidden_states",), # In
62
62
  ("hidden_states", "encoder_hidden_states"), # Out
63
63
  True, # Supported
64
64
  )
65
65
 
66
66
  Pattern_5 = (
67
- False, # Return_H_First
68
- False, # Return_H_Only
67
+ False, # Return_H_First
68
+ False, # Return_H_Only
69
69
  True, # Forward_H_only
70
- ("hidden_states",), # In
70
+ ("hidden_states",), # In
71
71
  ("encoder_hidden_states", "hidden_states"), # Out
72
72
  True, # Supported
73
73
  )
@@ -1,6 +1,3 @@
1
1
  from cache_dit.parallelism.backends.native_diffusers.parallel_difffusers import (
2
2
  maybe_enable_parallelism,
3
3
  )
4
- from cache_dit.parallelism.backends.native_diffusers.parallel_difffusers import (
5
- native_diffusers_parallelism_available,
6
- )
@@ -0,0 +1,95 @@
1
+ import torch
2
+ from typing import Optional
3
+
4
+ from diffusers.models.modeling_utils import ModelMixin
5
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
6
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
7
+ from cache_dit.logger import init_logger
8
+ from ..utils import (
9
+ native_diffusers_parallelism_available,
10
+ ContextParallelConfig,
11
+ )
12
+ from .cp_planners import *
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ def maybe_enable_context_parallelism(
18
+ transformer: torch.nn.Module,
19
+ parallelism_config: Optional[ParallelismConfig],
20
+ ) -> torch.nn.Module:
21
+ assert isinstance(transformer, ModelMixin), (
22
+ "transformer must be an instance of diffusers' ModelMixin, "
23
+ f"but got {type(transformer)}"
24
+ )
25
+ if parallelism_config is None:
26
+ return transformer
27
+
28
+ assert isinstance(parallelism_config, ParallelismConfig), (
29
+ "parallelism_config must be an instance of ParallelismConfig"
30
+ f" but got {type(parallelism_config)}"
31
+ )
32
+
33
+ if (
34
+ parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
35
+ and native_diffusers_parallelism_available()
36
+ ):
37
+ cp_config = None
38
+ if (
39
+ parallelism_config.ulysses_size is not None
40
+ or parallelism_config.ring_size is not None
41
+ ):
42
+ cp_config = ContextParallelConfig(
43
+ ulysses_degree=parallelism_config.ulysses_size,
44
+ ring_degree=parallelism_config.ring_size,
45
+ )
46
+ if cp_config is not None:
47
+ attention_backend = parallelism_config.parallel_kwargs.get(
48
+ "attention_backend", None
49
+ )
50
+ if hasattr(transformer, "enable_parallelism"):
51
+ if hasattr(transformer, "set_attention_backend"):
52
+ # _native_cudnn, flash, etc.
53
+ if attention_backend is None:
54
+ # Now only _native_cudnn is supported for parallelism
55
+ # issue: https://github.com/huggingface/diffusers/pull/12443
56
+ transformer.set_attention_backend("_native_cudnn")
57
+ logger.warning(
58
+ "attention_backend is None, set default attention backend "
59
+ "to _native_cudnn for parallelism because of the issue: "
60
+ "https://github.com/huggingface/diffusers/pull/12443"
61
+ )
62
+ else:
63
+ transformer.set_attention_backend(attention_backend)
64
+ logger.info(
65
+ "Found attention_backend from config, set attention "
66
+ f"backend to: {attention_backend}"
67
+ )
68
+ # Prefer custom cp_plan if provided
69
+ cp_plan = parallelism_config.parallel_kwargs.get(
70
+ "cp_plan", None
71
+ )
72
+ if cp_plan is not None:
73
+ logger.info(
74
+ f"Using custom context parallelism plan: {cp_plan}"
75
+ )
76
+ else:
77
+ # Try get context parallelism plan from register if not provided
78
+ extra_parallel_kwargs = {}
79
+ if parallelism_config.parallel_kwargs is not None:
80
+ extra_parallel_kwargs = (
81
+ parallelism_config.parallel_kwargs
82
+ )
83
+ cp_plan = ContextParallelismPlannerRegister.get_planner(
84
+ transformer
85
+ )().apply(transformer=transformer, **extra_parallel_kwargs)
86
+
87
+ transformer.enable_parallelism(
88
+ config=cp_config, cp_plan=cp_plan
89
+ )
90
+ else:
91
+ raise ValueError(
92
+ f"{transformer.__class__.__name__} does not support context parallelism."
93
+ )
94
+
95
+ return transformer
@@ -0,0 +1,74 @@
1
+ import torch
2
+ import logging
3
+ from abc import abstractmethod
4
+ from typing import Optional
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+
7
+ try:
8
+ from diffusers.models._modeling_parallel import (
9
+ ContextParallelModelPlan,
10
+ )
11
+ except ImportError:
12
+ raise ImportError(
13
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
14
+ "Please install latest version of diffusers from source: \n"
15
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
16
+ )
17
+
18
+ from cache_dit.logger import init_logger
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ __all__ = [
24
+ "ContextParallelismPlanner",
25
+ "ContextParallelismPlannerRegister",
26
+ ]
27
+
28
+
29
+ class ContextParallelismPlanner:
30
+ @abstractmethod
31
+ def apply(
32
+ self,
33
+ # NOTE: Keep this kwarg for future extensions
34
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
35
+ **kwargs,
36
+ ) -> ContextParallelModelPlan:
37
+ # NOTE: This method should only return the CP plan dictionary.
38
+ raise NotImplementedError(
39
+ "apply method must be implemented by subclasses"
40
+ )
41
+
42
+
43
+ class ContextParallelismPlannerRegister:
44
+ _cp_planner_registry: dict[str, ContextParallelismPlanner] = {}
45
+
46
+ @classmethod
47
+ def register(cls, name: str):
48
+ def decorator(planner_cls: type[ContextParallelismPlanner]):
49
+ assert (
50
+ name not in cls._cp_planner_registry
51
+ ), f"ContextParallelismPlanner with name {name} is already registered."
52
+ if logger.isEnabledFor(logging.DEBUG):
53
+ logger.debug(f"Registering ContextParallelismPlanner: {name}")
54
+ cls._cp_planner_registry[name] = planner_cls
55
+ return planner_cls
56
+
57
+ return decorator
58
+
59
+ @classmethod
60
+ def get_planner(
61
+ cls, transformer: str | torch.nn.Module | ModelMixin
62
+ ) -> type[ContextParallelismPlanner]:
63
+ if isinstance(transformer, (torch.nn.Module, ModelMixin)):
64
+ name = transformer.__class__.__name__
65
+ else:
66
+ name = transformer
67
+ planner_cls = None
68
+ for planner_name in cls._cp_planner_registry:
69
+ if name.startswith(planner_name):
70
+ planner_cls = cls._cp_planner_registry.get(planner_name)
71
+ break
72
+ if planner_cls is None:
73
+ raise ValueError(f"No planner registered under name: {name}")
74
+ return planner_cls
@@ -0,0 +1,254 @@
1
+ # Docstring references: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/_modeling_parallel.py#L185
2
+ # A dictionary where keys denote the input to be split across context parallel region, and the
3
+ # value denotes the sharding configuration.
4
+ # If the key is a string, it denotes the name of the parameter in the forward function.
5
+ # If the key is an integer, split_output must be set to True, and it denotes the index of the output
6
+ # to be split across context parallel region.
7
+ # ContextParallelInputType = Dict[
8
+ # Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
9
+ # ]
10
+
11
+ # A dictionary where keys denote the output to be gathered across context parallel region, and the
12
+ # value denotes the gathering configuration.
13
+ # ContextParallelOutputType = Union[
14
+ # ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
15
+ # ]
16
+
17
+ # A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
18
+ # the module should be split/gathered across context parallel region.
19
+ # ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
20
+
21
+ # Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
22
+ #
23
+ # Each model should define a _cp_plan attribute that contains information on how to shard/gather
24
+ # tensors at different stages of the forward:
25
+ #
26
+ # ```python
27
+ # _cp_plan = {
28
+ # "": {
29
+ # "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
30
+ # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
31
+ # "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
32
+ # },
33
+ # "pos_embed": {
34
+ # 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
35
+ # 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
36
+ # },
37
+ # "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
38
+ # }
39
+ # ```
40
+ #
41
+ # The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
42
+ # split/gathered according to this at the respective module level. Here, the following happens:
43
+ # - "":
44
+ # we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
45
+ # the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
46
+ # - "pos_embed":
47
+ # we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
48
+ # we can individually specify how they should be split
49
+ # - "proj_out":
50
+ # before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
51
+ # layer forward has run).
52
+ #
53
+ # ContextParallelInput:
54
+ # specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
55
+ #
56
+ # ContextParallelOutput:
57
+ # specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
58
+
59
+ import torch
60
+ from typing import Optional
61
+ from diffusers.models.modeling_utils import ModelMixin
62
+
63
+ try:
64
+ from diffusers.models._modeling_parallel import (
65
+ ContextParallelInput,
66
+ ContextParallelOutput,
67
+ ContextParallelModelPlan,
68
+ )
69
+ except ImportError:
70
+ raise ImportError(
71
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
72
+ "Please install latest version of diffusers from source: \n"
73
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
74
+ )
75
+ from .cp_plan_registers import (
76
+ ContextParallelismPlanner,
77
+ ContextParallelismPlannerRegister,
78
+ )
79
+
80
+ from cache_dit.logger import init_logger
81
+
82
+ logger = init_logger(__name__)
83
+
84
+
85
+ __all__ = [
86
+ "ContextParallelismPlanner",
87
+ "ContextParallelismPlannerRegister",
88
+ "FluxContextParallelismPlanner",
89
+ "QwenImageContextParallelismPlanner",
90
+ "WanContextParallelismPlanner",
91
+ "LTXVideoContextParallelismPlanner",
92
+ ]
93
+
94
+
95
+ # Register context parallelism planner for models
96
+ @ContextParallelismPlannerRegister.register("Flux")
97
+ class FluxContextParallelismPlanner(ContextParallelismPlanner):
98
+ def apply(
99
+ self,
100
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
101
+ **kwargs,
102
+ ) -> ContextParallelModelPlan:
103
+ if transformer is not None:
104
+ from diffusers import FluxTransformer2DModel
105
+
106
+ assert isinstance(
107
+ transformer, FluxTransformer2DModel
108
+ ), "Transformer must be an instance of FluxTransformer2DModel"
109
+ if hasattr(transformer, "_cp_plan"):
110
+ return transformer._cp_plan
111
+
112
+ _cp_plan = {
113
+ "": {
114
+ "hidden_states": ContextParallelInput(
115
+ split_dim=1, expected_dims=3, split_output=False
116
+ ),
117
+ "encoder_hidden_states": ContextParallelInput(
118
+ split_dim=1, expected_dims=3, split_output=False
119
+ ),
120
+ "img_ids": ContextParallelInput(
121
+ split_dim=0, expected_dims=2, split_output=False
122
+ ),
123
+ "txt_ids": ContextParallelInput(
124
+ split_dim=0, expected_dims=2, split_output=False
125
+ ),
126
+ },
127
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
128
+ }
129
+ return _cp_plan
130
+
131
+
132
+ @ContextParallelismPlannerRegister.register("QwenImage")
133
+ class QwenImageContextParallelismPlanner(ContextParallelismPlanner):
134
+ def apply(
135
+ self,
136
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
137
+ **kwargs,
138
+ ) -> ContextParallelModelPlan:
139
+ if transformer is not None:
140
+ from diffusers import QwenImageTransformer2DModel
141
+
142
+ assert isinstance(
143
+ transformer, QwenImageTransformer2DModel
144
+ ), "Transformer must be an instance of QwenImageTransformer2DModel"
145
+ if hasattr(transformer, "_cp_plan"):
146
+ return transformer._cp_plan
147
+
148
+ _cp_plan = _cp_plan = {
149
+ "": {
150
+ "hidden_states": ContextParallelInput(
151
+ split_dim=1, expected_dims=3, split_output=False
152
+ ),
153
+ "encoder_hidden_states": ContextParallelInput(
154
+ split_dim=1, expected_dims=3, split_output=False
155
+ ),
156
+ "encoder_hidden_states_mask": ContextParallelInput(
157
+ split_dim=1, expected_dims=2, split_output=False
158
+ ),
159
+ },
160
+ "pos_embed": {
161
+ 0: ContextParallelInput(
162
+ split_dim=0, expected_dims=2, split_output=True
163
+ ),
164
+ 1: ContextParallelInput(
165
+ split_dim=0, expected_dims=2, split_output=True
166
+ ),
167
+ },
168
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
169
+ }
170
+ return _cp_plan
171
+
172
+
173
+ # TODO: Add WanVACETransformer3DModel context parallelism planner.
174
+ # NOTE: We choice to use full name to avoid name conflict between
175
+ # WanTransformer3DModel and WanVACETransformer3DModel.
176
+ @ContextParallelismPlannerRegister.register("WanTransformer3D")
177
+ class WanContextParallelismPlanner(ContextParallelismPlanner):
178
+ def apply(
179
+ self,
180
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
181
+ **kwargs,
182
+ ) -> ContextParallelModelPlan:
183
+ if transformer is not None:
184
+ from diffusers import WanTransformer3DModel
185
+
186
+ assert isinstance(
187
+ transformer, WanTransformer3DModel
188
+ ), "Transformer must be an instance of WanTransformer3DModel"
189
+ if hasattr(transformer, "_cp_plan"):
190
+ return transformer._cp_plan
191
+
192
+ _cp_plan = {
193
+ "rope": {
194
+ 0: ContextParallelInput(
195
+ split_dim=1, expected_dims=4, split_output=True
196
+ ),
197
+ 1: ContextParallelInput(
198
+ split_dim=1, expected_dims=4, split_output=True
199
+ ),
200
+ },
201
+ "blocks.0": {
202
+ "hidden_states": ContextParallelInput(
203
+ split_dim=1, expected_dims=3, split_output=False
204
+ ),
205
+ },
206
+ "blocks.*": {
207
+ "encoder_hidden_states": ContextParallelInput(
208
+ split_dim=1, expected_dims=3, split_output=False
209
+ ),
210
+ },
211
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
212
+ }
213
+ return _cp_plan
214
+
215
+
216
+ @ContextParallelismPlannerRegister.register("LTXVideo")
217
+ class LTXVideoContextParallelismPlanner(ContextParallelismPlanner):
218
+ def apply(
219
+ self,
220
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
221
+ **kwargs,
222
+ ) -> ContextParallelModelPlan:
223
+ if transformer is not None:
224
+ from diffusers import LTXVideoTransformer3DModel
225
+
226
+ assert isinstance(
227
+ transformer, LTXVideoTransformer3DModel
228
+ ), "Transformer must be an instance of LTXVideoTransformer3DModel"
229
+ if hasattr(transformer, "_cp_plan"):
230
+ return transformer._cp_plan
231
+
232
+ _cp_plan = {
233
+ "": {
234
+ "hidden_states": ContextParallelInput(
235
+ split_dim=1, expected_dims=3, split_output=False
236
+ ),
237
+ "encoder_hidden_states": ContextParallelInput(
238
+ split_dim=1, expected_dims=3, split_output=False
239
+ ),
240
+ "encoder_attention_mask": ContextParallelInput(
241
+ split_dim=1, expected_dims=2, split_output=False
242
+ ),
243
+ },
244
+ "rope": {
245
+ 0: ContextParallelInput(
246
+ split_dim=1, expected_dims=3, split_output=True
247
+ ),
248
+ 1: ContextParallelInput(
249
+ split_dim=1, expected_dims=3, split_output=True
250
+ ),
251
+ },
252
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
253
+ }
254
+ return _cp_plan
@@ -6,22 +6,10 @@ from cache_dit.logger import init_logger
6
6
  logger = init_logger(__name__)
7
7
 
8
8
 
9
- try:
10
- from diffusers import ContextParallelConfig
11
-
12
- def native_diffusers_parallelism_available() -> bool:
13
- return True
14
-
15
- except ImportError:
16
- ContextParallelConfig = None
17
-
18
- def native_diffusers_parallelism_available() -> bool:
19
- return False
20
-
21
-
22
9
  from diffusers.models.modeling_utils import ModelMixin
23
10
  from cache_dit.parallelism.parallel_backend import ParallelismBackend
24
11
  from cache_dit.parallelism.parallel_config import ParallelismConfig
12
+ from .context_parallelism import maybe_enable_context_parallelism
25
13
 
26
14
 
27
15
  def maybe_enable_parallelism(
@@ -40,42 +28,22 @@ def maybe_enable_parallelism(
40
28
  f" but got {type(parallelism_config)}"
41
29
  )
42
30
 
31
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER, (
32
+ f"parallelism backend must be {ParallelismBackend.NATIVE_DIFFUSER}, "
33
+ f"but got {parallelism_config.backend}"
34
+ )
35
+
43
36
  if (
44
- parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
45
- and native_diffusers_parallelism_available()
37
+ parallelism_config.ulysses_size is not None
38
+ or parallelism_config.ring_size is not None
46
39
  ):
47
- cp_config = None
48
- if (
49
- parallelism_config.ulysses_size is not None
50
- or parallelism_config.ring_size is not None
51
- ):
52
- cp_config = ContextParallelConfig(
53
- ulysses_degree=parallelism_config.ulysses_size,
54
- ring_degree=parallelism_config.ring_size,
55
- )
56
- if cp_config is not None:
57
- if hasattr(transformer, "enable_parallelism"):
58
- if hasattr(transformer, "set_attention_backend"):
59
- # Now only _native_cudnn is supported for parallelism
60
- # issue: https://github.com/huggingface/diffusers/pull/12443
61
- transformer.set_attention_backend("_native_cudnn")
62
- logger.warning(
63
- "Set attention backend to _native_cudnn for parallelism because of "
64
- "the issue: https://github.com/huggingface/diffusers/pull/12443"
65
- )
66
- cp_plan = parallelism_config.parallel_kwargs.get(
67
- "cp_plan", None
68
- )
69
- if cp_plan is not None:
70
- logger.info(
71
- f"Using custom context parallelism plan: {cp_plan}"
72
- )
73
- transformer.enable_parallelism(
74
- config=cp_config, cp_plan=cp_plan
75
- )
76
- else:
77
- raise ValueError(
78
- f"{transformer.__class__.__name__} does not support context parallelism."
79
- )
80
-
40
+ transformer = maybe_enable_context_parallelism(
41
+ transformer,
42
+ parallelism_config,
43
+ )
44
+ else:
45
+ raise ValueError(
46
+ "NATIVE_DIFFUSER backend only support context parallelism now. "
47
+ "Please set ulysses_size or ring_size in parallelism_config."
48
+ )
81
49
  return transformer
@@ -0,0 +1,11 @@
1
+ try:
2
+ from diffusers import ContextParallelConfig
3
+
4
+ def native_diffusers_parallelism_available() -> bool:
5
+ return True
6
+
7
+ except ImportError:
8
+ ContextParallelConfig = None
9
+
10
+ def native_diffusers_parallelism_available() -> bool:
11
+ return False
@@ -0,0 +1,3 @@
1
+ from cache_dit.parallelism.backends.native_pytorch.parallel_torch import (
2
+ maybe_enable_parallelism,
3
+ )
@@ -0,0 +1,62 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+
7
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
8
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
9
+
10
+ from cache_dit.logger import init_logger
11
+
12
+ logger = init_logger(__name__)
13
+
14
+
15
+ def maybe_enable_parallelism(
16
+ transformer: torch.nn.Module | ModelMixin,
17
+ parallelism_config: Optional[ParallelismConfig],
18
+ ) -> torch.nn.Module:
19
+ assert isinstance(transformer, torch.nn.Module), (
20
+ "transformer must be an instance of torch.nn.Module, "
21
+ f"but got {type(transformer)}"
22
+ )
23
+ assert isinstance(transformer, ModelMixin), (
24
+ "transformer must be an instance of diffusers' ModelMixin, "
25
+ f"but got {type(transformer)}"
26
+ )
27
+ if parallelism_config is None:
28
+ return transformer
29
+
30
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
31
+ "parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
32
+ f"but got {parallelism_config.backend}"
33
+ )
34
+
35
+ assert isinstance(parallelism_config, ParallelismConfig), (
36
+ "parallelism_config must be an instance of ParallelismConfig"
37
+ f" but got {type(parallelism_config)}"
38
+ )
39
+ assert (
40
+ parallelism_config.ulysses_size is None
41
+ and parallelism_config.ring_size is None
42
+ ), (
43
+ "Ulysses/Ring parallelism is not supported in Native_PyTorch backend. "
44
+ "Please set it to None in parallelism_config."
45
+ )
46
+
47
+ if (
48
+ parallelism_config.tp_size is not None
49
+ and parallelism_config.tp_size > 1
50
+ ):
51
+ from .tensor_parallelism import maybe_enable_tensor_parallelism
52
+
53
+ transformer = maybe_enable_tensor_parallelism(
54
+ transformer=transformer,
55
+ parallelism_config=parallelism_config,
56
+ )
57
+ else:
58
+ raise ValueError(
59
+ "NATIVE_PYTORCH only supported tensor parallelism now. "
60
+ "Please set tp_size > 1 for tensor parallelism."
61
+ )
62
+ return transformer