cache-dit 1.0.9__py3-none-any.whl → 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (45) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/block_adapters/__init__.py +37 -0
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +51 -3
  5. cache_dit/cache_factory/block_adapters/block_registers.py +41 -14
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +68 -30
  7. cache_dit/cache_factory/cache_contexts/cache_config.py +5 -3
  8. cache_dit/cache_factory/cache_contexts/cache_manager.py +125 -4
  9. cache_dit/cache_factory/cache_contexts/context_manager.py +9 -2
  10. cache_dit/cache_factory/cache_contexts/prune_manager.py +15 -2
  11. cache_dit/cache_factory/cache_interface.py +29 -3
  12. cache_dit/cache_factory/forward_pattern.py +14 -14
  13. cache_dit/parallelism/backends/native_diffusers/__init__.py +0 -3
  14. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +95 -0
  15. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +74 -0
  16. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +254 -0
  17. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +17 -61
  18. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  19. cache_dit/parallelism/backends/native_pytorch/__init__.py +3 -0
  20. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  21. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  22. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +159 -0
  23. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  24. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +58 -0
  25. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  26. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +12 -0
  27. cache_dit/parallelism/parallel_backend.py +2 -0
  28. cache_dit/parallelism/parallel_config.py +8 -1
  29. cache_dit/parallelism/parallel_interface.py +9 -4
  30. cache_dit/quantize/backends/__init__.py +1 -0
  31. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  32. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  33. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +28 -9
  34. cache_dit/quantize/quantize_backend.py +0 -0
  35. cache_dit/quantize/quantize_config.py +0 -0
  36. cache_dit/quantize/quantize_interface.py +3 -16
  37. cache_dit/utils.py +22 -2
  38. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/METADATA +22 -13
  39. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/RECORD +45 -29
  40. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  41. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  42. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/WHEEL +0 -0
  43. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/entry_points.txt +0 -0
  44. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/licenses/LICENSE +0 -0
  45. {cache_dit-1.0.9.dist-info → cache_dit-1.0.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
1
+ # Docstring references: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/_modeling_parallel.py#L185
2
+ # A dictionary where keys denote the input to be split across context parallel region, and the
3
+ # value denotes the sharding configuration.
4
+ # If the key is a string, it denotes the name of the parameter in the forward function.
5
+ # If the key is an integer, split_output must be set to True, and it denotes the index of the output
6
+ # to be split across context parallel region.
7
+ # ContextParallelInputType = Dict[
8
+ # Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
9
+ # ]
10
+
11
+ # A dictionary where keys denote the output to be gathered across context parallel region, and the
12
+ # value denotes the gathering configuration.
13
+ # ContextParallelOutputType = Union[
14
+ # ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
15
+ # ]
16
+
17
+ # A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
18
+ # the module should be split/gathered across context parallel region.
19
+ # ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
20
+
21
+ # Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
22
+ #
23
+ # Each model should define a _cp_plan attribute that contains information on how to shard/gather
24
+ # tensors at different stages of the forward:
25
+ #
26
+ # ```python
27
+ # _cp_plan = {
28
+ # "": {
29
+ # "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
30
+ # "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
31
+ # "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
32
+ # },
33
+ # "pos_embed": {
34
+ # 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
35
+ # 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
36
+ # },
37
+ # "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
38
+ # }
39
+ # ```
40
+ #
41
+ # The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
42
+ # split/gathered according to this at the respective module level. Here, the following happens:
43
+ # - "":
44
+ # we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
45
+ # the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
46
+ # - "pos_embed":
47
+ # we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
48
+ # we can individually specify how they should be split
49
+ # - "proj_out":
50
+ # before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
51
+ # layer forward has run).
52
+ #
53
+ # ContextParallelInput:
54
+ # specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
55
+ #
56
+ # ContextParallelOutput:
57
+ # specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
58
+
59
+ import torch
60
+ from typing import Optional
61
+ from diffusers.models.modeling_utils import ModelMixin
62
+
63
+ try:
64
+ from diffusers.models._modeling_parallel import (
65
+ ContextParallelInput,
66
+ ContextParallelOutput,
67
+ ContextParallelModelPlan,
68
+ )
69
+ except ImportError:
70
+ raise ImportError(
71
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
72
+ "Please install latest version of diffusers from source: \n"
73
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
74
+ )
75
+ from .cp_plan_registers import (
76
+ ContextParallelismPlanner,
77
+ ContextParallelismPlannerRegister,
78
+ )
79
+
80
+ from cache_dit.logger import init_logger
81
+
82
+ logger = init_logger(__name__)
83
+
84
+
85
+ __all__ = [
86
+ "ContextParallelismPlanner",
87
+ "ContextParallelismPlannerRegister",
88
+ "FluxContextParallelismPlanner",
89
+ "QwenImageContextParallelismPlanner",
90
+ "WanContextParallelismPlanner",
91
+ "LTXVideoContextParallelismPlanner",
92
+ ]
93
+
94
+
95
+ # Register context parallelism planner for models
96
+ @ContextParallelismPlannerRegister.register("Flux")
97
+ class FluxContextParallelismPlanner(ContextParallelismPlanner):
98
+ def apply(
99
+ self,
100
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
101
+ **kwargs,
102
+ ) -> ContextParallelModelPlan:
103
+ if transformer is not None:
104
+ from diffusers import FluxTransformer2DModel
105
+
106
+ assert isinstance(
107
+ transformer, FluxTransformer2DModel
108
+ ), "Transformer must be an instance of FluxTransformer2DModel"
109
+ if hasattr(transformer, "_cp_plan"):
110
+ return transformer._cp_plan
111
+
112
+ _cp_plan = {
113
+ "": {
114
+ "hidden_states": ContextParallelInput(
115
+ split_dim=1, expected_dims=3, split_output=False
116
+ ),
117
+ "encoder_hidden_states": ContextParallelInput(
118
+ split_dim=1, expected_dims=3, split_output=False
119
+ ),
120
+ "img_ids": ContextParallelInput(
121
+ split_dim=0, expected_dims=2, split_output=False
122
+ ),
123
+ "txt_ids": ContextParallelInput(
124
+ split_dim=0, expected_dims=2, split_output=False
125
+ ),
126
+ },
127
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
128
+ }
129
+ return _cp_plan
130
+
131
+
132
+ @ContextParallelismPlannerRegister.register("QwenImage")
133
+ class QwenImageContextParallelismPlanner(ContextParallelismPlanner):
134
+ def apply(
135
+ self,
136
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
137
+ **kwargs,
138
+ ) -> ContextParallelModelPlan:
139
+ if transformer is not None:
140
+ from diffusers import QwenImageTransformer2DModel
141
+
142
+ assert isinstance(
143
+ transformer, QwenImageTransformer2DModel
144
+ ), "Transformer must be an instance of QwenImageTransformer2DModel"
145
+ if hasattr(transformer, "_cp_plan"):
146
+ return transformer._cp_plan
147
+
148
+ _cp_plan = _cp_plan = {
149
+ "": {
150
+ "hidden_states": ContextParallelInput(
151
+ split_dim=1, expected_dims=3, split_output=False
152
+ ),
153
+ "encoder_hidden_states": ContextParallelInput(
154
+ split_dim=1, expected_dims=3, split_output=False
155
+ ),
156
+ "encoder_hidden_states_mask": ContextParallelInput(
157
+ split_dim=1, expected_dims=2, split_output=False
158
+ ),
159
+ },
160
+ "pos_embed": {
161
+ 0: ContextParallelInput(
162
+ split_dim=0, expected_dims=2, split_output=True
163
+ ),
164
+ 1: ContextParallelInput(
165
+ split_dim=0, expected_dims=2, split_output=True
166
+ ),
167
+ },
168
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
169
+ }
170
+ return _cp_plan
171
+
172
+
173
+ # TODO: Add WanVACETransformer3DModel context parallelism planner.
174
+ # NOTE: We choice to use full name to avoid name conflict between
175
+ # WanTransformer3DModel and WanVACETransformer3DModel.
176
+ @ContextParallelismPlannerRegister.register("WanTransformer3D")
177
+ class WanContextParallelismPlanner(ContextParallelismPlanner):
178
+ def apply(
179
+ self,
180
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
181
+ **kwargs,
182
+ ) -> ContextParallelModelPlan:
183
+ if transformer is not None:
184
+ from diffusers import WanTransformer3DModel
185
+
186
+ assert isinstance(
187
+ transformer, WanTransformer3DModel
188
+ ), "Transformer must be an instance of WanTransformer3DModel"
189
+ if hasattr(transformer, "_cp_plan"):
190
+ return transformer._cp_plan
191
+
192
+ _cp_plan = {
193
+ "rope": {
194
+ 0: ContextParallelInput(
195
+ split_dim=1, expected_dims=4, split_output=True
196
+ ),
197
+ 1: ContextParallelInput(
198
+ split_dim=1, expected_dims=4, split_output=True
199
+ ),
200
+ },
201
+ "blocks.0": {
202
+ "hidden_states": ContextParallelInput(
203
+ split_dim=1, expected_dims=3, split_output=False
204
+ ),
205
+ },
206
+ "blocks.*": {
207
+ "encoder_hidden_states": ContextParallelInput(
208
+ split_dim=1, expected_dims=3, split_output=False
209
+ ),
210
+ },
211
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
212
+ }
213
+ return _cp_plan
214
+
215
+
216
+ @ContextParallelismPlannerRegister.register("LTXVideo")
217
+ class LTXVideoContextParallelismPlanner(ContextParallelismPlanner):
218
+ def apply(
219
+ self,
220
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
221
+ **kwargs,
222
+ ) -> ContextParallelModelPlan:
223
+ if transformer is not None:
224
+ from diffusers import LTXVideoTransformer3DModel
225
+
226
+ assert isinstance(
227
+ transformer, LTXVideoTransformer3DModel
228
+ ), "Transformer must be an instance of LTXVideoTransformer3DModel"
229
+ if hasattr(transformer, "_cp_plan"):
230
+ return transformer._cp_plan
231
+
232
+ _cp_plan = {
233
+ "": {
234
+ "hidden_states": ContextParallelInput(
235
+ split_dim=1, expected_dims=3, split_output=False
236
+ ),
237
+ "encoder_hidden_states": ContextParallelInput(
238
+ split_dim=1, expected_dims=3, split_output=False
239
+ ),
240
+ "encoder_attention_mask": ContextParallelInput(
241
+ split_dim=1, expected_dims=2, split_output=False
242
+ ),
243
+ },
244
+ "rope": {
245
+ 0: ContextParallelInput(
246
+ split_dim=1, expected_dims=3, split_output=True
247
+ ),
248
+ 1: ContextParallelInput(
249
+ split_dim=1, expected_dims=3, split_output=True
250
+ ),
251
+ },
252
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
253
+ }
254
+ return _cp_plan
@@ -6,22 +6,10 @@ from cache_dit.logger import init_logger
6
6
  logger = init_logger(__name__)
7
7
 
8
8
 
9
- try:
10
- from diffusers import ContextParallelConfig
11
-
12
- def native_diffusers_parallelism_available() -> bool:
13
- return True
14
-
15
- except ImportError:
16
- ContextParallelConfig = None
17
-
18
- def native_diffusers_parallelism_available() -> bool:
19
- return False
20
-
21
-
22
9
  from diffusers.models.modeling_utils import ModelMixin
23
10
  from cache_dit.parallelism.parallel_backend import ParallelismBackend
24
11
  from cache_dit.parallelism.parallel_config import ParallelismConfig
12
+ from .context_parallelism import maybe_enable_context_parallelism
25
13
 
26
14
 
27
15
  def maybe_enable_parallelism(
@@ -40,54 +28,22 @@ def maybe_enable_parallelism(
40
28
  f" but got {type(parallelism_config)}"
41
29
  )
42
30
 
31
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER, (
32
+ f"parallelism backend must be {ParallelismBackend.NATIVE_DIFFUSER}, "
33
+ f"but got {parallelism_config.backend}"
34
+ )
35
+
43
36
  if (
44
- parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
45
- and native_diffusers_parallelism_available()
37
+ parallelism_config.ulysses_size is not None
38
+ or parallelism_config.ring_size is not None
46
39
  ):
47
- cp_config = None
48
- if (
49
- parallelism_config.ulysses_size is not None
50
- or parallelism_config.ring_size is not None
51
- ):
52
- cp_config = ContextParallelConfig(
53
- ulysses_degree=parallelism_config.ulysses_size,
54
- ring_degree=parallelism_config.ring_size,
55
- )
56
- if cp_config is not None:
57
- attention_backend = parallelism_config.parallel_kwargs.get(
58
- "attention_backend", None
59
- )
60
- if hasattr(transformer, "enable_parallelism"):
61
- if hasattr(transformer, "set_attention_backend"):
62
- # _native_cudnn, flash, etc.
63
- if attention_backend is None:
64
- # Now only _native_cudnn is supported for parallelism
65
- # issue: https://github.com/huggingface/diffusers/pull/12443
66
- transformer.set_attention_backend("_native_cudnn")
67
- logger.warning(
68
- "attention_backend is None, set default attention backend "
69
- "to _native_cudnn for parallelism because of the issue: "
70
- "https://github.com/huggingface/diffusers/pull/12443"
71
- )
72
- else:
73
- transformer.set_attention_backend(attention_backend)
74
- logger.info(
75
- "Found attention_backend from config, set attention "
76
- f"backend to: {attention_backend}"
77
- )
78
- cp_plan = parallelism_config.parallel_kwargs.get(
79
- "cp_plan", None
80
- )
81
- if cp_plan is not None:
82
- logger.info(
83
- f"Using custom context parallelism plan: {cp_plan}"
84
- )
85
- transformer.enable_parallelism(
86
- config=cp_config, cp_plan=cp_plan
87
- )
88
- else:
89
- raise ValueError(
90
- f"{transformer.__class__.__name__} does not support context parallelism."
91
- )
92
-
40
+ transformer = maybe_enable_context_parallelism(
41
+ transformer,
42
+ parallelism_config,
43
+ )
44
+ else:
45
+ raise ValueError(
46
+ "NATIVE_DIFFUSER backend only support context parallelism now. "
47
+ "Please set ulysses_size or ring_size in parallelism_config."
48
+ )
93
49
  return transformer
@@ -0,0 +1,11 @@
1
+ try:
2
+ from diffusers import ContextParallelConfig
3
+
4
+ def native_diffusers_parallelism_available() -> bool:
5
+ return True
6
+
7
+ except ImportError:
8
+ ContextParallelConfig = None
9
+
10
+ def native_diffusers_parallelism_available() -> bool:
11
+ return False
@@ -0,0 +1,3 @@
1
+ from cache_dit.parallelism.backends.native_pytorch.parallel_torch import (
2
+ maybe_enable_parallelism,
3
+ )
@@ -0,0 +1,62 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+
7
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
8
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
9
+
10
+ from cache_dit.logger import init_logger
11
+
12
+ logger = init_logger(__name__)
13
+
14
+
15
+ def maybe_enable_parallelism(
16
+ transformer: torch.nn.Module | ModelMixin,
17
+ parallelism_config: Optional[ParallelismConfig],
18
+ ) -> torch.nn.Module:
19
+ assert isinstance(transformer, torch.nn.Module), (
20
+ "transformer must be an instance of torch.nn.Module, "
21
+ f"but got {type(transformer)}"
22
+ )
23
+ assert isinstance(transformer, ModelMixin), (
24
+ "transformer must be an instance of diffusers' ModelMixin, "
25
+ f"but got {type(transformer)}"
26
+ )
27
+ if parallelism_config is None:
28
+ return transformer
29
+
30
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
31
+ "parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
32
+ f"but got {parallelism_config.backend}"
33
+ )
34
+
35
+ assert isinstance(parallelism_config, ParallelismConfig), (
36
+ "parallelism_config must be an instance of ParallelismConfig"
37
+ f" but got {type(parallelism_config)}"
38
+ )
39
+ assert (
40
+ parallelism_config.ulysses_size is None
41
+ and parallelism_config.ring_size is None
42
+ ), (
43
+ "Ulysses/Ring parallelism is not supported in Native_PyTorch backend. "
44
+ "Please set it to None in parallelism_config."
45
+ )
46
+
47
+ if (
48
+ parallelism_config.tp_size is not None
49
+ and parallelism_config.tp_size > 1
50
+ ):
51
+ from .tensor_parallelism import maybe_enable_tensor_parallelism
52
+
53
+ transformer = maybe_enable_tensor_parallelism(
54
+ transformer=transformer,
55
+ parallelism_config=parallelism_config,
56
+ )
57
+ else:
58
+ raise ValueError(
59
+ "NATIVE_PYTORCH only supported tensor parallelism now. "
60
+ "Please set tp_size > 1 for tensor parallelism."
61
+ )
62
+ return transformer
@@ -0,0 +1,48 @@
1
+ try:
2
+ import einops
3
+ except ImportError:
4
+ raise ImportError(
5
+ "Metrics functionality requires the 'parallelism' extra dependencies. "
6
+ "Install with:\npip install cache-dit[parallelism]"
7
+ )
8
+
9
+ import torch
10
+ from typing import Optional
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
13
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
14
+ from cache_dit.logger import init_logger
15
+ from .tp_planners import *
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ def maybe_enable_tensor_parallelism(
21
+ transformer: torch.nn.Module | ModelMixin,
22
+ parallelism_config: Optional[ParallelismConfig],
23
+ ) -> torch.nn.Module:
24
+ assert isinstance(transformer, torch.nn.Module), (
25
+ "transformer must be an instance of torch.nn.Module, "
26
+ f"but got {type(transformer)}"
27
+ )
28
+ assert isinstance(transformer, ModelMixin), (
29
+ "transformer must be an instance of diffusers' ModelMixin, "
30
+ f"but got {type(transformer)}"
31
+ )
32
+ if parallelism_config is None:
33
+ return transformer
34
+
35
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
36
+ "parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
37
+ f"but got {parallelism_config.backend}"
38
+ )
39
+
40
+ extra_parallel_kwargs = {}
41
+ if parallelism_config.parallel_kwargs is not None:
42
+ extra_parallel_kwargs = parallelism_config.parallel_kwargs
43
+
44
+ return TensorParallelismPlannerRegister.get_planner(transformer)().apply(
45
+ transformer=transformer,
46
+ parallelism_config=parallelism_config,
47
+ **extra_parallel_kwargs,
48
+ )
@@ -0,0 +1,159 @@
1
+ import torch
2
+ from diffusers.models.transformers.transformer_flux import (
3
+ FluxSingleTransformerBlock,
4
+ )
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.distributed import DeviceMesh, init_device_mesh
8
+ from torch.distributed._tensor import Replicate
9
+ from torch.distributed.tensor.parallel import (
10
+ ColwiseParallel,
11
+ RowwiseParallel,
12
+ parallelize_module,
13
+ )
14
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
15
+ from .tp_plan_registers import (
16
+ TensorParallelismPlanner,
17
+ TensorParallelismPlannerRegister,
18
+ )
19
+
20
+ from cache_dit.logger import init_logger
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ @TensorParallelismPlannerRegister.register("Flux")
26
+ class FluxTensorParallelismPlanner(TensorParallelismPlanner):
27
+ def apply(
28
+ self,
29
+ transformer: torch.nn.Module,
30
+ parallelism_config: ParallelismConfig,
31
+ **kwargs,
32
+ ) -> torch.nn.Module:
33
+ assert (
34
+ parallelism_config.tp_size is not None
35
+ and parallelism_config.tp_size > 1
36
+ ), (
37
+ "parallel_config.tp_size must be set and greater than 1 for "
38
+ "tensor parallelism"
39
+ )
40
+
41
+ device_type = torch.accelerator.current_accelerator().type
42
+ tp_mesh: DeviceMesh = init_device_mesh(
43
+ device_type=device_type,
44
+ mesh_shape=[parallelism_config.tp_size],
45
+ )
46
+
47
+ transformer = self.parallelize_transformer(
48
+ transformer=transformer,
49
+ tp_mesh=tp_mesh,
50
+ )
51
+ # TODO: Parallelize t5 text encoder via `apply_extra`
52
+ # abstract method and `extra_parallel_kwargs` ?
53
+
54
+ return transformer
55
+
56
+ def parallelize_t5(
57
+ self,
58
+ text_encoder: nn.Module,
59
+ tp_mesh: DeviceMesh,
60
+ ):
61
+ for i, block in enumerate(text_encoder.encoder.block):
62
+ block.layer[0].SelfAttention.n_heads //= tp_mesh.size()
63
+ block.layer[0].SelfAttention.inner_dim //= tp_mesh.size()
64
+ layer_plan = {
65
+ "layer.0.SelfAttention.q": ColwiseParallel(),
66
+ "layer.0.SelfAttention.k": ColwiseParallel(),
67
+ "layer.0.SelfAttention.v": ColwiseParallel(),
68
+ "layer.0.SelfAttention.o": RowwiseParallel(),
69
+ "layer.1.DenseReluDense.wi_0": ColwiseParallel(),
70
+ "layer.1.DenseReluDense.wi_1": ColwiseParallel(),
71
+ "layer.1.DenseReluDense.wo": RowwiseParallel(),
72
+ }
73
+ if i == 0:
74
+ layer_plan["layer.0.SelfAttention.relative_attention_bias"] = (
75
+ ColwiseParallel()
76
+ )
77
+ parallelize_module(
78
+ module=block,
79
+ device_mesh=tp_mesh,
80
+ parallelize_plan=layer_plan,
81
+ )
82
+
83
+ return text_encoder
84
+
85
+ def parallelize_transformer(
86
+ self,
87
+ transformer: nn.Module,
88
+ tp_mesh: DeviceMesh,
89
+ ):
90
+ for _, block in transformer.transformer_blocks.named_children():
91
+ block.attn.heads //= tp_mesh.size()
92
+ layer_plan = {
93
+ "attn.to_q": ColwiseParallel(),
94
+ "attn.to_k": ColwiseParallel(),
95
+ "attn.to_v": ColwiseParallel(),
96
+ "attn.to_out.0": RowwiseParallel(),
97
+ "norm1.linear": ColwiseParallel(output_layouts=Replicate()),
98
+ "ff.net.0.proj": ColwiseParallel(),
99
+ "ff.net.2": RowwiseParallel(),
100
+ "attn.add_q_proj": ColwiseParallel(),
101
+ "attn.add_k_proj": ColwiseParallel(),
102
+ "attn.add_v_proj": ColwiseParallel(),
103
+ "attn.to_add_out": RowwiseParallel(),
104
+ "norm1_context.linear": ColwiseParallel(
105
+ output_layouts=Replicate()
106
+ ),
107
+ "ff_context.net.0.proj": ColwiseParallel(),
108
+ "ff_context.net.2": RowwiseParallel(),
109
+ }
110
+ parallelize_module(
111
+ module=block,
112
+ device_mesh=tp_mesh,
113
+ parallelize_plan=layer_plan,
114
+ )
115
+
116
+ # NOTE: special handling for FluxSingleTransformerBlock, we have to
117
+ # rearrange the proj_out weight because it contains both out and down
118
+ # projection weights in a single matrix.
119
+ def rearrange_proj_out_weight(
120
+ single_block: FluxSingleTransformerBlock, tp_group_size
121
+ ):
122
+ # rowwise
123
+ hidden_dim = 3072
124
+ requires_grad = single_block.proj_out.weight.requires_grad
125
+ linear2_weight_data = (
126
+ single_block.proj_out.weight.data.T.detach().clone()
127
+ )
128
+ out_weight = linear2_weight_data[:hidden_dim, ...]
129
+ out_weight = rearrange(
130
+ out_weight, "(G D) C -> G D C", G=tp_group_size
131
+ )
132
+ down_weight = linear2_weight_data.data[hidden_dim:, ...]
133
+ down_weight = rearrange(
134
+ down_weight, "(G D) C -> G D C", G=tp_group_size
135
+ )
136
+ new_linear2_weight = torch.cat([out_weight, down_weight], dim=1)
137
+ new_linear2_weight = rearrange(
138
+ new_linear2_weight, "G D C -> (G D) C"
139
+ )
140
+ single_block.proj_out.weight.data.copy_(new_linear2_weight.T)
141
+ single_block.proj_out.weight.requires_grad_(requires_grad)
142
+
143
+ for _, block in transformer.single_transformer_blocks.named_children():
144
+ rearrange_proj_out_weight(block, tp_mesh.size())
145
+ block.attn.heads //= tp_mesh.size()
146
+ layer_plan = {
147
+ "attn.to_q": ColwiseParallel(),
148
+ "attn.to_k": ColwiseParallel(),
149
+ "attn.to_v": ColwiseParallel(),
150
+ "proj_mlp": ColwiseParallel(),
151
+ "proj_out": RowwiseParallel(),
152
+ "norm.linear": ColwiseParallel(output_layouts=Replicate()),
153
+ }
154
+ parallelize_module(
155
+ module=block,
156
+ device_mesh=tp_mesh,
157
+ parallelize_plan=layer_plan,
158
+ )
159
+ return transformer