d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
d9d/loop/run/train.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
|
|
5
|
+
from d9d.core.dist_context import DeviceMeshParameters
|
|
6
|
+
from d9d.internals.determinism import set_seeds
|
|
7
|
+
from d9d.internals.pipeline_state import PipelineStateHandler
|
|
8
|
+
from d9d.loop.component import (
|
|
9
|
+
BatchMaths,
|
|
10
|
+
DataLoaderFactory,
|
|
11
|
+
GradientClipper,
|
|
12
|
+
GradientManager,
|
|
13
|
+
JobLogger,
|
|
14
|
+
JobProfiler,
|
|
15
|
+
LossComputer,
|
|
16
|
+
ManualGarbageCollector,
|
|
17
|
+
ModelStageExporter,
|
|
18
|
+
ModelStageFactory,
|
|
19
|
+
OptimizerFactory,
|
|
20
|
+
StateCheckpointer,
|
|
21
|
+
Stepper,
|
|
22
|
+
TimeoutManager,
|
|
23
|
+
TrainTaskOperator,
|
|
24
|
+
)
|
|
25
|
+
from d9d.loop.config import TrainerConfig
|
|
26
|
+
from d9d.loop.control import (
|
|
27
|
+
CreateMetricsContext,
|
|
28
|
+
DatasetProvider,
|
|
29
|
+
FinalizeContext,
|
|
30
|
+
LRSchedulerProvider,
|
|
31
|
+
ModelProvider,
|
|
32
|
+
OptimizerProvider,
|
|
33
|
+
TrainTaskProvider,
|
|
34
|
+
TrainTaskProviderContext,
|
|
35
|
+
)
|
|
36
|
+
from d9d.loop.state import TrainJobState
|
|
37
|
+
from d9d.metric.impl import ComposeMetric
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TrainingConfigurator:
|
|
41
|
+
"""
|
|
42
|
+
Orchestrates the assembly of the distributed training environment.
|
|
43
|
+
|
|
44
|
+
This class binds the infrastructure configuration (DeviceMesh), the training
|
|
45
|
+
parameters (TrainerConfig), and the user-defined logic (Providers) to create
|
|
46
|
+
a fully initialized state object capable of running the training loop.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
mesh: DeviceMeshParameters,
|
|
52
|
+
parameters: TrainerConfig,
|
|
53
|
+
task_provider: TrainTaskProvider,
|
|
54
|
+
model_provider: ModelProvider,
|
|
55
|
+
data_provider: DatasetProvider,
|
|
56
|
+
optimizer_provider: OptimizerProvider,
|
|
57
|
+
lr_scheduler_provider: LRSchedulerProvider
|
|
58
|
+
):
|
|
59
|
+
"""
|
|
60
|
+
Constructs a configurator capable of building the full training state.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
mesh: Definition of the distributed device mesh topology.
|
|
64
|
+
parameters: The global configuration object for the trainer.
|
|
65
|
+
task_provider: Factory for creating the training task logic.
|
|
66
|
+
model_provider: Factory for defining and creating model stages.
|
|
67
|
+
data_provider: Factory for providing training datasets.
|
|
68
|
+
optimizer_provider: Factory for creating the optimizer.
|
|
69
|
+
lr_scheduler_provider: Factory for creating the learning rate scheduler.
|
|
70
|
+
"""
|
|
71
|
+
self._mesh = mesh
|
|
72
|
+
self._parameters = parameters
|
|
73
|
+
self._task_provider = task_provider
|
|
74
|
+
self._model_provider = model_provider
|
|
75
|
+
self._data_provider = data_provider
|
|
76
|
+
self._optimizer_provider = optimizer_provider
|
|
77
|
+
self._lr_scheduler_provider = lr_scheduler_provider
|
|
78
|
+
|
|
79
|
+
def _build_new_training_state(self) -> TrainJobState:
|
|
80
|
+
dist_context = self._mesh.build()
|
|
81
|
+
|
|
82
|
+
set_seeds(dist_context, seed=self._parameters.determinism.base_seed)
|
|
83
|
+
|
|
84
|
+
task = self._task_provider(TrainTaskProviderContext(
|
|
85
|
+
dist_context=dist_context
|
|
86
|
+
))
|
|
87
|
+
|
|
88
|
+
batch_maths = BatchMaths(
|
|
89
|
+
dist_context=dist_context,
|
|
90
|
+
config_batching=self._parameters.batching,
|
|
91
|
+
config_pipelining=self._parameters.pipelining
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
data_loader_factory = DataLoaderFactory(
|
|
95
|
+
dist_context=dist_context,
|
|
96
|
+
provider=self._data_provider,
|
|
97
|
+
config_data_loading=self._parameters.data_loading,
|
|
98
|
+
batch_maths=batch_maths
|
|
99
|
+
)
|
|
100
|
+
data_loader_train = data_loader_factory.build_dataloader_for_train_job()
|
|
101
|
+
|
|
102
|
+
stepper = Stepper(
|
|
103
|
+
initial_step=1,
|
|
104
|
+
total_steps=len(data_loader_train)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
pipeline_state_handler = PipelineStateHandler(
|
|
108
|
+
sharding_spec={},
|
|
109
|
+
num_shards=batch_maths.num_microbatches_pipelining
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
loss_computer = LossComputer(
|
|
113
|
+
state=pipeline_state_handler,
|
|
114
|
+
task=task,
|
|
115
|
+
stepper=stepper
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
schedule, modules = ModelStageFactory(
|
|
119
|
+
model_provider=self._model_provider,
|
|
120
|
+
dist_context=dist_context,
|
|
121
|
+
config_model=self._parameters.model_stage_factory,
|
|
122
|
+
config_pipelining=self._parameters.pipelining,
|
|
123
|
+
batch_maths=batch_maths,
|
|
124
|
+
loss_computer=loss_computer
|
|
125
|
+
).build_pipeline_and_modules()
|
|
126
|
+
|
|
127
|
+
metrics = ComposeMetric(task.create_metrics(CreateMetricsContext()).metrics)
|
|
128
|
+
metrics.to("cuda")
|
|
129
|
+
|
|
130
|
+
task_operator = TrainTaskOperator(
|
|
131
|
+
dist_context=dist_context,
|
|
132
|
+
task=task,
|
|
133
|
+
pp_schedule=schedule,
|
|
134
|
+
tracked_modules=modules,
|
|
135
|
+
loss_computer=loss_computer,
|
|
136
|
+
pipeline_state=pipeline_state_handler,
|
|
137
|
+
metrics=metrics
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
grad_clipper = GradientClipper(
|
|
141
|
+
dist_context=dist_context,
|
|
142
|
+
tracked_modules=modules,
|
|
143
|
+
config=self._parameters.gradient_clipping,
|
|
144
|
+
stepper=stepper
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
optimizer, scheduler = OptimizerFactory(
|
|
148
|
+
dist_context=dist_context,
|
|
149
|
+
tracked_modules=modules,
|
|
150
|
+
optimizer_provider=self._optimizer_provider,
|
|
151
|
+
stepper=stepper,
|
|
152
|
+
lr_scheduler_provider=self._lr_scheduler_provider
|
|
153
|
+
).build_optimizer_and_scheduler()
|
|
154
|
+
|
|
155
|
+
gc = ManualGarbageCollector(
|
|
156
|
+
dist_ctx=dist_context,
|
|
157
|
+
config=self._parameters.gc,
|
|
158
|
+
step=stepper
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
checkpointer = StateCheckpointer(
|
|
162
|
+
dist_context=dist_context,
|
|
163
|
+
stepper=stepper,
|
|
164
|
+
config=self._parameters.checkpointing,
|
|
165
|
+
gc=gc,
|
|
166
|
+
run_name=self._parameters.run.name
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
profiler = JobProfiler(
|
|
170
|
+
dist_context=dist_context,
|
|
171
|
+
stepper=stepper,
|
|
172
|
+
config=self._parameters.profiling
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
exporter = ModelStageExporter(
|
|
176
|
+
model_provider=self._model_provider,
|
|
177
|
+
dist_context=dist_context,
|
|
178
|
+
modules=modules
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
gradient_manager = GradientManager(
|
|
182
|
+
dist_context=dist_context,
|
|
183
|
+
tracked_modules=modules,
|
|
184
|
+
batch_maths=batch_maths,
|
|
185
|
+
config=self._parameters.gradient_manager
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
timeout_manager = TimeoutManager(
|
|
189
|
+
dist_context=dist_context,
|
|
190
|
+
config=self._parameters.timeout
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
job_logger = JobLogger(
|
|
194
|
+
dist_context=dist_context,
|
|
195
|
+
config=self._parameters.logging,
|
|
196
|
+
metrics=metrics,
|
|
197
|
+
stepper=stepper,
|
|
198
|
+
run_config=self._parameters.run,
|
|
199
|
+
additional_hparams={
|
|
200
|
+
"task": task.dump_hparams(),
|
|
201
|
+
"model": self._model_provider.dump_hparams()
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
return TrainJobState(
|
|
206
|
+
dist_context=dist_context,
|
|
207
|
+
data_loader=data_loader_train,
|
|
208
|
+
stepper=stepper,
|
|
209
|
+
tracked_modules=modules,
|
|
210
|
+
garbage_collector=gc,
|
|
211
|
+
batch_maths=batch_maths,
|
|
212
|
+
checkpointer=checkpointer,
|
|
213
|
+
optimizer=optimizer,
|
|
214
|
+
task=task,
|
|
215
|
+
lr_scheduler=scheduler,
|
|
216
|
+
gradient_clipper=grad_clipper,
|
|
217
|
+
profiler=profiler,
|
|
218
|
+
exporter=exporter,
|
|
219
|
+
metrics=metrics,
|
|
220
|
+
logger=job_logger,
|
|
221
|
+
gradient_manager=gradient_manager,
|
|
222
|
+
timeout_manager=timeout_manager,
|
|
223
|
+
task_operator=task_operator
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def configure(self) -> "Trainer":
|
|
227
|
+
"""
|
|
228
|
+
Instantiates all training components and returns a configured Trainer.
|
|
229
|
+
|
|
230
|
+
This method triggers the creation of the distributed context, sets seeds,
|
|
231
|
+
builds the model, optimizer, data loaders, and attaches all auxiliary
|
|
232
|
+
components (logging, profiling, checkpointing).
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Trainer: A ready-to-use trainer instance encapsulating the job state.
|
|
236
|
+
"""
|
|
237
|
+
state = self._build_new_training_state()
|
|
238
|
+
|
|
239
|
+
return Trainer(state)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class Trainer:
|
|
243
|
+
"""
|
|
244
|
+
The main execution engine for running a distributed training job.
|
|
245
|
+
|
|
246
|
+
This class manages the training loop, lifecycle events, distributed synchronization,
|
|
247
|
+
and periodic side-effects (logging, checkpointing).
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(self, state: TrainJobState):
|
|
251
|
+
"""
|
|
252
|
+
Constructs a Trainer from a pre-built job state.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
state: The encapsulated state object containing all initialized
|
|
256
|
+
components (model, optimizer, dist_context, etc.).
|
|
257
|
+
"""
|
|
258
|
+
self._state = state
|
|
259
|
+
|
|
260
|
+
def train(self):
|
|
261
|
+
"""
|
|
262
|
+
Executes the full training workflow.
|
|
263
|
+
"""
|
|
264
|
+
self._state.dist_context.logger.info("Waiting for the world to start training")
|
|
265
|
+
self._state.dist_context.wait_world()
|
|
266
|
+
self._state.dist_context.logger.info("Trying to load last checkpoint before doing anything else")
|
|
267
|
+
self._state.checkpointer.load_last_checkpoint(self._state)
|
|
268
|
+
|
|
269
|
+
if self._state.stepper.current_step >= self._state.stepper.total_steps:
|
|
270
|
+
self._state.dist_context.logger.info("Already trained fully, will do nothing")
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
self._state.dist_context.wait_world()
|
|
274
|
+
|
|
275
|
+
with (
|
|
276
|
+
tqdm(
|
|
277
|
+
desc="Training",
|
|
278
|
+
total=self._state.stepper.total_steps,
|
|
279
|
+
disable=not self._state.dist_context.is_local_main_process,
|
|
280
|
+
initial=self._state.stepper.current_step
|
|
281
|
+
) as bar,
|
|
282
|
+
self._state.logger.new_run() as run,
|
|
283
|
+
self._state.garbage_collector as gc,
|
|
284
|
+
self._state.profiler.open() as profiler,
|
|
285
|
+
self._state.gradient_manager.install(),
|
|
286
|
+
self._state.gradient_clipper.install()
|
|
287
|
+
):
|
|
288
|
+
self._state.timeout_manager.step()
|
|
289
|
+
run.set_context({"stage": "train"})
|
|
290
|
+
|
|
291
|
+
for batch_group in self._state.data_loader:
|
|
292
|
+
run.set_step(self._state.stepper.current_step)
|
|
293
|
+
|
|
294
|
+
for batch in batch_group:
|
|
295
|
+
# we do both forward and backward passes
|
|
296
|
+
# since GradientManager is installed - it should start performing
|
|
297
|
+
# synchronization overlapping grad sync with compute
|
|
298
|
+
loss = self._state.task_operator.forward_backward(batch)
|
|
299
|
+
|
|
300
|
+
# add loss for grad manager - it want it for grad reduction
|
|
301
|
+
if loss is not None:
|
|
302
|
+
self._state.gradient_manager.add_loss_with_weight(loss.loss, loss.loss_weight)
|
|
303
|
+
|
|
304
|
+
# metrics were successfully accumulated during forward passes - we can schedule their synchronization
|
|
305
|
+
self._state.logger.trigger_sync()
|
|
306
|
+
|
|
307
|
+
# wait for gradient synchronization finishes and scale them
|
|
308
|
+
self._state.gradient_manager.sync_and_scale()
|
|
309
|
+
|
|
310
|
+
# clip grads after they are synced across world
|
|
311
|
+
self._state.gradient_clipper.clip_and_log(run)
|
|
312
|
+
|
|
313
|
+
# optimize (it won't sync grads - they are already Replicate-d)
|
|
314
|
+
self._state.optimizer.step()
|
|
315
|
+
|
|
316
|
+
# update LR
|
|
317
|
+
self._state.lr_scheduler.step()
|
|
318
|
+
|
|
319
|
+
# log everything
|
|
320
|
+
self._state.logger.log(
|
|
321
|
+
run,
|
|
322
|
+
loss_value=self._state.gradient_manager.compute_global_loss()
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# reset grads
|
|
326
|
+
self._state.gradient_manager.zero_grad()
|
|
327
|
+
|
|
328
|
+
gc.collect_periodic()
|
|
329
|
+
self._state.stepper.step()
|
|
330
|
+
bar.update()
|
|
331
|
+
|
|
332
|
+
# checkpoint at the end of the step
|
|
333
|
+
self._state.checkpointer.checkpoint_if_needed(self._state)
|
|
334
|
+
|
|
335
|
+
if profiler:
|
|
336
|
+
profiler.step()
|
|
337
|
+
|
|
338
|
+
self._state.task.finalize(FinalizeContext())
|
|
339
|
+
|
|
340
|
+
def export(self, export_to: Path, load_checkpoint: bool):
|
|
341
|
+
"""
|
|
342
|
+
Exports the current model state to the specified directory.
|
|
343
|
+
|
|
344
|
+
This handles the distributed saving logic, allowing the model to be
|
|
345
|
+
reconstituted later or used for inference.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
export_to: The directory path where the model artifacts will be saved.
|
|
349
|
+
load_checkpoint: If True, attempts to load the latest checkpoint
|
|
350
|
+
into the model before exporting.
|
|
351
|
+
"""
|
|
352
|
+
if load_checkpoint:
|
|
353
|
+
self._state.checkpointer.load_last_checkpoint(self._state)
|
|
354
|
+
|
|
355
|
+
self._state.exporter.export(export_to)
|
d9d/loop/state.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
5
|
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
|
6
|
+
|
|
7
|
+
from d9d.core.dist_context import DistributedContext
|
|
8
|
+
from d9d.core.protocol import LRSchedulerProtocol, OptimizerProtocol
|
|
9
|
+
from d9d.loop.component import (
|
|
10
|
+
BatchMaths,
|
|
11
|
+
GradientClipper,
|
|
12
|
+
GradientManager,
|
|
13
|
+
JobLogger,
|
|
14
|
+
JobProfiler,
|
|
15
|
+
ManualGarbageCollector,
|
|
16
|
+
ModelStageExporter,
|
|
17
|
+
StateCheckpointer,
|
|
18
|
+
Stepper,
|
|
19
|
+
TimeoutManager,
|
|
20
|
+
TrackedModules,
|
|
21
|
+
TrainTaskOperator,
|
|
22
|
+
)
|
|
23
|
+
from d9d.loop.control import InferenceTask, TrainTask
|
|
24
|
+
from d9d.metric.impl import ComposeMetric
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclasses.dataclass(kw_only=True)
|
|
28
|
+
class JobState(Stateful):
|
|
29
|
+
"""
|
|
30
|
+
Base container for the state of a distributed execution job.
|
|
31
|
+
|
|
32
|
+
This dataclass holds the common infrastructure components required for both
|
|
33
|
+
training and inference loops. It implements the Stateful protocol to support
|
|
34
|
+
checkpointing of its internal components.
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
dist_context: The distributed context.
|
|
38
|
+
stepper: Component for tracking the current global step and total steps.
|
|
39
|
+
garbage_collector: Component for manual control of Python garbage collection.
|
|
40
|
+
checkpointer: Component responsible for saving and loading execution states.
|
|
41
|
+
profiler: Component for performance profiling.
|
|
42
|
+
tracked_modules: Container holding the model (or model parts) being executed.
|
|
43
|
+
batch_maths: Helper for calculating batch sizes and gradient accumulation steps.
|
|
44
|
+
data_loader: The input data stream.
|
|
45
|
+
timeout_manager: Component for checking and refreshing distributed timeouts.
|
|
46
|
+
"""
|
|
47
|
+
dist_context: DistributedContext
|
|
48
|
+
|
|
49
|
+
stepper: Stepper
|
|
50
|
+
garbage_collector: ManualGarbageCollector
|
|
51
|
+
checkpointer: StateCheckpointer
|
|
52
|
+
profiler: JobProfiler
|
|
53
|
+
|
|
54
|
+
tracked_modules: TrackedModules
|
|
55
|
+
batch_maths: BatchMaths
|
|
56
|
+
|
|
57
|
+
data_loader: StatefulDataLoader
|
|
58
|
+
|
|
59
|
+
timeout_manager: TimeoutManager
|
|
60
|
+
|
|
61
|
+
def state_dict(self) -> dict[str, Any]:
|
|
62
|
+
return {
|
|
63
|
+
"stepper": self.stepper.state_dict(),
|
|
64
|
+
"tracked_modules": self.tracked_modules.state_dict(),
|
|
65
|
+
"data_loader": self.data_loader.state_dict()
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
69
|
+
self.stepper.load_state_dict(state_dict["stepper"])
|
|
70
|
+
self.tracked_modules.load_state_dict(state_dict["tracked_modules"])
|
|
71
|
+
self.data_loader.load_state_dict(state_dict["data_loader"])
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclasses.dataclass(kw_only=True)
|
|
75
|
+
class TrainJobState(JobState):
|
|
76
|
+
"""
|
|
77
|
+
Container for the state of a training job.
|
|
78
|
+
|
|
79
|
+
Extends JobState to include components specific to training, such as
|
|
80
|
+
optimization, gradient management, and loss computation.
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
task: The specific training task logic definition.
|
|
84
|
+
gradient_manager: Component handling gradient synchronization.
|
|
85
|
+
metrics: Container for aggregating training metrics.
|
|
86
|
+
task_operator: Executor for running forward and backward passes.
|
|
87
|
+
logger: Component for logging metrics and system status.
|
|
88
|
+
optimizer: The optimizer instance updating model parameters.
|
|
89
|
+
lr_scheduler: The scheduler adjusting the learning rate.
|
|
90
|
+
gradient_clipper: Component for clipping gradient norms.
|
|
91
|
+
exporter: Component for exporting the final model artifacts.
|
|
92
|
+
"""
|
|
93
|
+
task: TrainTask
|
|
94
|
+
gradient_manager: GradientManager
|
|
95
|
+
metrics: ComposeMetric
|
|
96
|
+
task_operator: TrainTaskOperator
|
|
97
|
+
|
|
98
|
+
logger: JobLogger
|
|
99
|
+
|
|
100
|
+
optimizer: OptimizerProtocol
|
|
101
|
+
lr_scheduler: LRSchedulerProtocol
|
|
102
|
+
gradient_clipper: GradientClipper
|
|
103
|
+
exporter: ModelStageExporter
|
|
104
|
+
|
|
105
|
+
def state_dict(self) -> dict[str, Any]:
|
|
106
|
+
return {
|
|
107
|
+
**super().state_dict(),
|
|
108
|
+
"logger": self.logger.state_dict(),
|
|
109
|
+
"task": self.task.state_dict(),
|
|
110
|
+
"metrics": self.metrics.state_dict(),
|
|
111
|
+
"optimizer": self.optimizer.state_dict(),
|
|
112
|
+
"lr_scheduler": self.lr_scheduler.state_dict(),
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
116
|
+
super().load_state_dict(state_dict)
|
|
117
|
+
|
|
118
|
+
self.logger.load_state_dict(state_dict["logger"])
|
|
119
|
+
self.task.load_state_dict(state_dict["task"])
|
|
120
|
+
self.metrics.load_state_dict(state_dict["metrics"])
|
|
121
|
+
self.optimizer.load_state_dict(state_dict["optimizer"])
|
|
122
|
+
self.lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclasses.dataclass(kw_only=True)
|
|
126
|
+
class InferJobState(JobState):
|
|
127
|
+
"""
|
|
128
|
+
Container for the state of an inference job.
|
|
129
|
+
|
|
130
|
+
Attributes:
|
|
131
|
+
task: The specific inference task logic definition.
|
|
132
|
+
"""
|
|
133
|
+
task: InferenceTask
|
|
134
|
+
|
|
135
|
+
def state_dict(self) -> dict[str, Any]:
|
|
136
|
+
return {
|
|
137
|
+
**super().state_dict(),
|
|
138
|
+
"task": self.task.state_dict(),
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
142
|
+
super().load_state_dict(state_dict)
|
|
143
|
+
self.task.load_state_dict(state_dict["task"])
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implements flexible piecewise learning rate schedules via a builder pattern.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .builder import piecewise_schedule
|
|
6
|
+
from .config import PiecewiseSchedulerConfig, piecewise_scheduler_from_config
|
|
7
|
+
from .curves import CurveBase, CurveCosine, CurveExponential, CurveLinear, CurvePoly
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"CurveBase",
|
|
11
|
+
"CurveCosine",
|
|
12
|
+
"CurveExponential",
|
|
13
|
+
"CurveLinear",
|
|
14
|
+
"CurvePoly",
|
|
15
|
+
"PiecewiseSchedulerConfig",
|
|
16
|
+
"piecewise_schedule",
|
|
17
|
+
"piecewise_scheduler_from_config"
|
|
18
|
+
]
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from typing import Self
|
|
2
|
+
|
|
3
|
+
from torch.optim import Optimizer
|
|
4
|
+
from torch.optim.lr_scheduler import LambdaLR
|
|
5
|
+
|
|
6
|
+
from d9d.core.protocol import LRSchedulerProtocol
|
|
7
|
+
|
|
8
|
+
from .curves import CurveBase
|
|
9
|
+
from .engine import PiecewiseScheduleEngine, SchedulePhase
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PiecewiseScheduleBuilder:
|
|
13
|
+
"""
|
|
14
|
+
Builder for constructing multiphase learning rate schedules.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
initial_multiplier: float,
|
|
20
|
+
total_steps: int | None
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Constructs a new PiecewiseScheduleBuilder.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
initial_multiplier: The starting learning rate multiplier (usually 0.0 or 1.0).
|
|
27
|
+
total_steps: The total number of training steps. Required if using percentage-based methods.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
self._phases: list[SchedulePhase] = []
|
|
31
|
+
self._total_steps = total_steps
|
|
32
|
+
self._last_end_step = 0
|
|
33
|
+
self._last_multiplier = initial_multiplier
|
|
34
|
+
|
|
35
|
+
def for_steps(self, steps: int, target_multiplier: float, curve: CurveBase) -> Self:
|
|
36
|
+
"""
|
|
37
|
+
Adds a schedule phase lasting for a specific number of steps.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
steps: Duration of this phase in steps.
|
|
41
|
+
target_multiplier: The value of the multiplier at the end of this phase.
|
|
42
|
+
curve: The interpolation curve to use for bridging the start and end values.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The builder instance for chaining.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
self._phases.append(SchedulePhase(
|
|
49
|
+
start_step=self._last_end_step,
|
|
50
|
+
end_step=self._last_end_step + steps,
|
|
51
|
+
curve=curve,
|
|
52
|
+
start_value=self._last_multiplier,
|
|
53
|
+
end_value=target_multiplier
|
|
54
|
+
))
|
|
55
|
+
|
|
56
|
+
self._last_end_step += steps
|
|
57
|
+
self._last_multiplier = target_multiplier
|
|
58
|
+
|
|
59
|
+
return self
|
|
60
|
+
|
|
61
|
+
def until_percentage(self, p: float, target_multiplier: float, curve: CurveBase) -> Self:
|
|
62
|
+
"""
|
|
63
|
+
Adds a schedule phase lasting until a specific percentage of total training steps is reached.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
p: The target percentage (0.0 to 1.0) of total_steps where this phase ends.
|
|
67
|
+
target_multiplier: The value of the multiplier at the end of this phase.
|
|
68
|
+
curve: The interpolation curve to use.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
The builder instance for chaining.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
ValueError: If total_steps was not provided in constructor or if the target
|
|
75
|
+
percentage implies a step count earlier than the current cursor.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
if self._total_steps is None:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"You must define 'total_steps' in the constructor to use percentage-based methods."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if not 0.0 <= p <= 1.0:
|
|
84
|
+
raise ValueError("Percentage should be in range of [0.0, 1.0]")
|
|
85
|
+
|
|
86
|
+
target_step_abs = int(self._total_steps * p)
|
|
87
|
+
duration = target_step_abs - self._last_end_step
|
|
88
|
+
|
|
89
|
+
if duration < 0:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Target percentage {p} (step {target_step_abs}) is behind "
|
|
92
|
+
f"current cursor (step {self._last_end_step})."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return self.for_steps(duration, target_multiplier, curve)
|
|
96
|
+
|
|
97
|
+
def fill_rest(self, target_multiplier: float, curve: CurveBase) -> Self:
|
|
98
|
+
"""
|
|
99
|
+
Adds a schedule phase that lasts from the current cursor until the end of training.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
target_multiplier: The value of the multiplier at the very end of training.
|
|
103
|
+
curve: The interpolation curve to use.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
The builder instance for chaining.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
return self.until_percentage(1.0, target_multiplier, curve)
|
|
110
|
+
|
|
111
|
+
def build(self, optimizer: Optimizer) -> LRSchedulerProtocol:
|
|
112
|
+
"""
|
|
113
|
+
Finalizes the schedule and returns a PyTorch LR Scheduler.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
optimizer: The optimizer to wrap.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
A scheduler configured with the defined phases.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ValueError: If the defined phases exceed the total_steps provided.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
if self._total_steps is not None and self._last_end_step > self._total_steps:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Schedule defined for {self._last_end_step} steps, but total_steps is {self._total_steps}."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
engine = PiecewiseScheduleEngine(self._phases)
|
|
131
|
+
return LambdaLR(optimizer, engine.get_factor)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def piecewise_schedule(
|
|
135
|
+
initial_multiplier: float,
|
|
136
|
+
total_steps: int | None = None
|
|
137
|
+
) -> PiecewiseScheduleBuilder:
|
|
138
|
+
"""
|
|
139
|
+
Entry point for creating a piecewise learning rate schedule.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
initial_multiplier: The initial learning rate multiplier.
|
|
143
|
+
total_steps: Total training steps. Required for percentage-based scheduling.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
A builder instance to configure phases.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
return PiecewiseScheduleBuilder(
|
|
150
|
+
initial_multiplier=initial_multiplier,
|
|
151
|
+
total_steps=total_steps
|
|
152
|
+
)
|