d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
d9d/__init__.py
ADDED
|
File without changes
|
d9d/core/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from enum import StrEnum
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class GradDirection(StrEnum):
|
|
6
|
+
"""
|
|
7
|
+
Enum representing the specific gradient edges to compute.
|
|
8
|
+
|
|
9
|
+
This is used to manually control gradient flow in custom autograd functions
|
|
10
|
+
during split backward passes.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
inputs: Mark gradient edge as pointing to the module's inputs (activations).
|
|
14
|
+
weight: Mark gradient edge as pointing to the module's parameters (weights).
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
inputs = "inputs"
|
|
18
|
+
weight = "weights"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GlobalGradContext:
|
|
22
|
+
"""
|
|
23
|
+
Global state manager for controlling gradient computation in custom autograd functions.
|
|
24
|
+
|
|
25
|
+
This context addresses a limitation in PyTorch where custom `torch.autograd.Function`
|
|
26
|
+
implementations set `ctx.needs_input_grad` to True for all edges requiring grad,
|
|
27
|
+
even during partial backward passes (e.g., `torch.autograd.backward(inputs=...)`).
|
|
28
|
+
|
|
29
|
+
For additional information on this limitation, please refer to a
|
|
30
|
+
[related issue](https://github.com/pytorch/pytorch/issues/174017).
|
|
31
|
+
|
|
32
|
+
This class allows:
|
|
33
|
+
|
|
34
|
+
1. For the training code - to explicitly signal which gradient edges (inputs vs weights)
|
|
35
|
+
should currently be computed, allowing custom ops to skip unnecessary computations.
|
|
36
|
+
2. For module code - to check whether it's required to compute a gradient edge.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self):
|
|
40
|
+
"""Constructs a GlobalGradContext object with all directions enabled by default."""
|
|
41
|
+
|
|
42
|
+
# both directions by default
|
|
43
|
+
self._enabled_directions: set[GradDirection] = {GradDirection.inputs, GradDirection.weight}
|
|
44
|
+
|
|
45
|
+
def check_direction(self, direction: GradDirection | None) -> bool:
|
|
46
|
+
"""
|
|
47
|
+
Checks if the gradient calculation for the given direction is currently enabled.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
direction: The direction to check (inputs or weights). If None,
|
|
51
|
+
returns True.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
True if the direction is enabled or None is passed, False otherwise.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
if direction is None:
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
return direction in self._enabled_directions
|
|
61
|
+
|
|
62
|
+
@contextmanager
|
|
63
|
+
def with_directions(self, *directions: GradDirection):
|
|
64
|
+
"""
|
|
65
|
+
Context manager that sets the enabled gradient directions.
|
|
66
|
+
|
|
67
|
+
This overrides the current state for the duration of the context
|
|
68
|
+
and restores the previous state afterwards.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
*directions: The gradient directions to enable.
|
|
72
|
+
"""
|
|
73
|
+
prev_directions = self._enabled_directions
|
|
74
|
+
self._enabled_directions = set(directions)
|
|
75
|
+
yield
|
|
76
|
+
self._enabled_directions = prev_directions
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
GLOBAL_GRAD_CONTEXT = GlobalGradContext()
|
|
80
|
+
"""
|
|
81
|
+
The singleton instance of GlobalGradContext.
|
|
82
|
+
|
|
83
|
+
This should be used by custom autograd functions to check `GLOBAL_GRAD_CONTEXT.check_direction()`
|
|
84
|
+
during their backward pass.
|
|
85
|
+
"""
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This package configures the distributed environment and device meshes.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .configured import DistributedContext
|
|
6
|
+
from .device_mesh_domains import BATCH_DOMAIN, DENSE_DOMAIN, EXPERT_DOMAIN, FLAT_DOMAIN, REGULAR_DOMAIN
|
|
7
|
+
from .log import build_dist_logger
|
|
8
|
+
from .params import DeviceMeshParameters
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"BATCH_DOMAIN",
|
|
12
|
+
"DENSE_DOMAIN",
|
|
13
|
+
"EXPERT_DOMAIN",
|
|
14
|
+
"FLAT_DOMAIN",
|
|
15
|
+
"REGULAR_DOMAIN",
|
|
16
|
+
"DeviceMeshParameters",
|
|
17
|
+
"DistributedContext",
|
|
18
|
+
"build_dist_logger"
|
|
19
|
+
]
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import socket
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.distributed import DeviceMesh
|
|
10
|
+
|
|
11
|
+
from .device_mesh_domains import ALL_DOMAIN_PROVIDERS, REGULAR_DOMAIN
|
|
12
|
+
from .log import build_dist_logger
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from .params import DeviceMeshParameters
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _resolve_master_addr() -> str:
|
|
19
|
+
if "MASTER_ADDR" not in os.environ:
|
|
20
|
+
return "127.0.0.1"
|
|
21
|
+
|
|
22
|
+
master_addr = os.environ["MASTER_ADDR"]
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
return socket.gethostbyname(master_addr)
|
|
26
|
+
except OSError:
|
|
27
|
+
return master_addr
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _build_mesh_domains(params: "DeviceMeshParameters") -> dict[str, DeviceMesh]:
|
|
31
|
+
return {
|
|
32
|
+
provider.name: provider.build_mesh(params)
|
|
33
|
+
for provider in ALL_DOMAIN_PROVIDERS
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class DistributedContext:
|
|
38
|
+
"""
|
|
39
|
+
Acts as the single source of truth for the distributed execution environment.
|
|
40
|
+
|
|
41
|
+
It acts as the central repository for the distributed configuration, managing the creation
|
|
42
|
+
and synchronization of PyTorch DeviceMeshes for different domains (Regular domain, Expert Parallel domain, ...).
|
|
43
|
+
|
|
44
|
+
All assertions regarding rank placement, group memberships, and parallel topology
|
|
45
|
+
must be derived from this context to ensure consistency.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, params: "DeviceMeshParameters", log_level: int):
|
|
49
|
+
self._params = params
|
|
50
|
+
|
|
51
|
+
if params.is_distributed:
|
|
52
|
+
meshes = _build_mesh_domains(params)
|
|
53
|
+
regular_mesh = meshes[REGULAR_DOMAIN]
|
|
54
|
+
|
|
55
|
+
self._meshes = meshes
|
|
56
|
+
self._num_nodes = regular_mesh.size() // torch.cuda.device_count()
|
|
57
|
+
self._logger = build_dist_logger(
|
|
58
|
+
f'pp:{regular_mesh.get_local_rank("pp")}-'
|
|
59
|
+
f'dpr:{regular_mesh.get_local_rank("dp_replicate")}-'
|
|
60
|
+
f'dps:{regular_mesh.get_local_rank("dp_shard")}-'
|
|
61
|
+
f'cps:{regular_mesh.get_local_rank("cp_shard")}-'
|
|
62
|
+
f'cpr:{regular_mesh.get_local_rank("cp_replicate")}-'
|
|
63
|
+
f'tp:{regular_mesh.get_local_rank("tp")}',
|
|
64
|
+
level=log_level
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
self._meshes = {}
|
|
68
|
+
self._num_nodes = 1
|
|
69
|
+
self._logger = build_dist_logger("local", level=log_level)
|
|
70
|
+
|
|
71
|
+
self._local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
72
|
+
self._global_rank = int(os.environ.get("RANK", "0"))
|
|
73
|
+
|
|
74
|
+
self._node_rank = self._global_rank // torch.cuda.device_count()
|
|
75
|
+
|
|
76
|
+
self._master_addr = _resolve_master_addr()
|
|
77
|
+
self._current_device = torch.device("cuda")
|
|
78
|
+
|
|
79
|
+
torch.cuda.set_device(self._local_rank)
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def logger(self) -> logging.Logger:
|
|
83
|
+
"""Returns the logger instance configured for distributed logging."""
|
|
84
|
+
|
|
85
|
+
return self._logger
|
|
86
|
+
|
|
87
|
+
def mesh_for(self, domain: str) -> DeviceMesh:
|
|
88
|
+
"""
|
|
89
|
+
Returns the device mesh view associated with a specific logical domain.
|
|
90
|
+
|
|
91
|
+
Available Domains and Dimensions:
|
|
92
|
+
* `regular` (`REGULAR_DOMAIN`): The most granular mesh for fully decomposed parallelism.
|
|
93
|
+
Dimensions: ``('pp', 'dp_replicate', 'dp_shard', 'cp_shard', 'cp_replicate', 'tp')``
|
|
94
|
+
* `expert` (`EXPERT_DOMAIN`): Mesh optimized for distributing MoE (Mixture of Experts) layers.
|
|
95
|
+
Dimensions: ``('pp', 'replicate', 'ep')``
|
|
96
|
+
* `dense` (`DENSE_DOMAIN`): Mesh optimized for distributing dense layers.
|
|
97
|
+
Dimensions: ``('pp', 'dp_replicate', 'dp_cp_shard', 'cp_replicate', 'tp')``
|
|
98
|
+
* `batch` (`BATCH_DOMAIN`): Mesh optimized for distributing input data.
|
|
99
|
+
Dimensions: ``('pp', 'dp', 'cp', 'tp')``
|
|
100
|
+
* `flat` (`FLAT_DOMAIN`): Mesh containing a single dimension with all the processes.
|
|
101
|
+
Dimensions: ``('world')``
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
domain: The name of the domain to retrieve.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
The PyTorch DeviceMesh configured for the requested domain.
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
ValueError: If the specified domain does not exist.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
if domain not in self._meshes:
|
|
114
|
+
raise ValueError(f"Domain {domain} does not exist")
|
|
115
|
+
return self._meshes[domain]
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def is_main_process(self) -> bool:
|
|
119
|
+
"""Checks if the current process is the global rank 0."""
|
|
120
|
+
|
|
121
|
+
return self._global_rank == 0
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def is_local_main_process(self) -> bool:
|
|
125
|
+
"""Checks if the current process is the rank 0 on the specific node."""
|
|
126
|
+
|
|
127
|
+
return self._local_rank == 0
|
|
128
|
+
|
|
129
|
+
def wait_world(self):
|
|
130
|
+
"""Blocks process execution until all ranks reach this point."""
|
|
131
|
+
|
|
132
|
+
torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
|
|
133
|
+
torch.cuda.synchronize()
|
|
134
|
+
|
|
135
|
+
def set_timeout(self, timeout_seconds: float):
|
|
136
|
+
"""
|
|
137
|
+
Updates the NCCL/process group timeout for all underlying meshes.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
timeout_seconds: New timeout duration in seconds.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
self.logger.info(f"Setting global timeout to {timeout_seconds} seconds")
|
|
144
|
+
self.wait_world()
|
|
145
|
+
|
|
146
|
+
groups: list[torch.distributed.ProcessGroup | None] = [None]
|
|
147
|
+
for mesh in self._meshes.values():
|
|
148
|
+
for dim in range(mesh.ndim):
|
|
149
|
+
groups.append(mesh.get_group(dim))
|
|
150
|
+
|
|
151
|
+
for group in groups:
|
|
152
|
+
torch.distributed.distributed_c10d._set_pg_timeout(datetime.timedelta(seconds=timeout_seconds), group) # noqa: SLF001
|
|
153
|
+
|
|
154
|
+
@contextmanager
|
|
155
|
+
def local_main_process_first(self):
|
|
156
|
+
"""
|
|
157
|
+
Context manager that executes the block on the local main process first.
|
|
158
|
+
|
|
159
|
+
Other local ranks wait at the entrance. The local main process waits at the
|
|
160
|
+
exit to synchronize before continuing.
|
|
161
|
+
"""
|
|
162
|
+
if not self.is_local_main_process:
|
|
163
|
+
self.wait_world()
|
|
164
|
+
|
|
165
|
+
yield
|
|
166
|
+
|
|
167
|
+
if self.is_local_main_process:
|
|
168
|
+
self.wait_world()
|
|
169
|
+
|
|
170
|
+
@contextmanager
|
|
171
|
+
def main_process_first(self):
|
|
172
|
+
"""
|
|
173
|
+
Context manager that executes the block on the global main process first.
|
|
174
|
+
|
|
175
|
+
All other ranks wait at the entrance. The global main process waits at the
|
|
176
|
+
exit to synchronize before continuing.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
if not self.is_main_process:
|
|
180
|
+
self.wait_world()
|
|
181
|
+
|
|
182
|
+
yield
|
|
183
|
+
|
|
184
|
+
if self.is_main_process:
|
|
185
|
+
self.wait_world()
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def current_device(self) -> torch.device:
|
|
189
|
+
"""Returns the CUDA device associated with this rank."""
|
|
190
|
+
|
|
191
|
+
return self._current_device
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def mesh_params(self) -> "DeviceMeshParameters":
|
|
195
|
+
"""Returns the parameters used to initialize this context."""
|
|
196
|
+
|
|
197
|
+
return self._params
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def master_addr(self) -> str:
|
|
201
|
+
"""Returns the IP address or domain name of the master node."""
|
|
202
|
+
|
|
203
|
+
return self._master_addr
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def node_rank(self) -> int:
|
|
207
|
+
"""Returns the index of the node this process is running on."""
|
|
208
|
+
|
|
209
|
+
return self._node_rank
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def num_nodes(self) -> int:
|
|
213
|
+
"""Returns the total number of nodes in the cluster."""
|
|
214
|
+
|
|
215
|
+
return self._num_nodes
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from torch.distributed import DeviceMesh, init_device_mesh
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .params import DeviceMeshParameters
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DeviceMeshDomain(abc.ABC):
|
|
11
|
+
"""
|
|
12
|
+
Abstract base class for a Device Mesh provider.
|
|
13
|
+
|
|
14
|
+
A Domain defines a specific strategy for organizing available GPUs into a
|
|
15
|
+
multidimensional grid (Mesh) to support specific parallelism techniques.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def name(self) -> str:
|
|
21
|
+
"""Returns the unique identifier for this mesh domain."""
|
|
22
|
+
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
@abc.abstractmethod
|
|
26
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
27
|
+
"""
|
|
28
|
+
Constructs the device mesh configuration.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
params: Global configuration parameters for the distributed environment.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The initialized PyTorch DeviceMesh for this specific domain.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
REGULAR_DOMAIN = "regular"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class RegularDomain(DeviceMeshDomain):
|
|
44
|
+
@property
|
|
45
|
+
def name(self) -> str:
|
|
46
|
+
return "regular"
|
|
47
|
+
|
|
48
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
49
|
+
return init_device_mesh(
|
|
50
|
+
device_type="cuda",
|
|
51
|
+
mesh_shape=(
|
|
52
|
+
params.pipeline_parallel,
|
|
53
|
+
params.data_parallel_replicate,
|
|
54
|
+
params.data_parallel_shard,
|
|
55
|
+
params.context_parallel_shard,
|
|
56
|
+
params.context_parallel_replicate,
|
|
57
|
+
params.tensor_parallel
|
|
58
|
+
),
|
|
59
|
+
mesh_dim_names=(
|
|
60
|
+
"pp",
|
|
61
|
+
"dp_replicate",
|
|
62
|
+
"dp_shard",
|
|
63
|
+
"cp_shard",
|
|
64
|
+
"cp_replicate",
|
|
65
|
+
"tp"
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
EXPERT_DOMAIN = "expert"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ExpertDomain(DeviceMeshDomain):
|
|
74
|
+
@property
|
|
75
|
+
def name(self) -> str:
|
|
76
|
+
return EXPERT_DOMAIN
|
|
77
|
+
|
|
78
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
79
|
+
replicate_degree = (
|
|
80
|
+
params.data_parallel_replicate *
|
|
81
|
+
params.context_parallel_replicate *
|
|
82
|
+
params.data_parallel_shard *
|
|
83
|
+
params.context_parallel_shard
|
|
84
|
+
)
|
|
85
|
+
return init_device_mesh(
|
|
86
|
+
device_type="cuda",
|
|
87
|
+
mesh_shape=(
|
|
88
|
+
params.pipeline_parallel,
|
|
89
|
+
replicate_degree // params.expert_parallel,
|
|
90
|
+
params.expert_parallel
|
|
91
|
+
),
|
|
92
|
+
mesh_dim_names=(
|
|
93
|
+
"pp",
|
|
94
|
+
"ep_replicate",
|
|
95
|
+
"ep_shard"
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
DENSE_DOMAIN = "dense"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class DenseDomain(DeviceMeshDomain):
|
|
104
|
+
@property
|
|
105
|
+
def name(self) -> str:
|
|
106
|
+
return DENSE_DOMAIN
|
|
107
|
+
|
|
108
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
109
|
+
return init_device_mesh(
|
|
110
|
+
device_type="cuda",
|
|
111
|
+
mesh_shape=(
|
|
112
|
+
params.pipeline_parallel,
|
|
113
|
+
params.data_parallel_replicate,
|
|
114
|
+
params.data_parallel_shard * params.context_parallel_shard,
|
|
115
|
+
params.context_parallel_replicate,
|
|
116
|
+
params.tensor_parallel
|
|
117
|
+
),
|
|
118
|
+
mesh_dim_names=(
|
|
119
|
+
"pp",
|
|
120
|
+
"dp_replicate",
|
|
121
|
+
"dp_cp_shard",
|
|
122
|
+
"cp_replicate",
|
|
123
|
+
"tp"
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
BATCH_DOMAIN = "batch"
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class BatchDomain(DeviceMeshDomain):
|
|
132
|
+
@property
|
|
133
|
+
def name(self) -> str:
|
|
134
|
+
return BATCH_DOMAIN
|
|
135
|
+
|
|
136
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
137
|
+
return init_device_mesh(
|
|
138
|
+
device_type="cuda",
|
|
139
|
+
mesh_shape=(
|
|
140
|
+
params.pipeline_parallel,
|
|
141
|
+
params.data_parallel_replicate * params.data_parallel_shard,
|
|
142
|
+
params.context_parallel_replicate * params.context_parallel_shard,
|
|
143
|
+
params.tensor_parallel
|
|
144
|
+
),
|
|
145
|
+
mesh_dim_names=(
|
|
146
|
+
"pp",
|
|
147
|
+
"dp",
|
|
148
|
+
"cp",
|
|
149
|
+
"tp"
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
FLAT_DOMAIN = "flat"
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class FlatDomain(DeviceMeshDomain):
|
|
158
|
+
@property
|
|
159
|
+
def name(self) -> str:
|
|
160
|
+
return FLAT_DOMAIN
|
|
161
|
+
|
|
162
|
+
def build_mesh(self, params: "DeviceMeshParameters") -> DeviceMesh:
|
|
163
|
+
mesh_shape = (
|
|
164
|
+
params.pipeline_parallel *
|
|
165
|
+
params.data_parallel_replicate *
|
|
166
|
+
params.data_parallel_shard *
|
|
167
|
+
params.context_parallel_replicate *
|
|
168
|
+
params.context_parallel_shard *
|
|
169
|
+
params.tensor_parallel
|
|
170
|
+
)
|
|
171
|
+
return init_device_mesh(
|
|
172
|
+
device_type="cuda",
|
|
173
|
+
mesh_shape=(
|
|
174
|
+
mesh_shape,
|
|
175
|
+
),
|
|
176
|
+
mesh_dim_names=(
|
|
177
|
+
"world",
|
|
178
|
+
)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
ALL_DOMAIN_PROVIDERS: list[DeviceMeshDomain] = [
|
|
183
|
+
RegularDomain(), DenseDomain(), ExpertDomain(), BatchDomain(),
|
|
184
|
+
FlatDomain()
|
|
185
|
+
]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def build_dist_logger(qualifier: str, level: int) -> logging.Logger:
|
|
6
|
+
"""
|
|
7
|
+
Configures and returns a logger instance for d9d.
|
|
8
|
+
|
|
9
|
+
The logger is configured to write to stdout with a formatter that includes
|
|
10
|
+
the provided rank qualifier, allowing for easier debugging in distributed logs.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
qualifier: A string identifying the current rank's position in the mesh.
|
|
14
|
+
level: Log level to set by default
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
A configured logging.Logger instance.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
dist_logger = logging.getLogger("d9d")
|
|
21
|
+
dist_logger.setLevel(level)
|
|
22
|
+
dist_logger.handlers.clear()
|
|
23
|
+
ch = logging.StreamHandler(sys.stdout)
|
|
24
|
+
ch.setLevel(level)
|
|
25
|
+
formatter = logging.Formatter(
|
|
26
|
+
f"[d9d] [{qualifier}] %(asctime)s - %(levelname)s - %(message)s"
|
|
27
|
+
)
|
|
28
|
+
ch.setFormatter(formatter)
|
|
29
|
+
dist_logger.addHandler(ch)
|
|
30
|
+
return dist_logger
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Self
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
|
5
|
+
|
|
6
|
+
from .configured import DistributedContext
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DeviceMeshParameters(BaseModel):
|
|
10
|
+
"""
|
|
11
|
+
Configuration parameters for initializing Distributed Device Meshes.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
pipeline_parallel: Degree of pipeline parallelism (PP).
|
|
15
|
+
data_parallel_replicate: Degree of data parallel replication (DDP).
|
|
16
|
+
data_parallel_shard: Degree of data parallel sharding (FSDP).
|
|
17
|
+
context_parallel_replicate: Degree of context parallel (CP) replication.
|
|
18
|
+
context_parallel_shard: Degree of context parallel (FSCP) sharding.
|
|
19
|
+
tensor_parallel: Degree of tensor parallelism (TP).
|
|
20
|
+
expert_parallel: Degree of expert parallelism (EP/MoE).
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
model_config = ConfigDict(frozen=True)
|
|
24
|
+
|
|
25
|
+
pipeline_parallel: int = 1
|
|
26
|
+
|
|
27
|
+
data_parallel_replicate: int = 1
|
|
28
|
+
data_parallel_shard: int = 1
|
|
29
|
+
|
|
30
|
+
context_parallel_replicate: int = 1
|
|
31
|
+
context_parallel_shard: int = 1
|
|
32
|
+
|
|
33
|
+
tensor_parallel: int = 1
|
|
34
|
+
|
|
35
|
+
expert_parallel: int = 1
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def has_pipeline_parallel(self) -> bool:
|
|
39
|
+
"""Checks if pipeline parallelism is enabled (degree > 1)."""
|
|
40
|
+
|
|
41
|
+
return self.pipeline_parallel > 1
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def has_data_parallel_replicate(self) -> bool:
|
|
45
|
+
"""Checks if data parallel replication is enabled (degree > 1)."""
|
|
46
|
+
|
|
47
|
+
return self.data_parallel_replicate > 1
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def has_data_parallel_shard(self) -> bool:
|
|
51
|
+
"""Checks if data parallel sharding is enabled (degree > 1)."""
|
|
52
|
+
|
|
53
|
+
return self.data_parallel_shard > 1
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def has_context_parallel_replicate(self) -> bool:
|
|
57
|
+
return self.context_parallel_replicate > 1
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def has_context_parallel_shard(self) -> bool:
|
|
61
|
+
return self.context_parallel_shard > 1
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def has_tensor_parallel(self) -> bool:
|
|
65
|
+
return self.tensor_parallel > 1
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def has_expert_parallel(self) -> bool:
|
|
69
|
+
"""Checks if expert parallelism is enabled (degree > 1)."""
|
|
70
|
+
return self.expert_parallel > 1
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def is_distributed(self) -> bool:
|
|
74
|
+
"""Checks if any form of parallelism is enabled."""
|
|
75
|
+
|
|
76
|
+
return (
|
|
77
|
+
self.has_pipeline_parallel or
|
|
78
|
+
self.has_data_parallel_replicate or
|
|
79
|
+
self.has_data_parallel_shard or
|
|
80
|
+
self.has_context_parallel_shard or
|
|
81
|
+
self.has_context_parallel_replicate or
|
|
82
|
+
self.has_expert_parallel or
|
|
83
|
+
self.has_tensor_parallel
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@model_validator(mode="after")
|
|
87
|
+
def _check_ep_divisibility(self) -> Self:
|
|
88
|
+
"""Validates that DP/CP/TP dimensions can support the requested EP/ETP degrees."""
|
|
89
|
+
dp_cp_tp_degree = (
|
|
90
|
+
self.data_parallel_shard *
|
|
91
|
+
self.data_parallel_replicate *
|
|
92
|
+
self.context_parallel_shard *
|
|
93
|
+
self.context_parallel_replicate *
|
|
94
|
+
self.tensor_parallel
|
|
95
|
+
)
|
|
96
|
+
ep_degree = self.expert_parallel
|
|
97
|
+
|
|
98
|
+
if dp_cp_tp_degree % ep_degree != 0:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Total data/context/tensor parallelism degree ({dp_cp_tp_degree}) must be divisible by "
|
|
101
|
+
f"total expert parallelism degree ({ep_degree})."
|
|
102
|
+
)
|
|
103
|
+
return self
|
|
104
|
+
|
|
105
|
+
def build(self, log_level: int = logging.INFO) -> "DistributedContext":
|
|
106
|
+
"""
|
|
107
|
+
Initializes the DistributedContext using these parameters.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
A new DistributedContext instance containing the initialized device meshes.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
return DistributedContext(self, log_level)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides high-level wrappers around `torch.distributed` collective operations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from .object import all_gather_object, gather_object
|
|
7
|
+
from .tensor import all_gather, all_gather_variadic_shape, gather, gather_variadic_shape
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"all_gather",
|
|
11
|
+
"all_gather_object",
|
|
12
|
+
"all_gather_variadic_shape",
|
|
13
|
+
"gather",
|
|
14
|
+
"gather_object",
|
|
15
|
+
"gather_variadic_shape"
|
|
16
|
+
]
|