tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1035 @@
1
+ """TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
2
+
3
+ import functools
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax import lax
8
+ from jax._src import dtypes
9
+ from jax.experimental import pallas as pl
10
+ from jax.experimental import shard_map
11
+ from jax.experimental.pallas import tpu as pltpu
12
+
13
+ P = jax.sharding.PartitionSpec
14
+
15
+ cdiv = pl.cdiv
16
+
17
+
18
+ def align_to(x, a):
19
+ return cdiv(x, a) * a
20
+
21
+
22
+ def get_dtype_packing(dtype):
23
+ bits = dtypes.bit_width(dtype)
24
+ return 32 // bits
25
+
26
+
27
+ def broadcast_minor(src, shape):
28
+ if src.shape == shape:
29
+ return src
30
+ assert src.shape[:-1] == shape[:-1]
31
+ assert src.shape[-1] % 128 == 0
32
+ target_minor = align_to(shape[-1], src.shape[-1])
33
+ # no-op concatenation.
34
+ return jnp.concatenate([src for _ in range(target_minor // src.shape[-1])],
35
+ axis=-1)[..., :shape[-1]]
36
+
37
+
38
+ def ref_moe(
39
+ tokens: jax.Array, # (num_tokens, hidden_size)
40
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
41
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
42
+ gating_output: jax.Array, # (num_tokens, num_experts)
43
+ top_k: int,
44
+ activation="silu",
45
+ ):
46
+ n_tokens = tokens.shape[0] # num_tokens
47
+
48
+ # Compute gating scores for all experts
49
+ gating_logits = jax.nn.softmax(gating_output,
50
+ axis=-1) # [num_tokens, n_experts]
51
+
52
+ # Select top-k experts per token
53
+ top_k_logits, top_k_indices = lax.top_k(
54
+ gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
55
+
56
+ t_outputs = []
57
+
58
+ # Process each token individually
59
+ for i in range(n_tokens):
60
+ curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, d_model]
61
+ assigned_expert_ids = top_k_indices[
62
+ i] # [top_k] - indices of selected experts for token i
63
+ tok_expert_act = []
64
+
65
+ # Process each selected expert for the current token
66
+ for expert_id in assigned_expert_ids:
67
+ # Get expert weights
68
+ expert_weight_1 = jnp.concat(
69
+ [w1[expert_id, 0], w1[expert_id, 1]],
70
+ axis=-1) # [d_model, 2 * intermediate_size]
71
+ expert_weight_2 = w2[expert_id] # [intermediate_size, d_model]
72
+
73
+ # First linear layer with SwiGLU activation
74
+ gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
75
+
76
+ # Split into gate and up projections for SwiGLU
77
+ gmm1_w1_proj, gmm1_w3_proj = jnp.split(
78
+ gmm_1_out, 2,
79
+ axis=-1) # [1, intermediate_size], [1, intermediate_size]
80
+
81
+ # Apply gated activation: activation(gate) * up
82
+ if activation == "silu":
83
+ act = jax.nn.silu(
84
+ gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
85
+ elif activation == "gelu":
86
+ act = jax.nn.gelu(
87
+ gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
88
+ else:
89
+ raise ValueError(
90
+ f"Unsupported activation: {activation}. Use 'silu' or 'gelu'."
91
+ )
92
+
93
+ # Second linear layer (down projection)
94
+ gmm_2_out = act @ expert_weight_2 # [1, d_model]
95
+ tok_expert_act.append(gmm_2_out)
96
+
97
+ # Combine outputs from all selected experts
98
+ experts_act = jnp.concatenate(tok_expert_act,
99
+ axis=0) # [top_k, d_model]
100
+
101
+ # Weighted sum using top-k gating weights
102
+ top_k_weights = top_k_logits[i] # [top_k]
103
+ top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
104
+ weighted_output = jnp.sum(experts_act * top_k_weights,
105
+ axis=0,
106
+ keepdims=True) # [1, d_model]
107
+
108
+ t_outputs.append(weighted_output)
109
+
110
+ return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model]
111
+
112
+
113
+ def _fused_ep_moe_kernel(
114
+ # Input
115
+ tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
116
+ w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
117
+ w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
118
+ gating_hbm, # (local_num_tokens, padded_num_experts)
119
+ a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
120
+ # Output
121
+ output_hbm, # (local_num_tokens, hidden_size)
122
+ # Scratch
123
+ t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_num_experts)
124
+ d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
125
+ expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
126
+ expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
127
+ expert_sizes_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
128
+ a2a_s_sends_x2_smem, # <e_sem_id> (2,)
129
+ a2a_s_x2_vmem, # <e_sem_id> (2, bt * num_devices, t_packing, hidden_size // t_packing)
130
+ a2a_s_acc_x2_vmem, # <e_sem_id> (2, bt * num_devices, t_packing, hidden_size // t_packing)
131
+ ### Accumulation for gathered tokens:
132
+ a2a_g_acc_vmem, # (top_k, bt, t_packing, hidden_size // t_packing)
133
+ ### Expert weight double buffering:
134
+ b_gating_x2_vmem, # <bt_sem_id> (2, bt, padded_num_experts)
135
+ b_output_x2_vmem, # <bt_sem_id> (2, bt, hidden_size)
136
+ b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
137
+ b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
138
+ b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
139
+ b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
140
+ ### Semaphores:
141
+ local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
142
+ send_sems, # <e_sem_id> (2,)
143
+ recv_sems, # <e_sem_id> (2,)
144
+ a2a_gather_sem,
145
+ a2a_acc_sem,
146
+ *,
147
+ top_k: int,
148
+ ep_axis_name: str,
149
+ # Kernel tuning params.
150
+ bt: int, # Block size of local_num_tokens.
151
+ bf: int, # Block size of intermediate_size.
152
+ bd1: int, # Block size of hidden_size in w1.
153
+ bd2: int, # Block size of hidden_size in w2.
154
+ btc: int, # Compute size of block tokens for active expert.
155
+ bfc: int, # Compute size of block intermediate_size.
156
+ bd1c: int, # Compute size of block hidden_size.
157
+ bd2c: int, # Compute size of block hidden_size.
158
+ ):
159
+ my_id = lax.axis_index(ep_axis_name)
160
+ num_devices = lax.axis_size(ep_axis_name)
161
+ local_num_tokens = tokens_hbm.shape[0]
162
+ local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
163
+ # num_experts = local_num_experts * num_devices
164
+ # padded_num_experts = expert_starts_x2_smem.shape[-1]
165
+ right_id = (my_id + 1) % num_devices
166
+
167
+ t_dtype = tokens_hbm.dtype
168
+ t_packing = get_dtype_packing(t_dtype)
169
+ t_bitwidth = 32 // t_packing
170
+ assert a2a_g_hbm.dtype == t_dtype
171
+ assert w1_hbm.dtype == t_dtype
172
+ assert w2_hbm.dtype == t_dtype
173
+
174
+ h_per_packing = hidden_size // t_packing
175
+ assert tokens_hbm.shape[-1] == h_per_packing
176
+ bd1_per_packing = bd1 // t_packing
177
+ bd2_per_packing = bd2 // t_packing
178
+ bd1c_per_packing = bd1c // t_packing
179
+ bd2c_per_packing = bd2c // t_packing
180
+
181
+ num_bt = cdiv(local_num_tokens, bt)
182
+ num_bf = cdiv(intermediate_size, bf)
183
+ num_bd1 = cdiv(hidden_size, bd1)
184
+ num_bd2 = cdiv(hidden_size, bd2)
185
+
186
+ def sync_barrier():
187
+ barrier_sem = pltpu.get_barrier_semaphore()
188
+ pltpu.semaphore_signal(
189
+ barrier_sem,
190
+ device_id=(0, right_id),
191
+ device_id_type=pltpu.DeviceIdType.MESH,
192
+ )
193
+ pltpu.semaphore_wait(barrier_sem, 1)
194
+
195
+ def start_fetch_b_gating(bt_id, priority=0):
196
+ is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
197
+ sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
198
+ bt_sem_id = (bt_id + 2) % 2
199
+ b_gating_sem = local_sems.at[bt_sem_id, 0]
200
+ pltpu.make_async_copy(
201
+ src_ref=gating_hbm.at[pl.ds(bt_id * bt, sz)],
202
+ dst_ref=b_gating_x2_vmem.at[bt_sem_id, pl.ds(0, sz)],
203
+ sem=b_gating_sem,
204
+ ).start(priority=priority)
205
+
206
+ def wait_fetch_b_gating(bt_id):
207
+ bt_sem_id = bt_id % 2
208
+ b_gating_sem = local_sems.at[bt_sem_id, 0]
209
+ pltpu.make_async_copy(
210
+ src_ref=b_gating_x2_vmem.at[bt_sem_id],
211
+ dst_ref=b_gating_x2_vmem.at[bt_sem_id],
212
+ sem=b_gating_sem,
213
+ ).wait()
214
+
215
+ def get_top_k(input, top_k):
216
+ assert len(input.shape) == 2, input.shape
217
+ input = input.astype(jnp.float32)
218
+ top_k_logits_lst = []
219
+ top_k_indices_lst = []
220
+ t2e = jnp.zeros(input.shape, dtype=jnp.int32)
221
+ t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32)
222
+ iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
223
+ for k_id in range(top_k):
224
+ # TODO(jevinjiang): return both top_k values and indices in op in Mosaic
225
+ top_k_logits = jnp.broadcast_to(
226
+ jnp.max(input, axis=1, keepdims=True),
227
+ (input.shape[0], 128)).astype(input.dtype)
228
+ top_k_logits_lst.append(top_k_logits)
229
+ # TODO(jevinjiang): support bf16 argmax in Mosaic
230
+ top_k_indices = jnp.broadcast_to(
231
+ jnp.argmax(input, axis=1, keepdims=True), input.shape)
232
+ top_k_indices_lst.append(top_k_indices)
233
+ t2e_routing = jnp.where(iota == k_id, top_k_indices, t2e_routing)
234
+ mask = iota == top_k_indices
235
+ t2e += mask.astype(jnp.int32)
236
+ if k_id != top_k - 1:
237
+ input = jnp.where(mask, -jnp.inf, input)
238
+
239
+ expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
240
+ expert_starts = jnp.zeros_like(expert_sizes)
241
+ return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
242
+
243
+ def all_reduce_metadata(bt_sem_id, t2e_routing, starts, sizes):
244
+ send_sem = send_sems.at[0]
245
+ recv_sem = recv_sems.at[0]
246
+
247
+ # All-reduce to accumulate starts and sizes and transfer to SMEM.
248
+ def _all_reduce_metadata(
249
+ t2e_routing_vmem,
250
+ d2e_count_vmem,
251
+ offsets_vmem,
252
+ starts_vmem,
253
+ sizes_vmem,
254
+ ):
255
+ offsets_vmem[...] = jnp.zeros_like(offsets_vmem)
256
+ # TODO(jevinjiang): check how slow is VMEM -> SMEM.
257
+ offsets_copy = pltpu.async_copy(
258
+ src_ref=offsets_vmem,
259
+ dst_ref=expert_offsets_x2_smem.at[bt_sem_id],
260
+ sem=send_sem,
261
+ )
262
+ t2e_routing_vmem[...] = t2e_routing
263
+ t2e_routing_copy = pltpu.async_copy(
264
+ src_ref=t2e_routing_vmem,
265
+ dst_ref=t2e_routing_x2_smem.at[bt_sem_id],
266
+ sem=send_sem,
267
+ )
268
+ reduced_sizes = sizes
269
+ reduced_starts = starts
270
+ row_id = my_id
271
+ d2e_count_vmem[row_id] = sizes
272
+ for i in range(num_devices - 1):
273
+ sync_barrier()
274
+ # TODO(jevinjiang): we can use double buffering to improve AR if needed.
275
+ pltpu.async_remote_copy(
276
+ src_ref=d2e_count_vmem.at[row_id],
277
+ dst_ref=d2e_count_vmem.at[row_id],
278
+ send_sem=send_sem,
279
+ recv_sem=recv_sem,
280
+ device_id=(0, right_id),
281
+ device_id_type=pltpu.DeviceIdType.MESH,
282
+ ).wait()
283
+ row_id = (row_id + num_devices - 1) % num_devices
284
+ new_sizes = d2e_count_vmem[row_id]
285
+ reduced_sizes += new_sizes
286
+ reduced_starts += lax.select(my_id > i, new_sizes,
287
+ jnp.zeros_like(new_sizes))
288
+ starts_vmem[...] = reduced_starts
289
+ sizes_vmem[...] = reduced_sizes
290
+
291
+ starts_copy = pltpu.async_copy(
292
+ src_ref=starts_vmem,
293
+ dst_ref=expert_starts_x2_smem.at[bt_sem_id],
294
+ sem=send_sem,
295
+ )
296
+ sizes_copy = pltpu.async_copy(
297
+ src_ref=sizes_vmem,
298
+ dst_ref=expert_sizes_x2_smem.at[bt_sem_id],
299
+ sem=send_sem,
300
+ )
301
+
302
+ # TODO(jevinjiang): if d2e_count is too big, we can store in HBM and fetch
303
+ # to SMEM partially.
304
+ d2e_count_copy = pltpu.async_copy(
305
+ src_ref=d2e_count_vmem,
306
+ dst_ref=d2e_count_x2_smem.at[bt_sem_id],
307
+ sem=send_sem,
308
+ )
309
+
310
+ t2e_routing_copy.wait()
311
+ d2e_count_copy.wait()
312
+ offsets_copy.wait()
313
+ starts_copy.wait()
314
+ sizes_copy.wait()
315
+
316
+ pl.run_scoped(
317
+ _all_reduce_metadata,
318
+ pltpu.VMEM(t2e_routing_x2_smem.shape[1:],
319
+ t2e_routing_x2_smem.dtype),
320
+ pltpu.VMEM(d2e_count_x2_smem.shape[1:], d2e_count_x2_smem.dtype),
321
+ pltpu.VMEM(expert_offsets_x2_smem.shape[1:],
322
+ expert_offsets_x2_smem.dtype),
323
+ pltpu.VMEM(expert_starts_x2_smem.shape[1:],
324
+ expert_starts_x2_smem.dtype),
325
+ pltpu.VMEM(expert_sizes_x2_smem.shape[1:],
326
+ expert_sizes_x2_smem.dtype),
327
+ )
328
+
329
+ def start_a2a_scatter(bt_id, e_sem_id, local_e_id):
330
+ bt_sem_id = bt_id % 2
331
+
332
+ # Counting the number of remote sends from the current device.
333
+ send_sz = 0
334
+ for bt_t_id in range(bt):
335
+ for k_id in range(top_k):
336
+ e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id]
337
+ is_active_expert = e_id % local_num_experts == local_e_id
338
+ recv_id = e_id // local_num_experts
339
+ offset = expert_offsets_x2_smem[bt_sem_id, 0, e_id]
340
+ sz = lax.select(is_active_expert, 1, 0)
341
+ is_local = recv_id == my_id
342
+ local_sz = lax.select(is_local, sz, 0)
343
+ remote_sz = lax.select(is_local, 0, sz)
344
+ send_sz += remote_sz
345
+ expert_offsets_x2_smem[bt_sem_id, 0,
346
+ e_id] = (offset + local_sz + remote_sz)
347
+ start = expert_starts_x2_smem[bt_sem_id, 0, e_id] + offset
348
+ t_id = bt * bt_id + bt_t_id
349
+ # TODO(jevinjiang): compare the perf when using branches.
350
+ pltpu.make_async_copy(
351
+ src_ref=tokens_hbm.at[pl.ds(t_id, local_sz)],
352
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id,
353
+ pl.ds(start, local_sz)],
354
+ sem=recv_sems.at[e_sem_id],
355
+ ).start()
356
+ pltpu.make_async_remote_copy(
357
+ src_ref=tokens_hbm.at[pl.ds(t_id, remote_sz)],
358
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id,
359
+ pl.ds(start, remote_sz)],
360
+ send_sem=send_sems.at[e_sem_id],
361
+ recv_sem=recv_sems.at[e_sem_id],
362
+ device_id=(
363
+ 0,
364
+ recv_id,
365
+ ),
366
+ ).start()
367
+ a2a_s_sends_x2_smem[e_sem_id] = send_sz
368
+
369
+ def wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id):
370
+ bt_sem_id = bt_id % 2
371
+ e_id = my_id * local_num_experts + local_e_id
372
+ sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
373
+ pltpu.make_async_copy(
374
+ src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
375
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
376
+ sem=recv_sems.at[e_sem_id],
377
+ ).wait()
378
+
379
+ def wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id):
380
+ del bt_id, local_e_id
381
+ sz = a2a_s_sends_x2_smem[e_sem_id]
382
+ pltpu.make_async_copy(
383
+ src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
384
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
385
+ sem=send_sems.at[e_sem_id],
386
+ ).wait()
387
+
388
+ def start_a2a_gather(bt_id, e_sem_id, local_e_id):
389
+ my_e_id = my_id * local_num_experts + local_e_id
390
+ bt_sem_id = bt_id % 2
391
+ start = 0
392
+ for recv_id in range(num_devices):
393
+ sz = d2e_count_x2_smem[bt_sem_id, recv_id, 0, my_e_id]
394
+ is_local = recv_id == my_id
395
+ local_sz = lax.select(is_local, sz, 0)
396
+ remote_sz = lax.select(is_local, 0, sz)
397
+ pltpu.make_async_copy(
398
+ src_ref=a2a_s_acc_x2_vmem.at[e_sem_id,
399
+ pl.ds(start, local_sz)],
400
+ dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, local_sz)],
401
+ sem=a2a_gather_sem,
402
+ ).start()
403
+ pltpu.make_async_remote_copy(
404
+ src_ref=a2a_s_acc_x2_vmem.at[e_sem_id,
405
+ pl.ds(start, remote_sz)],
406
+ dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
407
+ send_sem=send_sems.at[e_sem_id],
408
+ recv_sem=a2a_gather_sem,
409
+ device_id=(0, recv_id),
410
+ ).start()
411
+ start += sz
412
+
413
+ def wait_a2a_gather_send(bt_id, e_sem_id, local_e_id):
414
+ my_e_id = my_id * local_num_experts + local_e_id
415
+ bt_sem_id = bt_id % 2
416
+ sz = expert_sizes_x2_smem[bt_sem_id, 0, my_e_id]
417
+ local_sz = d2e_count_x2_smem[bt_sem_id, my_id, 0, my_e_id]
418
+ remote_sz = sz - local_sz
419
+ is_valid = jnp.logical_and(0 <= local_e_id, local_e_id
420
+ < local_num_experts)
421
+ remote_sz = lax.select(is_valid, remote_sz, 0)
422
+ pltpu.make_async_copy(
423
+ src_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)],
424
+ dst_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)],
425
+ sem=send_sems.at[e_sem_id],
426
+ ).wait()
427
+
428
+ def wait_a2a_gather_recv_all():
429
+ sz = top_k * bt
430
+ pltpu.make_async_copy(
431
+ src_ref=a2a_g_hbm.at[0, pl.ds(0, sz)],
432
+ dst_ref=a2a_g_hbm.at[0, pl.ds(0, sz)],
433
+ sem=a2a_gather_sem,
434
+ ).wait()
435
+
436
+ def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
437
+ for p in range(t_packing):
438
+ offset = p * h_per_packing + bd1_id * bd1_per_packing
439
+ pltpu.make_async_copy(
440
+ src_ref=w1_hbm.at[
441
+ local_e_id,
442
+ 0,
443
+ pl.ds(offset, bd1_per_packing),
444
+ pl.ds(bf_id * bf, bf),
445
+ ],
446
+ dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
447
+ sem=local_sems.at[bw1_sem_id, 1],
448
+ ).start()
449
+
450
+ def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
451
+ for p in range(t_packing):
452
+ offset = p * h_per_packing + bd2_id * bd2_per_packing
453
+ pltpu.make_async_copy(
454
+ src_ref=w2_hbm.at[
455
+ local_e_id,
456
+ pl.ds(bf_id * bf, bf),
457
+ pl.ds(offset, bd2_per_packing),
458
+ ],
459
+ dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
460
+ sem=local_sems.at[bw2_sem_id, 2],
461
+ ).start()
462
+
463
+ def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
464
+ for p in range(t_packing):
465
+ offset = p * h_per_packing + bd3_id * bd1_per_packing
466
+ pltpu.make_async_copy(
467
+ src_ref=w1_hbm.at[
468
+ local_e_id,
469
+ 1,
470
+ pl.ds(offset, bd1_per_packing),
471
+ pl.ds(bf_id * bf, bf),
472
+ ],
473
+ dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
474
+ sem=local_sems.at[bw3_sem_id, 3],
475
+ ).start()
476
+
477
+ def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
478
+ del local_e_id, bf_id, bd1_id
479
+ pltpu.make_async_copy(
480
+ src_ref=b_w1_x2_vmem.at[bw1_sem_id],
481
+ dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
482
+ sem=local_sems.at[bw1_sem_id, 1],
483
+ ).wait()
484
+
485
+ def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
486
+ del local_e_id, bf_id, bd2_id
487
+ pltpu.make_async_copy(
488
+ src_ref=b_w2_x2_vmem.at[bw2_sem_id],
489
+ dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
490
+ sem=local_sems.at[bw2_sem_id, 2],
491
+ ).wait()
492
+
493
+ def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
494
+ del local_e_id, bf_id, bd3_id
495
+ pltpu.make_async_copy(
496
+ src_ref=b_w3_x2_vmem.at[bw3_sem_id],
497
+ dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
498
+ sem=local_sems.at[bw3_sem_id, 3],
499
+ ).wait()
500
+
501
+ def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
502
+ next_bd1_id = bd1_id + 1
503
+ next_bd2_id = bd2_id + 1
504
+ next_sem_id = (bw_sem_id + 1) % 2
505
+
506
+ if bf_id >= num_bf:
507
+ return
508
+ if next_bd1_id < num_bd1:
509
+ start_fetch_bw1(local_e_id, next_sem_id, bf_id, next_bd1_id)
510
+ start_fetch_bw3(local_e_id, next_sem_id, bf_id, next_bd1_id)
511
+ elif next_bd1_id == num_bd1:
512
+ start_fetch_bw2(local_e_id, next_sem_id, bf_id, 0)
513
+ elif next_bd2_id < num_bd2:
514
+ start_fetch_bw2(local_e_id, next_sem_id, bf_id, next_bd2_id)
515
+ elif next_bd2_id == num_bd2:
516
+ start_fetch_next_bw(local_e_id, bw_sem_id, bf_id + 1, -1, -1)
517
+ else:
518
+ raise RuntimeError("Unreachable")
519
+
520
+ def dynamic_ffn1(
521
+ t_b32_vmem,
522
+ w1_vmem,
523
+ w3_vmem,
524
+ acc1_vmem,
525
+ acc3_vmem,
526
+ dyn_sz,
527
+ should_init,
528
+ ):
529
+ assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
530
+ assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_packing,
531
+ bf)
532
+ assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
533
+ assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
534
+ assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
535
+
536
+ num_loops = cdiv(dyn_sz, btc)
537
+ repack_ty = jnp.dtype(f"int{t_bitwidth}")
538
+
539
+ def body(btc_id, _):
540
+ for bd1c_id in range(cdiv(bd1, bd1c)):
541
+ t_b32 = t_b32_vmem[
542
+ pl.ds(btc_id * btc, btc),
543
+ pl.ds(bd1c_id * bd1c_per_packing, bd1c_per_packing),
544
+ ]
545
+ for p_id in range(t_packing):
546
+ t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
547
+ t_b32 = t_b32 >> t_bitwidth
548
+ for bfc_id in range(cdiv(bf, bfc)):
549
+ w_slices = (
550
+ p_id,
551
+ pl.ds(bd1c_id * bd1c_per_packing,
552
+ bd1c_per_packing),
553
+ pl.ds(bfc_id * bfc, bfc),
554
+ )
555
+ w1 = w1_vmem[*w_slices]
556
+ acc1 = jnp.dot(t,
557
+ w1,
558
+ preferred_element_type=jnp.float32)
559
+ w3 = w3_vmem[*w_slices]
560
+ acc3 = jnp.dot(t,
561
+ w3,
562
+ preferred_element_type=jnp.float32)
563
+ acc_slices = (pl.ds(btc_id * btc,
564
+ btc), pl.ds(bfc_id * bfc, bfc))
565
+ if should_init and p_id == bd1c_id == 0:
566
+ acc1_vmem[*acc_slices] = acc1
567
+ acc3_vmem[*acc_slices] = acc3
568
+ else:
569
+ acc1_vmem[*acc_slices] += acc1
570
+ acc3_vmem[*acc_slices] += acc3
571
+
572
+ lax.fori_loop(0, num_loops, body, None)
573
+
574
+ def dynamic_ffn2(
575
+ acc1_vmem,
576
+ acc3_vmem,
577
+ w2_vmem,
578
+ res_b32_vmem,
579
+ dyn_sz,
580
+ should_init,
581
+ ):
582
+ assert res_b32_vmem.shape == (bt * num_devices, bd2_per_packing)
583
+ assert w2_vmem.shape == (t_packing, bf, bd2_per_packing), (
584
+ w2_vmem.shape,
585
+ t_packing,
586
+ bf,
587
+ bd2_per_packing,
588
+ )
589
+ assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
590
+ assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
591
+ assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
592
+ assert t_dtype in (jnp.float32, jnp.bfloat16)
593
+
594
+ num_loops = cdiv(dyn_sz, btc)
595
+ assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
596
+
597
+ def body(btc_id, _):
598
+ for bd2c_id in range(cdiv(bd2, bd2c)):
599
+ res_lst = []
600
+ for p_id in range(t_packing):
601
+ res = jnp.zeros((btc, bd2c_per_packing), dtype=jnp.float32)
602
+ for bfc_id in range(cdiv(bf, bfc)):
603
+ acc_slices = (pl.ds(btc_id * btc,
604
+ btc), pl.ds(bfc_id * bfc, bfc))
605
+ acc1 = acc1_vmem[*acc_slices]
606
+ acc3 = acc3_vmem[*acc_slices]
607
+ act = jax.nn.silu(acc1) * acc3
608
+ w2 = w2_vmem[
609
+ p_id,
610
+ pl.ds(bfc_id * bfc, bfc),
611
+ pl.ds(bd2c_id *
612
+ bd2c_per_packing, bd2c_per_packing),
613
+ ]
614
+ res += jnp.dot(act,
615
+ w2,
616
+ preferred_element_type=jnp.float32)
617
+ res = pltpu.bitcast(res, jnp.uint32)
618
+ if t_packing == 2:
619
+ res = res >> 16 << (16 * p_id)
620
+ else:
621
+ assert t_packing == 1
622
+ res_lst.append(res)
623
+ res = res_lst[0]
624
+ # TODO(jevinjiang): use interleaved packing when it is exposed to Pallas
625
+ for i in range(1, t_packing):
626
+ res |= res_lst[i]
627
+ sliced_res_vmem = res_b32_vmem.at[
628
+ pl.ds(btc_id * btc, btc),
629
+ pl.ds(bd2c_id * bd2c_per_packing, bd2c_per_packing),
630
+ ]
631
+ if should_init:
632
+ sliced_res_vmem[...] = res
633
+ else:
634
+ sliced_res_vmem[...] = pltpu.bitcast(
635
+ sliced_res_vmem.bitcast(t_dtype)[...] +
636
+ pltpu.bitcast(res, t_dtype),
637
+ sliced_res_vmem.dtype,
638
+ )
639
+
640
+ lax.fori_loop(0, num_loops, body, None)
641
+
642
+ def expert_ffn(bt_id, e_sem_id, local_e_id):
643
+ bt_sem_id = bt_id % 2
644
+ bw_sem_id = 0
645
+ # start_fetch_bw1(local_e_id, bw_sem_id, 0, 0)
646
+ # start_fetch_bw3(local_e_id, bw_sem_id, 0, 0)
647
+ a2a_s_b32_vmem = (a2a_s_x2_vmem.bitcast(jnp.uint32).reshape(
648
+ 2, bt * num_devices, hidden_size // t_packing).at[e_sem_id])
649
+ a2a_s_acc_b32_vmem = (a2a_s_acc_x2_vmem.bitcast(jnp.uint32).reshape(
650
+ 2, bt * num_devices, hidden_size // t_packing).at[e_sem_id])
651
+ b_acc_vmem_2d = b_acc_vmem.reshape(bt * num_devices, bf * 2)
652
+ b_acc1_vmem = b_acc_vmem_2d.at[:, :bf]
653
+ b_acc3_vmem = b_acc_vmem_2d.at[:, bf:]
654
+
655
+ e_id = my_id * local_num_experts + local_e_id
656
+ dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
657
+
658
+ bd1_per_packing = bd1 // t_packing
659
+ bd2_per_packing = bd2 // t_packing
660
+
661
+ for bf_id in range(num_bf):
662
+ for bd1_id in range(num_bd1):
663
+ start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
664
+ wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
665
+ wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
666
+
667
+ dynamic_ffn1(
668
+ t_b32_vmem=a2a_s_b32_vmem.at[
669
+ ...,
670
+ pl.ds(bd1_id * bd1_per_packing, bd1_per_packing)],
671
+ w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
672
+ w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
673
+ acc1_vmem=b_acc1_vmem,
674
+ acc3_vmem=b_acc3_vmem,
675
+ dyn_sz=dyn_sz,
676
+ should_init=(bd1_id == 0),
677
+ )
678
+ bw_sem_id = (bw_sem_id + 1) % 2
679
+
680
+ for bd2_id in range(num_bd2):
681
+ start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, num_bd1,
682
+ bd2_id)
683
+ wait_fetch_bw2(local_e_id, bw_sem_id, bf_id, bd2_id)
684
+ if bf_id == bd2_id == 0:
685
+ wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
686
+
687
+ dynamic_ffn2(
688
+ acc1_vmem=b_acc1_vmem,
689
+ acc3_vmem=b_acc3_vmem,
690
+ w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
691
+ res_b32_vmem=a2a_s_acc_b32_vmem.at[
692
+ ...,
693
+ pl.ds(bd2_id * bd2_per_packing, bd2_per_packing)],
694
+ dyn_sz=dyn_sz,
695
+ should_init=(bf_id == 0),
696
+ )
697
+ bw_sem_id = (bw_sem_id + 1) % 2
698
+
699
+ def bt_acc(bt_id, top_k_logits_lst):
700
+ bt_sem_id = bt_id % 2
701
+ for bt_t_id in range(bt):
702
+ for k_id in range(top_k):
703
+ e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id]
704
+ offset = expert_offsets_x2_smem[bt_sem_id, 1, e_id]
705
+ expert_offsets_x2_smem[bt_sem_id, 1, e_id] = offset + 1
706
+ pltpu.make_async_copy(
707
+ src_ref=a2a_g_hbm.at[e_id, pl.ds(offset, 1)],
708
+ dst_ref=a2a_g_acc_vmem.at[k_id, pl.ds(bt_t_id, 1)],
709
+ sem=a2a_acc_sem,
710
+ ).start()
711
+ pltpu.make_async_copy(
712
+ src_ref=a2a_g_acc_vmem,
713
+ dst_ref=a2a_g_acc_vmem,
714
+ sem=a2a_acc_sem,
715
+ ).wait()
716
+ output = None
717
+ for k_id in range(top_k):
718
+ acc = a2a_g_acc_vmem[k_id].reshape(bt, hidden_size)
719
+ logits = broadcast_minor(top_k_logits_lst[k_id], acc.shape)
720
+ acc *= logits
721
+ if output is None:
722
+ output = acc
723
+ else:
724
+ output += acc
725
+ assert output is not None
726
+ return output.astype(output_hbm.dtype)
727
+
728
+ def start_send_bo(bt_id, priority=0):
729
+ bt_sem_id = bt_id % 2
730
+ b_output_sem = local_sems.at[bt_sem_id, 4]
731
+ pltpu.make_async_copy(
732
+ src_ref=b_output_x2_vmem.at[bt_sem_id],
733
+ dst_ref=output_hbm.at[pl.ds(bt_id * bt, bt)],
734
+ sem=b_output_sem,
735
+ ).start(priority=priority)
736
+
737
+ def wait_send_bo(bt_id):
738
+ is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
739
+ sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
740
+ bt_sem_id = (bt_id + 2) % 2
741
+ b_output_sem = local_sems.at[bt_sem_id, 4]
742
+ pltpu.make_async_copy(
743
+ src_ref=output_hbm.at[pl.ds(0, sz)],
744
+ dst_ref=output_hbm.at[pl.ds(0, sz)],
745
+ sem=b_output_sem,
746
+ ).wait()
747
+
748
+ ### ------- Kernel start ------- ###
749
+ start_fetch_b_gating(bt_id=0)
750
+
751
+ def run_per_bt(bt_id, e_sem_id):
752
+ bt_sem_id = bt_id % 2
753
+ next_bt_id = bt_id + 1
754
+ start_fetch_b_gating(next_bt_id)
755
+ wait_fetch_b_gating(bt_id)
756
+
757
+ b_gating = b_gating_x2_vmem[bt_sem_id]
758
+ b_gating_score = jax.nn.softmax(b_gating, axis=-1)
759
+ top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
760
+ b_gating_score, top_k)
761
+
762
+ all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
763
+ expert_sizes)
764
+
765
+ start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
766
+
767
+ def run_per_expert(local_e_id, e_sem_id):
768
+ sync_barrier()
769
+ next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
770
+ next_local_e_id = local_e_id + 1
771
+
772
+ @pl.when(next_local_e_id < local_num_experts)
773
+ def _():
774
+ start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
775
+
776
+ # Prefetch weights for active expert.
777
+ start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
778
+ start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
779
+
780
+ # Wait for a2a scatter and perform FFN for active expert.
781
+ wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
782
+ expert_ffn(bt_id, e_sem_id, local_e_id)
783
+
784
+ # Wait for a2a gather to send back tokens for active expert.
785
+ start_a2a_gather(bt_id, e_sem_id, local_e_id)
786
+
787
+ # A must-wait before next sync_barrier.
788
+ wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id)
789
+ return next_e_sem_id
790
+
791
+ e_sem_id = lax.fori_loop(0,
792
+ local_num_experts,
793
+ run_per_expert,
794
+ e_sem_id,
795
+ unroll=False)
796
+
797
+ wait_a2a_gather_recv_all()
798
+ output = bt_acc(bt_id, top_k_logits_lst)
799
+
800
+ # Make sure it is safe to overwrite output buffer.
801
+ wait_send_bo(bt_id=bt_id - 2)
802
+ b_output_x2_vmem[bt_sem_id] = output
803
+
804
+ start_send_bo(bt_id)
805
+
806
+ wait_a2a_gather_send(
807
+ bt_id,
808
+ e_sem_id=e_sem_id,
809
+ local_e_id=local_num_experts - 2,
810
+ )
811
+ wait_a2a_gather_send(
812
+ bt_id,
813
+ e_sem_id=lax.select(e_sem_id == 0, 1, 0),
814
+ local_e_id=local_num_experts - 1,
815
+ )
816
+ return e_sem_id
817
+
818
+ lax.fori_loop(0, num_bt, run_per_bt, 0, unroll=False)
819
+ wait_send_bo(bt_id=num_bt - 2)
820
+ wait_send_bo(bt_id=num_bt - 1)
821
+
822
+ ### ------- Kernel end ------- ###
823
+
824
+
825
+ @functools.partial(
826
+ jax.jit,
827
+ static_argnames=[
828
+ "mesh",
829
+ "top_k",
830
+ "bt",
831
+ "bf",
832
+ "bd1",
833
+ "bd2",
834
+ "btc",
835
+ "bfc",
836
+ "bd1c",
837
+ "bd2c",
838
+ "ep_axis_name",
839
+ ],
840
+ )
841
+ def fused_ep_moe(
842
+ mesh: jax.sharding.Mesh,
843
+ tokens: jax.Array, # (num_tokens, hidden_size)
844
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
845
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
846
+ gating_output: jax.Array, # (num_tokens, num_experts)
847
+ top_k: int,
848
+ *,
849
+ # Kernel tuning parameters.
850
+ bt: int,
851
+ bf: int,
852
+ bd1: int,
853
+ bd2: int,
854
+ btc: int,
855
+ bfc: int,
856
+ bd1c: int,
857
+ bd2c: int,
858
+ ep_axis_name: str = 'model',
859
+ ):
860
+ # Assert all other axes have length of 1
861
+ assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference"
862
+ assert 'data' in mesh.shape and mesh.shape['data'] == 1, \
863
+ "Expect data axis size of 1 in tpu-inference"
864
+
865
+ ep_size = mesh.shape[ep_axis_name]
866
+ num_devices = ep_size
867
+
868
+ num_tokens, actual_hidden_size = tokens.shape
869
+ num_experts, intermediate_size, _ = w2.shape
870
+
871
+ assert num_tokens % ep_size == 0
872
+ assert num_experts % ep_size == 0
873
+
874
+ local_num_tokens = num_tokens // ep_size
875
+ # local_num_experts = num_experts // ep_size
876
+ padded_num_experts = align_to(num_experts, 128)
877
+
878
+ t_dtype = tokens.dtype
879
+ t_packing = get_dtype_packing(t_dtype)
880
+ hidden_size = align_to(actual_hidden_size, 128 * t_packing)
881
+ if hidden_size != actual_hidden_size:
882
+ tokens = jnp.pad(
883
+ tokens,
884
+ ((0, 0), (0, hidden_size - actual_hidden_size)),
885
+ constant_values=0,
886
+ )
887
+ tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
888
+ bt = min(bt, local_num_tokens)
889
+ bf = min(bf, intermediate_size)
890
+ bd1 = min(bd1, hidden_size)
891
+ bd2 = min(bd2, hidden_size)
892
+
893
+ btc = min(btc, bt * num_devices)
894
+ bfc = min(bfc, bf)
895
+ bd1c = min(bd1c, bd1)
896
+ bd2c = min(bd2c, bd2)
897
+ assert bfc % 128 == 0
898
+ assert bd1c % (t_packing * 128) == 0
899
+ assert bd2c % (t_packing * 128) == 0
900
+ assert bf % bfc == 0
901
+ assert bd1 % bd1c == 0
902
+ assert bd2 % bd2c == 0
903
+
904
+ if padded_num_experts != gating_output.shape[-1]:
905
+ gating_output = jnp.pad(
906
+ gating_output,
907
+ ((0, 0), (0, padded_num_experts - gating_output.shape[-1])),
908
+ constant_values=-jnp.inf,
909
+ )
910
+
911
+ scope_name = f"fused_moe_k-{top_k}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
912
+ fused_moe = jax.named_scope(scope_name)(
913
+ pl.pallas_call(
914
+ functools.partial(
915
+ _fused_ep_moe_kernel,
916
+ top_k=top_k,
917
+ ep_axis_name=ep_axis_name,
918
+ bt=bt,
919
+ bf=bf,
920
+ bd1=bd1,
921
+ bd2=bd2,
922
+ btc=btc,
923
+ bfc=bfc,
924
+ bd1c=bd1c,
925
+ bd2c=bd2c,
926
+ ),
927
+ out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
928
+ t_dtype),
929
+ grid_spec=pltpu.PrefetchScalarGridSpec(
930
+ num_scalar_prefetch=0,
931
+ in_specs=[
932
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
933
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
934
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
935
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
936
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
937
+ ],
938
+ out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
939
+ scratch_shapes=([
940
+ # t2e_routing_x2_smem
941
+ pltpu.SMEM((2, bt, padded_num_experts), jnp.int32),
942
+ # d2e_count_x2_smem
943
+ pltpu.SMEM((2, num_devices, 1, padded_num_experts),
944
+ jnp.int32),
945
+ # expert_offsets_x2_smem
946
+ pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
947
+ # expert_starts_x2_smem
948
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
949
+ # expert_sizes_x2_smem
950
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
951
+ # a2a_s_sends_x2_smem
952
+ pltpu.SMEM((2, ), jnp.int32),
953
+ # a2a_s_x2_vmem
954
+ pltpu.VMEM(
955
+ (
956
+ 2,
957
+ bt * num_devices,
958
+ t_packing,
959
+ hidden_size // t_packing,
960
+ ),
961
+ t_dtype,
962
+ ),
963
+ # a2a_s_acc_x2_vmem
964
+ pltpu.VMEM(
965
+ (
966
+ 2,
967
+ bt * num_devices,
968
+ t_packing,
969
+ hidden_size // t_packing,
970
+ ),
971
+ t_dtype,
972
+ ),
973
+ # a2a_g_acc_vmem
974
+ pltpu.VMEM(
975
+ (top_k, bt, t_packing, hidden_size // t_packing),
976
+ t_dtype),
977
+ # b_gating_x2_vmem
978
+ pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
979
+ # b_output_x2_vmem
980
+ pltpu.VMEM((2, bt, hidden_size), t_dtype),
981
+ # b_w1_x2_vmem
982
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
983
+ # b_w3_x2_vmem
984
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
985
+ # b_w2_x2_vmem
986
+ pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
987
+ # b_acc_vmem
988
+ pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
989
+ # local_sems
990
+ pltpu.SemaphoreType.DMA((2, 5)),
991
+ # send_sems
992
+ pltpu.SemaphoreType.DMA((2, )),
993
+ # recv_sems
994
+ pltpu.SemaphoreType.DMA((2, )),
995
+ # a2a_gather_sem
996
+ pltpu.SemaphoreType.DMA,
997
+ # a2a_acc_sem
998
+ pltpu.SemaphoreType.DMA,
999
+ ]),
1000
+ ),
1001
+ compiler_params=pltpu.CompilerParams(
1002
+ collective_id=0,
1003
+ vmem_limit_bytes=100 * 1024 * 1024,
1004
+ ),
1005
+ name=scope_name,
1006
+ ))
1007
+
1008
+ @jax.jit
1009
+ @functools.partial(
1010
+ shard_map.shard_map,
1011
+ mesh=mesh,
1012
+ in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name),
1013
+ P(ep_axis_name), P()),
1014
+ out_specs=P(ep_axis_name),
1015
+ check_rep=False,
1016
+ )
1017
+ def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch):
1018
+ return fused_moe(
1019
+ pltpu.with_memory_space_constraint(tokens, pltpu.HBM),
1020
+ pltpu.with_memory_space_constraint(w1, pltpu.HBM),
1021
+ pltpu.with_memory_space_constraint(w2, pltpu.HBM),
1022
+ pltpu.with_memory_space_constraint(gating_output, pltpu.HBM),
1023
+ pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, pltpu.HBM),
1024
+ )
1025
+
1026
+ a2a_g_hbm_scratch = pl.empty(
1027
+ (num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
1028
+ results = kernel(
1029
+ tokens,
1030
+ w1,
1031
+ w2,
1032
+ gating_output,
1033
+ a2a_g_hbm_scratch,
1034
+ )
1035
+ return results[:, :actual_hidden_size]