d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
d9d/loop/run/train.py ADDED
@@ -0,0 +1,355 @@
1
+ from pathlib import Path
2
+
3
+ from tqdm import tqdm
4
+
5
+ from d9d.core.dist_context import DeviceMeshParameters
6
+ from d9d.internals.determinism import set_seeds
7
+ from d9d.internals.pipeline_state import PipelineStateHandler
8
+ from d9d.loop.component import (
9
+ BatchMaths,
10
+ DataLoaderFactory,
11
+ GradientClipper,
12
+ GradientManager,
13
+ JobLogger,
14
+ JobProfiler,
15
+ LossComputer,
16
+ ManualGarbageCollector,
17
+ ModelStageExporter,
18
+ ModelStageFactory,
19
+ OptimizerFactory,
20
+ StateCheckpointer,
21
+ Stepper,
22
+ TimeoutManager,
23
+ TrainTaskOperator,
24
+ )
25
+ from d9d.loop.config import TrainerConfig
26
+ from d9d.loop.control import (
27
+ CreateMetricsContext,
28
+ DatasetProvider,
29
+ FinalizeContext,
30
+ LRSchedulerProvider,
31
+ ModelProvider,
32
+ OptimizerProvider,
33
+ TrainTaskProvider,
34
+ TrainTaskProviderContext,
35
+ )
36
+ from d9d.loop.state import TrainJobState
37
+ from d9d.metric.impl import ComposeMetric
38
+
39
+
40
+ class TrainingConfigurator:
41
+ """
42
+ Orchestrates the assembly of the distributed training environment.
43
+
44
+ This class binds the infrastructure configuration (DeviceMesh), the training
45
+ parameters (TrainerConfig), and the user-defined logic (Providers) to create
46
+ a fully initialized state object capable of running the training loop.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ mesh: DeviceMeshParameters,
52
+ parameters: TrainerConfig,
53
+ task_provider: TrainTaskProvider,
54
+ model_provider: ModelProvider,
55
+ data_provider: DatasetProvider,
56
+ optimizer_provider: OptimizerProvider,
57
+ lr_scheduler_provider: LRSchedulerProvider
58
+ ):
59
+ """
60
+ Constructs a configurator capable of building the full training state.
61
+
62
+ Args:
63
+ mesh: Definition of the distributed device mesh topology.
64
+ parameters: The global configuration object for the trainer.
65
+ task_provider: Factory for creating the training task logic.
66
+ model_provider: Factory for defining and creating model stages.
67
+ data_provider: Factory for providing training datasets.
68
+ optimizer_provider: Factory for creating the optimizer.
69
+ lr_scheduler_provider: Factory for creating the learning rate scheduler.
70
+ """
71
+ self._mesh = mesh
72
+ self._parameters = parameters
73
+ self._task_provider = task_provider
74
+ self._model_provider = model_provider
75
+ self._data_provider = data_provider
76
+ self._optimizer_provider = optimizer_provider
77
+ self._lr_scheduler_provider = lr_scheduler_provider
78
+
79
+ def _build_new_training_state(self) -> TrainJobState:
80
+ dist_context = self._mesh.build()
81
+
82
+ set_seeds(dist_context, seed=self._parameters.determinism.base_seed)
83
+
84
+ task = self._task_provider(TrainTaskProviderContext(
85
+ dist_context=dist_context
86
+ ))
87
+
88
+ batch_maths = BatchMaths(
89
+ dist_context=dist_context,
90
+ config_batching=self._parameters.batching,
91
+ config_pipelining=self._parameters.pipelining
92
+ )
93
+
94
+ data_loader_factory = DataLoaderFactory(
95
+ dist_context=dist_context,
96
+ provider=self._data_provider,
97
+ config_data_loading=self._parameters.data_loading,
98
+ batch_maths=batch_maths
99
+ )
100
+ data_loader_train = data_loader_factory.build_dataloader_for_train_job()
101
+
102
+ stepper = Stepper(
103
+ initial_step=1,
104
+ total_steps=len(data_loader_train)
105
+ )
106
+
107
+ pipeline_state_handler = PipelineStateHandler(
108
+ sharding_spec={},
109
+ num_shards=batch_maths.num_microbatches_pipelining
110
+ )
111
+
112
+ loss_computer = LossComputer(
113
+ state=pipeline_state_handler,
114
+ task=task,
115
+ stepper=stepper
116
+ )
117
+
118
+ schedule, modules = ModelStageFactory(
119
+ model_provider=self._model_provider,
120
+ dist_context=dist_context,
121
+ config_model=self._parameters.model_stage_factory,
122
+ config_pipelining=self._parameters.pipelining,
123
+ batch_maths=batch_maths,
124
+ loss_computer=loss_computer
125
+ ).build_pipeline_and_modules()
126
+
127
+ metrics = ComposeMetric(task.create_metrics(CreateMetricsContext()).metrics)
128
+ metrics.to("cuda")
129
+
130
+ task_operator = TrainTaskOperator(
131
+ dist_context=dist_context,
132
+ task=task,
133
+ pp_schedule=schedule,
134
+ tracked_modules=modules,
135
+ loss_computer=loss_computer,
136
+ pipeline_state=pipeline_state_handler,
137
+ metrics=metrics
138
+ )
139
+
140
+ grad_clipper = GradientClipper(
141
+ dist_context=dist_context,
142
+ tracked_modules=modules,
143
+ config=self._parameters.gradient_clipping,
144
+ stepper=stepper
145
+ )
146
+
147
+ optimizer, scheduler = OptimizerFactory(
148
+ dist_context=dist_context,
149
+ tracked_modules=modules,
150
+ optimizer_provider=self._optimizer_provider,
151
+ stepper=stepper,
152
+ lr_scheduler_provider=self._lr_scheduler_provider
153
+ ).build_optimizer_and_scheduler()
154
+
155
+ gc = ManualGarbageCollector(
156
+ dist_ctx=dist_context,
157
+ config=self._parameters.gc,
158
+ step=stepper
159
+ )
160
+
161
+ checkpointer = StateCheckpointer(
162
+ dist_context=dist_context,
163
+ stepper=stepper,
164
+ config=self._parameters.checkpointing,
165
+ gc=gc,
166
+ run_name=self._parameters.run.name
167
+ )
168
+
169
+ profiler = JobProfiler(
170
+ dist_context=dist_context,
171
+ stepper=stepper,
172
+ config=self._parameters.profiling
173
+ )
174
+
175
+ exporter = ModelStageExporter(
176
+ model_provider=self._model_provider,
177
+ dist_context=dist_context,
178
+ modules=modules
179
+ )
180
+
181
+ gradient_manager = GradientManager(
182
+ dist_context=dist_context,
183
+ tracked_modules=modules,
184
+ batch_maths=batch_maths,
185
+ config=self._parameters.gradient_manager
186
+ )
187
+
188
+ timeout_manager = TimeoutManager(
189
+ dist_context=dist_context,
190
+ config=self._parameters.timeout
191
+ )
192
+
193
+ job_logger = JobLogger(
194
+ dist_context=dist_context,
195
+ config=self._parameters.logging,
196
+ metrics=metrics,
197
+ stepper=stepper,
198
+ run_config=self._parameters.run,
199
+ additional_hparams={
200
+ "task": task.dump_hparams(),
201
+ "model": self._model_provider.dump_hparams()
202
+ }
203
+ )
204
+
205
+ return TrainJobState(
206
+ dist_context=dist_context,
207
+ data_loader=data_loader_train,
208
+ stepper=stepper,
209
+ tracked_modules=modules,
210
+ garbage_collector=gc,
211
+ batch_maths=batch_maths,
212
+ checkpointer=checkpointer,
213
+ optimizer=optimizer,
214
+ task=task,
215
+ lr_scheduler=scheduler,
216
+ gradient_clipper=grad_clipper,
217
+ profiler=profiler,
218
+ exporter=exporter,
219
+ metrics=metrics,
220
+ logger=job_logger,
221
+ gradient_manager=gradient_manager,
222
+ timeout_manager=timeout_manager,
223
+ task_operator=task_operator
224
+ )
225
+
226
+ def configure(self) -> "Trainer":
227
+ """
228
+ Instantiates all training components and returns a configured Trainer.
229
+
230
+ This method triggers the creation of the distributed context, sets seeds,
231
+ builds the model, optimizer, data loaders, and attaches all auxiliary
232
+ components (logging, profiling, checkpointing).
233
+
234
+ Returns:
235
+ Trainer: A ready-to-use trainer instance encapsulating the job state.
236
+ """
237
+ state = self._build_new_training_state()
238
+
239
+ return Trainer(state)
240
+
241
+
242
+ class Trainer:
243
+ """
244
+ The main execution engine for running a distributed training job.
245
+
246
+ This class manages the training loop, lifecycle events, distributed synchronization,
247
+ and periodic side-effects (logging, checkpointing).
248
+ """
249
+
250
+ def __init__(self, state: TrainJobState):
251
+ """
252
+ Constructs a Trainer from a pre-built job state.
253
+
254
+ Args:
255
+ state: The encapsulated state object containing all initialized
256
+ components (model, optimizer, dist_context, etc.).
257
+ """
258
+ self._state = state
259
+
260
+ def train(self):
261
+ """
262
+ Executes the full training workflow.
263
+ """
264
+ self._state.dist_context.logger.info("Waiting for the world to start training")
265
+ self._state.dist_context.wait_world()
266
+ self._state.dist_context.logger.info("Trying to load last checkpoint before doing anything else")
267
+ self._state.checkpointer.load_last_checkpoint(self._state)
268
+
269
+ if self._state.stepper.current_step >= self._state.stepper.total_steps:
270
+ self._state.dist_context.logger.info("Already trained fully, will do nothing")
271
+ return
272
+
273
+ self._state.dist_context.wait_world()
274
+
275
+ with (
276
+ tqdm(
277
+ desc="Training",
278
+ total=self._state.stepper.total_steps,
279
+ disable=not self._state.dist_context.is_local_main_process,
280
+ initial=self._state.stepper.current_step
281
+ ) as bar,
282
+ self._state.logger.new_run() as run,
283
+ self._state.garbage_collector as gc,
284
+ self._state.profiler.open() as profiler,
285
+ self._state.gradient_manager.install(),
286
+ self._state.gradient_clipper.install()
287
+ ):
288
+ self._state.timeout_manager.step()
289
+ run.set_context({"stage": "train"})
290
+
291
+ for batch_group in self._state.data_loader:
292
+ run.set_step(self._state.stepper.current_step)
293
+
294
+ for batch in batch_group:
295
+ # we do both forward and backward passes
296
+ # since GradientManager is installed - it should start performing
297
+ # synchronization overlapping grad sync with compute
298
+ loss = self._state.task_operator.forward_backward(batch)
299
+
300
+ # add loss for grad manager - it want it for grad reduction
301
+ if loss is not None:
302
+ self._state.gradient_manager.add_loss_with_weight(loss.loss, loss.loss_weight)
303
+
304
+ # metrics were successfully accumulated during forward passes - we can schedule their synchronization
305
+ self._state.logger.trigger_sync()
306
+
307
+ # wait for gradient synchronization finishes and scale them
308
+ self._state.gradient_manager.sync_and_scale()
309
+
310
+ # clip grads after they are synced across world
311
+ self._state.gradient_clipper.clip_and_log(run)
312
+
313
+ # optimize (it won't sync grads - they are already Replicate-d)
314
+ self._state.optimizer.step()
315
+
316
+ # update LR
317
+ self._state.lr_scheduler.step()
318
+
319
+ # log everything
320
+ self._state.logger.log(
321
+ run,
322
+ loss_value=self._state.gradient_manager.compute_global_loss()
323
+ )
324
+
325
+ # reset grads
326
+ self._state.gradient_manager.zero_grad()
327
+
328
+ gc.collect_periodic()
329
+ self._state.stepper.step()
330
+ bar.update()
331
+
332
+ # checkpoint at the end of the step
333
+ self._state.checkpointer.checkpoint_if_needed(self._state)
334
+
335
+ if profiler:
336
+ profiler.step()
337
+
338
+ self._state.task.finalize(FinalizeContext())
339
+
340
+ def export(self, export_to: Path, load_checkpoint: bool):
341
+ """
342
+ Exports the current model state to the specified directory.
343
+
344
+ This handles the distributed saving logic, allowing the model to be
345
+ reconstituted later or used for inference.
346
+
347
+ Args:
348
+ export_to: The directory path where the model artifacts will be saved.
349
+ load_checkpoint: If True, attempts to load the latest checkpoint
350
+ into the model before exporting.
351
+ """
352
+ if load_checkpoint:
353
+ self._state.checkpointer.load_last_checkpoint(self._state)
354
+
355
+ self._state.exporter.export(export_to)
d9d/loop/state.py ADDED
@@ -0,0 +1,143 @@
1
+ import dataclasses
2
+ from typing import Any
3
+
4
+ from torch.distributed.checkpoint.stateful import Stateful
5
+ from torchdata.stateful_dataloader import StatefulDataLoader
6
+
7
+ from d9d.core.dist_context import DistributedContext
8
+ from d9d.core.protocol import LRSchedulerProtocol, OptimizerProtocol
9
+ from d9d.loop.component import (
10
+ BatchMaths,
11
+ GradientClipper,
12
+ GradientManager,
13
+ JobLogger,
14
+ JobProfiler,
15
+ ManualGarbageCollector,
16
+ ModelStageExporter,
17
+ StateCheckpointer,
18
+ Stepper,
19
+ TimeoutManager,
20
+ TrackedModules,
21
+ TrainTaskOperator,
22
+ )
23
+ from d9d.loop.control import InferenceTask, TrainTask
24
+ from d9d.metric.impl import ComposeMetric
25
+
26
+
27
+ @dataclasses.dataclass(kw_only=True)
28
+ class JobState(Stateful):
29
+ """
30
+ Base container for the state of a distributed execution job.
31
+
32
+ This dataclass holds the common infrastructure components required for both
33
+ training and inference loops. It implements the Stateful protocol to support
34
+ checkpointing of its internal components.
35
+
36
+ Attributes:
37
+ dist_context: The distributed context.
38
+ stepper: Component for tracking the current global step and total steps.
39
+ garbage_collector: Component for manual control of Python garbage collection.
40
+ checkpointer: Component responsible for saving and loading execution states.
41
+ profiler: Component for performance profiling.
42
+ tracked_modules: Container holding the model (or model parts) being executed.
43
+ batch_maths: Helper for calculating batch sizes and gradient accumulation steps.
44
+ data_loader: The input data stream.
45
+ timeout_manager: Component for checking and refreshing distributed timeouts.
46
+ """
47
+ dist_context: DistributedContext
48
+
49
+ stepper: Stepper
50
+ garbage_collector: ManualGarbageCollector
51
+ checkpointer: StateCheckpointer
52
+ profiler: JobProfiler
53
+
54
+ tracked_modules: TrackedModules
55
+ batch_maths: BatchMaths
56
+
57
+ data_loader: StatefulDataLoader
58
+
59
+ timeout_manager: TimeoutManager
60
+
61
+ def state_dict(self) -> dict[str, Any]:
62
+ return {
63
+ "stepper": self.stepper.state_dict(),
64
+ "tracked_modules": self.tracked_modules.state_dict(),
65
+ "data_loader": self.data_loader.state_dict()
66
+ }
67
+
68
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
69
+ self.stepper.load_state_dict(state_dict["stepper"])
70
+ self.tracked_modules.load_state_dict(state_dict["tracked_modules"])
71
+ self.data_loader.load_state_dict(state_dict["data_loader"])
72
+
73
+
74
+ @dataclasses.dataclass(kw_only=True)
75
+ class TrainJobState(JobState):
76
+ """
77
+ Container for the state of a training job.
78
+
79
+ Extends JobState to include components specific to training, such as
80
+ optimization, gradient management, and loss computation.
81
+
82
+ Attributes:
83
+ task: The specific training task logic definition.
84
+ gradient_manager: Component handling gradient synchronization.
85
+ metrics: Container for aggregating training metrics.
86
+ task_operator: Executor for running forward and backward passes.
87
+ logger: Component for logging metrics and system status.
88
+ optimizer: The optimizer instance updating model parameters.
89
+ lr_scheduler: The scheduler adjusting the learning rate.
90
+ gradient_clipper: Component for clipping gradient norms.
91
+ exporter: Component for exporting the final model artifacts.
92
+ """
93
+ task: TrainTask
94
+ gradient_manager: GradientManager
95
+ metrics: ComposeMetric
96
+ task_operator: TrainTaskOperator
97
+
98
+ logger: JobLogger
99
+
100
+ optimizer: OptimizerProtocol
101
+ lr_scheduler: LRSchedulerProtocol
102
+ gradient_clipper: GradientClipper
103
+ exporter: ModelStageExporter
104
+
105
+ def state_dict(self) -> dict[str, Any]:
106
+ return {
107
+ **super().state_dict(),
108
+ "logger": self.logger.state_dict(),
109
+ "task": self.task.state_dict(),
110
+ "metrics": self.metrics.state_dict(),
111
+ "optimizer": self.optimizer.state_dict(),
112
+ "lr_scheduler": self.lr_scheduler.state_dict(),
113
+ }
114
+
115
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
116
+ super().load_state_dict(state_dict)
117
+
118
+ self.logger.load_state_dict(state_dict["logger"])
119
+ self.task.load_state_dict(state_dict["task"])
120
+ self.metrics.load_state_dict(state_dict["metrics"])
121
+ self.optimizer.load_state_dict(state_dict["optimizer"])
122
+ self.lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
123
+
124
+
125
+ @dataclasses.dataclass(kw_only=True)
126
+ class InferJobState(JobState):
127
+ """
128
+ Container for the state of an inference job.
129
+
130
+ Attributes:
131
+ task: The specific inference task logic definition.
132
+ """
133
+ task: InferenceTask
134
+
135
+ def state_dict(self) -> dict[str, Any]:
136
+ return {
137
+ **super().state_dict(),
138
+ "task": self.task.state_dict(),
139
+ }
140
+
141
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
142
+ super().load_state_dict(state_dict)
143
+ self.task.load_state_dict(state_dict["task"])
@@ -0,0 +1,9 @@
1
+ """
2
+ Utilities for learning rate scheduling.
3
+ """
4
+
5
+ from .visualizer import visualize_lr_scheduler
6
+
7
+ __all__ = [
8
+ "visualize_lr_scheduler"
9
+ ]
@@ -0,0 +1,18 @@
1
+ """
2
+ Implements flexible piecewise learning rate schedules via a builder pattern.
3
+ """
4
+
5
+ from .builder import piecewise_schedule
6
+ from .config import PiecewiseSchedulerConfig, piecewise_scheduler_from_config
7
+ from .curves import CurveBase, CurveCosine, CurveExponential, CurveLinear, CurvePoly
8
+
9
+ __all__ = [
10
+ "CurveBase",
11
+ "CurveCosine",
12
+ "CurveExponential",
13
+ "CurveLinear",
14
+ "CurvePoly",
15
+ "PiecewiseSchedulerConfig",
16
+ "piecewise_schedule",
17
+ "piecewise_scheduler_from_config"
18
+ ]
@@ -0,0 +1,152 @@
1
+ from typing import Self
2
+
3
+ from torch.optim import Optimizer
4
+ from torch.optim.lr_scheduler import LambdaLR
5
+
6
+ from d9d.core.protocol import LRSchedulerProtocol
7
+
8
+ from .curves import CurveBase
9
+ from .engine import PiecewiseScheduleEngine, SchedulePhase
10
+
11
+
12
+ class PiecewiseScheduleBuilder:
13
+ """
14
+ Builder for constructing multiphase learning rate schedules.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ initial_multiplier: float,
20
+ total_steps: int | None
21
+ ):
22
+ """
23
+ Constructs a new PiecewiseScheduleBuilder.
24
+
25
+ Args:
26
+ initial_multiplier: The starting learning rate multiplier (usually 0.0 or 1.0).
27
+ total_steps: The total number of training steps. Required if using percentage-based methods.
28
+ """
29
+
30
+ self._phases: list[SchedulePhase] = []
31
+ self._total_steps = total_steps
32
+ self._last_end_step = 0
33
+ self._last_multiplier = initial_multiplier
34
+
35
+ def for_steps(self, steps: int, target_multiplier: float, curve: CurveBase) -> Self:
36
+ """
37
+ Adds a schedule phase lasting for a specific number of steps.
38
+
39
+ Args:
40
+ steps: Duration of this phase in steps.
41
+ target_multiplier: The value of the multiplier at the end of this phase.
42
+ curve: The interpolation curve to use for bridging the start and end values.
43
+
44
+ Returns:
45
+ The builder instance for chaining.
46
+ """
47
+
48
+ self._phases.append(SchedulePhase(
49
+ start_step=self._last_end_step,
50
+ end_step=self._last_end_step + steps,
51
+ curve=curve,
52
+ start_value=self._last_multiplier,
53
+ end_value=target_multiplier
54
+ ))
55
+
56
+ self._last_end_step += steps
57
+ self._last_multiplier = target_multiplier
58
+
59
+ return self
60
+
61
+ def until_percentage(self, p: float, target_multiplier: float, curve: CurveBase) -> Self:
62
+ """
63
+ Adds a schedule phase lasting until a specific percentage of total training steps is reached.
64
+
65
+ Args:
66
+ p: The target percentage (0.0 to 1.0) of total_steps where this phase ends.
67
+ target_multiplier: The value of the multiplier at the end of this phase.
68
+ curve: The interpolation curve to use.
69
+
70
+ Returns:
71
+ The builder instance for chaining.
72
+
73
+ Raises:
74
+ ValueError: If total_steps was not provided in constructor or if the target
75
+ percentage implies a step count earlier than the current cursor.
76
+ """
77
+
78
+ if self._total_steps is None:
79
+ raise ValueError(
80
+ "You must define 'total_steps' in the constructor to use percentage-based methods."
81
+ )
82
+
83
+ if not 0.0 <= p <= 1.0:
84
+ raise ValueError("Percentage should be in range of [0.0, 1.0]")
85
+
86
+ target_step_abs = int(self._total_steps * p)
87
+ duration = target_step_abs - self._last_end_step
88
+
89
+ if duration < 0:
90
+ raise ValueError(
91
+ f"Target percentage {p} (step {target_step_abs}) is behind "
92
+ f"current cursor (step {self._last_end_step})."
93
+ )
94
+
95
+ return self.for_steps(duration, target_multiplier, curve)
96
+
97
+ def fill_rest(self, target_multiplier: float, curve: CurveBase) -> Self:
98
+ """
99
+ Adds a schedule phase that lasts from the current cursor until the end of training.
100
+
101
+ Args:
102
+ target_multiplier: The value of the multiplier at the very end of training.
103
+ curve: The interpolation curve to use.
104
+
105
+ Returns:
106
+ The builder instance for chaining.
107
+ """
108
+
109
+ return self.until_percentage(1.0, target_multiplier, curve)
110
+
111
+ def build(self, optimizer: Optimizer) -> LRSchedulerProtocol:
112
+ """
113
+ Finalizes the schedule and returns a PyTorch LR Scheduler.
114
+
115
+ Args:
116
+ optimizer: The optimizer to wrap.
117
+
118
+ Returns:
119
+ A scheduler configured with the defined phases.
120
+
121
+ Raises:
122
+ ValueError: If the defined phases exceed the total_steps provided.
123
+ """
124
+
125
+ if self._total_steps is not None and self._last_end_step > self._total_steps:
126
+ raise ValueError(
127
+ f"Schedule defined for {self._last_end_step} steps, but total_steps is {self._total_steps}."
128
+ )
129
+
130
+ engine = PiecewiseScheduleEngine(self._phases)
131
+ return LambdaLR(optimizer, engine.get_factor)
132
+
133
+
134
+ def piecewise_schedule(
135
+ initial_multiplier: float,
136
+ total_steps: int | None = None
137
+ ) -> PiecewiseScheduleBuilder:
138
+ """
139
+ Entry point for creating a piecewise learning rate schedule.
140
+
141
+ Args:
142
+ initial_multiplier: The initial learning rate multiplier.
143
+ total_steps: Total training steps. Required for percentage-based scheduling.
144
+
145
+ Returns:
146
+ A builder instance to configure phases.
147
+ """
148
+
149
+ return PiecewiseScheduleBuilder(
150
+ initial_multiplier=initial_multiplier,
151
+ total_steps=total_steps
152
+ )