d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,146 @@
1
+ from collections.abc import Generator
2
+ from contextlib import contextmanager
3
+ from typing import Any
4
+
5
+ import torch
6
+ import torch.utils._pytree as pytree # noqa: PLC2701
7
+ from torch.distributed.checkpoint.stateful import Stateful
8
+
9
+ from d9d.core.dist_context import DistributedContext
10
+ from d9d.core.types import PyTree, ScalarTree
11
+ from d9d.internals.state import load_state_dict_main_process, state_dict_main_process
12
+ from d9d.loop.config import JobLoggerConfig
13
+ from d9d.metric.impl import ComposeMetric
14
+ from d9d.tracker import BaseTracker, BaseTrackerRun, RunConfig, tracker_from_config
15
+ from d9d.tracker.provider.null import NullTrackerConfig
16
+
17
+ from .stepper import Stepper
18
+
19
+
20
+ def _flatten_pytree_for_metrics(tree: PyTree[float]) -> dict[str, float]:
21
+ flat_dict = {}
22
+
23
+ for path_tuple, value in pytree.tree_leaves_with_path(tree):
24
+ path_segments = []
25
+
26
+ for key in path_tuple:
27
+ match key:
28
+ case pytree.MappingKey(k):
29
+ path_segments.append(str(k))
30
+ case pytree.SequenceKey(idx):
31
+ path_segments.append(str(idx))
32
+ case pytree.GetAttrKey(name):
33
+ path_segments.append(name)
34
+ case _:
35
+ path_segments.append(str(key))
36
+
37
+ flat_key = "/".join(path_segments)
38
+ flat_dict[flat_key] = value
39
+
40
+ return flat_dict
41
+
42
+
43
+ class JobLogger(Stateful):
44
+ """
45
+ Handles the logging of training metrics and loss values.
46
+
47
+ This class coordinates with the distributed context and metric calculators
48
+ to log instantaneous loss values and periodic aggregated metrics to the
49
+ configured experiment tracker.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ dist_context: DistributedContext,
55
+ config: JobLoggerConfig,
56
+ metrics: ComposeMetric,
57
+ stepper: Stepper,
58
+ run_config: RunConfig,
59
+ additional_hparams: ScalarTree
60
+ ):
61
+ """
62
+ Constructs JobLogger object.
63
+
64
+ Args:
65
+ dist_context: The distributed context.
66
+ config: Configuration settings.
67
+ metrics: The composite metric collection to be computed and logged.
68
+ stepper: Object tracking the current global step.
69
+ run_config: Run configuration.
70
+ """
71
+
72
+ self._dist_context = dist_context
73
+ self._config = config
74
+ self._metrics = metrics
75
+ self._stepper = stepper
76
+ self._run_config = run_config.model_copy(deep=True, update={"hparams": {
77
+ "run": run_config.hparams,
78
+ "params": additional_hparams
79
+ }})
80
+
81
+ self._tracker = self._build_tracker()
82
+
83
+ def _build_tracker(self) -> BaseTracker:
84
+ if self._dist_context.is_main_process:
85
+ return tracker_from_config(self._config.tracker)
86
+ else:
87
+ return tracker_from_config(NullTrackerConfig())
88
+
89
+ @contextmanager
90
+ def new_run(self) -> Generator[BaseTrackerRun, None, None]:
91
+ with self._tracker.open(self._run_config) as run:
92
+ yield run
93
+
94
+ def trigger_sync(self):
95
+ """
96
+ Conditionally initiates the synchronization of distributed metrics.
97
+
98
+ Checks if the current step is scheduled for metric logging. If so, it
99
+ triggers the asynchronous communication required to aggregate metric values
100
+ across ranks. This allows communication to overlap with other operations
101
+ before `log` is called.
102
+ """
103
+
104
+ if not self._stepper.should_do_action(self._config.period_steps, enable_on_last_step_if_periodic=True):
105
+ return
106
+
107
+ self._metrics.trigger_sync(self._dist_context)
108
+
109
+ def log(self, run: BaseTrackerRun, loss_value: torch.Tensor):
110
+ """
111
+ Logs the current loss and conditionally processes aggregated metrics.
112
+
113
+ This method always logs the provided loss value. Periodically (determined
114
+ by the stepper and configuration), it waits for the synchronization of
115
+ metrics to complete (initiated by `trigger_sync`), computes their values,
116
+ flattens the result structure, logs them to the tracker, and resets the
117
+ metrics for the next window.
118
+
119
+ Args:
120
+ run: The active tracker run interface for sending data.
121
+ loss_value: Tensor containing the scalar loss for the current step.
122
+ """
123
+
124
+ run.scalar("loss", loss_value.item())
125
+
126
+ if not self._stepper.should_do_action(self._config.period_steps, enable_on_last_step_if_periodic=True):
127
+ return
128
+
129
+ self._metrics.wait_sync(self._dist_context)
130
+
131
+ results_tree = self._metrics.compute()
132
+ results_tree = pytree.tree_map(lambda x: x.item(), results_tree)
133
+ results_flat = _flatten_pytree_for_metrics(results_tree)
134
+
135
+ for name, value in results_flat.items():
136
+ run.scalar(name, value)
137
+
138
+ self._metrics.reset()
139
+
140
+ def state_dict(self) -> dict[str, Any]:
141
+ return {
142
+ "tracker": state_dict_main_process(self._dist_context, self._tracker),
143
+ }
144
+
145
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
146
+ load_state_dict_main_process(self._dist_context, self._tracker, state_dict["tracker"])
@@ -0,0 +1,62 @@
1
+ from collections.abc import Generator
2
+ from contextlib import contextmanager
3
+
4
+ import torch.profiler
5
+
6
+ from d9d.core.dist_context import DistributedContext
7
+ from d9d.internals.profiling import Profiler
8
+ from d9d.loop.config import ProfilingConfig
9
+
10
+ from .stepper import Stepper
11
+
12
+
13
+ class JobProfiler:
14
+ """
15
+ Manages profiling sessions during a job loop.
16
+
17
+ This class coordinates the initialization and activation of the internal
18
+ profiler based on the current step count provided by the stepper.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ dist_context: DistributedContext,
24
+ config: ProfilingConfig | None,
25
+ stepper: Stepper
26
+ ):
27
+ """
28
+ Constructs JobProfiler object.
29
+
30
+ Args:
31
+ dist_context: The distributed context.
32
+ config: Configuration settings for profiling.
33
+ stepper: Object tracking the current global step of the training loop.
34
+ """
35
+
36
+ self._config = config
37
+ if config is None or not config.enabled:
38
+ self._profiler = None
39
+ else:
40
+ self._profiler = Profiler(
41
+ save_dir=config.traces_dir,
42
+ active_steps=config.active_steps,
43
+ warmup_steps=config.warmup_steps,
44
+ period_steps=config.period_steps,
45
+ dist_context=dist_context
46
+ )
47
+ self._stepper = stepper
48
+
49
+ @contextmanager
50
+ def open(self) -> Generator[torch.profiler.profile | None]:
51
+ """
52
+ Context manager to activate profiling for the job loop.
53
+
54
+ Yields:
55
+ The active Profiler instance if profiling is enabled, otherwise None.
56
+ """
57
+
58
+ if self._profiler is None:
59
+ yield None
60
+ else:
61
+ with self._profiler.open(self._stepper.current_step) as prof:
62
+ yield prof
@@ -0,0 +1,86 @@
1
+ import torch
2
+
3
+ from d9d.internals.pipeline_state import PipelineStateHandler
4
+ from d9d.loop.control import ComputeLossContext, TrainTask
5
+
6
+ from .stepper import Stepper
7
+
8
+ STATE_LOSS = "__internal_loss"
9
+ STATE_LOSS_WEIGHT = "__internal_loss_weight"
10
+
11
+
12
+ class LossComputer:
13
+ """
14
+ Handles the computation of loss values and their integration into the pipeline state.
15
+
16
+ This component acts as a bridge between the raw outputs of the model pipeline
17
+ and the user-defined training task. It retrieves the appropriate state context
18
+ (potentially sharded per microbatch), executes the user's loss logic, persists
19
+ metrics into the state for logging, and returns the loss*weight term for backpropagation.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ state: PipelineStateHandler,
25
+ task: TrainTask,
26
+ stepper: Stepper
27
+ ):
28
+ """
29
+ Constructs a new LossComputer.
30
+
31
+ Args:
32
+ state: Handler for managing global and sharded pipeline states.
33
+ task: The user-defined training task containing loss computation logic.
34
+ stepper: Component tracking current step and progress.
35
+ """
36
+
37
+ self._state = state
38
+ self._task = task
39
+ self._stepper = stepper
40
+
41
+ def compute_loss_mul_weight(
42
+ self,
43
+ pipeline_outputs: dict[str, torch.Tensor],
44
+ microbatch_idx: int | None
45
+ ) -> torch.Tensor:
46
+ """
47
+ Computes the weighted loss for a specific sharded microbatch or the full microbatch.
48
+
49
+ This method retrieves the appropriate state context based on the microbatch
50
+ index, delegates calculation to the training task, saves the raw loss and
51
+ weight into the state for later retrieval, and returns the final scalar
52
+ product used for backward passes.
53
+
54
+ You can retrieve states by using `STATE_LOSS` and `STATE_LOSS_WEIGHT` keys.
55
+
56
+ Args:
57
+ pipeline_outputs: Dictionary containing model output tensors.
58
+ microbatch_idx: Index of the current microbatch, or `None` for full microbatch execution.
59
+
60
+ Returns:
61
+ The calculated loss multiplied by its weight.
62
+ """
63
+
64
+ if microbatch_idx is None:
65
+ state = self._state.global_state()
66
+ else:
67
+ state = self._state.sharded_state(
68
+ shard_id=microbatch_idx
69
+ )
70
+
71
+ computation = self._task.compute_loss(ComputeLossContext(
72
+ pipeline_results=pipeline_outputs,
73
+ state=state,
74
+ stepper=self._stepper
75
+ ))
76
+
77
+ loss = computation.loss
78
+ loss_weight = computation.loss_weight
79
+
80
+ if loss_weight is None:
81
+ loss_weight = torch.ones_like(loss)
82
+
83
+ state[STATE_LOSS] = loss[None]
84
+ state[STATE_LOSS_WEIGHT] = loss_weight[None]
85
+
86
+ return loss * loss_weight
@@ -0,0 +1,37 @@
1
+ from pathlib import Path
2
+
3
+ from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
4
+ from d9d.loop.control import ModelProvider, PrepareExportModelStageContext
5
+ from d9d.model_state.io import save_model_state_pipeline_parallel
6
+ from d9d.model_state.mapper.compose import ModelStateMapperParallel
7
+
8
+ from .model_stage_factory import TrackedModules
9
+
10
+
11
+ class ModelStageExporter:
12
+ def __init__(
13
+ self,
14
+ model_provider: ModelProvider,
15
+ modules: TrackedModules,
16
+ dist_context: DistributedContext
17
+ ):
18
+ self._model_provider = model_provider
19
+ self._modules = modules
20
+ self._dist_context = dist_context
21
+
22
+ def export(self, save_dir: Path):
23
+ mappers = []
24
+ for stage in self._modules.modules:
25
+ result = self._model_provider.prepare_export_model_stage(PrepareExportModelStageContext(
26
+ model=stage,
27
+ dist_context=self._dist_context
28
+ ))
29
+ mappers.append(result.state_mapper)
30
+ save_model_state_pipeline_parallel(
31
+ dest_dir=save_dir,
32
+ mapper=ModelStateMapperParallel(mappers),
33
+ device_mesh=self._dist_context.mesh_for(REGULAR_DOMAIN),
34
+ pipeline_dim_name="pp",
35
+ models=self._modules.modules,
36
+ show_progress=True,
37
+ )
@@ -0,0 +1,261 @@
1
+ import itertools
2
+ from collections.abc import Callable
3
+ from typing import Any
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.distributed.checkpoint.stateful import Stateful
8
+
9
+ from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
10
+ from d9d.loop.config import ModelStageFactoryConfig, PipeliningConfig
11
+ from d9d.loop.control import InitializeModelStageContext, ModelProvider, ParallelizeModelStageContext
12
+ from d9d.model_state.io import load_model_state
13
+ from d9d.module.base import ModuleLateInit
14
+ from d9d.pipelining.api import PipelineStageInfo
15
+ from d9d.pipelining.factory.factory import PipelineScheduleInfo, build_schedule
16
+
17
+ from .batch_maths import BatchMaths
18
+ from .loss_computer import LossComputer
19
+
20
+ StatefulPredicate = Callable[[str, torch.Tensor], bool]
21
+ """Determines if a specific parameter or buffer should be included in the state dictionary."""
22
+
23
+
24
+ def _stateful_predicate_requires_grad(key: str, value: torch.Tensor) -> bool:
25
+ """Predicate that allows saving only tensors that require gradients."""
26
+ return value.requires_grad
27
+
28
+
29
+ def _stateful_predicate_always(key: str, value: torch.Tensor) -> bool:
30
+ """Predicate that always allows saving."""
31
+ return True
32
+
33
+
34
+ class TrackedModules(Stateful):
35
+ """
36
+ Wraps a list of model stages and manages their state for distributed checkpointing.
37
+
38
+ This class implements the PyTorch Distributed `Stateful` protocol, aggregating
39
+ the state dictionaries of multiple pipeline stages assigned to the current rank.
40
+ It handles namespacing to ensure uniqueness across pipeline ranks and stages.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dist_context: DistributedContext,
46
+ modules: list[nn.Module],
47
+ stateful_predicate: StatefulPredicate
48
+ ):
49
+ """Constructs a TrackedModules object."""
50
+ self._dist_context = dist_context
51
+ self._modules = modules
52
+ self._stateful_predicate = stateful_predicate
53
+
54
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
55
+ """
56
+ Forwards execution to the only pipeline stage.
57
+
58
+ This method is only valid when pipeline parallelism is disabled.
59
+
60
+ Args:
61
+ *args: Positional arguments passed to the module.
62
+ **kwargs: Keyword arguments passed to the module.
63
+
64
+ Returns:
65
+ The output of the model execution.
66
+
67
+ Raises:
68
+ ValueError: If pipeline parallelism is configured.
69
+ """
70
+
71
+ if self._dist_context.mesh_params.has_pipeline_parallel:
72
+ raise ValueError("You cannot call tracked modules when using pipelining")
73
+
74
+ return self._modules[0](*args, **kwargs)
75
+
76
+ @property
77
+ def modules(self) -> list[nn.Module]:
78
+ """Returns the list of underlying PyTorch model modules."""
79
+ return self._modules
80
+
81
+ def _whitelisted_params(self, module: nn.Module) -> set[str]:
82
+ allow_saving = set()
83
+ for param_name, param in itertools.chain(module.named_parameters(), module.named_buffers()):
84
+ if self._stateful_predicate(param_name, param):
85
+ allow_saving.add(param_name)
86
+ return allow_saving
87
+
88
+ def _state_dict_stage(self, module: nn.Module) -> dict[str, Any]:
89
+ whitelist = self._whitelisted_params(module)
90
+ result = {
91
+ k: v for k, v in module.state_dict().items() if k in whitelist
92
+ }
93
+ return result
94
+
95
+ def state_dict(self) -> dict[str, Any]:
96
+ """
97
+ Generates the state dictionary for all tracked modules.
98
+
99
+ The keys are namespaced using the current pipeline rank and stage index
100
+ (e.g., `pp_0_stage_0`). Only parameters satisfying the `stateful_predicate`
101
+ are included.
102
+
103
+ Returns:
104
+ A dictionary containing the states of all managed modules.
105
+ """
106
+
107
+ pp_rank = self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"].get_local_rank()
108
+ ret = {
109
+ f"pp_{pp_rank}_stage_{i}": self._state_dict_stage(module)
110
+ for i, module in enumerate(self._modules)
111
+ }
112
+ return ret
113
+
114
+ def _load_state_dict_stage(self, module: nn.Module, state_dict: dict[str, Any]):
115
+ whitelist = self._whitelisted_params(module)
116
+
117
+ loading_result = module.load_state_dict(state_dict, strict=False)
118
+ missing_keys = set(loading_result.missing_keys)
119
+ extra_keys = set(loading_result.unexpected_keys)
120
+
121
+ if len(whitelist.intersection(missing_keys)) > 0:
122
+ raise ValueError(f"Missing keys: {whitelist.intersection(missing_keys)}")
123
+ if len(extra_keys) > 0:
124
+ raise ValueError(f"Extra keys: {extra_keys}")
125
+
126
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
127
+ """
128
+ Loads the state dictionary into the tracked modules.
129
+
130
+ Args:
131
+ state_dict: The state dictionary to load. Must contain keys corresponding
132
+ to the pipeline rank and stage indices managed by this instance.
133
+
134
+ Raises:
135
+ ValueError: If required keys are missing or unexpected keys are present
136
+ based on the allow-list predicate.
137
+ """
138
+
139
+ pp_rank = self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"].get_local_rank()
140
+ for i, module in enumerate(self._modules):
141
+ self._load_state_dict_stage(module, state_dict[f"pp_{pp_rank}_stage_{i}"])
142
+
143
+
144
+ class ModelStageFactory:
145
+ """
146
+ Factory class responsible for creating, initializing, and parallelizing model stages.
147
+
148
+ This class coordinates the `ModelProvider` with the distributed context to:
149
+
150
+ 1. Initialize models on a meta device.
151
+ 2. Apply horizontal distribution strategy (TP, DP, FSDP, etc).
152
+ 3. Materialize weights on the target device.
153
+ 4. Load initial model states from checkpoints.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ model_provider: ModelProvider,
159
+ dist_context: DistributedContext,
160
+ batch_maths: BatchMaths,
161
+ config_model: ModelStageFactoryConfig,
162
+ config_pipelining: PipeliningConfig | None,
163
+ loss_computer: LossComputer | None
164
+ ):
165
+ """Constructs a ModelStageFactory object."""
166
+
167
+ self._model_provider = model_provider
168
+ self._dist_context = dist_context
169
+ self._config_model = config_model
170
+ self._config_pipelining = config_pipelining
171
+ self._batch_maths = batch_maths
172
+ self._loss_computer = loss_computer
173
+
174
+ def _build_model_stage(self, stage: PipelineStageInfo) -> nn.Module:
175
+ # create a model with no real memory occupied
176
+ with torch.device("meta"):
177
+ factored = self._model_provider.initialize_model_stage(
178
+ InitializeModelStageContext(
179
+ dist_context=self._dist_context,
180
+ stage=stage,
181
+ )
182
+ )
183
+
184
+ model = factored.model
185
+
186
+ if not isinstance(model, ModuleLateInit) or not isinstance(model, nn.Module):
187
+ raise ValueError("Model stage is required to be nn.Module instance implementing ModuleLateInit protocol")
188
+
189
+ # if current context is distributed - parallelize this model
190
+ if self._dist_context.mesh_params.is_distributed:
191
+ self._model_provider.parallelize_model_stage(
192
+ ParallelizeModelStageContext(
193
+ model=model,
194
+ stage=stage,
195
+ dist_context=self._dist_context
196
+ )
197
+ )
198
+
199
+ # move state that is bound to current device to it
200
+ model.to_empty(device=self._dist_context.current_device)
201
+
202
+ # reinitialize model parameters (only these are on current device)
203
+ with torch.no_grad():
204
+ model.reset_parameters()
205
+
206
+ if self._config_model.source_checkpoint:
207
+ load_model_state(
208
+ src_dir=self._config_model.source_checkpoint,
209
+ model=model,
210
+ mapper=factored.state_mapper,
211
+ device=f"cuda:{torch.cuda.current_device()}"
212
+ )
213
+
214
+ # set training state
215
+ model.train()
216
+
217
+ return model
218
+
219
+ def build_pipeline_and_modules(
220
+ self
221
+ ) -> tuple[PipelineScheduleInfo | None, TrackedModules]:
222
+ """
223
+ Constructs the execution schedule and the model container.
224
+
225
+ If pipeline parallelism is enabled, this orchestrates the creation of a
226
+ distributed pipeline schedule.
227
+
228
+ Otherwise, it simply builds a standalone model stage.
229
+
230
+ Returns:
231
+ The pipeline schedule information (or None if no pipelining).
232
+ The `TrackedModules` instance wrapping the created model stage(s).
233
+
234
+ Raises:
235
+ ValueError: If pipelining configuration is missing but a pipeline is requested.
236
+ """
237
+
238
+ if self._config_model.checkpoint_only_trainable_parameters:
239
+ stateful_predicate = _stateful_predicate_requires_grad
240
+ else:
241
+ stateful_predicate = _stateful_predicate_always
242
+
243
+ if self._dist_context.mesh_params.has_pipeline_parallel:
244
+ if self._config_pipelining is None:
245
+ raise ValueError("Pipelining is enabled, but not configured")
246
+
247
+ loss_fn = self._loss_computer.compute_loss_mul_weight if self._loss_computer is not None else None
248
+
249
+ schedule, modules = build_schedule(
250
+ dist_context=self._dist_context,
251
+ n_microbatches=self._batch_maths.num_microbatches_pipelining,
252
+ schedule_config=self._config_pipelining.schedule,
253
+ model_provider=self._build_model_stage,
254
+ loss_fn=loss_fn
255
+ )
256
+
257
+ return schedule, TrackedModules(self._dist_context, modules, stateful_predicate)
258
+ else:
259
+ model = self._build_model_stage(PipelineStageInfo(num_stages=1, current_stage=0))
260
+
261
+ return None, TrackedModules(self._dist_context, [model], stateful_predicate)
@@ -0,0 +1,88 @@
1
+ from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
2
+ from d9d.core.protocol import LRSchedulerProtocol, OptimizerProtocol
3
+ from d9d.loop.control import (
4
+ InitializeLRSchedulerContext,
5
+ InitializeOptimizerStageContext,
6
+ LRSchedulerProvider,
7
+ OptimizerProvider,
8
+ )
9
+ from d9d.pipelining.training import PipelinedLRScheduler, PipelinedOptimizer
10
+
11
+ from .model_stage_factory import TrackedModules
12
+ from .stepper import Stepper
13
+
14
+
15
+ class OptimizerFactory:
16
+ """
17
+ Factory for creating and configuring distributed optimizers and learning rate schedulers.
18
+
19
+ This factory handles the orchestration of optimizer creation for models potentially split across
20
+ pipeline stages. It uses the providers to instantiate underlying PyTorch optimizers and schedulers for each
21
+ tracked module, and wraps them in pipeline-aware interfaces.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ dist_context: DistributedContext,
27
+ tracked_modules: TrackedModules,
28
+ optimizer_provider: OptimizerProvider,
29
+ lr_scheduler_provider: LRSchedulerProvider,
30
+ stepper: Stepper
31
+ ):
32
+ """
33
+ Constructs the OptimizerFactory.
34
+
35
+ Args:
36
+ dist_context: The distributed context.
37
+ tracked_modules: A container of model modules owned by the current rank.
38
+ optimizer_provider: A callable responsible for creating optimizer instances for a given model.
39
+ lr_scheduler_provider: A callable responsible for creating LR scheduler instances.
40
+ stepper: The training stepper providing information about total training steps.
41
+ """
42
+ self._dist_context = dist_context
43
+ self._tracked_modules = tracked_modules
44
+ self._optimizer_provider = optimizer_provider
45
+ self._lr_scheduler_provider = lr_scheduler_provider
46
+ self._stepper = stepper
47
+
48
+ def build_optimizer_and_scheduler(self) -> tuple[OptimizerProtocol, LRSchedulerProtocol]:
49
+ """
50
+ Builds both the optimizer and learning rate scheduler.
51
+
52
+ This method iterates through all local model modules. For each module, it creates an
53
+ optimizer and scheduler using the configured providers. Finally, it aggregates these individual
54
+ instances into a single `PipelinedOptimizer` and `PipelinedLRScheduler` capable of coordinated
55
+ stepping across the pipeline parallel dimension.
56
+
57
+ Returns:
58
+ A tuple containing the initialized pipeline-aware optimizer and scheduler.
59
+ """
60
+
61
+ optimizers: list[OptimizerProtocol] = []
62
+ lr_schedulers: list[LRSchedulerProtocol] = []
63
+ for module in self._tracked_modules.modules:
64
+ optimizer = self._optimizer_provider(
65
+ InitializeOptimizerStageContext(
66
+ dist_context=self._dist_context,
67
+ model=module
68
+ )
69
+ )
70
+ optimizers.append(optimizer)
71
+
72
+ scheduler = self._lr_scheduler_provider(
73
+ InitializeLRSchedulerContext(
74
+ dist_context=self._dist_context,
75
+ total_steps=self._stepper.total_steps,
76
+ optimizer=optimizer
77
+ )
78
+ )
79
+ lr_schedulers.append(scheduler)
80
+ pipe_optimizer = PipelinedOptimizer(
81
+ mesh_pp=self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"],
82
+ optimizers=optimizers
83
+ )
84
+ pipe_scheduler = PipelinedLRScheduler(
85
+ mesh_pp=self._dist_context.mesh_for(REGULAR_DOMAIN)["pp"],
86
+ schedulers=lr_schedulers
87
+ )
88
+ return pipe_optimizer, pipe_scheduler