d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,1035 @@
1
+ # https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/permutation.py
2
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3
+
4
+ import torch
5
+ import triton
6
+
7
+
8
+ import triton.language as tl
9
+ from triton.language.standard import _log2
10
+
11
+ from d9d.kernel.general import get_int_dtype
12
+
13
+
14
+ @triton.jit
15
+ def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
16
+ n_outer: tl.constexpr = x.numel >> n_dims
17
+ shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)]
18
+ y = tl.reshape(x, shape)
19
+ z = tl.reshape(indices, shape)
20
+
21
+ mask = tl.arange(0, 2)[None, :, None]
22
+
23
+ l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to(
24
+ x.dtype
25
+ )
26
+ r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to(
27
+ x.dtype
28
+ )
29
+
30
+ l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape)
31
+ r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape)
32
+
33
+ idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
34
+
35
+ il_value = l_value.to(idtype, bitcast=True)
36
+ ir_value = r_value.to(idtype, bitcast=True)
37
+ ix = x.to(idtype, bitcast=True)
38
+
39
+ flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix))
40
+ ret = ix ^ flag1
41
+ flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix))
42
+ ind = indices ^ flag2
43
+
44
+ return ret.to(x.dtype, bitcast=True), ind
45
+
46
+
47
+ @triton.jit
48
+ def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
49
+ n_outer: tl.constexpr = x.numel >> n_dims
50
+ tl.static_assert(stage <= n_dims)
51
+ """
52
+ order_type 0 == ascending
53
+ order_type 1 == descending
54
+ order_type 2 == alternating
55
+ """
56
+ if order == 2:
57
+ shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage]
58
+ flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
59
+ else:
60
+ flip = tl.full(x.shape, value=order, dtype=tl.int32)
61
+ for i in tl.static_range(stage):
62
+ x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims)
63
+ return x, indices
64
+
65
+
66
+ @triton.jit
67
+ def _argsort(x, indices, n_dims: tl.constexpr):
68
+ for i in tl.static_range(1, n_dims + 1):
69
+ x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims)
70
+ return x, indices
71
+
72
+
73
+ @triton.jit
74
+ def _row_id_map_pass_1_kernel(
75
+ # pointers
76
+ routing_map_ptr,
77
+ row_id_map_ptr,
78
+ workspace_ptr,
79
+ # sizes
80
+ num_tokens,
81
+ # strides
82
+ stride_routing_map_token,
83
+ stride_routing_map_expert,
84
+ stride_row_id_map_token,
85
+ stride_row_id_map_expert,
86
+ # metas
87
+ BLOCK_SIZE: tl.constexpr,
88
+ ):
89
+ pid_m = tl.program_id(0)
90
+ pid_n = tl.program_id(1)
91
+ offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
92
+ expert_token_mask = tl.load(
93
+ routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
94
+ mask=(offset < num_tokens),
95
+ other=0,
96
+ ).to(tl.int32)
97
+ row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
98
+ tl.store(
99
+ row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
100
+ row_id_within_token_block,
101
+ mask=offset < num_tokens,
102
+ )
103
+ n_tokens_per_block = tl.sum(expert_token_mask)
104
+ tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)
105
+
106
+
107
+ @triton.jit
108
+ def _row_id_map_pass_2_kernel(
109
+ # pointers
110
+ row_id_map_ptr,
111
+ workspace_ptr,
112
+ # sizes
113
+ num_tokens,
114
+ # strides
115
+ stride_row_id_map_token,
116
+ stride_row_id_map_expert,
117
+ # metas
118
+ WORKSPACE_LOAD_WIDTH: tl.constexpr,
119
+ BLOCK_SIZE: tl.constexpr,
120
+ ):
121
+ pid_m = tl.program_id(0)
122
+ pid_n = tl.program_id(1)
123
+ chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
124
+ offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
125
+ row_id_within_token_block = tl.load(
126
+ row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
127
+ mask=(offset < num_tokens),
128
+ other=0,
129
+ )
130
+
131
+ workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
132
+ n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
133
+ row_id = tl.where(
134
+ row_id_within_token_block == 0,
135
+ -1,
136
+ row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
137
+ )
138
+ tl.store(
139
+ row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
140
+ row_id,
141
+ mask=(offset < num_tokens),
142
+ )
143
+
144
+
145
+ @triton.jit
146
+ def _row_id_map_pass_3_kernel(
147
+ # pointers
148
+ row_id_map_ptr,
149
+ # sizes
150
+ num_experts: tl.constexpr,
151
+ # strides
152
+ stride_row_id_map_token,
153
+ stride_row_id_map_expert,
154
+ # metas
155
+ LOAD_SIZE: tl.constexpr,
156
+ ):
157
+ pid = tl.program_id(0)
158
+ n_dims: tl.constexpr = _log2(LOAD_SIZE)
159
+ off = tl.arange(0, LOAD_SIZE)
160
+ row_id_map = tl.load(
161
+ row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off,
162
+ mask=off < num_experts,
163
+ other=-1,
164
+ )
165
+ n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0))
166
+ indices = off
167
+ sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims)
168
+ tl.store(
169
+ row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert,
170
+ sorted_map,
171
+ mask=off < n_routed,
172
+ )
173
+ tl.store(
174
+ row_id_map_ptr
175
+ + pid * stride_row_id_map_token
176
+ + (num_experts + off) * stride_row_id_map_expert,
177
+ indices,
178
+ mask=off < n_routed,
179
+ )
180
+ tl.store(
181
+ row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert,
182
+ n_routed,
183
+ )
184
+
185
+
186
+ def make_row_id_map(
187
+ routing_map: torch.Tensor,
188
+ num_tokens: int,
189
+ num_experts: int,
190
+ ):
191
+ """
192
+ Prepare the row_id_map for the permutation.
193
+
194
+ Parameters
195
+ ----------
196
+ routing_map: torch.Tensor
197
+ Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
198
+ which experts are routed to which tokens. The values in it: 1 means the token is routed to
199
+ this expert and 0 means not.
200
+ num_tokens: int
201
+ Number of tokens in the input tensor.
202
+ num_experts: int
203
+ Number of experts in the input tensor.
204
+
205
+ Returns
206
+ -------
207
+ row_id_map: torch.Tensor
208
+ The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
209
+ For each token, the last item is the number of experts that are routed (n_routed).
210
+ The first n_routed items are the destination row indices in the permuted tokens.
211
+ The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
212
+ to the first n_routed row indices above.
213
+ """
214
+ row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda")
215
+ block_size = 1024
216
+ grid = (num_experts, triton.cdiv(num_tokens, block_size))
217
+ workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda")
218
+
219
+ # supposing num_tokens == 5, num_experts == 3, block_size == 3
220
+ # and we have a routing_map like this:
221
+ # [[1, 1, 0],
222
+ # [1, 0, 1],
223
+ # [0, 0, 1],
224
+ # [1, 1, 0],
225
+ # [0, 0, 0]]
226
+
227
+ # pass 1: block cumsum
228
+ # for each expert, compute the cumsum of every block_size tokens
229
+ # the row_id_map will be like this after pass 1 (r means useless values):
230
+ # [[1, 1, 0, r, r, r, r],
231
+ # [2, 0, 1, r, r, r, r],
232
+ # [0, 0, 2, r, r, r, r],
233
+ # [1, 1, 0, r, r, r, r],
234
+ # [0, 0, 0, r, r, r, r]]
235
+ _row_id_map_pass_1_kernel[grid](
236
+ routing_map,
237
+ row_id_map,
238
+ workspace_tensor,
239
+ num_tokens,
240
+ routing_map.stride(0),
241
+ routing_map.stride(1),
242
+ row_id_map.stride(0),
243
+ row_id_map.stride(1),
244
+ block_size,
245
+ )
246
+
247
+ # pass 2: cumsum all and process the mask
248
+ # process the block cumsum into the global cumsum and then into the dst row indices
249
+ # the row_id_map will be like this after pass 2 (r means useless value):
250
+ # [[ 0, 3, -1, r, r, r, r],
251
+ # [ 1, -1, 5, r, r, r, r],
252
+ # [-1, -1, 6, r, r, r, r],
253
+ # [ 2, 4, -1, r, r, r, r],
254
+ # [-1, -1, -1, r, r, r, r]]
255
+ _row_id_map_pass_2_kernel[grid](
256
+ row_id_map,
257
+ workspace_tensor,
258
+ num_tokens,
259
+ row_id_map.stride(0),
260
+ row_id_map.stride(1),
261
+ triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
262
+ block_size,
263
+ )
264
+
265
+ # pass 3: make the row_id_map from the sparse structure to the dense structure
266
+ # the row_id_map will be like this after pass 3 (r means useless value):
267
+ # [[3, 0, r, 1, 0, r, 2],
268
+ # [5, 1, r, 2, 0, r, 2],
269
+ # [6, r, r, 2, r, r, 1],
270
+ # [4, 2, r, 1, 0, r, 2],
271
+ # [r, r, r, r, r, r, 0]]
272
+ grid = (num_tokens,)
273
+ _row_id_map_pass_3_kernel[grid](
274
+ row_id_map,
275
+ num_experts,
276
+ row_id_map.stride(0),
277
+ row_id_map.stride(1),
278
+ triton.next_power_of_2(num_experts),
279
+ )
280
+ return row_id_map
281
+
282
+
283
+ @triton.jit
284
+ def _permute_kernel(
285
+ # pointers
286
+ input_ptr,
287
+ output_ptr,
288
+ row_id_map_ptr,
289
+ probs_ptr,
290
+ scale_ptr,
291
+ permuted_probs_ptr,
292
+ permuted_scale_ptr,
293
+ # sizes
294
+ num_experts: tl.constexpr,
295
+ hidden_size: tl.constexpr,
296
+ scale_hidden_dim,
297
+ # strides
298
+ stride_row_id_map_token,
299
+ stride_row_id_map_expert,
300
+ stride_input_token,
301
+ stride_input_hidden,
302
+ stride_output_token,
303
+ stride_output_hidden,
304
+ stride_probs_token,
305
+ stride_probs_expert,
306
+ stride_scale_token,
307
+ stride_scale_hidden,
308
+ stride_permuted_probs_token,
309
+ stride_permuted_scale_token,
310
+ stride_permuted_scale_hidden,
311
+ # metas
312
+ PERMUTE_PROBS: tl.constexpr,
313
+ PERMUTE_SCALE: tl.constexpr,
314
+ BLOCK_SIZE: tl.constexpr,
315
+ ):
316
+ pid_t = tl.program_id(0)
317
+ pid_h = tl.program_id(1)
318
+ cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
319
+ mask = cur_off < hidden_size
320
+ input_off = pid_t * stride_input_token + cur_off * stride_input_hidden
321
+ inp = tl.load(input_ptr + input_off, mask=mask)
322
+ if PERMUTE_SCALE:
323
+ mask_scale = cur_off < scale_hidden_dim
324
+ scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden
325
+ scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
326
+ n_routed = tl.load(
327
+ row_id_map_ptr
328
+ + pid_t * stride_row_id_map_token
329
+ + num_experts * 2 * stride_row_id_map_expert
330
+ )
331
+ for idx in tl.range(n_routed):
332
+ dst_row = tl.load(
333
+ row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
334
+ )
335
+ output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
336
+ if PERMUTE_SCALE:
337
+ permuted_scale_off = (
338
+ dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
339
+ )
340
+ tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
341
+ if PERMUTE_PROBS:
342
+ expert_idx = tl.load(
343
+ row_id_map_ptr
344
+ + pid_t * stride_row_id_map_token
345
+ + (num_experts + idx) * stride_row_id_map_expert
346
+ )
347
+ prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
348
+ prob = tl.load(probs_ptr + prob_off)
349
+ if pid_h == 0:
350
+ permuted_prob_off = dst_row * stride_permuted_probs_token
351
+ tl.store(permuted_probs_ptr + permuted_prob_off, prob)
352
+ if prob == 0.0:
353
+ # for routing_map padding
354
+ # dst_row != -1 and prob == 0.0 means that this slot is padded
355
+ tl.store(output_ptr + output_off, 0.0, mask=mask)
356
+ else:
357
+ tl.store(output_ptr + output_off, inp, mask=mask)
358
+ else:
359
+ tl.store(output_ptr + output_off, inp, mask=mask)
360
+
361
+
362
+ try:
363
+ _permute_kernel = triton.autotune(
364
+ configs=[
365
+ triton.Config({"BLOCK_SIZE": 64}),
366
+ triton.Config({"BLOCK_SIZE": 128}),
367
+ triton.Config({"BLOCK_SIZE": 256}),
368
+ triton.Config({"BLOCK_SIZE": 512}),
369
+ triton.Config({"BLOCK_SIZE": 1024}),
370
+ triton.Config({"BLOCK_SIZE": 2048}),
371
+ triton.Config({"BLOCK_SIZE": 4096}),
372
+ ],
373
+ key=["hidden_size"],
374
+ )(_permute_kernel)
375
+ except RuntimeError:
376
+ pass
377
+
378
+
379
+ def permute_with_mask_map(
380
+ inp: torch.Tensor,
381
+ row_id_map: torch.Tensor,
382
+ probs: torch.Tensor,
383
+ scale: torch.Tensor,
384
+ num_tokens: int,
385
+ num_experts: int,
386
+ num_out_tokens: int,
387
+ hidden_size: int,
388
+ scale_hidden_dim: int,
389
+ ):
390
+ """
391
+ Permute the input tensor based on the row_id_map.
392
+
393
+ Parameters
394
+ ----------
395
+ inp: torch.Tensor
396
+ Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
397
+ row_id_map: torch.Tensor
398
+ The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
399
+ probs: torch.Tensor
400
+ The probabilities of the input tensor. If it is not None, it will be permuted.
401
+ scale: torch.Tensor
402
+ The scale of the input tensor. If it is not None, it will be permuted.
403
+ num_tokens: int
404
+ Number of tokens in the input tensor.
405
+ num_experts: int
406
+ Number of experts in the input tensor.
407
+ num_out_tokens: int
408
+ Number of tokens in the permuted tensor.
409
+ hidden_size: int
410
+ Hidden size of the input tensor.
411
+ scale_hidden_dim: int
412
+ Hidden size of the scale tensor.
413
+ """
414
+ output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
415
+ if probs is not None:
416
+ permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
417
+ else:
418
+ permuted_probs = None
419
+
420
+ if scale is not None:
421
+ permuted_scale = torch.empty(
422
+ (num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
423
+ )
424
+ else:
425
+ permuted_scale = None
426
+ # pylint: disable=unnecessary-lambda-assignment
427
+ grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
428
+ _permute_kernel[grid](
429
+ inp,
430
+ output,
431
+ row_id_map,
432
+ probs,
433
+ scale,
434
+ permuted_probs,
435
+ permuted_scale,
436
+ num_experts,
437
+ hidden_size,
438
+ scale_hidden_dim,
439
+ row_id_map.stride(0),
440
+ row_id_map.stride(1),
441
+ inp.stride(0),
442
+ inp.stride(1),
443
+ output.stride(0),
444
+ output.stride(1),
445
+ probs.stride(0) if probs is not None else None,
446
+ probs.stride(1) if probs is not None else None,
447
+ scale.stride(0) if scale is not None else None,
448
+ scale.stride(1) if scale is not None else None,
449
+ permuted_probs.stride(0) if permuted_probs is not None else None,
450
+ permuted_scale.stride(0) if permuted_scale is not None else None,
451
+ permuted_scale.stride(1) if permuted_scale is not None else None,
452
+ PERMUTE_PROBS=probs is not None,
453
+ PERMUTE_SCALE=scale is not None,
454
+ )
455
+ return output, permuted_scale, permuted_probs
456
+
457
+
458
+ @triton.jit
459
+ def _unpermute_kernel(
460
+ # pointers
461
+ input_ptr,
462
+ output_ptr,
463
+ row_id_map_ptr,
464
+ merging_probs_ptr,
465
+ permuted_probs_ptr,
466
+ unpermuted_probs_ptr,
467
+ # sizes
468
+ num_experts: tl.constexpr,
469
+ hidden_size: tl.constexpr,
470
+ # strides
471
+ stride_row_id_map_token,
472
+ stride_row_id_map_expert,
473
+ stride_input_token,
474
+ stride_input_hidden,
475
+ stride_output_token,
476
+ stride_output_hidden,
477
+ stride_merging_probs_token,
478
+ stride_merging_probs_expert,
479
+ stride_permuted_probs_token,
480
+ stride_unpermuted_probs_token,
481
+ stride_unpermuted_probs_expert,
482
+ # metas
483
+ PROBS_LOAD_WIDTH: tl.constexpr,
484
+ WITH_MERGING_PROBS: tl.constexpr,
485
+ PERMUTE_PROBS: tl.constexpr,
486
+ BLOCK_SIZE: tl.constexpr,
487
+ ):
488
+ data_type = input_ptr.dtype.element_ty
489
+ compute_type = tl.float32
490
+
491
+ pid_t = tl.program_id(0)
492
+ pid_h = tl.program_id(1)
493
+ current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
494
+ mask = current_offset < hidden_size
495
+ if PERMUTE_PROBS:
496
+ # write 0.0 to probs_grad that are not routed
497
+ if pid_h == 0:
498
+ map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
499
+ unpermuted_prob_off = (
500
+ pid_t * stride_unpermuted_probs_token
501
+ + stride_unpermuted_probs_expert * map_load_off
502
+ )
503
+ tl.store(
504
+ unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts
505
+ )
506
+ accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
507
+ n_routed = tl.load(
508
+ row_id_map_ptr
509
+ + pid_t * stride_row_id_map_token
510
+ + num_experts * 2 * stride_row_id_map_expert
511
+ )
512
+ for idx in tl.range(n_routed):
513
+ src_row = tl.load(
514
+ row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
515
+ )
516
+ input_off = src_row * stride_input_token + current_offset * stride_input_hidden
517
+ inp = tl.load(input_ptr + input_off, mask=mask)
518
+ inp = inp.to(compute_type)
519
+ if WITH_MERGING_PROBS:
520
+ expert_idx = tl.load(
521
+ row_id_map_ptr
522
+ + pid_t * stride_row_id_map_token
523
+ + (num_experts + idx) * stride_row_id_map_expert
524
+ )
525
+ merging_prob_off = (
526
+ pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
527
+ )
528
+ merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
529
+ inp *= merging_prob
530
+ accumulator += inp
531
+ if PERMUTE_PROBS:
532
+ if pid_h == 0:
533
+ expert_idx = tl.load(
534
+ row_id_map_ptr
535
+ + pid_t * stride_row_id_map_token
536
+ + (num_experts + idx) * stride_row_id_map_expert
537
+ )
538
+ unpermuted_prob_off = (
539
+ pid_t * stride_unpermuted_probs_token
540
+ + expert_idx * stride_unpermuted_probs_expert
541
+ )
542
+ permuted_prob_off = src_row * stride_permuted_probs_token
543
+ prob = tl.load(permuted_probs_ptr + permuted_prob_off)
544
+ tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
545
+ accumulator = accumulator.to(data_type)
546
+ output_off = pid_t * stride_output_token + current_offset * stride_output_hidden
547
+ tl.store(output_ptr + output_off, accumulator, mask=mask)
548
+
549
+
550
+ try:
551
+ _unpermute_kernel = triton.autotune(
552
+ configs=[
553
+ triton.Config({"BLOCK_SIZE": 64}),
554
+ triton.Config({"BLOCK_SIZE": 128}),
555
+ triton.Config({"BLOCK_SIZE": 256}),
556
+ triton.Config({"BLOCK_SIZE": 512}),
557
+ triton.Config({"BLOCK_SIZE": 1024}),
558
+ triton.Config({"BLOCK_SIZE": 2048}),
559
+ triton.Config({"BLOCK_SIZE": 4096}),
560
+ ],
561
+ key=["hidden_size"],
562
+ )(_unpermute_kernel)
563
+ except RuntimeError:
564
+ pass
565
+
566
+
567
+ def unpermute_with_mask_map(
568
+ inp: torch.Tensor,
569
+ row_id_map: torch.Tensor,
570
+ merging_probs: torch.Tensor | None,
571
+ permuted_probs: torch.Tensor | None,
572
+ num_tokens: int,
573
+ num_experts: int,
574
+ hidden_size: int,
575
+ ):
576
+ """
577
+ Unpermute the input tensor based on the row_id_map.
578
+
579
+ Parameters
580
+ ----------
581
+ inp: torch.Tensor
582
+ Input tensor of shape `[num_out_tokens, hidden_size]`.
583
+ row_id_map: torch.Tensor
584
+ The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
585
+ merging_probs: torch.Tensor
586
+ The merging probabilities of the input tensor. If it is not None, it will be used as weights
587
+ to reduce the unpermuted tokens.
588
+ permuted_probs: torch.Tensor
589
+ The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
590
+ num_tokens: int
591
+ Number of tokens in the permuted tensor.
592
+ num_experts: int
593
+ Number of experts in the permuted tensor.
594
+ hidden_size: int
595
+ Hidden size of the permuted tensor.
596
+ """
597
+ output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
598
+ if permuted_probs is not None:
599
+ unpermuted_probs = torch.empty(
600
+ (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
601
+ )
602
+ else:
603
+ unpermuted_probs = None
604
+ # pylint: disable=unnecessary-lambda-assignment
605
+ grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
606
+ _unpermute_kernel[grid](
607
+ inp,
608
+ output,
609
+ row_id_map,
610
+ merging_probs,
611
+ permuted_probs,
612
+ unpermuted_probs,
613
+ num_experts,
614
+ hidden_size,
615
+ row_id_map.stride(0),
616
+ row_id_map.stride(1),
617
+ inp.stride(0),
618
+ inp.stride(1),
619
+ output.stride(0),
620
+ output.stride(1),
621
+ merging_probs.stride(0) if merging_probs is not None else None,
622
+ merging_probs.stride(1) if merging_probs is not None else None,
623
+ permuted_probs.stride(0) if permuted_probs is not None else None,
624
+ unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
625
+ unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
626
+ PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
627
+ WITH_MERGING_PROBS=merging_probs is not None,
628
+ PERMUTE_PROBS=permuted_probs is not None,
629
+ )
630
+ return output, unpermuted_probs
631
+
632
+
633
+ class _moe_permute_mask_map(torch.autograd.Function):
634
+ """functional Permute with mask router map"""
635
+
636
+ @staticmethod
637
+ def forward(
638
+ ctx,
639
+ inp: torch.Tensor,
640
+ routing_map: torch.Tensor,
641
+ num_out_tokens: int,
642
+ probs: torch.Tensor,
643
+ ) -> tuple[torch.Tensor, torch.Tensor]:
644
+ if not inp.numel():
645
+ ctx.probs = probs
646
+ return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device)
647
+
648
+ assert inp.is_cuda, "TransformerEngine needs CUDA."
649
+ assert routing_map.is_cuda, "TransformerEngine needs CUDA."
650
+ if probs is not None:
651
+ assert probs.is_cuda, "TransformerEngine needs CUDA."
652
+
653
+ assert inp.size(0) == routing_map.size(0), "Permute not possible"
654
+ num_tokens, hidden_size = inp.size()
655
+ num_experts = routing_map.size(1)
656
+ assert (
657
+ num_out_tokens is not None
658
+ ), "num_out_tokens must be provided to the fused permute function."
659
+
660
+ row_id_map = make_row_id_map(routing_map, num_tokens, num_experts)
661
+
662
+ # todo torchao fp8
663
+
664
+ output, permuted_scale, permuted_probs = permute_with_mask_map(
665
+ inp,
666
+ row_id_map,
667
+ probs,
668
+ None,
669
+ num_tokens,
670
+ num_experts,
671
+ num_out_tokens,
672
+ hidden_size,
673
+ None,
674
+ )
675
+
676
+ ctx.save_for_backward(row_id_map)
677
+ ctx.num_experts = num_experts
678
+ ctx.num_tokens = num_tokens
679
+ ctx.hidden_size = hidden_size
680
+ return output, row_id_map, permuted_probs
681
+
682
+ @staticmethod
683
+ def backward(
684
+ ctx,
685
+ permuted_act_grad: torch.Tensor,
686
+ _,
687
+ permuted_probs_grad: torch.Tensor,
688
+ ) -> tuple[torch.Tensor, ...]:
689
+ # pylint: disable=missing-function-docstring
690
+ if not permuted_act_grad.numel():
691
+ return permuted_act_grad, None, None, ctx.probs
692
+
693
+ act_grad = None
694
+ probs_grad = None
695
+ if ctx.needs_input_grad[0]:
696
+ (row_id_map,) = ctx.saved_tensors
697
+ act_grad, probs_grad = unpermute_with_mask_map(
698
+ permuted_act_grad,
699
+ row_id_map,
700
+ None,
701
+ permuted_probs_grad,
702
+ ctx.num_tokens,
703
+ ctx.num_experts,
704
+ ctx.hidden_size,
705
+ )
706
+ if not ctx.needs_input_grad[3]:
707
+ probs_grad = None
708
+ return act_grad, None, None, probs_grad
709
+
710
+
711
+ def moe_permute_with_probs(
712
+ inp: torch.Tensor,
713
+ probs: torch.Tensor,
714
+ routing_map: torch.Tensor,
715
+ num_out_tokens: int = -1,
716
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
717
+ """
718
+ Permute the tokens and probs based on the routing_map.
719
+ Token with the same index will be grouped together.
720
+ Tokens with the same designated expert will be grouped together.
721
+ The routing_map indicates which experts were selected by each token.
722
+
723
+ Parameters
724
+ ----------
725
+ inp: torch.Tensor
726
+ Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
727
+ probs: torch.Tensor
728
+ The tensor of probabilities corresponding to the permuted tokens and is
729
+ of shape [num_tokens, num_experts]. It will be permuted with the tokens
730
+ according to the routing_map.
731
+ routing_map: torch.Tensor
732
+ The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
733
+ The values in it: 1 means the token is routed to this expert and 0 means not.
734
+ num_out_tokens: int, default = -1
735
+ The effective output token count, representing the number of tokens not dropped.
736
+ By default, set to '-1', meaning no tokens are dropped.
737
+ """
738
+ output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
739
+ inp, routing_map, num_out_tokens, probs
740
+ )
741
+ return output, permuted_probs, row_id_map
742
+
743
+
744
+ @triton.jit
745
+ def _unpermute_bwd_with_merging_probs_kernel(
746
+ # pointers
747
+ fwd_output_grad_ptr,
748
+ fwd_input_grad_ptr,
749
+ fwd_input_ptr,
750
+ merging_probs_ptr,
751
+ merging_probs_grad_ptr,
752
+ row_id_map_ptr,
753
+ # sizes
754
+ num_experts: tl.constexpr,
755
+ hidden_size: tl.constexpr,
756
+ # strides
757
+ stride_row_id_map_token,
758
+ stride_row_id_map_expert,
759
+ stride_fwd_output_grad_token,
760
+ stride_fwd_output_grad_hidden,
761
+ stride_fwd_input_grad_token,
762
+ stride_fwd_input_grad_hidden,
763
+ stride_fwd_input_token,
764
+ stride_fwd_input_hidden,
765
+ stride_merging_probs_token,
766
+ stride_merging_probs_expert,
767
+ stride_merging_probs_grad_token,
768
+ stride_merging_probs_grad_expert,
769
+ # metas
770
+ PROBS_LOAD_WIDTH: tl.constexpr,
771
+ BLOCK_SIZE: tl.constexpr,
772
+ ):
773
+ data_type = fwd_output_grad_ptr.dtype.element_ty
774
+ compute_type = tl.float32
775
+
776
+ pid = tl.program_id(0)
777
+ map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
778
+ token_probs_grad_off = (
779
+ pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off
780
+ )
781
+ tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts)
782
+ n_routed = tl.load(
783
+ row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert
784
+ )
785
+ for idx in tl.range(n_routed):
786
+ dst_row = tl.load(
787
+ row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
788
+ )
789
+ expert_idx = tl.load(
790
+ row_id_map_ptr
791
+ + pid * stride_row_id_map_token
792
+ + (num_experts + idx) * stride_row_id_map_expert
793
+ )
794
+ prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
795
+ current_start = 0
796
+ while current_start < hidden_size:
797
+ current_offset = current_start + tl.arange(0, BLOCK_SIZE)
798
+ mask = current_offset < hidden_size
799
+ input_off = (
800
+ pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden
801
+ )
802
+ inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
803
+ inp = inp.to(compute_type)
804
+ merging_prob_off = (
805
+ pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
806
+ )
807
+ merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
808
+ output = inp * merging_prob
809
+ output = output.to(data_type)
810
+ output_off = (
811
+ dst_row * stride_fwd_input_grad_token
812
+ + current_offset * stride_fwd_input_grad_hidden
813
+ )
814
+ tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)
815
+
816
+ fwd_input_off = (
817
+ dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
818
+ )
819
+ fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
820
+ prob_grad_accum += fwd_input.to(compute_type) * inp
821
+ current_start += BLOCK_SIZE
822
+ probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
823
+ probs_grad_off = (
824
+ pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
825
+ )
826
+ tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
827
+
828
+
829
+ try:
830
+ _unpermute_bwd_with_merging_probs_kernel = triton.autotune(
831
+ configs=[
832
+ triton.Config({"BLOCK_SIZE": 64}),
833
+ triton.Config({"BLOCK_SIZE": 128}),
834
+ triton.Config({"BLOCK_SIZE": 256}),
835
+ triton.Config({"BLOCK_SIZE": 512}),
836
+ triton.Config({"BLOCK_SIZE": 1024}),
837
+ triton.Config({"BLOCK_SIZE": 2048}),
838
+ triton.Config({"BLOCK_SIZE": 4096}),
839
+ ],
840
+ key=["hidden_size"],
841
+ )(_unpermute_bwd_with_merging_probs_kernel)
842
+ except RuntimeError:
843
+ pass
844
+
845
+
846
+ def unpermute_with_mask_map_bwd_with_merging_probs(
847
+ fwd_output_grad: torch.Tensor,
848
+ row_id_map: torch.Tensor,
849
+ fwd_input: torch.Tensor,
850
+ merging_probs: torch.Tensor,
851
+ num_tokens: int,
852
+ num_experts: int,
853
+ num_out_tokens: int,
854
+ hidden_size: int,
855
+ ):
856
+ """
857
+ Unpermute backward pass kernel with merging probs.
858
+
859
+ Parameters
860
+ ----------
861
+ fwd_output_grad: torch.Tensor
862
+ The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
863
+ row_id_map: torch.Tensor
864
+ The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
865
+ fwd_input: torch.Tensor
866
+ The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
867
+ merging_probs: torch.Tensor
868
+ The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
869
+ num_tokens: int
870
+ Number of tokens in the permuted tensor.
871
+ num_experts: int
872
+ Number of experts in the permuted tensor.
873
+ num_out_tokens: int
874
+ Number of tokens in the output tensor.
875
+ hidden_size: int
876
+ Hidden size of the output tensor.
877
+ """
878
+ act_grad = torch.empty(
879
+ (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
880
+ )
881
+ merging_probs_grad = torch.empty(
882
+ (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
883
+ )
884
+ grid = (num_tokens,)
885
+ _unpermute_bwd_with_merging_probs_kernel[grid](
886
+ fwd_output_grad,
887
+ act_grad,
888
+ fwd_input,
889
+ merging_probs,
890
+ merging_probs_grad,
891
+ row_id_map,
892
+ num_experts,
893
+ hidden_size,
894
+ row_id_map.stride(0),
895
+ row_id_map.stride(1),
896
+ fwd_output_grad.stride(0),
897
+ fwd_output_grad.stride(1),
898
+ act_grad.stride(0),
899
+ act_grad.stride(1),
900
+ fwd_input.stride(0),
901
+ fwd_input.stride(1),
902
+ merging_probs.stride(0),
903
+ merging_probs.stride(1),
904
+ merging_probs_grad.stride(0),
905
+ merging_probs_grad.stride(1),
906
+ PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
907
+ )
908
+ return act_grad, merging_probs_grad
909
+
910
+
911
+ class _moe_unpermute_mask_map(torch.autograd.Function):
912
+ """functional Unpermute with mask router map"""
913
+
914
+ @staticmethod
915
+ def forward(
916
+ ctx,
917
+ inp: torch.Tensor,
918
+ row_id_map: torch.Tensor,
919
+ merging_probs: torch.Tensor | None,
920
+ restore_shape: torch.Size | None,
921
+ ) -> torch.Tensor:
922
+ # pylint: disable=missing-function-docstring
923
+ if not inp.numel():
924
+ ctx.merging_probs = merging_probs
925
+ return inp
926
+
927
+ if restore_shape is None:
928
+ restore_shape = inp.shape
929
+ num_tokens, hidden_size = restore_shape
930
+ num_experts = (row_id_map.size(1) - 1) // 2
931
+
932
+ with_probs = merging_probs is not None
933
+ if with_probs:
934
+ assert merging_probs.is_cuda, "TransformerEngine needs CUDA."
935
+
936
+ # Device check
937
+ assert inp.is_cuda, "TransformerEngine needs CUDA."
938
+ assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
939
+
940
+ unpermuted_output, _ = unpermute_with_mask_map(
941
+ inp,
942
+ row_id_map,
943
+ merging_probs,
944
+ None,
945
+ num_tokens,
946
+ num_experts,
947
+ hidden_size,
948
+ )
949
+
950
+ if with_probs:
951
+ ctx.save_for_backward(inp, row_id_map, merging_probs)
952
+ else:
953
+ ctx.save_for_backward(row_id_map)
954
+ ctx.num_experts = num_experts
955
+ ctx.num_tokens = num_tokens
956
+ ctx.num_permuted_tokens = inp.size(0)
957
+ ctx.hidden_size = hidden_size
958
+ ctx.with_probs = with_probs
959
+ return unpermuted_output
960
+
961
+ @staticmethod
962
+ def backward(ctx, unpermuted_act_grad):
963
+ # pylint: disable=missing-function-docstring
964
+ if not unpermuted_act_grad.numel():
965
+ return unpermuted_act_grad, None, ctx.merging_probs, None
966
+
967
+ act_grad = None
968
+ probs_grad = None
969
+ if ctx.needs_input_grad[0]:
970
+ if ctx.with_probs:
971
+ fwd_input, row_id_map, merging_probs = ctx.saved_tensors
972
+ else:
973
+ (row_id_map,) = ctx.saved_tensors
974
+
975
+ if ctx.with_probs:
976
+ act_grad, probs_grad = (
977
+ unpermute_with_mask_map_bwd_with_merging_probs(
978
+ unpermuted_act_grad,
979
+ row_id_map,
980
+ fwd_input,
981
+ merging_probs,
982
+ ctx.num_tokens,
983
+ ctx.num_experts,
984
+ ctx.num_permuted_tokens,
985
+ ctx.hidden_size,
986
+ )
987
+ )
988
+ else:
989
+ act_grad, permuted_scale, _ = permute_with_mask_map(
990
+ unpermuted_act_grad,
991
+ row_id_map,
992
+ None,
993
+ None,
994
+ ctx.num_tokens,
995
+ ctx.num_experts,
996
+ ctx.num_permuted_tokens,
997
+ ctx.hidden_size,
998
+ None,
999
+ )
1000
+
1001
+ if not ctx.needs_input_grad[2]:
1002
+ probs_grad = None
1003
+ return act_grad, None, probs_grad, None
1004
+
1005
+
1006
+ def moe_unpermute_mask(
1007
+ inp: torch.Tensor,
1008
+ row_id_map: torch.Tensor,
1009
+ merging_probs: torch.Tensor | None = None,
1010
+ restore_shape: torch.Size | None = None,
1011
+ ) -> torch.Tensor:
1012
+ """
1013
+ Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
1014
+ corresponding probabilities.
1015
+
1016
+ Parameters
1017
+ ----------
1018
+ inp: torch.Tensor
1019
+ Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
1020
+ row_id_map: torch.Tensor
1021
+ The tensor of a mapping table for sorted indices used to unpermute the tokens,
1022
+ which is the second output tensor of `Permute`.
1023
+ merging_probs: torch.Tensor, default = None
1024
+ The tensor of probabilities corresponding to the permuted tokens. If provided,
1025
+ the unpermuted tokens will be merged with their respective probabilities.
1026
+ By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
1027
+ restore_shape: torch.Size, default = None
1028
+ The output shape after the unpermute operation.
1029
+ map_type: str, default = 'mask'
1030
+ Type of the routing map tensor. Should be the same as the value passed to moe_permute.
1031
+ Options are: 'mask', 'index'.
1032
+ probs: torch.Tensor, default = None
1033
+ Renamed to merging_probs. Keep for backward compatibility.
1034
+ """
1035
+ return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape)