d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,317 @@
1
+ import dataclasses
2
+ from collections.abc import Iterator, Mapping, Sequence
3
+ from typing import Any, cast
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.autograd.graph import Node
8
+
9
+ from .splitgrad import (
10
+ ParamGroup,
11
+ stage_backward_full,
12
+ stage_backward_input,
13
+ stage_backward_weight,
14
+ )
15
+ from .struct_helper import DictFlattener
16
+
17
+ # TODO/NOTICE: We WILL NOT disable FSDP's resharding for microbatches since it will modify
18
+ # TODO/NOTICE: its behavior in an unexpected way. Perhaps we need better FSDP resharding policy handler?
19
+
20
+
21
+ @dataclasses.dataclass(slots=True)
22
+ class ForwardCache:
23
+ """
24
+ Stores the inputs and outputs of a forward pass to be used later in the backward pass.
25
+ """
26
+
27
+ inputs: dict[str, torch.Tensor]
28
+ outputs: dict[str, torch.Tensor]
29
+
30
+
31
+ class ForwardComputeHandler:
32
+ """
33
+ Handles the execution of the forward pass for a pipeline stage module.
34
+
35
+ Maintains a cache of inputs and outputs indexed by microbatch ID.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ stage_index: int,
41
+ module: nn.Module
42
+ ):
43
+ """
44
+ Constructs a ForwardComputeHandler object.
45
+
46
+ Args:
47
+ stage_index: Logical index of the stage.
48
+ module: The PyTorch module representing this stage computation.
49
+ """
50
+
51
+ self._stage_idx = stage_index
52
+ self._module = module
53
+
54
+ self._cache: dict[int, ForwardCache] = {}
55
+
56
+ def run(
57
+ self,
58
+ microbatch_index: int,
59
+ inputs: dict[str, torch.Tensor],
60
+ kwargs: dict[str, Any]
61
+ ):
62
+ """
63
+ Executes the module's forward pass.
64
+
65
+ Args:
66
+ microbatch_index: Identifier for the current microbatch.
67
+ inputs: Dictionary of input tensors.
68
+ kwargs: Additional keyword arguments for the module.
69
+
70
+ Returns:
71
+ The output of the module.
72
+
73
+ Raises:
74
+ RuntimeError: If the forward pass implementation fails.
75
+ """
76
+
77
+ # Compute forward
78
+ try:
79
+ output = self._module(**inputs, **kwargs)
80
+ except Exception as e:
81
+ raise RuntimeError(f"S{self._stage_idx}B{microbatch_index} failed to run forward") from e
82
+
83
+ if not isinstance(output, Mapping):
84
+ raise ValueError("Currently, pipelined models should output dict[str, torch.Tensor | None]")
85
+
86
+ output = {k: v for k, v in output.items() if v is not None}
87
+
88
+ self._cache[microbatch_index] = ForwardCache(
89
+ inputs=inputs,
90
+ outputs=output
91
+ )
92
+
93
+ def get_outputs(self, microbatch_index: int) -> dict[str, torch.Tensor]:
94
+ """
95
+ Retrieves cached outputs for a specific microbatch without removing them.
96
+
97
+ Args:
98
+ microbatch_index: Identifier for the microbatch.
99
+
100
+ Returns:
101
+ Dictionary of output tensors.
102
+ """
103
+
104
+ return self._cache[microbatch_index].outputs
105
+
106
+ def pop_inputs_outputs(
107
+ self, microbatch_index: int
108
+ ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
109
+ """
110
+ Retrieves and removes the cached inputs and outputs for a specific microbatch.
111
+
112
+ Typically called when initiating the backward pass.
113
+
114
+ Args:
115
+ microbatch_index: Identifier for the microbatch.
116
+
117
+ Returns:
118
+ A tuple containing (inputs, outputs).
119
+ """
120
+
121
+ cache = self._cache.pop(microbatch_index)
122
+ return cache.inputs, cache.outputs
123
+
124
+
125
+ @dataclasses.dataclass(kw_only=True, slots=True)
126
+ class BackwardCacheInputForWeight:
127
+ """
128
+ State preserved after calculating input gradients, pending weight gradient calculation.
129
+ """
130
+
131
+ inputs_grad: dict[str, torch.Tensor]
132
+ param_groups: list[ParamGroup]
133
+ ownership_tokens: list[Node]
134
+
135
+
136
+ @dataclasses.dataclass(kw_only=True, slots=True)
137
+ class BackwardCacheInputForFull:
138
+ stage_outputs_or_loss: list[torch.Tensor]
139
+ output_grads: list[torch.Tensor] | None
140
+ input_values: list[torch.Tensor]
141
+
142
+
143
+ @dataclasses.dataclass(kw_only=True, slots=True)
144
+ class BackwardCacheFull:
145
+ """
146
+ State preserved after calculating weight gradients.
147
+ """
148
+
149
+ inputs_grad: dict[str, torch.Tensor | None]
150
+
151
+
152
+ class BackwardComputeHandler:
153
+ """
154
+ Handles the execution of backward passes for a pipeline stage.
155
+
156
+ Supports splitting the backward pass into input-gradients and weight-gradients
157
+ phases, which is necessary for schedules like ZB.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ stage_index: int,
163
+ module: nn.Module
164
+ ):
165
+ """
166
+ Constructs a BackwardComputeHandler object.
167
+
168
+ Args:
169
+ stage_index: Logical index of the stage.
170
+ module: The PyTorch module to compute gradients for.
171
+ """
172
+
173
+ self._stage_idx = stage_index
174
+ self._module = module
175
+
176
+ self._cache: dict[int, BackwardCacheInputForWeight | BackwardCacheInputForFull | BackwardCacheFull] = {}
177
+
178
+ def _parameters_with_grad(self) -> Iterator[nn.Parameter]:
179
+ return (param for param in self._module.parameters() if param.requires_grad)
180
+
181
+ def backward_full(
182
+ self,
183
+ microbatch_index: int,
184
+ inputs: dict[str, torch.Tensor],
185
+ outputs: dict[str, torch.Tensor],
186
+ outputs_grad: dict[str, torch.Tensor] | None,
187
+ ):
188
+ """
189
+ Performs a full backward pass (both inputs and weights).
190
+
191
+ Args:
192
+ microbatch_index: Identifier for the microbatch.
193
+ inputs: The inputs used in the forward pass.
194
+ outputs: The outputs produced by the forward pass.
195
+ outputs_grad: Gradients of the loss with respect to the outputs.
196
+ """
197
+
198
+ if microbatch_index in self._cache:
199
+ raise ValueError(f"S{self._stage_idx}B{microbatch_index} double backward")
200
+
201
+ inputs_flattener = DictFlattener(inputs.keys())
202
+ outputs_flattener = DictFlattener(outputs.keys())
203
+
204
+ inputs_grad_linear = stage_backward_full(
205
+ outputs=outputs_flattener.flatten(outputs),
206
+ output_grads=outputs_flattener.flatten(outputs_grad) if outputs_grad is not None else None,
207
+ inputs=inputs_flattener.flatten(inputs)
208
+ )
209
+
210
+ if self._stage_idx != 0:
211
+ self._cache[microbatch_index] = BackwardCacheFull(
212
+ inputs_grad=inputs_flattener.unflatten(inputs_grad_linear)
213
+ )
214
+
215
+ def backward_input(
216
+ self,
217
+ microbatch_index: int,
218
+ inputs: dict[str, torch.Tensor],
219
+ outputs: dict[str, torch.Tensor],
220
+ outputs_grad: dict[str, torch.Tensor] | None
221
+ ):
222
+ """
223
+ Performs a partial backward pass to compute gradients with respect to inputs only.
224
+
225
+ This prepares the computation state for a subsequent `backward_weight` call.
226
+
227
+ Args:
228
+ microbatch_index: Identifier for the microbatch.
229
+ inputs: The inputs used in the forward pass.
230
+ outputs: The outputs produced by the forward pass.
231
+ outputs_grad: Gradients of the loss with respect to the outputs.
232
+ """
233
+
234
+ if microbatch_index in self._cache:
235
+ raise ValueError("Double backward pass")
236
+
237
+ inputs_flattener = DictFlattener(inputs.keys())
238
+ outputs_flattener = DictFlattener(outputs.keys())
239
+
240
+ if self._stage_idx == 0:
241
+ self._cache[microbatch_index] = BackwardCacheInputForFull(
242
+ stage_outputs_or_loss=outputs_flattener.flatten(outputs),
243
+ output_grads=outputs_flattener.flatten(outputs_grad) if outputs_grad is not None else None,
244
+ input_values=inputs_flattener.flatten(inputs)
245
+ )
246
+ else:
247
+ results = stage_backward_input(
248
+ outputs=outputs_flattener.flatten(outputs),
249
+ output_grads=outputs_flattener.flatten(outputs_grad) if outputs_grad is not None else None,
250
+ inputs=inputs_flattener.flatten(inputs),
251
+ weights=self._parameters_with_grad()
252
+ )
253
+
254
+ self._cache[microbatch_index] = BackwardCacheInputForWeight(
255
+ inputs_grad=inputs_flattener.unflatten(cast(Sequence[torch.Tensor], results.input_grads)),
256
+ param_groups=results.param_groups,
257
+ ownership_tokens=results.grad_ownership_tokens
258
+ )
259
+
260
+ def backward_weight(
261
+ self,
262
+ microbatch_index: int
263
+ ):
264
+ """
265
+ Performs a partial backward pass to accumulate gradients into weights.
266
+
267
+ Must be preceded by `backward_input` for the same microbatch index.
268
+
269
+ Args:
270
+ microbatch_index: Identifier for the microbatch.
271
+ """
272
+
273
+ if microbatch_index not in self._cache:
274
+ raise ValueError(f"S{self._stage_idx}BW{microbatch_index} - weight backward with no input backward before")
275
+
276
+ prev_cache = self._cache.pop(microbatch_index)
277
+
278
+ match prev_cache:
279
+ case BackwardCacheInputForFull():
280
+ stage_backward_full(
281
+ outputs=prev_cache.stage_outputs_or_loss,
282
+ output_grads=prev_cache.output_grads,
283
+ inputs=prev_cache.input_values
284
+ )
285
+ case BackwardCacheInputForWeight():
286
+ stage_backward_weight(
287
+ weights=self._parameters_with_grad(),
288
+ param_groups=prev_cache.param_groups
289
+ )
290
+ case _:
291
+ raise ValueError("Previous backward was not input backward")
292
+
293
+ def pop_for_sending(self, microbatch_index: int) -> dict[str, torch.Tensor]:
294
+ """
295
+ Retrieves the calculated input gradients for a microbatch.
296
+
297
+ Args:
298
+ microbatch_index: Identifier for the microbatch.
299
+
300
+ Returns:
301
+ Dictionary of gradient tensors.
302
+ """
303
+ cached = self._cache[microbatch_index]
304
+
305
+ match cached:
306
+ case BackwardCacheFull():
307
+ del self._cache[microbatch_index]
308
+ case BackwardCacheInputForWeight():
309
+ pass
310
+ case _:
311
+ raise ValueError("You should call either backward_full or backward_input before popping cached grad")
312
+
313
+ for grad_value in cached.inputs_grad.values():
314
+ if grad_value is None:
315
+ raise ValueError("Cannot pop null gradient for sending! Perhaps malformed schedule?")
316
+
317
+ return cast(dict[str, torch.Tensor], cached.inputs_grad)
@@ -0,0 +1,377 @@
1
+ from collections import defaultdict, deque
2
+ from collections.abc import Callable, Iterator
3
+ from dataclasses import dataclass
4
+ from typing import Any, cast
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.autograd.graph import GradientEdge, Node
9
+
10
+ from d9d.core.autograd import GLOBAL_GRAD_CONTEXT, GradDirection
11
+
12
+
13
+ def stage_backward_full(
14
+ outputs: list[torch.Tensor],
15
+ output_grads: list[torch.Tensor] | None,
16
+ inputs: list[torch.Tensor]
17
+ ) -> list[torch.Tensor | None]:
18
+ """
19
+ Performs a standard, full backward pass for a pipeline stage.
20
+
21
+ This function computes gradients for the inputs based on the gradients
22
+ received for the outputs.
23
+
24
+ Args:
25
+ outputs: The output tensors of the forward pass.
26
+ output_grads: The gradients arriving from the next pipeline stage corresponding
27
+ to `outputs`. If None, assumes scalar output or implied ones.
28
+ inputs: The input tensors to the forward pass for which gradients are required.
29
+
30
+ Returns:
31
+ A list of gradients corresponding to the `inputs`. If some input does not require gradient - its result will
32
+ be None.
33
+ """
34
+
35
+ with GLOBAL_GRAD_CONTEXT.with_directions(GradDirection.inputs, GradDirection.weight):
36
+ torch.autograd.backward(
37
+ tensors=outputs,
38
+ grad_tensors=output_grads
39
+ )
40
+
41
+ input_grads = []
42
+ for input_item in inputs:
43
+ input_grads.append(input_item.grad)
44
+ input_item.grad = None
45
+ return input_grads
46
+
47
+
48
+ @dataclass
49
+ class ParamGroup:
50
+ """
51
+ Represents a group of parameters and their dependency intermediates in the autograd graph.
52
+
53
+ This structure is used to manage the split backward pass, identifying which
54
+ intermediate nodes in the graph allow gradients to flow to specific sets of parameters.
55
+
56
+ Attributes:
57
+ params: Set of autograd Nodes representing the parameters.
58
+ intermediates: List of autograd Nodes serving as entry points for gradients
59
+ flowing to these parameters.
60
+ grads: Storage for captured gradients at the intermediate nodes during
61
+ the input backward phase.
62
+ """
63
+
64
+ params: set[Node]
65
+ intermediates: list[Node] | None
66
+ grads: list[torch.Tensor | None] | None = None
67
+
68
+
69
+ def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None:
70
+ if t.requires_grad and t.grad_fn is None:
71
+ # hack from pytorch codebase to create accumulation op
72
+ viewed_t = t.view_as(t)
73
+ grad_fn = viewed_t.grad_fn
74
+ grad_fn = cast(Node, grad_fn)
75
+ return grad_fn.next_functions[0][0]
76
+ else:
77
+ return t.grad_fn
78
+
79
+
80
+ def _construct_reverse_graph(roots: list[Node]) -> dict[Node, list[Node]]:
81
+ """
82
+ Builds a reverse adjacency list (Input -> Output) via BFS from the roots.
83
+
84
+ Standard autograd graphs point from Output -> Input (next_functions).
85
+ This helper provides the reverse mapping to assist in dependency analysis.
86
+
87
+ Args:
88
+ roots: The starting nodes for the graph traversal.
89
+
90
+ Returns:
91
+ A dictionary mapping a node to a list of its dependent (child) nodes.
92
+ """
93
+ reverse_graph = defaultdict(list)
94
+ valid_roots = {x for x in roots if x is not None}
95
+ to_visit = deque(valid_roots)
96
+ visited = set(valid_roots)
97
+
98
+ while to_visit:
99
+ current_node = to_visit.popleft()
100
+ for parent_node, _ in current_node.next_functions:
101
+ if parent_node is None:
102
+ continue
103
+ reverse_graph[parent_node].append(current_node)
104
+ if parent_node not in visited:
105
+ visited.add(parent_node)
106
+ to_visit.append(parent_node)
107
+
108
+ return reverse_graph
109
+
110
+
111
+ def _reverse_closure(
112
+ roots: list[Node], target_nodes: set[Node], reverse_edges_dict: dict[Node, list[Node]]
113
+ ) -> tuple[set[Node], set[Node]]:
114
+ """
115
+ Computes a closure of nodes reachable from roots in the reverse graph.
116
+
117
+ Args:
118
+ roots: Starting nodes.
119
+ target_nodes: Nodes that act as boundaries/targets for the search.
120
+ reverse_edges_dict: The reverse graph adjacency list.
121
+
122
+ Returns:
123
+ A tuple containing the set of all closure nodes and the set of visited target nodes.
124
+ """
125
+
126
+ closure: set[Node] = set()
127
+ visited_target_nodes = set()
128
+ to_visit: deque[Node] = deque()
129
+
130
+ for node in roots:
131
+ if node is not None and node not in closure:
132
+ closure.add(node)
133
+ to_visit.append(node)
134
+
135
+ while to_visit:
136
+ node = to_visit.popleft()
137
+ reverse_edges = reverse_edges_dict[node]
138
+ for fn in reverse_edges:
139
+ if fn in closure or fn is None:
140
+ continue
141
+ if fn in target_nodes:
142
+ visited_target_nodes.add(fn)
143
+ continue
144
+ closure.add(fn)
145
+ to_visit.append(fn)
146
+
147
+ return closure, visited_target_nodes
148
+
149
+
150
+ def _get_param_groups(
151
+ inputs: list[Node], params: list[Node], reverse_edges_dict: dict[Node, list[Node]]
152
+ ) -> list[ParamGroup]:
153
+ """
154
+ Clusters parameters based on their dependencies on inputs.
155
+
156
+ This function identifies how gradients propagate from inputs through intermediates
157
+ to parameters, grouping them to facilitate split backward execution.
158
+
159
+ Args:
160
+ inputs: Gradient functions of the input tensors.
161
+ params: Gradient functions of the parameter tensors.
162
+ reverse_edges_dict: The reverse autograd graph.
163
+
164
+ Returns:
165
+ A list of distinct parameter groups.
166
+ """
167
+
168
+ inputs_closure, _ = _reverse_closure(inputs, set(), reverse_edges_dict)
169
+
170
+ node_to_group_map: dict[Node, dict[str, set[Node]]] = {}
171
+
172
+ for param in params:
173
+ _, intersected_inputs = _reverse_closure(
174
+ [param], inputs_closure, reverse_edges_dict
175
+ )
176
+
177
+ current_dict = {
178
+ "params": {param},
179
+ "intermediates": intersected_inputs
180
+ }
181
+
182
+ target_dict = None
183
+ for intermediate_node in intersected_inputs:
184
+ if intermediate_node in node_to_group_map:
185
+ target_dict = node_to_group_map[intermediate_node]
186
+ break
187
+
188
+ if target_dict is not None:
189
+ target_dict["params"].update(current_dict["params"])
190
+ target_dict["intermediates"].update(current_dict["intermediates"])
191
+ current_dict = target_dict
192
+
193
+ for intermediate_node in current_dict["intermediates"]:
194
+ node_to_group_map[intermediate_node] = current_dict
195
+
196
+ # Deduplicate and Convert to Dataclass
197
+ unique_groups = []
198
+ seen_ids = set()
199
+ for group_dict in node_to_group_map.values():
200
+ if id(group_dict) not in seen_ids:
201
+ seen_ids.add(id(group_dict))
202
+ unique_groups.append(ParamGroup(
203
+ params=group_dict["params"],
204
+ intermediates=list(group_dict["intermediates"])
205
+ ))
206
+
207
+ return unique_groups
208
+
209
+
210
+ def _make_capture_hook(group: ParamGroup, idx: int) -> Callable[[torch.Tensor], None]:
211
+ def _hook(grad_in: torch.Tensor):
212
+ # Lazy init gradients list
213
+ if group.grads is None and group.intermediates is not None:
214
+ group.grads = [None] * len(group.intermediates)
215
+
216
+ if group.grads is not None:
217
+ group.grads[idx] = grad_in
218
+
219
+ return _hook
220
+
221
+
222
+ @dataclass
223
+ class BackwardInputResult:
224
+ """
225
+ Container for the results of the input backward phase.
226
+
227
+ Attributes:
228
+ input_grads: The gradients computed for the input tensors.
229
+ param_groups: The parameter groups with hooks established to capture
230
+ weight gradients in the subsequent phase.
231
+ grad_ownership_tokens: References to tensors keeping the computation
232
+ graph alive for the weight backward phase.
233
+ """
234
+
235
+ input_grads: list[torch.Tensor | None]
236
+ param_groups: list[ParamGroup]
237
+ grad_ownership_tokens: list[Any]
238
+
239
+
240
+ def stage_backward_input(
241
+ outputs: list[torch.Tensor],
242
+ output_grads: list[torch.Tensor] | None,
243
+ inputs: list[torch.Tensor],
244
+ weights: Iterator[nn.Parameter],
245
+ ) -> BackwardInputResult:
246
+ """
247
+ Performs the first phase of a split backward pass: Input Gradients.
248
+
249
+ This function computes the gradients with respect to `inputs` while postponing
250
+ the computation of gradients with respect to `weights`. It analyzes the
251
+ autograd graph to identify intermediate nodes where gradients destined for
252
+ weights split off from the main flow. Hooks are registered at these
253
+ intermediates to capture gradients for the second phase (`stage_backward_weight`).
254
+
255
+ Args:
256
+ outputs: The output tensors of the forward pass.
257
+ output_grads: The gradients arriving for the outputs.
258
+ inputs: The input tensors from the forward pass.
259
+ weights: An iterator over the model parameters (weights).
260
+
261
+ Returns:
262
+ A result object containing input gradients, prepared parameter groups,
263
+ and ownership tokens to maintain graph validity.
264
+ """
265
+
266
+ outputs_grad_fn = [grad_fn for x in outputs if (grad_fn := _get_grad_fn_or_grad_acc(x)) is not None]
267
+ inputs_grad_fn = [grad_fn for x in inputs if (grad_fn := _get_grad_fn_or_grad_acc(x)) is not None]
268
+ weights_grad_fn = [grad_fn for x in weights if (grad_fn := _get_grad_fn_or_grad_acc(x)) is not None]
269
+
270
+ reverse_edges = _construct_reverse_graph(outputs_grad_fn)
271
+ param_groups = _get_param_groups(inputs_grad_fn, weights_grad_fn, reverse_edges)
272
+
273
+ hook_handles = []
274
+
275
+ for group in param_groups:
276
+ if group.intermediates:
277
+ for i, node in enumerate(group.intermediates):
278
+ hook_handles.append(node.register_prehook(_make_capture_hook(group, i)))
279
+
280
+ if output_grads is None:
281
+ output_grads = [torch.ones_like(o) for o in outputs]
282
+
283
+ inputs_requiring_grad = [inp for inp in inputs if inp.requires_grad]
284
+
285
+ with GLOBAL_GRAD_CONTEXT.with_directions(GradDirection.inputs):
286
+ torch.autograd.backward(
287
+ tensors=outputs,
288
+ grad_tensors=output_grads,
289
+ inputs=inputs_requiring_grad,
290
+ retain_graph=True,
291
+ )
292
+
293
+ final_input_grads = []
294
+
295
+ # 6. Cleanup
296
+ for input_item in inputs:
297
+ final_input_grads.append(input_item.grad)
298
+ input_item.grad = None
299
+
300
+ for handle in hook_handles:
301
+ handle.remove()
302
+
303
+ return BackwardInputResult(
304
+ input_grads=final_input_grads,
305
+ param_groups=param_groups,
306
+ # TODO(max): we can keep only intermediate ownership tokens to both truncate the
307
+ # TODO(max): graph and do not deallocate C++ stuff
308
+ grad_ownership_tokens=outputs # Keep the tensors alive!
309
+ )
310
+
311
+
312
+ def stage_backward_weight( # noqa: C901
313
+ weights: Iterator[nn.Parameter],
314
+ param_groups: list[ParamGroup],
315
+ retain_graph: bool = False
316
+ ) -> tuple[torch.Tensor | None, ...]:
317
+ """
318
+ Performs the second phase of a split backward pass: Weight Gradients.
319
+
320
+ This function consumes the gradients captured in the `ParamGroup`s during
321
+ `stage_backward_input` to compute the final gradients for the model weights.
322
+ It triggers backward passes starting from the intermediate nodes identified previously.
323
+
324
+ Args:
325
+ weights: An iterator over the model parameters to extract gradients for.
326
+ param_groups: The list of groups containing captured intermediate gradients.
327
+ retain_graph: Whether to retain the graph after this backward pass.
328
+
329
+ Returns:
330
+ A tuple of gradients corresponding to the provided `weights`.
331
+ """
332
+
333
+ grad_acc_to_weight = {}
334
+ all_weights = [] # Keep order
335
+
336
+ for weight in weights:
337
+ all_weights.append(weight)
338
+ grad_acc = _get_grad_fn_or_grad_acc(weight)
339
+ if grad_acc is not None:
340
+ grad_acc_to_weight[grad_acc] = weight
341
+
342
+ for group in param_groups:
343
+ valid_edges = []
344
+ valid_grad_outputs: list[torch.Tensor] = []
345
+
346
+ # Ensure we have data
347
+ if group.grads and group.intermediates:
348
+ for grads_tuple, intermediate in zip(group.grads, group.intermediates, strict=True):
349
+ if grads_tuple is None:
350
+ raise ValueError("Trying to do backward_weight with to intermediate grads")
351
+ non_none = [g for g in grads_tuple if g is not None]
352
+ if len(non_none) > 0:
353
+ valid_edges.append(GradientEdge(intermediate, 0))
354
+ valid_grad_outputs.append(cast(torch.Tensor, sum(non_none)))
355
+
356
+ # Break Cycle: Intermediates
357
+ group.intermediates = None
358
+
359
+ if valid_edges:
360
+ inputs_for_backward = []
361
+ for node in group.params:
362
+ if node in grad_acc_to_weight:
363
+ inputs_for_backward.append(grad_acc_to_weight[node])
364
+
365
+ if inputs_for_backward:
366
+ with GLOBAL_GRAD_CONTEXT.with_directions(GradDirection.weight):
367
+ torch.autograd.backward(
368
+ tensors=valid_edges,
369
+ grad_tensors=valid_grad_outputs,
370
+ retain_graph=retain_graph,
371
+ inputs=inputs_for_backward
372
+ )
373
+
374
+ # Break Cycle: Grads
375
+ group.grads = None
376
+
377
+ return tuple(w.grad for w in all_weights)