d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,196 @@
1
+ import abc
2
+ from abc import ABC
3
+ from collections.abc import Iterable
4
+ from typing import Annotated, Literal
5
+
6
+ import torch
7
+ from pydantic import BaseModel, Field
8
+ from torch import nn
9
+ from torch.optim import SGD, Adam, AdamW, Optimizer
10
+
11
+ from d9d.loop.control import InitializeOptimizerStageContext, OptimizerProvider
12
+ from d9d.optim.stochastic import StochasticAdamW
13
+
14
+
15
+ class BaseAutoOptimizerConfig(BaseModel, ABC):
16
+ """
17
+ Abstract base class for optimizer configurations.
18
+ """
19
+
20
+ @abc.abstractmethod
21
+ def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
22
+ """
23
+ Creates the PyTorch optimizer instance.
24
+
25
+ Args:
26
+ params: An iterable of model parameters to optimize.
27
+
28
+ Returns:
29
+ The instantiated optimizer.
30
+ """
31
+ ...
32
+
33
+
34
+ class StochasticAdamWOptimizerConfig(BaseAutoOptimizerConfig):
35
+ """
36
+ Configuration for the Stochastic AdamW optimizer.
37
+
38
+ Attributes:
39
+ name: Discriminator tag.
40
+ lr: Learning rate.
41
+ betas: Coefficients used for computing running averages of gradient and its square.
42
+ eps: Term added to the denominator to improve numerical stability.
43
+ weight_decay: Weight decay coefficient.
44
+ state_dtype: Data Type to use for the optimizer states.
45
+ """
46
+ name: Literal["stochastic_adamw"] = "stochastic_adamw"
47
+
48
+ lr: float
49
+ betas: tuple[float, float] = (0.9, 0.999)
50
+ eps: float = 1e-8
51
+ weight_decay: float = 1e-2
52
+ state_dtype: str
53
+
54
+ def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
55
+ """Builds StochasticAdamW with the configured parameters."""
56
+ return StochasticAdamW(
57
+ params=params,
58
+ lr=self.lr,
59
+ betas=self.betas,
60
+ eps=self.eps,
61
+ weight_decay=self.weight_decay,
62
+ state_dtype=getattr(torch, self.state_dtype)
63
+ )
64
+
65
+
66
+ class AdamWOptimizerConfig(BaseAutoOptimizerConfig):
67
+ """
68
+ Configuration for the PyTorch AdamW optimizer.
69
+
70
+ Attributes:
71
+ name: Discriminator tag.
72
+ lr: The learning rate.
73
+ betas: Coefficients for computing running averages of gradient and its square.
74
+ eps: Term added to the denominator to improve numerical stability.
75
+ weight_decay: Weight decay coefficient.
76
+ amsgrad: Whether to use the AMSGrad variant.
77
+ maximize: Whether to maximize the params based on the objective (as opposed to minimizing).
78
+ """
79
+ name: Literal["adamw"] = "adamw"
80
+
81
+ lr: float
82
+ betas: tuple[float, float] = (0.9, 0.999)
83
+ eps: float = 1e-8
84
+ weight_decay: float = 1e-2
85
+ amsgrad: bool = False
86
+ maximize: bool = False
87
+
88
+ def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
89
+ """Builds fused AdamW with the configured parameters."""
90
+ return AdamW(
91
+ params=params,
92
+ lr=self.lr,
93
+ betas=self.betas,
94
+ eps=self.eps,
95
+ weight_decay=self.weight_decay,
96
+ amsgrad=self.amsgrad,
97
+ maximize=self.maximize,
98
+ fused=True
99
+ )
100
+
101
+
102
+ class AdamOptimizerConfig(BaseAutoOptimizerConfig):
103
+ """
104
+ Configuration for the PyTorch Adam optimizer.
105
+
106
+ Attributes:
107
+ name: Discriminator tag.
108
+ lr: The learning rate.
109
+ betas: Coefficients for computing running averages of gradient and its square.
110
+ eps: Term added to the denominator to improve numerical stability.
111
+ weight_decay: Weight decay coefficient.
112
+ decoupled_weight_decay: Whether to apply decoupled weight decay.
113
+ amsgrad: Whether to use the AMSGrad variant.
114
+ maximize: Whether to maximize the params based on the objective.
115
+ """
116
+ name: Literal["adam"] = "adam"
117
+
118
+ lr: float
119
+ betas: tuple[float, float] = (0.9, 0.999)
120
+ eps: float = 1e-8
121
+ weight_decay: float = 1e-2
122
+ decoupled_weight_decay: bool = False
123
+ amsgrad: bool = False
124
+ maximize: bool = False
125
+
126
+ def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
127
+ """Builds fused Adam with the configured parameters."""
128
+ return Adam(
129
+ params=params,
130
+ lr=self.lr,
131
+ betas=self.betas,
132
+ eps=self.eps,
133
+ weight_decay=self.weight_decay,
134
+ decoupled_weight_decay=self.decoupled_weight_decay,
135
+ amsgrad=self.amsgrad,
136
+ maximize=self.maximize,
137
+ fused=True
138
+ )
139
+
140
+
141
+ class SGDOptimizerConfig(BaseAutoOptimizerConfig):
142
+ """
143
+ Configuration for the PyTorch SGD optimizer.
144
+
145
+ Attributes:
146
+ name: Discriminator tag.
147
+ lr: The learning rate.
148
+ momentum: Momentum factor.
149
+ dampening: Dampening for momentum.
150
+ weight_decay: Weight decay (L2 penalty).
151
+ nesterov: Enables Nesterov momentum.
152
+ maximize: Whether to maximize the params based on the objective.
153
+ """
154
+ name: Literal["sgd"] = "sgd"
155
+
156
+ lr: float
157
+ momentum: float = 0
158
+ dampening: float = 0
159
+ weight_decay: float = 0
160
+ nesterov: bool = False
161
+ maximize: bool = False
162
+
163
+ def build(self, params: Iterable[nn.Parameter]) -> Optimizer:
164
+ """Builds fused SGD with the configured parameters."""
165
+ return SGD(
166
+ params,
167
+ lr=self.lr,
168
+ momentum=self.momentum,
169
+ dampening=self.dampening,
170
+ weight_decay=self.weight_decay,
171
+ nesterov=self.nesterov,
172
+ maximize=self.maximize,
173
+ fused=True
174
+ )
175
+
176
+
177
+ AutoOptimizerConfig = Annotated[
178
+ StochasticAdamWOptimizerConfig |
179
+ AdamWOptimizerConfig |
180
+ AdamOptimizerConfig |
181
+ SGDOptimizerConfig,
182
+ Field(discriminator="name")
183
+ ]
184
+
185
+
186
+ class AutoOptimizerProvider(OptimizerProvider):
187
+ """
188
+ OptimizerProvider that builds a PyTorch optimizer based on a configuration object.
189
+ """
190
+
191
+ def __init__(self, config: AutoOptimizerConfig):
192
+ """Constructs the provider with the given configuration."""
193
+ self._config = config
194
+
195
+ def __call__(self, context: InitializeOptimizerStageContext) -> Optimizer:
196
+ return self._config.build(context.model.parameters())
@@ -0,0 +1,35 @@
1
+ from .batch_maths import BatchMaths
2
+ from .checkpointer import StateCheckpointer
3
+ from .data_loader_factory import DataLoaderFactory
4
+ from .garbage_collector import ManualGarbageCollector
5
+ from .gradient_clipper import GradientClipper
6
+ from .gradient_manager import GradientManager
7
+ from .job_logger import JobLogger
8
+ from .job_profiler import JobProfiler
9
+ from .loss_computer import LossComputer
10
+ from .model_stage_exporter import ModelStageExporter
11
+ from .model_stage_factory import ModelStageFactory, TrackedModules
12
+ from .optimizer_factory import OptimizerFactory
13
+ from .stepper import Stepper
14
+ from .timeout_manager import TimeoutManager
15
+ from .train_task_operator import ForwardResult, TrainTaskOperator
16
+
17
+ __all__ = [
18
+ "BatchMaths",
19
+ "DataLoaderFactory",
20
+ "ForwardResult",
21
+ "GradientClipper",
22
+ "GradientManager",
23
+ "JobLogger",
24
+ "JobProfiler",
25
+ "LossComputer",
26
+ "ManualGarbageCollector",
27
+ "ModelStageExporter",
28
+ "ModelStageFactory",
29
+ "OptimizerFactory",
30
+ "StateCheckpointer",
31
+ "Stepper",
32
+ "TimeoutManager",
33
+ "TrackedModules",
34
+ "TrainTaskOperator"
35
+ ]
@@ -0,0 +1,106 @@
1
+ from d9d.core.dist_context import BATCH_DOMAIN, DistributedContext
2
+ from d9d.loop.config import BatchingConfig, PipeliningConfig
3
+
4
+
5
+ class BatchMaths:
6
+ """
7
+ Calculates derived batching dimensions and iteration counts for distributed training loops.
8
+
9
+ This class bridges the gap between global configuration (Global Batch Size) and
10
+ local execution constraints (Microbatch Size, Data Parallel World Size).
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ dist_context: DistributedContext,
16
+ config_batching: BatchingConfig,
17
+ config_pipelining: PipeliningConfig | None
18
+ ):
19
+ """
20
+ Constructs the batch mathematics calculator.
21
+
22
+ Validates that the Global Batch Size is perfectly divisible by the
23
+ effective parallel microbatch capacity (DP size * Microbatch size).
24
+
25
+ Args:
26
+ dist_context: The distributed context containing mesh layout information.
27
+ config_batching: Configuration detailing batch sizes.
28
+ config_pipelining: Optional configuration for pipeline parallelism capabilities.
29
+
30
+ Raises:
31
+ ValueError: If global batch size is not divisible by the product of
32
+ Data Parallel size and Microbatch size.
33
+ """
34
+
35
+ self._dist_context = dist_context
36
+ self._config_batching = config_batching
37
+ self._config_pipelining = config_pipelining
38
+
39
+ global_batch = self._config_batching.global_batch_size
40
+ dp_size = self._dist_context.mesh_for(BATCH_DOMAIN)["dp"].size()
41
+ microbatch_size = self._config_batching.microbatch_size
42
+
43
+ global_microbatch = dp_size * microbatch_size
44
+
45
+ if global_batch % global_microbatch != 0:
46
+ raise ValueError("Global Batch Size must be divisible by (Data Parallel cardinality * Microbatch Size)")
47
+
48
+ self._global_microbatch_size = global_microbatch
49
+
50
+ @property
51
+ def global_batch_size(self) -> int:
52
+ """
53
+ Returns the global batch size across the world.
54
+ """
55
+
56
+ return self._config_batching.global_batch_size
57
+
58
+ @property
59
+ def num_microbatches_pipelining(self) -> int:
60
+ """
61
+ Returns the number of microbatches handled by the pipeline scheduler per step.
62
+
63
+ If pipeline parallelism is enabled, this is the total number of microbatches
64
+ processed to form one global batch. If disabled, this returns 1.
65
+ """
66
+
67
+ if not self._dist_context.mesh_params.has_pipeline_parallel:
68
+ return 1
69
+
70
+ return self._config_batching.global_batch_size // self._global_microbatch_size
71
+
72
+ @property
73
+ def num_microbatches_gradient_accumulation(self) -> int:
74
+ """
75
+ Returns the number of gradient accumulation iterations for non-pipelined training.
76
+
77
+ If pipeline parallelism is enabled, this returns 1 (as accumulation is handled
78
+ internally by the pipeline schedule). If disabled, this is the number of
79
+ forward/backward passes the training loop must execute before an optimizer step.
80
+ """
81
+
82
+ if self._dist_context.mesh_params.has_pipeline_parallel:
83
+ return 1
84
+
85
+ return self._config_batching.global_batch_size // self._global_microbatch_size
86
+
87
+ @property
88
+ def data_loader_batch_size(self) -> int:
89
+ """
90
+ Returns the quantity of samples this local rank needs to fetch for one optimizer step.
91
+
92
+ This is calculated as `microbatch_size * total_microbatches_per_step`.
93
+ """
94
+
95
+ return self._config_batching.microbatch_size * self.num_microbatches_pipelining
96
+
97
+ @property
98
+ def num_backward_calls(self) -> int:
99
+ """
100
+ Returns the total number of backward passes executed per optimizer step.
101
+
102
+ This represents the total gradient accumulation factor, regardless of whether
103
+ it is handled by a pipeline schedule or a simple loop.
104
+ """
105
+
106
+ return self.num_microbatches_pipelining * self.num_microbatches_gradient_accumulation
@@ -0,0 +1,172 @@
1
+ import re
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.distributed.checkpoint as dcp
7
+ from torch.distributed.checkpoint.stateful import Stateful
8
+
9
+ from d9d.core.dist_context import DistributedContext
10
+ from d9d.loop.config import CheckpointingConfig
11
+
12
+ from .garbage_collector import ManualGarbageCollector
13
+ from .stepper import Stepper
14
+
15
+ # TODO feat(max): async checkpointing may break everything up, but I guess we still have to support it
16
+
17
+ _SAVE_RE = re.compile(r"^save-(\d+)$")
18
+
19
+
20
+ def _save_iter_predicate(x: Path) -> int:
21
+ match = _SAVE_RE.fullmatch(x.stem)
22
+ if match is None:
23
+ raise ValueError("Malformed checkpoint name")
24
+ return int(match.group(1))
25
+
26
+
27
+ class StateCheckpointer:
28
+ """
29
+ Manages the lifecycle of distributed training checkpoints.
30
+
31
+ This class handles saving and loading the training state (JobState object)
32
+ using PyTorch Distributed Checkpoint (DCP). It manages checkpoint versioning,
33
+ storage rotation (keeping only N latest), and synchronization across distributed ranks.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ dist_context: DistributedContext,
39
+ stepper: Stepper,
40
+ config: CheckpointingConfig,
41
+ gc: ManualGarbageCollector,
42
+ run_name: str | None
43
+ ):
44
+ """
45
+ Constructs the StateCheckpoint object.
46
+
47
+ Args:
48
+ dist_context: The distributed context.
49
+ stepper: The training stepper tracking the current iteration/step.
50
+ config: Configuration object containing checkpointing parameters.
51
+ gc: Garbage collector for manual memory management during IO.
52
+ run_name: Optional specific run name to append to the save directory.
53
+ """
54
+ self._dist_context = dist_context
55
+ self._stepper = stepper
56
+ self._gc = gc
57
+
58
+ if run_name:
59
+ self._save_dir = config.save_dir / run_name
60
+ else:
61
+ self._save_dir = config.save_dir
62
+
63
+ self._config = config
64
+
65
+ def _free_memory(self):
66
+ self._gc.collect_forced()
67
+ torch.cuda.empty_cache()
68
+
69
+ def _get_sorted_checkpoint_dirs(self) -> list[Path]:
70
+ if not self._save_dir:
71
+ return []
72
+
73
+ if not self._save_dir.is_dir():
74
+ return []
75
+
76
+ checkpoint_dirs = [x for x in self._save_dir.iterdir() if x.is_dir() and _SAVE_RE.fullmatch(x.stem)]
77
+ checkpoint_dirs = sorted(checkpoint_dirs, key=_save_iter_predicate)
78
+ return checkpoint_dirs
79
+
80
+ def _next_checkpoint_id(self) -> Path:
81
+ next_name = f"save-{self._stepper.current_step}"
82
+ return self._save_dir / next_name
83
+
84
+ def _purge_old_checkpoints(self):
85
+ if not self._dist_context.is_main_process:
86
+ return
87
+ if not self._config.num_to_keep:
88
+ return
89
+
90
+ to_delete = self._get_sorted_checkpoint_dirs()[:-self._config.num_to_keep]
91
+
92
+ for delete_dir in to_delete:
93
+ self._dist_context.logger.info(f"Purging checkpoint {delete_dir}")
94
+ shutil.rmtree(delete_dir)
95
+
96
+ def _checkpoint(self, state: Stateful):
97
+ next_checkpoint_id = self._next_checkpoint_id()
98
+
99
+ self._dist_context.logger.info("Freeing up memory before checkpointing")
100
+ self._free_memory()
101
+ self._dist_context.logger.info("Waiting for world before saving checkpoint")
102
+ self._dist_context.wait_world()
103
+ self._dist_context.logger.info(f"Saving checkpoint {next_checkpoint_id}")
104
+
105
+ save_from = {"state": state}
106
+ dcp.save(
107
+ state_dict=save_from,
108
+ checkpoint_id=next_checkpoint_id
109
+ )
110
+
111
+ self._purge_old_checkpoints()
112
+ self._free_memory()
113
+
114
+ self._dist_context.logger.info("Waiting for world after saving checkpoint")
115
+ self._dist_context.wait_world()
116
+ self._dist_context.logger.info("Checkpoint successfully saved across the world")
117
+
118
+ def checkpoint_if_needed(self, state: Stateful):
119
+ """
120
+ Checks if a checkpoint is due based on the configuration and saves if necessary.
121
+
122
+ This checks the stepper to see if the current step matches the configured
123
+ saving period (or if it is the final step).
124
+
125
+ Args:
126
+ state: The Stateful object to save.
127
+ """
128
+
129
+ if self._stepper.should_do_action(self._config.period_steps, enable_on_last_step_if_periodic=True):
130
+ self._checkpoint(state)
131
+
132
+ def _last_checkpoint_id(self) -> Path | None:
133
+ checkpoints = self._get_sorted_checkpoint_dirs()
134
+ if len(checkpoints) == 0:
135
+ return None
136
+ return checkpoints[-1]
137
+
138
+ def _load(self, state: Stateful):
139
+ last_checkpoint = self._last_checkpoint_id()
140
+
141
+ if last_checkpoint is None:
142
+ self._dist_context.logger.info("Starting job from scratch")
143
+ return
144
+
145
+ self._dist_context.logger.info("Waiting for world before loading checkpoint")
146
+ self._dist_context.wait_world()
147
+ self._dist_context.logger.info(f"Loading checkpoint {last_checkpoint}")
148
+
149
+ load_into = {
150
+ "state": state
151
+ }
152
+ dcp.load(
153
+ state_dict=load_into,
154
+ checkpoint_id=last_checkpoint
155
+ )
156
+ self._free_memory()
157
+
158
+ self._dist_context.logger.info("Waiting for world after loading checkpoint")
159
+ self._dist_context.wait_world()
160
+ self._dist_context.logger.info("Checkpoint successfully loaded across the world")
161
+
162
+ def load_last_checkpoint(self, state: Stateful):
163
+ """
164
+ Attempts to load the most recent checkpoint available in the save directory.
165
+
166
+ If no checkpoint is found, the state remains unchanged (starting from scratch).
167
+
168
+ Args:
169
+ state: The stateful object to which loaded parameters will be applied.
170
+ """
171
+
172
+ self._load(state)