d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.distributed as dist
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from d9d.pipelining.api import ModuleSupportsPipelining, PipelineStageInfo
|
|
8
|
+
|
|
9
|
+
from .communications import StageCommunicationHandler
|
|
10
|
+
from .computations import BackwardComputeHandler, ForwardComputeHandler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PipelineStage:
|
|
14
|
+
"""
|
|
15
|
+
Represents a single structural stage in a Pipelined Model.
|
|
16
|
+
|
|
17
|
+
This class acts as an orchestrator that combines `StageCommunicationHandler` (for I/O)
|
|
18
|
+
and `Forward/BackwardComputeHandler` (for execution). It abstracts away the complexity
|
|
19
|
+
of buffer management, distributed communication, and gradient calculation from the scheduler.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
info: PipelineStageInfo,
|
|
25
|
+
module: nn.Module,
|
|
26
|
+
group: dist.ProcessGroup,
|
|
27
|
+
stage_to_host_topology: dict[int, int]
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Constructs a PipelineStage object.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
info: Metadata about the stage (index, total stages).
|
|
34
|
+
module: The PyTorch module executed by this stage.
|
|
35
|
+
group: The distributed process group for pipeline communications.
|
|
36
|
+
stage_to_host_topology: Dict mapping stage ID to PP rank hosting it.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
self._info = info
|
|
40
|
+
self._module = module
|
|
41
|
+
self._group = group
|
|
42
|
+
self._stage_to_host_topology = stage_to_host_topology
|
|
43
|
+
|
|
44
|
+
self._has_backward = False
|
|
45
|
+
|
|
46
|
+
self._forward_comm: StageCommunicationHandler | None = None
|
|
47
|
+
self._backward_comm: StageCommunicationHandler | None = None
|
|
48
|
+
|
|
49
|
+
self._forward_comp = ForwardComputeHandler(
|
|
50
|
+
stage_index=info.current_stage,
|
|
51
|
+
module=module
|
|
52
|
+
)
|
|
53
|
+
self._backward_comp = BackwardComputeHandler(
|
|
54
|
+
stage_index=info.current_stage,
|
|
55
|
+
module=module
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def info(self) -> PipelineStageInfo:
|
|
60
|
+
return self._info
|
|
61
|
+
|
|
62
|
+
def configure_buffers(
|
|
63
|
+
self,
|
|
64
|
+
num_microbatches: int,
|
|
65
|
+
has_backward: bool,
|
|
66
|
+
pipeline_inputs: dict[str, torch.Tensor]
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Initializes the communication handlers and buffers for the stage.
|
|
70
|
+
|
|
71
|
+
This must be called before execution to establish P2P buffer sizes and directions.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
num_microbatches: Total number of microbatches to process.
|
|
75
|
+
has_backward: Does this pipeline stage should store info for a backward pass
|
|
76
|
+
pipeline_inputs: Pipeline input data.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
self._has_backward = has_backward
|
|
80
|
+
|
|
81
|
+
prev_stage_idx = None if self._info.is_current_stage_first else self._info.current_stage - 1
|
|
82
|
+
next_stage_idx = None if self._info.is_current_stage_last else self._info.current_stage + 1
|
|
83
|
+
|
|
84
|
+
with torch.device("meta"):
|
|
85
|
+
if not isinstance(self._module, ModuleSupportsPipelining):
|
|
86
|
+
raise TypeError("Module does not implement ModuleSupportsPipelining protocol")
|
|
87
|
+
inputs_meta = self._module.infer_stage_inputs_from_pipeline_inputs(
|
|
88
|
+
inputs=pipeline_inputs,
|
|
89
|
+
n_microbatches=num_microbatches
|
|
90
|
+
)
|
|
91
|
+
outputs_meta = self._module.infer_stage_outputs_from_pipeline_inputs(
|
|
92
|
+
inputs=pipeline_inputs,
|
|
93
|
+
n_microbatches=num_microbatches
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self._forward_comm = StageCommunicationHandler(
|
|
97
|
+
name="fwd",
|
|
98
|
+
stage_index=self._info.current_stage,
|
|
99
|
+
num_microbatches=num_microbatches,
|
|
100
|
+
input_stage_index=prev_stage_idx,
|
|
101
|
+
input_args=inputs_meta,
|
|
102
|
+
output_stage_index=next_stage_idx,
|
|
103
|
+
output_args=outputs_meta,
|
|
104
|
+
group=self._group,
|
|
105
|
+
stage_idx_to_host_rank=self._stage_to_host_topology
|
|
106
|
+
)
|
|
107
|
+
self._forward_comm.set_input_requires_grad_(requires_grad=has_backward)
|
|
108
|
+
|
|
109
|
+
if has_backward:
|
|
110
|
+
# for grad - current stage receives OUTPUTS as inputs and sends INPUTS as outputs
|
|
111
|
+
# because it is reversed forward
|
|
112
|
+
self._backward_comm = StageCommunicationHandler(
|
|
113
|
+
name="bwd",
|
|
114
|
+
stage_index=self._info.current_stage,
|
|
115
|
+
num_microbatches=num_microbatches,
|
|
116
|
+
input_stage_index=next_stage_idx,
|
|
117
|
+
input_args=outputs_meta,
|
|
118
|
+
output_stage_index=prev_stage_idx,
|
|
119
|
+
output_args=inputs_meta,
|
|
120
|
+
group=self._group,
|
|
121
|
+
stage_idx_to_host_rank=self._stage_to_host_topology
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
self._backward_comm = None
|
|
125
|
+
|
|
126
|
+
def set_local_fwd_input(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
|
|
127
|
+
"""
|
|
128
|
+
Sets local forward inputs manually.
|
|
129
|
+
|
|
130
|
+
Used for the V-shape schedulers.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
if self._forward_comm is None:
|
|
134
|
+
raise ValueError("You must configure stage buffers first")
|
|
135
|
+
|
|
136
|
+
self._forward_comm.set_inputs_local(inputs, microbatch_index)
|
|
137
|
+
|
|
138
|
+
def get_local_fwd_output(self, microbatch_index: int) -> dict[str, torch.Tensor]:
|
|
139
|
+
return self._forward_comp.get_outputs(microbatch_index)
|
|
140
|
+
|
|
141
|
+
def pop_local_bwd_output(self, microbatch_index: int) -> dict[str, torch.Tensor]:
|
|
142
|
+
"""
|
|
143
|
+
Retrieves local backward outputs (gradients).
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
if not self._has_backward:
|
|
147
|
+
raise ValueError()
|
|
148
|
+
|
|
149
|
+
return self._backward_comp.pop_for_sending(microbatch_index)
|
|
150
|
+
|
|
151
|
+
def set_local_bwd_input(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
|
|
152
|
+
"""
|
|
153
|
+
Sets local backward inputs (output gradients) manually.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
if not self._has_backward:
|
|
157
|
+
raise ValueError()
|
|
158
|
+
|
|
159
|
+
if self._backward_comm is None:
|
|
160
|
+
raise ValueError("You must configure stage buffers first")
|
|
161
|
+
|
|
162
|
+
self._backward_comm.set_inputs_local(inputs, microbatch_index)
|
|
163
|
+
|
|
164
|
+
def get_fwd_recv_ops(self, microbatch_index: int) -> list[dist.P2POp]:
|
|
165
|
+
"""Returns P2P ops to receive forward inputs for the given microbatch."""
|
|
166
|
+
|
|
167
|
+
if self._forward_comm is None:
|
|
168
|
+
raise ValueError("You must configure stage buffers first")
|
|
169
|
+
|
|
170
|
+
return self._forward_comm.create_receive_ops(microbatch_index)
|
|
171
|
+
|
|
172
|
+
def get_fwd_send_ops(self, microbatch_index: int) -> list[dist.P2POp]:
|
|
173
|
+
"""Returns P2P ops to send forward outputs for the given microbatch."""
|
|
174
|
+
|
|
175
|
+
if self._forward_comm is None:
|
|
176
|
+
raise ValueError("You must configure stage buffers first")
|
|
177
|
+
|
|
178
|
+
fwd_result = self._forward_comp.get_outputs(microbatch_index)
|
|
179
|
+
return self._forward_comm.create_send_ops(fwd_result)
|
|
180
|
+
|
|
181
|
+
def get_bwd_recv_ops(self, microbatch_index: int) -> list[dist.P2POp]:
|
|
182
|
+
"""Returns P2P ops to receive backward gradients for the given microbatch."""
|
|
183
|
+
|
|
184
|
+
if not self._has_backward:
|
|
185
|
+
return []
|
|
186
|
+
|
|
187
|
+
if self._backward_comm is None:
|
|
188
|
+
raise ValueError("You must configure stage buffers first")
|
|
189
|
+
|
|
190
|
+
return self._backward_comm.create_receive_ops(microbatch_index)
|
|
191
|
+
|
|
192
|
+
def get_bwd_send_ops(self, microbatch_index: int) -> list[dist.P2POp]:
|
|
193
|
+
"""Returns P2P ops to send backward gradients for the given microbatch."""
|
|
194
|
+
|
|
195
|
+
if not self._has_backward:
|
|
196
|
+
return []
|
|
197
|
+
|
|
198
|
+
if self._backward_comm is None:
|
|
199
|
+
raise ValueError("You must configure stage buffers first")
|
|
200
|
+
|
|
201
|
+
bwd_result = self._backward_comp.pop_for_sending(microbatch_index)
|
|
202
|
+
return self._backward_comm.create_send_ops(bwd_result)
|
|
203
|
+
|
|
204
|
+
def forward_one_chunk(
|
|
205
|
+
self,
|
|
206
|
+
microbatch_index: int,
|
|
207
|
+
pipeline_inputs: dict[str, torch.Tensor],
|
|
208
|
+
pipeline_kwargs: dict[str, Any] | None = None,
|
|
209
|
+
):
|
|
210
|
+
"""
|
|
211
|
+
Executes a forward pass for a single microbatch chunk.
|
|
212
|
+
|
|
213
|
+
Fetches inputs from the communication buffer (or `pipeline_inputs` if first stage),
|
|
214
|
+
runs the computation, and caches the result.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
microbatch_index: The microbatch index.
|
|
218
|
+
pipeline_inputs: Inputs provided locally (only used if this is the first stage).
|
|
219
|
+
pipeline_kwargs: Additional arguments for the module.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
The output tensors of the forward pass.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
if self._forward_comm is None:
|
|
226
|
+
raise ValueError("You must configure stage buffers first")
|
|
227
|
+
|
|
228
|
+
if self._info.is_current_stage_first:
|
|
229
|
+
inputs = pipeline_inputs
|
|
230
|
+
else:
|
|
231
|
+
inputs = self._forward_comm.get_inputs(microbatch_index)
|
|
232
|
+
|
|
233
|
+
kwargs = pipeline_kwargs or {}
|
|
234
|
+
|
|
235
|
+
self._forward_comp.run(
|
|
236
|
+
microbatch_index=microbatch_index,
|
|
237
|
+
inputs=inputs,
|
|
238
|
+
kwargs=kwargs
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def backward_one_chunk(
|
|
242
|
+
self,
|
|
243
|
+
microbatch_index: int,
|
|
244
|
+
loss: torch.Tensor | None = None,
|
|
245
|
+
full_backward: bool = True
|
|
246
|
+
):
|
|
247
|
+
"""
|
|
248
|
+
Executes a backward pass for a single microbatch chunk.
|
|
249
|
+
|
|
250
|
+
Can perform either a full backward or just the input gradients (if `full_backward=False`).
|
|
251
|
+
It fetches required data from forward cache and communication buffers.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
microbatch_index: The microbatch index.
|
|
255
|
+
loss: The loss tensor (only used if this is the last stage).
|
|
256
|
+
full_backward: If True, computes grads for inputs and weights. If False, only for inputs.
|
|
257
|
+
"""
|
|
258
|
+
|
|
259
|
+
if not self._has_backward:
|
|
260
|
+
raise ValueError()
|
|
261
|
+
|
|
262
|
+
if self._backward_comm is None:
|
|
263
|
+
raise ValueError("You must configure stage buffers first")
|
|
264
|
+
|
|
265
|
+
inputs, fwd_outputs = self._forward_comp.pop_inputs_outputs(microbatch_index)
|
|
266
|
+
|
|
267
|
+
outputs: dict[str, torch.Tensor]
|
|
268
|
+
outputs_grad: dict[str, torch.Tensor] | None
|
|
269
|
+
|
|
270
|
+
if self._info.is_current_stage_last:
|
|
271
|
+
if loss is None:
|
|
272
|
+
raise ValueError("Cannot perform backward on last stage without loss specified")
|
|
273
|
+
outputs = {"loss": loss}
|
|
274
|
+
outputs_grad = None
|
|
275
|
+
else:
|
|
276
|
+
outputs = fwd_outputs
|
|
277
|
+
outputs_grad = self._backward_comm.get_inputs(microbatch_index)
|
|
278
|
+
|
|
279
|
+
if full_backward:
|
|
280
|
+
self._backward_comp.backward_full(
|
|
281
|
+
microbatch_index=microbatch_index,
|
|
282
|
+
inputs=inputs,
|
|
283
|
+
outputs=outputs,
|
|
284
|
+
outputs_grad=outputs_grad
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
self._backward_comp.backward_input(
|
|
288
|
+
microbatch_index=microbatch_index,
|
|
289
|
+
inputs=inputs,
|
|
290
|
+
outputs=outputs,
|
|
291
|
+
outputs_grad=outputs_grad
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
if self._info.is_current_stage_last and not self._info.is_current_stage_first:
|
|
295
|
+
for t in fwd_outputs.values():
|
|
296
|
+
if not t._is_view(): # noqa: SLF001
|
|
297
|
+
t.detach_()
|
|
298
|
+
|
|
299
|
+
def backward_weight_one_chunk(self, microbatch_index: int):
|
|
300
|
+
"""
|
|
301
|
+
Executes the weight gradient accumulation part of the backward pass.
|
|
302
|
+
|
|
303
|
+
This assumes `backward_one_chunk(..., full_backward=False)` was already called
|
|
304
|
+
for this microbatch.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
microbatch_index: The microbatch index.
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
if not self._has_backward:
|
|
311
|
+
raise ValueError()
|
|
312
|
+
|
|
313
|
+
self._backward_comp.backward_weight(microbatch_index=microbatch_index)
|
|
314
|
+
|
|
315
|
+
def reset(self):
|
|
316
|
+
"""Resets the internal state of communication handlers, clearing gradients on buffers."""
|
|
317
|
+
|
|
318
|
+
if self._forward_comm is not None:
|
|
319
|
+
self._forward_comm.reset()
|
|
320
|
+
if self._backward_comm is not None:
|
|
321
|
+
self._backward_comm.reset()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from collections.abc import Iterable, Sequence
|
|
2
|
+
from typing import TypeVar
|
|
3
|
+
|
|
4
|
+
T = TypeVar("T")
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DictFlattener:
|
|
8
|
+
"""
|
|
9
|
+
Helper class to flatten and unflatten dictionaries into sequences deterministically.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, keys: Iterable[str]):
|
|
13
|
+
"""
|
|
14
|
+
Constructs a DictFlattener object.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
keys: The collection of dictionary keys to manage. They will be sorted internally.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
self._order_to_key = {i: x for i, x in enumerate(sorted(keys))}
|
|
21
|
+
|
|
22
|
+
def flatten(self, inputs: dict[str, T]) -> list[T]:
|
|
23
|
+
"""
|
|
24
|
+
Converts a dictionary into a list based on the sorted internal key order.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
inputs: The dictionary to flatten. Must contain all keys provided at init.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A list of values sorted by their corresponding keys.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
return [inputs[self._order_to_key[i]] for i in range(len(inputs))]
|
|
34
|
+
|
|
35
|
+
def unflatten(self, outputs: Sequence[T]) -> dict[str, T]:
|
|
36
|
+
"""
|
|
37
|
+
Reconstructs a dictionary from a sequence of values.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
outputs: A sequence of values corresponding to the sorted internal key order.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
A dictionary mapping original keys to the provided values.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
return {self._order_to_key[i]: out for i, out in enumerate(outputs)}
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from torch.distributed import DeviceMesh
|
|
4
|
+
|
|
5
|
+
from d9d.core.protocol import OptimizerProtocol
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PipelinedOptimizer(OptimizerProtocol):
|
|
9
|
+
"""
|
|
10
|
+
Wrapper that manages multiple optimizers for a pipeline parallel rank.
|
|
11
|
+
|
|
12
|
+
In a pipeline parallel setup, a single rank might host multiple stages, each having its own parameters
|
|
13
|
+
and optimizer.
|
|
14
|
+
This class aggregates them into a single interface.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, mesh_pp: DeviceMesh, optimizers: list[OptimizerProtocol]):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
self._mesh_pp = mesh_pp
|
|
21
|
+
self._optimizers = optimizers
|
|
22
|
+
|
|
23
|
+
def state_dict(self) -> dict[str, Any]:
|
|
24
|
+
pp_rank = self._mesh_pp.get_local_rank()
|
|
25
|
+
return {
|
|
26
|
+
f"pp_{pp_rank}_stage_{i}": optimizer.state_dict()
|
|
27
|
+
for i, optimizer in enumerate(self._optimizers)
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
31
|
+
pp_rank = self._mesh_pp.get_local_rank()
|
|
32
|
+
for i, optimizer in enumerate(self._optimizers):
|
|
33
|
+
optimizer.load_state_dict(state_dict[f"pp_{pp_rank}_stage_{i}"])
|
|
34
|
+
|
|
35
|
+
def step(self) -> None:
|
|
36
|
+
for optimizer in self._optimizers:
|
|
37
|
+
optimizer.step()
|
|
38
|
+
|
|
39
|
+
def zero_grad(self) -> None:
|
|
40
|
+
for optimizer in self._optimizers:
|
|
41
|
+
optimizer.zero_grad()
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from torch.distributed import DeviceMesh
|
|
4
|
+
|
|
5
|
+
from d9d.core.protocol import LRSchedulerProtocol
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PipelinedLRScheduler(LRSchedulerProtocol):
|
|
9
|
+
"""
|
|
10
|
+
Wrapper that manages multiple LR schedulers for a pipeline parallel rank.
|
|
11
|
+
|
|
12
|
+
Similar to `PipelinedOptimizer`, this aggregates schedulers corresponding to
|
|
13
|
+
multiple model stages hosted on the current rank.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, mesh_pp: DeviceMesh, schedulers: list[LRSchedulerProtocol]):
|
|
17
|
+
self._mesh_pp = mesh_pp
|
|
18
|
+
self._schedulers = schedulers
|
|
19
|
+
|
|
20
|
+
def state_dict(self) -> dict[str, Any]:
|
|
21
|
+
pp_rank = self._mesh_pp.get_local_rank()
|
|
22
|
+
return {
|
|
23
|
+
f"pp_{pp_rank}_stage_{i}": scheduler.state_dict()
|
|
24
|
+
for i, scheduler in enumerate(self._schedulers)
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
28
|
+
pp_rank = self._mesh_pp.get_local_rank()
|
|
29
|
+
for i, scheduler in enumerate(self._schedulers):
|
|
30
|
+
scheduler.load_state_dict(state_dict[f"pp_{pp_rank}_stage_{i}"])
|
|
31
|
+
|
|
32
|
+
def step(self) -> None:
|
|
33
|
+
for scheduler in self._schedulers:
|
|
34
|
+
scheduler.step()
|
d9d/tracker/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Package providing a unified interface for experiment tracking and logging.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .base import BaseTracker, BaseTrackerRun, RunConfig
|
|
6
|
+
from .factory import AnyTrackerConfig, tracker_from_config
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"AnyTrackerConfig",
|
|
10
|
+
"BaseTracker",
|
|
11
|
+
"BaseTrackerRun",
|
|
12
|
+
"RunConfig",
|
|
13
|
+
"tracker_from_config"
|
|
14
|
+
]
|
d9d/tracker/base.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from collections.abc import Generator
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import Any, Generic, Self, TypeVar
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseTrackerRun(abc.ABC):
|
|
12
|
+
"""
|
|
13
|
+
Abstract base class representing an active tracking session (run).
|
|
14
|
+
|
|
15
|
+
This object is responsible for the actual logging of metrics, parameters,
|
|
16
|
+
during train or inference run.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def set_step(self, step: int):
|
|
21
|
+
"""
|
|
22
|
+
Updates the global step counter for subsequent logs.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
step: The current step index (e.g., iteration number).
|
|
26
|
+
"""
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
@abc.abstractmethod
|
|
30
|
+
def set_context(self, context: dict[str, str]):
|
|
31
|
+
"""
|
|
32
|
+
Sets a persistent context dictionary for subsequent logs.
|
|
33
|
+
|
|
34
|
+
These context values (tags) will be attached to every metric logged
|
|
35
|
+
until changed.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
context: A dictionary of tag names and values.
|
|
39
|
+
"""
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
@abc.abstractmethod
|
|
43
|
+
def scalar(self, name: str, value: float, context: dict[str, str] | None = None):
|
|
44
|
+
"""
|
|
45
|
+
Logs a scalar value.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
name: The name of the metric.
|
|
49
|
+
value: The scalar value to log.
|
|
50
|
+
context: Optional ephemeral context specific to this metric event.
|
|
51
|
+
Merged with global context if present.
|
|
52
|
+
"""
|
|
53
|
+
...
|
|
54
|
+
|
|
55
|
+
@abc.abstractmethod
|
|
56
|
+
def bins(self, name: str, values: torch.Tensor, context: dict[str, str] | None = None):
|
|
57
|
+
"""
|
|
58
|
+
Logs a distribution/histogram of values.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
name: The name of the metric.
|
|
62
|
+
values: A tensor containing the population of values to bin.
|
|
63
|
+
context: Optional ephemeral context specific to this metric event.
|
|
64
|
+
Merged with global context if present.
|
|
65
|
+
"""
|
|
66
|
+
...
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class RunConfig(BaseModel):
|
|
70
|
+
"""
|
|
71
|
+
Configuration for initializing a specific logged run.
|
|
72
|
+
|
|
73
|
+
Attributes:
|
|
74
|
+
name: The display name of the experiment.
|
|
75
|
+
description: An optional description of the experiment.
|
|
76
|
+
hparams: A dictionary of hyperparameters to log at the start of the run.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
name: str
|
|
80
|
+
description: str | None
|
|
81
|
+
hparams: dict[str, Any] = Field(default_factory=dict)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
TConfig = TypeVar("TConfig", bound=BaseModel)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class BaseTracker(abc.ABC, Stateful, Generic[TConfig]):
|
|
88
|
+
"""
|
|
89
|
+
Abstract base class for a tracker backend factory.
|
|
90
|
+
|
|
91
|
+
This class manages the lifecycle of runs and integration with the
|
|
92
|
+
distributed checkpointing system to ensure experiment continuity
|
|
93
|
+
(e.g., resuming the same run hash after a restart).
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
@contextmanager
|
|
97
|
+
@abc.abstractmethod
|
|
98
|
+
def open(self, properties: RunConfig) -> Generator[BaseTrackerRun, None, None]:
|
|
99
|
+
"""
|
|
100
|
+
Context manager that initiates and manages an experiment run.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
properties: Configuration metadata for the run.
|
|
104
|
+
|
|
105
|
+
Yields:
|
|
106
|
+
An active BaseTrackerRun instance for logging metrics.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
...
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
@abc.abstractmethod
|
|
113
|
+
def from_config(cls, config: TConfig) -> Self:
|
|
114
|
+
"""
|
|
115
|
+
Factory method to create a tracker instance from a configuration object.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
config: The backend-specific configuration object.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
An initialized instance of the tracker.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
...
|
d9d/tracker/factory.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import Annotated
|
|
3
|
+
|
|
4
|
+
from pydantic import Field
|
|
5
|
+
|
|
6
|
+
from .base import BaseTracker
|
|
7
|
+
from .provider.aim.config import AimConfig
|
|
8
|
+
from .provider.null import NullTracker, NullTrackerConfig
|
|
9
|
+
|
|
10
|
+
AnyTrackerConfig = Annotated[AimConfig | NullTrackerConfig, Field(discriminator="provider")]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclasses.dataclass
|
|
14
|
+
class _TrackerImportFailed:
|
|
15
|
+
dependency: str
|
|
16
|
+
exception: ImportError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
_MAP: dict[type[AnyTrackerConfig], type[BaseTracker] | _TrackerImportFailed] = {
|
|
20
|
+
NullTrackerConfig: NullTracker
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from .provider.aim.tracker import AimTracker
|
|
25
|
+
|
|
26
|
+
_MAP[AimConfig] = AimTracker
|
|
27
|
+
except ImportError as e:
|
|
28
|
+
_MAP[AimConfig] = _TrackerImportFailed(dependency="aim", exception=e)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def tracker_from_config(config: AnyTrackerConfig) -> BaseTracker:
|
|
32
|
+
"""
|
|
33
|
+
Instantiates a specific tracker implementation based on the configuration.
|
|
34
|
+
|
|
35
|
+
Based on the 'provider' field in the config, this function selects the
|
|
36
|
+
appropriate backend (e.g., Aim, Null). It handles checking for missing
|
|
37
|
+
dependencies for optional backends.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
config: A specific tracker configuration object.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
An initialized BaseTracker instance.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ImportError: If the dependencies for the requested provider are not installed.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
tracker_type = _MAP[type(config)]
|
|
50
|
+
|
|
51
|
+
if isinstance(tracker_type, _TrackerImportFailed):
|
|
52
|
+
raise ImportError(
|
|
53
|
+
f"The tracker configuration {config.provider} could not be loaded - "
|
|
54
|
+
f"ensure these dependencies are installed: {tracker_type.dependency}"
|
|
55
|
+
) from tracker_type.exception
|
|
56
|
+
|
|
57
|
+
return tracker_type.from_config(config)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AimConfig(BaseModel):
|
|
7
|
+
"""
|
|
8
|
+
Configuration for the Aim tracker backend.
|
|
9
|
+
|
|
10
|
+
Attributes:
|
|
11
|
+
provider: Discriminator field, must be 'aim'.
|
|
12
|
+
repo: Path to the Aim repository directory or URL.
|
|
13
|
+
log_system_params: Whether to log system resource usage (CPU/GPU/Memory).
|
|
14
|
+
capture_terminal_logs: Whether to capture stdout/stderr.
|
|
15
|
+
system_tracking_interval: Interval in seconds for system monitoring.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
provider: Literal["aim"] = "aim"
|
|
19
|
+
|
|
20
|
+
repo: str
|
|
21
|
+
log_system_params: bool = True
|
|
22
|
+
capture_terminal_logs: bool = True
|
|
23
|
+
system_tracking_interval: int = 10
|