d9d 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d-0.1.0/PKG-INFO +90 -0
  2. d9d-0.1.0/README.md +49 -0
  3. d9d-0.1.0/d9d/__init__.py +0 -0
  4. d9d-0.1.0/d9d/core/__init__.py +0 -0
  5. d9d-0.1.0/d9d/core/autograd/__init__.py +7 -0
  6. d9d-0.1.0/d9d/core/autograd/grad_context.py +85 -0
  7. d9d-0.1.0/d9d/core/dist_context/__init__.py +19 -0
  8. d9d-0.1.0/d9d/core/dist_context/configured.py +215 -0
  9. d9d-0.1.0/d9d/core/dist_context/device_mesh_domains.py +185 -0
  10. d9d-0.1.0/d9d/core/dist_context/log.py +30 -0
  11. d9d-0.1.0/d9d/core/dist_context/params.py +113 -0
  12. d9d-0.1.0/d9d/core/dist_ops/__init__.py +16 -0
  13. d9d-0.1.0/d9d/core/dist_ops/object.py +68 -0
  14. d9d-0.1.0/d9d/core/dist_ops/tensor.py +192 -0
  15. d9d-0.1.0/d9d/core/protocol/__init__.py +8 -0
  16. d9d-0.1.0/d9d/core/protocol/training.py +38 -0
  17. d9d-0.1.0/d9d/core/sharding/__init__.py +15 -0
  18. d9d-0.1.0/d9d/core/sharding/auto_spec.py +66 -0
  19. d9d-0.1.0/d9d/core/sharding/shard.py +154 -0
  20. d9d-0.1.0/d9d/core/sharding/spec.py +28 -0
  21. d9d-0.1.0/d9d/core/sharding/unshard.py +117 -0
  22. d9d-0.1.0/d9d/core/types/__init__.py +12 -0
  23. d9d-0.1.0/d9d/core/types/data.py +14 -0
  24. d9d-0.1.0/d9d/core/types/pytree.py +26 -0
  25. d9d-0.1.0/d9d/dataset/__init__.py +17 -0
  26. d9d-0.1.0/d9d/dataset/buffer_sorted.py +143 -0
  27. d9d-0.1.0/d9d/dataset/padding.py +79 -0
  28. d9d-0.1.0/d9d/dataset/sharded.py +195 -0
  29. d9d-0.1.0/d9d/internals/__init__.py +0 -0
  30. d9d-0.1.0/d9d/internals/determinism/__init__.py +10 -0
  31. d9d-0.1.0/d9d/internals/determinism/seed.py +63 -0
  32. d9d-0.1.0/d9d/internals/grad_norm/__init__.py +8 -0
  33. d9d-0.1.0/d9d/internals/grad_norm/group.py +87 -0
  34. d9d-0.1.0/d9d/internals/grad_norm/norm.py +169 -0
  35. d9d-0.1.0/d9d/internals/grad_sync/__init__.py +14 -0
  36. d9d-0.1.0/d9d/internals/grad_sync/bucket.py +317 -0
  37. d9d-0.1.0/d9d/internals/grad_sync/placement_helper.py +23 -0
  38. d9d-0.1.0/d9d/internals/grad_sync/synchronizer.py +257 -0
  39. d9d-0.1.0/d9d/internals/pipeline_state/__init__.py +14 -0
  40. d9d-0.1.0/d9d/internals/pipeline_state/api.py +45 -0
  41. d9d-0.1.0/d9d/internals/pipeline_state/handler.py +111 -0
  42. d9d-0.1.0/d9d/internals/pipeline_state/storage.py +236 -0
  43. d9d-0.1.0/d9d/internals/profiling/__init__.py +7 -0
  44. d9d-0.1.0/d9d/internals/profiling/profile.py +112 -0
  45. d9d-0.1.0/d9d/internals/state/__init__.py +6 -0
  46. d9d-0.1.0/d9d/internals/state/main_process.py +44 -0
  47. d9d-0.1.0/d9d/kernel/__init__.py +0 -0
  48. d9d-0.1.0/d9d/kernel/cce/__init__.py +5 -0
  49. d9d-0.1.0/d9d/kernel/cce/cce.py +298 -0
  50. d9d-0.1.0/d9d/kernel/cce/main.py +282 -0
  51. d9d-0.1.0/d9d/kernel/general/__init__.py +5 -0
  52. d9d-0.1.0/d9d/kernel/general/get_int_dtype.py +7 -0
  53. d9d-0.1.0/d9d/kernel/gmm/__init__.py +5 -0
  54. d9d-0.1.0/d9d/kernel/gmm/function.py +78 -0
  55. d9d-0.1.0/d9d/kernel/moe/__init__.py +8 -0
  56. d9d-0.1.0/d9d/kernel/moe/indices_to_multihot.py +268 -0
  57. d9d-0.1.0/d9d/kernel/moe/permute_with_probs.py +1035 -0
  58. d9d-0.1.0/d9d/kernel/stochastic/__init__.py +11 -0
  59. d9d-0.1.0/d9d/kernel/stochastic/adamw_step.py +204 -0
  60. d9d-0.1.0/d9d/kernel/stochastic/copy.py +104 -0
  61. d9d-0.1.0/d9d/kernel/stochastic/ops/__init__.py +5 -0
  62. d9d-0.1.0/d9d/kernel/stochastic/ops/round.py +22 -0
  63. d9d-0.1.0/d9d/kernel/swiglu/__init__.py +5 -0
  64. d9d-0.1.0/d9d/kernel/swiglu/function.py +36 -0
  65. d9d-0.1.0/d9d/kernel/swiglu/op.py +167 -0
  66. d9d-0.1.0/d9d/loop/__init__.py +0 -0
  67. d9d-0.1.0/d9d/loop/auto/__init__.py +9 -0
  68. d9d-0.1.0/d9d/loop/auto/auto_lr_scheduler.py +46 -0
  69. d9d-0.1.0/d9d/loop/auto/auto_optimizer.py +196 -0
  70. d9d-0.1.0/d9d/loop/component/__init__.py +35 -0
  71. d9d-0.1.0/d9d/loop/component/batch_maths.py +106 -0
  72. d9d-0.1.0/d9d/loop/component/checkpointer.py +172 -0
  73. d9d-0.1.0/d9d/loop/component/data_loader_factory.py +258 -0
  74. d9d-0.1.0/d9d/loop/component/garbage_collector.py +94 -0
  75. d9d-0.1.0/d9d/loop/component/gradient_clipper.py +89 -0
  76. d9d-0.1.0/d9d/loop/component/gradient_manager.py +149 -0
  77. d9d-0.1.0/d9d/loop/component/job_logger.py +146 -0
  78. d9d-0.1.0/d9d/loop/component/job_profiler.py +62 -0
  79. d9d-0.1.0/d9d/loop/component/loss_computer.py +86 -0
  80. d9d-0.1.0/d9d/loop/component/model_stage_exporter.py +37 -0
  81. d9d-0.1.0/d9d/loop/component/model_stage_factory.py +261 -0
  82. d9d-0.1.0/d9d/loop/component/optimizer_factory.py +88 -0
  83. d9d-0.1.0/d9d/loop/component/stepper.py +52 -0
  84. d9d-0.1.0/d9d/loop/component/timeout_manager.py +54 -0
  85. d9d-0.1.0/d9d/loop/component/train_task_operator.py +152 -0
  86. d9d-0.1.0/d9d/loop/config/__init__.py +36 -0
  87. d9d-0.1.0/d9d/loop/config/config.py +225 -0
  88. d9d-0.1.0/d9d/loop/config/types.py +24 -0
  89. d9d-0.1.0/d9d/loop/control/__init__.py +61 -0
  90. d9d-0.1.0/d9d/loop/control/dataset_provider.py +58 -0
  91. d9d-0.1.0/d9d/loop/control/lr_scheduler_provider.py +47 -0
  92. d9d-0.1.0/d9d/loop/control/model_provider.py +162 -0
  93. d9d-0.1.0/d9d/loop/control/optimizer_provider.py +45 -0
  94. d9d-0.1.0/d9d/loop/control/task.py +304 -0
  95. d9d-0.1.0/d9d/loop/run/__init__.py +6 -0
  96. d9d-0.1.0/d9d/loop/run/train.py +355 -0
  97. d9d-0.1.0/d9d/loop/state.py +143 -0
  98. d9d-0.1.0/d9d/lr_scheduler/__init__.py +9 -0
  99. d9d-0.1.0/d9d/lr_scheduler/piecewise/__init__.py +18 -0
  100. d9d-0.1.0/d9d/lr_scheduler/piecewise/builder.py +152 -0
  101. d9d-0.1.0/d9d/lr_scheduler/piecewise/config.py +176 -0
  102. d9d-0.1.0/d9d/lr_scheduler/piecewise/curves.py +75 -0
  103. d9d-0.1.0/d9d/lr_scheduler/piecewise/engine.py +76 -0
  104. d9d-0.1.0/d9d/lr_scheduler/visualizer.py +74 -0
  105. d9d-0.1.0/d9d/metric/__init__.py +10 -0
  106. d9d-0.1.0/d9d/metric/abc.py +79 -0
  107. d9d-0.1.0/d9d/metric/impl/__init__.py +7 -0
  108. d9d-0.1.0/d9d/metric/impl/compose.py +54 -0
  109. d9d-0.1.0/d9d/metric/impl/mean.py +94 -0
  110. d9d-0.1.0/d9d/model_state/__init__.py +0 -0
  111. d9d-0.1.0/d9d/model_state/io/__init__.py +21 -0
  112. d9d-0.1.0/d9d/model_state/io/dto.py +30 -0
  113. d9d-0.1.0/d9d/model_state/io/module_reader.py +75 -0
  114. d9d-0.1.0/d9d/model_state/io/module_writer.py +123 -0
  115. d9d-0.1.0/d9d/model_state/io/reader.py +125 -0
  116. d9d-0.1.0/d9d/model_state/io/writer.py +309 -0
  117. d9d-0.1.0/d9d/model_state/mapper/__init__.py +10 -0
  118. d9d-0.1.0/d9d/model_state/mapper/abc.py +70 -0
  119. d9d-0.1.0/d9d/model_state/mapper/adapters/__init__.py +12 -0
  120. d9d-0.1.0/d9d/model_state/mapper/adapters/mapper.py +27 -0
  121. d9d-0.1.0/d9d/model_state/mapper/adapters/module.py +22 -0
  122. d9d-0.1.0/d9d/model_state/mapper/compose/__init__.py +17 -0
  123. d9d-0.1.0/d9d/model_state/mapper/compose/helper.py +22 -0
  124. d9d-0.1.0/d9d/model_state/mapper/compose/parallel.py +58 -0
  125. d9d-0.1.0/d9d/model_state/mapper/compose/sequential.py +131 -0
  126. d9d-0.1.0/d9d/model_state/mapper/compose/shard.py +36 -0
  127. d9d-0.1.0/d9d/model_state/mapper/leaf/__init__.py +18 -0
  128. d9d-0.1.0/d9d/model_state/mapper/leaf/dtensor.py +56 -0
  129. d9d-0.1.0/d9d/model_state/mapper/leaf/identity.py +23 -0
  130. d9d-0.1.0/d9d/model_state/mapper/leaf/rename.py +26 -0
  131. d9d-0.1.0/d9d/model_state/mapper/leaf/select_child.py +37 -0
  132. d9d-0.1.0/d9d/model_state/mapper/leaf/stack.py +29 -0
  133. d9d-0.1.0/d9d/module/__init__.py +0 -0
  134. d9d-0.1.0/d9d/module/base/__init__.py +7 -0
  135. d9d-0.1.0/d9d/module/base/late_init.py +10 -0
  136. d9d-0.1.0/d9d/module/block/__init__.py +0 -0
  137. d9d-0.1.0/d9d/module/block/attention/__init__.py +7 -0
  138. d9d-0.1.0/d9d/module/block/attention/grouped_query.py +139 -0
  139. d9d-0.1.0/d9d/module/block/attention/sdpa/__init__.py +5 -0
  140. d9d-0.1.0/d9d/module/block/attention/sdpa/flash.py +52 -0
  141. d9d-0.1.0/d9d/module/block/embedding/__init__.py +7 -0
  142. d9d-0.1.0/d9d/module/block/embedding/shard_token_embedding.py +103 -0
  143. d9d-0.1.0/d9d/module/block/ffn/__init__.py +5 -0
  144. d9d-0.1.0/d9d/module/block/ffn/swiglu.py +60 -0
  145. d9d-0.1.0/d9d/module/block/head/__init__.py +6 -0
  146. d9d-0.1.0/d9d/module/block/head/language_modelling.py +87 -0
  147. d9d-0.1.0/d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  148. d9d-0.1.0/d9d/module/block/hidden_states_aggregator/base.py +35 -0
  149. d9d-0.1.0/d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  150. d9d-0.1.0/d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  151. d9d-0.1.0/d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  152. d9d-0.1.0/d9d/module/block/moe/__init__.py +13 -0
  153. d9d-0.1.0/d9d/module/block/moe/communications/__init__.py +11 -0
  154. d9d-0.1.0/d9d/module/block/moe/communications/base.py +58 -0
  155. d9d-0.1.0/d9d/module/block/moe/communications/deepep.py +300 -0
  156. d9d-0.1.0/d9d/module/block/moe/communications/naive.py +68 -0
  157. d9d-0.1.0/d9d/module/block/moe/grouped_experts.py +81 -0
  158. d9d-0.1.0/d9d/module/block/moe/grouped_linear.py +78 -0
  159. d9d-0.1.0/d9d/module/block/moe/layer.py +122 -0
  160. d9d-0.1.0/d9d/module/block/moe/router.py +103 -0
  161. d9d-0.1.0/d9d/module/block/positional/__init__.py +8 -0
  162. d9d-0.1.0/d9d/module/block/positional/rope.py +150 -0
  163. d9d-0.1.0/d9d/module/model/__init__.py +0 -0
  164. d9d-0.1.0/d9d/module/model/qwen3_moe/__init__.py +16 -0
  165. d9d-0.1.0/d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  166. d9d-0.1.0/d9d/module/model/qwen3_moe/model.py +373 -0
  167. d9d-0.1.0/d9d/module/model/qwen3_moe/params.py +69 -0
  168. d9d-0.1.0/d9d/module/parallelism/__init__.py +0 -0
  169. d9d-0.1.0/d9d/module/parallelism/api/__init__.py +18 -0
  170. d9d-0.1.0/d9d/module/parallelism/api/expert_parallel.py +36 -0
  171. d9d-0.1.0/d9d/module/parallelism/api/fully_sharded.py +43 -0
  172. d9d-0.1.0/d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  173. d9d-0.1.0/d9d/module/parallelism/api/replicate_parallel.py +33 -0
  174. d9d-0.1.0/d9d/module/parallelism/model/__init__.py +0 -0
  175. d9d-0.1.0/d9d/module/parallelism/model/qwen3_moe.py +99 -0
  176. d9d-0.1.0/d9d/module/parallelism/style/__init__.py +7 -0
  177. d9d-0.1.0/d9d/module/parallelism/style/shard_experts.py +60 -0
  178. d9d-0.1.0/d9d/module/parallelism/style/to_local.py +86 -0
  179. d9d-0.1.0/d9d/optim/__init__.py +0 -0
  180. d9d-0.1.0/d9d/optim/stochastic/__init__.py +5 -0
  181. d9d-0.1.0/d9d/optim/stochastic/adamw.py +158 -0
  182. d9d-0.1.0/d9d/peft/__init__.py +13 -0
  183. d9d-0.1.0/d9d/peft/all/__init__.py +12 -0
  184. d9d-0.1.0/d9d/peft/all/config.py +31 -0
  185. d9d-0.1.0/d9d/peft/all/method.py +76 -0
  186. d9d-0.1.0/d9d/peft/applicator.py +47 -0
  187. d9d-0.1.0/d9d/peft/base.py +70 -0
  188. d9d-0.1.0/d9d/peft/full_tune/__init__.py +11 -0
  189. d9d-0.1.0/d9d/peft/full_tune/config.py +20 -0
  190. d9d-0.1.0/d9d/peft/full_tune/method.py +46 -0
  191. d9d-0.1.0/d9d/peft/lora/__init__.py +15 -0
  192. d9d-0.1.0/d9d/peft/lora/config.py +35 -0
  193. d9d-0.1.0/d9d/peft/lora/layer.py +177 -0
  194. d9d-0.1.0/d9d/peft/lora/method.py +132 -0
  195. d9d-0.1.0/d9d/pipelining/__init__.py +0 -0
  196. d9d-0.1.0/d9d/pipelining/api/__init__.py +19 -0
  197. d9d-0.1.0/d9d/pipelining/api/module.py +149 -0
  198. d9d-0.1.0/d9d/pipelining/api/schedule.py +50 -0
  199. d9d-0.1.0/d9d/pipelining/api/sharding.py +9 -0
  200. d9d-0.1.0/d9d/pipelining/factory/__init__.py +21 -0
  201. d9d-0.1.0/d9d/pipelining/factory/config.py +89 -0
  202. d9d-0.1.0/d9d/pipelining/factory/factory.py +114 -0
  203. d9d-0.1.0/d9d/pipelining/factory/registry.py +82 -0
  204. d9d-0.1.0/d9d/pipelining/infra/__init__.py +0 -0
  205. d9d-0.1.0/d9d/pipelining/infra/schedule/__init__.py +0 -0
  206. d9d-0.1.0/d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  207. d9d-0.1.0/d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  208. d9d-0.1.0/d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  209. d9d-0.1.0/d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  210. d9d-0.1.0/d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  211. d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  212. d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  213. d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  214. d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  215. d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  216. d9d-0.1.0/d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  217. d9d-0.1.0/d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  218. d9d-0.1.0/d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  219. d9d-0.1.0/d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  220. d9d-0.1.0/d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  221. d9d-0.1.0/d9d/pipelining/infra/stage/__init__.py +5 -0
  222. d9d-0.1.0/d9d/pipelining/infra/stage/communications.py +274 -0
  223. d9d-0.1.0/d9d/pipelining/infra/stage/computations.py +317 -0
  224. d9d-0.1.0/d9d/pipelining/infra/stage/splitgrad.py +377 -0
  225. d9d-0.1.0/d9d/pipelining/infra/stage/stage.py +321 -0
  226. d9d-0.1.0/d9d/pipelining/infra/stage/struct_helper.py +46 -0
  227. d9d-0.1.0/d9d/pipelining/training/__init__.py +7 -0
  228. d9d-0.1.0/d9d/pipelining/training/optimizer.py +41 -0
  229. d9d-0.1.0/d9d/pipelining/training/scheduler.py +34 -0
  230. d9d-0.1.0/d9d/tracker/__init__.py +14 -0
  231. d9d-0.1.0/d9d/tracker/base.py +124 -0
  232. d9d-0.1.0/d9d/tracker/factory.py +57 -0
  233. d9d-0.1.0/d9d/tracker/provider/__init__.py +0 -0
  234. d9d-0.1.0/d9d/tracker/provider/aim/__init__.py +0 -0
  235. d9d-0.1.0/d9d/tracker/provider/aim/config.py +23 -0
  236. d9d-0.1.0/d9d/tracker/provider/aim/tracker.py +114 -0
  237. d9d-0.1.0/d9d/tracker/provider/null.py +61 -0
  238. d9d-0.1.0/pyproject.toml +267 -0
