d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,99 @@
1
+ from d9d.core.dist_context import DENSE_DOMAIN, EXPERT_DOMAIN, DistributedContext
2
+ from d9d.module.model.qwen3_moe import Qwen3MoEForCausalLM, Qwen3MoEModel
3
+ from d9d.module.parallelism.api import parallelize_expert_parallel, parallelize_hsdp
4
+ from d9d.pipelining.api import PipelineStageInfo
5
+
6
+
7
+ def parallelize_qwen3_moe_model(
8
+ dist_context: DistributedContext,
9
+ model: Qwen3MoEModel,
10
+ stage: PipelineStageInfo
11
+ ):
12
+ """
13
+ Parallelizes the base Qwen3 MoE model components.
14
+
15
+ This function configures the model layers for distributed execution within a pipeline
16
+ stage. It applies Hybrid Sharded Data Parallelism (HSDP) to dense components (embeddings,
17
+ norms, attention) and Expert Parallelism (EP) to the Mixture-of-Experts (MLP) layers.
18
+
19
+ Current usage constraints:
20
+ * Tensor Parallelism is not supported (we may implement it later).
21
+ * Context Parallelism is not supported (we will implement it later).
22
+
23
+ Args:
24
+ dist_context: The distributed context.
25
+ model: The Qwen3 MoE base model to parallelize.
26
+ stage: Information about the current pipeline stage.
27
+
28
+ Raises:
29
+ ValueError: If Tensor Parallel or Context Parallel is enabled in the context.
30
+ """
31
+
32
+ dims = dist_context.mesh_params
33
+ dense_mesh = dist_context.mesh_for(DENSE_DOMAIN)
34
+ expert_mesh = dist_context.mesh_for(EXPERT_DOMAIN)
35
+
36
+ if dims.has_tensor_parallel:
37
+ raise ValueError("Tensor Parallel currently is not supported for this model.")
38
+ if dims.has_context_parallel_replicate or dims.has_context_parallel_shard:
39
+ raise ValueError("Context Parallel currently is not supported for this model.")
40
+
41
+ if stage.is_current_stage_first:
42
+ parallelize_hsdp(
43
+ model.embed_tokens,
44
+ mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"]
45
+ )
46
+
47
+ if stage.is_current_stage_last:
48
+ parallelize_hsdp(
49
+ model.norm,
50
+ mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
51
+ )
52
+
53
+ for layer in model.layers.values():
54
+ parallelize_expert_parallel(
55
+ layer.mlp,
56
+ mesh_experts=expert_mesh["ep_replicate", "ep_shard"]
57
+ )
58
+
59
+ parallelize_hsdp(
60
+ layer.self_attn,
61
+ mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
62
+ )
63
+ parallelize_hsdp(
64
+ layer.input_layernorm,
65
+ mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
66
+ )
67
+ parallelize_hsdp(
68
+ layer.post_attention_layernorm,
69
+ mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
70
+ )
71
+
72
+
73
+ def parallelize_qwen3_moe_for_causal_lm(
74
+ dist_context: DistributedContext,
75
+ model: Qwen3MoEForCausalLM,
76
+ stage: PipelineStageInfo
77
+ ):
78
+ """
79
+ Parallelizes the Qwen3 MoE Causal LM model.
80
+
81
+ This function delegates backbone parallelization to ``parallelize_qwen3_moe_model``
82
+ and additionally configures the language model head with Hybrid Sharded Data
83
+ Parallelism (HSDP).
84
+
85
+ Args:
86
+ dist_context: The distributed context containing device meshes and topology info.
87
+ model: The Qwen3 MoE Causal LM model to parallelize.
88
+ stage: Information about the current pipeline stage.
89
+ """
90
+
91
+ dense_mesh = dist_context.mesh_for(DENSE_DOMAIN)
92
+
93
+ parallelize_qwen3_moe_model(dist_context, model.model, stage)
94
+
95
+ if stage.is_current_stage_last:
96
+ parallelize_hsdp(
97
+ model.lm_head,
98
+ mesh=dense_mesh["dp_replicate", "dp_cp_shard", "cp_replicate"],
99
+ )
@@ -0,0 +1,7 @@
1
+ from .shard_experts import ShardMoESparseExpertsParallel
2
+ from .to_local import ToLocalParallel
3
+
4
+ __all__ = [
5
+ "ShardMoESparseExpertsParallel",
6
+ "ToLocalParallel"
7
+ ]
@@ -0,0 +1,60 @@
1
+ from torch import nn
2
+ from torch.distributed import DeviceMesh
3
+ from torch.distributed.tensor import (
4
+ Replicate,
5
+ Shard,
6
+ distribute_module,
7
+ distribute_tensor,
8
+ )
9
+ from torch.distributed.tensor.parallel import ParallelStyle
10
+
11
+ from d9d.module.block.moe import GroupedLinear, MoELayer
12
+
13
+
14
+ class ShardMoESparseExpertsParallel(ParallelStyle):
15
+ """
16
+ Parallel style that shards MoE experts across a specific mesh dimension.
17
+
18
+ This style is designed for ``MoELayer`` instances using ``GroupedLinear`` for experts.
19
+ It splits the experts across the specified
20
+ dimension of the device mesh (Expert Parallelism). Other dimensions in the
21
+ mesh treat the parameters as Replicated.
22
+
23
+ It also initializes the necessary distributed communication groups within the
24
+ MoE layer to handle token dispatching.
25
+ """
26
+
27
+ def __init__(self, shard_dim_name: str):
28
+ self._shard_dim_name = shard_dim_name
29
+
30
+ def _partition_experts(self, module_name: str, mod: nn.Module, device_mesh: DeviceMesh):
31
+ if not isinstance(mod, GroupedLinear):
32
+ raise TypeError("This plan should be applied only on GroupedLinear")
33
+
34
+ mesh_dim_names = device_mesh.mesh_dim_names
35
+
36
+ if mesh_dim_names is None:
37
+ raise ValueError("This plan should be applied only on named DeviceMeshes")
38
+
39
+ placements = [
40
+ Shard(0) if dim_name == self._shard_dim_name else Replicate()
41
+ for dim_name
42
+ in mesh_dim_names
43
+ ]
44
+ weight = nn.Parameter(
45
+ distribute_tensor(mod.weight, device_mesh, placements),
46
+ requires_grad=mod.weight.requires_grad
47
+ )
48
+ mod.weight = weight
49
+
50
+ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
51
+ if not isinstance(module, MoELayer):
52
+ raise TypeError("This plan should be applied only on MoELayer")
53
+
54
+ module.enable_distributed_communicator(device_mesh.get_group(self._shard_dim_name))
55
+
56
+ for submod in module.modules():
57
+ if isinstance(submod, GroupedLinear):
58
+ distribute_module(submod, device_mesh, self._partition_experts)
59
+
60
+ return module
@@ -0,0 +1,86 @@
1
+ from typing import Any
2
+
3
+ from torch import nn
4
+ from torch.distributed import DeviceMesh
5
+ from torch.distributed.tensor import Placement, distribute_module, distribute_tensor
6
+ from torch.distributed.tensor.parallel import ParallelStyle
7
+
8
+
9
+ def _build_to_local_patched_class(
10
+ module: nn.Module,
11
+ grad_placement: tuple[Placement, ...],
12
+ param_names: list[str]
13
+ ) -> type:
14
+ param_name_to_property = {
15
+ param_name: property(
16
+ lambda self, pn=param_name: self._parameters[pn].to_local(grad_placements=grad_placement) # type: ignore
17
+ )
18
+ for param_name in param_names
19
+ }
20
+ return type(
21
+ f"Replicate{module.__class__.__name__}",
22
+ (module.__class__,),
23
+ param_name_to_property,
24
+ )
25
+
26
+
27
+ class _ModulePatch:
28
+ def __init__(self, class_mapper: dict[str, type]):
29
+ self._class_mapper = class_mapper
30
+
31
+ def __call__(self, mod: nn.Module, *args: Any, **kwargs: Any):
32
+ for submod_name, submod in mod.named_modules():
33
+ submod.__class__ = self._class_mapper[submod_name]
34
+
35
+
36
+ class ToLocalParallel(ParallelStyle):
37
+ """
38
+ Parallel style that distributes parameters and gradients but executes with local tensors.
39
+
40
+ This style wraps standard tensor distribution (via ``DTensor``) but injects
41
+ runtime hooks to temporarily unwrap ``DTensor`` parameters into local ``torch.Tensor``
42
+ during the forward pass.
43
+
44
+ This is useful for parallel strategies (like Replicate)
45
+ where the underlying calculation logic is not DTensor-aware, but the parameters must remain
46
+ distributed for gradient synchronization and for distributed checkpointing.
47
+ """
48
+
49
+ def __init__(self, param_placement: tuple[Placement, ...], grad_placement: tuple[Placement, ...]):
50
+ """
51
+ Constructs ToLocalParallel object.
52
+
53
+ Args:
54
+ param_placement: Tuple of placements defining how parameters are distributed.
55
+ grad_placement: Tuple of placements defining how gradients are synchronized.
56
+ """
57
+
58
+ self._grad_placement = grad_placement
59
+ self._param_placement = param_placement
60
+
61
+ def _distribute_params(self, name: str, module: nn.Module, device_mesh: DeviceMesh):
62
+ for param_name, param in module.named_parameters(recurse=False):
63
+ new_param = nn.Parameter(
64
+ distribute_tensor(param.data, device_mesh, self._param_placement),
65
+ requires_grad=param.requires_grad
66
+ )
67
+
68
+ module.register_parameter(param_name, new_param)
69
+
70
+ def _apply(self, master_module: nn.Module, device_mesh: DeviceMesh):
71
+ patched_classes = {}
72
+ original_classes = {}
73
+
74
+ for submod_name, submod in master_module.named_modules():
75
+ param_names = [name for name, p in submod.named_parameters(recurse=False)]
76
+ patched_classes[submod_name] = _build_to_local_patched_class(submod, self._grad_placement, param_names)
77
+ original_classes[submod_name] = submod.__class__
78
+
79
+ distribute_module(
80
+ submod,
81
+ device_mesh,
82
+ self._distribute_params
83
+ )
84
+
85
+ master_module.register_forward_pre_hook(_ModulePatch(patched_classes))
86
+ master_module.register_forward_hook(_ModulePatch(original_classes))
d9d/optim/__init__.py ADDED
File without changes
@@ -0,0 +1,5 @@
1
+ from .adamw import StochasticAdamW
2
+
3
+ __all__ = [
4
+ "StochasticAdamW"
5
+ ]
@@ -0,0 +1,158 @@
1
+ from typing import cast
2
+
3
+ import torch
4
+ from torch.distributed.tensor import DTensor
5
+ from torch.optim import Optimizer
6
+ from torch.optim.optimizer import ParamsT, StateDict
7
+
8
+ from d9d.kernel.stochastic import adamw_stochastic_bf16_
9
+
10
+ _GENERATOR_STATE_KEY = "_d9d_generator_state"
11
+
12
+
13
+ def _new_buffer(p: torch.Tensor, dtype_override: torch.dtype) -> torch.Tensor:
14
+ if isinstance(p, DTensor):
15
+ local_p = p.to_local()
16
+ else:
17
+ local_p = p
18
+
19
+ out = torch.zeros_like(local_p, dtype=dtype_override).contiguous()
20
+
21
+ if isinstance(p, DTensor):
22
+ out = DTensor.from_local(
23
+ local_tensor=out,
24
+ device_mesh=p.device_mesh,
25
+ placements=p.placements,
26
+ run_check=False,
27
+ shape=p.shape,
28
+ stride=p.stride(),
29
+ )
30
+
31
+ return out
32
+
33
+
34
+ def _tensor_to_local(tensor: torch.Tensor) -> torch.Tensor:
35
+ if isinstance(tensor, DTensor):
36
+ return tensor.to_local()
37
+ return tensor
38
+
39
+
40
+ class StochasticAdamW(Optimizer):
41
+ """Implements the AdamW algorithm with Stochastic Rounding.
42
+
43
+ This optimizer is designed to handle stochastic rounding primarily for BF16 training,
44
+ leveraging a custom kernel.
45
+
46
+ Parameters must be in BF16. Gradients could be both in BF16 and FP32.
47
+
48
+ It natively supports PyTorch distributed ``DTensor`` parameters.
49
+
50
+ It maintains its own random number generator state to ensure reproducibility.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ params: ParamsT,
56
+ lr: float,
57
+ betas: tuple[float, float] = (0.9, 0.999),
58
+ eps: float = 1e-8,
59
+ weight_decay: float = 1e-2,
60
+ generator: torch.Generator | None = None,
61
+ state_dtype: torch.dtype = torch.float32,
62
+ ):
63
+ """Constructs a new StochasticAdamW optimizer.
64
+
65
+ Args:
66
+ params: Iterable of parameters to optimize or dicts defining parameter groups.
67
+ lr: Learning rate.
68
+ betas: Coefficients used for computing running averages of gradient and its square.
69
+ eps: Term added to the denominator to improve numerical stability.
70
+ weight_decay: Weight decay coefficient.
71
+ generator: Pseudorandom number generator for stochastic rounding. If None,
72
+ a new generator is created and seeded from the main PyTorch generator.
73
+ state_dtype: Data Type to use for the optimizer states.
74
+ """
75
+
76
+ if lr <= 0:
77
+ raise ValueError(f"Invalid learning rate: {lr}")
78
+ if eps <= 0:
79
+ raise ValueError(f"Invalid epsilon value: {eps}")
80
+ if not 0.0 <= betas[0] < 1.0:
81
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
82
+ if not 0.0 <= betas[1] < 1.0:
83
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
84
+ if weight_decay <= 0:
85
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
86
+
87
+ if generator is None:
88
+ generator = torch.Generator(device="cpu")
89
+ # make the generator fork from pytorch's main generator
90
+ seed = cast(int, torch.randint(0, 2**32, (1,)).item())
91
+ generator.manual_seed(seed)
92
+
93
+ self._generator = generator
94
+
95
+ defaults = {
96
+ "lr": lr,
97
+ "betas": betas,
98
+ "eps": eps,
99
+ "weight_decay": weight_decay,
100
+ "state_dtype": state_dtype
101
+ }
102
+ super().__init__(params, defaults)
103
+
104
+ def state_dict(self) -> StateDict:
105
+ state_dict = super().state_dict()
106
+ state_dict[_GENERATOR_STATE_KEY] = self._generator.get_state()
107
+ return state_dict
108
+
109
+ def load_state_dict(self, state_dict: StateDict) -> None:
110
+ if _GENERATOR_STATE_KEY in state_dict:
111
+ self._generator.set_state(state_dict.pop(_GENERATOR_STATE_KEY))
112
+ super().load_state_dict(state_dict)
113
+
114
+ @torch.no_grad()
115
+ def step(self, closure: None = None) -> None: # type: ignore[override]
116
+ if closure is not None:
117
+ raise ValueError("Closure is not supported")
118
+
119
+ for group in self.param_groups:
120
+ lr = group["lr"]
121
+ beta1, beta2 = group["betas"]
122
+ eps = group["eps"]
123
+ weight_decay = group["weight_decay"]
124
+ state_dtype = group["state_dtype"]
125
+
126
+ for p in group["params"]:
127
+ if p.grad is None:
128
+ continue
129
+
130
+ grad = p.grad
131
+ if grad.is_sparse:
132
+ raise RuntimeError("StochasticAdamW does not support sparse gradients")
133
+
134
+ state = self.state[p]
135
+
136
+ # State Initialization
137
+ if len(state) == 0:
138
+ state["step"] = 0
139
+ state["exp_avg"] = _new_buffer(p, dtype_override=state_dtype)
140
+ state["exp_avg_sq"] = _new_buffer(p, dtype_override=state_dtype)
141
+
142
+ state["step"] += 1
143
+ exp_avg = state["exp_avg"]
144
+ exp_avg_sq = state["exp_avg_sq"]
145
+
146
+ adamw_stochastic_bf16_(
147
+ params=_tensor_to_local(p),
148
+ grads=_tensor_to_local(grad),
149
+ exp_avg=_tensor_to_local(exp_avg),
150
+ exp_avg_sq=_tensor_to_local(exp_avg_sq),
151
+ lr=lr,
152
+ beta1=beta1,
153
+ beta2=beta2,
154
+ eps=eps,
155
+ weight_decay=weight_decay,
156
+ step=state["step"],
157
+ generator=self._generator
158
+ )
d9d/peft/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ """
2
+ Provides core logic for PEFT (Parameter-Efficient Fine-Tuning) application and base definitions.
3
+ """
4
+
5
+ from .applicator import inject_peft_and_freeze, merge_peft
6
+ from .base import PeftInjectionResult, PeftMethod
7
+
8
+ __all__ = [
9
+ "PeftInjectionResult",
10
+ "PeftMethod",
11
+ "inject_peft_and_freeze",
12
+ "merge_peft"
13
+ ]
@@ -0,0 +1,12 @@
1
+ """
2
+ Package for composing multiple PEFT methods into a stack.
3
+ """
4
+
5
+ from .config import PeftStackConfig
6
+ from .method import PeftStack, peft_method_from_config
7
+
8
+ __all__ = [
9
+ "PeftStack",
10
+ "PeftStackConfig",
11
+ "peft_method_from_config"
12
+ ]
d9d/peft/all/config.py ADDED
@@ -0,0 +1,31 @@
1
+ from typing import Annotated, Literal
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from d9d.peft.full_tune.config import FullTuneConfig
6
+ from d9d.peft.lora.config import LoRAConfig
7
+
8
+
9
+ class PeftStackConfig(BaseModel):
10
+ """
11
+ Configuration for applying a stack of multiple PEFT methods sequentially.
12
+
13
+ Attributes:
14
+ kind: Discriminator field, always "stack".
15
+ methods: A list of specific PEFT configurations (e.g., LoRA, FullTune) to apply in order.
16
+ """
17
+
18
+ kind: Literal["stack"] = "stack"
19
+
20
+ methods: list["AnyPeftConfig"]
21
+
22
+
23
+ AnyPeftConfig = Annotated[
24
+ LoRAConfig
25
+ | FullTuneConfig
26
+ | PeftStackConfig,
27
+ Field(discriminator="kind"),
28
+ ]
29
+ """
30
+ Union type representing any valid PEFT configuration, discriminated by the 'kind' field.
31
+ """
d9d/peft/all/method.py ADDED
@@ -0,0 +1,76 @@
1
+ from typing import Self, cast
2
+
3
+ from pydantic import BaseModel
4
+ from torch import nn
5
+
6
+ from ..all.config import PeftStackConfig
7
+ from ..base import PeftInjectionResult, PeftMethod, TConfig
8
+ from ..full_tune.config import FullTuneConfig
9
+ from ..full_tune.method import FullTune
10
+ from ..lora.config import LoRAConfig
11
+ from ..lora.method import LoRA
12
+
13
+
14
+ class PeftStack(PeftMethod[PeftStackConfig]):
15
+ """
16
+ A composite PEFT method that applies a list of methods sequentially.
17
+ """
18
+
19
+ def __init__(self, methods: list[PeftMethod]):
20
+ """
21
+ Constructs a PeftStack object.
22
+
23
+ Args:
24
+ methods: A list of instantiated PEFT methods to apply in order.
25
+ """
26
+
27
+ self._methods = methods
28
+
29
+ def inject(self, module: nn.Module) -> PeftInjectionResult:
30
+ params_to_train = []
31
+ state_mappers = []
32
+
33
+ for method in self._methods:
34
+ result = method.inject(module)
35
+ params_to_train.extend(result.parameters_to_train)
36
+ state_mappers.extend(result.load_state_mappers)
37
+
38
+ return PeftInjectionResult(
39
+ parameters_to_train=params_to_train,
40
+ load_state_mappers=state_mappers
41
+ )
42
+
43
+ def merge(self, module: nn.Module):
44
+ for method in self._methods[::-1]:
45
+ method.merge(module)
46
+
47
+ @classmethod
48
+ def from_config(cls, config: PeftStackConfig) -> Self:
49
+ methods = []
50
+
51
+ for method in config.methods:
52
+ methods.append(peft_method_from_config(method))
53
+
54
+ return cls(methods)
55
+
56
+
57
+ _PEFT_CONFIG_MAP: dict[type[BaseModel], type[PeftMethod]] = {
58
+ LoRAConfig: LoRA,
59
+ FullTuneConfig: FullTune,
60
+ PeftStackConfig: PeftStack
61
+ }
62
+
63
+
64
+ def peft_method_from_config(config: TConfig) -> PeftMethod[TConfig]:
65
+ """
66
+ Factory function to instantiate the correct PeftMethod based on the configuration type.
67
+
68
+ Args:
69
+ config: A specific PEFT configuration object (e.g., LoRAConfig).
70
+
71
+ Returns:
72
+ The corresponding method instance.
73
+ """
74
+
75
+ method_cls = cast(type[PeftMethod[TConfig]], _PEFT_CONFIG_MAP[type(config)])
76
+ return method_cls.from_config(config)
d9d/peft/applicator.py ADDED
@@ -0,0 +1,47 @@
1
+ from torch import nn
2
+
3
+ from d9d.model_state.mapper import ModelStateMapper
4
+ from d9d.model_state.mapper.compose import ModelStateMapperParallel
5
+
6
+ from .base import PeftMethod
7
+
8
+
9
+ def inject_peft_and_freeze(method: PeftMethod, module: nn.Module) -> ModelStateMapper:
10
+ """
11
+ Applies a PEFT method to a module, freezes non-trained parameters, and prepares state mapping.
12
+
13
+ This function performs three main steps:
14
+
15
+ 1. Sets `requires_grad=False` for all parameters in the module.
16
+ 2. Calls the method's `inject` to modify the model structure.
17
+ 3. Sets `requires_grad=True` for the parameters returned by the injection result.
18
+
19
+ Args:
20
+ method: The PEFT method strategy to apply.
21
+ module: The PyTorch module to modify.
22
+
23
+ Returns:
24
+ A ModelStateMapper capable of loading checkpoint weights into the modified structure.
25
+ """
26
+
27
+ for param in module.parameters():
28
+ param.requires_grad = False
29
+
30
+ result = method.inject(module)
31
+
32
+ for param in result.parameters_to_train:
33
+ param.requires_grad = True
34
+
35
+ return ModelStateMapperParallel(result.load_state_mappers)
36
+
37
+
38
+ def merge_peft(method: PeftMethod, module: nn.Module):
39
+ """
40
+ Merges PEFT adaptations back into the base model weights.
41
+
42
+ Args:
43
+ method: The PEFT method strategy originally applied.
44
+ module: The PyTorch module to merge.
45
+ """
46
+
47
+ method.merge(module)
d9d/peft/base.py ADDED
@@ -0,0 +1,70 @@
1
+ import abc
2
+ import dataclasses
3
+ from typing import Generic, Self, TypeVar
4
+
5
+ from pydantic import BaseModel
6
+ from torch import nn
7
+
8
+ from d9d.model_state.mapper import ModelStateMapper
9
+
10
+
11
+ @dataclasses.dataclass(slots=True)
12
+ class PeftInjectionResult:
13
+ """
14
+ Encapsulates the result of injecting a PEFT method into a model.
15
+
16
+ Attributes:
17
+ parameters_to_train: A list of parameters that should remain trainable.
18
+ load_state_mappers: A list of mappers required to load pre-trained weights into the modified structure.
19
+ """
20
+
21
+ parameters_to_train: list[nn.Parameter]
22
+ load_state_mappers: list[ModelStateMapper]
23
+
24
+
25
+ TConfig = TypeVar("TConfig", bound=BaseModel)
26
+
27
+
28
+ class PeftMethod(abc.ABC, Generic[TConfig]):
29
+ """
30
+ Abstract base class for all Parameter-Efficient Fine-Tuning methods.
31
+ """
32
+
33
+ @abc.abstractmethod
34
+ def inject(self, module: nn.Module) -> PeftInjectionResult:
35
+ """
36
+ Modifies the module in-place to apply the PEFT strategy.
37
+
38
+ Args:
39
+ module: The PyTorch module to modify.
40
+
41
+ Returns:
42
+ Result object containing trainable parameters and structure mappers.
43
+ """
44
+ ...
45
+
46
+ @abc.abstractmethod
47
+ def merge(self, module: nn.Module):
48
+ """
49
+ Merges the trained adapters back into the base model parameters.
50
+
51
+ Args:
52
+ module: The PyTorch module to update.
53
+ """
54
+
55
+ ...
56
+
57
+ @classmethod
58
+ @abc.abstractmethod
59
+ def from_config(cls, config: TConfig) -> Self:
60
+ """
61
+ Creates an instance of the method from a configuration object.
62
+
63
+ Args:
64
+ config: The configuration object.
65
+
66
+ Returns:
67
+ An instance of the PeftMethod.
68
+ """
69
+
70
+ ...
@@ -0,0 +1,11 @@
1
+ """
2
+ Package for Full Fine-Tuning functionality within the PEFT framework.
3
+ """
4
+
5
+ from .config import FullTuneConfig
6
+ from .method import FullTune
7
+
8
+ __all__ = [
9
+ "FullTune",
10
+ "FullTuneConfig"
11
+ ]