d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,12 @@
1
+ """
2
+ Common type definitions used throughout the framework.
3
+ """
4
+ from .data import CollateFn
5
+ from .pytree import PyTree, ScalarTree, TensorTree
6
+
7
+ __all__ = [
8
+ "CollateFn",
9
+ "PyTree",
10
+ "ScalarTree",
11
+ "TensorTree"
12
+ ]
d9d/core/types/data.py ADDED
@@ -0,0 +1,14 @@
1
+ from collections.abc import Callable, Sequence
2
+ from typing import TypeAlias, TypeVar
3
+
4
+ from .pytree import PyTree
5
+
6
+ TDataTree = TypeVar("TDataTree", bound=PyTree)
7
+
8
+ CollateFn: TypeAlias = Callable[[Sequence[TDataTree]], TDataTree]
9
+ """
10
+ Type alias for a function that collates a sequence of samples into a batch.
11
+
12
+ The function receives a sequence of individual data point structures (PyTrees)
13
+ and is responsible for stacking or merging them into a single batched structure.
14
+ """
@@ -0,0 +1,26 @@
1
+ from typing import TypeAlias, TypeVar
2
+
3
+ import torch
4
+
5
+ TLeaf = TypeVar("TLeaf")
6
+
7
+ PyTree: TypeAlias = TLeaf | list["PyTree[TLeaf]"] | dict[str, "PyTree[TLeaf]"] | tuple["PyTree[TLeaf]", ...]
8
+ """
9
+ A recursive type definition representing a tree of data.
10
+
11
+ This type alias covers standard Python containers (dictionaries, lists, tuples)
12
+ nested arbitrarily deep, terminating in a leaf node of type `TLeaf`.
13
+
14
+ This is commonly used for handling nested state dictionaries or arguments
15
+ passed to functions that support recursive traversal (similar to `torch.utils._pytree`).
16
+ """
17
+
18
+ TensorTree: TypeAlias = PyTree[torch.Tensor]
19
+ """
20
+ A recursive tree structure where the leaf nodes are PyTorch Tensors.
21
+ """
22
+
23
+ ScalarTree: TypeAlias = PyTree[str | float | int | bool]
24
+ """
25
+ A recursive tree structure where the leaf nodes are python scalars (str, float, int).
26
+ """
@@ -0,0 +1,17 @@
1
+ """
2
+ This package provides utilities and torch.utils.data.Dataset implementations.
3
+ """
4
+
5
+ from .buffer_sorted import BufferSortedDataset, DatasetImplementingSortKeyProtocol
6
+ from .padding import PaddingSide1D, pad_stack_1d
7
+ from .sharded import ShardedDataset, ShardIndexingMode, shard_dataset_data_parallel
8
+
9
+ __all__ = [
10
+ "BufferSortedDataset",
11
+ "DatasetImplementingSortKeyProtocol",
12
+ "PaddingSide1D",
13
+ "ShardIndexingMode",
14
+ "ShardedDataset",
15
+ "pad_stack_1d",
16
+ "shard_dataset_data_parallel"
17
+ ]
@@ -0,0 +1,143 @@
1
+ import pickle # noqa: S403
2
+ import random
3
+ from typing import Any, Protocol, TypeVar
4
+
5
+ from torch.distributed.checkpoint.stateful import Stateful
6
+ from torch.utils.data import Dataset
7
+
8
+ _T_co = TypeVar("_T_co", covariant=True)
9
+
10
+
11
+ class DatasetImplementingSortKeyProtocol(Protocol[_T_co]):
12
+ """
13
+ Protocol for datasets that support retrieval of a specific key for sorting purposes.
14
+
15
+ This is typically used for length-based bucketing/sorting where the dataset
16
+ needs to expose the length of an item without loading the full item.
17
+ """
18
+
19
+ def __len__(self) -> int:
20
+ """Returns the total number of items in the dataset."""
21
+ ...
22
+
23
+ def sort_key(self, index: int) -> Any:
24
+ """
25
+ Returns a value used for sorting the dataset at the given index.
26
+
27
+ Args:
28
+ index: The index of the item.
29
+
30
+ Returns:
31
+ A comparable value (e.g., int length) used for sorting.
32
+ """
33
+ ...
34
+
35
+ def __getitem__(self, item: int) -> _T_co:
36
+ """Retrieves the item at the specific index."""
37
+ ...
38
+
39
+
40
+ class BufferSortedDataset(Dataset[_T_co], Stateful):
41
+ """
42
+ A dataset wrapper that groups items into buffers, sorts them, and yields them with local shuffling.
43
+
44
+ This prevents extreme padding in variable-length training (by grouping similar lengths)
45
+ while maintaining enough randomness to ensure statistical variance in updates.
46
+
47
+ Algorithm:
48
+
49
+ 1. Select a range of indices (size `buffer_size`).
50
+ 2. Sort these indices based on `base_dataset.sort_key()`.
51
+ 3. Break the sorted list into packs of size `pack_size`.
52
+ 4. Shuffle the order of these packs.
53
+ 5. Flatten the list and serve items.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ base_dataset: DatasetImplementingSortKeyProtocol[_T_co],
59
+ buffer_size: int,
60
+ pack_size: int,
61
+ init_seed: int | None = None
62
+ ):
63
+ """
64
+ Constructs a BufferSortedDataset object.
65
+
66
+ Args:
67
+ base_dataset: The underlying dataset implementing the `DatasetImplementingSortKeyProtocol` protocol.
68
+ buffer_size: The number of items to load into the buffer for sorting.
69
+ pack_size: The size of local groups (batches/micro-batches) that remain
70
+ contiguous after sorting, but are shuffled relative to other packs.
71
+ init_seed: Seed for the random number generator used for shuffling packs.
72
+ """
73
+
74
+ self._base_dataset = base_dataset
75
+ self._buffer_size = buffer_size
76
+ self._pack_size = pack_size
77
+
78
+ self._rng = random.Random(init_seed ^ 0x105E7 if init_seed is not None else None)
79
+ self._buffer_indices: list[int] = []
80
+ self._buffer_idx: int = -1
81
+
82
+ def _update_buffer_idx(self, buffer_idx: int):
83
+ select_start = buffer_idx * self._buffer_size
84
+ select_end = (buffer_idx + 1) * self._buffer_size
85
+ select_end = min(select_end, len(self._base_dataset))
86
+
87
+ base_idx = list(range(select_start, select_end))
88
+ base_sort_keys = [self._base_dataset.sort_key(idx) for idx in range(select_start, select_end)]
89
+
90
+ local_idx = list(range(len(base_idx)))
91
+ local_idx = sorted(local_idx, key=lambda local_id: base_sort_keys[local_id])
92
+
93
+ local_idx_batch = [
94
+ local_idx[i: i + self._pack_size]
95
+ for i in range(0, len(local_idx), self._pack_size)
96
+ ]
97
+ self._rng.shuffle(local_idx_batch)
98
+ local_idx = [y for x in local_idx_batch for y in x]
99
+
100
+ self._buffer_indices = [base_idx[local_id] for local_id in local_idx]
101
+
102
+ self._buffer_idx = buffer_idx
103
+
104
+ def __getitem__(self, index: int) -> _T_co:
105
+ """
106
+ Retrieves an item from the locally sorted/shuffled buffer.
107
+
108
+ Args:
109
+ index: The global index.
110
+
111
+ Returns:
112
+ The dataset item.
113
+ """
114
+
115
+ needs_buffer_idx = index // self._buffer_size
116
+ if self._buffer_idx != needs_buffer_idx:
117
+ self._update_buffer_idx(needs_buffer_idx)
118
+
119
+ take_id = self._buffer_indices[index % self._buffer_size]
120
+
121
+ return self._base_dataset[take_id]
122
+
123
+ def __len__(self) -> int:
124
+ """Returns the length of the base dataset."""
125
+
126
+ return len(self._base_dataset)
127
+
128
+ def state_dict(self) -> dict[str, Any]:
129
+ ret = {
130
+ "seed": pickle.dumps(self._rng.getstate()),
131
+ "buffer_idx": self._buffer_idx,
132
+ "buffer_indices": self._buffer_indices,
133
+ }
134
+ if isinstance(self._base_dataset, Stateful):
135
+ ret["base_dataset"] = self._base_dataset.state_dict()
136
+ return ret
137
+
138
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
139
+ self._rng.setstate(pickle.loads(state_dict["seed"])) # noqa: S301
140
+ self._buffer_idx = state_dict["buffer_idx"]
141
+ self._buffer_indices = state_dict["buffer_indices"]
142
+ if isinstance(self._base_dataset, Stateful):
143
+ self._base_dataset.load_state_dict(state_dict["base_dataset"])
d9d/dataset/padding.py ADDED
@@ -0,0 +1,79 @@
1
+ from collections.abc import Sequence
2
+ from enum import StrEnum
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class PaddingSide1D(StrEnum):
9
+ """
10
+ Enum specifying the side for padding 1D sequences.
11
+
12
+ Attributes:
13
+ left: Pad on the left side.
14
+ right: Pad on the right side.
15
+ """
16
+
17
+ left = "left"
18
+ right = "right"
19
+
20
+
21
+ def _padding_side_1d_to_config(side: PaddingSide1D, difference: int) -> tuple[int, ...]:
22
+ match side:
23
+ case PaddingSide1D.left:
24
+ return difference, 0
25
+ case PaddingSide1D.right:
26
+ return 0, difference
27
+ case _:
28
+ raise ValueError("Unknown padding side")
29
+
30
+
31
+ def pad_stack_1d(
32
+ items: Sequence[torch.Tensor],
33
+ pad_value: int,
34
+ padding_side: PaddingSide1D = PaddingSide1D.right,
35
+ pad_to_multiple_of: int | None = None
36
+ ) -> torch.Tensor:
37
+ """
38
+ Stacks 1D tensors into a batch, applying padding.
39
+
40
+ Calculates the maximum length among the input tensors (optionally aligning to a multiple),
41
+ pads elements to match this length on the specified side, and stacks them.
42
+
43
+ Args:
44
+ items: A sequence of 1D tensors to be stacked.
45
+ pad_value: The value used for padding.
46
+ padding_side: The side on which to apply padding (left or right).
47
+ pad_to_multiple_of: Optional integer. If provided, ensures the target length
48
+ is a multiple of this value.
49
+
50
+ Returns:
51
+ A single stacked tensor of shape (batch, max_length).
52
+
53
+ Raises:
54
+ ValueError: If no items are provided or if `pad_to_multiple_of` is <= 0.
55
+ """
56
+
57
+ if not items:
58
+ raise ValueError("Cannot stack 0 items")
59
+ if pad_to_multiple_of is not None and pad_to_multiple_of <= 0:
60
+ raise ValueError("pad_to_multiple_of should be > 0")
61
+
62
+ max_len = max(x.shape[0] for x in items)
63
+
64
+ if pad_to_multiple_of is not None and (remainder := max_len % pad_to_multiple_of) != 0:
65
+ max_len = max_len + (pad_to_multiple_of - remainder)
66
+
67
+ padded_items = []
68
+
69
+ for x in items:
70
+ difference = max_len - x.shape[0]
71
+
72
+ if difference == 0:
73
+ padded_items.append(x)
74
+ else:
75
+ padded_items.append(
76
+ F.pad(x, _padding_side_1d_to_config(padding_side, difference), value=pad_value)
77
+ )
78
+
79
+ return torch.stack(padded_items, dim=0)
d9d/dataset/sharded.py ADDED
@@ -0,0 +1,195 @@
1
+ import math
2
+ from collections.abc import Sized
3
+ from enum import StrEnum
4
+ from typing import Any, TypeVar
5
+
6
+ from torch.distributed.checkpoint.stateful import Stateful
7
+ from torch.utils.data import Dataset
8
+
9
+ from d9d.core.dist_context import BATCH_DOMAIN, DistributedContext
10
+
11
+
12
+ class ShardIndexingMode(StrEnum):
13
+ """
14
+ Defines how the dataset is split across shards.
15
+
16
+ Modes:
17
+ sequential: Round-robin distribution.
18
+
19
+ shard0: 0, 4, 8, 12
20
+ shard1: 1, 5, 9, 13
21
+ shard2: 2, 6, 10
22
+ shard3: 3, 7, 11
23
+
24
+ chunked: Contiguous blocks.
25
+
26
+ shard0: 0, 1, 2, 3
27
+ shard1: 4, 5, 6, 7
28
+ shard2: 8, 9, 10, 11
29
+ shard3: 12, 13
30
+ """
31
+
32
+ sequential = "sequential"
33
+ chunked = "chunked"
34
+
35
+
36
+ _T_co = TypeVar("_T_co", covariant=True)
37
+
38
+
39
+ class ShardedDataset(Dataset[_T_co], Stateful):
40
+ """
41
+ A dataset wrapper that acts as a view onto a specific shard of the underlying dataset.
42
+
43
+ This is useful for Data Parallel training where each process should only see
44
+ a subset of the data. It supports different indexing modes and optional padding
45
+ to ensure all shards have equal length (preventing hangs in distributed collectives).
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ dataset: Dataset[_T_co],
51
+ total_shards: int,
52
+ current_shard: int,
53
+ indexing_mode: ShardIndexingMode,
54
+ pad_to_equal_size_across_shards: bool
55
+ ):
56
+ """
57
+ Constructs a ShardedDataset object.
58
+
59
+ Args:
60
+ dataset: The underlying dataset to shard.
61
+ total_shards: The total number of shards (e.g., number of DP ranks).
62
+ current_shard: The index of the current shard (e.g., current DP rank).
63
+ indexing_mode: How indices are assigned to shards (sequential/round-robin or chunked).
64
+ pad_to_equal_size_across_shards: If True, the length of the dataset will be padded
65
+ so that all shards report the same length. The last standard element is repeated.
66
+ """
67
+
68
+ if not isinstance(dataset, Sized):
69
+ raise ValueError("Dataset should implement __len__ method")
70
+
71
+ self._dataset = dataset
72
+
73
+ self._total_shards = total_shards
74
+ self._current_shard = current_shard
75
+
76
+ self._indexing_mode = indexing_mode
77
+ self._pad_to_equal_size_across_shards = pad_to_equal_size_across_shards
78
+
79
+ def _compute_real_index_sequential(self, index: int) -> int:
80
+ return index * self._total_shards + self._current_shard
81
+
82
+ def _get_base_index_unsafe(self, index: int) -> int:
83
+ """
84
+ Calculates the underlying dataset index for a given shard index,
85
+ without boundary checking.
86
+ """
87
+
88
+ match self._indexing_mode:
89
+ case ShardIndexingMode.sequential:
90
+ base_index = index * self._total_shards + self._current_shard
91
+
92
+ return base_index
93
+ case ShardIndexingMode.chunked:
94
+ ceil_len = math.ceil(len(self._dataset) / self._total_shards)
95
+ shard_start_offset = ceil_len * self._current_shard
96
+
97
+ return shard_start_offset + index
98
+ case _:
99
+ raise ValueError(f"Unknown shard indexing mode: {self._indexing_mode}")
100
+
101
+ def __getitem__(self, index: int) -> _T_co:
102
+ """
103
+ Retrieves an item from the underlying dataset mapping logic shard index to physical index.
104
+
105
+ If padding is enabled and the index exceeds the valid data for this shard,
106
+ the last item in the dataset is returned.
107
+
108
+ Args:
109
+ index: The index relative to this shard.
110
+
111
+ Returns:
112
+ The data item.
113
+ """
114
+
115
+ base_index = self._get_base_index_unsafe(index)
116
+ if base_index >= len(self._dataset):
117
+ base_index = len(self._dataset) - 1
118
+ return self._dataset[base_index]
119
+
120
+ def __len__(self) -> int:
121
+ """
122
+ Returns the number of items in this specific shard.
123
+
124
+ If `pad_to_equal_size_across_shards` is True, this returns the ceiling
125
+ length (max length across all shards).
126
+ """
127
+
128
+ ceil_len = math.ceil(len(self._dataset) / self._total_shards)
129
+
130
+ if self._pad_to_equal_size_across_shards:
131
+ return ceil_len
132
+
133
+ shards_remainder = len(self._dataset) % self._total_shards
134
+ match self._indexing_mode:
135
+ case ShardIndexingMode.sequential:
136
+ shards_full = len(self._dataset) // self._total_shards
137
+ return shards_full + 1 if self._current_shard < shards_remainder else shards_full
138
+ case ShardIndexingMode.chunked:
139
+ is_shard_last = self._current_shard == self._total_shards - 1
140
+ if not is_shard_last or shards_remainder == 0:
141
+ return ceil_len
142
+ else:
143
+ return ceil_len - (self._total_shards - shards_remainder)
144
+
145
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
146
+ if isinstance(self._dataset, Stateful):
147
+ self._dataset.load_state_dict(state_dict["dataset"])
148
+
149
+ # check whether env mismatched
150
+ if state_dict["total_shards"] != self._total_shards:
151
+ raise ValueError("Shard count mismatch")
152
+ self._total_shards = state_dict["total_shards"]
153
+
154
+ self._current_shard = state_dict["current_shard"]
155
+
156
+ def state_dict(self) -> dict[str, Any]:
157
+ dct: dict[str, Any] = {
158
+ "total_shards": self._total_shards,
159
+ "current_shard": self._current_shard
160
+ }
161
+ if isinstance(self._dataset, Stateful):
162
+ dct["dataset"] = self._dataset.state_dict()
163
+ return dct
164
+
165
+
166
+ def shard_dataset_data_parallel(
167
+ dataset: Dataset[_T_co],
168
+ dist_context: DistributedContext,
169
+ indexing_mode: ShardIndexingMode = ShardIndexingMode.sequential,
170
+ pad_to_equal_size_across_shards: bool = True
171
+ ) -> Dataset[_T_co]:
172
+ """
173
+ Wraps a dataset into a ShardedDataset based on the Data Parallel dimension of the distributed context.
174
+
175
+ This is a helper function to automatically determine the correct rank and world size
176
+ from the 'dp' (Data Parallel) mesh dimension within the batch domain DeviceMesh.
177
+
178
+ Args:
179
+ dataset: The source dataset to shard.
180
+ dist_context: The distributed context.
181
+ indexing_mode: The strategy for splitting data indices (sequential/round-robin or chunked).
182
+ pad_to_equal_size_across_shards: If True, ensures all shards have the same length by padding.
183
+
184
+ Returns:
185
+ A dataset instance representing the local shard.
186
+ """
187
+
188
+ dp_mesh = dist_context.mesh_for(BATCH_DOMAIN)["dp"]
189
+ return ShardedDataset(
190
+ dataset=dataset,
191
+ total_shards=dp_mesh.size(),
192
+ current_shard=dp_mesh.get_local_rank(),
193
+ indexing_mode=indexing_mode,
194
+ pad_to_equal_size_across_shards=pad_to_equal_size_across_shards
195
+ )
File without changes
@@ -0,0 +1,10 @@
1
+ """
2
+ This package provides utilities for making your distributed setup deterministic.
3
+ """
4
+
5
+
6
+ from .seed import set_seeds
7
+
8
+ __all__ = [
9
+ "set_seeds"
10
+ ]
@@ -0,0 +1,63 @@
1
+ import os
2
+ import random
3
+ from typing import cast
4
+
5
+ import torch
6
+ import torch.distributed.tensor
7
+
8
+ from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
9
+
10
+
11
+ def set_seeds(
12
+ dist_context: DistributedContext,
13
+ seed: int,
14
+ distinct_seed_mesh_dim: str = "pp",
15
+ ) -> None:
16
+ """
17
+ Sets random seeds for Python, NumPy, and PyTorch.
18
+
19
+ This function sets seeds deterministically based on the provided base seed and the
20
+ process's rank within a specific mesh dimension.
21
+
22
+ The seed is shifted by the rank in the `distinct_seed_mesh_dim` (e.g., Pipeline Parallel rank).
23
+ This ensures that processes in different pipeline stages operate with different random states,
24
+ while processes that should share randomness (like Expert Parallel peers) can be synchronized.
25
+
26
+ Args:
27
+ dist_context: The distributed context.
28
+ seed: The base random seed.
29
+ distinct_seed_mesh_dim: The name of the mesh dimension along which seeds should
30
+ be distinct (e.g., 'pp' for pipeline parallelism). Ranks along other dimensions
31
+ will share the seed.
32
+ """
33
+
34
+ # Mutate seed based on PP rank if distributed
35
+ if dist_context.mesh_params.is_distributed:
36
+ distinct_mesh = dist_context.mesh_for(REGULAR_DOMAIN)[distinct_seed_mesh_dim]
37
+ seed = (seed + distinct_mesh.get_local_rank()) % 2**64
38
+
39
+ dist_context.logger.info(f"Set seed {seed}")
40
+
41
+ torch.manual_seed(seed)
42
+ os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
43
+ random.seed(seed)
44
+
45
+ try:
46
+ import numpy as np # noqa: PLC0415
47
+ np.random.seed(seed)
48
+ except ImportError:
49
+ pass
50
+
51
+ # Set DTensor seeding if distributed
52
+ if dist_context.mesh_params.is_distributed:
53
+ mesh_regular = dist_context.mesh_for(REGULAR_DOMAIN)
54
+ duplicate_seed_mesh_dim = tuple(
55
+ name
56
+ for name
57
+ in cast(list[str], mesh_regular.mesh_dim_names)
58
+ if name != distinct_seed_mesh_dim
59
+ )
60
+ duplicate_seed_mesh = mesh_regular[duplicate_seed_mesh_dim] if len(duplicate_seed_mesh_dim) != 0 else None
61
+
62
+ if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None:
63
+ torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh) # noqa: SLF001
@@ -0,0 +1,8 @@
1
+ from .group import ParametersForNorm, group_parameters_for_norm
2
+ from .norm import clip_grad_norm_distributed_
3
+
4
+ __all__ = [
5
+ "ParametersForNorm",
6
+ "clip_grad_norm_distributed_",
7
+ "group_parameters_for_norm"
8
+ ]
@@ -0,0 +1,87 @@
1
+ import dataclasses
2
+ from collections import defaultdict
3
+ from collections.abc import Iterable
4
+ from typing import Any
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.distributed import DeviceMesh
9
+ from torch.distributed.tensor import DTensor, Shard
10
+
11
+
12
+ @dataclasses.dataclass(kw_only=True, frozen=True)
13
+ class GradNormGroup:
14
+ """
15
+ Defines a group of parameters that share the same distributed properties.
16
+
17
+ This grouping is used to batch gradient norm reductions efficiently. Parameters
18
+ sharing the same device mesh shards can be reduced in a single communication collective.
19
+
20
+ Attributes:
21
+ shard_meshes: A tuple of device meshes where the parameters are sharded, or None if replicated/local.
22
+ device: The device where parameters reside.
23
+ grad_dtype: The data type of the gradients.
24
+ """
25
+
26
+ shard_meshes: tuple[DeviceMesh, ...] | None
27
+ device: torch.device
28
+ grad_dtype: torch.dtype | None
29
+
30
+
31
+ ParametersForNorm = dict[GradNormGroup, list[nn.Parameter]]
32
+
33
+
34
+ def _extract_shard_meshes(param: nn.Parameter) -> tuple[DeviceMesh, ...] | None:
35
+ data = param.data
36
+
37
+ if not isinstance(data, DTensor):
38
+ return None
39
+
40
+ mesh = data.device_mesh
41
+ mesh_dim_names = mesh.mesh_dim_names
42
+ if mesh_dim_names is None:
43
+ raise ValueError("Only named meshes are supported.")
44
+
45
+ shard_placement_dim_names: list[str] = []
46
+
47
+ for dim_i, placement in enumerate(data.placements):
48
+ if isinstance(placement, Shard):
49
+ shard_placement_dim_names.append(mesh_dim_names[dim_i])
50
+
51
+ if len(shard_placement_dim_names) == 0:
52
+ return None
53
+
54
+ return tuple(mesh[name] for name in shard_placement_dim_names)
55
+
56
+
57
+ def _group_sort_key(item: tuple[GradNormGroup, list[nn.Parameter]]) -> Any:
58
+ # put items WITH shard_meshes on top so they are processed first so we benefit from comm-comp overlap
59
+ return item[0].shard_meshes is None
60
+
61
+
62
+ def group_parameters_for_norm(parameters: Iterable[nn.Parameter]) -> ParametersForNorm:
63
+ """
64
+ Groups parameters based on their distributed tensor characteristics.
65
+
66
+ Groups parameters by their sharding meshes, device, and gradient data type.
67
+
68
+ Args:
69
+ parameters: The iterable of parameters to group.
70
+
71
+ Returns:
72
+ A dictionary mapping synchronization groups to lists of parameters.
73
+ """
74
+
75
+ grouped_params: ParametersForNorm = defaultdict(list)
76
+ for param in parameters:
77
+ if not param.requires_grad:
78
+ continue
79
+
80
+ group = GradNormGroup(
81
+ shard_meshes=_extract_shard_meshes(param),
82
+ grad_dtype=param.grad_dtype,
83
+ device=param.device
84
+ )
85
+ grouped_params[group].append(param)
86
+ # we are sure dict is ordered in python 3.11 so we can sort it...
87
+ return dict(sorted(grouped_params.items(), key=_group_sort_key))