d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from abc import ABC
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Annotated, Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torch.optim import SGD, Adam, AdamW, Optimizer
|
|
10
|
+
|
|
11
|
+
from d9d.loop.control import InitializeOptimizerStageContext, OptimizerProvider
|
|
12
|
+
from d9d.optim.stochastic import StochasticAdamW
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaseAutoOptimizerConfig(BaseModel, ABC):
|
|
16
|
+
"""
|
|
17
|
+
Abstract base class for optimizer configurations.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
@abc.abstractmethod
|
|
21
|
+
def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
|
|
22
|
+
"""
|
|
23
|
+
Creates the PyTorch optimizer instance.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
params: An iterable of model parameters to optimize.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
The instantiated optimizer.
|
|
30
|
+
"""
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class StochasticAdamWOptimizerConfig(BaseAutoOptimizerConfig):
|
|
35
|
+
"""
|
|
36
|
+
Configuration for the Stochastic AdamW optimizer.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
name: Discriminator tag.
|
|
40
|
+
lr: Learning rate.
|
|
41
|
+
betas: Coefficients used for computing running averages of gradient and its square.
|
|
42
|
+
eps: Term added to the denominator to improve numerical stability.
|
|
43
|
+
weight_decay: Weight decay coefficient.
|
|
44
|
+
state_dtype: Data Type to use for the optimizer states.
|
|
45
|
+
"""
|
|
46
|
+
name: Literal["stochastic_adamw"] = "stochastic_adamw"
|
|
47
|
+
|
|
48
|
+
lr: float
|
|
49
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
50
|
+
eps: float = 1e-8
|
|
51
|
+
weight_decay: float = 1e-2
|
|
52
|
+
state_dtype: str
|
|
53
|
+
|
|
54
|
+
def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
|
|
55
|
+
"""Builds StochasticAdamW with the configured parameters."""
|
|
56
|
+
return StochasticAdamW(
|
|
57
|
+
params=params,
|
|
58
|
+
lr=self.lr,
|
|
59
|
+
betas=self.betas,
|
|
60
|
+
eps=self.eps,
|
|
61
|
+
weight_decay=self.weight_decay,
|
|
62
|
+
state_dtype=getattr(torch, self.state_dtype)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class AdamWOptimizerConfig(BaseAutoOptimizerConfig):
|
|
67
|
+
"""
|
|
68
|
+
Configuration for the PyTorch AdamW optimizer.
|
|
69
|
+
|
|
70
|
+
Attributes:
|
|
71
|
+
name: Discriminator tag.
|
|
72
|
+
lr: The learning rate.
|
|
73
|
+
betas: Coefficients for computing running averages of gradient and its square.
|
|
74
|
+
eps: Term added to the denominator to improve numerical stability.
|
|
75
|
+
weight_decay: Weight decay coefficient.
|
|
76
|
+
amsgrad: Whether to use the AMSGrad variant.
|
|
77
|
+
maximize: Whether to maximize the params based on the objective (as opposed to minimizing).
|
|
78
|
+
"""
|
|
79
|
+
name: Literal["adamw"] = "adamw"
|
|
80
|
+
|
|
81
|
+
lr: float
|
|
82
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
83
|
+
eps: float = 1e-8
|
|
84
|
+
weight_decay: float = 1e-2
|
|
85
|
+
amsgrad: bool = False
|
|
86
|
+
maximize: bool = False
|
|
87
|
+
|
|
88
|
+
def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
|
|
89
|
+
"""Builds fused AdamW with the configured parameters."""
|
|
90
|
+
return AdamW(
|
|
91
|
+
params=params,
|
|
92
|
+
lr=self.lr,
|
|
93
|
+
betas=self.betas,
|
|
94
|
+
eps=self.eps,
|
|
95
|
+
weight_decay=self.weight_decay,
|
|
96
|
+
amsgrad=self.amsgrad,
|
|
97
|
+
maximize=self.maximize,
|
|
98
|
+
fused=True
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class AdamOptimizerConfig(BaseAutoOptimizerConfig):
|
|
103
|
+
"""
|
|
104
|
+
Configuration for the PyTorch Adam optimizer.
|
|
105
|
+
|
|
106
|
+
Attributes:
|
|
107
|
+
name: Discriminator tag.
|
|
108
|
+
lr: The learning rate.
|
|
109
|
+
betas: Coefficients for computing running averages of gradient and its square.
|
|
110
|
+
eps: Term added to the denominator to improve numerical stability.
|
|
111
|
+
weight_decay: Weight decay coefficient.
|
|
112
|
+
decoupled_weight_decay: Whether to apply decoupled weight decay.
|
|
113
|
+
amsgrad: Whether to use the AMSGrad variant.
|
|
114
|
+
maximize: Whether to maximize the params based on the objective.
|
|
115
|
+
"""
|
|
116
|
+
name: Literal["adam"] = "adam"
|
|
117
|
+
|
|
118
|
+
lr: float
|
|
119
|
+
betas: tuple[float, float] = (0.9, 0.999)
|
|
120
|
+
eps: float = 1e-8
|
|
121
|
+
weight_decay: float = 1e-2
|
|
122
|
+
decoupled_weight_decay: bool = False
|
|
123
|
+
amsgrad: bool = False
|
|
124
|
+
maximize: bool = False
|
|
125
|
+
|
|
126
|
+
def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
|
|
127
|
+
"""Builds fused Adam with the configured parameters."""
|
|
128
|
+
return Adam(
|
|
129
|
+
params=params,
|
|
130
|
+
lr=self.lr,
|
|
131
|
+
betas=self.betas,
|
|
132
|
+
eps=self.eps,
|
|
133
|
+
weight_decay=self.weight_decay,
|
|
134
|
+
decoupled_weight_decay=self.decoupled_weight_decay,
|
|
135
|
+
amsgrad=self.amsgrad,
|
|
136
|
+
maximize=self.maximize,
|
|
137
|
+
fused=True
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class SGDOptimizerConfig(BaseAutoOptimizerConfig):
|
|
142
|
+
"""
|
|
143
|
+
Configuration for the PyTorch SGD optimizer.
|
|
144
|
+
|
|
145
|
+
Attributes:
|
|
146
|
+
name: Discriminator tag.
|
|
147
|
+
lr: The learning rate.
|
|
148
|
+
momentum: Momentum factor.
|
|
149
|
+
dampening: Dampening for momentum.
|
|
150
|
+
weight_decay: Weight decay (L2 penalty).
|
|
151
|
+
nesterov: Enables Nesterov momentum.
|
|
152
|
+
maximize: Whether to maximize the params based on the objective.
|
|
153
|
+
"""
|
|
154
|
+
name: Literal["sgd"] = "sgd"
|
|
155
|
+
|
|
156
|
+
lr: float
|
|
157
|
+
momentum: float = 0
|
|
158
|
+
dampening: float = 0
|
|
159
|
+
weight_decay: float = 0
|
|
160
|
+
nesterov: bool = False
|
|
161
|
+
maximize: bool = False
|
|
162
|
+
|
|
163
|
+
def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
|
|
164
|
+
"""Builds fused SGD with the configured parameters."""
|
|
165
|
+
return SGD(
|
|
166
|
+
params,
|
|
167
|
+
lr=self.lr,
|
|
168
|
+
momentum=self.momentum,
|
|
169
|
+
dampening=self.dampening,
|
|
170
|
+
weight_decay=self.weight_decay,
|
|
171
|
+
nesterov=self.nesterov,
|
|
172
|
+
maximize=self.maximize,
|
|
173
|
+
fused=True
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
AutoOptimizerConfig = Annotated[
|
|
178
|
+
StochasticAdamWOptimizerConfig |
|
|
179
|
+
AdamWOptimizerConfig |
|
|
180
|
+
AdamOptimizerConfig |
|
|
181
|
+
SGDOptimizerConfig,
|
|
182
|
+
Field(discriminator="name")
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class AutoOptimizerProvider(OptimizerProvider):
|
|
187
|
+
"""
|
|
188
|
+
OptimizerProvider that builds a PyTorch optimizer based on a configuration object.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(self, config: AutoOptimizerConfig):
|
|
192
|
+
"""Constructs the provider with the given configuration."""
|
|
193
|
+
self._config = config
|
|
194
|
+
|
|
195
|
+
def __call__(self, context: InitializeOptimizerStageContext) -> Optimizer:
|
|
196
|
+
return self._config.build(context.model.parameters())
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from .batch_maths import BatchMaths
|
|
2
|
+
from .checkpointer import StateCheckpointer
|
|
3
|
+
from .data_loader_factory import DataLoaderFactory
|
|
4
|
+
from .garbage_collector import ManualGarbageCollector
|
|
5
|
+
from .gradient_clipper import GradientClipper
|
|
6
|
+
from .gradient_manager import GradientManager
|
|
7
|
+
from .job_logger import JobLogger
|
|
8
|
+
from .job_profiler import JobProfiler
|
|
9
|
+
from .loss_computer import LossComputer
|
|
10
|
+
from .model_stage_exporter import ModelStageExporter
|
|
11
|
+
from .model_stage_factory import ModelStageFactory, TrackedModules
|
|
12
|
+
from .optimizer_factory import OptimizerFactory
|
|
13
|
+
from .stepper import Stepper
|
|
14
|
+
from .timeout_manager import TimeoutManager
|
|
15
|
+
from .train_task_operator import ForwardResult, TrainTaskOperator
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"BatchMaths",
|
|
19
|
+
"DataLoaderFactory",
|
|
20
|
+
"ForwardResult",
|
|
21
|
+
"GradientClipper",
|
|
22
|
+
"GradientManager",
|
|
23
|
+
"JobLogger",
|
|
24
|
+
"JobProfiler",
|
|
25
|
+
"LossComputer",
|
|
26
|
+
"ManualGarbageCollector",
|
|
27
|
+
"ModelStageExporter",
|
|
28
|
+
"ModelStageFactory",
|
|
29
|
+
"OptimizerFactory",
|
|
30
|
+
"StateCheckpointer",
|
|
31
|
+
"Stepper",
|
|
32
|
+
"TimeoutManager",
|
|
33
|
+
"TrackedModules",
|
|
34
|
+
"TrainTaskOperator"
|
|
35
|
+
]
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from d9d.core.dist_context import BATCH_DOMAIN, DistributedContext
|
|
2
|
+
from d9d.loop.config import BatchingConfig, PipeliningConfig
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BatchMaths:
|
|
6
|
+
"""
|
|
7
|
+
Calculates derived batching dimensions and iteration counts for distributed training loops.
|
|
8
|
+
|
|
9
|
+
This class bridges the gap between global configuration (Global Batch Size) and
|
|
10
|
+
local execution constraints (Microbatch Size, Data Parallel World Size).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
dist_context: DistributedContext,
|
|
16
|
+
config_batching: BatchingConfig,
|
|
17
|
+
config_pipelining: PipeliningConfig | None
|
|
18
|
+
):
|
|
19
|
+
"""
|
|
20
|
+
Constructs the batch mathematics calculator.
|
|
21
|
+
|
|
22
|
+
Validates that the Global Batch Size is perfectly divisible by the
|
|
23
|
+
effective parallel microbatch capacity (DP size * Microbatch size).
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
dist_context: The distributed context containing mesh layout information.
|
|
27
|
+
config_batching: Configuration detailing batch sizes.
|
|
28
|
+
config_pipelining: Optional configuration for pipeline parallelism capabilities.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ValueError: If global batch size is not divisible by the product of
|
|
32
|
+
Data Parallel size and Microbatch size.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
self._dist_context = dist_context
|
|
36
|
+
self._config_batching = config_batching
|
|
37
|
+
self._config_pipelining = config_pipelining
|
|
38
|
+
|
|
39
|
+
global_batch = self._config_batching.global_batch_size
|
|
40
|
+
dp_size = self._dist_context.mesh_for(BATCH_DOMAIN)["dp"].size()
|
|
41
|
+
microbatch_size = self._config_batching.microbatch_size
|
|
42
|
+
|
|
43
|
+
global_microbatch = dp_size * microbatch_size
|
|
44
|
+
|
|
45
|
+
if global_batch % global_microbatch != 0:
|
|
46
|
+
raise ValueError("Global Batch Size must be divisible by (Data Parallel cardinality * Microbatch Size)")
|
|
47
|
+
|
|
48
|
+
self._global_microbatch_size = global_microbatch
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def global_batch_size(self) -> int:
|
|
52
|
+
"""
|
|
53
|
+
Returns the global batch size across the world.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
return self._config_batching.global_batch_size
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def num_microbatches_pipelining(self) -> int:
|
|
60
|
+
"""
|
|
61
|
+
Returns the number of microbatches handled by the pipeline scheduler per step.
|
|
62
|
+
|
|
63
|
+
If pipeline parallelism is enabled, this is the total number of microbatches
|
|
64
|
+
processed to form one global batch. If disabled, this returns 1.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
if not self._dist_context.mesh_params.has_pipeline_parallel:
|
|
68
|
+
return 1
|
|
69
|
+
|
|
70
|
+
return self._config_batching.global_batch_size // self._global_microbatch_size
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def num_microbatches_gradient_accumulation(self) -> int:
|
|
74
|
+
"""
|
|
75
|
+
Returns the number of gradient accumulation iterations for non-pipelined training.
|
|
76
|
+
|
|
77
|
+
If pipeline parallelism is enabled, this returns 1 (as accumulation is handled
|
|
78
|
+
internally by the pipeline schedule). If disabled, this is the number of
|
|
79
|
+
forward/backward passes the training loop must execute before an optimizer step.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
if self._dist_context.mesh_params.has_pipeline_parallel:
|
|
83
|
+
return 1
|
|
84
|
+
|
|
85
|
+
return self._config_batching.global_batch_size // self._global_microbatch_size
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def data_loader_batch_size(self) -> int:
|
|
89
|
+
"""
|
|
90
|
+
Returns the quantity of samples this local rank needs to fetch for one optimizer step.
|
|
91
|
+
|
|
92
|
+
This is calculated as `microbatch_size * total_microbatches_per_step`.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
return self._config_batching.microbatch_size * self.num_microbatches_pipelining
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def num_backward_calls(self) -> int:
|
|
99
|
+
"""
|
|
100
|
+
Returns the total number of backward passes executed per optimizer step.
|
|
101
|
+
|
|
102
|
+
This represents the total gradient accumulation factor, regardless of whether
|
|
103
|
+
it is handled by a pipeline schedule or a simple loop.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
return self.num_microbatches_pipelining * self.num_microbatches_gradient_accumulation
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.distributed.checkpoint as dcp
|
|
7
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
8
|
+
|
|
9
|
+
from d9d.core.dist_context import DistributedContext
|
|
10
|
+
from d9d.loop.config import CheckpointingConfig
|
|
11
|
+
|
|
12
|
+
from .garbage_collector import ManualGarbageCollector
|
|
13
|
+
from .stepper import Stepper
|
|
14
|
+
|
|
15
|
+
# TODO feat(max): async checkpointing may break everything up, but I guess we still have to support it
|
|
16
|
+
|
|
17
|
+
_SAVE_RE = re.compile(r"^save-(\d+)$")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _save_iter_predicate(x: Path) -> int:
|
|
21
|
+
match = _SAVE_RE.fullmatch(x.stem)
|
|
22
|
+
if match is None:
|
|
23
|
+
raise ValueError("Malformed checkpoint name")
|
|
24
|
+
return int(match.group(1))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class StateCheckpointer:
|
|
28
|
+
"""
|
|
29
|
+
Manages the lifecycle of distributed training checkpoints.
|
|
30
|
+
|
|
31
|
+
This class handles saving and loading the training state (JobState object)
|
|
32
|
+
using PyTorch Distributed Checkpoint (DCP). It manages checkpoint versioning,
|
|
33
|
+
storage rotation (keeping only N latest), and synchronization across distributed ranks.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
dist_context: DistributedContext,
|
|
39
|
+
stepper: Stepper,
|
|
40
|
+
config: CheckpointingConfig,
|
|
41
|
+
gc: ManualGarbageCollector,
|
|
42
|
+
run_name: str | None
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Constructs the StateCheckpoint object.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
dist_context: The distributed context.
|
|
49
|
+
stepper: The training stepper tracking the current iteration/step.
|
|
50
|
+
config: Configuration object containing checkpointing parameters.
|
|
51
|
+
gc: Garbage collector for manual memory management during IO.
|
|
52
|
+
run_name: Optional specific run name to append to the save directory.
|
|
53
|
+
"""
|
|
54
|
+
self._dist_context = dist_context
|
|
55
|
+
self._stepper = stepper
|
|
56
|
+
self._gc = gc
|
|
57
|
+
|
|
58
|
+
if run_name:
|
|
59
|
+
self._save_dir = config.save_dir / run_name
|
|
60
|
+
else:
|
|
61
|
+
self._save_dir = config.save_dir
|
|
62
|
+
|
|
63
|
+
self._config = config
|
|
64
|
+
|
|
65
|
+
def _free_memory(self):
|
|
66
|
+
self._gc.collect_forced()
|
|
67
|
+
torch.cuda.empty_cache()
|
|
68
|
+
|
|
69
|
+
def _get_sorted_checkpoint_dirs(self) -> list[Path]:
|
|
70
|
+
if not self._save_dir:
|
|
71
|
+
return []
|
|
72
|
+
|
|
73
|
+
if not self._save_dir.is_dir():
|
|
74
|
+
return []
|
|
75
|
+
|
|
76
|
+
checkpoint_dirs = [x for x in self._save_dir.iterdir() if x.is_dir() and _SAVE_RE.fullmatch(x.stem)]
|
|
77
|
+
checkpoint_dirs = sorted(checkpoint_dirs, key=_save_iter_predicate)
|
|
78
|
+
return checkpoint_dirs
|
|
79
|
+
|
|
80
|
+
def _next_checkpoint_id(self) -> Path:
|
|
81
|
+
next_name = f"save-{self._stepper.current_step}"
|
|
82
|
+
return self._save_dir / next_name
|
|
83
|
+
|
|
84
|
+
def _purge_old_checkpoints(self):
|
|
85
|
+
if not self._dist_context.is_main_process:
|
|
86
|
+
return
|
|
87
|
+
if not self._config.num_to_keep:
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
to_delete = self._get_sorted_checkpoint_dirs()[:-self._config.num_to_keep]
|
|
91
|
+
|
|
92
|
+
for delete_dir in to_delete:
|
|
93
|
+
self._dist_context.logger.info(f"Purging checkpoint {delete_dir}")
|
|
94
|
+
shutil.rmtree(delete_dir)
|
|
95
|
+
|
|
96
|
+
def _checkpoint(self, state: Stateful):
|
|
97
|
+
next_checkpoint_id = self._next_checkpoint_id()
|
|
98
|
+
|
|
99
|
+
self._dist_context.logger.info("Freeing up memory before checkpointing")
|
|
100
|
+
self._free_memory()
|
|
101
|
+
self._dist_context.logger.info("Waiting for world before saving checkpoint")
|
|
102
|
+
self._dist_context.wait_world()
|
|
103
|
+
self._dist_context.logger.info(f"Saving checkpoint {next_checkpoint_id}")
|
|
104
|
+
|
|
105
|
+
save_from = {"state": state}
|
|
106
|
+
dcp.save(
|
|
107
|
+
state_dict=save_from,
|
|
108
|
+
checkpoint_id=next_checkpoint_id
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self._purge_old_checkpoints()
|
|
112
|
+
self._free_memory()
|
|
113
|
+
|
|
114
|
+
self._dist_context.logger.info("Waiting for world after saving checkpoint")
|
|
115
|
+
self._dist_context.wait_world()
|
|
116
|
+
self._dist_context.logger.info("Checkpoint successfully saved across the world")
|
|
117
|
+
|
|
118
|
+
def checkpoint_if_needed(self, state: Stateful):
|
|
119
|
+
"""
|
|
120
|
+
Checks if a checkpoint is due based on the configuration and saves if necessary.
|
|
121
|
+
|
|
122
|
+
This checks the stepper to see if the current step matches the configured
|
|
123
|
+
saving period (or if it is the final step).
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
state: The Stateful object to save.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
if self._stepper.should_do_action(self._config.period_steps, enable_on_last_step_if_periodic=True):
|
|
130
|
+
self._checkpoint(state)
|
|
131
|
+
|
|
132
|
+
def _last_checkpoint_id(self) -> Path | None:
|
|
133
|
+
checkpoints = self._get_sorted_checkpoint_dirs()
|
|
134
|
+
if len(checkpoints) == 0:
|
|
135
|
+
return None
|
|
136
|
+
return checkpoints[-1]
|
|
137
|
+
|
|
138
|
+
def _load(self, state: Stateful):
|
|
139
|
+
last_checkpoint = self._last_checkpoint_id()
|
|
140
|
+
|
|
141
|
+
if last_checkpoint is None:
|
|
142
|
+
self._dist_context.logger.info("Starting job from scratch")
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
self._dist_context.logger.info("Waiting for world before loading checkpoint")
|
|
146
|
+
self._dist_context.wait_world()
|
|
147
|
+
self._dist_context.logger.info(f"Loading checkpoint {last_checkpoint}")
|
|
148
|
+
|
|
149
|
+
load_into = {
|
|
150
|
+
"state": state
|
|
151
|
+
}
|
|
152
|
+
dcp.load(
|
|
153
|
+
state_dict=load_into,
|
|
154
|
+
checkpoint_id=last_checkpoint
|
|
155
|
+
)
|
|
156
|
+
self._free_memory()
|
|
157
|
+
|
|
158
|
+
self._dist_context.logger.info("Waiting for world after loading checkpoint")
|
|
159
|
+
self._dist_context.wait_world()
|
|
160
|
+
self._dist_context.logger.info("Checkpoint successfully loaded across the world")
|
|
161
|
+
|
|
162
|
+
def load_last_checkpoint(self, state: Stateful):
|
|
163
|
+
"""
|
|
164
|
+
Attempts to load the most recent checkpoint available in the save directory.
|
|
165
|
+
|
|
166
|
+
If no checkpoint is found, the state remains unchanged (starting from scratch).
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
state: The stateful object to which loaded parameters will be applied.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
self._load(state)
|