d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,68 @@
1
+ from typing import TypeVar, cast
2
+
3
+ import torch.distributed as dist
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ def gather_object(
9
+ obj: T,
10
+ group: dist.ProcessGroup,
11
+ group_dst: int
12
+ ) -> list[T] | None:
13
+ """
14
+ Gathers picklable objects from the whole process group to a specific destination rank.
15
+
16
+ This acts as a wrapper around torch.distributed.gather_object that automatically
17
+ initializes the output buffer list on the destination rank.
18
+
19
+ Args:
20
+ obj: The local object to send. Must be picklable.
21
+ group: The process group to work on.
22
+ group_dst: The rank within the group that will receive the objects.
23
+
24
+ Returns:
25
+ A list of objects from all ranks on the destination rank; None on other ranks.
26
+ """
27
+
28
+ if group.rank() == group_dst:
29
+ # We initialize with None, but we cast to list[T] because we know
30
+ # dist.gather_object will populate these slots with actual objects.
31
+ save_list = cast(list[T], [None for _ in range(group.size())])
32
+ else:
33
+ save_list = None
34
+ dist.gather_object(
35
+ obj,
36
+ save_list,
37
+ group=group,
38
+ group_dst=group_dst
39
+ )
40
+ return save_list
41
+
42
+
43
+ def all_gather_object(
44
+ obj: T,
45
+ group: dist.ProcessGroup
46
+ ) -> list[T]:
47
+ """
48
+ Gathers picklable objects from the whole process group to all ranks.
49
+
50
+ This acts as a wrapper around torch.distributed.all_gather_object that automatically
51
+ initializes the output buffer list on all ranks.
52
+
53
+ Args:
54
+ obj: The local object to send. Must be picklable.
55
+ group: The process group to work on.
56
+
57
+ Returns:
58
+ A list of objects containing the data gathered from all ranks.
59
+ """
60
+ # We initialize with None, but we cast to list[T] because we know
61
+ # dist.gather_object will populate these slots with actual objects.
62
+ save_list = cast(list[T], [None for _ in range(group.size())])
63
+ dist.all_gather_object(
64
+ save_list,
65
+ obj,
66
+ group=group
67
+ )
68
+ return save_list
@@ -0,0 +1,192 @@
1
+ from collections.abc import Sequence
2
+ from typing import cast
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ def gather(
9
+ tensor: torch.Tensor,
10
+ group: dist.ProcessGroup,
11
+ group_dst: int,
12
+ async_op: bool = False
13
+ ) -> list[torch.Tensor] | tuple[list[torch.Tensor] | None, dist.Work] | None:
14
+ """
15
+ Gathers tensors from the process group to a specific destination rank.
16
+
17
+ This function assumes that tensors on all ranks have the same shape and dtype
18
+ as the tensor on the current rank. It automatically allocates the output
19
+ buffer list on the destination.
20
+
21
+ Args:
22
+ tensor: The local tensor to send.
23
+ group: The process group to work on.
24
+ group_dst: The rank within the group that will receive the tensors.
25
+ async_op: Whether the operation should be asynchronous.
26
+
27
+ Returns:
28
+ If async_op is False: A list of tensors on the destination rank, None elsewhere.
29
+ If async_op is True: A tuple containing (buffer_list, work_handle).
30
+ """
31
+
32
+ if group.rank() == group_dst:
33
+ save_list = [torch.empty_like(tensor) for _ in range(group.size())]
34
+ else:
35
+ save_list = None
36
+
37
+ work = dist.gather(
38
+ tensor,
39
+ save_list,
40
+ group=group,
41
+ group_dst=group_dst,
42
+ async_op=async_op
43
+ )
44
+
45
+ if async_op:
46
+ return save_list, work
47
+ else:
48
+ return save_list
49
+
50
+
51
+ def all_gather(
52
+ tensor: torch.Tensor,
53
+ group: dist.ProcessGroup,
54
+ async_op: bool = False
55
+ ) -> list[torch.Tensor] | tuple[list[torch.Tensor], dist.Work]:
56
+ """
57
+ Gathers tensors from the whole process group to all ranks.
58
+
59
+ This function assumes that tensors on all ranks have the same shape and dtype
60
+ as the tensor on the current rank. It automatically allocates the output
61
+ buffer list.
62
+
63
+ Args:
64
+ tensor: The local tensor to send.
65
+ group: The process group to work on.
66
+ async_op: Whether the operation should be asynchronous.
67
+
68
+ Returns:
69
+ If async_op is False: A list of gathered tensors.
70
+ If async_op is True: A tuple containing (buffer_list, work_handle).
71
+ """
72
+
73
+ save_list = [torch.empty_like(tensor) for _ in range(group.size())]
74
+ work = dist.all_gather(
75
+ save_list,
76
+ tensor,
77
+ group=group,
78
+ async_op=async_op
79
+ )
80
+ if async_op:
81
+ return save_list, work
82
+ else:
83
+ return save_list
84
+
85
+
86
+ def _all_gather_shapes(
87
+ tensor: torch.Tensor,
88
+ group: dist.ProcessGroup,
89
+ ) -> Sequence[torch.Tensor]:
90
+ all_ndim = [torch.empty((), dtype=torch.long, device=tensor.device) for _ in range(group.size())]
91
+ all_ndim_wait = dist.all_gather(
92
+ all_ndim,
93
+ torch.tensor(tensor.ndim, dtype=torch.long, device=tensor.device),
94
+ group=group,
95
+ async_op=True
96
+ )
97
+ all_ndim_wait.wait()
98
+
99
+ all_shape = [torch.empty(cast(int, ndim.item()), dtype=torch.long, device=tensor.device) for ndim in all_ndim]
100
+ all_shape_wait = dist.all_gather(
101
+ all_shape,
102
+ torch.tensor(tensor.shape, dtype=torch.long, device=tensor.device),
103
+ group=group,
104
+ async_op=True
105
+ )
106
+ all_shape_wait.wait()
107
+
108
+ return all_shape
109
+
110
+
111
+ def all_gather_variadic_shape(
112
+ tensor: torch.Tensor,
113
+ group: dist.ProcessGroup,
114
+ async_op: bool = False
115
+ ) -> list[torch.Tensor] | tuple[list[torch.Tensor], dist.Work]:
116
+ """
117
+ Gathers tensors of different shapes from the whole process group to all ranks.
118
+
119
+ Unlike standard all_gather, this function first communicates the shape of the
120
+ tensor on every rank allowing for dynamic sizing.
121
+
122
+ Args:
123
+ tensor: The local tensor to send.
124
+ group: The process group to work on.
125
+ async_op: Whether the final data gathering should be asynchronous.
126
+ Note that shape gathering is always synchronous.
127
+
128
+ Returns:
129
+ If async_op is False: A list of gathered tensors of varying shapes.
130
+ If async_op is True: A tuple containing (buffer_list, work_handle).
131
+ """
132
+
133
+ all_shape = _all_gather_shapes(tensor, group)
134
+
135
+ all_result = [torch.empty(tuple(shape), dtype=tensor.dtype, device=tensor.device) for shape in all_shape]
136
+ all_result_wait = dist.all_gather(
137
+ all_result,
138
+ tensor,
139
+ group=group,
140
+ async_op=async_op
141
+ )
142
+ if async_op:
143
+ return all_result, all_result_wait
144
+ else:
145
+ return all_result
146
+
147
+
148
+ def gather_variadic_shape(
149
+ tensor: torch.Tensor,
150
+ group: dist.ProcessGroup,
151
+ group_dst: int
152
+ ) -> list[torch.Tensor] | None:
153
+ """
154
+ Gathers tensors of different shapes from the process group to a specific rank.
155
+
156
+ This function coordinates shape exchange and uses point-to-point communication
157
+ (isend/irecv) to gather tensors that may differ in shape across ranks.
158
+
159
+ Currently, does not support async_op.
160
+
161
+ Args:
162
+ tensor: The local tensor to send.
163
+ group: The process group to work on.
164
+ group_dst: The rank within the group that will receive the tensors.
165
+
166
+ Returns:
167
+ A list of tensors of varying shapes on the destination rank; None on other ranks.
168
+ """
169
+
170
+ is_current_dst = group.rank() == group_dst
171
+
172
+ all_shape = _all_gather_shapes(tensor, group)
173
+
174
+ if is_current_dst:
175
+ all_recv_futures: list[dist.Work] = []
176
+ all_result: list[torch.Tensor] = cast(list[torch.Tensor], [None for _ in range(group.size())])
177
+ for group_src_i in range(group.size()):
178
+ if group_src_i == group_dst:
179
+ all_result[group_src_i] = tensor
180
+ continue
181
+ all_result[group_src_i] = torch.empty(
182
+ tuple(all_shape[group_src_i]), dtype=tensor.dtype, device=tensor.device
183
+ )
184
+ all_recv_future = dist.irecv(all_result[group_src_i], group=group, group_src=group_src_i)
185
+ all_recv_future = cast(dist.Work, all_recv_future) # we know we are on dst rank
186
+ all_recv_futures.append(all_recv_future)
187
+ for recv_future in all_recv_futures:
188
+ recv_future.wait()
189
+ return all_result
190
+ else:
191
+ dist.isend(tensor=tensor, group=group, group_dst=group_dst)
192
+ return None
@@ -0,0 +1,8 @@
1
+ """Package providing protocol definitions for standard PyTorch objects."""
2
+
3
+ from .training import LRSchedulerProtocol, OptimizerProtocol
4
+
5
+ __all__ = [
6
+ "LRSchedulerProtocol",
7
+ "OptimizerProtocol"
8
+ ]
@@ -0,0 +1,38 @@
1
+ from typing import Protocol, runtime_checkable
2
+
3
+ from torch.distributed.checkpoint.stateful import Stateful
4
+
5
+
6
+ @runtime_checkable
7
+ class OptimizerProtocol(Protocol, Stateful):
8
+ """
9
+ Protocol defining an interface for standard PyTorch Optimizer object.
10
+
11
+ This protocol ensures that the wrapped optimizer supports standard
12
+ API and state checkpointing via the Stateful interface.
13
+ """
14
+
15
+ def step(self):
16
+ """Performs a single optimization step."""
17
+
18
+ ...
19
+
20
+ def zero_grad(self):
21
+ """Sets the gradients of all optimized tensors to zero."""
22
+
23
+ ...
24
+
25
+
26
+ @runtime_checkable
27
+ class LRSchedulerProtocol(Protocol, Stateful):
28
+ """
29
+ Protocol defining an interface for a Learning Rate Scheduler.
30
+
31
+ This protocol ensures that the wrapped scheduler supports stepping
32
+ and state checkpointing via the Stateful interface.
33
+ """
34
+
35
+ def step(self):
36
+ """Performs a single learning rate scheduling step."""
37
+
38
+ ...
@@ -0,0 +1,15 @@
1
+ from .auto_spec import shard_spec_nothing, shard_spec_on_dim
2
+ from .shard import shard_tree
3
+ from .spec import ShardingSpec, ShardingSpecLeaf, SpecReplicate, SpecShard
4
+ from .unshard import unshard_tree
5
+
6
+ __all__ = [
7
+ "ShardingSpec",
8
+ "ShardingSpecLeaf",
9
+ "SpecReplicate",
10
+ "SpecShard",
11
+ "shard_spec_nothing",
12
+ "shard_spec_on_dim",
13
+ "shard_tree",
14
+ "unshard_tree"
15
+ ]
@@ -0,0 +1,66 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ import torch.utils._pytree as pytree # noqa: PLC2701
5
+
6
+ from d9d.core.types import PyTree
7
+
8
+ from .spec import ShardingSpec, ShardingSpecLeaf, SpecReplicate, SpecShard
9
+
10
+
11
+ def _tree_item_to_shard(item: Any, shard_on_dim: int) -> ShardingSpecLeaf:
12
+ if isinstance(item, list):
13
+ if shard_on_dim != 0:
14
+ raise ValueError(f"Cannot shard list on dim {shard_on_dim}. Lists behave as 1D sequences.")
15
+ return SpecShard(0)
16
+ elif isinstance(item, torch.Tensor):
17
+ if item.ndim == 0:
18
+ return SpecReplicate()
19
+ if item.ndim <= shard_on_dim:
20
+ raise ValueError(f"Cannot shard {item.ndim}-dimensional tensor on dim {shard_on_dim}")
21
+ return SpecShard(shard_on_dim)
22
+ else:
23
+ return SpecReplicate()
24
+
25
+
26
+ def shard_spec_on_dim(tree: PyTree[Any], dim: int) -> ShardingSpec:
27
+ """
28
+ Creates a sharding specification to split all tensors in the tree on a specific dimension.
29
+
30
+ Iterates over the input tree:
31
+ * If a leaf is a Tensor with enough dimensions, it is mapped to a SpecShard(dim) object.
32
+ * If a leaf is a list, it is mapped to a SpecShard(0) object (only dim=0 is allowed for lists).
33
+ * Other types and 0-dim tensors are mapped to SpecReplicate.
34
+
35
+ Args:
36
+ tree: The input PyTree structure.
37
+ dim: The dimension index to shard eligible tensors on.
38
+
39
+ Returns:
40
+ A new PyTree matching the input structure, containing SpecShard or SpecReplicate objects.
41
+
42
+ Raises:
43
+ ValueError: If a tensor exists in the tree with rank less than or equal to 'dim'.
44
+ """
45
+
46
+ return pytree.tree_map(
47
+ lambda x: _tree_item_to_shard(x, dim),
48
+ tree,
49
+ is_leaf=lambda x: isinstance(x, (torch.Tensor, list))
50
+ )
51
+
52
+
53
+ def shard_spec_nothing(tree: PyTree[Any]) -> ShardingSpec:
54
+ """
55
+ Creates a sharding specification where no sharding is performed.
56
+
57
+ This effectively clones the tree structure but replaces every leaf with SpecReplicate.
58
+
59
+ Args:
60
+ tree: The input PyTree structure.
61
+
62
+ Returns:
63
+ A new PyTree matching the input structure, containing strictly SpecReplicate for all leaves.
64
+ """
65
+
66
+ return pytree.tree_map(lambda _: SpecReplicate(), tree, is_leaf=lambda x: isinstance(x, (torch.Tensor, list)))
@@ -0,0 +1,154 @@
1
+ from collections.abc import Sequence
2
+ from typing import TypeVar, cast
3
+
4
+ import torch
5
+ import torch.utils._pytree as pytree # noqa: PLC2701
6
+
7
+ from d9d.core.types import PyTree
8
+
9
+ from .spec import ShardingSpec, SpecReplicate, SpecShard
10
+
11
+ TLeaf = TypeVar("TLeaf")
12
+ TSameTree = TypeVar("TSameTree", bound=PyTree)
13
+
14
+
15
+ def _shard_list(
16
+ item: list[TLeaf],
17
+ spec: SpecShard,
18
+ num_shards: int,
19
+ enforce_even_split: bool
20
+ ) -> Sequence[list[TLeaf] | TLeaf]:
21
+ if spec.dim != 0:
22
+ raise ValueError(f"Lists can only be sharded on dim 0, got {spec.dim}")
23
+
24
+ if spec.do_stack:
25
+ if len(item) != num_shards:
26
+ raise ValueError(
27
+ f"do_stack=True requires list length ({len(item)}) to match num_shards ({num_shards})"
28
+ )
29
+ return item
30
+
31
+ if enforce_even_split and len(item) % num_shards != 0:
32
+ raise ValueError(
33
+ f"Tried to shard a list with length {len(item)} "
34
+ f"to {num_shards} shards, but the length is not perfectly divisible."
35
+ )
36
+
37
+ shard_size, shard_extra = divmod(len(item), num_shards)
38
+ return [
39
+ item[
40
+ shard_id * shard_size + min(shard_id, shard_extra):
41
+ (shard_id + 1) * shard_size + min(shard_id + 1, shard_extra)
42
+ ]
43
+ for shard_id in range(num_shards)
44
+ ]
45
+
46
+
47
+ def _shard_tensor(
48
+ item: torch.Tensor,
49
+ spec: SpecShard,
50
+ num_shards: int,
51
+ enforce_even_split: bool
52
+ ) -> Sequence[torch.Tensor]:
53
+ if item.ndim == 0:
54
+ raise ValueError("Found a 0-dim Tensor for sharding")
55
+
56
+ if spec.do_stack:
57
+ if item.shape[spec.dim] != num_shards:
58
+ raise ValueError(
59
+ f"do_stack=True requires tensor shape[{spec.dim}] ({item.shape[spec.dim]}) "
60
+ f"to match num_shards ({num_shards})"
61
+ )
62
+ return torch.unbind(item, dim=spec.dim)
63
+
64
+ if enforce_even_split and item.shape[spec.dim] % num_shards != 0:
65
+ raise ValueError(
66
+ f"Tried to shard a tensor with shape {item.shape} on dim {spec.dim} "
67
+ f"to {num_shards} shards, but the dimension is not perfectly divisible."
68
+ )
69
+
70
+ return torch.tensor_split(item, sections=num_shards, dim=spec.dim)
71
+
72
+
73
+ def _shard_leaf_to_list(
74
+ item: TLeaf,
75
+ spec: SpecShard | SpecReplicate,
76
+ num_shards: int,
77
+ enforce_even_split: bool
78
+ ) -> Sequence[TLeaf]:
79
+ """Helper to split an item into a list of items for each rank."""
80
+ if isinstance(spec, SpecReplicate):
81
+ # Replicated: strict copy reference for all shards
82
+ return [item] * num_shards
83
+
84
+ if not isinstance(spec, SpecShard):
85
+ raise TypeError(f"Unknown sharding spec object type: {type(spec)}")
86
+
87
+ if isinstance(item, torch.Tensor):
88
+ return cast(Sequence[TLeaf], _shard_tensor(
89
+ item=item,
90
+ num_shards=num_shards,
91
+ enforce_even_split=enforce_even_split,
92
+ spec=spec
93
+ ))
94
+ elif isinstance(item, list):
95
+ return cast(Sequence[TLeaf], _shard_list(
96
+ item=item,
97
+ num_shards=num_shards,
98
+ enforce_even_split=enforce_even_split,
99
+ spec=spec
100
+ ))
101
+ else:
102
+ raise TypeError(
103
+ f"Sharding spec found a SpecShard object, but the item was not a Tensor and not a list (got {type(item)})"
104
+ )
105
+
106
+
107
+ def shard_tree(
108
+ tree: TSameTree,
109
+ sharding_spec: ShardingSpec,
110
+ num_shards: int,
111
+ enforce_even_split: bool
112
+ ) -> tuple[TSameTree, ...]:
113
+ """
114
+ Shards a PyTree into a tuple of PyTrees, one for each shard rank.
115
+
116
+ This function takes a single global data structure and splits it into `num_shards`
117
+ structures.
118
+
119
+ * If a spec leaf is a ``SpecShard(dim)``, the tensor or list is split along that dimension,
120
+ and the ``i``-th slice goes to the ``i``-th output tree.
121
+ * If a spec leaf is ``SpecReplicate``, the item is replicated (reference copy) to all
122
+ output trees.
123
+
124
+ Args:
125
+ tree: The structure containing tensors to be sharded.
126
+ sharding_spec: A structure matching 'tree' containing ``SpecShard`` or ``SpecReplicate`` objects.
127
+ num_shards: The total number of shards to split the tensors into.
128
+ enforce_even_split: If True, raises a ValueError if a tensor's dimension
129
+ size is not perfectly divisible by ``num_shards``.
130
+
131
+ Returns:
132
+ A tuple of length ``num_shards``. Each element is a PyTree matching
133
+ the structure of the input ``tree``, containing the local data for
134
+ that specific rank.
135
+
136
+ Raises:
137
+ ValueError: If tree structures do not match, or valid sharding conditions
138
+ are not met.
139
+ """
140
+ flat_spec, spec_struct = pytree.tree_flatten(sharding_spec)
141
+
142
+ try:
143
+ flat_tree = spec_struct.flatten_up_to(tree)
144
+ except (ValueError, TypeError) as e:
145
+ raise ValueError("Tree structure does not match sharding spec") from e
146
+
147
+ sharded_leaves_per_node = [
148
+ _shard_leaf_to_list(item, spec, num_shards, enforce_even_split)
149
+ for item, spec in zip(flat_tree, flat_spec, strict=True)
150
+ ]
151
+
152
+ rank_leaves = list(zip(*sharded_leaves_per_node, strict=True))
153
+
154
+ return tuple(spec_struct.unflatten(leaves) for leaves in rank_leaves)
@@ -0,0 +1,28 @@
1
+ import dataclasses
2
+
3
+ from d9d.core.types import PyTree
4
+
5
+
6
+ @dataclasses.dataclass(slots=True, frozen=True)
7
+ class SpecReplicate:
8
+ """
9
+ Specifies that a leaf node should be replicated across all shards.
10
+ """
11
+
12
+
13
+ @dataclasses.dataclass(slots=True, frozen=True)
14
+ class SpecShard:
15
+ """
16
+ Specifies that a leaf node should be split along a specific dimension.
17
+
18
+ Attributes:
19
+ dim: The dimension to split.
20
+ do_stack: If True, sharding will squeeze the sharded dimension (it should be exactly the num_shards length)
21
+ """
22
+
23
+ dim: int
24
+ do_stack: bool = False
25
+
26
+
27
+ ShardingSpecLeaf = SpecReplicate | SpecShard
28
+ ShardingSpec = PyTree[ShardingSpecLeaf]
@@ -0,0 +1,117 @@
1
+ from collections.abc import Sequence
2
+ from typing import TypeVar, cast
3
+
4
+ import torch
5
+ import torch.utils._pytree as pytree # noqa: PLC2701
6
+
7
+ from d9d.core.types import PyTree
8
+
9
+ from .spec import ShardingSpec, ShardingSpecLeaf, SpecReplicate, SpecShard
10
+
11
+ TLeaf = TypeVar("TLeaf")
12
+ TSameTree = TypeVar("TSameTree", bound=PyTree)
13
+
14
+
15
+ def _unshard_list(
16
+ group: Sequence[list[TLeaf] | TLeaf],
17
+ spec: SpecShard
18
+ ) -> list[TLeaf]:
19
+ if spec.dim != 0:
20
+ raise ValueError(f"Lists can only be unsharded on dim 0, got {spec.dim}")
21
+
22
+ if spec.do_stack:
23
+ return cast(list[TLeaf], list(group))
24
+
25
+ merged_list: list[TLeaf] = []
26
+ for x in group:
27
+ merged_list.extend(cast(list[TLeaf], x))
28
+ return merged_list
29
+
30
+
31
+ def _unshard_tensor(
32
+ group: list[torch.Tensor],
33
+ spec: SpecShard
34
+ ) -> torch.Tensor:
35
+ if spec.do_stack:
36
+ return torch.stack(group, dim=spec.dim)
37
+
38
+ return torch.cat(group, dim=spec.dim)
39
+
40
+
41
+ def _unshard_leaf_from_group(
42
+ group: Sequence[TLeaf],
43
+ spec: ShardingSpecLeaf
44
+ ) -> TLeaf:
45
+ """Helper to merge a group of items from different ranks into one."""
46
+ if isinstance(spec, SpecReplicate):
47
+ return group[0]
48
+
49
+ if not isinstance(spec, SpecShard):
50
+ raise TypeError(f"Unknown sharding spec object type: {type(spec)}")
51
+
52
+ first_item = group[0]
53
+
54
+ if isinstance(first_item, torch.Tensor):
55
+ return cast(TLeaf, _unshard_tensor(
56
+ cast(list[torch.Tensor], group),
57
+ spec
58
+ ))
59
+ elif spec.do_stack or isinstance(first_item, list):
60
+ return cast(TLeaf, _unshard_list(group, spec))
61
+ else:
62
+ raise TypeError(f"Expected Tensor or list instances, got {type(group[0])}")
63
+
64
+
65
+ def unshard_tree(
66
+ sharded_trees: Sequence[TSameTree],
67
+ sharding_spec: ShardingSpec
68
+ ) -> TSameTree:
69
+ """
70
+ Combines a sequence of PyTrees (one per rank) into a single global PyTree.
71
+
72
+ This is the inverse of ``shard_tree``. It iterates over the provided trees,
73
+ gathering corresponding leaves from each rank.
74
+
75
+ * If the spec for a leaf is ``SpecShard(dim)``, the tensors from all ranks are
76
+ concatenated (or stacked if ``do_stack=True``) along that dimension.
77
+ * If the spec is ``SpecReplicate``, it assumes the data is replicated
78
+ and takes the item from the first rank.
79
+
80
+ Args:
81
+ sharded_trees: A sequence (list or tuple) of PyTrees. There must be
82
+ one tree for each shard rank, and they must all share the same
83
+ structure as ``sharding_spec``.
84
+ sharding_spec: A structure matching the input trees containing
85
+ ``SpecShard`` or ``SpecReplicate`` objects.
86
+
87
+ Returns:
88
+ A single PyTree where distinct shards have been merged into full tensors.
89
+
90
+ Raises:
91
+ ValueError: If ``sharded_trees`` is empty, or if unit structures do
92
+ not match the spec.
93
+ """
94
+ if not sharded_trees:
95
+ raise ValueError("sharded_trees sequence cannot be empty")
96
+
97
+ flat_spec, spec_struct = pytree.tree_flatten(sharding_spec)
98
+
99
+ flat_shards_per_rank = []
100
+ for i, tree in enumerate(sharded_trees):
101
+ try:
102
+ leaves = spec_struct.flatten_up_to(tree)
103
+ except (ValueError, TypeError) as e:
104
+ raise ValueError(
105
+ f"Structure mismatch at shard {i}: tree does not match sharding spec structure"
106
+ ) from e
107
+
108
+ flat_shards_per_rank.append(leaves)
109
+
110
+ grouped_leaves = list(zip(*flat_shards_per_rank, strict=True))
111
+
112
+ reconstructed_leaves = [
113
+ _unshard_leaf_from_group(group, spec)
114
+ for group, spec in zip(grouped_leaves, flat_spec, strict=True)
115
+ ]
116
+
117
+ return spec_struct.unflatten(reconstructed_leaves)