d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
d9d/__init__.py ADDED
File without changes
d9d/core/__init__.py ADDED
File without changes
@@ -0,0 +1,7 @@
1
+ from .grad_context import GLOBAL_GRAD_CONTEXT, GlobalGradContext, GradDirection
2
+
3
+ __all__ = [
4
+ "GLOBAL_GRAD_CONTEXT",
5
+ "GlobalGradContext",
6
+ "GradDirection"
7
+ ]
@@ -0,0 +1,85 @@
1
+ from contextlib import contextmanager
2
+ from enum import StrEnum
3
+
4
+
5
+ class GradDirection(StrEnum):
6
+ """
7
+ Enum representing the specific gradient edges to compute.
8
+
9
+ This is used to manually control gradient flow in custom autograd functions
10
+ during split backward passes.
11
+
12
+ Attributes:
13
+ inputs: Mark gradient edge as pointing to the module's inputs (activations).
14
+ weight: Mark gradient edge as pointing to the module's parameters (weights).
15
+ """
16
+
17
+ inputs = "inputs"
18
+ weight = "weights"
19
+
20
+
21
+ class GlobalGradContext:
22
+ """
23
+ Global state manager for controlling gradient computation in custom autograd functions.
24
+
25
+ This context addresses a limitation in PyTorch where custom `torch.autograd.Function`
26
+ implementations set `ctx.needs_input_grad` to True for all edges requiring grad,
27
+ even during partial backward passes (e.g., `torch.autograd.backward(inputs=...)`).
28
+
29
+ For additional information on this limitation, please refer to a
30
+ [related issue](https://github.com/pytorch/pytorch/issues/174017).
31
+
32
+ This class allows:
33
+
34
+ 1. For the training code - to explicitly signal which gradient edges (inputs vs weights)
35
+ should currently be computed, allowing custom ops to skip unnecessary computations.
36
+ 2. For module code - to check whether it's required to compute a gradient edge.
37
+ """
38
+
39
+ def __init__(self):
40
+ """Constructs a GlobalGradContext object with all directions enabled by default."""
41
+
42
+ # both directions by default
43
+ self._enabled_directions: set[GradDirection] = {GradDirection.inputs, GradDirection.weight}
44
+
45
+ def check_direction(self, direction: GradDirection | None) -> bool:
46
+ """
47
+ Checks if the gradient calculation for the given direction is currently enabled.
48
+
49
+ Args:
50
+ direction: The direction to check (inputs or weights). If None,
51
+ returns True.
52
+
53
+ Returns:
54
+ True if the direction is enabled or None is passed, False otherwise.
55
+ """
56
+
57
+ if direction is None:
58
+ return True
59
+
60
+ return direction in self._enabled_directions
61
+
62
+ @contextmanager
63
+ def with_directions(self, *directions: GradDirection):
64
+ """
65
+ Context manager that sets the enabled gradient directions.
66
+
67
+ This overrides the current state for the duration of the context
68
+ and restores the previous state afterwards.
69
+
70
+ Args:
71
+ *directions: The gradient directions to enable.
72
+ """
73
+ prev_directions = self._enabled_directions
74
+ self._enabled_directions = set(directions)
75
+ yield
76
+ self._enabled_directions = prev_directions
77
+
78
+
79
+ GLOBAL_GRAD_CONTEXT = GlobalGradContext()
80
+ """
81
+ The singleton instance of GlobalGradContext.
82
+
83
+ This should be used by custom autograd functions to check `GLOBAL_GRAD_CONTEXT.check_direction()`
84
+ during their backward pass.
85
+ """
@@ -0,0 +1,19 @@
1
+ """
2
+ This package configures the distributed environment and device meshes.
3
+ """
4
+
5
+ from .configured import DistributedContext
6
+ from .device_mesh_domains import BATCH_DOMAIN, DENSE_DOMAIN, EXPERT_DOMAIN, FLAT_DOMAIN, REGULAR_DOMAIN
7
+ from .log import build_dist_logger
8
+ from .params import DeviceMeshParameters
9
+
10
+ __all__ = [
11
+ "BATCH_DOMAIN",
12
+ "DENSE_DOMAIN",
13
+ "EXPERT_DOMAIN",
14
+ "FLAT_DOMAIN",
15
+ "REGULAR_DOMAIN",
16
+ "DeviceMeshParameters",
17
+ "DistributedContext",
18
+ "build_dist_logger"
19
+ ]
@@ -0,0 +1,215 @@
1
+ import datetime
2
+ import logging
3
+ import os
4
+ import socket
5
+ from contextlib import contextmanager
6
+ from typing import TYPE_CHECKING
7
+
8
+ import torch
9
+ from torch.distributed import DeviceMesh
10
+
11
+ from .device_mesh_domains import ALL_DOMAIN_PROVIDERS, REGULAR_DOMAIN
12
+ from .log import build_dist_logger
13
+
14
+ if TYPE_CHECKING:
15
+ from .params import DeviceMeshParameters
16
+
17
+
18
+ def _resolve_master_addr() -> str:
19
+ if "MASTER_ADDR" not in os.environ:
20
+ return "127.0.0.1"
21
+
22
+ master_addr = os.environ["MASTER_ADDR"]
23
+
24
+ try:
25
+ return socket.gethostbyname(master_addr)
26
+ except OSError:
27
+ return master_addr
28
+
29
+
30
+ def _build_mesh_domains(params: "DeviceMeshParameters") -> dict[str, DeviceMesh]:
31
+ return {
32
+ provider.name: provider.build_mesh(params)
33
+ for provider in ALL_DOMAIN_PROVIDERS
34
+ }
35
+
36
+
37
+ class DistributedContext:
38
+ """
39
+ Acts as the single source of truth for the distributed execution environment.
40
+
41
+ It acts as the central repository for the distributed configuration, managing the creation
42
+ and synchronization of PyTorch DeviceMeshes for different domains (Regular domain, Expert Parallel domain, ...).
43
+
44
+ All assertions regarding rank placement, group memberships, and parallel topology
45
+ must be derived from this context to ensure consistency.
46
+ """
47
+
48
+ def __init__(self, params: "DeviceMeshParameters", log_level: int):
49
+ self._params = params
50
+
51
+ if params.is_distributed:
52
+ meshes = _build_mesh_domains(params)
53
+ regular_mesh = meshes[REGULAR_DOMAIN]
54
+
55
+ self._meshes = meshes
56
+ self._num_nodes = regular_mesh.size() // torch.cuda.device_count()
57
+ self._logger = build_dist_logger(
58
+ f'pp:{regular_mesh.get_local_rank("pp")}-'
59
+ f'dpr:{regular_mesh.get_local_rank("dp_replicate")}-'
60
+ f'dps:{regular_mesh.get_local_rank("dp_shard")}-'
61
+ f'cps:{regular_mesh.get_local_rank("cp_shard")}-'
62
+ f'cpr:{regular_mesh.get_local_rank("cp_replicate")}-'
63
+ f'tp:{regular_mesh.get_local_rank("tp")}',
64
+ level=log_level
65
+ )
66
+ else:
67
+ self._meshes = {}
68
+ self._num_nodes = 1
69
+ self._logger = build_dist_logger("local", level=log_level)
70
+
71
+ self._local_rank = int(os.environ.get("LOCAL_RANK", "0"))
72
+ self._global_rank = int(os.environ.get("RANK", "0"))
73
+
74
+ self._node_rank = self._global_rank // torch.cuda.device_count()
75
+
76
+ self._master_addr = _resolve_master_addr()
77
+ self._current_device = torch.device("cuda")
78
+
79
+ torch.cuda.set_device(self._local_rank)
80
+
81
+ @property
82
+ def logger(self) -> logging.Logger:
83
+ """Returns the logger instance configured for distributed logging."""
84
+
85
+ return self._logger
86
+
87
+ def mesh_for(self, domain: str) -> DeviceMesh:
88
+ """
89
+ Returns the device mesh view associated with a specific logical domain.
90
+
91
+ Available Domains and Dimensions:
92
+ * `regular` (`REGULAR_DOMAIN`): The most granular mesh for fully decomposed parallelism.
93
+ Dimensions: ``('pp', 'dp_replicate', 'dp_shard', 'cp_shard', 'cp_replicate', 'tp')``
94
+ * `expert` (`EXPERT_DOMAIN`): Mesh optimized for distributing MoE (Mixture of Experts) layers.
95
+ Dimensions: ``('pp', 'replicate', 'ep')``
96
+ * `dense` (`DENSE_DOMAIN`): Mesh optimized for distributing dense layers.
97
+ Dimensions: ``('pp', 'dp_replicate', 'dp_cp_shard', 'cp_replicate', 'tp')``
98
+ * `batch` (`BATCH_DOMAIN`): Mesh optimized for distributing input data.
99
+ Dimensions: ``('pp', 'dp', 'cp', 'tp')``
100
+ * `flat` (`FLAT_DOMAIN`): Mesh containing a single dimension with all the processes.
101
+ Dimensions: ``('world')``
102
+
103
+ Args:
104
+ domain: The name of the domain to retrieve.
105
+
106
+ Returns:
107
+ The PyTorch DeviceMesh configured for the requested domain.
108
+
109
+ Raises:
110
+ ValueError: If the specified domain does not exist.
111
+ """
112
+
113
+ if domain not in self._meshes:
114
+ raise ValueError(f"Domain {domain} does not exist")
115
+ return self._meshes[domain]
116
+
117
+ @property
118
+ def is_main_process(self) -> bool:
119
+ """Checks if the current process is the global rank 0."""
120
+
121
+ return self._global_rank == 0
122
+
123
+ @property
124
+ def is_local_main_process(self) -> bool:
125
+ """Checks if the current process is the rank 0 on the specific node."""
126
+
127
+ return self._local_rank == 0
128
+
129
+ def wait_world(self):
130
+ """Blocks process execution until all ranks reach this point."""
131
+
132
+ torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
133
+ torch.cuda.synchronize()
134
+
135
+ def set_timeout(self, timeout_seconds: float):
136
+ """
137
+ Updates the NCCL/process group timeout for all underlying meshes.
138
+
139
+ Args:
140
+ timeout_seconds: New timeout duration in seconds.
141
+ """
142
+
143
+ self.logger.info(f"Setting global timeout to {timeout_seconds} seconds")
144
+ self.wait_world()
145
+
146
+ groups: list[torch.distributed.ProcessGroup | None] = [None]
147
+ for mesh in self._meshes.values():
148
+ for dim in range(mesh.ndim):
149
+ groups.append(mesh.get_group(dim))
150
+
151
+ for group in groups:
152
+ torch.distributed.distributed_c10d._set_pg_timeout(datetime.timedelta(seconds=timeout_seconds), group) # noqa: SLF001
153
+
154
+ @contextmanager
155
+ def local_main_process_first(self):
156
+ """
157
+ Context manager that executes the block on the local main process first.
158
+
159
+ Other local ranks wait at the entrance. The local main process waits at the
160
+ exit to synchronize before continuing.
161
+ """
162
+ if not self.is_local_main_process:
163
+ self.wait_world()
164
+
165
+ yield
166
+
167
+ if self.is_local_main_process:
168
+ self.wait_world()
169
+
170
+ @contextmanager
171
+ def main_process_first(self):
172
+ """
173
+ Context manager that executes the block on the global main process first.
174
+
175
+ All other ranks wait at the entrance. The global main process waits at the
176
+ exit to synchronize before continuing.
177
+ """
178
+
179
+ if not self.is_main_process:
180
+ self.wait_world()
181
+
182
+ yield
183
+
184
+ if self.is_main_process:
185
+ self.wait_world()
186
+
187
+ @property
188
+ def current_device(self) -> torch.device:
189
+ """Returns the CUDA device associated with this rank."""
190
+
191
+ return self._current_device
192
+
193
+ @property
194
+ def mesh_params(self) -> "DeviceMeshParameters":
195
+ """Returns the parameters used to initialize this context."""
196
+
197
+ return self._params
198
+
199
+ @property
200
+ def master_addr(self) -> str:
201
+ """Returns the IP address or domain name of the master node."""
202
+
203
+ return self._master_addr
204
+
205
+ @property
206
+ def node_rank(self) -> int:
207
+ """Returns the index of the node this process is running on."""
208
+
209
+ return self._node_rank
210
+
211
+ @property
212
+ def num_nodes(self) -> int:
213
+ """Returns the total number of nodes in the cluster."""
214
+
215
+ return self._num_nodes
@@ -0,0 +1,185 @@
1
+ import abc
2
+ from typing import TYPE_CHECKING
3
+
4
+ from torch.distributed import DeviceMesh, init_device_mesh
5
+
6
+ if TYPE_CHECKING:
7
+ from .params import DeviceMeshParameters
8
+
9
+
10
+ class DeviceMeshDomain(abc.ABC):
11
+ """
12
+ Abstract base class for a Device Mesh provider.
13
+
14
+ A Domain defines a specific strategy for organizing available GPUs into a
15
+ multidimensional grid (Mesh) to support specific parallelism techniques.
16
+ """
17
+
18
+ @property
19
+ @abc.abstractmethod
20
+ def name(self) -> str:
21
+ """Returns the unique identifier for this mesh domain."""
22
+
23
+ ...
24
+
25
+ @abc.abstractmethod
26
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
27
+ """
28
+ Constructs the device mesh configuration.
29
+
30
+ Args:
31
+ params: Global configuration parameters for the distributed environment.
32
+
33
+ Returns:
34
+ The initialized PyTorch DeviceMesh for this specific domain.
35
+ """
36
+
37
+ ...
38
+
39
+
40
+ REGULAR_DOMAIN = "regular"
41
+
42
+
43
+ class RegularDomain(DeviceMeshDomain):
44
+ @property
45
+ def name(self) -> str:
46
+ return "regular"
47
+
48
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
49
+ return init_device_mesh(
50
+ device_type="cuda",
51
+ mesh_shape=(
52
+ params.pipeline_parallel,
53
+ params.data_parallel_replicate,
54
+ params.data_parallel_shard,
55
+ params.context_parallel_shard,
56
+ params.context_parallel_replicate,
57
+ params.tensor_parallel
58
+ ),
59
+ mesh_dim_names=(
60
+ "pp",
61
+ "dp_replicate",
62
+ "dp_shard",
63
+ "cp_shard",
64
+ "cp_replicate",
65
+ "tp"
66
+ )
67
+ )
68
+
69
+
70
+ EXPERT_DOMAIN = "expert"
71
+
72
+
73
+ class ExpertDomain(DeviceMeshDomain):
74
+ @property
75
+ def name(self) -> str:
76
+ return EXPERT_DOMAIN
77
+
78
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
79
+ replicate_degree = (
80
+ params.data_parallel_replicate *
81
+ params.context_parallel_replicate *
82
+ params.data_parallel_shard *
83
+ params.context_parallel_shard
84
+ )
85
+ return init_device_mesh(
86
+ device_type="cuda",
87
+ mesh_shape=(
88
+ params.pipeline_parallel,
89
+ replicate_degree // params.expert_parallel,
90
+ params.expert_parallel
91
+ ),
92
+ mesh_dim_names=(
93
+ "pp",
94
+ "ep_replicate",
95
+ "ep_shard"
96
+ )
97
+ )
98
+
99
+
100
+ DENSE_DOMAIN = "dense"
101
+
102
+
103
+ class DenseDomain(DeviceMeshDomain):
104
+ @property
105
+ def name(self) -> str:
106
+ return DENSE_DOMAIN
107
+
108
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
109
+ return init_device_mesh(
110
+ device_type="cuda",
111
+ mesh_shape=(
112
+ params.pipeline_parallel,
113
+ params.data_parallel_replicate,
114
+ params.data_parallel_shard * params.context_parallel_shard,
115
+ params.context_parallel_replicate,
116
+ params.tensor_parallel
117
+ ),
118
+ mesh_dim_names=(
119
+ "pp",
120
+ "dp_replicate",
121
+ "dp_cp_shard",
122
+ "cp_replicate",
123
+ "tp"
124
+ )
125
+ )
126
+
127
+
128
+ BATCH_DOMAIN = "batch"
129
+
130
+
131
+ class BatchDomain(DeviceMeshDomain):
132
+ @property
133
+ def name(self) -> str:
134
+ return BATCH_DOMAIN
135
+
136
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
137
+ return init_device_mesh(
138
+ device_type="cuda",
139
+ mesh_shape=(
140
+ params.pipeline_parallel,
141
+ params.data_parallel_replicate * params.data_parallel_shard,
142
+ params.context_parallel_replicate * params.context_parallel_shard,
143
+ params.tensor_parallel
144
+ ),
145
+ mesh_dim_names=(
146
+ "pp",
147
+ "dp",
148
+ "cp",
149
+ "tp"
150
+ )
151
+ )
152
+
153
+
154
+ FLAT_DOMAIN = "flat"
155
+
156
+
157
+ class FlatDomain(DeviceMeshDomain):
158
+ @property
159
+ def name(self) -> str:
160
+ return FLAT_DOMAIN
161
+
162
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
163
+ mesh_shape = (
164
+ params.pipeline_parallel *
165
+ params.data_parallel_replicate *
166
+ params.data_parallel_shard *
167
+ params.context_parallel_replicate *
168
+ params.context_parallel_shard *
169
+ params.tensor_parallel
170
+ )
171
+ return init_device_mesh(
172
+ device_type="cuda",
173
+ mesh_shape=(
174
+ mesh_shape,
175
+ ),
176
+ mesh_dim_names=(
177
+ "world",
178
+ )
179
+ )
180
+
181
+
182
+ ALL_DOMAIN_PROVIDERS: list[DeviceMeshDomain] = [
183
+ RegularDomain(), DenseDomain(), ExpertDomain(), BatchDomain(),
184
+ FlatDomain()
185
+ ]
@@ -0,0 +1,30 @@
1
+ import logging
2
+ import sys
3
+
4
+
5
+ def build_dist_logger(qualifier: str, level: int) -> logging.Logger:
6
+ """
7
+ Configures and returns a logger instance for d9d.
8
+
9
+ The logger is configured to write to stdout with a formatter that includes
10
+ the provided rank qualifier, allowing for easier debugging in distributed logs.
11
+
12
+ Args:
13
+ qualifier: A string identifying the current rank's position in the mesh.
14
+ level: Log level to set by default
15
+
16
+ Returns:
17
+ A configured logging.Logger instance.
18
+ """
19
+
20
+ dist_logger = logging.getLogger("d9d")
21
+ dist_logger.setLevel(level)
22
+ dist_logger.handlers.clear()
23
+ ch = logging.StreamHandler(sys.stdout)
24
+ ch.setLevel(level)
25
+ formatter = logging.Formatter(
26
+ f"[d9d] [{qualifier}] %(asctime)s - %(levelname)s - %(message)s"
27
+ )
28
+ ch.setFormatter(formatter)
29
+ dist_logger.addHandler(ch)
30
+ return dist_logger
@@ -0,0 +1,113 @@
1
+ import logging
2
+ from typing import Self
3
+
4
+ from pydantic import BaseModel, ConfigDict, model_validator
5
+
6
+ from .configured import DistributedContext
7
+
8
+
9
+ class DeviceMeshParameters(BaseModel):
10
+ """
11
+ Configuration parameters for initializing Distributed Device Meshes.
12
+
13
+ Attributes:
14
+ pipeline_parallel: Degree of pipeline parallelism (PP).
15
+ data_parallel_replicate: Degree of data parallel replication (DDP).
16
+ data_parallel_shard: Degree of data parallel sharding (FSDP).
17
+ context_parallel_replicate: Degree of context parallel (CP) replication.
18
+ context_parallel_shard: Degree of context parallel (FSCP) sharding.
19
+ tensor_parallel: Degree of tensor parallelism (TP).
20
+ expert_parallel: Degree of expert parallelism (EP/MoE).
21
+ """
22
+
23
+ model_config = ConfigDict(frozen=True)
24
+
25
+ pipeline_parallel: int = 1
26
+
27
+ data_parallel_replicate: int = 1
28
+ data_parallel_shard: int = 1
29
+
30
+ context_parallel_replicate: int = 1
31
+ context_parallel_shard: int = 1
32
+
33
+ tensor_parallel: int = 1
34
+
35
+ expert_parallel: int = 1
36
+
37
+ @property
38
+ def has_pipeline_parallel(self) -> bool:
39
+ """Checks if pipeline parallelism is enabled (degree > 1)."""
40
+
41
+ return self.pipeline_parallel > 1
42
+
43
+ @property
44
+ def has_data_parallel_replicate(self) -> bool:
45
+ """Checks if data parallel replication is enabled (degree > 1)."""
46
+
47
+ return self.data_parallel_replicate > 1
48
+
49
+ @property
50
+ def has_data_parallel_shard(self) -> bool:
51
+ """Checks if data parallel sharding is enabled (degree > 1)."""
52
+
53
+ return self.data_parallel_shard > 1
54
+
55
+ @property
56
+ def has_context_parallel_replicate(self) -> bool:
57
+ return self.context_parallel_replicate > 1
58
+
59
+ @property
60
+ def has_context_parallel_shard(self) -> bool:
61
+ return self.context_parallel_shard > 1
62
+
63
+ @property
64
+ def has_tensor_parallel(self) -> bool:
65
+ return self.tensor_parallel > 1
66
+
67
+ @property
68
+ def has_expert_parallel(self) -> bool:
69
+ """Checks if expert parallelism is enabled (degree > 1)."""
70
+ return self.expert_parallel > 1
71
+
72
+ @property
73
+ def is_distributed(self) -> bool:
74
+ """Checks if any form of parallelism is enabled."""
75
+
76
+ return (
77
+ self.has_pipeline_parallel or
78
+ self.has_data_parallel_replicate or
79
+ self.has_data_parallel_shard or
80
+ self.has_context_parallel_shard or
81
+ self.has_context_parallel_replicate or
82
+ self.has_expert_parallel or
83
+ self.has_tensor_parallel
84
+ )
85
+
86
+ @model_validator(mode="after")
87
+ def _check_ep_divisibility(self) -> Self:
88
+ """Validates that DP/CP/TP dimensions can support the requested EP/ETP degrees."""
89
+ dp_cp_tp_degree = (
90
+ self.data_parallel_shard *
91
+ self.data_parallel_replicate *
92
+ self.context_parallel_shard *
93
+ self.context_parallel_replicate *
94
+ self.tensor_parallel
95
+ )
96
+ ep_degree = self.expert_parallel
97
+
98
+ if dp_cp_tp_degree % ep_degree != 0:
99
+ raise ValueError(
100
+ f"Total data/context/tensor parallelism degree ({dp_cp_tp_degree}) must be divisible by "
101
+ f"total expert parallelism degree ({ep_degree})."
102
+ )
103
+ return self
104
+
105
+ def build(self, log_level: int = logging.INFO) -> "DistributedContext":
106
+ """
107
+ Initializes the DistributedContext using these parameters.
108
+
109
+ Returns:
110
+ A new DistributedContext instance containing the initialized device meshes.
111
+ """
112
+
113
+ return DistributedContext(self, log_level)
@@ -0,0 +1,16 @@
1
+ """
2
+ This module provides high-level wrappers around `torch.distributed` collective operations.
3
+ """
4
+
5
+
6
+ from .object import all_gather_object, gather_object
7
+ from .tensor import all_gather, all_gather_variadic_shape, gather, gather_variadic_shape
8
+
9
+ __all__ = [
10
+ "all_gather",
11
+ "all_gather_object",
12
+ "all_gather_variadic_shape",
13
+ "gather",
14
+ "gather_object",
15
+ "gather_variadic_shape"
16
+ ]