d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,321 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ from torch import nn
6
+
7
+ from d9d.pipelining.api import ModuleSupportsPipelining, PipelineStageInfo
8
+
9
+ from .communications import StageCommunicationHandler
10
+ from .computations import BackwardComputeHandler, ForwardComputeHandler
11
+
12
+
13
+ class PipelineStage:
14
+ """
15
+ Represents a single structural stage in a Pipelined Model.
16
+
17
+ This class acts as an orchestrator that combines `StageCommunicationHandler` (for I/O)
18
+ and `Forward/BackwardComputeHandler` (for execution). It abstracts away the complexity
19
+ of buffer management, distributed communication, and gradient calculation from the scheduler.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ info: PipelineStageInfo,
25
+ module: nn.Module,
26
+ group: dist.ProcessGroup,
27
+ stage_to_host_topology: dict[int, int]
28
+ ):
29
+ """
30
+ Constructs a PipelineStage object.
31
+
32
+ Args:
33
+ info: Metadata about the stage (index, total stages).
34
+ module: The PyTorch module executed by this stage.
35
+ group: The distributed process group for pipeline communications.
36
+ stage_to_host_topology: Dict mapping stage ID to PP rank hosting it.
37
+ """
38
+
39
+ self._info = info
40
+ self._module = module
41
+ self._group = group
42
+ self._stage_to_host_topology = stage_to_host_topology
43
+
44
+ self._has_backward = False
45
+
46
+ self._forward_comm: StageCommunicationHandler | None = None
47
+ self._backward_comm: StageCommunicationHandler | None = None
48
+
49
+ self._forward_comp = ForwardComputeHandler(
50
+ stage_index=info.current_stage,
51
+ module=module
52
+ )
53
+ self._backward_comp = BackwardComputeHandler(
54
+ stage_index=info.current_stage,
55
+ module=module
56
+ )
57
+
58
+ @property
59
+ def info(self) -> PipelineStageInfo:
60
+ return self._info
61
+
62
+ def configure_buffers(
63
+ self,
64
+ num_microbatches: int,
65
+ has_backward: bool,
66
+ pipeline_inputs: dict[str, torch.Tensor]
67
+ ):
68
+ """
69
+ Initializes the communication handlers and buffers for the stage.
70
+
71
+ This must be called before execution to establish P2P buffer sizes and directions.
72
+
73
+ Args:
74
+ num_microbatches: Total number of microbatches to process.
75
+ has_backward: Does this pipeline stage should store info for a backward pass
76
+ pipeline_inputs: Pipeline input data.
77
+ """
78
+
79
+ self._has_backward = has_backward
80
+
81
+ prev_stage_idx = None if self._info.is_current_stage_first else self._info.current_stage - 1
82
+ next_stage_idx = None if self._info.is_current_stage_last else self._info.current_stage + 1
83
+
84
+ with torch.device("meta"):
85
+ if not isinstance(self._module, ModuleSupportsPipelining):
86
+ raise TypeError("Module does not implement ModuleSupportsPipelining protocol")
87
+ inputs_meta = self._module.infer_stage_inputs_from_pipeline_inputs(
88
+ inputs=pipeline_inputs,
89
+ n_microbatches=num_microbatches
90
+ )
91
+ outputs_meta = self._module.infer_stage_outputs_from_pipeline_inputs(
92
+ inputs=pipeline_inputs,
93
+ n_microbatches=num_microbatches
94
+ )
95
+
96
+ self._forward_comm = StageCommunicationHandler(
97
+ name="fwd",
98
+ stage_index=self._info.current_stage,
99
+ num_microbatches=num_microbatches,
100
+ input_stage_index=prev_stage_idx,
101
+ input_args=inputs_meta,
102
+ output_stage_index=next_stage_idx,
103
+ output_args=outputs_meta,
104
+ group=self._group,
105
+ stage_idx_to_host_rank=self._stage_to_host_topology
106
+ )
107
+ self._forward_comm.set_input_requires_grad_(requires_grad=has_backward)
108
+
109
+ if has_backward:
110
+ # for grad - current stage receives OUTPUTS as inputs and sends INPUTS as outputs
111
+ # because it is reversed forward
112
+ self._backward_comm = StageCommunicationHandler(
113
+ name="bwd",
114
+ stage_index=self._info.current_stage,
115
+ num_microbatches=num_microbatches,
116
+ input_stage_index=next_stage_idx,
117
+ input_args=outputs_meta,
118
+ output_stage_index=prev_stage_idx,
119
+ output_args=inputs_meta,
120
+ group=self._group,
121
+ stage_idx_to_host_rank=self._stage_to_host_topology
122
+ )
123
+ else:
124
+ self._backward_comm = None
125
+
126
+ def set_local_fwd_input(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
127
+ """
128
+ Sets local forward inputs manually.
129
+
130
+ Used for the V-shape schedulers.
131
+ """
132
+
133
+ if self._forward_comm is None:
134
+ raise ValueError("You must configure stage buffers first")
135
+
136
+ self._forward_comm.set_inputs_local(inputs, microbatch_index)
137
+
138
+ def get_local_fwd_output(self, microbatch_index: int) -> dict[str, torch.Tensor]:
139
+ return self._forward_comp.get_outputs(microbatch_index)
140
+
141
+ def pop_local_bwd_output(self, microbatch_index: int) -> dict[str, torch.Tensor]:
142
+ """
143
+ Retrieves local backward outputs (gradients).
144
+ """
145
+
146
+ if not self._has_backward:
147
+ raise ValueError()
148
+
149
+ return self._backward_comp.pop_for_sending(microbatch_index)
150
+
151
+ def set_local_bwd_input(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
152
+ """
153
+ Sets local backward inputs (output gradients) manually.
154
+ """
155
+
156
+ if not self._has_backward:
157
+ raise ValueError()
158
+
159
+ if self._backward_comm is None:
160
+ raise ValueError("You must configure stage buffers first")
161
+
162
+ self._backward_comm.set_inputs_local(inputs, microbatch_index)
163
+
164
+ def get_fwd_recv_ops(self, microbatch_index: int) -> list[dist.P2POp]:
165
+ """Returns P2P ops to receive forward inputs for the given microbatch."""
166
+
167
+ if self._forward_comm is None:
168
+ raise ValueError("You must configure stage buffers first")
169
+
170
+ return self._forward_comm.create_receive_ops(microbatch_index)
171
+
172
+ def get_fwd_send_ops(self, microbatch_index: int) -> list[dist.P2POp]:
173
+ """Returns P2P ops to send forward outputs for the given microbatch."""
174
+
175
+ if self._forward_comm is None:
176
+ raise ValueError("You must configure stage buffers first")
177
+
178
+ fwd_result = self._forward_comp.get_outputs(microbatch_index)
179
+ return self._forward_comm.create_send_ops(fwd_result)
180
+
181
+ def get_bwd_recv_ops(self, microbatch_index: int) -> list[dist.P2POp]:
182
+ """Returns P2P ops to receive backward gradients for the given microbatch."""
183
+
184
+ if not self._has_backward:
185
+ return []
186
+
187
+ if self._backward_comm is None:
188
+ raise ValueError("You must configure stage buffers first")
189
+
190
+ return self._backward_comm.create_receive_ops(microbatch_index)
191
+
192
+ def get_bwd_send_ops(self, microbatch_index: int) -> list[dist.P2POp]:
193
+ """Returns P2P ops to send backward gradients for the given microbatch."""
194
+
195
+ if not self._has_backward:
196
+ return []
197
+
198
+ if self._backward_comm is None:
199
+ raise ValueError("You must configure stage buffers first")
200
+
201
+ bwd_result = self._backward_comp.pop_for_sending(microbatch_index)
202
+ return self._backward_comm.create_send_ops(bwd_result)
203
+
204
+ def forward_one_chunk(
205
+ self,
206
+ microbatch_index: int,
207
+ pipeline_inputs: dict[str, torch.Tensor],
208
+ pipeline_kwargs: dict[str, Any] | None = None,
209
+ ):
210
+ """
211
+ Executes a forward pass for a single microbatch chunk.
212
+
213
+ Fetches inputs from the communication buffer (or `pipeline_inputs` if first stage),
214
+ runs the computation, and caches the result.
215
+
216
+ Args:
217
+ microbatch_index: The microbatch index.
218
+ pipeline_inputs: Inputs provided locally (only used if this is the first stage).
219
+ pipeline_kwargs: Additional arguments for the module.
220
+
221
+ Returns:
222
+ The output tensors of the forward pass.
223
+ """
224
+
225
+ if self._forward_comm is None:
226
+ raise ValueError("You must configure stage buffers first")
227
+
228
+ if self._info.is_current_stage_first:
229
+ inputs = pipeline_inputs
230
+ else:
231
+ inputs = self._forward_comm.get_inputs(microbatch_index)
232
+
233
+ kwargs = pipeline_kwargs or {}
234
+
235
+ self._forward_comp.run(
236
+ microbatch_index=microbatch_index,
237
+ inputs=inputs,
238
+ kwargs=kwargs
239
+ )
240
+
241
+ def backward_one_chunk(
242
+ self,
243
+ microbatch_index: int,
244
+ loss: torch.Tensor | None = None,
245
+ full_backward: bool = True
246
+ ):
247
+ """
248
+ Executes a backward pass for a single microbatch chunk.
249
+
250
+ Can perform either a full backward or just the input gradients (if `full_backward=False`).
251
+ It fetches required data from forward cache and communication buffers.
252
+
253
+ Args:
254
+ microbatch_index: The microbatch index.
255
+ loss: The loss tensor (only used if this is the last stage).
256
+ full_backward: If True, computes grads for inputs and weights. If False, only for inputs.
257
+ """
258
+
259
+ if not self._has_backward:
260
+ raise ValueError()
261
+
262
+ if self._backward_comm is None:
263
+ raise ValueError("You must configure stage buffers first")
264
+
265
+ inputs, fwd_outputs = self._forward_comp.pop_inputs_outputs(microbatch_index)
266
+
267
+ outputs: dict[str, torch.Tensor]
268
+ outputs_grad: dict[str, torch.Tensor] | None
269
+
270
+ if self._info.is_current_stage_last:
271
+ if loss is None:
272
+ raise ValueError("Cannot perform backward on last stage without loss specified")
273
+ outputs = {"loss": loss}
274
+ outputs_grad = None
275
+ else:
276
+ outputs = fwd_outputs
277
+ outputs_grad = self._backward_comm.get_inputs(microbatch_index)
278
+
279
+ if full_backward:
280
+ self._backward_comp.backward_full(
281
+ microbatch_index=microbatch_index,
282
+ inputs=inputs,
283
+ outputs=outputs,
284
+ outputs_grad=outputs_grad
285
+ )
286
+ else:
287
+ self._backward_comp.backward_input(
288
+ microbatch_index=microbatch_index,
289
+ inputs=inputs,
290
+ outputs=outputs,
291
+ outputs_grad=outputs_grad
292
+ )
293
+
294
+ if self._info.is_current_stage_last and not self._info.is_current_stage_first:
295
+ for t in fwd_outputs.values():
296
+ if not t._is_view(): # noqa: SLF001
297
+ t.detach_()
298
+
299
+ def backward_weight_one_chunk(self, microbatch_index: int):
300
+ """
301
+ Executes the weight gradient accumulation part of the backward pass.
302
+
303
+ This assumes `backward_one_chunk(..., full_backward=False)` was already called
304
+ for this microbatch.
305
+
306
+ Args:
307
+ microbatch_index: The microbatch index.
308
+ """
309
+
310
+ if not self._has_backward:
311
+ raise ValueError()
312
+
313
+ self._backward_comp.backward_weight(microbatch_index=microbatch_index)
314
+
315
+ def reset(self):
316
+ """Resets the internal state of communication handlers, clearing gradients on buffers."""
317
+
318
+ if self._forward_comm is not None:
319
+ self._forward_comm.reset()
320
+ if self._backward_comm is not None:
321
+ self._backward_comm.reset()
@@ -0,0 +1,46 @@
1
+ from collections.abc import Iterable, Sequence
2
+ from typing import TypeVar
3
+
4
+ T = TypeVar("T")
5
+
6
+
7
+ class DictFlattener:
8
+ """
9
+ Helper class to flatten and unflatten dictionaries into sequences deterministically.
10
+ """
11
+
12
+ def __init__(self, keys: Iterable[str]):
13
+ """
14
+ Constructs a DictFlattener object.
15
+
16
+ Args:
17
+ keys: The collection of dictionary keys to manage. They will be sorted internally.
18
+ """
19
+
20
+ self._order_to_key = {i: x for i, x in enumerate(sorted(keys))}
21
+
22
+ def flatten(self, inputs: dict[str, T]) -> list[T]:
23
+ """
24
+ Converts a dictionary into a list based on the sorted internal key order.
25
+
26
+ Args:
27
+ inputs: The dictionary to flatten. Must contain all keys provided at init.
28
+
29
+ Returns:
30
+ A list of values sorted by their corresponding keys.
31
+ """
32
+
33
+ return [inputs[self._order_to_key[i]] for i in range(len(inputs))]
34
+
35
+ def unflatten(self, outputs: Sequence[T]) -> dict[str, T]:
36
+ """
37
+ Reconstructs a dictionary from a sequence of values.
38
+
39
+ Args:
40
+ outputs: A sequence of values corresponding to the sorted internal key order.
41
+
42
+ Returns:
43
+ A dictionary mapping original keys to the provided values.
44
+ """
45
+
46
+ return {self._order_to_key[i]: out for i, out in enumerate(outputs)}
@@ -0,0 +1,7 @@
1
+ from .optimizer import PipelinedOptimizer
2
+ from .scheduler import PipelinedLRScheduler
3
+
4
+ __all__ = [
5
+ "PipelinedLRScheduler",
6
+ "PipelinedOptimizer"
7
+ ]
@@ -0,0 +1,41 @@
1
+ from typing import Any
2
+
3
+ from torch.distributed import DeviceMesh
4
+
5
+ from d9d.core.protocol import OptimizerProtocol
6
+
7
+
8
+ class PipelinedOptimizer(OptimizerProtocol):
9
+ """
10
+ Wrapper that manages multiple optimizers for a pipeline parallel rank.
11
+
12
+ In a pipeline parallel setup, a single rank might host multiple stages, each having its own parameters
13
+ and optimizer.
14
+ This class aggregates them into a single interface.
15
+ """
16
+
17
+ def __init__(self, mesh_pp: DeviceMesh, optimizers: list[OptimizerProtocol]):
18
+ super().__init__()
19
+
20
+ self._mesh_pp = mesh_pp
21
+ self._optimizers = optimizers
22
+
23
+ def state_dict(self) -> dict[str, Any]:
24
+ pp_rank = self._mesh_pp.get_local_rank()
25
+ return {
26
+ f"pp_{pp_rank}_stage_{i}": optimizer.state_dict()
27
+ for i, optimizer in enumerate(self._optimizers)
28
+ }
29
+
30
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
31
+ pp_rank = self._mesh_pp.get_local_rank()
32
+ for i, optimizer in enumerate(self._optimizers):
33
+ optimizer.load_state_dict(state_dict[f"pp_{pp_rank}_stage_{i}"])
34
+
35
+ def step(self) -> None:
36
+ for optimizer in self._optimizers:
37
+ optimizer.step()
38
+
39
+ def zero_grad(self) -> None:
40
+ for optimizer in self._optimizers:
41
+ optimizer.zero_grad()
@@ -0,0 +1,34 @@
1
+ from typing import Any
2
+
3
+ from torch.distributed import DeviceMesh
4
+
5
+ from d9d.core.protocol import LRSchedulerProtocol
6
+
7
+
8
+ class PipelinedLRScheduler(LRSchedulerProtocol):
9
+ """
10
+ Wrapper that manages multiple LR schedulers for a pipeline parallel rank.
11
+
12
+ Similar to `PipelinedOptimizer`, this aggregates schedulers corresponding to
13
+ multiple model stages hosted on the current rank.
14
+ """
15
+
16
+ def __init__(self, mesh_pp: DeviceMesh, schedulers: list[LRSchedulerProtocol]):
17
+ self._mesh_pp = mesh_pp
18
+ self._schedulers = schedulers
19
+
20
+ def state_dict(self) -> dict[str, Any]:
21
+ pp_rank = self._mesh_pp.get_local_rank()
22
+ return {
23
+ f"pp_{pp_rank}_stage_{i}": scheduler.state_dict()
24
+ for i, scheduler in enumerate(self._schedulers)
25
+ }
26
+
27
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
28
+ pp_rank = self._mesh_pp.get_local_rank()
29
+ for i, scheduler in enumerate(self._schedulers):
30
+ scheduler.load_state_dict(state_dict[f"pp_{pp_rank}_stage_{i}"])
31
+
32
+ def step(self) -> None:
33
+ for scheduler in self._schedulers:
34
+ scheduler.step()
@@ -0,0 +1,14 @@
1
+ """
2
+ Package providing a unified interface for experiment tracking and logging.
3
+ """
4
+
5
+ from .base import BaseTracker, BaseTrackerRun, RunConfig
6
+ from .factory import AnyTrackerConfig, tracker_from_config
7
+
8
+ __all__ = [
9
+ "AnyTrackerConfig",
10
+ "BaseTracker",
11
+ "BaseTrackerRun",
12
+ "RunConfig",
13
+ "tracker_from_config"
14
+ ]
d9d/tracker/base.py ADDED
@@ -0,0 +1,124 @@
1
+ import abc
2
+ from collections.abc import Generator
3
+ from contextlib import contextmanager
4
+ from typing import Any, Generic, Self, TypeVar
5
+
6
+ import torch
7
+ from pydantic import BaseModel, Field
8
+ from torch.distributed.checkpoint.stateful import Stateful
9
+
10
+
11
+ class BaseTrackerRun(abc.ABC):
12
+ """
13
+ Abstract base class representing an active tracking session (run).
14
+
15
+ This object is responsible for the actual logging of metrics, parameters,
16
+ during train or inference run.
17
+ """
18
+
19
+ @abc.abstractmethod
20
+ def set_step(self, step: int):
21
+ """
22
+ Updates the global step counter for subsequent logs.
23
+
24
+ Args:
25
+ step: The current step index (e.g., iteration number).
26
+ """
27
+ ...
28
+
29
+ @abc.abstractmethod
30
+ def set_context(self, context: dict[str, str]):
31
+ """
32
+ Sets a persistent context dictionary for subsequent logs.
33
+
34
+ These context values (tags) will be attached to every metric logged
35
+ until changed.
36
+
37
+ Args:
38
+ context: A dictionary of tag names and values.
39
+ """
40
+ ...
41
+
42
+ @abc.abstractmethod
43
+ def scalar(self, name: str, value: float, context: dict[str, str] | None = None):
44
+ """
45
+ Logs a scalar value.
46
+
47
+ Args:
48
+ name: The name of the metric.
49
+ value: The scalar value to log.
50
+ context: Optional ephemeral context specific to this metric event.
51
+ Merged with global context if present.
52
+ """
53
+ ...
54
+
55
+ @abc.abstractmethod
56
+ def bins(self, name: str, values: torch.Tensor, context: dict[str, str] | None = None):
57
+ """
58
+ Logs a distribution/histogram of values.
59
+
60
+ Args:
61
+ name: The name of the metric.
62
+ values: A tensor containing the population of values to bin.
63
+ context: Optional ephemeral context specific to this metric event.
64
+ Merged with global context if present.
65
+ """
66
+ ...
67
+
68
+
69
+ class RunConfig(BaseModel):
70
+ """
71
+ Configuration for initializing a specific logged run.
72
+
73
+ Attributes:
74
+ name: The display name of the experiment.
75
+ description: An optional description of the experiment.
76
+ hparams: A dictionary of hyperparameters to log at the start of the run.
77
+ """
78
+
79
+ name: str
80
+ description: str | None
81
+ hparams: dict[str, Any] = Field(default_factory=dict)
82
+
83
+
84
+ TConfig = TypeVar("TConfig", bound=BaseModel)
85
+
86
+
87
+ class BaseTracker(abc.ABC, Stateful, Generic[TConfig]):
88
+ """
89
+ Abstract base class for a tracker backend factory.
90
+
91
+ This class manages the lifecycle of runs and integration with the
92
+ distributed checkpointing system to ensure experiment continuity
93
+ (e.g., resuming the same run hash after a restart).
94
+ """
95
+
96
+ @contextmanager
97
+ @abc.abstractmethod
98
+ def open(self, properties: RunConfig) -> Generator[BaseTrackerRun, None, None]:
99
+ """
100
+ Context manager that initiates and manages an experiment run.
101
+
102
+ Args:
103
+ properties: Configuration metadata for the run.
104
+
105
+ Yields:
106
+ An active BaseTrackerRun instance for logging metrics.
107
+ """
108
+
109
+ ...
110
+
111
+ @classmethod
112
+ @abc.abstractmethod
113
+ def from_config(cls, config: TConfig) -> Self:
114
+ """
115
+ Factory method to create a tracker instance from a configuration object.
116
+
117
+ Args:
118
+ config: The backend-specific configuration object.
119
+
120
+ Returns:
121
+ An initialized instance of the tracker.
122
+ """
123
+
124
+ ...
d9d/tracker/factory.py ADDED
@@ -0,0 +1,57 @@
1
+ import dataclasses
2
+ from typing import Annotated
3
+
4
+ from pydantic import Field
5
+
6
+ from .base import BaseTracker
7
+ from .provider.aim.config import AimConfig
8
+ from .provider.null import NullTracker, NullTrackerConfig
9
+
10
+ AnyTrackerConfig = Annotated[AimConfig | NullTrackerConfig, Field(discriminator="provider")]
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class _TrackerImportFailed:
15
+ dependency: str
16
+ exception: ImportError
17
+
18
+
19
+ _MAP: dict[type[AnyTrackerConfig], type[BaseTracker] | _TrackerImportFailed] = {
20
+ NullTrackerConfig: NullTracker
21
+ }
22
+
23
+ try:
24
+ from .provider.aim.tracker import AimTracker
25
+
26
+ _MAP[AimConfig] = AimTracker
27
+ except ImportError as e:
28
+ _MAP[AimConfig] = _TrackerImportFailed(dependency="aim", exception=e)
29
+
30
+
31
+ def tracker_from_config(config: AnyTrackerConfig) -> BaseTracker:
32
+ """
33
+ Instantiates a specific tracker implementation based on the configuration.
34
+
35
+ Based on the 'provider' field in the config, this function selects the
36
+ appropriate backend (e.g., Aim, Null). It handles checking for missing
37
+ dependencies for optional backends.
38
+
39
+ Args:
40
+ config: A specific tracker configuration object.
41
+
42
+ Returns:
43
+ An initialized BaseTracker instance.
44
+
45
+ Raises:
46
+ ImportError: If the dependencies for the requested provider are not installed.
47
+ """
48
+
49
+ tracker_type = _MAP[type(config)]
50
+
51
+ if isinstance(tracker_type, _TrackerImportFailed):
52
+ raise ImportError(
53
+ f"The tracker configuration {config.provider} could not be loaded - "
54
+ f"ensure these dependencies are installed: {tracker_type.dependency}"
55
+ ) from tracker_type.exception
56
+
57
+ return tracker_type.from_config(config)
File without changes
File without changes
@@ -0,0 +1,23 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class AimConfig(BaseModel):
7
+ """
8
+ Configuration for the Aim tracker backend.
9
+
10
+ Attributes:
11
+ provider: Discriminator field, must be 'aim'.
12
+ repo: Path to the Aim repository directory or URL.
13
+ log_system_params: Whether to log system resource usage (CPU/GPU/Memory).
14
+ capture_terminal_logs: Whether to capture stdout/stderr.
15
+ system_tracking_interval: Interval in seconds for system monitoring.
16
+ """
17
+
18
+ provider: Literal["aim"] = "aim"
19
+
20
+ repo: str
21
+ log_system_params: bool = True
22
+ capture_terminal_logs: bool = True
23
+ system_tracking_interval: int = 10