d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,176 @@
1
+ from typing import Annotated, Literal
2
+
3
+ from pydantic import BaseModel, Field, PositiveInt
4
+ from torch.optim import Optimizer
5
+
6
+ from d9d.core.protocol import LRSchedulerProtocol
7
+
8
+ from .builder import piecewise_schedule
9
+ from .curves import CurveBase, CurveCosine, CurveExponential, CurveLinear, CurvePoly
10
+
11
+
12
+ class CurveLinearConfig(BaseModel):
13
+ """
14
+ Configuration for linear interpolation.
15
+ """
16
+
17
+ type: Literal["linear"] = "linear"
18
+
19
+
20
+ class CurveCosineConfig(BaseModel):
21
+ """
22
+ Configuration for cosine interpolation.
23
+ """
24
+
25
+ type: Literal["cosine"] = "cosine"
26
+
27
+
28
+ class CurveExponentialConfig(BaseModel):
29
+ """
30
+ Configuration for exponential interpolation.
31
+ """
32
+
33
+ type: Literal["exponential"] = "exponential"
34
+
35
+
36
+ class CurvePolyConfig(BaseModel):
37
+ """
38
+ Configuration for polynomial interpolation.
39
+
40
+ Attributes:
41
+ power: The exponent of the polynomial function.
42
+ """
43
+
44
+ type: Literal["poly"] = "poly"
45
+ power: float = 2.0
46
+
47
+
48
+ AnyCurveConfig = Annotated[
49
+ CurveLinearConfig | CurveCosineConfig | CurveExponentialConfig | CurvePolyConfig,
50
+ Field(discriminator="type")
51
+ ]
52
+
53
+
54
+ def curve_from_config(config: AnyCurveConfig) -> CurveBase:
55
+ """
56
+ Instantiates a concrete curve object from its configuration.
57
+
58
+ Args:
59
+ config: The configuration object.
60
+
61
+ Returns:
62
+ The instantiated curve.
63
+ """
64
+
65
+ match config:
66
+ case CurveLinearConfig():
67
+ return CurveLinear()
68
+ case CurvePolyConfig():
69
+ return CurvePoly(config.power)
70
+ case CurveExponentialConfig():
71
+ return CurveExponential()
72
+ case CurveCosineConfig():
73
+ return CurveCosine()
74
+
75
+
76
+ class StepPhaseConfig(BaseModel):
77
+ """
78
+ Configuration for a phase defined by a fixed number of steps.
79
+
80
+ Attributes:
81
+ mode: Discriminator field, must be "steps".
82
+ steps: The absolute duration of this phase in steps.
83
+ target_multiplier: The multiplier value at the end of this phase.
84
+ curve: The interpolation curve configuration.
85
+ """
86
+
87
+ mode: Literal["steps"] = "steps"
88
+
89
+ steps: PositiveInt
90
+ target_multiplier: float
91
+ curve: AnyCurveConfig
92
+
93
+
94
+ class PercentagePhaseConfig(BaseModel):
95
+ """
96
+ Configuration for a phase that lasts until a specific percentage of training is complete.
97
+
98
+ Attributes:
99
+ mode: Discriminator field, must be "percentage".
100
+ percentage: The target progress (0.0 to 1.0) where this phase ends.
101
+ target_multiplier: The multiplier value at the end of this phase.
102
+ curve: The interpolation curve configuration.
103
+ """
104
+
105
+ mode: Literal["percentage"] = "percentage"
106
+
107
+ percentage: float = Field(..., ge=0.0, le=1.0)
108
+ target_multiplier: float
109
+ curve: AnyCurveConfig
110
+
111
+
112
+ class RestPhaseConfig(BaseModel):
113
+ """
114
+ Configuration for a phase that fills the remainder of the training duration.
115
+
116
+ Attributes:
117
+ mode: Discriminator field, must be "rest".
118
+ target_multiplier: The multiplier value at the very end of training.
119
+ curve: The interpolation curve configuration.
120
+ """
121
+
122
+ mode: Literal["rest"] = "rest"
123
+
124
+ target_multiplier: float
125
+ curve: AnyCurveConfig
126
+
127
+
128
+ PhaseConfig = Annotated[
129
+ StepPhaseConfig | PercentagePhaseConfig | RestPhaseConfig,
130
+ Field(discriminator="mode")
131
+ ]
132
+
133
+
134
+ class PiecewiseSchedulerConfig(BaseModel):
135
+ """
136
+ Declarative configuration for a piecewise learning rate scheduler.
137
+
138
+ Attributes:
139
+ initial_multiplier: The starting learning rate multiplier.
140
+ phases: A sequential list of phase configurations.
141
+ """
142
+
143
+ initial_multiplier: float
144
+ phases: list[PhaseConfig]
145
+
146
+
147
+ def piecewise_scheduler_from_config(
148
+ config: PiecewiseSchedulerConfig,
149
+ optimizer: Optimizer,
150
+ total_steps: int | None
151
+ ) -> LRSchedulerProtocol:
152
+ """
153
+ Constructs a PyTorch scheduler from the provided configuration.
154
+
155
+ Args:
156
+ config: The scheduler configuration.
157
+ optimizer: The optimizer to wrap.
158
+ total_steps: The total number of training steps. Required if using percentage-based phases.
159
+
160
+ Returns:
161
+ A configured learning rate scheduler.
162
+ """
163
+
164
+ builder = piecewise_schedule(config.initial_multiplier, total_steps)
165
+
166
+ for phase in config.phases:
167
+ curve = curve_from_config(phase.curve)
168
+ match phase:
169
+ case StepPhaseConfig():
170
+ builder.for_steps(phase.steps, phase.target_multiplier, curve)
171
+ case PercentagePhaseConfig():
172
+ builder.until_percentage(phase.percentage, phase.target_multiplier, curve)
173
+ case RestPhaseConfig():
174
+ builder.fill_rest(phase.target_multiplier, curve)
175
+
176
+ return builder.build(optimizer)
@@ -0,0 +1,75 @@
1
+ import abc
2
+ import math
3
+
4
+
5
+ class CurveBase(abc.ABC):
6
+ """
7
+ Abstract base class for interpolation curves used in scheduling.
8
+ """
9
+
10
+ @abc.abstractmethod
11
+ def compute(self, start: float, end: float, step_p: float) -> float:
12
+ """
13
+ Calculates the interpolated value.
14
+
15
+ Args:
16
+ start: The value at the beginning of the phase.
17
+ end: The value at the end of the phase.
18
+ step_p: Progress fraction through the phase (0.0 to 1.0).
19
+
20
+ Returns:
21
+ The interpolated value.
22
+ """
23
+
24
+
25
+ class CurveLinear(CurveBase):
26
+ """
27
+ Linearly interpolates between start and end values.
28
+ """
29
+
30
+ def compute(self, start: float, end: float, step_p: float) -> float:
31
+ return start + (end - start) * step_p
32
+
33
+
34
+ class CurveCosine(CurveBase):
35
+ """
36
+ Interpolates using a cosine annealing schedule (half-period cosine).
37
+ """
38
+
39
+ def compute(self, start: float, end: float, step_p: float) -> float:
40
+ cos_out = (1 + math.cos(math.pi * step_p)) / 2
41
+ return end + (start - end) * cos_out
42
+
43
+
44
+ class CurvePoly(CurveBase):
45
+ """
46
+ Interpolates using a polynomial function.
47
+ """
48
+
49
+ def __init__(self, power: float):
50
+ """
51
+ Constructs a polynomial curve.
52
+
53
+ Args:
54
+ power: The exponent of the polynomial. 1.0 is linear, 2.0 is quadratic, etc.
55
+ """
56
+
57
+ self._power = power
58
+
59
+ def compute(self, start: float, end: float, step_p: float) -> float:
60
+ p_transformed = step_p ** self._power
61
+ return start + (end - start) * p_transformed
62
+
63
+
64
+ class CurveExponential(CurveBase):
65
+ """
66
+ Interpolates exponentially between start and end values (log-space linear).
67
+ """
68
+
69
+ def compute(self, start: float, end: float, step_p: float) -> float:
70
+ eps = 1e-8
71
+ safe_start = max(start, eps)
72
+ safe_end = max(end, eps)
73
+
74
+ out_log = math.log(safe_start) + (math.log(safe_end) - math.log(safe_start)) * step_p
75
+ return math.exp(out_log)
@@ -0,0 +1,76 @@
1
+ import dataclasses
2
+
3
+ from .curves import CurveBase
4
+
5
+
6
+ @dataclasses.dataclass
7
+ class SchedulePhase:
8
+ """
9
+ Data container representing a single phase in a piecewise schedule.
10
+
11
+ Attributes:
12
+ start_step: The absolute step index where this phase begins.
13
+ end_step: The absolute step index where this phase ends.
14
+ start_value: The multiplier value at start_step.
15
+ end_value: The multiplier value at end_step.
16
+ curve: The interpolation logic for this phase.
17
+ """
18
+
19
+ start_step: int
20
+ end_step: int
21
+ start_value: float
22
+ end_value: float
23
+ curve: CurveBase
24
+
25
+
26
+ class PiecewiseScheduleEngine:
27
+ """
28
+ Runtime engine that calculates multipliers based on a list of defined phases.
29
+ """
30
+
31
+ def __init__(self, phases: list[SchedulePhase]):
32
+ """
33
+ Constructs the schedule engine.
34
+
35
+ Args:
36
+ phases: A sequential list of schedule phases.
37
+
38
+ Raises:
39
+ ValueError: If the phases list is empty.
40
+ """
41
+
42
+ if len(phases) == 0:
43
+ raise ValueError("Scheduler should contain at least one phase")
44
+
45
+ self._phases = phases
46
+
47
+ def get_factor(self, step: int) -> float:
48
+ """
49
+ Computes the learning rate multiplier for the given step.
50
+
51
+ Args:
52
+ step: The global training step.
53
+
54
+ Returns:
55
+ The calculated multiplier. If the step is outside defined phases,
56
+ it clamps to the nearest boundary value.
57
+ """
58
+
59
+ if step < 0:
60
+ return self._phases[0].start_value
61
+
62
+ for phase in self._phases:
63
+ if not (phase.start_step <= step < phase.end_step):
64
+ continue
65
+
66
+ steps_in_phase = step - phase.start_step
67
+ phase_len = phase.end_step - phase.start_step
68
+ phase_progress = steps_in_phase / phase_len
69
+
70
+ return phase.curve.compute(
71
+ start=phase.start_value,
72
+ end=phase.end_value,
73
+ step_p=phase_progress
74
+ )
75
+
76
+ return self._phases[-1].end_value
@@ -0,0 +1,74 @@
1
+ from collections.abc import Callable
2
+
3
+ from torch import nn
4
+ from torch.optim import SGD, Optimizer
5
+ from torch.optim.lr_scheduler import LRScheduler
6
+
7
+ SchedulerFactory = Callable[[Optimizer], LRScheduler]
8
+
9
+
10
+ def _get_history(factory: SchedulerFactory, num_steps: int, init_lr: float) -> list[float]:
11
+ optimizer = SGD(nn.Linear(1, 1).parameters(), lr=init_lr)
12
+
13
+ scheduler = factory(optimizer)
14
+
15
+ lrs = []
16
+
17
+ for _ in range(num_steps):
18
+ current_lr = optimizer.param_groups[0]["lr"]
19
+ lrs.append(current_lr)
20
+ scheduler.step()
21
+
22
+ return lrs
23
+
24
+
25
+ def visualize_lr_scheduler(factory: SchedulerFactory, num_steps: int, init_lr: float = 1.0):
26
+ """
27
+ Visualizes the learning rate schedule using Plotly.
28
+
29
+ This function simulates the training process for `num_steps` to record the LR changes
30
+ and generates an interactive plot.
31
+
32
+ Args:
33
+ factory: A callable that accepts an Optimizer and returns an LRScheduler.
34
+ num_steps: The number of steps to simulate.
35
+ init_lr: The initial learning rate to set on the dummy optimizer.
36
+
37
+ Raises:
38
+ ImportError: If the `plotly` library is not installed.
39
+ """
40
+
41
+ try:
42
+ import plotly.graph_objects as go # noqa: PLC0415
43
+ except ImportError as e:
44
+ raise ImportError("You have to install `plotly` dependency to use scheduler visualization") from e
45
+ lrs = _get_history(factory, num_steps, init_lr)
46
+ steps = list(range(num_steps))
47
+
48
+ fig = go.Figure()
49
+
50
+ fig.add_trace(go.Scatter(
51
+ x=steps,
52
+ y=lrs,
53
+ mode="lines",
54
+ name="Learning Rate",
55
+ line={"color": "#636EFA", "width": 3},
56
+ hovertemplate="<b>Step:</b> %{x}<br><b>LR:</b> %{y:.6f}<extra></extra>"
57
+ ))
58
+
59
+ fig.update_layout(
60
+ title={
61
+ "text": "Scheduler",
62
+ "y": 0.95,
63
+ "x": 0.5,
64
+ "xanchor": "center",
65
+ "yanchor": "top"
66
+ },
67
+ xaxis_title="Steps",
68
+ yaxis_title="Learning Rate",
69
+ template="plotly_white",
70
+ hovermode="x unified",
71
+ height=500
72
+ )
73
+
74
+ fig.show()
d9d/metric/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ """
2
+ Distributed metric abstractions and implementations.
3
+ """
4
+
5
+
6
+ from .abc import Metric
7
+
8
+ __all__ = [
9
+ "Metric"
10
+ ]
d9d/metric/abc.py ADDED
@@ -0,0 +1,79 @@
1
+ import abc
2
+ from typing import Any, Generic, TypeVar
3
+
4
+ import torch
5
+ from torch.distributed.checkpoint.stateful import Stateful
6
+
7
+ from d9d.core.dist_context import DistributedContext
8
+ from d9d.core.types import TensorTree
9
+
10
+ TComputeResult = TypeVar("TComputeResult", bound=TensorTree)
11
+
12
+
13
+ class Metric(abc.ABC, Stateful, Generic[TComputeResult]):
14
+ """
15
+ Abstract base class for all metrics.
16
+
17
+ Metrics track statistics over time (e.g., during training) and can be synchronized
18
+ across distributed processes. They also support state persistence via the Stateful
19
+ interface.
20
+ """
21
+
22
+ @abc.abstractmethod
23
+ def update(self, *args: Any, **kwargs: Any):
24
+ """
25
+ Updates the metric state with a new batch of data.
26
+
27
+ Args:
28
+ *args: Positional arguments required by the specific metric implementation.
29
+ **kwargs: Keyword arguments required by the specific metric implementation.
30
+ """
31
+
32
+ @abc.abstractmethod
33
+ def trigger_sync(self, dist_context: DistributedContext):
34
+ """
35
+ Initiates the synchronization of the metric state across distributed processes.
36
+
37
+ This method should start the collective operations (e.g., all-reduce) required
38
+ to aggregate statistics, but should not block waiting for completion if possible.
39
+
40
+ Args:
41
+ dist_context: The distributed context.
42
+ """
43
+
44
+ @abc.abstractmethod
45
+ def wait_sync(self, dist_context: DistributedContext):
46
+ """
47
+ Waits for the synchronization initiated by `trigger_sync` to complete.
48
+
49
+ After this method returns, the metric state must be fully aggregated and
50
+ consistent across ranks.
51
+
52
+ Args:
53
+ dist_context: The distributed context.
54
+ """
55
+
56
+ @abc.abstractmethod
57
+ def compute(self) -> TComputeResult:
58
+ """
59
+ Computes the current value of the metric.
60
+
61
+ Returns:
62
+ The computed metric result (of type `TComputeResult`).
63
+ This can be a single `torch.Tensor` or `PyTree` structure (dict, list, etc.)
64
+ containing tensors, depending on how the subclass was typed.
65
+ """
66
+
67
+ @abc.abstractmethod
68
+ def reset(self):
69
+ """
70
+ Resets the internal state of the metric to the initial values.
71
+ """
72
+
73
+ def to(self, device: str | torch.device | int):
74
+ """
75
+ Moves a metric state to a specified device.
76
+
77
+ Args:
78
+ device: The device to move the metric state to.
79
+ """
@@ -0,0 +1,7 @@
1
+ from .compose import ComposeMetric
2
+ from .mean import WeightedMeanMetric
3
+
4
+ __all__ = [
5
+ "ComposeMetric",
6
+ "WeightedMeanMetric"
7
+ ]
@@ -0,0 +1,54 @@
1
+ from collections.abc import Mapping
2
+ from typing import Any
3
+
4
+ import torch
5
+
6
+ from d9d.core.dist_context import DistributedContext
7
+ from d9d.metric import Metric
8
+
9
+
10
+ class ComposeMetric(Metric[dict[str, Any]]):
11
+ def __init__(self, children: Mapping[str, Metric]):
12
+ self._children = children
13
+
14
+ def update(self, *args: Any, **kwargs: Any):
15
+ raise ValueError("Cannot update ComposeMetric directly - you can only update its children")
16
+
17
+ def __getitem__(self, item: str) -> Metric:
18
+ return self._children[item]
19
+
20
+ @property
21
+ def children(self) -> Mapping[str, Metric]:
22
+ return self._children
23
+
24
+ def trigger_sync(self, dist_context: DistributedContext):
25
+ for metric in self._children.values():
26
+ metric.trigger_sync(dist_context)
27
+
28
+ def wait_sync(self, dist_context: DistributedContext):
29
+ for metric in self._children.values():
30
+ metric.wait_sync(dist_context)
31
+
32
+ def compute(self) -> dict[str, Any]:
33
+ return {
34
+ metric_name: metric.compute()
35
+ for metric_name, metric in self._children.items()
36
+ }
37
+
38
+ def reset(self):
39
+ for metric in self._children.values():
40
+ metric.reset()
41
+
42
+ def to(self, device: str | torch.device | int):
43
+ for metric in self._children.values():
44
+ metric.to(device)
45
+
46
+ def state_dict(self) -> dict[str, Any]:
47
+ return {
48
+ metric_name: metric.state_dict()
49
+ for metric_name, metric in self._children.items()
50
+ }
51
+
52
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
53
+ for metric_name, metric in self._children.items():
54
+ metric.load_state_dict(state_dict[metric_name])
@@ -0,0 +1,94 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+ from d9d.core.dist_context import DistributedContext
7
+ from d9d.metric import Metric
8
+
9
+
10
+ class WeightedMeanMetric(Metric[torch.Tensor]):
11
+ """
12
+ Computes the weighted mean of values.
13
+
14
+ Tracks the sum of weighted values and the sum of weights.
15
+ """
16
+
17
+ def __init__(self):
18
+ """Constructs a WeightedMeanMetric object."""
19
+
20
+ super().__init__()
21
+ self._value = torch.scalar_tensor(0, dtype=torch.float32)
22
+ self._weight = torch.scalar_tensor(0, dtype=torch.float32)
23
+
24
+ self._is_synced = False
25
+ self._synced_value = torch.scalar_tensor(0, dtype=torch.float32)
26
+ self._synced_weight = torch.scalar_tensor(0, dtype=torch.float32)
27
+
28
+ self._handles: list[dist.Work] | None = None
29
+
30
+ def update(self, values: torch.Tensor, weights: torch.Tensor):
31
+ self._value += (values * weights).sum()
32
+ self._weight += weights.sum()
33
+
34
+ self._is_synced = False
35
+
36
+ def trigger_sync(self, dist_context: DistributedContext):
37
+ self._synced_value = self._value.clone()
38
+ self._synced_weight = self._weight.clone()
39
+ self._is_synced = True
40
+
41
+ self._handles = [
42
+ dist.all_reduce(self._synced_value, op=dist.ReduceOp.SUM, async_op=True),
43
+ dist.all_reduce(self._synced_weight, op=dist.ReduceOp.SUM, async_op=True)
44
+ ]
45
+
46
+ def wait_sync(self, dist_context: DistributedContext):
47
+ if self._handles is None:
48
+ raise RuntimeError("Sync was not triggered before")
49
+
50
+ for handle in self._handles:
51
+ handle.wait()
52
+ self._handles = None
53
+
54
+ def compute(self) -> torch.Tensor:
55
+ if self._is_synced:
56
+ return self._synced_value / self._synced_weight
57
+ else:
58
+ return self._value / self._weight
59
+
60
+ def reset(self):
61
+ self._value.fill_(0)
62
+ self._weight.fill_(0)
63
+ self._is_synced = False
64
+ self._handles = None
65
+
66
+ def to(self, device: str | torch.device | int):
67
+ self._weight = self._weight.to(device)
68
+ self._value = self._value.to(device)
69
+ self._synced_weight = self._synced_weight.to(device)
70
+ self._synced_value = self._synced_value.to(device)
71
+
72
+ @property
73
+ def accumulated_weight(self) -> torch.Tensor:
74
+ """
75
+ Returns the total weight accumulated so far.
76
+
77
+ Returns:
78
+ Scalar tensor with total weight.
79
+ """
80
+
81
+ if self._is_synced:
82
+ return self._synced_weight
83
+
84
+ return self._weight
85
+
86
+ def state_dict(self) -> dict[str, Any]:
87
+ return {
88
+ "value": self._value,
89
+ "weight": self._weight
90
+ }
91
+
92
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
93
+ self._value = state_dict["value"]
94
+ self._weight = state_dict["weight"]
File without changes
@@ -0,0 +1,21 @@
1
+ from .module_reader import load_model_state
2
+ from .module_writer import (
3
+ save_model_state,
4
+ save_model_state_pipeline_parallel,
5
+ )
6
+ from .reader import read_model_state
7
+ from .writer import (
8
+ write_model_state_distributed,
9
+ write_model_state_local,
10
+ write_model_state_pipeline_parallel,
11
+ )
12
+
13
+ __all__ = [
14
+ "load_model_state",
15
+ "read_model_state",
16
+ "save_model_state",
17
+ "save_model_state_pipeline_parallel",
18
+ "write_model_state_distributed",
19
+ "write_model_state_local",
20
+ "write_model_state_pipeline_parallel"
21
+ ]