d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,258 @@
1
+ from collections.abc import Callable, Iterable, Iterator
2
+ from typing import Any, Self, TypedDict, Unpack
3
+
4
+ import torch
5
+ import torch.utils._pytree as pytree # noqa: PLC2701
6
+ from torch.utils.data import Dataset, Sampler
7
+ from torchdata.stateful_dataloader import StatefulDataLoader
8
+
9
+ from d9d.core.dist_context import BATCH_DOMAIN, DistributedContext
10
+ from d9d.core.types import CollateFn, PyTree
11
+ from d9d.loop.config import DataLoadingConfig
12
+ from d9d.loop.control import DatasetProvider, InitializeDatasetContext
13
+
14
+ from .batch_maths import BatchMaths
15
+
16
+
17
+ class DataLoaderKwargs(TypedDict, total=False):
18
+ """
19
+ Type definition for arguments accepted by the PyTorch DataLoader.
20
+ """
21
+
22
+ batch_size: int | None
23
+ shuffle: bool | None
24
+ sampler: Sampler | Iterable | None
25
+ batch_sampler: Sampler[list] | Iterable[list] | None
26
+ num_workers: int
27
+ collate_fn: CollateFn
28
+ pin_memory: bool
29
+ drop_last: bool
30
+ timeout: float
31
+ worker_init_fn: Callable | None
32
+ multiprocessing_context: Any
33
+ generator: Any
34
+ prefetch_factor: int | None
35
+ persistent_workers: bool
36
+ pin_memory_device: str
37
+
38
+
39
+ def _move_to_device(data: PyTree, device: torch.types.Device) -> PyTree:
40
+ return pytree.tree_map(lambda x: x.to(device), data)
41
+
42
+
43
+ class IteratorBatchGroup(Iterator):
44
+ """
45
+ An iterator that groups items from a base iterator into sub-streams.
46
+
47
+ This class is utilized for gradient accumulation where
48
+ a single optimizer step consumes multiple micro-batches (the group).
49
+
50
+ It also moves the data to the specified device immediately upon access.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ base: Iterator,
56
+ device: torch.types.Device,
57
+ batch_group_size: int
58
+ ):
59
+ """
60
+ Constructs an IteratorBatchGroup object.
61
+
62
+ Args:
63
+ base: The underlying data iterator (usually from a DataLoader).
64
+ device: The target device to move tensors to.
65
+ batch_group_size: The number of micro-batches to yield within one group.
66
+ """
67
+
68
+ self._base = base
69
+ self._device = device
70
+
71
+ self._batch_group_size = batch_group_size
72
+
73
+ self._is_end = False
74
+
75
+ def __next__(self) -> PyTree:
76
+ """
77
+ Advances the iterator.
78
+
79
+ Returns:
80
+ A generator that yields `batch_group_size` items (micro-batches),
81
+ with each item already moved to the configured device.
82
+
83
+ Raises:
84
+ StopIteration: If the underlying iterator is exhausted.
85
+ """
86
+
87
+ if self._is_end:
88
+ raise StopIteration()
89
+
90
+ try:
91
+ sample_item = next(self._base)
92
+ except StopIteration:
93
+ self._is_end = True
94
+ raise StopIteration() from None
95
+
96
+ def _iter_inside_group():
97
+ yield _move_to_device(sample_item, self._device)
98
+
99
+ for _ in range(self._batch_group_size - 1):
100
+ try:
101
+ item = next(self._base)
102
+ yield _move_to_device(item, self._device)
103
+ except StopIteration:
104
+ self._is_end = True
105
+ break
106
+
107
+ return _iter_inside_group()
108
+
109
+ def __iter__(self) -> Self:
110
+ """Returns self."""
111
+
112
+ return self
113
+
114
+
115
+ class StatefulDataLoaderDataParallelAware(StatefulDataLoader):
116
+ """
117
+ A stateful data loader that is aware of data parallel ranks.
118
+
119
+ This loader extends the standard torchdata StatefulDataLoader to ensure
120
+ that checkpoints are saved with rank-specific keys.
121
+
122
+ It also wraps the iterator to support batch grouping for gradient accumulation and
123
+ automatically transfer data to bound device.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ dataset: Dataset,
129
+ dp_rank: int,
130
+ device: torch.types.Device,
131
+ group_size: int,
132
+ **kwargs: Unpack[DataLoaderKwargs]
133
+ ):
134
+ """
135
+ Constructs a StatefulDataLoaderDataParallelAware object.
136
+
137
+ Args:
138
+ dataset: The dataset to load from.
139
+ dp_rank: The Data Parallel rank of the current process (used for state checkpointing).
140
+ device: The device to move data to.
141
+ group_size: The number of batches to group together (e.g., for gradient accumulation).
142
+ **kwargs: Standard arguments passed to the parent DataLoader.
143
+ """
144
+
145
+ super().__init__(dataset, **kwargs)
146
+ self._dp_rank = dp_rank
147
+ self._device = device
148
+ self._group_size = group_size
149
+
150
+ def state_dict(self) -> dict[str, Any]:
151
+ return {
152
+ f"dp_{self._dp_rank}": super().state_dict()
153
+ }
154
+
155
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
156
+ super().load_state_dict(state_dict[f"dp_{self._dp_rank}"])
157
+
158
+ def __iter__(self) -> Iterator:
159
+ return IteratorBatchGroup(
160
+ super().__iter__(),
161
+ device=self._device,
162
+ batch_group_size=self._group_size
163
+ )
164
+
165
+
166
+ class DataLoaderFactory:
167
+ """
168
+ Factory class for creating configured DataLoaders.
169
+
170
+ This class centralizes the creation logic for training and inference
171
+ data loaders, applying configurations.
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ dist_context: DistributedContext,
177
+ provider: DatasetProvider,
178
+ config_data_loading: DataLoadingConfig,
179
+ batch_maths: BatchMaths
180
+ ):
181
+ """
182
+ Constructs a DataLoaderFactory object.
183
+
184
+ Args:
185
+ dist_context: The distributed context containing mesh and rank information.
186
+ provider: The provider callable that initializes the dataset and collator.
187
+ config_data_loading: Specific configuration for data loading.
188
+ batch_maths: BatchMaths object.
189
+ """
190
+
191
+ self._dist_context = dist_context
192
+ self._provider = provider
193
+
194
+ self._config_data_loading = config_data_loading
195
+
196
+ self._batch_maths = batch_maths
197
+
198
+ def _build_dataloader(
199
+ self,
200
+ provider: DatasetProvider,
201
+ batch_size: int,
202
+ group_size: int,
203
+ drop_last: bool
204
+ ) -> StatefulDataLoader:
205
+ result = provider(InitializeDatasetContext(
206
+ dist_context=self._dist_context,
207
+ batch_maths=self._batch_maths
208
+ ))
209
+
210
+ return StatefulDataLoaderDataParallelAware(
211
+ result.dataset,
212
+ collate_fn=result.collator,
213
+ group_size=group_size,
214
+ num_workers=self._config_data_loading.num_workers,
215
+ persistent_workers=self._config_data_loading.persistent_workers,
216
+ pin_memory=self._config_data_loading.pin_memory,
217
+ batch_size=batch_size,
218
+ dp_rank=self._dist_context.mesh_for(BATCH_DOMAIN)["dp"].size(),
219
+ device="cuda",
220
+ drop_last=drop_last
221
+ )
222
+
223
+ def build_dataloader_for_train_job(self) -> StatefulDataLoader:
224
+ """
225
+ Builds and returns a StatefulDataLoader configured for training.
226
+
227
+ This loader is configured to drop the last incomplete batch and group
228
+ batches according to the gradient accumulation settings defined in
229
+ BatchMaths.
230
+
231
+ Returns:
232
+ A configured StatefulDataLoader instance.
233
+ """
234
+
235
+ return self._build_dataloader(
236
+ self._provider,
237
+ batch_size=self._batch_maths.data_loader_batch_size,
238
+ group_size=self._batch_maths.num_microbatches_gradient_accumulation,
239
+ drop_last=True
240
+ )
241
+
242
+ def build_dataloader_for_infer_job(self) -> StatefulDataLoader:
243
+ """
244
+ Builds and returns a StatefulDataLoader configured for inference.
245
+
246
+ This loader processes batches one by one (group size of 1) and does
247
+ not drop the last batch.
248
+
249
+ Returns:
250
+ A configured StatefulDataLoader instance.
251
+ """
252
+
253
+ return self._build_dataloader(
254
+ self._provider,
255
+ batch_size=self._batch_maths.data_loader_batch_size,
256
+ group_size=1,
257
+ drop_last=False
258
+ )
@@ -0,0 +1,94 @@
1
+ import gc
2
+ import time
3
+ from contextlib import AbstractContextManager
4
+ from types import TracebackType
5
+ from typing import Self
6
+
7
+ from d9d.core.dist_context import DistributedContext
8
+ from d9d.loop.config import GarbageCollectionConfig
9
+
10
+ from .stepper import Stepper
11
+
12
+
13
+ class ManualGarbageCollector(AbstractContextManager):
14
+ """
15
+ Manages efficient Python garbage collection during the training loop.
16
+
17
+ This context manager disables automatic garbage collection upon entry to prevent
18
+ unpredictable latency spikes during training steps. It allows performing
19
+ manual collection at specific intervals (periodic) or specific points (forced).
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ dist_ctx: DistributedContext,
25
+ config: GarbageCollectionConfig,
26
+ step: Stepper
27
+ ):
28
+ """
29
+ Constructs the garbage collector manager.
30
+
31
+ Args:
32
+ dist_ctx: The distributed context.
33
+ config: Configuration determining how often GC should run.
34
+ step: Stepper instance used to track the current training step.
35
+ """
36
+ self._dist_ctx = dist_ctx
37
+ self._config = config
38
+ self._step = step
39
+
40
+ def __enter__(self) -> Self:
41
+ """
42
+ Disables automatic garbage collection and performs an initial full collection.
43
+
44
+ Returns:
45
+ The calling instance.
46
+ """
47
+
48
+ gc.disable()
49
+ self._collect(generation=2)
50
+
51
+ return self
52
+
53
+ def __exit__(
54
+ self,
55
+ exc_type: type[BaseException] | None,
56
+ exc_value: BaseException | None,
57
+ traceback: TracebackType | None, /
58
+ ) -> None:
59
+ """
60
+ Re-enables automatic garbage collection and performs a final full collection.
61
+
62
+ Args:
63
+ exc_type: The type of the exception raised (if any).
64
+ exc_value: The exception instance raised (if any).
65
+ traceback: The traceback object (if any).
66
+ """
67
+
68
+ gc.enable()
69
+ self._collect(generation=2)
70
+
71
+ def collect_periodic(self):
72
+ """
73
+ Triggers garbage collection if the current step matches the configured period.
74
+
75
+ This typically performs a faster (generation 1) collection rather than a full sweep.
76
+ """
77
+
78
+ if self._step.should_do_action(self._config.period_steps, enable_on_last_step_if_periodic=False):
79
+ self._collect(generation=1)
80
+
81
+ def collect_forced(self):
82
+ """
83
+ Forces a full garbage collection run regardless of the step count.
84
+
85
+ This performs a generation 2 collection.
86
+ """
87
+
88
+ self._collect(generation=2)
89
+
90
+ def _collect(self, generation: int):
91
+ begin = time.monotonic()
92
+ gc.collect(generation)
93
+ end = time.monotonic()
94
+ self._dist_ctx.logger.info(f"[GC] Garbage collection for generation {generation} took {end - begin}s")
@@ -0,0 +1,89 @@
1
+ from contextlib import contextmanager
2
+
3
+ from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
4
+ from d9d.internals.grad_norm import ParametersForNorm, clip_grad_norm_distributed_, group_parameters_for_norm
5
+ from d9d.loop.config import GradientClippingConfig
6
+ from d9d.tracker import BaseTrackerRun
7
+
8
+ from .model_stage_factory import TrackedModules
9
+ from .stepper import Stepper
10
+
11
+
12
+ class GradientClipper:
13
+ """
14
+ Manages gradient clipping and logging of gradient norms in a distributed execution environment.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ dist_context: DistributedContext,
20
+ tracked_modules: TrackedModules,
21
+ config: GradientClippingConfig,
22
+ stepper: Stepper
23
+ ):
24
+ """
25
+ Constructs the gradient clipper.
26
+
27
+ Args:
28
+ dist_context: The distributed context.
29
+ tracked_modules: Container of model modules whose parameters need clipping.
30
+ config: Configuration defining max norm and logging frequency.
31
+ stepper: Stepper instance used to track the current training step.
32
+ """
33
+
34
+ self._dist_context = dist_context
35
+ self._tracked_modules = tracked_modules
36
+ self._config = config
37
+ self._stepper = stepper
38
+
39
+ self._parameter_groups: ParametersForNorm | None = None
40
+
41
+ def _all_parameters(self):
42
+ for model in self._tracked_modules.modules:
43
+ yield from model.parameters()
44
+
45
+ @contextmanager
46
+ def install(self):
47
+ """
48
+ Context manager that prepares and groups parameters for efficient norm calculation.
49
+
50
+ It calculates necessary metadata (such as segregating shared parameters) to ensure
51
+ correct global norm calculation across the pipeline parallel mesh.
52
+ """
53
+
54
+ self._parameter_groups = group_parameters_for_norm(self._all_parameters())
55
+ yield
56
+ self._parameter_groups = None
57
+
58
+ def clip_and_log(self, run: BaseTrackerRun):
59
+ """
60
+ Clips gradients to the configured maximum norm and logs the total L2 norm.
61
+
62
+ This method performs an in-place modification of parameter gradients if a
63
+ maximum norm is configured. It calculates the global gradient norm across
64
+ distributed ranks.
65
+
66
+ Args:
67
+ run: The tracker run instance used for logging the norm scalar.
68
+
69
+ Raises:
70
+ ValueError: If called outside the ``install`` context manager scope.
71
+ """
72
+
73
+ should_log = self._stepper.should_do_action(self._config.log_total_steps)
74
+
75
+ if not self._config.max_norm and not should_log:
76
+ return
77
+
78
+ if self._parameter_groups is None:
79
+ raise ValueError("Parameter groups are not configured")
80
+
81
+ grad_norm = clip_grad_norm_distributed_(
82
+ parameter_groups=self._parameter_groups,
83
+ max_norm=self._config.max_norm,
84
+ norm_type=2.0,
85
+ pp_mesh=self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"],
86
+ )
87
+
88
+ if should_log:
89
+ run.scalar(name="l2_grad_norm_total", value=grad_norm.item())
@@ -0,0 +1,149 @@
1
+ from contextlib import contextmanager
2
+
3
+ import torch
4
+ from torch.distributed.tensor import DTensor
5
+
6
+ from d9d.core.dist_context import DistributedContext
7
+ from d9d.internals.grad_sync import GradientSynchronizer
8
+ from d9d.loop.config import GradientManagerConfig
9
+ from d9d.metric.impl import WeightedMeanMetric
10
+
11
+ from .batch_maths import BatchMaths
12
+ from .model_stage_factory import TrackedModules
13
+
14
+
15
+ class GradientManager:
16
+ """
17
+ Manages the lifecycle of gradients during the training loop.
18
+
19
+ This class handles gradient synchronization across ranks,
20
+ gradient data type configuration, and loss scaling based on accumulated weights.
21
+ It orchestrates the `GradientSynchronizer` and ensures gradients are correctly
22
+ prepared before the optimizer step.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ dist_context: DistributedContext,
28
+ tracked_modules: TrackedModules,
29
+ batch_maths: BatchMaths,
30
+ config: GradientManagerConfig
31
+ ):
32
+ """
33
+ Constructs the GradientManager and initializes the internal synchronizer.
34
+
35
+ Args:
36
+ dist_context: The distributed context.
37
+ tracked_modules: Container of model modules to manage gradients for.
38
+ batch_maths: Calculation utility for batch sizes and accumulation steps.
39
+ config: Configuration for gradient handling.
40
+ """
41
+
42
+ self._dist_context = dist_context
43
+ self._tracked_modules = tracked_modules
44
+ self._batch_maths = batch_maths
45
+ self._config = config
46
+ self._loss = WeightedMeanMetric()
47
+ self._loss.to("cuda")
48
+
49
+ self._grad_sync = GradientSynchronizer(
50
+ [list(module.parameters()) for module in self._tracked_modules.modules],
51
+ bucket_size_mb=self._config.bucket_size_mb,
52
+ require_accumulations=self._batch_maths.num_backward_calls
53
+ )
54
+ self._grads_to_scale: list[torch.Tensor] | None = None
55
+
56
+ def _setup_grad_dtype(self):
57
+ if self._config.grad_dtype is None:
58
+ return
59
+
60
+ for mod in self._tracked_modules.modules:
61
+ for param in mod.parameters():
62
+ if param.requires_grad:
63
+ param.grad_dtype = getattr(torch, self._config.grad_dtype)
64
+
65
+ def _bind_grads_to_scale(self):
66
+ grads_to_scale: list[torch.Tensor] = []
67
+
68
+ for mod in self._tracked_modules.modules:
69
+ for param in mod.parameters():
70
+ if param.grad is None:
71
+ continue
72
+ grad = param.grad.to_local() if isinstance(param.grad, DTensor) else param.grad
73
+ grads_to_scale.append(grad)
74
+
75
+ self._grads_to_scale = grads_to_scale
76
+
77
+ def _unbind_grads_to_scale(self):
78
+ self._grads_to_scale = None
79
+
80
+ def _scale_grads(self):
81
+ scale_factor = 1.0 / self._loss.accumulated_weight
82
+ torch._foreach_mul_(self._grads_to_scale, scale_factor)
83
+
84
+ @contextmanager
85
+ def install(self):
86
+ """
87
+ Context manager to activate gradient handling for a forward/backward pass.
88
+
89
+ This sets up gradient dtypes, install backward hooks for synchronization via
90
+ the `GradientSynchronizer`, and binds gradients for later scaling. It acts
91
+ as the boundary for the accumulation phase.
92
+ """
93
+
94
+ self._setup_grad_dtype()
95
+ self._grad_sync.bind()
96
+ self._bind_grads_to_scale()
97
+ yield
98
+ self._unbind_grads_to_scale()
99
+ self._grad_sync.unbind()
100
+
101
+ def add_loss_with_weight(self, loss: torch.Tensor, loss_weight: torch.Tensor):
102
+ """
103
+ Accumulates a loss value and its corresponding weight into the internal metric.
104
+
105
+ Args:
106
+ loss: The computed loss scalar.
107
+ loss_weight: The weight asscociated with this loss.
108
+ """
109
+
110
+ self._loss.update(loss, loss_weight)
111
+
112
+ def sync_and_scale(self):
113
+ """
114
+ Finalizes gradients to be ready for the optimizer step.
115
+
116
+ This method performs the following operations:
117
+
118
+ 1. Waits for all gradient synchronization hooks to complete.
119
+ 2. Synchronizes the accumulated loss/weights across the distributed context.
120
+ 3. Scales the gradients by the inverse of the total accumulated weight to
121
+ normalize them.
122
+ """
123
+
124
+ self._grad_sync.wait()
125
+
126
+ self._loss.trigger_sync(self._dist_context)
127
+ self._loss.wait_sync(self._dist_context)
128
+ self._scale_grads()
129
+
130
+ def compute_global_loss(self) -> torch.Tensor:
131
+ """
132
+ Calculates the final weighted mean loss.
133
+
134
+ Returns:
135
+ The averaged loss scalar across all accumulation steps and ranks.
136
+ """
137
+
138
+ return self._loss.compute()
139
+
140
+ def zero_grad(self):
141
+ """
142
+ Resets the internal state for the next training step.
143
+
144
+ This clears the accumulated gradients in the synchronizer and resets the
145
+ loss metrics.
146
+ """
147
+
148
+ self._grad_sync.zero_grad()
149
+ self._loss.reset()