d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,373 @@
1
+ from collections.abc import Mapping
2
+ from typing import cast
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils.checkpoint import checkpoint
7
+
8
+ from d9d.module.base import ModuleLateInit
9
+ from d9d.module.block.embedding import SplitTokenEmbeddings
10
+ from d9d.module.block.head import SplitLanguageModellingHead
11
+ from d9d.module.block.hidden_states_aggregator import HiddenStatesAggregationMode, create_hidden_states_aggregator
12
+ from d9d.module.block.positional import RotaryEmbeddingProvider
13
+ from d9d.pipelining.api import (
14
+ ModuleSupportsPipelining,
15
+ PipelineStageInfo,
16
+ distribute_layers_for_pipeline_stage,
17
+ )
18
+
19
+ from .decoder_layer import Qwen3MoELayer
20
+ from .params import Qwen3MoEForCausalLMParameters, Qwen3MoEParameters
21
+
22
+
23
+ class Qwen3MoEModel(nn.Module, ModuleLateInit, ModuleSupportsPipelining):
24
+ """
25
+ The Qwen3 Mixture-of-Experts (MoE) Transformer Decoder backbone.
26
+
27
+ It is designed to be split across multiple pipeline stages.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ params: Qwen3MoEParameters,
33
+ stage: PipelineStageInfo,
34
+ hidden_states_snapshot_mode: HiddenStatesAggregationMode,
35
+ enable_checkpointing: bool
36
+ ):
37
+ """
38
+ Constructs the Qwen3MoEModel object.
39
+
40
+ Args:
41
+ params: Configuration parameters for the full model.
42
+ stage: Information about the pipeline stage this instance belongs to.
43
+ hidden_states_snapshot_mode: Configures intermediate hidden state aggregation & snapshotting mode
44
+ enable_checkpointing: If True, enables activation checkpointing for transformer layers to save memory.
45
+ """
46
+
47
+ super().__init__()
48
+
49
+ if stage.is_current_stage_first:
50
+ self.embed_tokens = SplitTokenEmbeddings(
51
+ hidden_size=params.layer.hidden_size,
52
+ split_vocab_size=params.split_vocab_size,
53
+ split_order=params.split_vocab_order
54
+ )
55
+
56
+ # we use ModuleDict here to properly handle pipelining and loading weights after the model
57
+ # was pipelined
58
+ layer_start, layer_end = distribute_layers_for_pipeline_stage(
59
+ num_layers=params.num_hidden_layers,
60
+ num_virtual_layers_pre=params.pipeline_num_virtual_layers_pre, # embeddings
61
+ num_virtual_layers_post=params.pipeline_num_virtual_layers_post, # LM head
62
+ stage=stage
63
+ )
64
+
65
+ self._num_layers_before = layer_start
66
+ self._layers_iter = list(map(str, range(layer_start, layer_end)))
67
+ layers = nn.ModuleDict({
68
+ str(layer_idx): Qwen3MoELayer(params=params.layer) for layer_idx in self._layers_iter
69
+ })
70
+ self.layers: Mapping[str, Qwen3MoELayer] = cast(Mapping[str, Qwen3MoELayer], layers)
71
+
72
+ self.rope_provider = RotaryEmbeddingProvider(
73
+ max_position_ids=params.max_position_ids,
74
+ rope_base=params.rope_base,
75
+ head_dim=params.layer.head_dim
76
+ )
77
+
78
+ if stage.is_current_stage_last:
79
+ self.norm = nn.RMSNorm(
80
+ normalized_shape=params.layer.hidden_size,
81
+ eps=params.layer.rms_norm_eps
82
+ )
83
+
84
+ self._stage = stage
85
+ self._hidden_states_snapshot_mode = hidden_states_snapshot_mode
86
+ self._hidden_size = params.layer.hidden_size
87
+ self._enable_checkpointing = enable_checkpointing
88
+
89
+ def output_dtype(self) -> torch.dtype:
90
+ """
91
+ Returns the data type of the model output hidden states.
92
+ """
93
+ return self.layers[self._layers_iter[0]].input_layernorm.weight.dtype
94
+
95
+ def forward(
96
+ self,
97
+ input_ids: torch.Tensor | None = None,
98
+ hidden_states: torch.Tensor | None = None,
99
+ position_ids: torch.Tensor | None = None,
100
+ hidden_states_snapshot: torch.Tensor | None = None,
101
+ hidden_states_agg_mask: torch.Tensor | None = None,
102
+ ) -> dict[str, torch.Tensor]:
103
+ """
104
+ Executes the forward pass for the current pipeline stage.
105
+
106
+ Args:
107
+ input_ids: Indices of input sequence tokens. Required if this is the
108
+ first pipeline stage.
109
+ hidden_states: Hidden states from the previous pipeline stage. Required
110
+ if this is not the first pipeline stage.
111
+ position_ids: Indices of positions of each input sequence tokens in the
112
+ position embeddings.
113
+ hidden_states_snapshot: Accumulated tensor of aggregated hidden states
114
+ from previous stages. Used if snapshotting is enabled.
115
+ hidden_states_agg_mask: Mask used to aggregate hidden states for
116
+ snapshots.
117
+
118
+ Returns:
119
+ A dictionary containing:
120
+ * 'hidden_states': The output of the last layer in this stage.
121
+ * 'hidden_states_snapshot': (Optional) The updated snapshot tensor.
122
+ """
123
+ state_aggregator = create_hidden_states_aggregator(self._hidden_states_snapshot_mode, hidden_states_agg_mask)
124
+
125
+ if input_ids is not None:
126
+ last_hidden_states = self.embed_tokens(input_ids)
127
+ state_aggregator.add_hidden_states(last_hidden_states)
128
+ else:
129
+ last_hidden_states = hidden_states
130
+
131
+ rope_params = self.rope_provider(position_ids)
132
+
133
+ for decoder_layer_name in self._layers_iter:
134
+ decoder_layer = self.layers[decoder_layer_name]
135
+
136
+ if self._enable_checkpointing:
137
+ last_hidden_states = checkpoint(
138
+ decoder_layer, last_hidden_states, rope_params,
139
+ use_reentrant=False
140
+ )
141
+ else:
142
+ last_hidden_states = decoder_layer(last_hidden_states, rope_params)
143
+
144
+ state_aggregator.add_hidden_states(last_hidden_states)
145
+
146
+ if self._stage.is_current_stage_last:
147
+ last_hidden_states = self.norm(last_hidden_states)
148
+
149
+ return {
150
+ "hidden_states": last_hidden_states,
151
+ "hidden_states_snapshot": state_aggregator.pack_with_snapshot(hidden_states_snapshot)
152
+ }
153
+
154
+ def reset_moe_stats(self):
155
+ """
156
+ Resets routing statistics for all MoE layers in this stage.
157
+ """
158
+
159
+ for layer_name in self._layers_iter:
160
+ self.layers[layer_name].reset_moe_stats()
161
+
162
+ @property
163
+ def moe_tokens_per_expert(self) -> torch.Tensor:
164
+ """
165
+ Retrieves the number of tokens routed to each expert across all layers.
166
+
167
+ Returns:
168
+ A tensor of shape (num_local_layers, num_experts) containing counts.
169
+ """
170
+
171
+ return torch.stack(
172
+ [self.layers[layer_name].moe_tokens_per_expert for layer_name in self._layers_iter],
173
+ dim=0
174
+ )
175
+
176
+ def reset_parameters(self):
177
+ """Resets module parameters"""
178
+
179
+ if self._stage.is_current_stage_first:
180
+ self.embed_tokens.reset_parameters()
181
+
182
+ self.rope_provider.reset_parameters()
183
+
184
+ for decoder_layer_name in self._layers_iter:
185
+ decoder_layer = self.layers[decoder_layer_name]
186
+ decoder_layer.reset_parameters()
187
+
188
+ if self._stage.is_current_stage_last:
189
+ self.norm.reset_parameters()
190
+
191
+ def infer_stage_inputs_from_pipeline_inputs(
192
+ self, inputs: dict[str, torch.Tensor], n_microbatches: int
193
+ ) -> dict[str, torch.Tensor]:
194
+ input_ids = inputs["input_ids"]
195
+
196
+ pp_inputs = {}
197
+
198
+ # for calculation - input ids or prev hidden state
199
+ if self._stage.is_current_stage_first:
200
+ pp_inputs["input_ids"] = torch.empty(
201
+ (input_ids.shape[0] // n_microbatches, input_ids.shape[1]),
202
+ dtype=torch.long,
203
+ device=input_ids.device
204
+ )
205
+ else:
206
+ pp_inputs["hidden_states"] = torch.empty(
207
+ (input_ids.shape[0] // n_microbatches, input_ids.shape[1], self._hidden_size),
208
+ dtype=self.output_dtype(),
209
+ device=input_ids.device
210
+ )
211
+ if self._hidden_states_snapshot_mode != HiddenStatesAggregationMode.no:
212
+ num_layers_before = self._num_layers_before + 1 # 1 for embedding
213
+ pp_inputs["hidden_states_snapshot"] = torch.empty(
214
+ (num_layers_before, input_ids.shape[0] // n_microbatches, self._hidden_size),
215
+ dtype=self.output_dtype(),
216
+ device=input_ids.device
217
+ )
218
+
219
+ return pp_inputs
220
+
221
+ def infer_stage_outputs_from_pipeline_inputs(
222
+ self, inputs: dict[str, torch.Tensor], n_microbatches: int
223
+ ) -> dict[str, torch.Tensor]:
224
+ input_ids = inputs["input_ids"]
225
+
226
+ # for calculation - last hidden state
227
+ pp_outputs = {
228
+ "hidden_states": torch.empty(
229
+ (input_ids.shape[0] // n_microbatches, input_ids.shape[1], self._hidden_size),
230
+ dtype=self.output_dtype(),
231
+ device=input_ids.device
232
+ )
233
+ }
234
+
235
+ # for state caching
236
+ if self._hidden_states_snapshot_mode != HiddenStatesAggregationMode.no:
237
+ num_layers_before = self._num_layers_before + 1
238
+ num_layers_current = len(self.layers)
239
+ num_layers_after = num_layers_before + num_layers_current
240
+ pp_outputs["hidden_states_snapshot"] = torch.empty(
241
+ (num_layers_after, input_ids.shape[0] // n_microbatches, self._hidden_size),
242
+ dtype=self.output_dtype(),
243
+ device=input_ids.device
244
+ )
245
+
246
+ return pp_outputs
247
+
248
+
249
+ class Qwen3MoEForCausalLM(nn.Module, ModuleLateInit, ModuleSupportsPipelining):
250
+ """
251
+ A Qwen3 MoE model wrapped with a Causal Language Modeling head.
252
+
253
+ It is designed to be split across multiple pipeline stages.
254
+ """
255
+
256
+ def __init__(
257
+ self,
258
+ params: Qwen3MoEForCausalLMParameters,
259
+ stage: PipelineStageInfo,
260
+ hidden_states_snapshot_mode: HiddenStatesAggregationMode,
261
+ enable_checkpointing: bool
262
+ ):
263
+ """
264
+ Constructs the Qwen3MoEForCausalLM object.
265
+
266
+ Args:
267
+ params: Full model configuration parameters.
268
+ stage: Pipeline stage information for this instance.
269
+ hidden_states_snapshot_mode: Configures intermediate hidden state aggregation & snapshotting mode.
270
+ enable_checkpointing: Whether to enable activation checkpointing.
271
+ """
272
+
273
+ super().__init__()
274
+
275
+ self.model = Qwen3MoEModel(
276
+ params.model,
277
+ stage,
278
+ hidden_states_snapshot_mode=hidden_states_snapshot_mode,
279
+ enable_checkpointing=enable_checkpointing
280
+ )
281
+
282
+ if stage.is_current_stage_last:
283
+ self.lm_head = SplitLanguageModellingHead(
284
+ split_vocab_size=params.model.split_vocab_size,
285
+ split_order=params.model.split_vocab_order,
286
+ hidden_size=params.model.layer.hidden_size
287
+ )
288
+
289
+ self._stage = stage
290
+ self._hidden_size = params.model.layer.hidden_size
291
+
292
+ def forward(
293
+ self,
294
+ input_ids: torch.Tensor | None = None,
295
+ hidden_states: torch.Tensor | None = None,
296
+ position_ids: torch.Tensor | None = None,
297
+ hidden_states_snapshot: torch.Tensor | None = None,
298
+ hidden_states_agg_mask: torch.Tensor | None = None,
299
+ labels: torch.Tensor | None = None
300
+ ) -> dict[str, torch.Tensor]:
301
+ """
302
+ Executes the model forward pass.
303
+
304
+ If this is the last stage, it expects `labels` to be provided and computes
305
+ the cross-entropy loss (returned as 'logps' typically representing per-token loss).
306
+
307
+ Args:
308
+ input_ids: Input token IDS (for Stage 0).
309
+ hidden_states: Hidden states from previous stage (for Stage > 0).
310
+ position_ids: Positional indices for RoPE.
311
+ hidden_states_snapshot: Intermediate state collector.
312
+ hidden_states_agg_mask: Mask for state aggregation.
313
+ labels: Target tokens for loss computation (Last Stage).
314
+
315
+ Returns:
316
+ Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot',
317
+ and per-token 'logps' if on the last stage.
318
+ """
319
+
320
+ model_outputs = self.model(
321
+ input_ids=input_ids,
322
+ hidden_states=hidden_states,
323
+ position_ids=position_ids,
324
+ hidden_states_snapshot=hidden_states_snapshot,
325
+ hidden_states_agg_mask=hidden_states_agg_mask
326
+ )
327
+ if self._stage.is_current_stage_last:
328
+ lm_out = self.lm_head(
329
+ hidden_states=model_outputs["hidden_states"],
330
+ labels=labels
331
+ )
332
+ model_outputs["logps"] = lm_out
333
+ return model_outputs
334
+
335
+ def reset_parameters(self):
336
+ """
337
+ Resets module parameters.
338
+ """
339
+
340
+ self.model.reset_parameters()
341
+
342
+ if self._stage.is_current_stage_last:
343
+ self.lm_head.reset_parameters()
344
+
345
+ def reset_moe_stats(self):
346
+ """
347
+ Resets MoE routing statistics in the backbone.
348
+ """
349
+
350
+ self.model.reset_moe_stats()
351
+
352
+ @property
353
+ def moe_tokens_per_expert(self) -> torch.Tensor:
354
+ """
355
+ Accesses MoE routing statistics from the backbone.
356
+ """
357
+
358
+ return self.model.moe_tokens_per_expert
359
+
360
+ def infer_stage_inputs_from_pipeline_inputs(
361
+ self, inputs: dict[str, torch.Tensor], n_microbatches: int
362
+ ) -> dict[str, torch.Tensor]:
363
+ return self.model.infer_stage_inputs_from_pipeline_inputs(inputs, n_microbatches)
364
+
365
+ def infer_stage_outputs_from_pipeline_inputs(
366
+ self, inputs: dict[str, torch.Tensor], n_microbatches: int
367
+ ) -> dict[str, torch.Tensor]:
368
+ pp_outputs = self.model.infer_stage_outputs_from_pipeline_inputs(inputs, n_microbatches)
369
+
370
+ if self._stage.is_current_stage_last:
371
+ pp_outputs["logps"] = torch.empty(inputs["input_ids"].shape, dtype=torch.float32)
372
+
373
+ return pp_outputs
@@ -0,0 +1,69 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class Qwen3MoELayerParameters(BaseModel):
5
+ """
6
+ Configuration parameters for a single Qwen3 MoE layer.
7
+
8
+ Attributes:
9
+ hidden_size: Dimension of the model's hidden states.
10
+ intermediate_size: Dimension of the feed-forward hidden state.
11
+ num_experts: Total number of experts in the MoE layer.
12
+ experts_top_k: Number of experts to route tokens to.
13
+ num_attention_heads: Number of attention heads for the query.
14
+ num_key_value_heads: Number of attention heads for key and value.
15
+ rms_norm_eps: Epsilon value found in the RMSNorm layers.
16
+ head_dim: Dimension of a single attention head.
17
+ """
18
+
19
+ hidden_size: int
20
+ intermediate_size: int
21
+ num_experts: int
22
+ experts_top_k: int
23
+ num_attention_heads: int
24
+ num_key_value_heads: int
25
+ rms_norm_eps: float
26
+ head_dim: int
27
+
28
+
29
+ class Qwen3MoEParameters(BaseModel):
30
+ """
31
+ Configuration parameters for the Qwen3 Mixture-of-Experts model backbone.
32
+
33
+ Attributes:
34
+ layer: Configuration shared across all transformer layers.
35
+ num_hidden_layers: The total number of transformer layers.
36
+ rope_base: Base value for RoPE frequency calculation.
37
+ max_position_ids: Maximum sequence length.
38
+ split_vocab_size: A dictionary mapping vocabulary segment names to their sizes.
39
+ split_vocab_order: The sequence in which vocabulary splits are correctly ordered.
40
+ pipeline_num_virtual_layers_pre: The number of 'virtual' layers representing the
41
+ computational cost of modules on the *first* stage, before the main
42
+ layers (e.g., token and positional embeddings).
43
+ pipeline_num_virtual_layers_post: The number of 'virtual' layers representing the
44
+ computational cost of modules on the *last* stage, after the main
45
+ layers (e.g., the final layer normalization and LM head).
46
+ """
47
+
48
+ layer: Qwen3MoELayerParameters
49
+
50
+ num_hidden_layers: int
51
+ rope_base: int
52
+ max_position_ids: int
53
+
54
+ split_vocab_size: dict[str, int]
55
+ split_vocab_order: list[str]
56
+
57
+ pipeline_num_virtual_layers_pre: int = 0
58
+ pipeline_num_virtual_layers_post: int = 0
59
+
60
+
61
+ class Qwen3MoEForCausalLMParameters(BaseModel):
62
+ """
63
+ Configuration parameters for Qwen3 Mixture-of-Experts model with a Causal Language Modeling head.
64
+
65
+ Attributes:
66
+ model: The configuration for the underlying Qwen3 MoE model.
67
+ """
68
+
69
+ model: Qwen3MoEParameters
File without changes
@@ -0,0 +1,18 @@
1
+ """
2
+ Horizontal parallelism strategies and utilities for d9d modules.
3
+
4
+ This package provides high-level helper functions to apply specific distributed
5
+ parallelism strategies to PyTorch modules compatible with the d9d ecosystem.
6
+ """
7
+
8
+ from .expert_parallel import parallelize_expert_parallel
9
+ from .fully_sharded import parallelize_fsdp
10
+ from .hybrid_sharded import parallelize_hsdp
11
+ from .replicate_parallel import parallelize_replicate
12
+
13
+ __all__ = [
14
+ "parallelize_expert_parallel",
15
+ "parallelize_fsdp",
16
+ "parallelize_hsdp",
17
+ "parallelize_replicate"
18
+ ]
@@ -0,0 +1,36 @@
1
+ from torch.distributed import DeviceMesh
2
+ from torch.distributed.tensor import Replicate
3
+ from torch.distributed.tensor.parallel import parallelize_module
4
+
5
+ from d9d.module.block.moe import MoELayer
6
+ from d9d.module.parallelism.style import ShardMoESparseExpertsParallel, ToLocalParallel
7
+
8
+
9
+ def parallelize_expert_parallel(
10
+ module: MoELayer,
11
+ mesh_experts: DeviceMesh,
12
+ expert_shard_dim: str = "ep_shard"
13
+ ):
14
+ """
15
+ Applies Expert Parallelism to a MoE layer.
16
+
17
+ This function configures the provided Mixture of Experts layer for distributed
18
+ execution.
19
+
20
+ It partitions the sparse experts across the specified dimension
21
+ of the device mesh (Expert Parallelism) and replicates along other dims.
22
+
23
+ Simultaneously, it configures the router to be fully replicated across
24
+ the mesh.
25
+
26
+ Args:
27
+ module: The MoE layer instance to parallelize.
28
+ mesh_experts: The device mesh containing the expert parallel resources.
29
+ expert_shard_dim: The name of the mesh dimension where experts should be sharded.
30
+ """
31
+
32
+ parallelize_module(module, mesh_experts, ShardMoESparseExpertsParallel(shard_dim_name=expert_shard_dim))
33
+ parallelize_module(module.router, mesh_experts, ToLocalParallel(
34
+ param_placement=tuple(Replicate() for _ in range(mesh_experts.ndim)),
35
+ grad_placement=tuple(Replicate() for _ in range(mesh_experts.ndim))
36
+ ))
@@ -0,0 +1,43 @@
1
+ from typing import Any
2
+
3
+ from torch import nn
4
+ from torch.distributed import DeviceMesh
5
+ from torch.distributed.fsdp import FSDPModule, fully_shard
6
+
7
+
8
+ def _force_fsdp_grad_reduction_policy(module: FSDPModule):
9
+ module.set_force_sum_reduction_for_comms(enable=True)
10
+ module.set_gradient_divide_factor(1.0)
11
+ module.set_requires_all_reduce(False)
12
+
13
+
14
+ def parallelize_fsdp(
15
+ module: nn.Module,
16
+ mesh: DeviceMesh,
17
+ *args: Any,
18
+ **kwargs: Any
19
+ ):
20
+ """
21
+ Applies Fully Sharded Data Parallel (FSDP) with forced gradient summation.
22
+
23
+ This function wraps the provided module with PyTorch's ``fully_shard`` API using
24
+ the specified device mesh. Unlike standard FSDP usage, this function explicitly
25
+ configures the module to sum gradients across the mesh
26
+ instead of averaging them and disables internal all-sum-reduce hooks.
27
+ This is intended for d9d to handle gradient normalization and reduction across replicas externally.
28
+
29
+ Args:
30
+ module: The module to shard.
31
+ mesh: The device mesh over which to shard the module.
32
+ *args: Additional positional arguments passed to ``fully_shard``.
33
+ **kwargs: Additional keyword arguments passed to ``fully_shard``.
34
+ """
35
+
36
+ if mesh.ndim != 1:
37
+ raise ValueError("FSDP mesh should contain exactly one dimension - for HSDP, please apply "
38
+ "parallelize_replicate(...) first!")
39
+
40
+ fully_shard(module, *args, mesh=mesh, **kwargs)
41
+ if not isinstance(module, FSDPModule):
42
+ raise RuntimeError("Torch FSDP did not convert the module into FSDPModule")
43
+ _force_fsdp_grad_reduction_policy(module)
@@ -0,0 +1,49 @@
1
+ from typing import Any
2
+
3
+ from torch import nn
4
+ from torch.distributed import DeviceMesh
5
+
6
+ from .fully_sharded import parallelize_fsdp
7
+ from .replicate_parallel import parallelize_replicate
8
+
9
+
10
+ def parallelize_hsdp(
11
+ module: nn.Module,
12
+ mesh: DeviceMesh,
13
+ shard_dim: str = "dp_cp_shard",
14
+ *fsdp_args: Any,
15
+ **fsdp_kwargs: Any
16
+ ):
17
+ """
18
+ Applies Hybrid Sharded Data Parallelism (HSDP) to a module.
19
+
20
+ This function decomposes the provided device mesh into sharding dimensions
21
+ and replication dimensions. It applies replication parallelism
22
+ across the replication dimensions and Fully Sharded Data Parallelism (FSDP)
23
+ across the specified shard dimension.
24
+
25
+ Args:
26
+ module: The module to parallelize.
27
+ mesh: The device mesh over which to distribute the module.
28
+ shard_dim: The name of the mesh dimension used for FSDP sharding. Any
29
+ dimension in the mesh not matching this name will be treated as a
30
+ replication dimension.
31
+ *fsdp_args: Positional arguments passed to the underlying FSDP parallelizer.
32
+ **fsdp_kwargs: Keyword arguments passed to the underlying FSDP parallelizer.
33
+
34
+ Raises:
35
+ ValueError: If the device mesh does not have named dimensions.
36
+ """
37
+
38
+ replicate_dims = mesh.mesh_dim_names
39
+
40
+ if replicate_dims is None:
41
+ raise ValueError("Cannot use with unnamed device meshes")
42
+
43
+ replicate_dims = tuple(x for x in replicate_dims if x != shard_dim and mesh[x].size() > 1)
44
+
45
+ if len(replicate_dims) > 0:
46
+ parallelize_replicate(module, mesh[replicate_dims])
47
+
48
+ if mesh[shard_dim].size() != 1:
49
+ parallelize_fsdp(module, mesh[shard_dim], *fsdp_args, **fsdp_kwargs)
@@ -0,0 +1,33 @@
1
+ from torch import nn
2
+ from torch.distributed import DeviceMesh
3
+ from torch.distributed.tensor import Replicate
4
+ from torch.distributed.tensor.parallel import parallelize_module
5
+
6
+ from d9d.module.parallelism.style import ToLocalParallel
7
+
8
+
9
+ def parallelize_replicate(
10
+ module: nn.Module,
11
+ mesh: DeviceMesh,
12
+ ):
13
+ """
14
+ Applies replicated parallelism to the module.
15
+
16
+ This function configures the provided module to be fully replicated across the
17
+ given device mesh. It utilizes the ``ToLocalParallel`` style, which manages
18
+ ``DTensor`` wrapping for parameters and gradients (via ``Replicate`` placements)
19
+ while ensuring that the underlying computation sees standard local tensors during the forward pass.
20
+
21
+ This approach is effectively Data Parallelism managed via the DTensor
22
+ APIs, allowing seamless integration of modules that require local tensor inputs
23
+ into a broader distributed mesh context.
24
+
25
+ Args:
26
+ module: The module to parallelize.
27
+ mesh: The device mesh over which to replicate the module.
28
+ """
29
+
30
+ parallelize_module(module, mesh, ToLocalParallel(
31
+ param_placement=tuple(Replicate() for _ in range(mesh.ndim)),
32
+ grad_placement=tuple(Replicate() for _ in range(mesh.ndim))
33
+ ))
File without changes