d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,70 @@
1
+ import abc
2
+ import dataclasses
3
+
4
+ import torch
5
+
6
+
7
+ @dataclasses.dataclass(frozen=True)
8
+ class StateGroup:
9
+ """
10
+ Represents an atomic unit of dependency in the model state transformation graph.
11
+
12
+ A `StateGroup` defines a strict contract between a set of input keys (source)
13
+ and a set of output keys (destination).
14
+
15
+ Attributes:
16
+ inputs: The complete set of keys required from the source state dictionary to satisfy this dependency.
17
+ outputs: The complete set of keys that will be produced as a result of this transformation.
18
+ """
19
+
20
+ inputs: frozenset[str]
21
+ outputs: frozenset[str]
22
+
23
+
24
+ class ModelStateMapper(abc.ABC):
25
+ """
26
+ The abstract base class for all model state transformation operations.
27
+
28
+ This class serves as the interface between the definition of a transformation
29
+ topology and the actual execution of tensor operations.
30
+
31
+ It enforces a Declarative vs. Imperative separation of concerns:
32
+
33
+ 1. Declarative (Topology): Through `state_dependency_groups()`, the mapper
34
+ announces *what* it intends to do without handling any data. This allows the system to build execution graphs,
35
+ validate chains, detect collisions, and shard tasks *before* allocating memory.
36
+ 2. Imperative (Execution): Through `apply()`, the mapper performs the
37
+ actual logic (PyTorch operations) on model states.
38
+ """
39
+
40
+ @abc.abstractmethod
41
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
42
+ """
43
+ Calculates and returns the set of independent dependency groups this mapper handles.
44
+
45
+ Returns:
46
+ A frozenset of `StateGroup` objects. Each group
47
+ represents a disjoint operation. For example, a mapper that renames ten
48
+ independent tensors would return ten distinct `StateGroup` objects,
49
+ allowing them to be sharded or processed individually.
50
+ """
51
+ ...
52
+
53
+ @abc.abstractmethod
54
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
55
+ """
56
+ Executes the transformation logic on a specific dictionary of tensors.
57
+
58
+ The orchestration system guarantees that the `group` dictionary passed here contains
59
+ all keys listed in the `inputs` of the active `StateGroup`.
60
+
61
+ Implementation of this method should guarantee that the result will contain all keys listed in the `outputs`.
62
+
63
+ Args:
64
+ group: A dictionary containing the source data.
65
+ Keys match `StateGroup.inputs`.
66
+
67
+ Returns:
68
+ A dictionary containing the transformed data. Keys must strictly match `StateGroup.outputs`.
69
+ """
70
+ ...
@@ -0,0 +1,12 @@
1
+ """
2
+ This package provides utility functions that are used to create simple ModelStateMapper instances from objects
3
+ such as PyTorch modules or other StateMappers
4
+ """
5
+
6
+ from .mapper import identity_mapper_from_mapper_outputs
7
+ from .module import identity_mapper_from_module
8
+
9
+ __all__ = [
10
+ "identity_mapper_from_mapper_outputs",
11
+ "identity_mapper_from_module"
12
+ ]
@@ -0,0 +1,27 @@
1
+ from d9d.model_state.mapper import ModelStateMapper
2
+ from d9d.model_state.mapper.compose import ModelStateMapperParallel
3
+ from d9d.model_state.mapper.leaf import ModelStateMapperIdentity
4
+
5
+
6
+ def identity_mapper_from_mapper_outputs(mapper: ModelStateMapper) -> ModelStateMapper:
7
+ """
8
+ Creates an identity mapper covering all outputs produced by the provided mapper.
9
+
10
+ This function inspects the `state_dependency_groups()` of the input `mapper`,
11
+ extracts every key listed in the `outputs` set of each group, and creates a
12
+ corresponding `ModelStateMapperIdentity` for it.
13
+
14
+ Args:
15
+ mapper: The mapper whose output signature will be inspected to generate the new identity mapper.
16
+
17
+ Returns:
18
+ A composite mapper that acts as a pass-through for every key produced by the source `mapper`.
19
+ """
20
+
21
+ mappers: list[ModelStateMapper] = []
22
+
23
+ for state_group in mapper.state_dependency_groups():
24
+ for output_name in state_group.outputs:
25
+ mappers.append(ModelStateMapperIdentity(output_name))
26
+
27
+ return ModelStateMapperParallel(mappers)
@@ -0,0 +1,22 @@
1
+ from torch import nn
2
+
3
+ from d9d.model_state.mapper import ModelStateMapper
4
+ from d9d.model_state.mapper.compose import ModelStateMapperParallel
5
+ from d9d.model_state.mapper.leaf import ModelStateMapperIdentity
6
+
7
+
8
+ def identity_mapper_from_module(module: nn.Module) -> ModelStateMapper:
9
+ """
10
+ Creates an identity mapper for every parameter in a single PyTorch module.
11
+
12
+ It is useful when you want to define a "pass-through" pipeline where the
13
+ source checkpoint keys are expected to exactly match the model's current
14
+ parameter names (standard `load_state_dict` behavior).
15
+
16
+ Args:
17
+ module: The instantiated PyTorch model to inspect.
18
+ """
19
+
20
+ return ModelStateMapperParallel(
21
+ [ModelStateMapperIdentity(key) for key in module.state_dict()]
22
+ )
@@ -0,0 +1,17 @@
1
+ """
2
+ Complex state mappers are built using composition. This package provides ModelStateMapper implementations that
3
+ are composed of other mappers.
4
+ """
5
+
6
+
7
+ from .helper import filter_empty_mappers
8
+ from .parallel import ModelStateMapperParallel
9
+ from .sequential import ModelStateMapperSequential
10
+ from .shard import ModelStateMapperShard
11
+
12
+ __all__ = [
13
+ "ModelStateMapperParallel",
14
+ "ModelStateMapperSequential",
15
+ "ModelStateMapperShard",
16
+ "filter_empty_mappers"
17
+ ]
@@ -0,0 +1,22 @@
1
+ from collections.abc import Sequence
2
+
3
+ from d9d.model_state.mapper.abc import ModelStateMapper
4
+
5
+
6
+ def filter_empty_mappers(mappers: Sequence[ModelStateMapper]) -> list[ModelStateMapper]:
7
+ """
8
+ Filters out mappers that have no effect (no inputs and no outputs).
9
+
10
+ Args:
11
+ mappers: The list of mappers to filter.
12
+
13
+ Returns:
14
+ A new list containing only active mappers.
15
+ """
16
+ result = []
17
+ for mapper in mappers:
18
+ for group in mapper.state_dependency_groups():
19
+ if len(group.inputs) > 0 or len(group.outputs) > 0:
20
+ result.append(mapper)
21
+ break
22
+ return result
@@ -0,0 +1,58 @@
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+
5
+ from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
6
+ from d9d.model_state.mapper.compose.helper import filter_empty_mappers
7
+
8
+
9
+ class ModelStateMapperParallel(ModelStateMapper):
10
+ """
11
+ Executes a list of states mappers independently alongside each other.
12
+
13
+ This class aggregates multiple mappers into a single logical unit.
14
+ It enforces strict isolation between the mappers: no two mappers can
15
+ consume the same input key (input collision) or produce the same output
16
+ key (output collision).
17
+
18
+ During execution (`apply`), it routes the specific subset of the input dictionary
19
+ to the sub-mapper responsible for those keys.
20
+ """
21
+
22
+ def __init__(self, mappers: Sequence[ModelStateMapper]):
23
+ mappers_lst = filter_empty_mappers(mappers)
24
+
25
+ all_groups = set()
26
+ inputs_to_mapper = {}
27
+
28
+ seen_inputs: set[str] = set()
29
+ seen_outputs: set[str] = set()
30
+ for mapper in mappers_lst:
31
+ sub_groups = mapper.state_dependency_groups()
32
+
33
+ for sub_group in sub_groups:
34
+ if not seen_inputs.isdisjoint(sub_group.inputs):
35
+ raise ValueError(f"Found a colliding input group: {sub_group.inputs}")
36
+ seen_inputs.update(sub_group.inputs)
37
+
38
+ if not seen_outputs.isdisjoint(sub_group.outputs):
39
+ raise ValueError(f"Found colliding output keys: {sub_group.outputs}")
40
+ seen_outputs.update(sub_group.outputs)
41
+
42
+ all_groups.add(sub_group)
43
+ inputs_to_mapper[sub_group.inputs] = mapper
44
+
45
+ self._all_groups = frozenset(all_groups)
46
+ self._inputs_to_mapper = inputs_to_mapper
47
+
48
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
49
+ return self._all_groups
50
+
51
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
52
+ group_keys = frozenset(group.keys())
53
+
54
+ if group_keys not in self._inputs_to_mapper:
55
+ raise ValueError("Tried to run a parallel mapper with undefined group. Perhaps you sent groups that are "
56
+ "not isolated?")
57
+
58
+ return self._inputs_to_mapper[group_keys].apply(group)
@@ -0,0 +1,131 @@
1
+ from collections.abc import Sequence
2
+ from collections.abc import Set as AbstractSet
3
+
4
+ import torch
5
+
6
+ from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
7
+ from d9d.model_state.mapper.compose.helper import filter_empty_mappers
8
+ from d9d.model_state.mapper.compose.parallel import ModelStateMapperParallel
9
+ from d9d.model_state.mapper.leaf.identity import ModelStateMapperIdentity
10
+
11
+
12
+ class ModelStateMapperSequential(ModelStateMapper):
13
+ """
14
+ Executes a list of mappers in a specific sequence (pipeline).
15
+
16
+ This class manages the data flow from one mapper to the next. It abstracts
17
+ away intermediate states, exposing only the inputs required by the first
18
+ relevant stage and the outputs produced by the final relevant stage.
19
+
20
+ Key Features:
21
+
22
+ 1. **Gap Filling**: Automatically injects `Identity` mappers if a tensor needs
23
+ to pass through a stage without modification to reach a later stage or
24
+ the final output.
25
+
26
+ 2. **Group Merging**: Computes the net dependency graph. If Stage A requires 'x'
27
+ and produces 'y', and Stage B requires 'y' and produces 'z', the
28
+ Sequential mapper reports a single group `{x} -> {z}`.
29
+ """
30
+
31
+ def __init__(self, mappers: list[ModelStateMapper]):
32
+ mappers = filter_empty_mappers(mappers)
33
+ if not mappers:
34
+ raise ValueError("Mappers list cannot be empty.")
35
+
36
+ mappers = self._fill_gaps(mappers)
37
+
38
+ self._groups = self._compute_pipeline_groups(mappers)
39
+ self._mappers = mappers
40
+
41
+ @staticmethod
42
+ def _fill_gaps(mappers: list[ModelStateMapper]) -> list[ModelStateMapper]:
43
+ mappers = mappers.copy()
44
+
45
+ # propagate inputs from bottom to top
46
+ for stage_i in range(1, len(mappers))[::-1]:
47
+ groups_current = mappers[stage_i].state_dependency_groups()
48
+ groups_prev = mappers[stage_i - 1].state_dependency_groups()
49
+ current_stage_requires = frozenset.union(*(x.inputs for x in groups_current))
50
+ prev_stage_produces = frozenset.union(*(x.outputs for x in groups_prev))
51
+
52
+ needs_to_pass_through = current_stage_requires - prev_stage_produces
53
+
54
+ mappers[stage_i - 1] = ModelStateMapperParallel(
55
+ [mappers[stage_i - 1]] + [ModelStateMapperIdentity(x) for x in needs_to_pass_through]
56
+ )
57
+
58
+ # propagate outputs from top to bottom
59
+ for stage_i in range(0, len(mappers) - 1):
60
+ groups_current = mappers[stage_i].state_dependency_groups()
61
+ groups_next = mappers[stage_i + 1].state_dependency_groups()
62
+ current_stage_produces = frozenset.union(*(x.outputs for x in groups_current))
63
+ next_stage_requires = frozenset.union(*(x.inputs for x in groups_next))
64
+
65
+ needs_to_pass_through = current_stage_produces - next_stage_requires
66
+
67
+ mappers[stage_i + 1] = ModelStateMapperParallel(
68
+ [mappers[stage_i + 1]] + [ModelStateMapperIdentity(x) for x in needs_to_pass_through]
69
+ )
70
+
71
+ return mappers
72
+
73
+ @staticmethod
74
+ def _compute_pipeline_groups(mappers: list[ModelStateMapper]) -> frozenset[StateGroup]:
75
+ outputs_depend_on_inputs = {}
76
+
77
+ # given a fully connected graph, we can just go upwards
78
+ for last_group_traced in mappers[-1].state_dependency_groups():
79
+ required_inputs = last_group_traced.inputs
80
+
81
+ for mapper_i in range(0, len(mappers) - 1)[::-1]:
82
+ next_visit_groups = [x for x in mappers[mapper_i].state_dependency_groups()
83
+ if not x.outputs.isdisjoint(required_inputs)]
84
+
85
+ required_inputs = frozenset.union(*(x.inputs for x in next_visit_groups))
86
+
87
+ outputs_depend_on_inputs[last_group_traced.outputs] = required_inputs
88
+
89
+ return ModelStateMapperSequential._merge_groups(list(outputs_depend_on_inputs.items()))
90
+
91
+ @staticmethod
92
+ def _merge_groups(groups: Sequence[tuple[AbstractSet[str], AbstractSet[str]]]) -> frozenset[StateGroup]:
93
+ saved_groups: list[tuple[set[str], set[str]]] = []
94
+
95
+ saved_groups_modified = True
96
+ while saved_groups_modified:
97
+ saved_groups_modified = False
98
+ for output_names, input_names in groups:
99
+ was_new_group_created = False
100
+ for group in saved_groups:
101
+ if group[0].intersection(input_names) or group[1].intersection(output_names):
102
+ group[0].update(input_names)
103
+ group[1].update(output_names)
104
+ was_new_group_created = True
105
+ saved_groups_modified = True
106
+
107
+ if not was_new_group_created:
108
+ saved_groups.append((set(input_names), set(output_names)))
109
+
110
+ groups = saved_groups
111
+ saved_groups = []
112
+
113
+ return frozenset(StateGroup(inputs=frozenset(x[0]), outputs=frozenset(x[1])) for x in groups)
114
+
115
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
116
+ return self._groups
117
+
118
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
119
+ current_state = group
120
+ next_state = {}
121
+ for mapper in self._mappers:
122
+ for deps in mapper.state_dependency_groups():
123
+ if not deps.inputs <= current_state.keys():
124
+ continue
125
+
126
+ next_state.update(mapper.apply({k: v for k, v in current_state.items() if k in deps.inputs}))
127
+
128
+ current_state = next_state
129
+ next_state = {}
130
+
131
+ return current_state
@@ -0,0 +1,36 @@
1
+ import torch
2
+
3
+ from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
4
+
5
+
6
+ class ModelStateMapperShard(ModelStateMapper):
7
+ """
8
+ Wraps another state mapper and restricts its execution to a specific subset (shard)
9
+ of dependency groups.
10
+
11
+ This is primarily used for parallelizing model loading across multiple processes
12
+ or nodes. By assigning a different `current_shard` index to each process,
13
+ the total set of tensors required by the `sub_mapper` is split evenly,
14
+ preventing every process from loading the entire checkpoint.
15
+ """
16
+
17
+ def __init__(self, sub_mapper: ModelStateMapper, total_shards: int, current_shard: int):
18
+ self._groups = self._shard_groups(
19
+ sub_mapper.state_dependency_groups(),
20
+ n_shards=total_shards, shard=current_shard
21
+ )
22
+ self._sub_mapper = sub_mapper
23
+ self._total_shards = total_shards
24
+ self._current_shard = current_shard
25
+
26
+ @staticmethod
27
+ def _shard_groups(groups: frozenset[StateGroup], n_shards: int, shard: int) -> frozenset[StateGroup]:
28
+ groups_sorted = sorted(groups, key=lambda x: sorted(x.inputs))
29
+ groups_shard = [x for i, x in enumerate(groups_sorted) if i % n_shards == shard]
30
+ return frozenset(groups_shard)
31
+
32
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
33
+ return self._groups
34
+
35
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
36
+ return self._sub_mapper.apply(group)
@@ -0,0 +1,18 @@
1
+ """
2
+ This package provides leaf mapper implementations.
3
+ """
4
+
5
+ from .dtensor import ModelStateMapperDistribute, ModelStateMapperGatherFullTensor
6
+ from .identity import ModelStateMapperIdentity
7
+ from .rename import ModelStateMapperRename
8
+ from .select_child import ModelStateMapperSelectChildModules
9
+ from .stack import ModelStateMapperStackTensors
10
+
11
+ __all__ = [
12
+ "ModelStateMapperDistribute",
13
+ "ModelStateMapperGatherFullTensor",
14
+ "ModelStateMapperIdentity",
15
+ "ModelStateMapperRename",
16
+ "ModelStateMapperSelectChildModules",
17
+ "ModelStateMapperStackTensors",
18
+ ]
@@ -0,0 +1,56 @@
1
+ from collections.abc import Sequence
2
+
3
+ import torch
4
+ from torch._C._distributed import Placement
5
+ from torch.distributed import DeviceMesh
6
+ from torch.distributed.tensor import DTensor, distribute_tensor
7
+
8
+ from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
9
+
10
+
11
+ class ModelStateMapperDistribute(ModelStateMapper):
12
+ """
13
+ Converts a single local Tensor object into a DTensor object with specified
14
+ `device_mesh` and `placements`.
15
+ """
16
+
17
+ def __init__(self, name: str, device_mesh: DeviceMesh | None, placements: Sequence[Placement] | None):
18
+ self._name = name
19
+
20
+ self._device_mesh = device_mesh
21
+ self._placements = placements
22
+
23
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
24
+ return frozenset([StateGroup(inputs=frozenset([self._name]), outputs=frozenset([self._name]))])
25
+
26
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
27
+ return {
28
+ self._name: distribute_tensor(
29
+ group[self._name],
30
+ device_mesh=self._device_mesh,
31
+ placements=self._placements,
32
+ src_data_rank=None # do not communicate here
33
+ )
34
+ }
35
+
36
+
37
+ class ModelStateMapperGatherFullTensor(ModelStateMapper):
38
+ """
39
+ Gathers a single DTensor object into a full Tensor object.
40
+ """
41
+
42
+ def __init__(self, name: str):
43
+ self._name = name
44
+
45
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
46
+ return frozenset([StateGroup(inputs=frozenset([self._name]), outputs=frozenset([self._name]))])
47
+
48
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
49
+ tensor = group[self._name]
50
+
51
+ if not isinstance(tensor, DTensor):
52
+ raise ValueError("Cannot gather anything but DTensor")
53
+
54
+ return {
55
+ self._name: tensor.full_tensor()
56
+ }
@@ -0,0 +1,23 @@
1
+ import torch
2
+
3
+ from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
4
+
5
+
6
+ class ModelStateMapperIdentity(ModelStateMapper):
7
+ """
8
+ Passes a single state tensor through unchanged.
9
+ """
10
+
11
+ def __init__(self, name: str):
12
+ self._name = name
13
+
14
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
15
+ return frozenset([
16
+ StateGroup(
17
+ inputs=frozenset([self._name]),
18
+ outputs=frozenset([self._name])
19
+ )
20
+ ])
21
+
22
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
23
+ return group
@@ -0,0 +1,26 @@
1
+ import torch
2
+
3
+ from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
4
+
5
+
6
+ class ModelStateMapperRename(ModelStateMapper):
7
+ """
8
+ Renames a single state tensor from `name_from` to `name_to`.
9
+ """
10
+
11
+ def __init__(self, name_from: str, name_to: str):
12
+ self._name_from = name_from
13
+ self._name_to = name_to
14
+
15
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
16
+ return frozenset([
17
+ StateGroup(
18
+ inputs=frozenset([self._name_from]),
19
+ outputs=frozenset([self._name_to])
20
+ )
21
+ ])
22
+
23
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
24
+ return {
25
+ self._name_to: group[self._name_from]
26
+ }
@@ -0,0 +1,37 @@
1
+ import torch
2
+
3
+ from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
4
+
5
+
6
+ class ModelStateMapperSelectChildModules(ModelStateMapper):
7
+ """
8
+ Selects a set of keys belonging to a specific parent module (prefix) and
9
+ renames them by removing that prefix.
10
+
11
+ This is effectively a batch rename operation that "hoists" parameters
12
+ from a submodule scope to the current scope.
13
+ """
14
+
15
+ def __init__(self, base_names: list[str], parent_name: str):
16
+ self._base_names = base_names
17
+ self._parent_prefix = f"{parent_name}."
18
+
19
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
20
+ return frozenset([
21
+ StateGroup(
22
+ inputs=frozenset([self._parent_prefix + name]),
23
+ outputs=frozenset([name])
24
+ )
25
+ for name in self._base_names
26
+ ])
27
+
28
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
29
+ name, value = next(iter(group.items()))
30
+ if name.startswith(self._parent_prefix):
31
+ return {
32
+ name[len(self._parent_prefix):]: value
33
+ }
34
+ else:
35
+ return {
36
+
37
+ }
@@ -0,0 +1,29 @@
1
+ import torch
2
+
3
+ from d9d.model_state.mapper.abc import ModelStateMapper, StateGroup
4
+
5
+
6
+ class ModelStateMapperStackTensors(ModelStateMapper):
7
+ """
8
+ Stacks multiple input tensors with names `source_names` into a single output tensor with name `target_name`
9
+ producing new `stack_dim` dimension.
10
+ """
11
+
12
+ def __init__(self, source_names: list[str], target_name: str, stack_dim: int):
13
+ self._source_names = source_names
14
+ self._target_name = target_name
15
+ self._stack_dim = stack_dim
16
+
17
+ def state_dependency_groups(self) -> frozenset[StateGroup]:
18
+ return frozenset([
19
+ StateGroup(
20
+ inputs=frozenset(self._source_names),
21
+ outputs=frozenset([self._target_name])
22
+ )
23
+ ])
24
+
25
+ def apply(self, group: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
26
+ source_tensors = [group[name] for name in self._source_names]
27
+ return {
28
+ self._target_name: torch.stack(source_tensors, dim=self._stack_dim)
29
+ }
d9d/module/__init__.py ADDED
File without changes
@@ -0,0 +1,7 @@
1
+ """Defines structural protocols and base classes for PyTorch modules used within the d9d framework."""
2
+
3
+ from .late_init import ModuleLateInit
4
+
5
+ __all__ = [
6
+ "ModuleLateInit"
7
+ ]
@@ -0,0 +1,10 @@
1
+ import typing
2
+ from typing import Protocol
3
+
4
+
5
+ @typing.runtime_checkable
6
+ class ModuleLateInit(Protocol):
7
+ """Protocol for modules that support late parameter initialization."""
8
+
9
+ def reset_parameters(self):
10
+ """Resets the module parameters (i.e. performs random initialization)."""
File without changes
@@ -0,0 +1,7 @@
1
+ """Provides attention layer implementations."""
2
+
3
+ from .grouped_query import GroupedQueryAttention
4
+
5
+ __all__ = [
6
+ "GroupedQueryAttention"
7
+ ]