cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,49 @@
1
+ import torch
2
+
3
+ from typing import Optional
4
+ from cache_dit.logger import init_logger
5
+
6
+ logger = init_logger(__name__)
7
+
8
+
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
11
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
12
+ from .context_parallelism import maybe_enable_context_parallelism
13
+
14
+
15
+ def maybe_enable_parallelism(
16
+ transformer: torch.nn.Module,
17
+ parallelism_config: Optional[ParallelismConfig],
18
+ ) -> torch.nn.Module:
19
+ assert isinstance(transformer, ModelMixin), (
20
+ "transformer must be an instance of diffusers' ModelMixin, "
21
+ f"but got {type(transformer)}"
22
+ )
23
+ if parallelism_config is None:
24
+ return transformer
25
+
26
+ assert isinstance(parallelism_config, ParallelismConfig), (
27
+ "parallelism_config must be an instance of ParallelismConfig"
28
+ f" but got {type(parallelism_config)}"
29
+ )
30
+
31
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER, (
32
+ f"parallelism backend must be {ParallelismBackend.NATIVE_DIFFUSER}, "
33
+ f"but got {parallelism_config.backend}"
34
+ )
35
+
36
+ if (
37
+ parallelism_config.ulysses_size is not None
38
+ or parallelism_config.ring_size is not None
39
+ ):
40
+ transformer = maybe_enable_context_parallelism(
41
+ transformer,
42
+ parallelism_config,
43
+ )
44
+ else:
45
+ raise ValueError(
46
+ "NATIVE_DIFFUSER backend only support context parallelism now. "
47
+ "Please set ulysses_size or ring_size in parallelism_config."
48
+ )
49
+ return transformer
@@ -0,0 +1,11 @@
1
+ try:
2
+ from diffusers import ContextParallelConfig
3
+
4
+ def native_diffusers_parallelism_available() -> bool:
5
+ return True
6
+
7
+ except ImportError:
8
+ ContextParallelConfig = None
9
+
10
+ def native_diffusers_parallelism_available() -> bool:
11
+ return False
@@ -0,0 +1,6 @@
1
+ from cache_dit.parallelism.backends.native_pytorch.tensor_parallelism import (
2
+ TensorParallelismPlannerRegister,
3
+ )
4
+ from cache_dit.parallelism.backends.native_pytorch.parallel_torch import (
5
+ maybe_enable_parallelism,
6
+ )
@@ -0,0 +1,62 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+
7
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
8
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
9
+
10
+ from cache_dit.logger import init_logger
11
+
12
+ logger = init_logger(__name__)
13
+
14
+
15
+ def maybe_enable_parallelism(
16
+ transformer: torch.nn.Module | ModelMixin,
17
+ parallelism_config: Optional[ParallelismConfig],
18
+ ) -> torch.nn.Module:
19
+ assert isinstance(transformer, torch.nn.Module), (
20
+ "transformer must be an instance of torch.nn.Module, "
21
+ f"but got {type(transformer)}"
22
+ )
23
+ assert isinstance(transformer, ModelMixin), (
24
+ "transformer must be an instance of diffusers' ModelMixin, "
25
+ f"but got {type(transformer)}"
26
+ )
27
+ if parallelism_config is None:
28
+ return transformer
29
+
30
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
31
+ "parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
32
+ f"but got {parallelism_config.backend}"
33
+ )
34
+
35
+ assert isinstance(parallelism_config, ParallelismConfig), (
36
+ "parallelism_config must be an instance of ParallelismConfig"
37
+ f" but got {type(parallelism_config)}"
38
+ )
39
+ assert (
40
+ parallelism_config.ulysses_size is None
41
+ and parallelism_config.ring_size is None
42
+ ), (
43
+ "Ulysses/Ring parallelism is not supported in Native_PyTorch backend. "
44
+ "Please set it to None in parallelism_config."
45
+ )
46
+
47
+ if (
48
+ parallelism_config.tp_size is not None
49
+ and parallelism_config.tp_size > 1
50
+ ):
51
+ from .tensor_parallelism import maybe_enable_tensor_parallelism
52
+
53
+ transformer = maybe_enable_tensor_parallelism(
54
+ transformer=transformer,
55
+ parallelism_config=parallelism_config,
56
+ )
57
+ else:
58
+ raise ValueError(
59
+ "NATIVE_PYTORCH only supported tensor parallelism now. "
60
+ "Please set tp_size > 1 for tensor parallelism."
61
+ )
62
+ return transformer
@@ -0,0 +1,48 @@
1
+ try:
2
+ import einops
3
+ except ImportError:
4
+ raise ImportError(
5
+ "Metrics functionality requires the 'parallelism' extra dependencies. "
6
+ "Install with:\npip install cache-dit[parallelism]"
7
+ )
8
+
9
+ import torch
10
+ from typing import Optional
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+ from cache_dit.parallelism.parallel_backend import ParallelismBackend
13
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
14
+ from cache_dit.logger import init_logger
15
+ from .tp_planners import *
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ def maybe_enable_tensor_parallelism(
21
+ transformer: torch.nn.Module | ModelMixin,
22
+ parallelism_config: Optional[ParallelismConfig],
23
+ ) -> torch.nn.Module:
24
+ assert isinstance(transformer, torch.nn.Module), (
25
+ "transformer must be an instance of torch.nn.Module, "
26
+ f"but got {type(transformer)}"
27
+ )
28
+ assert isinstance(transformer, ModelMixin), (
29
+ "transformer must be an instance of diffusers' ModelMixin, "
30
+ f"but got {type(transformer)}"
31
+ )
32
+ if parallelism_config is None:
33
+ return transformer
34
+
35
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
36
+ "parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
37
+ f"but got {parallelism_config.backend}"
38
+ )
39
+
40
+ extra_parallel_kwargs = {}
41
+ if parallelism_config.parallel_kwargs is not None:
42
+ extra_parallel_kwargs = parallelism_config.parallel_kwargs
43
+
44
+ return TensorParallelismPlannerRegister.get_planner(transformer)().apply(
45
+ transformer=transformer,
46
+ parallelism_config=parallelism_config,
47
+ **extra_parallel_kwargs,
48
+ )
@@ -0,0 +1,171 @@
1
+ import torch
2
+ from diffusers.models.transformers.transformer_flux import (
3
+ FluxSingleTransformerBlock,
4
+ )
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.distributed import DeviceMesh, init_device_mesh
8
+ from torch.distributed._tensor import Replicate
9
+ from torch.distributed.tensor.parallel import (
10
+ ColwiseParallel,
11
+ RowwiseParallel,
12
+ parallelize_module,
13
+ )
14
+
15
+ from cache_dit.logger import init_logger
16
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
17
+
18
+ from .tp_plan_registers import (
19
+ TensorParallelismPlanner,
20
+ TensorParallelismPlannerRegister,
21
+ )
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ @TensorParallelismPlannerRegister.register("Chroma")
27
+ @TensorParallelismPlannerRegister.register("HunyuanImage")
28
+ @TensorParallelismPlannerRegister.register("HunyuanVideo")
29
+ @TensorParallelismPlannerRegister.register("Flux")
30
+ class FluxTensorParallelismPlanner(TensorParallelismPlanner):
31
+ def apply(
32
+ self,
33
+ transformer: torch.nn.Module,
34
+ parallelism_config: ParallelismConfig,
35
+ **kwargs,
36
+ ) -> torch.nn.Module:
37
+ assert (
38
+ parallelism_config.tp_size is not None
39
+ and parallelism_config.tp_size > 1
40
+ ), (
41
+ "parallel_config.tp_size must be set and greater than 1 for "
42
+ "tensor parallelism"
43
+ )
44
+
45
+ device_type = torch.accelerator.current_accelerator().type
46
+ tp_mesh: DeviceMesh = init_device_mesh(
47
+ device_type=device_type,
48
+ mesh_shape=[parallelism_config.tp_size],
49
+ )
50
+
51
+ transformer = self.parallelize_transformer(
52
+ transformer=transformer,
53
+ tp_mesh=tp_mesh,
54
+ )
55
+ # TODO: Parallelize t5 text encoder via `apply_extra`
56
+ # abstract method and `extra_parallel_kwargs` ?
57
+
58
+ return transformer
59
+
60
+ def parallelize_t5(
61
+ self,
62
+ text_encoder: nn.Module,
63
+ tp_mesh: DeviceMesh,
64
+ ):
65
+ for i, block in enumerate(text_encoder.encoder.block):
66
+ block.layer[0].SelfAttention.n_heads //= tp_mesh.size()
67
+ block.layer[0].SelfAttention.inner_dim //= tp_mesh.size()
68
+ layer_plan = {
69
+ "layer.0.SelfAttention.q": ColwiseParallel(),
70
+ "layer.0.SelfAttention.k": ColwiseParallel(),
71
+ "layer.0.SelfAttention.v": ColwiseParallel(),
72
+ "layer.0.SelfAttention.o": RowwiseParallel(),
73
+ "layer.1.DenseReluDense.wi_0": ColwiseParallel(),
74
+ "layer.1.DenseReluDense.wi_1": ColwiseParallel(),
75
+ "layer.1.DenseReluDense.wo": RowwiseParallel(),
76
+ }
77
+ if i == 0:
78
+ layer_plan["layer.0.SelfAttention.relative_attention_bias"] = (
79
+ ColwiseParallel()
80
+ )
81
+ parallelize_module(
82
+ module=block,
83
+ device_mesh=tp_mesh,
84
+ parallelize_plan=layer_plan,
85
+ )
86
+
87
+ return text_encoder
88
+
89
+ def parallelize_transformer(
90
+ self,
91
+ transformer: nn.Module,
92
+ tp_mesh: DeviceMesh,
93
+ ):
94
+ for _, block in transformer.transformer_blocks.named_children():
95
+ block.attn.heads //= tp_mesh.size()
96
+ layer_plan = {
97
+ "attn.to_q": ColwiseParallel(),
98
+ "attn.to_k": ColwiseParallel(),
99
+ "attn.to_v": ColwiseParallel(),
100
+ "attn.to_out.0": RowwiseParallel(),
101
+ "ff.net.0.proj": ColwiseParallel(),
102
+ "ff.net.2": RowwiseParallel(),
103
+ "attn.add_q_proj": ColwiseParallel(),
104
+ "attn.add_k_proj": ColwiseParallel(),
105
+ "attn.add_v_proj": ColwiseParallel(),
106
+ "attn.to_add_out": RowwiseParallel(),
107
+ "ff_context.net.0.proj": ColwiseParallel(),
108
+ "ff_context.net.2": RowwiseParallel(),
109
+ }
110
+
111
+ if getattr(block.norm1, "linear", None) is not None:
112
+ layer_plan["norm1.linear"] = ColwiseParallel(
113
+ output_layouts=Replicate()
114
+ )
115
+ if getattr(block.norm1_context, "linear", None) is not None:
116
+ layer_plan["norm1_context.linear"] = ColwiseParallel(
117
+ output_layouts=Replicate()
118
+ )
119
+ parallelize_module(
120
+ module=block,
121
+ device_mesh=tp_mesh,
122
+ parallelize_plan=layer_plan,
123
+ )
124
+
125
+ # NOTE: special handling for FluxSingleTransformerBlock, we have to
126
+ # rearrange the proj_out weight because it contains both out and down
127
+ # projection weights in a single matrix.
128
+ def rearrange_proj_out_weight(
129
+ single_block: FluxSingleTransformerBlock, tp_group_size
130
+ ):
131
+ # rowwise
132
+ hidden_dim = single_block.attn.to_q.weight.shape[0]
133
+ requires_grad = single_block.proj_out.weight.requires_grad
134
+ linear2_weight_data = (
135
+ single_block.proj_out.weight.data.T.detach().clone()
136
+ )
137
+ out_weight = linear2_weight_data[:hidden_dim, ...]
138
+ out_weight = rearrange(
139
+ out_weight, "(G D) C -> G D C", G=tp_group_size
140
+ )
141
+ down_weight = linear2_weight_data.data[hidden_dim:, ...]
142
+ down_weight = rearrange(
143
+ down_weight, "(G D) C -> G D C", G=tp_group_size
144
+ )
145
+ new_linear2_weight = torch.cat([out_weight, down_weight], dim=1)
146
+ new_linear2_weight = rearrange(
147
+ new_linear2_weight, "G D C -> (G D) C"
148
+ )
149
+ single_block.proj_out.weight.data.copy_(new_linear2_weight.T)
150
+ single_block.proj_out.weight.requires_grad_(requires_grad)
151
+
152
+ for _, block in transformer.single_transformer_blocks.named_children():
153
+ rearrange_proj_out_weight(block, tp_mesh.size())
154
+ block.attn.heads //= tp_mesh.size()
155
+ layer_plan = {
156
+ "attn.to_q": ColwiseParallel(),
157
+ "attn.to_k": ColwiseParallel(),
158
+ "attn.to_v": ColwiseParallel(),
159
+ "proj_mlp": ColwiseParallel(),
160
+ "proj_out": RowwiseParallel(),
161
+ }
162
+ if getattr(block.norm, "linear", None) is not None:
163
+ layer_plan["norm.linear"] = ColwiseParallel(
164
+ output_layouts=Replicate()
165
+ )
166
+ parallelize_module(
167
+ module=block,
168
+ device_mesh=tp_mesh,
169
+ parallelize_plan=layer_plan,
170
+ )
171
+ return transformer
@@ -0,0 +1,79 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.distributed import DeviceMesh, init_device_mesh
4
+ from torch.distributed._tensor import Replicate
5
+ from torch.distributed.tensor.parallel import (
6
+ ColwiseParallel,
7
+ RowwiseParallel,
8
+ parallelize_module,
9
+ )
10
+
11
+ from cache_dit.logger import init_logger
12
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
13
+
14
+ from .tp_plan_registers import (
15
+ TensorParallelismPlanner,
16
+ TensorParallelismPlannerRegister,
17
+ )
18
+
19
+ logger = init_logger(__name__)
20
+
21
+
22
+ @TensorParallelismPlannerRegister.register("Kandinsky5")
23
+ class Kandinsky5TensorParallelismPlanner(TensorParallelismPlanner):
24
+ def apply(
25
+ self,
26
+ transformer: torch.nn.Module,
27
+ parallelism_config: ParallelismConfig,
28
+ **kwargs,
29
+ ) -> torch.nn.Module:
30
+ assert (
31
+ parallelism_config.tp_size is not None
32
+ and parallelism_config.tp_size > 1
33
+ ), (
34
+ "parallel_config.tp_size must be set and greater than 1 for "
35
+ "tensor parallelism"
36
+ )
37
+
38
+ device_type = torch.accelerator.current_accelerator().type
39
+ tp_mesh: DeviceMesh = init_device_mesh(
40
+ device_type=device_type,
41
+ mesh_shape=[parallelism_config.tp_size],
42
+ )
43
+
44
+ transformer = self.parallelize_transformer(
45
+ transformer=transformer,
46
+ tp_mesh=tp_mesh,
47
+ )
48
+
49
+ return transformer
50
+
51
+ def parallelize_transformer(
52
+ self,
53
+ transformer: nn.Module,
54
+ tp_mesh: DeviceMesh,
55
+ ):
56
+ for _, block in transformer.visual_transformer_blocks.named_children():
57
+ block.self_attention.num_heads //= tp_mesh.size()
58
+ block.cross_attention.num_heads //= tp_mesh.size()
59
+ layer_plan = {
60
+ "self_attention.to_query": ColwiseParallel(),
61
+ "self_attention.to_key": ColwiseParallel(),
62
+ "self_attention.to_value": ColwiseParallel(),
63
+ "self_attention.out_layer": RowwiseParallel(),
64
+ "cross_attention.to_query": ColwiseParallel(),
65
+ "cross_attention.to_key": ColwiseParallel(),
66
+ "cross_attention.to_value": ColwiseParallel(),
67
+ "cross_attention.out_layer": RowwiseParallel(),
68
+ "visual_modulation.out_layer": ColwiseParallel(
69
+ output_layouts=Replicate()
70
+ ),
71
+ "feed_forward.in_layer": ColwiseParallel(),
72
+ "feed_forward.out_layer": RowwiseParallel(),
73
+ }
74
+ parallelize_module(
75
+ module=block,
76
+ device_mesh=tp_mesh,
77
+ parallelize_plan=layer_plan,
78
+ )
79
+ return transformer
@@ -0,0 +1,78 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.distributed import DeviceMesh, init_device_mesh
4
+ from torch.distributed._tensor import Replicate
5
+ from torch.distributed.tensor.parallel import (
6
+ ColwiseParallel,
7
+ RowwiseParallel,
8
+ parallelize_module,
9
+ )
10
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
11
+ from .tp_plan_registers import (
12
+ TensorParallelismPlanner,
13
+ TensorParallelismPlannerRegister,
14
+ )
15
+
16
+ from cache_dit.logger import init_logger
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ @TensorParallelismPlannerRegister.register("QwenImage")
22
+ class QwenImageTensorParallelismPlanner(TensorParallelismPlanner):
23
+ def apply(
24
+ self,
25
+ transformer: torch.nn.Module,
26
+ parallelism_config: ParallelismConfig,
27
+ **kwargs,
28
+ ) -> torch.nn.Module:
29
+ assert (
30
+ parallelism_config.tp_size is not None
31
+ and parallelism_config.tp_size > 1
32
+ ), (
33
+ "parallel_config.tp_size must be set and greater than 1 for "
34
+ "tensor parallelism"
35
+ )
36
+
37
+ device_type = torch.accelerator.current_accelerator().type
38
+ tp_mesh: DeviceMesh = init_device_mesh(
39
+ device_type=device_type,
40
+ mesh_shape=[parallelism_config.tp_size],
41
+ )
42
+
43
+ transformer = self.parallelize_transformer(
44
+ transformer=transformer,
45
+ tp_mesh=tp_mesh,
46
+ )
47
+
48
+ return transformer
49
+
50
+ def parallelize_transformer(
51
+ self,
52
+ transformer: nn.Module,
53
+ tp_mesh: DeviceMesh,
54
+ ):
55
+ for _, block in transformer.transformer_blocks.named_children():
56
+ block.attn.heads //= tp_mesh.size()
57
+ layer_plan = {
58
+ "attn.to_q": ColwiseParallel(),
59
+ "attn.to_k": ColwiseParallel(),
60
+ "attn.to_v": ColwiseParallel(),
61
+ "attn.to_out.0": RowwiseParallel(),
62
+ "img_mod.1": ColwiseParallel(output_layouts=Replicate()),
63
+ "img_mlp.net.0.proj": ColwiseParallel(),
64
+ "img_mlp.net.2": RowwiseParallel(),
65
+ "attn.add_q_proj": ColwiseParallel(),
66
+ "attn.add_k_proj": ColwiseParallel(),
67
+ "attn.add_v_proj": ColwiseParallel(),
68
+ "attn.to_add_out": RowwiseParallel(),
69
+ "txt_mod.1": ColwiseParallel(output_layouts=Replicate()),
70
+ "txt_mlp.net.0.proj": ColwiseParallel(),
71
+ "txt_mlp.net.2": RowwiseParallel(),
72
+ }
73
+ parallelize_module(
74
+ module=block,
75
+ device_mesh=tp_mesh,
76
+ parallelize_plan=layer_plan,
77
+ )
78
+ return transformer
@@ -0,0 +1,65 @@
1
+ import torch
2
+ import logging
3
+ from abc import abstractmethod
4
+ from typing import Dict
5
+ from cache_dit.parallelism.parallel_config import ParallelismConfig
6
+ from cache_dit.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ class TensorParallelismPlanner:
12
+ # TODO: add `apply_extra` abstract method for extra
13
+ # parallelism kwargs handling
14
+
15
+ @abstractmethod
16
+ def apply(
17
+ self,
18
+ transformer: torch.nn.Module,
19
+ parallelism_config: ParallelismConfig,
20
+ **kwargs,
21
+ ) -> torch.nn.Module:
22
+ raise NotImplementedError(
23
+ "apply method must be implemented by subclasses"
24
+ )
25
+
26
+
27
+ class TensorParallelismPlannerRegister:
28
+ _tp_planner_registry: Dict[str, TensorParallelismPlanner] = {}
29
+
30
+ @classmethod
31
+ def register(cls, name: str):
32
+ def decorator(planner_cls: type[TensorParallelismPlanner]):
33
+ assert (
34
+ name not in cls._tp_planner_registry
35
+ ), f"TensorParallelismPlanner with name {name} is already registered."
36
+ if logger.isEnabledFor(logging.DEBUG):
37
+ logger.debug(f"Registering TensorParallelismPlanner: {name}")
38
+ cls._tp_planner_registry[name] = planner_cls
39
+ return planner_cls
40
+
41
+ return decorator
42
+
43
+ @classmethod
44
+ def get_planner(
45
+ cls, transformer: str | torch.nn.Module
46
+ ) -> type[TensorParallelismPlanner]:
47
+ if isinstance(transformer, torch.nn.Module):
48
+ name = transformer.__class__.__name__
49
+ else:
50
+ name = transformer
51
+ planner_cls = None
52
+ for planner_name in cls._tp_planner_registry:
53
+ if name.startswith(planner_name):
54
+ planner_cls = cls._tp_planner_registry.get(planner_name)
55
+ break
56
+ if planner_cls is None:
57
+ raise ValueError(f"No planner registered under name: {name}")
58
+ return planner_cls
59
+
60
+ @classmethod
61
+ def supported_planners(
62
+ cls,
63
+ ) -> tuple[int, list[str]]:
64
+ val_planners = cls._tp_planner_registry.keys()
65
+ return len(val_planners), [p for p in val_planners]