d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,30 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class ModelStateIndexMeta(BaseModel):
5
+ """
6
+ Metadata for the model state index.
7
+
8
+ Attributes:
9
+ total_size: Total size of the model parameters in bytes.
10
+ """
11
+
12
+ total_size: int
13
+
14
+
15
+ class ModelStateIndex(BaseModel):
16
+ """
17
+ Represents the content of the `model.safetensors.index.json` file.
18
+
19
+ This index maps every weight name to the specific .safetensors file containing it.
20
+
21
+ Attributes:
22
+ metadata: Global metadata about the checkpoint.
23
+ weight_map: Mapping from parameter name to filename.
24
+ """
25
+
26
+ metadata: ModelStateIndexMeta
27
+ weight_map: dict[str, str]
28
+
29
+
30
+ MODEL_STATE_INDEX_FILE_NAME = "model.safetensors.index.json"
@@ -0,0 +1,75 @@
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.distributed.tensor import DTensor
6
+
7
+ from d9d.model_state.mapper import ModelStateMapper
8
+ from d9d.model_state.mapper.compose import (
9
+ ModelStateMapperParallel,
10
+ ModelStateMapperSequential,
11
+ )
12
+ from d9d.model_state.mapper.leaf import (
13
+ ModelStateMapperDistribute,
14
+ ModelStateMapperIdentity,
15
+ )
16
+
17
+ from .reader import read_model_state
18
+
19
+
20
+ def _build_injection_mapper(name: str, state: torch.Tensor) -> ModelStateMapper:
21
+ if isinstance(state, DTensor):
22
+ return ModelStateMapperDistribute(name=name, placements=state.placements, device_mesh=state.device_mesh)
23
+ else:
24
+ return ModelStateMapperIdentity(name)
25
+
26
+
27
+ def _augment_mapper_for_injection(model: nn.Module, mapper: ModelStateMapper) -> ModelStateMapper:
28
+ states_to_load = {output for group in mapper.state_dependency_groups() for output in group.outputs}
29
+ current_state_dict = model.state_dict()
30
+ mapper = ModelStateMapperSequential([
31
+ mapper,
32
+ ModelStateMapperParallel([_build_injection_mapper(name, current_state_dict[name]) for name in states_to_load])
33
+ ])
34
+ return mapper
35
+
36
+
37
+ def load_model_state(
38
+ src_dir: Path,
39
+ mapper: ModelStateMapper,
40
+ device: str,
41
+ model: nn.Module,
42
+ show_progress: bool = True,
43
+ ):
44
+ """
45
+ High-level utility to stream a checkpoint directly into a PyTorch module.
46
+
47
+ This function orchestrates the full loading lifecycle:
48
+
49
+ 1. Topology Mapping: Uses `mapper` to rename/stack/reshape on-disk states to model states.
50
+
51
+ 2. Automatic Distribution: If the `model` contains `DTensor`s, the loaded local tensors are automatically
52
+ sharded/replicated to match the model's placement schema.
53
+
54
+ 3. Streaming Read & Inject: After loading and transforming a model state, it will be injected into `model`
55
+ using `load_state_dict(...)`.
56
+
57
+ NOTICE: Only states specified in `mapper` will be loaded! You can use
58
+ `d9d.model_state.mapper.adapters.identity_mapper_from_module(module)` to create a mapper that will load every
59
+ model state without changing it.
60
+
61
+ Args:
62
+ src_dir: Directory containing .safetensors and index files.
63
+ mapper: The topology defining how mapping from disk keys to model keys works.
64
+ device: The device to load tensors onto (usually "cpu" or "cuda").
65
+ model: The model instance to load weights into.
66
+ show_progress: Whether to display the loading progress bar.
67
+ """
68
+
69
+ for state_name, state_value in read_model_state(
70
+ src_dir=src_dir,
71
+ mapper=_augment_mapper_for_injection(model, mapper),
72
+ device=device,
73
+ show_progress=show_progress
74
+ ):
75
+ model.load_state_dict({state_name: state_value}, strict=False)
@@ -0,0 +1,123 @@
1
+ from collections.abc import Iterable
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.distributed import DeviceMesh
7
+ from torch.distributed.tensor import DTensor
8
+
9
+ from d9d.model_state.mapper import ModelStateMapper
10
+ from d9d.model_state.mapper.compose import (
11
+ ModelStateMapperParallel,
12
+ ModelStateMapperSequential,
13
+ )
14
+ from d9d.model_state.mapper.leaf import (
15
+ ModelStateMapperGatherFullTensor,
16
+ ModelStateMapperIdentity,
17
+ )
18
+
19
+ from .writer import (
20
+ write_model_state_local,
21
+ write_model_state_pipeline_parallel,
22
+ )
23
+
24
+
25
+ def _build_extraction_mapper(name: str, state: torch.Tensor) -> ModelStateMapper:
26
+ if isinstance(state, DTensor):
27
+ return ModelStateMapperGatherFullTensor(name)
28
+ else:
29
+ return ModelStateMapperIdentity(name)
30
+
31
+
32
+ def _augment_mapper_for_extraction(models: list[nn.Module], mapper: ModelStateMapper) -> ModelStateMapper:
33
+ states_to_save = {input_state for group in mapper.state_dependency_groups() for input_state in group.inputs}
34
+
35
+ current_state_dict = {}
36
+ for model in models:
37
+ current_state_dict.update(model.state_dict())
38
+ mapper = ModelStateMapperSequential([
39
+ ModelStateMapperParallel([_build_extraction_mapper(name, current_state_dict[name]) for name in states_to_save]),
40
+ mapper
41
+ ])
42
+ return mapper
43
+
44
+
45
+ def _state_generator(models: list[nn.Module]) -> Iterable[tuple[str, torch.Tensor]]:
46
+ for model in models:
47
+ yield from model.state_dict().items()
48
+
49
+
50
+ def save_model_state(
51
+ dest_dir: Path,
52
+ mapper: ModelStateMapper,
53
+ model: nn.Module,
54
+ shard_size_gb: float = 4.0,
55
+ show_progress: bool = True
56
+ ):
57
+ """
58
+ High-level utility to save a PyTorch model to disk on a **single** process.
59
+
60
+ NOTICE: Only states specified in `mapper` will be saved! You can use
61
+ `d9d.model_state.mapper.adapters.identity_mapper_from_module(module)` to create a mapper that will save every
62
+ model state without changing it.
63
+
64
+ Args:
65
+ dest_dir: The directory to save .safetensors shards and index.
66
+ mapper: Topology defining how model keys map to disk keys.
67
+ model: The PyTorch module to save.
68
+ shard_size_gb: Max size per shard file in Gigabytes.
69
+ show_progress: Whether to display a progress bar.
70
+ """
71
+
72
+ write_model_state_local(
73
+ dest_dir=dest_dir,
74
+ mapper=_augment_mapper_for_extraction([model], mapper),
75
+ state_generator=_state_generator([model]),
76
+ shard_size_gb=shard_size_gb,
77
+ show_progress=show_progress
78
+ )
79
+
80
+
81
+ def save_model_state_pipeline_parallel(
82
+ dest_dir: Path,
83
+ mapper: ModelStateMapper,
84
+ device_mesh: DeviceMesh,
85
+ pipeline_dim_name: str,
86
+ models: list[nn.Module],
87
+ shard_size_gb: float = 4.0,
88
+ show_progress: bool = True
89
+ ):
90
+ """
91
+ High-level utility to save a model in a Distributed Pipeline Parallel environment to disk.
92
+
93
+ Features:
94
+
95
+ 1. **Auto-Gather**: Converts `DTensor` parameters to full tensors before saving.
96
+
97
+ 2. **Distribution Awareness**: Uses the `device_mesh` to ensure that for a given pipeline stage,
98
+ only the master rank writes the checkpoint, preventing Write-After-Write conflicts.
99
+
100
+ 3. **Index Merging**: Aggregates metadata from all independent pipeline stages into one global index file.
101
+
102
+ NOTICE: Only states specified in `mapper` will be saved! You can use
103
+ `d9d.model_state.mapper.adapters.identity_mapper_from_module(module)` to create a mapper that will save every
104
+ model state without changing it.
105
+
106
+ Args:
107
+ dest_dir: directory to save .safetensors shards and index file.
108
+ mapper: Topology defining how model keys map to disk keys.
109
+ device_mesh: The cluster topology mesh.
110
+ pipeline_dim_name: The specific dimension name in the mesh used for pipelining.
111
+ models: A list of modules (pipeline stages) processed by this PP rank.
112
+ shard_size_gb: Max size per shard file in Gigabytes.
113
+ show_progress: Whether to display a progress bar.
114
+ """
115
+ write_model_state_pipeline_parallel(
116
+ dest_dir=dest_dir,
117
+ mapper=_augment_mapper_for_extraction(models, mapper),
118
+ state_generator=_state_generator(models),
119
+ device_mesh=device_mesh,
120
+ pipeline_dim_name=pipeline_dim_name,
121
+ shard_size_gb=shard_size_gb,
122
+ show_progress=show_progress
123
+ )
@@ -0,0 +1,125 @@
1
+ from collections import defaultdict
2
+ from collections.abc import Generator, Iterable
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from safetensors import safe_open
7
+ from tqdm import tqdm
8
+
9
+ from d9d.model_state.io.dto import MODEL_STATE_INDEX_FILE_NAME, ModelStateIndex
10
+ from d9d.model_state.mapper import ModelStateMapper
11
+
12
+
13
+ class _StateLoadingFlow:
14
+ """
15
+ Internal orchestration logic for loading and transforming model states in a streamed manner.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ src_dir: Path,
21
+ mapper: ModelStateMapper,
22
+ device: str,
23
+ show_progress: bool
24
+ ):
25
+ self._src_dir = src_dir
26
+ self._mapper = mapper
27
+ self._device = device
28
+
29
+ # I/O in constructor!
30
+ self._index = self._load_index()
31
+ self._groups_to_process = set(mapper.state_dependency_groups())
32
+
33
+ self._stored_states: dict[str, torch.Tensor] = {}
34
+
35
+ self._check_index()
36
+
37
+ self._pbar = tqdm(
38
+ desc="Loading Model States",
39
+ total=len([output_name for group in self._groups_to_process for output_name in group.outputs]),
40
+ disable=not show_progress
41
+ )
42
+
43
+ def _load_index(self) -> ModelStateIndex:
44
+ index_file = self._src_dir / MODEL_STATE_INDEX_FILE_NAME
45
+ index_data = index_file.read_text(encoding="utf-8")
46
+ index = ModelStateIndex.model_validate_json(index_data)
47
+ return index
48
+
49
+ def _check_index(self):
50
+ will_process_inputs: set[str] = set()
51
+ for group in self._groups_to_process:
52
+ will_process_inputs.update(group.inputs)
53
+
54
+ on_disk_inputs = set(self._index.weight_map.keys())
55
+
56
+ missing_inputs = will_process_inputs.difference(on_disk_inputs)
57
+
58
+ if len(missing_inputs) > 0:
59
+ raise ValueError(f"Cannot run state loading: states {missing_inputs} are missing!")
60
+
61
+ def _update_in_memory_states(self, file_to_load: str, params_to_load: set[str]):
62
+ with safe_open(str(self._src_dir / file_to_load), framework="pt", device=str(self._device)) as st:
63
+ for param_to_load in params_to_load:
64
+ self._stored_states[param_to_load] = st.get_tensor(param_to_load)
65
+
66
+ def _process_available_groups(self) -> Generator[tuple[str, torch.Tensor], None, None]:
67
+ for group in self._groups_to_process.copy():
68
+ if not group.inputs.issubset(self._stored_states.keys()):
69
+ continue
70
+
71
+ self._groups_to_process.remove(group)
72
+
73
+ loaded_states = self._mapper.apply(
74
+ {k: v for k, v in self._stored_states.items() if k in group.inputs}
75
+ )
76
+ yield from loaded_states.items()
77
+ self._pbar.update(len(loaded_states))
78
+
79
+ for input_name in group.inputs:
80
+ del self._stored_states[input_name]
81
+
82
+ def _build_file_loading_plan(self) -> dict[str, set[str]]:
83
+ plan = defaultdict(set)
84
+ for group in self._mapper.state_dependency_groups():
85
+ for key in group.inputs:
86
+ require_file = self._index.weight_map[key]
87
+ plan[require_file].add(key)
88
+ return plan
89
+
90
+ def load(self) -> Iterable[tuple[str, torch.Tensor]]:
91
+ with self._pbar:
92
+ for file_to_load, params_to_load in self._build_file_loading_plan().items():
93
+ self._update_in_memory_states(file_to_load, params_to_load)
94
+ yield from self._process_available_groups()
95
+
96
+
97
+ def read_model_state(
98
+ src_dir: Path,
99
+ mapper: ModelStateMapper,
100
+ device: str,
101
+ show_progress: bool = True
102
+ ) -> Iterable[tuple[str, torch.Tensor]]:
103
+ """
104
+ Reads a model checkpoint from disk, transforming it on-the-fly according to the state mapper.
105
+
106
+ This function uses a streaming approach. It analyzes the mapper to determine which files
107
+ need to be loaded. Tensors are loaded into memory only when needed and evicted immediately
108
+ after the mapper processes them.
109
+
110
+ Args:
111
+ src_dir: The directory containing .safetensors files and `model.safetensors.index.json` file.
112
+ mapper: The transformation graph defining how to map on-disk keys to output keys.
113
+ device: The device to load tensors onto (e.g., "cpu", "cuda:0").
114
+ show_progress: Whether to display a progress bar.
115
+
116
+ Yields:
117
+ A tuple containing the transformed parameter name and its tensor value.
118
+ """
119
+
120
+ yield from _StateLoadingFlow(
121
+ src_dir=src_dir,
122
+ device=device,
123
+ mapper=mapper,
124
+ show_progress=show_progress
125
+ ).load()
@@ -0,0 +1,309 @@
1
+ import warnings
2
+ from collections.abc import Iterable
3
+ from pathlib import Path
4
+ from typing import cast
5
+
6
+ import torch
7
+ from safetensors.torch import save_file
8
+ from torch.distributed import DeviceMesh, ProcessGroup
9
+ from tqdm import tqdm
10
+
11
+ from d9d.core.dist_ops import all_gather_object
12
+ from d9d.model_state.io.dto import (
13
+ MODEL_STATE_INDEX_FILE_NAME,
14
+ ModelStateIndex,
15
+ ModelStateIndexMeta,
16
+ )
17
+ from d9d.model_state.mapper import ModelStateMapper
18
+
19
+
20
+ class _StateWritingFlowLocal:
21
+ """
22
+ Internal orchestration logic for buffering, transforming, and sharding model states during save.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ dest_dir: Path,
28
+ mapper: ModelStateMapper,
29
+ shard_size_gb: float,
30
+ show_progress: bool,
31
+ sharding_rank: int,
32
+ # so we have to call writing flow from all processes, but
33
+ is_current_process_rank_master: bool
34
+ ):
35
+ self._dest_dir = dest_dir
36
+ self._mapper = mapper
37
+ self._shard_size_bytes = int(shard_size_gb * (1024 ** 3))
38
+
39
+ self._groups_to_process = set(mapper.state_dependency_groups())
40
+
41
+ self._available_source_states: dict[str, torch.Tensor] = {}
42
+
43
+ self._total_size = 0
44
+ self._pending_write_tensors: dict[str, torch.Tensor] = {}
45
+ self._current_shard_size = 0
46
+
47
+ self._sharding_rank = sharding_rank
48
+ self._weight_name_to_local_shard_idx: dict[str, int] = {}
49
+ self._local_shard_idx_to_tmp_path: dict[int, Path] = {}
50
+
51
+ self._is_current_process_rank_master = is_current_process_rank_master
52
+ total_num_outputs = len([out_name for group in self._groups_to_process for out_name in group.outputs])
53
+ self._pbar = tqdm(
54
+ desc="Saving Model States",
55
+ total=total_num_outputs,
56
+ disable=not (show_progress and is_current_process_rank_master)
57
+ )
58
+
59
+ def _flush_shard(self):
60
+ if not self._pending_write_tensors:
61
+ return
62
+
63
+ local_shard_num = len(self._local_shard_idx_to_tmp_path) + 1
64
+ shard_tmp_path = self._dest_dir / f".tmp-rank{self._sharding_rank}-shard-{local_shard_num}.safetensors"
65
+
66
+ self._local_shard_idx_to_tmp_path[local_shard_num] = shard_tmp_path
67
+ save_file(self._pending_write_tensors, str(shard_tmp_path))
68
+
69
+ for state_name in self._pending_write_tensors:
70
+ self._weight_name_to_local_shard_idx[state_name] = local_shard_num
71
+
72
+ self._pbar.update(len(self._pending_write_tensors))
73
+
74
+ self._total_size += self._current_shard_size
75
+
76
+ self._pending_write_tensors.clear()
77
+ self._current_shard_size = 0
78
+
79
+ def _process_available_groups(self):
80
+ for group in self._groups_to_process.copy():
81
+ if not group.inputs.issubset(self._available_source_states.keys()):
82
+ continue
83
+
84
+ self._groups_to_process.remove(group)
85
+
86
+ states_to_save = self._mapper.apply(
87
+ {k: self._available_source_states[k] for k in group.inputs}
88
+ )
89
+
90
+ for input_name in group.inputs:
91
+ del self._available_source_states[input_name]
92
+
93
+ # proceed with stateful saving only on master rank
94
+ if self._is_current_process_rank_master:
95
+ for name, tensor in states_to_save.items():
96
+ update_size = tensor.numel() * tensor.element_size()
97
+
98
+ if update_size > self._shard_size_bytes:
99
+ raise ValueError(f"Cannot save state {name} that is larger than shard size")
100
+
101
+ if self._current_shard_size + update_size > self._shard_size_bytes:
102
+ self._flush_shard()
103
+
104
+ self._pending_write_tensors[name] = tensor
105
+ self._current_shard_size += update_size
106
+
107
+ def _finalize_locally(self) -> ModelStateIndex:
108
+ self._flush_shard()
109
+
110
+ if self._groups_to_process:
111
+ missing_groups = {g.inputs for g in self._groups_to_process}
112
+ raise ValueError(
113
+ f"Writing failed: not all source tensors were provided to satisfy mapper dependencies. "
114
+ f"Missing inputs for groups: {missing_groups}"
115
+ )
116
+
117
+ if self._available_source_states:
118
+ warnings.warn(
119
+ f"State Writing: The following source tensors were provided but not consumed by any "
120
+ f"mapper group and will be ignored: {sorted(self._available_source_states.keys())}",
121
+ stacklevel=2
122
+ )
123
+
124
+ weight_map_local = {
125
+ name: self._local_shard_idx_to_tmp_path[shard_idx].name
126
+ for name, shard_idx in self._weight_name_to_local_shard_idx.items()
127
+ }
128
+
129
+ return ModelStateIndex(
130
+ metadata=ModelStateIndexMeta(total_size=self._total_size),
131
+ weight_map=weight_map_local
132
+ )
133
+
134
+ def write(self, state_generator: Iterable[tuple[str, torch.Tensor]]) -> ModelStateIndex | None:
135
+ with self._pbar:
136
+ self._dest_dir.mkdir(parents=True, exist_ok=True)
137
+
138
+ for name, tensor in state_generator:
139
+ self._available_source_states[name] = tensor
140
+ self._process_available_groups()
141
+
142
+ if self._is_current_process_rank_master:
143
+ return self._finalize_locally()
144
+ else:
145
+ return None
146
+
147
+
148
+ def _finalize_master(dest_dir: Path, indices: list[ModelStateIndex]):
149
+ total_size = sum(index.metadata.total_size for index in indices)
150
+ total_weight_map_local = dict(pair for index in indices for pair in index.weight_map.items())
151
+ shard_count = len({file_name for index in indices for _, file_name in index.weight_map.items()})
152
+
153
+ total_weight_map = {}
154
+
155
+ local_file_to_global_file = {}
156
+ used_global_files = 0
157
+
158
+ for weight_name, old_file_name in total_weight_map_local.items():
159
+ if old_file_name not in local_file_to_global_file:
160
+ used_global_files += 1
161
+ new_file_name = f"model-{used_global_files:05d}-of-{shard_count:05d}.safetensors"
162
+
163
+ (dest_dir / old_file_name).rename(dest_dir / new_file_name)
164
+
165
+ local_file_to_global_file[old_file_name] = new_file_name
166
+
167
+ total_weight_map[weight_name] = local_file_to_global_file[old_file_name]
168
+
169
+ index_path = dest_dir / MODEL_STATE_INDEX_FILE_NAME
170
+ index_path.write_text(
171
+ ModelStateIndex(
172
+ metadata=ModelStateIndexMeta(total_size=total_size),
173
+ weight_map=total_weight_map
174
+ ).model_dump_json(indent=4),
175
+ encoding="utf-8"
176
+ )
177
+
178
+
179
+ def write_model_state_local(
180
+ dest_dir: Path,
181
+ mapper: ModelStateMapper,
182
+ state_generator: Iterable[tuple[str, torch.Tensor]],
183
+ shard_size_gb: float = 4.0,
184
+ show_progress: bool = True
185
+ ):
186
+ """
187
+ Saves model states to disk in a single local process.
188
+
189
+ This function uses a streaming approach. It analyzes the mapper to determine which files
190
+ need to be saved. Tensors are loaded into memory only when needed and evicted immediately
191
+ after the mapper processes them.
192
+
193
+ Args:
194
+ dest_dir: Destination directory.
195
+ mapper: Mapping to apply to states before saving.
196
+ state_generator: Stream of (name, tensor) pairs to save.
197
+ shard_size_gb: Maximum size of a single .safetensors file in GB.
198
+ show_progress: Whether to show the progress bar.
199
+ """
200
+ idx = _StateWritingFlowLocal(
201
+ dest_dir=dest_dir,
202
+ mapper=mapper,
203
+ shard_size_gb=shard_size_gb,
204
+ show_progress=show_progress,
205
+ sharding_rank=0,
206
+ is_current_process_rank_master=True
207
+ ).write(state_generator=state_generator)
208
+
209
+ idx = cast(ModelStateIndex, idx) # we are sure is_current_process_rank_master=True
210
+
211
+ _finalize_master(dest_dir, [idx])
212
+
213
+
214
+ def write_model_state_distributed(
215
+ dest_dir: Path,
216
+ mapper: ModelStateMapper,
217
+ state_generator: Iterable[tuple[str, torch.Tensor]],
218
+ process_group: ProcessGroup,
219
+ shard_size_gb: float = 4.0,
220
+ show_progress: bool = True
221
+ ):
222
+ """
223
+ Saves model states in a distributed setup (multiple processes).
224
+
225
+ This function uses a streaming approach. It analyzes the mapper to determine which files
226
+ need to be saved. Tensors are loaded into memory only when needed and evicted immediately
227
+ after the mapper processes them.
228
+
229
+ Each rank writes its own shard. Rank 0 gathers indices and finalizes the checkpoint.
230
+
231
+ Args:
232
+ dest_dir: Destination directory.
233
+ mapper: Mapping to apply to states before saving.
234
+ state_generator: Stream of (name, tensor) pairs from the model.
235
+ process_group: The distributed process group.
236
+ shard_size_gb: Maximum shard size in GB.
237
+ show_progress: Whether to show the progress bar.
238
+ """
239
+
240
+ current_idx = _StateWritingFlowLocal(
241
+ dest_dir=dest_dir,
242
+ mapper=mapper,
243
+ shard_size_gb=shard_size_gb,
244
+ show_progress=show_progress,
245
+ sharding_rank=process_group.rank(),
246
+ is_current_process_rank_master=True
247
+ ).write(state_generator=state_generator)
248
+ gather_idx = all_gather_object(current_idx, process_group)
249
+ gather_idx_filter = [x for x in gather_idx if x is not None]
250
+ if process_group.rank() == 0:
251
+ _finalize_master(dest_dir, gather_idx_filter)
252
+
253
+
254
+ def write_model_state_pipeline_parallel(
255
+ dest_dir: Path,
256
+ mapper: ModelStateMapper,
257
+ state_generator: Iterable[tuple[str, torch.Tensor]],
258
+ device_mesh: DeviceMesh,
259
+ pipeline_dim_name: str,
260
+ shard_size_gb: float = 4.0,
261
+ show_progress: bool = True
262
+ ):
263
+ """
264
+ Saves model states in a complex ND distributed training setting.
265
+
266
+ This function uses a streaming approach. It analyzes the mapper to determine which files
267
+ need to be saved. Tensors are loaded into memory only when needed and evicted immediately
268
+ after the mapper processes them.
269
+
270
+ This handles Pipeline Parallelism by ensuring that only one rank per pipeline stage
271
+ actually writes data to disk to avoid duplication.
272
+
273
+ Args:
274
+ dest_dir: Destination directory.
275
+ mapper: Mapping to apply to states before saving.
276
+ state_generator: Stream of (name, tensor) pairs from the model.
277
+ device_mesh: The PyTorch DeviceMesh representing the cluster layout.
278
+ pipeline_dim_name: The name of the mesh dimension responsible for pipeline parallelism.
279
+ shard_size_gb: Maximum shard size in GB.
280
+ show_progress: Whether to show the progress bar.
281
+ """
282
+
283
+ pipeline_rank = device_mesh[pipeline_dim_name].get_rank()
284
+
285
+ mesh_dim_names = device_mesh.mesh_dim_names
286
+ coords = device_mesh.get_coordinate()
287
+ if mesh_dim_names is None or coords is None:
288
+ raise ValueError("Cannot save state using a DeviceMesh with no dim names or coords")
289
+
290
+ non_pipeline_coord_sum = sum(
291
+ coord
292
+ for name, coord
293
+ in zip(mesh_dim_names, coords, strict=True)
294
+ if name != pipeline_dim_name
295
+ )
296
+ master_within_pipeline_rank = non_pipeline_coord_sum == 0
297
+
298
+ current_idx = _StateWritingFlowLocal(
299
+ dest_dir=dest_dir,
300
+ mapper=mapper,
301
+ shard_size_gb=shard_size_gb,
302
+ show_progress=show_progress,
303
+ sharding_rank=pipeline_rank,
304
+ is_current_process_rank_master=master_within_pipeline_rank
305
+ ).write(state_generator=state_generator)
306
+ gather_idx = all_gather_object(current_idx, device_mesh.get_group(0))
307
+ gather_idx_filter = [x for x in gather_idx if x is not None]
308
+ if pipeline_rank == 0 and master_within_pipeline_rank:
309
+ _finalize_master(dest_dir, gather_idx_filter)
@@ -0,0 +1,10 @@
1
+ """
2
+ This package provides core components of the state mapping system.
3
+ """
4
+
5
+ from .abc import ModelStateMapper, StateGroup
6
+
7
+ __all__ = [
8
+ "ModelStateMapper",
9
+ "StateGroup"
10
+ ]