d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,257 @@
1
+ import dataclasses
2
+ from collections import defaultdict
3
+ from typing import cast
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.distributed import DeviceMesh
8
+ from torch.distributed.tensor import DTensor, Replicate, Shard
9
+
10
+ from .bucket import AbstractGradientBucket, LocalGradientBucket, SyncGradientBucket
11
+
12
+
13
+ def _find_reduce_mesh(data: DTensor) -> DeviceMesh | None:
14
+ """
15
+ Identifies the sub-mesh required for gradient reduction based on tensor placements.
16
+
17
+ Args:
18
+ data: The parameter tensor.
19
+
20
+ Returns:
21
+ The DeviceMesh subset needed for reduction, or None if no reduction is needed.
22
+ """
23
+
24
+ reduce_dims: set[int] = set()
25
+
26
+ for dim_i, dim_placement in enumerate(data.placements):
27
+ match dim_placement:
28
+ case Replicate():
29
+ reduce_dims.add(dim_i)
30
+ case Shard():
31
+ pass
32
+ case _:
33
+ raise ValueError(f"Unknown grad placement: {dim_placement}")
34
+
35
+ if len(reduce_dims) == 0:
36
+ return None
37
+
38
+ device_mesh: DeviceMesh = data.device_mesh
39
+
40
+ # we are sure that device mesh contain dim names so we cast(...)
41
+ mesh_dim_names = cast(tuple[str, ...], device_mesh.mesh_dim_names)
42
+ reduce_mesh = device_mesh[tuple(
43
+ mesh_dim_names[dim_i] for dim_i in reduce_dims
44
+ )]
45
+
46
+ return reduce_mesh
47
+
48
+
49
+ @dataclasses.dataclass(frozen=True)
50
+ class _ParameterGroupMarker:
51
+ """
52
+ Identifier for grouping compatible parameters into buckets.
53
+ """
54
+
55
+ group_i: int
56
+ reduce_mesh: DeviceMesh | None
57
+ device: torch.device
58
+ grad_dtype: torch.dtype | None
59
+
60
+
61
+ def _group_params_for_buckets(
62
+ param_groups: list[list[nn.Parameter]]
63
+ ) -> dict[_ParameterGroupMarker, list[nn.Parameter]]:
64
+ """
65
+ Sorts parameters into groups based on their synchronization requirements.
66
+
67
+ Args:
68
+ param_groups: List of parameter groups (from optimizer).
69
+
70
+ Returns:
71
+ Dictionary mapping group markers to lists of parameters.
72
+ """
73
+
74
+ regrouped_params = defaultdict(list)
75
+ for param_group_i, param_group in enumerate(param_groups):
76
+ # iterate in reverse order to maximize overlap
77
+ for param in param_group[::-1]:
78
+ if not param.requires_grad:
79
+ continue
80
+
81
+ if not isinstance(param.data, DTensor):
82
+ raise TypeError("All params should be DTensors in a distributed setup")
83
+
84
+ reduce_mesh = _find_reduce_mesh(param.data)
85
+
86
+ group = _ParameterGroupMarker(
87
+ group_i=param_group_i,
88
+ reduce_mesh=reduce_mesh,
89
+ device=param.data.device,
90
+ grad_dtype=param.grad_dtype
91
+ )
92
+
93
+ regrouped_params[group].append(param)
94
+
95
+ return regrouped_params
96
+
97
+
98
+ def _make_bucket(
99
+ require_accumulations: int,
100
+ group_marker: _ParameterGroupMarker,
101
+ parameters: list[nn.Parameter],
102
+ communicate_stream: torch.cuda.Stream
103
+ ) -> AbstractGradientBucket:
104
+ """
105
+ Factory function to create the appropriate bucket type.
106
+ """
107
+
108
+ if group_marker.reduce_mesh is None:
109
+ return LocalGradientBucket(parameters)
110
+ else:
111
+ if group_marker.grad_dtype is None:
112
+ raise ValueError("Gradient dtype could not be None")
113
+
114
+ return SyncGradientBucket(
115
+ parameters=parameters,
116
+ require_accumulations=require_accumulations,
117
+ device=group_marker.device,
118
+ grad_dtype=group_marker.grad_dtype,
119
+ reduce_mesh=group_marker.reduce_mesh,
120
+ communicate_stream=communicate_stream
121
+ )
122
+
123
+
124
+ def _fill_buckets(
125
+ param_groups: dict[_ParameterGroupMarker, list[nn.Parameter]],
126
+ bucket_size_mb: int,
127
+ require_accumulations: int,
128
+ communicate_stream: torch.cuda.Stream
129
+ ) -> list[AbstractGradientBucket]:
130
+ """
131
+ Splits grouped parameters into buckets based on size constraints.
132
+
133
+ Args:
134
+ param_groups: Parameters grouped by sync requirements.
135
+ bucket_size_mb: Max size for each bucket in megabytes.
136
+ require_accumulations: Number of gradient accumulations required before syncing gradients.
137
+
138
+ Returns:
139
+ List of configured gradient buckets.
140
+ """
141
+
142
+ # TODO: Better grouping - probably we could trace autograd graph and use some topological clustering here
143
+ # TODO: to maximize overlap even better - current implementation just iterates over parameters in reverse order
144
+ buckets = []
145
+
146
+ bucket_size = bucket_size_mb * 1024 * 1024
147
+
148
+ for param_group_marker, param_group in param_groups.items():
149
+ current_bucket_size = 0
150
+ unfinished_bucket: list[nn.Parameter] = []
151
+ for param in param_group:
152
+ param_bytes = param.numel() * param.element_size()
153
+ if current_bucket_size + param_bytes >= bucket_size and unfinished_bucket:
154
+ buckets.append(_make_bucket(
155
+ require_accumulations=require_accumulations,
156
+ group_marker=param_group_marker,
157
+ parameters=unfinished_bucket,
158
+ communicate_stream=communicate_stream
159
+ ))
160
+ unfinished_bucket = []
161
+ current_bucket_size = 0
162
+
163
+ unfinished_bucket.append(param)
164
+ current_bucket_size += param_bytes
165
+
166
+ if unfinished_bucket:
167
+ buckets.append(_make_bucket(
168
+ require_accumulations=require_accumulations,
169
+ group_marker=param_group_marker,
170
+ parameters=unfinished_bucket,
171
+ communicate_stream=communicate_stream
172
+ ))
173
+ return buckets
174
+
175
+
176
+ class GradientSynchronizer:
177
+ """
178
+ Manages gradient synchronization for distributed training.
179
+
180
+ This class handles the bucketing of parameters, memory allocation for flat
181
+ gradient buffers, and the orchestration of asynchronous all-reduce operations
182
+ during the backward pass.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ param_groups: list[list[nn.Parameter]],
188
+ bucket_size_mb: int,
189
+ require_accumulations: int
190
+ ):
191
+ """
192
+ Constructs a GradientSynchronizer.
193
+
194
+ Args:
195
+ param_groups: List of parameter groups.
196
+ bucket_size_mb: Maximal size of a single gradient bucket in MB.
197
+ require_accumulations: Number of micro-batches to accumulate before reducing.
198
+ """
199
+
200
+ self._param_groups = param_groups
201
+ self._bucket_size_mb = bucket_size_mb
202
+ self._require_accumulations = require_accumulations
203
+
204
+ self._communicate_stream: torch.cuda.Stream | None = None
205
+ self._can_sync: bool
206
+ self._buckets: list[AbstractGradientBucket] = []
207
+
208
+ def bind(self):
209
+ """
210
+ Initializes the synchronizer for training.
211
+
212
+ Groups parameters, creates buckets, allocates memory, and registers hooks.
213
+ Must be called before the backward pass.
214
+ """
215
+
216
+ stream = torch.cuda.Stream()
217
+ self._communicate_stream = stream
218
+ self._buckets = _fill_buckets(
219
+ _group_params_for_buckets(self._param_groups),
220
+ bucket_size_mb=self._bucket_size_mb,
221
+ require_accumulations=self._require_accumulations,
222
+ communicate_stream=stream
223
+ )
224
+
225
+ for bucket in self._buckets:
226
+ bucket.bind()
227
+
228
+ def unbind(self):
229
+ """
230
+ Releases resources.
231
+
232
+ Destroys buckets, frees memory buffers, and removes hooks.
233
+ """
234
+
235
+ for bucket in self._buckets:
236
+ bucket.unbind()
237
+
238
+ self._buckets = []
239
+ self._communicate_stream = None
240
+
241
+ def wait(self):
242
+ """
243
+ Waits for all bucket operations (async reductions) to complete.
244
+ """
245
+
246
+ torch.cuda.current_stream().wait_stream(self._communicate_stream)
247
+
248
+ for bucket in self._buckets:
249
+ bucket.mark_sync()
250
+
251
+ def zero_grad(self):
252
+ """
253
+ Resets gradients and accumulation counters for all managed parameters.
254
+ """
255
+
256
+ for bucket in self._buckets:
257
+ bucket.zero_grad()
@@ -0,0 +1,14 @@
1
+ """
2
+ Pipeline State management package.
3
+
4
+ This package provides mechanisms to store, retrieve, and synchronize state
5
+ across different stages of a distributed pipeline, providing global and sharded view for these states.
6
+ """
7
+
8
+ from .api import PipelineState
9
+ from .handler import PipelineStateHandler
10
+
11
+ __all__ = [
12
+ "PipelineState",
13
+ "PipelineStateHandler"
14
+ ]
@@ -0,0 +1,45 @@
1
+ import abc
2
+ from typing import Any
3
+
4
+
5
+ class PipelineState(abc.ABC):
6
+ """
7
+ Object representing the state of a pipeline.
8
+
9
+ This class defines the interface for accessing state variables like a dictionary,
10
+ abstracting away whether the underlying storage is local, sharded, or global.
11
+ """
12
+
13
+ @abc.abstractmethod
14
+ def __setitem__(self, key: str, value: Any):
15
+ """
16
+ Sets a state value for a given key.
17
+
18
+ Args:
19
+ key: The identifier for the state variable.
20
+ value: The value to store.
21
+ """
22
+
23
+ @abc.abstractmethod
24
+ def __getitem__(self, item: str) -> Any:
25
+ """
26
+ Retrieves a state value for a given key.
27
+
28
+ Args:
29
+ item: The identifier for the state variable.
30
+
31
+ Returns:
32
+ The value associated with the key.
33
+ """
34
+
35
+ @abc.abstractmethod
36
+ def __contains__(self, item: str) -> bool:
37
+ """
38
+ Checks if a key exists in the state.
39
+
40
+ Args:
41
+ item: The identifier to check.
42
+
43
+ Returns:
44
+ True if the key exists, False otherwise.
45
+ """
@@ -0,0 +1,111 @@
1
+ from typing import Any
2
+
3
+ from d9d.core.sharding import ShardingSpecLeaf
4
+
5
+ from .api import PipelineState
6
+ from .storage import PipelineStateStorage
7
+
8
+
9
+ class PipelineStateGlobal(PipelineState):
10
+ """
11
+ Represents the global (unsharded) view of the pipeline state.
12
+ """
13
+
14
+ def __init__(self, storage: PipelineStateStorage):
15
+ """
16
+ Constructs a PipelineStateGlobal object.
17
+
18
+ Args:
19
+ storage: The underlying storage backend.
20
+ """
21
+
22
+ self._storage = storage
23
+
24
+ def __setitem__(self, key: str, value: Any):
25
+ self._storage.store_global((key,), value)
26
+
27
+ def __getitem__(self, item: str) -> Any:
28
+ return self._storage.acquire_global((item,))
29
+
30
+ def __contains__(self, item: str) -> bool:
31
+ return self._storage.contains((item,))
32
+
33
+
34
+ class PipelineStateShard(PipelineState):
35
+ """
36
+ Represents a sharded view of the pipeline state for a specific shard ID.
37
+ """
38
+
39
+ def __init__(self, storage: PipelineStateStorage, current_shard: int):
40
+ """
41
+ Constructs a PipelineStateShard object.
42
+
43
+ Args:
44
+ storage: The underlying storage backend.
45
+ current_shard: The index of the partial shard this view represents.
46
+ """
47
+
48
+ self._storage = storage
49
+ self._current_shard = current_shard
50
+
51
+ def __setitem__(self, key: str, value: Any):
52
+ self._storage.store_shard((key,), value, self._current_shard)
53
+
54
+ def __getitem__(self, item: str) -> Any:
55
+ return self._storage.acquire_shard((item,), self._current_shard)
56
+
57
+ def __contains__(self, item: str) -> bool:
58
+ return self._storage.contains((item,))
59
+
60
+
61
+ class PipelineStateHandler:
62
+ """
63
+ Manages the lifecycle and access patterns of pipeline states.
64
+
65
+ This handler initializes the underlying storage and provides specific views
66
+ (global or sharded) into that storage.
67
+ """
68
+
69
+ def __init__(self, sharding_spec: dict[str, ShardingSpecLeaf], num_shards: int):
70
+ """
71
+ Constructs a PipelineStateHandler object.
72
+
73
+ Args:
74
+ sharding_spec: A definition of how specific keys should be sharded.
75
+ num_shards: The total number of shards in the pipeline.
76
+ """
77
+
78
+ self._storage = PipelineStateStorage(
79
+ sharding_spec={(k,): v for k, v in sharding_spec.items()},
80
+ num_shards=num_shards
81
+ )
82
+
83
+ def global_state(self) -> PipelineState:
84
+ """
85
+ Returns a view interface for accessing global state.
86
+
87
+ Returns:
88
+ A PipelineState interface that accesses the full, aggregated data.
89
+ """
90
+
91
+ return PipelineStateGlobal(self._storage)
92
+
93
+ def sharded_state(self, shard_id: int) -> PipelineState:
94
+ """
95
+ Returns a view interface for accessing state specific to a shard ID.
96
+
97
+ Args:
98
+ shard_id: The index of the shard to access.
99
+
100
+ Returns:
101
+ A PipelineState interface that accesses partial data for the given shard.
102
+ """
103
+
104
+ return PipelineStateShard(self._storage, shard_id)
105
+
106
+ def reset(self):
107
+ """
108
+ Resets the underlying storage, clearing all state.
109
+ """
110
+
111
+ self._storage.reset()
@@ -0,0 +1,236 @@
1
+ import copy
2
+ from collections import UserDict
3
+ from typing import Any, TypeVar, cast
4
+
5
+ import torch
6
+ import torch.utils._pytree as pytree # noqa: PLC2701
7
+
8
+ from d9d.core.sharding import ShardingSpecLeaf, SpecReplicate, SpecShard, shard_tree, unshard_tree
9
+
10
+ StateKey = tuple[str, ...]
11
+
12
+
13
+ TMap = TypeVar("TMap")
14
+
15
+
16
+ def _detach_leaf(x: TMap) -> TMap:
17
+ """
18
+ Detaches a tensor from the computation graph if the input is a tensor.
19
+
20
+ Args:
21
+ x: The input object.
22
+
23
+ Returns:
24
+ The detached tensor or original object.
25
+ """
26
+
27
+ if isinstance(x, torch.Tensor):
28
+ return cast(TMap, x.detach())
29
+ return x
30
+
31
+
32
+ class ShardedState(UserDict):
33
+ """
34
+ Container for holding state broken down by shard indices.
35
+ """
36
+
37
+
38
+ class PipelineStateStorage:
39
+ """
40
+ Low-level storage backend handling sharding and aggregation of state data.
41
+
42
+ This class manages the transition between sharded data
43
+ and global data. It uses sharding specifications to determine
44
+ how to split or join data.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ sharding_spec: dict[StateKey, ShardingSpecLeaf],
50
+ num_shards: int
51
+ ):
52
+ """
53
+ Constructs a PipelineStateStorage object.
54
+
55
+ Args:
56
+ sharding_spec: Dictionary mapping state keys to their sharding behaviors.
57
+ num_shards: Total number of shards involved in the storage.
58
+ """
59
+
60
+ self._sharding_spec_orig = copy.deepcopy(sharding_spec)
61
+
62
+ self._state: dict[StateKey, Any] = {}
63
+ self._state_sharding_spec: dict[StateKey, ShardingSpecLeaf] = {}
64
+
65
+ self._num_shards = num_shards
66
+
67
+ def _guess_sharding_spec_for_shard(self, key: StateKey, shard: Any) -> ShardingSpecLeaf:
68
+ # Stack if scalar (tensor or item), cat otherwise
69
+
70
+ if key in self._sharding_spec_orig:
71
+ return self._sharding_spec_orig[key]
72
+
73
+ if isinstance(shard, torch.Tensor):
74
+ do_stack = shard.ndim == 0
75
+ return SpecShard(dim=0, do_stack=do_stack)
76
+ elif isinstance(shard, list):
77
+ return SpecShard(dim=0)
78
+ else:
79
+ return SpecShard(dim=0, do_stack=True)
80
+
81
+ def _guess_sharding_spec_for_global(self, key: StateKey, state: Any) -> ShardingSpecLeaf:
82
+ if key in self._sharding_spec_orig:
83
+ return self._sharding_spec_orig[key]
84
+
85
+ if isinstance(state, torch.Tensor):
86
+ if state.ndim == 0:
87
+ return SpecReplicate()
88
+ else:
89
+ return SpecShard(dim=0)
90
+ elif isinstance(state, list):
91
+ return SpecShard(dim=0)
92
+ else:
93
+ return SpecReplicate()
94
+
95
+ def store_global(self, key: StateKey, state: Any):
96
+ """
97
+ Stores a value in the global scope.
98
+
99
+ If the key does not have a sharding spec, one will be inferred. Detaches tensors.
100
+
101
+ Args:
102
+ key: The identifier key.
103
+ state: The unified value to store.
104
+ """
105
+
106
+ state = pytree.tree_map(_detach_leaf, state)
107
+
108
+ if key not in self._state_sharding_spec:
109
+ self._state_sharding_spec[key] = self._guess_sharding_spec_for_global(key, state)
110
+
111
+ self._state[key] = state
112
+
113
+ def store_shard(self, key: StateKey, state: Any, shard_id: int):
114
+ """
115
+ Stores a value for a specific shard index.
116
+
117
+ Raises error if attempting to shard an already global key without conversion.
118
+
119
+ Args:
120
+ key: The identifier key.
121
+ state: The partial value for the shard.
122
+ shard_id: The index of the shard.
123
+
124
+ Raises:
125
+ ValueError: If trying to store sharded state into an unsharded container.
126
+ """
127
+
128
+ if key not in self._state:
129
+ self._state[key] = ShardedState()
130
+
131
+ container = self._state[key]
132
+
133
+ if not isinstance(container, ShardedState):
134
+ raise ValueError(f"Trying to store sharded state into an unsharded one: {key}")
135
+
136
+ state = pytree.tree_map(_detach_leaf, state)
137
+
138
+ # dynamically populate sharding spec to know whether it is stacking or not
139
+ if key not in self._state_sharding_spec:
140
+ self._state_sharding_spec[key] = self._guess_sharding_spec_for_shard(key, state)
141
+
142
+ self._state[key][shard_id] = state
143
+
144
+ def _ensure_global(self, key: StateKey):
145
+ if key not in self._state:
146
+ raise ValueError(f"Cannot access non-existing state {key}")
147
+
148
+ state = self._state[key]
149
+
150
+ if not isinstance(state, ShardedState):
151
+ return
152
+
153
+ # here we know we are in ShardedState
154
+
155
+ shards = [state[shard_id] for shard_id in range(self._num_shards)]
156
+ resharded = unshard_tree(shards, self._state_sharding_spec[key])
157
+
158
+ self._state[key] = resharded
159
+
160
+ def _ensure_sharded(self, key: StateKey):
161
+ if key not in self._state:
162
+ raise ValueError(f"Cannot access non-existing state {key}")
163
+
164
+ state = self._state[key]
165
+
166
+ if isinstance(state, ShardedState):
167
+ return
168
+
169
+ # here we know we are in global state
170
+
171
+ sharded = shard_tree(
172
+ state,
173
+ self._state_sharding_spec[key],
174
+ num_shards=self._num_shards,
175
+ enforce_even_split=True
176
+ )
177
+
178
+ sharded_state = ShardedState({
179
+ shard_idx: shard for shard_idx, shard in enumerate(sharded)
180
+ })
181
+
182
+ self._state[key] = sharded_state
183
+
184
+ def acquire_global(self, key: StateKey) -> Any:
185
+ """
186
+ Retrieves data for a key in its global form.
187
+
188
+ Args:
189
+ key: The state key.
190
+
191
+ Returns:
192
+ The aggregated global data.
193
+ """
194
+
195
+ self._ensure_global(key)
196
+ return self._state[key]
197
+
198
+ def acquire_shard(self, key: StateKey, shard: int) -> Any:
199
+ """
200
+ Retrieves data for a key specific to a shard index.
201
+
202
+ Args:
203
+ key: The state key.
204
+ shard: The shard index.
205
+
206
+ Returns:
207
+ The data slice corresponding to the shard.
208
+ """
209
+
210
+ self._ensure_sharded(key)
211
+ state = self._state[key]
212
+
213
+ if isinstance(state, ShardedState):
214
+ return state[shard]
215
+ else:
216
+ return state
217
+
218
+ def contains(self, key: StateKey) -> bool:
219
+ """
220
+ Checks if a key exists in storage.
221
+
222
+ Args:
223
+ key: The state key.
224
+
225
+ Returns:
226
+ True if present.
227
+ """
228
+
229
+ return key in self._state
230
+
231
+ def reset(self):
232
+ """
233
+ Clears all stored state.
234
+ """
235
+
236
+ self._state.clear()
@@ -0,0 +1,7 @@
1
+ """Exposes the internal distributed profiler."""
2
+
3
+ from .profile import Profiler
4
+
5
+ __all__ = [
6
+ "Profiler"
7
+ ]