d9d 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d-0.1.0/PKG-INFO +90 -0
- d9d-0.1.0/README.md +49 -0
- d9d-0.1.0/d9d/__init__.py +0 -0
- d9d-0.1.0/d9d/core/__init__.py +0 -0
- d9d-0.1.0/d9d/core/autograd/__init__.py +7 -0
- d9d-0.1.0/d9d/core/autograd/grad_context.py +85 -0
- d9d-0.1.0/d9d/core/dist_context/__init__.py +19 -0
- d9d-0.1.0/d9d/core/dist_context/configured.py +215 -0
- d9d-0.1.0/d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d-0.1.0/d9d/core/dist_context/log.py +30 -0
- d9d-0.1.0/d9d/core/dist_context/params.py +113 -0
- d9d-0.1.0/d9d/core/dist_ops/__init__.py +16 -0
- d9d-0.1.0/d9d/core/dist_ops/object.py +68 -0
- d9d-0.1.0/d9d/core/dist_ops/tensor.py +192 -0
- d9d-0.1.0/d9d/core/protocol/__init__.py +8 -0
- d9d-0.1.0/d9d/core/protocol/training.py +38 -0
- d9d-0.1.0/d9d/core/sharding/__init__.py +15 -0
- d9d-0.1.0/d9d/core/sharding/auto_spec.py +66 -0
- d9d-0.1.0/d9d/core/sharding/shard.py +154 -0
- d9d-0.1.0/d9d/core/sharding/spec.py +28 -0
- d9d-0.1.0/d9d/core/sharding/unshard.py +117 -0
- d9d-0.1.0/d9d/core/types/__init__.py +12 -0
- d9d-0.1.0/d9d/core/types/data.py +14 -0
- d9d-0.1.0/d9d/core/types/pytree.py +26 -0
- d9d-0.1.0/d9d/dataset/__init__.py +17 -0
- d9d-0.1.0/d9d/dataset/buffer_sorted.py +143 -0
- d9d-0.1.0/d9d/dataset/padding.py +79 -0
- d9d-0.1.0/d9d/dataset/sharded.py +195 -0
- d9d-0.1.0/d9d/internals/__init__.py +0 -0
- d9d-0.1.0/d9d/internals/determinism/__init__.py +10 -0
- d9d-0.1.0/d9d/internals/determinism/seed.py +63 -0
- d9d-0.1.0/d9d/internals/grad_norm/__init__.py +8 -0
- d9d-0.1.0/d9d/internals/grad_norm/group.py +87 -0
- d9d-0.1.0/d9d/internals/grad_norm/norm.py +169 -0
- d9d-0.1.0/d9d/internals/grad_sync/__init__.py +14 -0
- d9d-0.1.0/d9d/internals/grad_sync/bucket.py +317 -0
- d9d-0.1.0/d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d-0.1.0/d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d-0.1.0/d9d/internals/pipeline_state/__init__.py +14 -0
- d9d-0.1.0/d9d/internals/pipeline_state/api.py +45 -0
- d9d-0.1.0/d9d/internals/pipeline_state/handler.py +111 -0
- d9d-0.1.0/d9d/internals/pipeline_state/storage.py +236 -0
- d9d-0.1.0/d9d/internals/profiling/__init__.py +7 -0
- d9d-0.1.0/d9d/internals/profiling/profile.py +112 -0
- d9d-0.1.0/d9d/internals/state/__init__.py +6 -0
- d9d-0.1.0/d9d/internals/state/main_process.py +44 -0
- d9d-0.1.0/d9d/kernel/__init__.py +0 -0
- d9d-0.1.0/d9d/kernel/cce/__init__.py +5 -0
- d9d-0.1.0/d9d/kernel/cce/cce.py +298 -0
- d9d-0.1.0/d9d/kernel/cce/main.py +282 -0
- d9d-0.1.0/d9d/kernel/general/__init__.py +5 -0
- d9d-0.1.0/d9d/kernel/general/get_int_dtype.py +7 -0
- d9d-0.1.0/d9d/kernel/gmm/__init__.py +5 -0
- d9d-0.1.0/d9d/kernel/gmm/function.py +78 -0
- d9d-0.1.0/d9d/kernel/moe/__init__.py +8 -0
- d9d-0.1.0/d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d-0.1.0/d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d-0.1.0/d9d/kernel/stochastic/__init__.py +11 -0
- d9d-0.1.0/d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d-0.1.0/d9d/kernel/stochastic/copy.py +104 -0
- d9d-0.1.0/d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d-0.1.0/d9d/kernel/stochastic/ops/round.py +22 -0
- d9d-0.1.0/d9d/kernel/swiglu/__init__.py +5 -0
- d9d-0.1.0/d9d/kernel/swiglu/function.py +36 -0
- d9d-0.1.0/d9d/kernel/swiglu/op.py +167 -0
- d9d-0.1.0/d9d/loop/__init__.py +0 -0
- d9d-0.1.0/d9d/loop/auto/__init__.py +9 -0
- d9d-0.1.0/d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d-0.1.0/d9d/loop/auto/auto_optimizer.py +196 -0
- d9d-0.1.0/d9d/loop/component/__init__.py +35 -0
- d9d-0.1.0/d9d/loop/component/batch_maths.py +106 -0
- d9d-0.1.0/d9d/loop/component/checkpointer.py +172 -0
- d9d-0.1.0/d9d/loop/component/data_loader_factory.py +258 -0
- d9d-0.1.0/d9d/loop/component/garbage_collector.py +94 -0
- d9d-0.1.0/d9d/loop/component/gradient_clipper.py +89 -0
- d9d-0.1.0/d9d/loop/component/gradient_manager.py +149 -0
- d9d-0.1.0/d9d/loop/component/job_logger.py +146 -0
- d9d-0.1.0/d9d/loop/component/job_profiler.py +62 -0
- d9d-0.1.0/d9d/loop/component/loss_computer.py +86 -0
- d9d-0.1.0/d9d/loop/component/model_stage_exporter.py +37 -0
- d9d-0.1.0/d9d/loop/component/model_stage_factory.py +261 -0
- d9d-0.1.0/d9d/loop/component/optimizer_factory.py +88 -0
- d9d-0.1.0/d9d/loop/component/stepper.py +52 -0
- d9d-0.1.0/d9d/loop/component/timeout_manager.py +54 -0
- d9d-0.1.0/d9d/loop/component/train_task_operator.py +152 -0
- d9d-0.1.0/d9d/loop/config/__init__.py +36 -0
- d9d-0.1.0/d9d/loop/config/config.py +225 -0
- d9d-0.1.0/d9d/loop/config/types.py +24 -0
- d9d-0.1.0/d9d/loop/control/__init__.py +61 -0
- d9d-0.1.0/d9d/loop/control/dataset_provider.py +58 -0
- d9d-0.1.0/d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d-0.1.0/d9d/loop/control/model_provider.py +162 -0
- d9d-0.1.0/d9d/loop/control/optimizer_provider.py +45 -0
- d9d-0.1.0/d9d/loop/control/task.py +304 -0
- d9d-0.1.0/d9d/loop/run/__init__.py +6 -0
- d9d-0.1.0/d9d/loop/run/train.py +355 -0
- d9d-0.1.0/d9d/loop/state.py +143 -0
- d9d-0.1.0/d9d/lr_scheduler/__init__.py +9 -0
- d9d-0.1.0/d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d-0.1.0/d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d-0.1.0/d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d-0.1.0/d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d-0.1.0/d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d-0.1.0/d9d/lr_scheduler/visualizer.py +74 -0
- d9d-0.1.0/d9d/metric/__init__.py +10 -0
- d9d-0.1.0/d9d/metric/abc.py +79 -0
- d9d-0.1.0/d9d/metric/impl/__init__.py +7 -0
- d9d-0.1.0/d9d/metric/impl/compose.py +54 -0
- d9d-0.1.0/d9d/metric/impl/mean.py +94 -0
- d9d-0.1.0/d9d/model_state/__init__.py +0 -0
- d9d-0.1.0/d9d/model_state/io/__init__.py +21 -0
- d9d-0.1.0/d9d/model_state/io/dto.py +30 -0
- d9d-0.1.0/d9d/model_state/io/module_reader.py +75 -0
- d9d-0.1.0/d9d/model_state/io/module_writer.py +123 -0
- d9d-0.1.0/d9d/model_state/io/reader.py +125 -0
- d9d-0.1.0/d9d/model_state/io/writer.py +309 -0
- d9d-0.1.0/d9d/model_state/mapper/__init__.py +10 -0
- d9d-0.1.0/d9d/model_state/mapper/abc.py +70 -0
- d9d-0.1.0/d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d-0.1.0/d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d-0.1.0/d9d/model_state/mapper/adapters/module.py +22 -0
- d9d-0.1.0/d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d-0.1.0/d9d/model_state/mapper/compose/helper.py +22 -0
- d9d-0.1.0/d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d-0.1.0/d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d-0.1.0/d9d/model_state/mapper/compose/shard.py +36 -0
- d9d-0.1.0/d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d-0.1.0/d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d-0.1.0/d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d-0.1.0/d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d-0.1.0/d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d-0.1.0/d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d-0.1.0/d9d/module/__init__.py +0 -0
- d9d-0.1.0/d9d/module/base/__init__.py +7 -0
- d9d-0.1.0/d9d/module/base/late_init.py +10 -0
- d9d-0.1.0/d9d/module/block/__init__.py +0 -0
- d9d-0.1.0/d9d/module/block/attention/__init__.py +7 -0
- d9d-0.1.0/d9d/module/block/attention/grouped_query.py +139 -0
- d9d-0.1.0/d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d-0.1.0/d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d-0.1.0/d9d/module/block/embedding/__init__.py +7 -0
- d9d-0.1.0/d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d-0.1.0/d9d/module/block/ffn/__init__.py +5 -0
- d9d-0.1.0/d9d/module/block/ffn/swiglu.py +60 -0
- d9d-0.1.0/d9d/module/block/head/__init__.py +6 -0
- d9d-0.1.0/d9d/module/block/head/language_modelling.py +87 -0
- d9d-0.1.0/d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d-0.1.0/d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d-0.1.0/d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d-0.1.0/d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d-0.1.0/d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d-0.1.0/d9d/module/block/moe/__init__.py +13 -0
- d9d-0.1.0/d9d/module/block/moe/communications/__init__.py +11 -0
- d9d-0.1.0/d9d/module/block/moe/communications/base.py +58 -0
- d9d-0.1.0/d9d/module/block/moe/communications/deepep.py +300 -0
- d9d-0.1.0/d9d/module/block/moe/communications/naive.py +68 -0
- d9d-0.1.0/d9d/module/block/moe/grouped_experts.py +81 -0
- d9d-0.1.0/d9d/module/block/moe/grouped_linear.py +78 -0
- d9d-0.1.0/d9d/module/block/moe/layer.py +122 -0
- d9d-0.1.0/d9d/module/block/moe/router.py +103 -0
- d9d-0.1.0/d9d/module/block/positional/__init__.py +8 -0
- d9d-0.1.0/d9d/module/block/positional/rope.py +150 -0
- d9d-0.1.0/d9d/module/model/__init__.py +0 -0
- d9d-0.1.0/d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d-0.1.0/d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d-0.1.0/d9d/module/model/qwen3_moe/model.py +373 -0
- d9d-0.1.0/d9d/module/model/qwen3_moe/params.py +69 -0
- d9d-0.1.0/d9d/module/parallelism/__init__.py +0 -0
- d9d-0.1.0/d9d/module/parallelism/api/__init__.py +18 -0
- d9d-0.1.0/d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d-0.1.0/d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d-0.1.0/d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d-0.1.0/d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d-0.1.0/d9d/module/parallelism/model/__init__.py +0 -0
- d9d-0.1.0/d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d-0.1.0/d9d/module/parallelism/style/__init__.py +7 -0
- d9d-0.1.0/d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d-0.1.0/d9d/module/parallelism/style/to_local.py +86 -0
- d9d-0.1.0/d9d/optim/__init__.py +0 -0
- d9d-0.1.0/d9d/optim/stochastic/__init__.py +5 -0
- d9d-0.1.0/d9d/optim/stochastic/adamw.py +158 -0
- d9d-0.1.0/d9d/peft/__init__.py +13 -0
- d9d-0.1.0/d9d/peft/all/__init__.py +12 -0
- d9d-0.1.0/d9d/peft/all/config.py +31 -0
- d9d-0.1.0/d9d/peft/all/method.py +76 -0
- d9d-0.1.0/d9d/peft/applicator.py +47 -0
- d9d-0.1.0/d9d/peft/base.py +70 -0
- d9d-0.1.0/d9d/peft/full_tune/__init__.py +11 -0
- d9d-0.1.0/d9d/peft/full_tune/config.py +20 -0
- d9d-0.1.0/d9d/peft/full_tune/method.py +46 -0
- d9d-0.1.0/d9d/peft/lora/__init__.py +15 -0
- d9d-0.1.0/d9d/peft/lora/config.py +35 -0
- d9d-0.1.0/d9d/peft/lora/layer.py +177 -0
- d9d-0.1.0/d9d/peft/lora/method.py +132 -0
- d9d-0.1.0/d9d/pipelining/__init__.py +0 -0
- d9d-0.1.0/d9d/pipelining/api/__init__.py +19 -0
- d9d-0.1.0/d9d/pipelining/api/module.py +149 -0
- d9d-0.1.0/d9d/pipelining/api/schedule.py +50 -0
- d9d-0.1.0/d9d/pipelining/api/sharding.py +9 -0
- d9d-0.1.0/d9d/pipelining/factory/__init__.py +21 -0
- d9d-0.1.0/d9d/pipelining/factory/config.py +89 -0
- d9d-0.1.0/d9d/pipelining/factory/factory.py +114 -0
- d9d-0.1.0/d9d/pipelining/factory/registry.py +82 -0
- d9d-0.1.0/d9d/pipelining/infra/__init__.py +0 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d-0.1.0/d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d-0.1.0/d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d-0.1.0/d9d/pipelining/infra/stage/communications.py +274 -0
- d9d-0.1.0/d9d/pipelining/infra/stage/computations.py +317 -0
- d9d-0.1.0/d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d-0.1.0/d9d/pipelining/infra/stage/stage.py +321 -0
- d9d-0.1.0/d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d-0.1.0/d9d/pipelining/training/__init__.py +7 -0
- d9d-0.1.0/d9d/pipelining/training/optimizer.py +41 -0
- d9d-0.1.0/d9d/pipelining/training/scheduler.py +34 -0
- d9d-0.1.0/d9d/tracker/__init__.py +14 -0
- d9d-0.1.0/d9d/tracker/base.py +124 -0
- d9d-0.1.0/d9d/tracker/factory.py +57 -0
- d9d-0.1.0/d9d/tracker/provider/__init__.py +0 -0
- d9d-0.1.0/d9d/tracker/provider/aim/__init__.py +0 -0
- d9d-0.1.0/d9d/tracker/provider/aim/config.py +23 -0
- d9d-0.1.0/d9d/tracker/provider/aim/tracker.py +114 -0
- d9d-0.1.0/d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0/pyproject.toml +267 -0
d9d-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: d9d
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: d9d - d[istribute]d - distributed training framework based on PyTorch that tries to be efficient yet hackable
|
|
5
|
+
License: Apache-2.0
|
|
6
|
+
Author: Maksim Afanasyev
|
|
7
|
+
Author-email: mr.applexz@gmail.com
|
|
8
|
+
Requires-Python: >=3.11,<3.15
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Intended Audience :: Education
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
|
+
Classifier: Topic :: Software Development
|
|
17
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
18
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
19
|
+
Provides-Extra: aim
|
|
20
|
+
Provides-Extra: cce
|
|
21
|
+
Provides-Extra: moe
|
|
22
|
+
Provides-Extra: visualization
|
|
23
|
+
Requires-Dist: aim (>=3.0.0,<4.0.0) ; extra == "aim"
|
|
24
|
+
Requires-Dist: cut-cross-entropy (>=25.9.3) ; extra == "cce"
|
|
25
|
+
Requires-Dist: deep-ep (>=1.2.1) ; extra == "moe"
|
|
26
|
+
Requires-Dist: nv-grouped-gemm (>=1.1.4) ; extra == "moe"
|
|
27
|
+
Requires-Dist: plotly (>=6.0.0) ; extra == "visualization"
|
|
28
|
+
Requires-Dist: pydantic (>=2.0.0)
|
|
29
|
+
Requires-Dist: safetensors (>=0.7.0)
|
|
30
|
+
Requires-Dist: setuptools (>=70.0.0) ; extra == "aim"
|
|
31
|
+
Requires-Dist: torch (>=2.10.0)
|
|
32
|
+
Requires-Dist: torchdata (>=0.11.0)
|
|
33
|
+
Requires-Dist: tqdm (>=4.0.0)
|
|
34
|
+
Requires-Dist: triton (>=3.6.0)
|
|
35
|
+
Project-URL: Documentation, https://d9d-project.github.io/d9d
|
|
36
|
+
Project-URL: Homepage, https://d9d-project.github.io/d9d
|
|
37
|
+
Project-URL: Issues, https://github.com/d9d-project/d9d/issues
|
|
38
|
+
Project-URL: Repository, https://github.com/d9d-project/d9d
|
|
39
|
+
Description-Content-Type: text/markdown
|
|
40
|
+
|
|
41
|
+
# The d9d Project
|
|
42
|
+
|
|
43
|
+
**d9d** is a distributed training framework built on top of PyTorch 2.0. It aims to be hackable, modular, and efficient, designed to scale from single-GPU debugging to massive clusters running 6D-Parallelism.
|
|
44
|
+
|
|
45
|
+
[LET'S START TRAINING 🚀](https://d9d-project.github.io/d9d/)
|
|
46
|
+
|
|
47
|
+
## Why another framework?
|
|
48
|
+
|
|
49
|
+
Distributed training frameworks such as **Megatron-LM** are monolithic in the way you run a script from the command line to train any of a set of *predefined* models, using *predefined* regimes. While powerful, these systems can be difficult to hack and integrate into novel research workflows. Their focus is often on providing a complete, end-to-end solution, which can limit flexibility for experimentally-driven research.
|
|
50
|
+
|
|
51
|
+
Conversely, creating your own distributed training solution from scratch is tricky. You have to implement many low-level components (like distributed checkpoints and synchronization) that are identical across setups, and manually tackle common performance bottlenecks.
|
|
52
|
+
|
|
53
|
+
**d9d** was designed to fill the gap between monolithic frameworks and homebrew setups, providing a modular yet effective solution for distributed training.
|
|
54
|
+
|
|
55
|
+
## What d9d is and isn't
|
|
56
|
+
|
|
57
|
+
In terms of **core concept**:
|
|
58
|
+
|
|
59
|
+
* **IS** a pluggable framework for implementing distributed training regimes for your deep learning models.
|
|
60
|
+
* **IS** built on clear interfaces and building blocks that may be composed and implemented in your own way.
|
|
61
|
+
* **IS NOT** an all-in-one CLI platform for setting up pre-training and post-training like **torchtitan**, **Megatron-LM**, or **torchforge**.
|
|
62
|
+
|
|
63
|
+
In terms of **codebase & engineering**:
|
|
64
|
+
|
|
65
|
+
* **IS** built on a **strong engineering foundation**: We enforce strict type-checking and rigorous linting to catch errors before execution.
|
|
66
|
+
* **IS** reliable: The framework is backed by a suite of **over 450 tests**, covering unit logic, integration flows, and End-to-End distributed scenarios.
|
|
67
|
+
* **IS** eager to use performance hacks (like **DeepEp** or custom kernels) if they improve MFU, even if they aren't PyTorch-native.
|
|
68
|
+
* **IS NOT** for legacy setups: We do not maintain backward compatibility with older PyTorch versions or hardware. We prioritize simplicity and modern APIs (like `DTensor`).
|
|
69
|
+
|
|
70
|
+
## Key Philosophies
|
|
71
|
+
|
|
72
|
+
To achieve the balance between hackability and performance, d9d adheres to specific design principles:
|
|
73
|
+
|
|
74
|
+
* **Composition over Monoliths**: We avoid "God Classes" like `DistributedDataParallel` or `ParallelDims` that assume ownership of the entire execution loop. Instead, we provide composable and extendable APIs. For instance, specific horizontal parallelism strategies for specific layers (`parallelize_replicate`, `parallelize_expert_parallel`, ...).
|
|
75
|
+
* **White-Box Modelling**: We encourage standard PyTorch code. Models are not wrapped in obscure metadata specifications; they are standard `nn.Module`s that implement lightweight protocols.
|
|
76
|
+
* **Pragmatic Efficiency**: While we prefer native PyTorch, we are eager to integrate non-native solutions if they improve MFU. For example, we implement MoE using **DeepEp** communications, reindexing kernels from **Megatron-LM**, and efficient grouped-GEMM implementations.
|
|
77
|
+
* **Graph-Based State Management**: Our IO system treats model checkpoints as directed acyclic graphs. This allows you to transform architectures (e.g., merging `q`, `k`, `v` into `qkv`) on-the-fly while streaming from disk, without massive memory overhead.
|
|
78
|
+
* **DTensors**: We mandate that distributed parameters be represented as `torch.distributed.tensor.DTensor`. This simplifies checkpointing by making them topology-aware automatically. We leverage modern PyTorch 2.0 APIs (`DeviceMesh`) as much as possible.
|
|
79
|
+
|
|
80
|
+
---
|
|
81
|
+
|
|
82
|
+
## Examples
|
|
83
|
+
|
|
84
|
+
### Qwen3-MoE Pretraining
|
|
85
|
+
An example showing causal LM pretraing for the Qwen3-MoE model.
|
|
86
|
+
|
|
87
|
+
WIP: MoE load balancing is currently work in progress.
|
|
88
|
+
|
|
89
|
+
[Link](https://github.com/d9d-project/d9d/blob/main/example/qwen3_moe/pretrain.py).
|
|
90
|
+
|
d9d-0.1.0/README.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# The d9d Project
|
|
2
|
+
|
|
3
|
+
**d9d** is a distributed training framework built on top of PyTorch 2.0. It aims to be hackable, modular, and efficient, designed to scale from single-GPU debugging to massive clusters running 6D-Parallelism.
|
|
4
|
+
|
|
5
|
+
[LET'S START TRAINING 🚀](https://d9d-project.github.io/d9d/)
|
|
6
|
+
|
|
7
|
+
## Why another framework?
|
|
8
|
+
|
|
9
|
+
Distributed training frameworks such as **Megatron-LM** are monolithic in the way you run a script from the command line to train any of a set of *predefined* models, using *predefined* regimes. While powerful, these systems can be difficult to hack and integrate into novel research workflows. Their focus is often on providing a complete, end-to-end solution, which can limit flexibility for experimentally-driven research.
|
|
10
|
+
|
|
11
|
+
Conversely, creating your own distributed training solution from scratch is tricky. You have to implement many low-level components (like distributed checkpoints and synchronization) that are identical across setups, and manually tackle common performance bottlenecks.
|
|
12
|
+
|
|
13
|
+
**d9d** was designed to fill the gap between monolithic frameworks and homebrew setups, providing a modular yet effective solution for distributed training.
|
|
14
|
+
|
|
15
|
+
## What d9d is and isn't
|
|
16
|
+
|
|
17
|
+
In terms of **core concept**:
|
|
18
|
+
|
|
19
|
+
* **IS** a pluggable framework for implementing distributed training regimes for your deep learning models.
|
|
20
|
+
* **IS** built on clear interfaces and building blocks that may be composed and implemented in your own way.
|
|
21
|
+
* **IS NOT** an all-in-one CLI platform for setting up pre-training and post-training like **torchtitan**, **Megatron-LM**, or **torchforge**.
|
|
22
|
+
|
|
23
|
+
In terms of **codebase & engineering**:
|
|
24
|
+
|
|
25
|
+
* **IS** built on a **strong engineering foundation**: We enforce strict type-checking and rigorous linting to catch errors before execution.
|
|
26
|
+
* **IS** reliable: The framework is backed by a suite of **over 450 tests**, covering unit logic, integration flows, and End-to-End distributed scenarios.
|
|
27
|
+
* **IS** eager to use performance hacks (like **DeepEp** or custom kernels) if they improve MFU, even if they aren't PyTorch-native.
|
|
28
|
+
* **IS NOT** for legacy setups: We do not maintain backward compatibility with older PyTorch versions or hardware. We prioritize simplicity and modern APIs (like `DTensor`).
|
|
29
|
+
|
|
30
|
+
## Key Philosophies
|
|
31
|
+
|
|
32
|
+
To achieve the balance between hackability and performance, d9d adheres to specific design principles:
|
|
33
|
+
|
|
34
|
+
* **Composition over Monoliths**: We avoid "God Classes" like `DistributedDataParallel` or `ParallelDims` that assume ownership of the entire execution loop. Instead, we provide composable and extendable APIs. For instance, specific horizontal parallelism strategies for specific layers (`parallelize_replicate`, `parallelize_expert_parallel`, ...).
|
|
35
|
+
* **White-Box Modelling**: We encourage standard PyTorch code. Models are not wrapped in obscure metadata specifications; they are standard `nn.Module`s that implement lightweight protocols.
|
|
36
|
+
* **Pragmatic Efficiency**: While we prefer native PyTorch, we are eager to integrate non-native solutions if they improve MFU. For example, we implement MoE using **DeepEp** communications, reindexing kernels from **Megatron-LM**, and efficient grouped-GEMM implementations.
|
|
37
|
+
* **Graph-Based State Management**: Our IO system treats model checkpoints as directed acyclic graphs. This allows you to transform architectures (e.g., merging `q`, `k`, `v` into `qkv`) on-the-fly while streaming from disk, without massive memory overhead.
|
|
38
|
+
* **DTensors**: We mandate that distributed parameters be represented as `torch.distributed.tensor.DTensor`. This simplifies checkpointing by making them topology-aware automatically. We leverage modern PyTorch 2.0 APIs (`DeviceMesh`) as much as possible.
|
|
39
|
+
|
|
40
|
+
---
|
|
41
|
+
|
|
42
|
+
## Examples
|
|
43
|
+
|
|
44
|
+
### Qwen3-MoE Pretraining
|
|
45
|
+
An example showing causal LM pretraing for the Qwen3-MoE model.
|
|
46
|
+
|
|
47
|
+
WIP: MoE load balancing is currently work in progress.
|
|
48
|
+
|
|
49
|
+
[Link](https://github.com/d9d-project/d9d/blob/main/example/qwen3_moe/pretrain.py).
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from enum import StrEnum
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class GradDirection(StrEnum):
|
|
6
|
+
"""
|
|
7
|
+
Enum representing the specific gradient edges to compute.
|
|
8
|
+
|
|
9
|
+
This is used to manually control gradient flow in custom autograd functions
|
|
10
|
+
during split backward passes.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
inputs: Mark gradient edge as pointing to the module's inputs (activations).
|
|
14
|
+
weight: Mark gradient edge as pointing to the module's parameters (weights).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
inputs = "inputs"
|
|
18
|
+
weight = "weights"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GlobalGradContext:
|
|
22
|
+
"""
|
|
23
|
+
Global state manager for controlling gradient computation in custom autograd functions.
|
|
24
|
+
|
|
25
|
+
This context addresses a limitation in PyTorch where custom `torch.autograd.Function`
|
|
26
|
+
implementations set `ctx.needs_input_grad` to True for all edges requiring grad,
|
|
27
|
+
even during partial backward passes (e.g., `torch.autograd.backward(inputs=...)`).
|
|
28
|
+
|
|
29
|
+
For additional information on this limitation, please refer to a
|
|
30
|
+
[related issue](https://github.com/pytorch/pytorch/issues/174017).
|
|
31
|
+
|
|
32
|
+
This class allows:
|
|
33
|
+
|
|
34
|
+
1. For the training code - to explicitly signal which gradient edges (inputs vs weights)
|
|
35
|
+
should currently be computed, allowing custom ops to skip unnecessary computations.
|
|
36
|
+
2. For module code - to check whether it's required to compute a gradient edge.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
"""Constructs a GlobalGradContext object with all directions enabled by default."""
|
|
41
|
+
|
|
42
|
+
# both directions by default
|
|
43
|
+
self._enabled_directions: set[GradDirection] = {GradDirection.inputs, GradDirection.weight}
|
|
44
|
+
|
|
45
|
+
def check_direction(self, direction: GradDirection | None) -> bool:
|
|
46
|
+
"""
|
|
47
|
+
Checks if the gradient calculation for the given direction is currently enabled.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
direction: The direction to check (inputs or weights). If None,
|
|
51
|
+
returns True.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
True if the direction is enabled or None is passed, False otherwise.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
if direction is None:
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
return direction in self._enabled_directions
|
|
61
|
+
|
|
62
|
+
@contextmanager
|
|
63
|
+
def with_directions(self, *directions: GradDirection):
|
|
64
|
+
"""
|
|
65
|
+
Context manager that sets the enabled gradient directions.
|
|
66
|
+
|
|
67
|
+
This overrides the current state for the duration of the context
|
|
68
|
+
and restores the previous state afterwards.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
*directions: The gradient directions to enable.
|
|
72
|
+
"""
|
|
73
|
+
prev_directions = self._enabled_directions
|
|
74
|
+
self._enabled_directions = set(directions)
|
|
75
|
+
yield
|
|
76
|
+
self._enabled_directions = prev_directions
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
GLOBAL_GRAD_CONTEXT = GlobalGradContext()
|
|
80
|
+
"""
|
|
81
|
+
The singleton instance of GlobalGradContext.
|
|
82
|
+
|
|
83
|
+
This should be used by custom autograd functions to check `GLOBAL_GRAD_CONTEXT.check_direction()`
|
|
84
|
+
during their backward pass.
|
|
85
|
+
"""
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This package configures the distributed environment and device meshes.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .configured import DistributedContext
|
|
6
|
+
from .device_mesh_domains import BATCH_DOMAIN, DENSE_DOMAIN, EXPERT_DOMAIN, FLAT_DOMAIN, REGULAR_DOMAIN
|
|
7
|
+
from .log import build_dist_logger
|
|
8
|
+
from .params import DeviceMeshParameters
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"BATCH_DOMAIN",
|
|
12
|
+
"DENSE_DOMAIN",
|
|
13
|
+
"EXPERT_DOMAIN",
|
|
14
|
+
"FLAT_DOMAIN",
|
|
15
|
+
"REGULAR_DOMAIN",
|
|
16
|
+
"DeviceMeshParameters",
|
|
17
|
+
"DistributedContext",
|
|
18
|
+
"build_dist_logger"
|
|
19
|
+
]
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import socket
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.distributed import DeviceMesh
|
|
10
|
+
|
|
11
|
+
from .device_mesh_domains import ALL_DOMAIN_PROVIDERS, REGULAR_DOMAIN
|
|
12
|
+
from .log import build_dist_logger
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from .params import DeviceMeshParameters
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _resolve_master_addr() -> str:
|
|
19
|
+
if "MASTER_ADDR" not in os.environ:
|
|
20
|
+
return "127.0.0.1"
|
|
21
|
+
|
|
22
|
+
master_addr = os.environ["MASTER_ADDR"]
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
return socket.gethostbyname(master_addr)
|
|
26
|
+
except OSError:
|
|
27
|
+
return master_addr
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _build_mesh_domains(params: "DeviceMeshParameters") -> dict[str, DeviceMesh]:
|
|
31
|
+
return {
|
|
32
|
+
provider.name: provider.build_mesh(params)
|
|
33
|
+
for provider in ALL_DOMAIN_PROVIDERS
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class DistributedContext:
|
|
38
|
+
"""
|
|
39
|
+
Acts as the single source of truth for the distributed execution environment.
|
|
40
|
+
|
|
41
|
+
It acts as the central repository for the distributed configuration, managing the creation
|
|
42
|
+
and synchronization of PyTorch DeviceMeshes for different domains (Regular domain, Expert Parallel domain, ...).
|
|
43
|
+
|
|
44
|
+
All assertions regarding rank placement, group memberships, and parallel topology
|
|
45
|
+
must be derived from this context to ensure consistency.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, params: "DeviceMeshParameters", log_level: int):
|
|
49
|
+
self._params = params
|
|
50
|
+
|
|
51
|
+
if params.is_distributed:
|
|
52
|
+
meshes = _build_mesh_domains(params)
|
|
53
|
+
regular_mesh = meshes[REGULAR_DOMAIN]
|
|
54
|
+
|
|
55
|
+
self._meshes = meshes
|
|
56
|
+
self._num_nodes = regular_mesh.size() // torch.cuda.device_count()
|
|
57
|
+
self._logger = build_dist_logger(
|
|
58
|
+
f'pp:{regular_mesh.get_local_rank("pp")}-'
|
|
59
|
+
f'dpr:{regular_mesh.get_local_rank("dp_replicate")}-'
|
|
60
|
+
f'dps:{regular_mesh.get_local_rank("dp_shard")}-'
|
|
61
|
+
f'cps:{regular_mesh.get_local_rank("cp_shard")}-'
|
|
62
|
+
f'cpr:{regular_mesh.get_local_rank("cp_replicate")}-'
|
|
63
|
+
f'tp:{regular_mesh.get_local_rank("tp")}',
|
|
64
|
+
level=log_level
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
self._meshes = {}
|
|
68
|
+
self._num_nodes = 1
|
|
69
|
+
self._logger = build_dist_logger("local", level=log_level)
|
|
70
|
+
|
|
71
|
+
self._local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
72
|
+
self._global_rank = int(os.environ.get("RANK", "0"))
|
|
73
|
+
|
|
74
|
+
self._node_rank = self._global_rank // torch.cuda.device_count()
|
|
75
|
+
|
|
76
|
+
self._master_addr = _resolve_master_addr()
|
|
77
|
+
self._current_device = torch.device("cuda")
|
|
78
|
+
|
|
79
|
+
torch.cuda.set_device(self._local_rank)
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def logger(self) -> logging.Logger:
|
|
83
|
+
"""Returns the logger instance configured for distributed logging."""
|
|
84
|
+
|
|
85
|
+
return self._logger
|
|
86
|
+
|
|
87
|
+
def mesh_for(self, domain: str) -> DeviceMesh:
|
|
88
|
+
"""
|
|
89
|
+
Returns the device mesh view associated with a specific logical domain.
|
|
90
|
+
|
|
91
|
+
Available Domains and Dimensions:
|
|
92
|
+
* `regular` (`REGULAR_DOMAIN`): The most granular mesh for fully decomposed parallelism.
|
|
93
|
+
Dimensions: ``('pp', 'dp_replicate', 'dp_shard', 'cp_shard', 'cp_replicate', 'tp')``
|
|
94
|
+
* `expert` (`EXPERT_DOMAIN`): Mesh optimized for distributing MoE (Mixture of Experts) layers.
|
|
95
|
+
Dimensions: ``('pp', 'replicate', 'ep')``
|
|
96
|
+
* `dense` (`DENSE_DOMAIN`): Mesh optimized for distributing dense layers.
|
|
97
|
+
Dimensions: ``('pp', 'dp_replicate', 'dp_cp_shard', 'cp_replicate', 'tp')``
|
|
98
|
+
* `batch` (`BATCH_DOMAIN`): Mesh optimized for distributing input data.
|
|
99
|
+
Dimensions: ``('pp', 'dp', 'cp', 'tp')``
|
|
100
|
+
* `flat` (`FLAT_DOMAIN`): Mesh containing a single dimension with all the processes.
|
|
101
|
+
Dimensions: ``('world')``
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
domain: The name of the domain to retrieve.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
The PyTorch DeviceMesh configured for the requested domain.
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
ValueError: If the specified domain does not exist.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
if domain not in self._meshes:
|
|
114
|
+
raise ValueError(f"Domain {domain} does not exist")
|
|
115
|
+
return self._meshes[domain]
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def is_main_process(self) -> bool:
|
|
119
|
+
"""Checks if the current process is the global rank 0."""
|
|
120
|
+
|
|
121
|
+
return self._global_rank == 0
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def is_local_main_process(self) -> bool:
|
|
125
|
+
"""Checks if the current process is the rank 0 on the specific node."""
|
|
126
|
+
|
|
127
|
+
return self._local_rank == 0
|
|
128
|
+
|
|
129
|
+
def wait_world(self):
|
|
130
|
+
"""Blocks process execution until all ranks reach this point."""
|
|
131
|
+
|
|
132
|
+
torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
|
|
133
|
+
torch.cuda.synchronize()
|
|
134
|
+
|
|
135
|
+
def set_timeout(self, timeout_seconds: float):
|
|
136
|
+
"""
|
|
137
|
+
Updates the NCCL/process group timeout for all underlying meshes.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
timeout_seconds: New timeout duration in seconds.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
self.logger.info(f"Setting global timeout to {timeout_seconds} seconds")
|
|
144
|
+
self.wait_world()
|
|
145
|
+
|
|
146
|
+
groups: list[torch.distributed.ProcessGroup | None] = [None]
|
|
147
|
+
for mesh in self._meshes.values():
|
|
148
|
+
for dim in range(mesh.ndim):
|
|
149
|
+
groups.append(mesh.get_group(dim))
|
|
150
|
+
|
|
151
|
+
for group in groups:
|
|
152
|
+
torch.distributed.distributed_c10d._set_pg_timeout(datetime.timedelta(seconds=timeout_seconds), group) # noqa: SLF001
|
|
153
|
+
|
|
154
|
+
@contextmanager
|
|
155
|
+
def local_main_process_first(self):
|
|
156
|
+
"""
|
|
157
|
+
Context manager that executes the block on the local main process first.
|
|
158
|
+
|
|
159
|
+
Other local ranks wait at the entrance. The local main process waits at the
|
|
160
|
+
exit to synchronize before continuing.
|
|
161
|
+
"""
|
|
162
|
+
if not self.is_local_main_process:
|
|
163
|
+
self.wait_world()
|
|
164
|
+
|
|
165
|
+
yield
|
|
166
|
+
|
|
167
|
+
if self.is_local_main_process:
|
|
168
|
+
self.wait_world()
|
|
169
|
+
|
|
170
|
+
@contextmanager
|
|
171
|
+
def main_process_first(self):
|
|
172
|
+
"""
|
|
173
|
+
Context manager that executes the block on the global main process first.
|
|
174
|
+
|
|
175
|
+
All other ranks wait at the entrance. The global main process waits at the
|
|
176
|
+
exit to synchronize before continuing.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
if not self.is_main_process:
|
|
180
|
+
self.wait_world()
|
|
181
|
+
|
|
182
|
+
yield
|
|
183
|
+
|
|
184
|
+
if self.is_main_process:
|
|
185
|
+
self.wait_world()
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def current_device(self) -> torch.device:
|
|
189
|
+
"""Returns the CUDA device associated with this rank."""
|
|
190
|
+
|
|
191
|
+
return self._current_device
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def mesh_params(self) -> "DeviceMeshParameters":
|
|
195
|
+
"""Returns the parameters used to initialize this context."""
|
|
196
|
+
|
|
197
|
+
return self._params
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def master_addr(self) -> str:
|
|
201
|
+
"""Returns the IP address or domain name of the master node."""
|
|
202
|
+
|
|
203
|
+
return self._master_addr
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def node_rank(self) -> int:
|
|
207
|
+
"""Returns the index of the node this process is running on."""
|
|
208
|
+
|
|
209
|
+
return self._node_rank
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def num_nodes(self) -> int:
|
|
213
|
+
"""Returns the total number of nodes in the cluster."""
|
|
214
|
+
|
|
215
|
+
return self._num_nodes
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .params import DeviceMeshParameters
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DeviceMeshDomain(abc.ABC):
|
|
11
|
+
"""
|
|
12
|
+
Abstract base class for a Device Mesh provider.
|
|
13
|
+
|
|
14
|
+
A Domain defines a specific strategy for organizing available GPUs into a
|
|
15
|
+
multidimensional grid (Mesh) to support specific parallelism techniques.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def name(self) -> str:
|
|
21
|
+
"""Returns the unique identifier for this mesh domain."""
|
|
22
|
+
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
@abc.abstractmethod
|
|
26
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
27
|
+
"""
|
|
28
|
+
Constructs the device mesh configuration.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
params: Global configuration parameters for the distributed environment.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The initialized PyTorch DeviceMesh for this specific domain.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
REGULAR_DOMAIN = "regular"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class RegularDomain(DeviceMeshDomain):
|
|
44
|
+
@property
|
|
45
|
+
def name(self) -> str:
|
|
46
|
+
return "regular"
|
|
47
|
+
|
|
48
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
49
|
+
return init_device_mesh(
|
|
50
|
+
device_type="cuda",
|
|
51
|
+
mesh_shape=(
|
|
52
|
+
params.pipeline_parallel,
|
|
53
|
+
params.data_parallel_replicate,
|
|
54
|
+
params.data_parallel_shard,
|
|
55
|
+
params.context_parallel_shard,
|
|
56
|
+
params.context_parallel_replicate,
|
|
57
|
+
params.tensor_parallel
|
|
58
|
+
),
|
|
59
|
+
mesh_dim_names=(
|
|
60
|
+
"pp",
|
|
61
|
+
"dp_replicate",
|
|
62
|
+
"dp_shard",
|
|
63
|
+
"cp_shard",
|
|
64
|
+
"cp_replicate",
|
|
65
|
+
"tp"
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
EXPERT_DOMAIN = "expert"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ExpertDomain(DeviceMeshDomain):
|
|
74
|
+
@property
|
|
75
|
+
def name(self) -> str:
|
|
76
|
+
return EXPERT_DOMAIN
|
|
77
|
+
|
|
78
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
79
|
+
replicate_degree = (
|
|
80
|
+
params.data_parallel_replicate *
|
|
81
|
+
params.context_parallel_replicate *
|
|
82
|
+
params.data_parallel_shard *
|
|
83
|
+
params.context_parallel_shard
|
|
84
|
+
)
|
|
85
|
+
return init_device_mesh(
|
|
86
|
+
device_type="cuda",
|
|
87
|
+
mesh_shape=(
|
|
88
|
+
params.pipeline_parallel,
|
|
89
|
+
replicate_degree // params.expert_parallel,
|
|
90
|
+
params.expert_parallel
|
|
91
|
+
),
|
|
92
|
+
mesh_dim_names=(
|
|
93
|
+
"pp",
|
|
94
|
+
"ep_replicate",
|
|
95
|
+
"ep_shard"
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
DENSE_DOMAIN = "dense"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class DenseDomain(DeviceMeshDomain):
|
|
104
|
+
@property
|
|
105
|
+
def name(self) -> str:
|
|
106
|
+
return DENSE_DOMAIN
|
|
107
|
+
|
|
108
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
109
|
+
return init_device_mesh(
|
|
110
|
+
device_type="cuda",
|
|
111
|
+
mesh_shape=(
|
|
112
|
+
params.pipeline_parallel,
|
|
113
|
+
params.data_parallel_replicate,
|
|
114
|
+
params.data_parallel_shard * params.context_parallel_shard,
|
|
115
|
+
params.context_parallel_replicate,
|
|
116
|
+
params.tensor_parallel
|
|
117
|
+
),
|
|
118
|
+
mesh_dim_names=(
|
|
119
|
+
"pp",
|
|
120
|
+
"dp_replicate",
|
|
121
|
+
"dp_cp_shard",
|
|
122
|
+
"cp_replicate",
|
|
123
|
+
"tp"
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
BATCH_DOMAIN = "batch"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class BatchDomain(DeviceMeshDomain):
|
|
132
|
+
@property
|
|
133
|
+
def name(self) -> str:
|
|
134
|
+
return BATCH_DOMAIN
|
|
135
|
+
|
|
136
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
137
|
+
return init_device_mesh(
|
|
138
|
+
device_type="cuda",
|
|
139
|
+
mesh_shape=(
|
|
140
|
+
params.pipeline_parallel,
|
|
141
|
+
params.data_parallel_replicate * params.data_parallel_shard,
|
|
142
|
+
params.context_parallel_replicate * params.context_parallel_shard,
|
|
143
|
+
params.tensor_parallel
|
|
144
|
+
),
|
|
145
|
+
mesh_dim_names=(
|
|
146
|
+
"pp",
|
|
147
|
+
"dp",
|
|
148
|
+
"cp",
|
|
149
|
+
"tp"
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
FLAT_DOMAIN = "flat"
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class FlatDomain(DeviceMeshDomain):
|
|
158
|
+
@property
|
|
159
|
+
def name(self) -> str:
|
|
160
|
+
return FLAT_DOMAIN
|
|
161
|
+
|
|
162
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
163
|
+
mesh_shape = (
|
|
164
|
+
params.pipeline_parallel *
|
|
165
|
+
params.data_parallel_replicate *
|
|
166
|
+
params.data_parallel_shard *
|
|
167
|
+
params.context_parallel_replicate *
|
|
168
|
+
params.context_parallel_shard *
|
|
169
|
+
params.tensor_parallel
|
|
170
|
+
)
|
|
171
|
+
return init_device_mesh(
|
|
172
|
+
device_type="cuda",
|
|
173
|
+
mesh_shape=(
|
|
174
|
+
mesh_shape,
|
|
175
|
+
),
|
|
176
|
+
mesh_dim_names=(
|
|
177
|
+
"world",
|
|
178
|
+
)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
ALL_DOMAIN_PROVIDERS: list[DeviceMeshDomain] = [
|
|
183
|
+
RegularDomain(), DenseDomain(), ExpertDomain(), BatchDomain(),
|
|
184
|
+
FlatDomain()
|
|
185
|
+
]
|