d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,20 @@
1
+ from re import Pattern
2
+ from typing import Literal
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class FullTuneConfig(BaseModel):
8
+ """
9
+ Configuration for Full Fine-Tuning.
10
+
11
+ Allows specifying which modules should be fully fine-tuned using regex patterns.
12
+
13
+ Attributes:
14
+ kind: Discriminator field, always "full_tune".
15
+ module_name_pattern: Regular expression matching module names to unfreeze.
16
+ """
17
+
18
+ kind: Literal["full_tune"] = "full_tune"
19
+
20
+ module_name_pattern: Pattern
@@ -0,0 +1,46 @@
1
+ from typing import Self
2
+
3
+ from torch import nn
4
+
5
+ from ..base import PeftInjectionResult, PeftMethod
6
+ from .config import FullTuneConfig
7
+
8
+
9
+ class FullTune(PeftMethod[FullTuneConfig]):
10
+ """
11
+ Implements Full Fine-Tuning as a 'PEFT' method.
12
+
13
+ Instead of injecting adapters, this method simply identifies existing parameters
14
+ that match the configuration pattern and marks them for training.
15
+ """
16
+
17
+ def __init__(self, config: FullTuneConfig):
18
+ """
19
+ Constructs a FullTune object.
20
+
21
+ Args:
22
+ config: Configuration defining the module name patterns to fine-tune.
23
+ """
24
+
25
+ self._config = config
26
+
27
+ def inject(self, module: nn.Module) -> PeftInjectionResult:
28
+ params_to_train = []
29
+
30
+ for mod_name, mod in module.named_modules():
31
+ is_applicable = self._config.module_name_pattern.fullmatch(mod_name)
32
+
33
+ if is_applicable:
34
+ params_to_train.extend(mod.parameters())
35
+
36
+ return PeftInjectionResult(
37
+ parameters_to_train=params_to_train,
38
+ load_state_mappers=[]
39
+ )
40
+
41
+ def merge(self, module: nn.Module):
42
+ pass # do nothing here
43
+
44
+ @classmethod
45
+ def from_config(cls, config: FullTuneConfig) -> Self:
46
+ return cls(config)
@@ -0,0 +1,15 @@
1
+ """
2
+ Package for Low-Rank Adaptation (LoRA) implementation.
3
+ """
4
+
5
+ from .config import LoRAConfig, LoRAParameters
6
+ from .layer import LoRAGroupedLinear, LoRALinear
7
+ from .method import LoRA
8
+
9
+ __all__ = [
10
+ "LoRA",
11
+ "LoRAConfig",
12
+ "LoRAGroupedLinear",
13
+ "LoRALinear",
14
+ "LoRAParameters"
15
+ ]
@@ -0,0 +1,35 @@
1
+ from re import Pattern
2
+ from typing import Literal
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class LoRAParameters(BaseModel):
8
+ """
9
+ Hyperparameters for LoRA layers.
10
+
11
+ Attributes:
12
+ r: Rank of the low-rank adaptation matrices.
13
+ alpha: Scaling factor for the learned weights.
14
+ dropout: Dropout probability for the input to LoRA layers.
15
+ """
16
+
17
+ r: int
18
+ alpha: int
19
+ dropout: float
20
+
21
+
22
+ class LoRAConfig(BaseModel):
23
+ """
24
+ Configuration for LoRA application.
25
+
26
+ Attributes:
27
+ kind: Discriminator field, always "lora".
28
+ module_name_pattern: Regular expression matching module names to wrap with LoRA.
29
+ params: Hyperparameters for the LoRA layers.
30
+ """
31
+
32
+ kind: Literal["lora"] = "lora"
33
+
34
+ module_name_pattern: Pattern
35
+ params: LoRAParameters
d9d/peft/lora/layer.py ADDED
@@ -0,0 +1,177 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from d9d.module.block.moe import GroupedLinear
5
+
6
+ from .config import LoRAParameters
7
+
8
+
9
+ class LoRALinear(nn.Module):
10
+ """
11
+ A LoRA wrapper around a standard PyTorch Linear layer.
12
+
13
+ Wraps a base linear layer and adds low-rank adaptation matrices A and B.
14
+
15
+ Attributes:
16
+ lora_A: The A matrix (in_features -> r).
17
+ lora_B: The B matrix (r -> out_features).
18
+ base: The original base Linear layer.
19
+ dropout: Scaling dropout layer.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ base_layer: nn.Linear,
25
+ params: LoRAParameters
26
+ ):
27
+ """
28
+ Constructs a LoRALinear layer.
29
+
30
+ Args:
31
+ base_layer: The original Linear layer to wrap.
32
+ params: LoRA hyperparameters (r, alpha, dropout).
33
+
34
+ Raises:
35
+ ValueError: If the base layer has a bias (currently unsupported).
36
+ """
37
+
38
+ super().__init__()
39
+ self.lora_A = nn.Linear(
40
+ base_layer.in_features, params.r, bias=False,
41
+ device=base_layer.weight.device,
42
+ dtype=base_layer.weight.dtype
43
+ )
44
+ self.lora_B = nn.Linear(
45
+ params.r, base_layer.out_features, bias=False,
46
+ device=base_layer.weight.device,
47
+ dtype=base_layer.weight.dtype
48
+ )
49
+ self.base = base_layer
50
+
51
+ if base_layer.bias is not None:
52
+ raise ValueError("LoRA is unsupported with biased linear layers")
53
+
54
+ self.dropout: nn.Dropout = nn.Dropout(params.dropout)
55
+
56
+ self._scale: float = params.alpha / params.r
57
+
58
+ self.reset_parameters()
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Takes input tensor, computes base output and LoRA adaptation, and returns the sum.
63
+
64
+ Args:
65
+ x: Input tensor.
66
+
67
+ Returns:
68
+ The output of base(x) + scale * (B @ A @ dropout(x)).
69
+ """
70
+
71
+ base_x = self.base(x)
72
+ adapt_x = self._scale * self.lora_B(self.lora_A(self.dropout(x)))
73
+ return base_x + adapt_x
74
+
75
+ @torch.no_grad()
76
+ def merge_with_base_(self) -> nn.Linear:
77
+ """
78
+ Collapse the LoRA weights into the base linear layer.
79
+
80
+ Returns:
81
+ The modified base linear layer with updated weights.
82
+ """
83
+
84
+ mod = self.base
85
+ mod.weight.data += (self.lora_B.weight.data @ self.lora_A.weight.data) * self._scale
86
+ return mod
87
+
88
+ def reset_parameters(self):
89
+ """
90
+ Resets LoRA parameters. A is random, B is zeroed.
91
+ """
92
+
93
+ self.lora_A.reset_parameters()
94
+ nn.init.zeros_(self.lora_B.weight)
95
+
96
+
97
+ class LoRAGroupedLinear(nn.Module):
98
+ """
99
+ A LoRA wrapper around a GroupedLinear layer (commonly used in MoE or grouped query attention).
100
+
101
+ Attributes:
102
+ lora_A: The A matrix (grouped linear).
103
+ lora_B: The B matrix (grouped linear).
104
+ base: The original base GroupedLinear layer.
105
+ dropout: Scaling dropout layer.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ base_layer: GroupedLinear,
111
+ params: LoRAParameters
112
+ ):
113
+ """
114
+ Constructs a LoRAGroupedLinear layer.
115
+
116
+ Args:
117
+ base_layer: The original GroupedLinear layer to wrap.
118
+ params: LoRA hyperparameters.
119
+ """
120
+
121
+ super().__init__()
122
+ self.lora_A = GroupedLinear(
123
+ base_layer.n_groups, base_layer.in_features, params.r,
124
+ device=base_layer.weight.device,
125
+ dtype=base_layer.weight.dtype
126
+ )
127
+ self.lora_B = GroupedLinear(
128
+ base_layer.n_groups,
129
+ params.r,
130
+ base_layer.out_features,
131
+ device=base_layer.weight.device,
132
+ dtype=base_layer.weight.dtype
133
+ )
134
+ self.base = base_layer
135
+
136
+ self.dropout = nn.Dropout(params.dropout)
137
+
138
+ self._scale = params.alpha / params.r
139
+
140
+ self.reset_parameters()
141
+
142
+ def forward(self, x: torch.Tensor, x_groups: torch.Tensor) -> torch.Tensor:
143
+ """
144
+ Computes forward pass for grouped inputs.
145
+
146
+ Args:
147
+ x: Input tensor.
148
+ x_groups: A tensor indicating group indices for each input.
149
+
150
+ Returns:
151
+ Combined output of base and LoRA path.
152
+ """
153
+
154
+ base_x = self.base(x, x_groups)
155
+ adapt_x = self._scale * self.lora_B(self.lora_A(self.dropout(x), x_groups), x_groups)
156
+ return base_x + adapt_x
157
+
158
+ @torch.no_grad()
159
+ def merge_with_base_(self) -> GroupedLinear:
160
+ """
161
+ Collapse the LoRA weights into the base GroupedLinear layer.
162
+
163
+ Returns:
164
+ The modified GroupedLinear layer.
165
+ """
166
+
167
+ mod = self.base
168
+ mod.weight.data += (torch.bmm(self.lora_A.weight.data, self.lora_B.weight.data)) * self._scale
169
+ return mod
170
+
171
+ def reset_parameters(self):
172
+ """
173
+ Resets LoRA parameters. A is random, B is zeroed.
174
+ """
175
+
176
+ self.lora_A.reset_parameters()
177
+ nn.init.zeros_(self.lora_B.weight)
@@ -0,0 +1,132 @@
1
+ from typing import Self
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from d9d.model_state.mapper import ModelStateMapper
7
+ from d9d.model_state.mapper.leaf import ModelStateMapperRename
8
+ from d9d.module.block.moe import GroupedLinear
9
+
10
+ from ..base import PeftInjectionResult, PeftMethod
11
+ from .config import LoRAConfig
12
+ from .layer import LoRAGroupedLinear, LoRALinear
13
+
14
+ _CAN_APPLY_MODULES = (nn.Linear, GroupedLinear)
15
+ _LORA_MODULES = (LoRALinear, LoRAGroupedLinear)
16
+
17
+
18
+ def named_modules_without_lora(
19
+ module: nn.Module,
20
+ memo: set[nn.Module] | None = None,
21
+ prefix: str = "",
22
+ remove_duplicate: bool = True
23
+ ):
24
+ """
25
+ Yields named modules, skipping submodules that are already LoRA layers.
26
+
27
+ This prevents recursively re-injecting LoRA into an already wrapped layer during
28
+ traversal.
29
+
30
+ Args:
31
+ module: The root module to traverse.
32
+ memo: Set of processed modules to avoid duplicates.
33
+ prefix: Current namespace prefix.
34
+ remove_duplicate: Whether to skip modules seen in memo.
35
+
36
+ Yields:
37
+ Tuple of (name, module).
38
+ """
39
+
40
+ if isinstance(module, _LORA_MODULES):
41
+ return
42
+
43
+ if memo is None:
44
+ memo = set()
45
+ if module in memo:
46
+ return
47
+
48
+ if remove_duplicate:
49
+ memo.add(module)
50
+
51
+ yield prefix, module
52
+
53
+ for name, submodule in module.named_children():
54
+ if submodule is None:
55
+ continue
56
+
57
+ submodule_prefix = prefix + ("." if prefix else "") + name
58
+ yield from named_modules_without_lora(submodule, memo, submodule_prefix, remove_duplicate)
59
+
60
+
61
+ class LoRA(PeftMethod[LoRAConfig]):
62
+ """
63
+ Implements the Low-Rank Adaptation (LoRA) injection strategy.
64
+
65
+ It scans the module structure for `nn.Linear` or `GroupedLinear` layers matching
66
+ the configured name pattern. Matched layers are replaced with LoRA wrappers.
67
+
68
+ It also generates `ModelStateMapperRename` objects. Since the original weight
69
+ `layer.weight` is now at `layer.base.weight` inside the wrapper, the mapper
70
+ ensures that loading a standard checkpoint still works by redirecting the key.
71
+ """
72
+
73
+ def __init__(self, config: LoRAConfig):
74
+ """
75
+ Constructs a LoRA method.
76
+
77
+ Args:
78
+ config: LoRA configuration containing patterns and hyperparameters.
79
+ """
80
+
81
+ self._config = config
82
+
83
+ def inject(self, module: nn.Module) -> PeftInjectionResult:
84
+ params_to_train: list[nn.Parameter] = []
85
+ state_mappers: list[ModelStateMapper] = []
86
+
87
+ for mod_name, mod in named_modules_without_lora(module):
88
+ if not isinstance(mod, _CAN_APPLY_MODULES):
89
+ continue
90
+
91
+ if not self._config.module_name_pattern.fullmatch(mod_name):
92
+ continue
93
+
94
+ lora_mod: LoRALinear | LoRAGroupedLinear
95
+ if isinstance(mod, nn.Linear):
96
+ lora_mod = LoRALinear(mod, self._config.params)
97
+ elif isinstance(mod, GroupedLinear):
98
+ lora_mod = LoRAGroupedLinear(mod, self._config.params)
99
+ else:
100
+ raise ValueError(f"Unknown layer {type(mod)} for LoRA")
101
+
102
+ params_to_train.extend(lora_mod.lora_A.parameters())
103
+ params_to_train.extend(lora_mod.lora_B.parameters())
104
+
105
+ state_mappers.append(ModelStateMapperRename(
106
+ name_from=f"{mod_name}.weight",
107
+ name_to=f"{mod_name}.base.weight"
108
+ ))
109
+
110
+ module.set_submodule(mod_name, lora_mod)
111
+
112
+ return PeftInjectionResult(
113
+ parameters_to_train=params_to_train,
114
+ load_state_mappers=state_mappers
115
+ )
116
+
117
+ def merge(self, module: nn.Module):
118
+ for mod_name, mod in module.named_modules():
119
+ if not isinstance(mod, _LORA_MODULES):
120
+ continue
121
+
122
+ if not self._config.module_name_pattern.fullmatch(mod_name):
123
+ continue
124
+
125
+ with torch.no_grad():
126
+ orig_mod = mod.merge_with_base_()
127
+
128
+ module.set_submodule(mod_name, orig_mod)
129
+
130
+ @classmethod
131
+ def from_config(cls, config: LoRAConfig) -> Self:
132
+ return cls(config)
File without changes
@@ -0,0 +1,19 @@
1
+ """
2
+ Pipelining API that is intended to be accessible by end user.
3
+ """
4
+
5
+ from .module import (
6
+ ModuleSupportsPipelining,
7
+ PipelineStageInfo,
8
+ distribute_layers_for_pipeline_stage,
9
+ )
10
+ from .schedule import PipelineSchedule
11
+ from .sharding import PipelineShardingSpec
12
+
13
+ __all__ = [
14
+ "ModuleSupportsPipelining",
15
+ "PipelineSchedule",
16
+ "PipelineShardingSpec",
17
+ "PipelineStageInfo",
18
+ "distribute_layers_for_pipeline_stage"
19
+ ]
@@ -0,0 +1,149 @@
1
+ import dataclasses
2
+ import typing
3
+
4
+ import torch
5
+
6
+
7
+ @dataclasses.dataclass
8
+ class PipelineStageInfo:
9
+ """
10
+ Holds information about the current position within the distributed pipeline.
11
+
12
+ Attributes:
13
+ current_stage: The 0-based index of the current pipeline stage.
14
+ num_stages: The total number of stages in the pipeline.
15
+ """
16
+
17
+ current_stage: int
18
+ num_stages: int
19
+
20
+ @property
21
+ def is_current_stage_first(self) -> bool:
22
+ """
23
+ Determines if this is the first stage in the pipeline.
24
+
25
+ Returns:
26
+ True if current_stage is 0.
27
+ """
28
+
29
+ return self.current_stage == 0
30
+
31
+ @property
32
+ def is_current_stage_last(self) -> bool:
33
+ """
34
+ Determines if this is the last stage in the pipeline.
35
+
36
+ Returns:
37
+ True if current_stage is the last index.
38
+ """
39
+
40
+ return self.current_stage == self.num_stages - 1
41
+
42
+
43
+ def distribute_layers_for_pipeline_stage(
44
+ num_layers: int,
45
+ num_virtual_layers_pre: int,
46
+ num_virtual_layers_post: int,
47
+ stage: PipelineStageInfo
48
+ ) -> tuple[int, int]:
49
+ """
50
+ Calculates the layer index range for a specific pipeline stage.
51
+
52
+ This function distributes a given number of layers across multiple pipeline
53
+ stages as evenly as possible. It accounts for additional, non-layer
54
+ computational load on the first and last stages (e.g., embeddings and the
55
+ LM head) by using the concept of 'virtual layers' to reserve capacity.
56
+
57
+ Args:
58
+ num_layers: The total number of primary model layers to be distributed
59
+ (e.g., the transformer blocks).
60
+ num_virtual_layers_pre: The number of 'virtual' layers representing the
61
+ computational cost of modules on the *first* stage, before the main
62
+ layers (e.g., token and positional embeddings).
63
+ num_virtual_layers_post: The number of 'virtual' layers representing the
64
+ computational cost of modules on the *last* stage, after the main
65
+ layers (e.g., the final layer normalization and LM head).
66
+ stage: An object containing total stages and current stage index.
67
+
68
+ Returns:
69
+ A tuple (start_index, end_index), representing the slice of layers for
70
+ the given stage. The start_index is inclusive and the end_index is
71
+ exclusive.
72
+
73
+ Raises:
74
+ ValueError: If the pipeline configuration results in a stage having zero
75
+ or negative layers assigned (pipeline too long for the model size).
76
+ """
77
+
78
+ num_layers_virtual = num_layers + num_virtual_layers_pre + num_virtual_layers_post
79
+
80
+ base_layers_per_stage = num_layers_virtual // stage.num_stages
81
+ extra_layers = num_layers_virtual % stage.num_stages
82
+
83
+ layer_count_per_stage = []
84
+
85
+ for proposed_stage_i in range(stage.num_stages):
86
+ proposed_stage = PipelineStageInfo(num_stages=stage.num_stages, current_stage=proposed_stage_i)
87
+ layers = base_layers_per_stage + 1 if proposed_stage_i < extra_layers else base_layers_per_stage
88
+
89
+ adjustment = 0
90
+ if proposed_stage.is_current_stage_first:
91
+ adjustment += num_virtual_layers_pre
92
+ if proposed_stage.is_current_stage_last:
93
+ adjustment += num_virtual_layers_post
94
+
95
+ actual_layers = layers - adjustment
96
+
97
+ if actual_layers <= 0:
98
+ raise ValueError(f"Tried to distribute layers, but got {actual_layers} on "
99
+ f"stage {proposed_stage.current_stage}. Perhaps the pipeline is too long for this model?")
100
+
101
+ layer_count_per_stage.append(actual_layers)
102
+
103
+ start_layer_id = sum(layer_count_per_stage[:stage.current_stage])
104
+ num_layers_in_stage = layer_count_per_stage[stage.current_stage]
105
+
106
+ return start_layer_id, start_layer_id + num_layers_in_stage
107
+
108
+
109
+ @typing.runtime_checkable
110
+ class ModuleSupportsPipelining(typing.Protocol):
111
+ """
112
+ Protocol for modules that support pipeline parallelism metadata inference.
113
+
114
+ Classes implementing this protocol enable the framework to pre-calculate
115
+ tensor shapes and types required for inter-stage communication (p2p)
116
+ without executing the full forward pass.
117
+ """
118
+
119
+ def infer_stage_inputs_from_pipeline_inputs(
120
+ self, inputs: dict[str, torch.Tensor], n_microbatches: int
121
+ ) -> dict[str, torch.Tensor]:
122
+ """
123
+ Infers the input tensors metadata for the current pipeline stage based on global batch inputs.
124
+
125
+ Args:
126
+ inputs: Global inputs for the pipeline.
127
+ n_microbatches: Number of microbatches the global batch is split into.
128
+
129
+ Returns:
130
+ Dictionary of input tensors expected by this specific stage locally.
131
+ """
132
+
133
+ ...
134
+
135
+ def infer_stage_outputs_from_pipeline_inputs(
136
+ self, inputs: dict[str, torch.Tensor], n_microbatches: int
137
+ ) -> dict[str, torch.Tensor]:
138
+ """
139
+ Infers the output tensors metadata for the current pipeline stage based on global batch inputs.
140
+
141
+ Args:
142
+ inputs: Global inputs for the pipeline (typically a batch).
143
+ n_microbatches: Number of microbatches the global batch is split into.
144
+
145
+ Returns:
146
+ Dictionary of output tensors produced by this specific stage locally.
147
+ """
148
+
149
+ ...
@@ -0,0 +1,50 @@
1
+ import abc
2
+ from typing import Any
3
+
4
+ import torch
5
+
6
+ from .sharding import PipelineShardingSpec
7
+
8
+ # TODO: feature - support any PyTrees as pipeline parameters
9
+
10
+
11
+ class PipelineSchedule(abc.ABC):
12
+ """Abstract base class defining the interface for pipeline execution schedules."""
13
+
14
+ @abc.abstractmethod
15
+ def configure_buffers(
16
+ self,
17
+ inputs: dict[str, torch.Tensor],
18
+ kwargs: dict[str, Any],
19
+ sharding_spec: PipelineShardingSpec | None
20
+ ):
21
+ """
22
+ Configures internal state and buffers based on input shapes.
23
+
24
+ This method allows the schedule to pre-allocate memory or setup sharding
25
+ specifications based on the structure of the input data before execution begins.
26
+
27
+ Args:
28
+ inputs: A dictionary of input tensors.
29
+ kwargs: A dictionary of keyword arguments.
30
+ sharding_spec: A specification defining how inputs and kwargs should be split
31
+ into micro-batches. If None, assumes standard split-by-zero-dim behavior.
32
+ """
33
+
34
+ ...
35
+
36
+ @abc.abstractmethod
37
+ def step(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
38
+ """
39
+ Executes a single pipeline step using the provided inputs.
40
+
41
+ This typically involves distributing inputs across microbatches,
42
+ executing forward and backward passes according to the specific schedule logic,
43
+ and handling communications between stages.
44
+
45
+ Args:
46
+ inputs: A dictionary of global input tensors.
47
+ kwargs: A dictionary of global keyword arguments.
48
+ """
49
+
50
+ ...
@@ -0,0 +1,9 @@
1
+ import dataclasses
2
+
3
+ from d9d.core.sharding import ShardingSpec
4
+
5
+
6
+ @dataclasses.dataclass
7
+ class PipelineShardingSpec:
8
+ input_data: ShardingSpec | None = None
9
+ input_kwargs: ShardingSpec | None = None
@@ -0,0 +1,21 @@
1
+ from .config import (
2
+ AnyPipelineScheduleConfig,
3
+ PipelineSchedule1F1BConfig,
4
+ PipelineScheduleDualPipeVConfig,
5
+ PipelineScheduleGPipeConfig,
6
+ PipelineScheduleInferenceConfig,
7
+ PipelineScheduleLoopedBFSConfig,
8
+ PipelineScheduleZeroBubbleVConfig,
9
+ )
10
+ from .factory import build_schedule
11
+
12
+ __all__ = [
13
+ "AnyPipelineScheduleConfig",
14
+ "PipelineSchedule1F1BConfig",
15
+ "PipelineScheduleDualPipeVConfig",
16
+ "PipelineScheduleGPipeConfig",
17
+ "PipelineScheduleInferenceConfig",
18
+ "PipelineScheduleLoopedBFSConfig",
19
+ "PipelineScheduleZeroBubbleVConfig",
20
+ "build_schedule"
21
+ ]