d9d-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,90 @@
1
+ Metadata-Version: 2.4
2
+ Name: d9d
3
+ Version: 0.1.0
4
+ Summary: d9d - d[istribute]d - distributed training framework based on PyTorch that tries to be efficient yet hackable
5
+ License: Apache-2.0
6
+ Author: Maksim Afanasyev
7
+ Author-email: mr.applexz@gmail.com
8
+ Requires-Python: >=3.11,<3.15
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Intended Audience :: Education
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Topic :: Scientific/Engineering
14
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Classifier: Topic :: Software Development
17
+ Classifier: Topic :: Software Development :: Libraries
18
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
19
+ Provides-Extra: aim
20
+ Provides-Extra: cce
21
+ Provides-Extra: moe
22
+ Provides-Extra: visualization
23
+ Requires-Dist: aim (>=3.0.0,<4.0.0) ; extra == "aim"
24
+ Requires-Dist: cut-cross-entropy (>=25.9.3) ; extra == "cce"
25
+ Requires-Dist: deep-ep (>=1.2.1) ; extra == "moe"
26
+ Requires-Dist: nv-grouped-gemm (>=1.1.4) ; extra == "moe"
27
+ Requires-Dist: plotly (>=6.0.0) ; extra == "visualization"
28
+ Requires-Dist: pydantic (>=2.0.0)
29
+ Requires-Dist: safetensors (>=0.7.0)
30
+ Requires-Dist: setuptools (>=70.0.0) ; extra == "aim"
31
+ Requires-Dist: torch (>=2.10.0)
32
+ Requires-Dist: torchdata (>=0.11.0)
33
+ Requires-Dist: tqdm (>=4.0.0)
34
+ Requires-Dist: triton (>=3.6.0)
35
+ Project-URL: Documentation, https://d9d-project.github.io/d9d
36
+ Project-URL: Homepage, https://d9d-project.github.io/d9d
37
+ Project-URL: Issues, https://github.com/d9d-project/d9d/issues
38
+ Project-URL: Repository, https://github.com/d9d-project/d9d
39
+ Description-Content-Type: text/markdown
40
+
41
+ # The d9d Project
42
+
43
+ **d9d** is a distributed training framework built on top of PyTorch 2.0. It aims to be hackable, modular, and efficient, designed to scale from single-GPU debugging to massive clusters running 6D-Parallelism.
44
+
45
+ [LET'S START TRAINING 🚀](https://d9d-project.github.io/d9d/)
46
+
47
+ ## Why another framework?
48
+
49
+ Distributed training frameworks such as **Megatron-LM** are monolithic in the way you run a script from the command line to train any of a set of *predefined* models, using *predefined* regimes. While powerful, these systems can be difficult to hack and integrate into novel research workflows. Their focus is often on providing a complete, end-to-end solution, which can limit flexibility for experimentally-driven research.
50
+
51
+ Conversely, creating your own distributed training solution from scratch is tricky. You have to implement many low-level components (like distributed checkpoints and synchronization) that are identical across setups, and manually tackle common performance bottlenecks.
52
+
53
+ **d9d** was designed to fill the gap between monolithic frameworks and homebrew setups, providing a modular yet effective solution for distributed training.
54
+
55
+ ## What d9d is and isn't
56
+
57
+ In terms of **core concept**:
58
+
59
+ * **IS** a pluggable framework for implementing distributed training regimes for your deep learning models.
60
+ * **IS** built on clear interfaces and building blocks that may be composed and implemented in your own way.
61
+ * **IS NOT** an all-in-one CLI platform for setting up pre-training and post-training like **torchtitan**, **Megatron-LM**, or **torchforge**.
62
+
63
+ In terms of **codebase & engineering**:
64
+
65
+ * **IS** built on a **strong engineering foundation**: We enforce strict type-checking and rigorous linting to catch errors before execution.
66
+ * **IS** reliable: The framework is backed by a suite of **over 450 tests**, covering unit logic, integration flows, and End-to-End distributed scenarios.
67
+ * **IS** eager to use performance hacks (like **DeepEp** or custom kernels) if they improve MFU, even if they aren't PyTorch-native.
68
+ * **IS NOT** for legacy setups: We do not maintain backward compatibility with older PyTorch versions or hardware. We prioritize simplicity and modern APIs (like `DTensor`).
69
+
70
+ ## Key Philosophies
71
+
72
+ To achieve the balance between hackability and performance, d9d adheres to specific design principles:
73
+
74
+ * **Composition over Monoliths**: We avoid "God Classes" like `DistributedDataParallel` or `ParallelDims` that assume ownership of the entire execution loop. Instead, we provide composable and extendable APIs. For instance, specific horizontal parallelism strategies for specific layers (`parallelize_replicate`, `parallelize_expert_parallel`, ...).
75
+ * **White-Box Modelling**: We encourage standard PyTorch code. Models are not wrapped in obscure metadata specifications; they are standard `nn.Module`s that implement lightweight protocols.
76
+ * **Pragmatic Efficiency**: While we prefer native PyTorch, we are eager to integrate non-native solutions if they improve MFU. For example, we implement MoE using **DeepEp** communications, reindexing kernels from **Megatron-LM**, and efficient grouped-GEMM implementations.
77
+ * **Graph-Based State Management**: Our IO system treats model checkpoints as directed acyclic graphs. This allows you to transform architectures (e.g., merging `q`, `k`, `v` into `qkv`) on-the-fly while streaming from disk, without massive memory overhead.
78
+ * **DTensors**: We mandate that distributed parameters be represented as `torch.distributed.tensor.DTensor`. This simplifies checkpointing by making them topology-aware automatically. We leverage modern PyTorch 2.0 APIs (`DeviceMesh`) as much as possible.
79
+
80
+ ---
81
+
82
+ ## Examples
83
+
84
+ ### Qwen3-MoE Pretraining
85
+ An example showing causal LM pretraing for the Qwen3-MoE model.
86
+
87
+ WIP: MoE load balancing is currently work in progress.
88
+
89
+ [Link](https://github.com/d9d-project/d9d/blob/main/example/qwen3_moe/pretrain.py).
90
+
d9d-0.1.0/README.md ADDED
@@ -0,0 +1,49 @@
1
+ # The d9d Project
2
+
3
+ **d9d** is a distributed training framework built on top of PyTorch 2.0. It aims to be hackable, modular, and efficient, designed to scale from single-GPU debugging to massive clusters running 6D-Parallelism.
4
+
5
+ [LET'S START TRAINING 🚀](https://d9d-project.github.io/d9d/)
6
+
7
+ ## Why another framework?
8
+
9
+ Distributed training frameworks such as **Megatron-LM** are monolithic in the way you run a script from the command line to train any of a set of *predefined* models, using *predefined* regimes. While powerful, these systems can be difficult to hack and integrate into novel research workflows. Their focus is often on providing a complete, end-to-end solution, which can limit flexibility for experimentally-driven research.
10
+
11
+ Conversely, creating your own distributed training solution from scratch is tricky. You have to implement many low-level components (like distributed checkpoints and synchronization) that are identical across setups, and manually tackle common performance bottlenecks.
12
+
13
+ **d9d** was designed to fill the gap between monolithic frameworks and homebrew setups, providing a modular yet effective solution for distributed training.
14
+
15
+ ## What d9d is and isn't
16
+
17
+ In terms of **core concept**:
18
+
19
+ * **IS** a pluggable framework for implementing distributed training regimes for your deep learning models.
20
+ * **IS** built on clear interfaces and building blocks that may be composed and implemented in your own way.
21
+ * **IS NOT** an all-in-one CLI platform for setting up pre-training and post-training like **torchtitan**, **Megatron-LM**, or **torchforge**.
22
+
23
+ In terms of **codebase & engineering**:
24
+
25
+ * **IS** built on a **strong engineering foundation**: We enforce strict type-checking and rigorous linting to catch errors before execution.
26
+ * **IS** reliable: The framework is backed by a suite of **over 450 tests**, covering unit logic, integration flows, and End-to-End distributed scenarios.
27
+ * **IS** eager to use performance hacks (like **DeepEp** or custom kernels) if they improve MFU, even if they aren't PyTorch-native.
28
+ * **IS NOT** for legacy setups: We do not maintain backward compatibility with older PyTorch versions or hardware. We prioritize simplicity and modern APIs (like `DTensor`).
29
+
30
+ ## Key Philosophies
31
+
32
+ To achieve the balance between hackability and performance, d9d adheres to specific design principles:
33
+
34
+ * **Composition over Monoliths**: We avoid "God Classes" like `DistributedDataParallel` or `ParallelDims` that assume ownership of the entire execution loop. Instead, we provide composable and extendable APIs. For instance, specific horizontal parallelism strategies for specific layers (`parallelize_replicate`, `parallelize_expert_parallel`, ...).
35
+ * **White-Box Modelling**: We encourage standard PyTorch code. Models are not wrapped in obscure metadata specifications; they are standard `nn.Module`s that implement lightweight protocols.
36
+ * **Pragmatic Efficiency**: While we prefer native PyTorch, we are eager to integrate non-native solutions if they improve MFU. For example, we implement MoE using **DeepEp** communications, reindexing kernels from **Megatron-LM**, and efficient grouped-GEMM implementations.
37
+ * **Graph-Based State Management**: Our IO system treats model checkpoints as directed acyclic graphs. This allows you to transform architectures (e.g., merging `q`, `k`, `v` into `qkv`) on-the-fly while streaming from disk, without massive memory overhead.
38
+ * **DTensors**: We mandate that distributed parameters be represented as `torch.distributed.tensor.DTensor`. This simplifies checkpointing by making them topology-aware automatically. We leverage modern PyTorch 2.0 APIs (`DeviceMesh`) as much as possible.
39
+
40
+ ---
41
+
42
+ ## Examples
43
+
44
+ ### Qwen3-MoE Pretraining
45
+ An example showing causal LM pretraing for the Qwen3-MoE model.
46
+
47
+ WIP: MoE load balancing is currently work in progress.
48
+
49
+ [Link](https://github.com/d9d-project/d9d/blob/main/example/qwen3_moe/pretrain.py).
File without changes
File without changes
@@ -0,0 +1,7 @@
1
+ from .grad_context import GLOBAL_GRAD_CONTEXT, GlobalGradContext, GradDirection
2
+
3
+ __all__ = [
4
+ "GLOBAL_GRAD_CONTEXT",
5
+ "GlobalGradContext",
6
+ "GradDirection"
7
+ ]
@@ -0,0 +1,85 @@
1
+ from contextlib import contextmanager
2
+ from enum import StrEnum
3
+
4
+
5
+ class GradDirection(StrEnum):
6
+ """
7
+ Enum representing the specific gradient edges to compute.
8
+
9
+ This is used to manually control gradient flow in custom autograd functions
10
+ during split backward passes.
11
+
12
+ Attributes:
13
+ inputs: Mark gradient edge as pointing to the module's inputs (activations).
14
+ weight: Mark gradient edge as pointing to the module's parameters (weights).
15
+ """
16
+
17
+ inputs = "inputs"
18
+ weight = "weights"
19
+
20
+
21
+ class GlobalGradContext:
22
+ """
23
+ Global state manager for controlling gradient computation in custom autograd functions.
24
+
25
+ This context addresses a limitation in PyTorch where custom `torch.autograd.Function`
26
+ implementations set `ctx.needs_input_grad` to True for all edges requiring grad,
27
+ even during partial backward passes (e.g., `torch.autograd.backward(inputs=...)`).
28
+
29
+ For additional information on this limitation, please refer to a
30
+ [related issue](https://github.com/pytorch/pytorch/issues/174017).
31
+
32
+ This class allows:
33
+
34
+ 1. For the training code - to explicitly signal which gradient edges (inputs vs weights)
35
+ should currently be computed, allowing custom ops to skip unnecessary computations.
36
+ 2. For module code - to check whether it's required to compute a gradient edge.
37
+ """
38
+
39
+ def __init__(self):
40
+ """Constructs a GlobalGradContext object with all directions enabled by default."""
41
+
42
+ # both directions by default
43
+ self._enabled_directions: set[GradDirection] = {GradDirection.inputs, GradDirection.weight}
44
+
45
+ def check_direction(self, direction: GradDirection | None) -> bool:
46
+ """
47
+ Checks if the gradient calculation for the given direction is currently enabled.
48
+
49
+ Args:
50
+ direction: The direction to check (inputs or weights). If None,
51
+ returns True.
52
+
53
+ Returns:
54
+ True if the direction is enabled or None is passed, False otherwise.
55
+ """
56
+
57
+ if direction is None:
58
+ return True
59
+
60
+ return direction in self._enabled_directions
61
+
62
+ @contextmanager
63
+ def with_directions(self, *directions: GradDirection):
64
+ """
65
+ Context manager that sets the enabled gradient directions.
66
+
67
+ This overrides the current state for the duration of the context
68
+ and restores the previous state afterwards.
69
+
70
+ Args:
71
+ *directions: The gradient directions to enable.
72
+ """
73
+ prev_directions = self._enabled_directions
74
+ self._enabled_directions = set(directions)
75
+ yield
76
+ self._enabled_directions = prev_directions
77
+
78
+
79
+ GLOBAL_GRAD_CONTEXT = GlobalGradContext()
80
+ """
81
+ The singleton instance of GlobalGradContext.
82
+
83
+ This should be used by custom autograd functions to check `GLOBAL_GRAD_CONTEXT.check_direction()`
84
+ during their backward pass.
85
+ """
@@ -0,0 +1,19 @@
1
+ """
2
+ This package configures the distributed environment and device meshes.
3
+ """
4
+
5
+ from .configured import DistributedContext
6
+ from .device_mesh_domains import BATCH_DOMAIN, DENSE_DOMAIN, EXPERT_DOMAIN, FLAT_DOMAIN, REGULAR_DOMAIN
7
+ from .log import build_dist_logger
8
+ from .params import DeviceMeshParameters
9
+
10
+ __all__ = [
11
+ "BATCH_DOMAIN",
12
+ "DENSE_DOMAIN",
13
+ "EXPERT_DOMAIN",
14
+ "FLAT_DOMAIN",
15
+ "REGULAR_DOMAIN",
16
+ "DeviceMeshParameters",
17
+ "DistributedContext",
18
+ "build_dist_logger"
19
+ ]
@@ -0,0 +1,215 @@
1
+ import datetime
2
+ import logging
3
+ import os
4
+ import socket
5
+ from contextlib import contextmanager
6
+ from typing import TYPE_CHECKING
7
+
8
+ import torch
9
+ from torch.distributed import DeviceMesh
10
+
11
+ from .device_mesh_domains import ALL_DOMAIN_PROVIDERS, REGULAR_DOMAIN
12
+ from .log import build_dist_logger
13
+
14
+ if TYPE_CHECKING:
15
+ from .params import DeviceMeshParameters
16
+
17
+
18
+ def _resolve_master_addr() -> str:
19
+ if "MASTER_ADDR" not in os.environ:
20
+ return "127.0.0.1"
21
+
22
+ master_addr = os.environ["MASTER_ADDR"]
23
+
24
+ try:
25
+ return socket.gethostbyname(master_addr)
26
+ except OSError:
27
+ return master_addr
28
+
29
+
30
+ def _build_mesh_domains(params: "DeviceMeshParameters") -> dict[str, DeviceMesh]:
31
+ return {
32
+ provider.name: provider.build_mesh(params)
33
+ for provider in ALL_DOMAIN_PROVIDERS
34
+ }
35
+
36
+
37
+ class DistributedContext:
38
+ """
39
+ Acts as the single source of truth for the distributed execution environment.
40
+
41
+ It acts as the central repository for the distributed configuration, managing the creation
42
+ and synchronization of PyTorch DeviceMeshes for different domains (Regular domain, Expert Parallel domain, ...).
43
+
44
+ All assertions regarding rank placement, group memberships, and parallel topology
45
+ must be derived from this context to ensure consistency.
46
+ """
47
+
48
+ def __init__(self, params: "DeviceMeshParameters", log_level: int):
49
+ self._params = params
50
+
51
+ if params.is_distributed:
52
+ meshes = _build_mesh_domains(params)
53
+ regular_mesh = meshes[REGULAR_DOMAIN]
54
+
55
+ self._meshes = meshes
56
+ self._num_nodes = regular_mesh.size() // torch.cuda.device_count()
57
+ self._logger = build_dist_logger(
58
+ f'pp:{regular_mesh.get_local_rank("pp")}-'
59
+ f'dpr:{regular_mesh.get_local_rank("dp_replicate")}-'
60
+ f'dps:{regular_mesh.get_local_rank("dp_shard")}-'
61
+ f'cps:{regular_mesh.get_local_rank("cp_shard")}-'
62
+ f'cpr:{regular_mesh.get_local_rank("cp_replicate")}-'
63
+ f'tp:{regular_mesh.get_local_rank("tp")}',
64
+ level=log_level
65
+ )
66
+ else:
67
+ self._meshes = {}
68
+ self._num_nodes = 1
69
+ self._logger = build_dist_logger("local", level=log_level)
70
+
71
+ self._local_rank = int(os.environ.get("LOCAL_RANK", "0"))
72
+ self._global_rank = int(os.environ.get("RANK", "0"))
73
+
74
+ self._node_rank = self._global_rank // torch.cuda.device_count()
75
+
76
+ self._master_addr = _resolve_master_addr()
77
+ self._current_device = torch.device("cuda")
78
+
79
+ torch.cuda.set_device(self._local_rank)
80
+
81
+ @property
82
+ def logger(self) -> logging.Logger:
83
+ """Returns the logger instance configured for distributed logging."""
84
+
85
+ return self._logger
86
+
87
+ def mesh_for(self, domain: str) -> DeviceMesh:
88
+ """
89
+ Returns the device mesh view associated with a specific logical domain.
90
+
91
+ Available Domains and Dimensions:
92
+ * `regular` (`REGULAR_DOMAIN`): The most granular mesh for fully decomposed parallelism.
93
+ Dimensions: ``('pp', 'dp_replicate', 'dp_shard', 'cp_shard', 'cp_replicate', 'tp')``
94
+ * `expert` (`EXPERT_DOMAIN`): Mesh optimized for distributing MoE (Mixture of Experts) layers.
95
+ Dimensions: ``('pp', 'replicate', 'ep')``
96
+ * `dense` (`DENSE_DOMAIN`): Mesh optimized for distributing dense layers.
97
+ Dimensions: ``('pp', 'dp_replicate', 'dp_cp_shard', 'cp_replicate', 'tp')``
98
+ * `batch` (`BATCH_DOMAIN`): Mesh optimized for distributing input data.
99
+ Dimensions: ``('pp', 'dp', 'cp', 'tp')``
100
+ * `flat` (`FLAT_DOMAIN`): Mesh containing a single dimension with all the processes.
101
+ Dimensions: ``('world')``
102
+
103
+ Args:
104
+ domain: The name of the domain to retrieve.
105
+
106
+ Returns:
107
+ The PyTorch DeviceMesh configured for the requested domain.
108
+
109
+ Raises:
110
+ ValueError: If the specified domain does not exist.
111
+ """
112
+
113
+ if domain not in self._meshes:
114
+ raise ValueError(f"Domain {domain} does not exist")
115
+ return self._meshes[domain]
116
+
117
+ @property
118
+ def is_main_process(self) -> bool:
119
+ """Checks if the current process is the global rank 0."""
120
+
121
+ return self._global_rank == 0
122
+
123
+ @property
124
+ def is_local_main_process(self) -> bool:
125
+ """Checks if the current process is the rank 0 on the specific node."""
126
+
127
+ return self._local_rank == 0
128
+
129
+ def wait_world(self):
130
+ """Blocks process execution until all ranks reach this point."""
131
+
132
+ torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
133
+ torch.cuda.synchronize()
134
+
135
+ def set_timeout(self, timeout_seconds: float):
136
+ """
137
+ Updates the NCCL/process group timeout for all underlying meshes.
138
+
139
+ Args:
140
+ timeout_seconds: New timeout duration in seconds.
141
+ """
142
+
143
+ self.logger.info(f"Setting global timeout to {timeout_seconds} seconds")
144
+ self.wait_world()
145
+
146
+ groups: list[torch.distributed.ProcessGroup | None] = [None]
147
+ for mesh in self._meshes.values():
148
+ for dim in range(mesh.ndim):
149
+ groups.append(mesh.get_group(dim))
150
+
151
+ for group in groups:
152
+ torch.distributed.distributed_c10d._set_pg_timeout(datetime.timedelta(seconds=timeout_seconds), group) # noqa: SLF001
153
+
154
+ @contextmanager
155
+ def local_main_process_first(self):
156
+ """
157
+ Context manager that executes the block on the local main process first.
158
+
159
+ Other local ranks wait at the entrance. The local main process waits at the
160
+ exit to synchronize before continuing.
161
+ """
162
+ if not self.is_local_main_process:
163
+ self.wait_world()
164
+
165
+ yield
166
+
167
+ if self.is_local_main_process:
168
+ self.wait_world()
169
+
170
+ @contextmanager
171
+ def main_process_first(self):
172
+ """
173
+ Context manager that executes the block on the global main process first.
174
+
175
+ All other ranks wait at the entrance. The global main process waits at the
176
+ exit to synchronize before continuing.
177
+ """
178
+
179
+ if not self.is_main_process:
180
+ self.wait_world()
181
+
182
+ yield
183
+
184
+ if self.is_main_process:
185
+ self.wait_world()
186
+
187
+ @property
188
+ def current_device(self) -> torch.device:
189
+ """Returns the CUDA device associated with this rank."""
190
+
191
+ return self._current_device
192
+
193
+ @property
194
+ def mesh_params(self) -> "DeviceMeshParameters":
195
+ """Returns the parameters used to initialize this context."""
196
+
197
+ return self._params
198
+
199
+ @property
200
+ def master_addr(self) -> str:
201
+ """Returns the IP address or domain name of the master node."""
202
+
203
+ return self._master_addr
204
+
205
+ @property
206
+ def node_rank(self) -> int:
207
+ """Returns the index of the node this process is running on."""
208
+
209
+ return self._node_rank
210
+
211
+ @property
212
+ def num_nodes(self) -> int:
213
+ """Returns the total number of nodes in the cluster."""
214
+
215
+ return self._num_nodes
@@ -0,0 +1,185 @@
1
+ import abc
2
+ from typing import TYPE_CHECKING
3
+
4
+ from torch.distributed import DeviceMesh, init_device_mesh
5
+
6
+ if TYPE_CHECKING:
7
+ from .params import DeviceMeshParameters
8
+
9
+
10
+ class DeviceMeshDomain(abc.ABC):
11
+ """
12
+ Abstract base class for a Device Mesh provider.
13
+
14
+ A Domain defines a specific strategy for organizing available GPUs into a
15
+ multidimensional grid (Mesh) to support specific parallelism techniques.
16
+ """
17
+
18
+ @property
19
+ @abc.abstractmethod
20
+ def name(self) -> str:
21
+ """Returns the unique identifier for this mesh domain."""
22
+
23
+ ...
24
+
25
+ @abc.abstractmethod
26
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
27
+ """
28
+ Constructs the device mesh configuration.
29
+
30
+ Args:
31
+ params: Global configuration parameters for the distributed environment.
32
+
33
+ Returns:
34
+ The initialized PyTorch DeviceMesh for this specific domain.
35
+ """
36
+
37
+ ...
38
+
39
+
40
+ REGULAR_DOMAIN = "regular"
41
+
42
+
43
+ class RegularDomain(DeviceMeshDomain):
44
+ @property
45
+ def name(self) -> str:
46
+ return "regular"
47
+
48
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
49
+ return init_device_mesh(
50
+ device_type="cuda",
51
+ mesh_shape=(
52
+ params.pipeline_parallel,
53
+ params.data_parallel_replicate,
54
+ params.data_parallel_shard,
55
+ params.context_parallel_shard,
56
+ params.context_parallel_replicate,
57
+ params.tensor_parallel
58
+ ),
59
+ mesh_dim_names=(
60
+ "pp",
61
+ "dp_replicate",
62
+ "dp_shard",
63
+ "cp_shard",
64
+ "cp_replicate",
65
+ "tp"
66
+ )
67
+ )
68
+
69
+
70
+ EXPERT_DOMAIN = "expert"
71
+
72
+
73
+ class ExpertDomain(DeviceMeshDomain):
74
+ @property
75
+ def name(self) -> str:
76
+ return EXPERT_DOMAIN
77
+
78
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
79
+ replicate_degree = (
80
+ params.data_parallel_replicate *
81
+ params.context_parallel_replicate *
82
+ params.data_parallel_shard *
83
+ params.context_parallel_shard
84
+ )
85
+ return init_device_mesh(
86
+ device_type="cuda",
87
+ mesh_shape=(
88
+ params.pipeline_parallel,
89
+ replicate_degree // params.expert_parallel,
90
+ params.expert_parallel
91
+ ),
92
+ mesh_dim_names=(
93
+ "pp",
94
+ "ep_replicate",
95
+ "ep_shard"
96
+ )
97
+ )
98
+
99
+
100
+ DENSE_DOMAIN = "dense"
101
+
102
+
103
+ class DenseDomain(DeviceMeshDomain):
104
+ @property
105
+ def name(self) -> str:
106
+ return DENSE_DOMAIN
107
+
108
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
109
+ return init_device_mesh(
110
+ device_type="cuda",
111
+ mesh_shape=(
112
+ params.pipeline_parallel,
113
+ params.data_parallel_replicate,
114
+ params.data_parallel_shard * params.context_parallel_shard,
115
+ params.context_parallel_replicate,
116
+ params.tensor_parallel
117
+ ),
118
+ mesh_dim_names=(
119
+ "pp",
120
+ "dp_replicate",
121
+ "dp_cp_shard",
122
+ "cp_replicate",
123
+ "tp"
124
+ )
125
+ )
126
+
127
+
128
+ BATCH_DOMAIN = "batch"
129
+
130
+
131
+ class BatchDomain(DeviceMeshDomain):
132
+ @property
133
+ def name(self) -> str:
134
+ return BATCH_DOMAIN
135
+
136
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
137
+ return init_device_mesh(
138
+ device_type="cuda",
139
+ mesh_shape=(
140
+ params.pipeline_parallel,
141
+ params.data_parallel_replicate * params.data_parallel_shard,
142
+ params.context_parallel_replicate * params.context_parallel_shard,
143
+ params.tensor_parallel
144
+ ),
145
+ mesh_dim_names=(
146
+ "pp",
147
+ "dp",
148
+ "cp",
149
+ "tp"
150
+ )
151
+ )
152
+
153
+
154
+ FLAT_DOMAIN = "flat"
155
+
156
+
157
+ class FlatDomain(DeviceMeshDomain):
158
+ @property
159
+ def name(self) -> str:
160
+ return FLAT_DOMAIN
161
+
162
+ def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
163
+ mesh_shape = (
164
+ params.pipeline_parallel *
165
+ params.data_parallel_replicate *
166
+ params.data_parallel_shard *
167
+ params.context_parallel_replicate *
168
+ params.context_parallel_shard *
169
+ params.tensor_parallel
170
+ )
171
+ return init_device_mesh(
172
+ device_type="cuda",
173
+ mesh_shape=(
174
+ mesh_shape,
175
+ ),
176
+ mesh_dim_names=(
177
+ "world",
178
+ )
179
+ )
180
+
181
+
182
+ ALL_DOMAIN_PROVIDERS: list[DeviceMeshDomain] = [
183
+ RegularDomain(), DenseDomain(), ExpertDomain(), BatchDomain(),
184
+ FlatDomain()
185
+ ]