d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
d9d/kernel/cce/main.py ADDED
@@ -0,0 +1,282 @@
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # TODO: currently this implementation diverges only in out_grad contiguity fix
3
+ # TODO: proposed in cce.py (grep FIX) - we should contribute this to main repo
4
+ import platform
5
+ import warnings
6
+ from typing import TYPE_CHECKING, Literal, overload
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from cut_cross_entropy.cce_utils import CCEPreset, CCEPresets, LinearCrossEntropyImpl
12
+ from cut_cross_entropy.constants import IGNORE_INDEX
13
+ from cut_cross_entropy.doc import (
14
+ CCE_OPTS_DOC,
15
+ DTENSOR_NOTE,
16
+ IMPL_DOC,
17
+ LINEAR_CROSS_ENTROPY_DOC,
18
+ add_doc_end,
19
+ add_doc_start,
20
+ )
21
+ from cut_cross_entropy.torch_compile import torch_compile_linear_cross_entropy
22
+ from cut_cross_entropy.utils import (
23
+ CCEWarning,
24
+ is_torch_greater_or_equal_2_5,
25
+ is_triton_3_2,
26
+ maybe_type_as,
27
+ to_full_tensor,
28
+ )
29
+ from cut_cross_entropy.vocab_parallel import VocabParallelOptions
30
+
31
+ warnings.filterwarnings("once", category=CCEWarning, module="cut_cross_entropy")
32
+
33
+ PLATFORM_SYSTEM = platform.system()
34
+
35
+ if TYPE_CHECKING or PLATFORM_SYSTEM != "Darwin":
36
+ from .cce import cce_linear_cross_entropy
37
+
38
+ LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.CCE
39
+ else:
40
+ cce_linear_cross_entropy = None
41
+ LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.TORCH_COMPILE
42
+
43
+ if TYPE_CHECKING or is_torch_greater_or_equal_2_5():
44
+ import torch.distributed.tensor
45
+
46
+
47
+ is_d_tensor_error_message = (
48
+ "Received {name} as a torch.distributed.tensor.DTensor. This is not supported. "
49
+ )
50
+
51
+
52
+ @overload
53
+ def linear_cross_entropy(
54
+ e: torch.Tensor,
55
+ c: torch.Tensor,
56
+ targets: torch.Tensor,
57
+ bias: torch.Tensor | None = None,
58
+ ignore_index: int = IGNORE_INDEX,
59
+ softcap: float | None = None,
60
+ reduction: str = "mean",
61
+ shift: bool | int = 0,
62
+ return_lse: Literal[False] = False,
63
+ filter_eps: float | str | None = "auto",
64
+ accum_e_fp32: bool = False,
65
+ accum_c_fp32: bool = False,
66
+ filter_e_grad: bool = True,
67
+ filter_c_grad: bool = True,
68
+ impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
69
+ vocab_parallel_options: VocabParallelOptions | None = None,
70
+ ) -> torch.Tensor: ...
71
+
72
+
73
+ @overload
74
+ def linear_cross_entropy(
75
+ e: torch.Tensor,
76
+ c: torch.Tensor,
77
+ targets: torch.Tensor,
78
+ bias: torch.Tensor | None = None,
79
+ ignore_index: int = IGNORE_INDEX,
80
+ softcap: float | None = None,
81
+ reduction: str = "mean",
82
+ shift: bool | int = 0,
83
+ return_lse: Literal[True] = True,
84
+ filter_eps: float | str | None = "auto",
85
+ accum_e_fp32: bool = False,
86
+ accum_c_fp32: bool = False,
87
+ filter_e_grad: bool = True,
88
+ filter_c_grad: bool = True,
89
+ impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
90
+ vocab_parallel_options: VocabParallelOptions | None = None,
91
+ ) -> tuple[torch.Tensor, torch.Tensor]: ...
92
+
93
+
94
+ @overload
95
+ def linear_cross_entropy(
96
+ e: torch.Tensor,
97
+ c: torch.Tensor,
98
+ targets: torch.Tensor,
99
+ bias: torch.Tensor | None = None,
100
+ ignore_index: int = IGNORE_INDEX,
101
+ softcap: float | None = None,
102
+ reduction: str = "mean",
103
+ shift: bool | int = 0,
104
+ return_lse: bool = False,
105
+ filter_eps: float | str | None = "auto",
106
+ accum_e_fp32: bool = False,
107
+ accum_c_fp32: bool = False,
108
+ filter_e_grad: bool = True,
109
+ filter_c_grad: bool = True,
110
+ impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
111
+ vocab_parallel_options: VocabParallelOptions | None = None,
112
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ...
113
+
114
+
115
+ @add_doc_start(LINEAR_CROSS_ENTROPY_DOC)
116
+ @add_doc_start(*(doc_str + " Only valid for the cce implementation." for doc_str in CCE_OPTS_DOC))
117
+ @add_doc_start(IMPL_DOC)
118
+ @add_doc_end(DTENSOR_NOTE)
119
+ def linear_cross_entropy(
120
+ e: torch.Tensor,
121
+ c: torch.Tensor,
122
+ targets: torch.Tensor,
123
+ bias: torch.Tensor | None = None,
124
+ ignore_index: int = IGNORE_INDEX,
125
+ softcap: float | None = None,
126
+ reduction: str = "mean",
127
+ shift: bool | int = 0,
128
+ return_lse: bool = False,
129
+ filter_eps: float | str | None = "auto",
130
+ accum_e_fp32: bool = False,
131
+ accum_c_fp32: bool = False,
132
+ filter_e_grad: bool = True,
133
+ filter_c_grad: bool = True,
134
+ impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
135
+ vocab_parallel_options: VocabParallelOptions | None = None,
136
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
137
+ """
138
+ :param vocab_parallel_options: Used to enable vocab parallelism."""
139
+
140
+ if is_torch_greater_or_equal_2_5():
141
+ maybe_tensor_inputs = dict(e=e, targets=targets)
142
+ for k, v in maybe_tensor_inputs.items():
143
+ if isinstance(v, torch.distributed.tensor.DTensor):
144
+ raise ValueError(is_d_tensor_error_message.format(name=k))
145
+
146
+ c = maybe_type_as(to_full_tensor(c), e)
147
+ bias = maybe_type_as(to_full_tensor(bias), e)
148
+
149
+ if isinstance(impl, LinearCrossEntropyImpl):
150
+ impl = impl.name.lower()
151
+
152
+ if isinstance(shift, int) and (shift < 0 or shift >= targets.size(-1)):
153
+ raise ValueError(f"Shift must be in the range [0, {targets.size(-1)}). Got {shift}.")
154
+
155
+ if vocab_parallel_options is not None:
156
+ expected_v_dim_size = vocab_parallel_options.stop - vocab_parallel_options.start
157
+ if c.size(0) != expected_v_dim_size:
158
+ raise ValueError(f"Expected c.size(0) to be {expected_v_dim_size}, got {c.size(0)}.")
159
+
160
+ if bias is not None and bias.size(0) != c.size(0):
161
+ raise ValueError(
162
+ f"Bias has a different number of elements than c. {bias.size(0)} vs. {c.size(0)}."
163
+ )
164
+
165
+ if impl in CCEPresets.names:
166
+ if platform.system() == "Darwin":
167
+ raise RuntimeError(
168
+ "CCE does not support MacOS. Please use torch_compile when running on MacOS instead."
169
+ )
170
+
171
+ if is_triton_3_2():
172
+ warnings.warn(
173
+ "There is a known issue with CCE and Triton 3.2 (the version that ships with PyTorch 2.6)"
174
+ " that can result in incorrect gradients. If possible, please verify that you"
175
+ " are not impacted by this bug by trying a newer triton version (i.e. by installing PyTorch>2.6).",
176
+ CCEWarning,
177
+ stacklevel=2,
178
+ )
179
+
180
+ cce_opts = CCEPresets.build_for_impl(
181
+ impl,
182
+ CCEPreset(
183
+ filter_eps=filter_eps,
184
+ accum_e_fp32=accum_e_fp32,
185
+ accum_c_fp32=accum_c_fp32,
186
+ filter_e_grad=filter_e_grad,
187
+ filter_c_grad=filter_c_grad,
188
+ ),
189
+ )
190
+
191
+ assert cce_linear_cross_entropy is not None
192
+ loss, lse = cce_linear_cross_entropy(
193
+ e,
194
+ c,
195
+ targets,
196
+ bias,
197
+ ignore_index,
198
+ softcap,
199
+ reduction,
200
+ shift,
201
+ **cce_opts,
202
+ vocab_parallel_options=vocab_parallel_options,
203
+ return_lse=return_lse,
204
+ )
205
+ elif impl == "torch_compile":
206
+ loss, lse = torch_compile_linear_cross_entropy(
207
+ e,
208
+ c,
209
+ targets,
210
+ bias,
211
+ ignore_index,
212
+ softcap,
213
+ reduction,
214
+ shift,
215
+ vocab_parallel_options=vocab_parallel_options,
216
+ return_lse=return_lse,
217
+ )
218
+ else:
219
+ raise NotImplementedError(f"{impl} is not implemented.")
220
+
221
+ if return_lse:
222
+ assert lse is not None
223
+ return loss, lse
224
+ else:
225
+ return loss
226
+
227
+
228
+ class LinearCrossEntropy(nn.Module):
229
+ def __init__(
230
+ self,
231
+ ignore_index: int = IGNORE_INDEX,
232
+ softcap: float | None = None,
233
+ reduction: str = "mean",
234
+ shift: bool | int = 0,
235
+ filter_eps: float | str | None = "auto",
236
+ accum_e_fp32: bool = False,
237
+ accum_c_fp32: bool = False,
238
+ filter_e_grad: bool = True,
239
+ filter_c_grad: bool = True,
240
+ impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
241
+ return_lse: bool = False,
242
+ ):
243
+ super().__init__()
244
+ self.ignore_index = ignore_index
245
+ self.softcap = softcap
246
+ self.reduction = reduction
247
+ self.filter_eps = filter_eps
248
+ self.shift = shift
249
+
250
+ self.accum_e_fp32 = accum_e_fp32
251
+ self.accum_c_fp32 = accum_c_fp32
252
+
253
+ self.filter_e_grad = filter_e_grad
254
+ self.filter_c_grad = filter_c_grad
255
+
256
+ self.impl = impl
257
+ self.return_lse = return_lse
258
+
259
+ def forward(
260
+ self,
261
+ e: torch.Tensor,
262
+ c: torch.Tensor,
263
+ targets: torch.Tensor,
264
+ bias: torch.Tensor | None = None,
265
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
266
+ return linear_cross_entropy(
267
+ e,
268
+ c,
269
+ targets,
270
+ bias=bias,
271
+ ignore_index=self.ignore_index,
272
+ softcap=self.softcap,
273
+ reduction=self.reduction,
274
+ shift=self.shift,
275
+ filter_eps=self.filter_eps,
276
+ accum_e_fp32=self.accum_e_fp32,
277
+ accum_c_fp32=self.accum_c_fp32,
278
+ filter_e_grad=self.filter_e_grad,
279
+ filter_c_grad=self.filter_c_grad,
280
+ impl=self.impl,
281
+ return_lse=self.return_lse,
282
+ )
@@ -0,0 +1,5 @@
1
+ from .get_int_dtype import get_int_dtype
2
+
3
+ __all__ = [
4
+ "get_int_dtype"
5
+ ]
@@ -0,0 +1,7 @@
1
+ import triton
2
+ import triton.language as tl
3
+
4
+
5
+ @triton.constexpr_function
6
+ def get_int_dtype(bitwidth: int, signed: bool) -> tl.dtype:
7
+ return tl.core.get_int_dtype(bitwidth, signed)
@@ -0,0 +1,5 @@
1
+ from .function import gmm
2
+
3
+ __all__ = [
4
+ "gmm"
5
+ ]
@@ -0,0 +1,78 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ from grouped_gemm import backend
5
+ from torch.autograd import Function
6
+
7
+ from d9d.core.autograd import GLOBAL_GRAD_CONTEXT, GradDirection
8
+
9
+
10
+ class GroupedGemm(Function):
11
+ """
12
+ Autograd function for Grouped GEMM (Generalized Matrix Multiplication) with explicit gradient control.
13
+ """
14
+
15
+ @staticmethod
16
+ def forward(
17
+ ctx: Any,
18
+ a: torch.Tensor,
19
+ b: torch.Tensor,
20
+ batch_sizes: torch.Tensor,
21
+ a_grad_direction: GradDirection | None,
22
+ b_grad_direction: GradDirection | None,
23
+ trans_b: bool
24
+ ) -> torch.Tensor:
25
+ ctx.save_for_backward(a, b, batch_sizes)
26
+ ctx.a_grad_direction = a_grad_direction
27
+ ctx.b_grad_direction = b_grad_direction
28
+ ctx.trans_b = trans_b
29
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
30
+
31
+ @staticmethod
32
+ def backward(
33
+ ctx: Any, grad: torch.Tensor
34
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None, None, None, None, None]:
35
+ grad = grad.contiguous()
36
+ a, b, batch_sizes = ctx.saved_tensors
37
+ trans_b = ctx.trans_b
38
+
39
+ compute_a = GLOBAL_GRAD_CONTEXT.check_direction(ctx.a_grad_direction)
40
+ compute_b = GLOBAL_GRAD_CONTEXT.check_direction(ctx.b_grad_direction)
41
+
42
+ a_grad = None
43
+ if ctx.needs_input_grad[0] and compute_a:
44
+ a_grad = backend.gmm(
45
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
46
+
47
+ b_grad = None
48
+ if ctx.needs_input_grad[1] and compute_b:
49
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
50
+ b_grad = backend.gmm(
51
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
52
+ return a_grad, b_grad, None, None, None, None
53
+
54
+
55
+ def gmm(
56
+ a: torch.Tensor,
57
+ b: torch.Tensor,
58
+ batch_sizes: torch.Tensor,
59
+ a_grad_direction: GradDirection | None,
60
+ b_grad_direction: GradDirection | None,
61
+ trans_b: bool = False
62
+ ) -> torch.Tensor:
63
+ """
64
+ The Grouped GEMM (Generalized Matrix Multiplication) function with explicit gradient control.
65
+
66
+ Args:
67
+ a: Left-hand side tensor.
68
+ b: Right-hand side tensor.
69
+ batch_sizes: Sizes of batches/groups.
70
+ a_grad_direction: Gradient category for `a` (e.g., `GradDirection.inputs`).
71
+ b_grad_direction: Gradient category for `b` (e.g., `GradDirection.weight`).
72
+ trans_b: Whether to transpose `b`.
73
+
74
+ Returns:
75
+ Result of matrix multiplication.
76
+ """
77
+
78
+ return GroupedGemm.apply(a, b, batch_sizes, a_grad_direction, b_grad_direction, trans_b)
@@ -0,0 +1,8 @@
1
+ from .indices_to_multihot import fused_indices_to_multihot
2
+ from .permute_with_probs import moe_permute_with_probs, moe_unpermute_mask
3
+
4
+ __all__ = [
5
+ "fused_indices_to_multihot",
6
+ "moe_permute_with_probs",
7
+ "moe_unpermute_mask"
8
+ ]
@@ -0,0 +1,268 @@
1
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/fusions/fused_indices_converter.py
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+
4
+ import math
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ # Assign a block to a row([1,topk]), generate a local routing map([1,num_of_local_experts])
12
+ @triton.jit
13
+ def _indices_to_multihot_kernel(
14
+ indices_ptr,
15
+ probs_in_indices_ptr,
16
+ multihot_indices_ptr, # bool
17
+ probs_in_multihot_ptr,
18
+ position_map_ptr,
19
+ num_of_local_experts: tl.constexpr,
20
+ num_of_local_experts_next_power_of_2: tl.constexpr,
21
+ topk: tl.constexpr,
22
+ topk_next_power_of_2: tl.constexpr,
23
+ BLOCK_SIZE: tl.constexpr,
24
+ ):
25
+ '''
26
+ Triton kernel for converting indices to multihot representation.
27
+
28
+ Input:
29
+ indices: [num_of_tokens, topk]
30
+ probs_in_indices: [num_of_tokens, topk]
31
+ Output:
32
+ multihot_indices: [num_of_tokens, num_of_local_experts]
33
+ probs_in_multihot: [num_of_tokens, num_of_local_experts]
34
+
35
+ Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2,
36
+ then the kernel can process the following conversion:
37
+
38
+ Input Example:
39
+ indices = [
40
+ [0, 1],
41
+ [1, 2]
42
+ ]
43
+ probs_in_indices = [
44
+ [0.1, 0.2],
45
+ [0.3, 0.4]
46
+ ]
47
+ Output Example:
48
+ multihot_indices = [
49
+ [1, 1, -1, -1],
50
+ [-1, 1, 1, -1]
51
+ ]
52
+ probs_in_multihot = [
53
+ [0.1, 0.2, 0.0, 0.0],
54
+ [0.0, 0.3, 0.4, 0.0]
55
+ ]
56
+ '''
57
+ # Prepare the [0, topk) row
58
+ topk_row = tl.arange(0, topk_next_power_of_2)
59
+ topk_row = tl.where(topk_row < topk, topk_row, -1)
60
+ topk_row_mask = topk_row != -1
61
+ # Prepare the [0, num_of_local_experts) row
62
+ num_exp_row = tl.arange(0, num_of_local_experts_next_power_of_2)
63
+ num_exp_row = tl.where(num_exp_row < num_of_local_experts, num_exp_row, -1)
64
+ num_exp_row_mask = num_exp_row != -1
65
+
66
+ # Load a [1, topk] row from the indices buffer
67
+ row_idx = tl.program_id(0)
68
+ indices_row = tl.load(indices_ptr + row_idx * topk + topk_row, mask=topk_row_mask)
69
+ indices_row = tl.where(topk_row_mask, indices_row, -1)
70
+ probs_row = tl.load(probs_in_indices_ptr + row_idx * topk + topk_row, mask=topk_row_mask)
71
+
72
+ # Get the position of the each index in the indices_row, which is saved for backwards
73
+ position_row = tl.where(indices_row != -1, topk_row, -1)
74
+ # Mask of the valid indices
75
+ mask = (indices_row != -1) & (indices_row < num_of_local_experts)
76
+
77
+ row_idx_offset = row_idx * num_of_local_experts
78
+ # Store to initialize
79
+ tl.store(multihot_indices_ptr + row_idx_offset + num_exp_row, 0, mask=num_exp_row_mask)
80
+ tl.store(probs_in_multihot_ptr + row_idx_offset + num_exp_row, 0, mask=num_exp_row_mask)
81
+ tl.store(position_map_ptr + row_idx_offset + num_exp_row, -1, mask=num_exp_row_mask)
82
+ # Use barrier to make sure the initialization is done
83
+ tl.debug_barrier()
84
+ # Store the indices and probs_in_indices
85
+ tl.store(multihot_indices_ptr + row_idx_offset + indices_row, 1, mask)
86
+ tl.store(probs_in_multihot_ptr + row_idx_offset + indices_row, probs_row, mask)
87
+ # Store the position of the position_row for backwards
88
+ tl.store(position_map_ptr + row_idx_offset + indices_row, position_row, mask)
89
+
90
+
91
+ # Assign a block to a row([1,topk]), generate a probs_indices([1,topk])
92
+ @triton.jit
93
+ def _multihot_to_indices_kernel(
94
+ probs_in_multihot_ptr,
95
+ position_map_ptr,
96
+ probs_indices_ptr,
97
+ num_of_local_experts: tl.constexpr,
98
+ num_of_local_experts_next_power_of_2: tl.constexpr,
99
+ topk: tl.constexpr,
100
+ topk_next_power_of_2: tl.constexpr,
101
+ BLOCK_SIZE: tl.constexpr,
102
+ ):
103
+ '''
104
+ Triton kernel for converting multihot representation to indices.
105
+
106
+ Input:
107
+ probs_in_multihot: [num_of_tokens, num_of_local_experts]
108
+ position_map: [num_of_tokens, num_of_local_experts]
109
+ Output:
110
+ probs_indices: [num_of_tokens, topk]
111
+
112
+ Assume that topk = 2 , num_of_local_experts = 4, num_of_tokens = 2,
113
+ then the kernel can process the following conversion:
114
+
115
+ Input Example:
116
+ probs_in_multihot = [
117
+ [0.7, 0.8, 0.0, 0.0],
118
+ [0.0, 0.1, 0.9, 0.0]
119
+ ]
120
+ position_map = [
121
+ [1, 1, -1, -1],
122
+ [-1, 1, 1, -1]
123
+ ]
124
+ Output Example:
125
+ probs_indices = [
126
+ [0.7, 0.8],
127
+ [0.1, 0.9]
128
+ ]
129
+ '''
130
+ # Prepare the [0, topk) row
131
+ topk_row = tl.arange(0, topk_next_power_of_2)
132
+ topk_row = tl.where(topk_row < topk, topk_row, -1)
133
+ topk_row_mask = topk_row != -1
134
+ # Prepare the [0, num_of_local_experts) row
135
+ num_exp_row = tl.arange(0, num_of_local_experts_next_power_of_2)
136
+ num_exp_row = tl.where(num_exp_row < num_of_local_experts, num_exp_row, -1)
137
+ num_exp_row_mask = num_exp_row != -1
138
+
139
+ # Load a [1, num_of_local_experts] row from the local routing map
140
+ row_idx = tl.program_id(0)
141
+ ptr_offset = row_idx * num_of_local_experts + num_exp_row
142
+ probs_in_multihot_row = tl.load(probs_in_multihot_ptr + ptr_offset, mask=num_exp_row_mask)
143
+
144
+ # Get the original position of the valid value in the the indices
145
+ position_map_row = tl.load(position_map_ptr + ptr_offset, mask=num_exp_row_mask)
146
+ position_map_row = tl.where(num_exp_row_mask, position_map_row, -1)
147
+ mask = position_map_row != -1
148
+
149
+ # Store to initialize
150
+ tl.store(probs_indices_ptr + row_idx * topk + topk_row, 0, mask=topk_row_mask)
151
+ # Use barrier to make sure the initialization is done
152
+ tl.debug_barrier()
153
+ # Restore the indices and probs_indices
154
+ tl.store(probs_indices_ptr + row_idx * topk + position_map_row, probs_in_multihot_row, mask)
155
+
156
+
157
+ class IndicesToMultihot(torch.autograd.Function):
158
+ """Convert moe topk indices to multihot representation.
159
+
160
+ This class implements a custom forward and backward propagation
161
+ operation for efficiently converting indices to multihot
162
+ representation.
163
+ It is an experimental feature and may change in future versions.
164
+ """
165
+
166
+ @staticmethod
167
+ def forward(ctx, indices, probs_indices, num_of_local_experts):
168
+ '''Forward function for IndicesToMultihot
169
+
170
+ Convert indices to multihot representation.
171
+
172
+ Args:
173
+ indices: [num_of_tokens, topk]
174
+ probs_indices: [num_of_tokens, topk]
175
+ num_of_local_experts: int
176
+
177
+ Returns:
178
+ multihot_indices: [num_of_tokens, num_of_local_experts]
179
+ probs_in_multihot: [num_of_tokens, num_of_local_experts]
180
+ '''
181
+ num_of_tokens = indices.shape[0]
182
+ assert (
183
+ indices.shape == probs_indices.shape
184
+ ), "indices and probs_indices must have the same shape"
185
+ topk = indices.shape[1]
186
+ multihot_indices = torch.empty(
187
+ (num_of_tokens, num_of_local_experts), dtype=torch.bool, device="cuda"
188
+ )
189
+ probs_in_multihot = torch.empty(
190
+ (num_of_tokens, num_of_local_experts), dtype=probs_indices.dtype, device="cuda"
191
+ )
192
+ position_map = torch.empty(
193
+ (num_of_tokens, num_of_local_experts), dtype=torch.int32, device="cuda"
194
+ )
195
+ # Compute the next power of 2 for the topk and num_of_local_experts
196
+ topk_next_power_of_2 = 2 ** int(math.ceil(math.log2(topk)))
197
+ num_of_local_experts_next_power_of_2 = 2 ** int(math.ceil(math.log2(num_of_local_experts)))
198
+ grid = (num_of_tokens,)
199
+ _indices_to_multihot_kernel[grid](
200
+ indices,
201
+ probs_indices,
202
+ multihot_indices,
203
+ probs_in_multihot,
204
+ position_map,
205
+ num_of_local_experts,
206
+ num_of_local_experts_next_power_of_2,
207
+ topk,
208
+ topk_next_power_of_2,
209
+ BLOCK_SIZE=32, # use only 1 warp per block
210
+ num_warps=1,
211
+ )
212
+
213
+ ctx.save_for_backward(position_map)
214
+ ctx.num_of_tokens = num_of_tokens
215
+ ctx.num_of_local_experts = num_of_local_experts
216
+ ctx.topk = topk
217
+ return multihot_indices, probs_in_multihot
218
+
219
+ @staticmethod
220
+ def backward(ctx, grad_multihot_indices, grad_probs_in_multihot):
221
+ '''Backward function for IndicesToMultihot
222
+
223
+ Convert multihot probs representation to indices.
224
+ indices is ignored in the backward function.
225
+
226
+ Args:
227
+ grad_multihot_indices: [num_of_tokens, num_of_local_experts]
228
+ grad_probs_in_multihot: [num_of_tokens, num_of_local_experts]
229
+
230
+ Returns:
231
+ grad_probs_indices: [num_of_tokens, topk]
232
+ '''
233
+ position_map = ctx.saved_tensors[0]
234
+ num_of_tokens = ctx.num_of_tokens
235
+ num_of_local_experts = ctx.num_of_local_experts
236
+ topk = ctx.topk
237
+
238
+ # Initialize the gradient of the indices and probs_indices
239
+ grad_probs_indices = torch.empty(
240
+ (num_of_tokens, topk), dtype=grad_probs_in_multihot.dtype, device="cuda"
241
+ )
242
+ # Compute the next power of 2 for the topk and num_of_local_experts
243
+ topk_next_power_of_2 = 2 ** int(math.ceil(math.log2(topk)))
244
+ num_of_local_experts_next_power_of_2 = 2 ** int(math.ceil(math.log2(num_of_local_experts)))
245
+
246
+ grid = (num_of_tokens,)
247
+ _multihot_to_indices_kernel[grid](
248
+ # if the grad_probs_in_multihot is all-one/all-zero,
249
+ # overlapping stride will cause error without contiguous()
250
+ grad_probs_in_multihot.contiguous(),
251
+ position_map,
252
+ grad_probs_indices,
253
+ num_of_local_experts,
254
+ num_of_local_experts_next_power_of_2,
255
+ topk,
256
+ topk_next_power_of_2,
257
+ BLOCK_SIZE=32, # use only 1 warp per block
258
+ num_warps=1,
259
+ )
260
+ return None, grad_probs_indices, None, None
261
+
262
+
263
+ def fused_indices_to_multihot(indices, probs_indices, num_of_local_experts):
264
+ """Convert moe topk indices to multihot representation.
265
+
266
+ This function is an experimental feature and may change in future versions.
267
+ """
268
+ return IndicesToMultihot.apply(indices, probs_indices, num_of_local_experts)