tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,735 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """All-gather matmul kernel."""
3
+
4
+ import functools
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jax import lax
9
+ from jax._src import dtypes
10
+ from jax.experimental import pallas as pl
11
+ from jax.experimental.pallas import tpu as pltpu
12
+
13
+ from tpu_inference.kernels.collectives import (
14
+ all_gather_matmul_tuned_block_sizes, util)
15
+
16
+ P = jax.sharding.PartitionSpec
17
+
18
+
19
+ def _cdiv(x, y):
20
+ return (x + y - 1) // y
21
+
22
+
23
+ # TODO(chengjiyao): try unrolling the loop instead of using pallas_call grid
24
+ # TODO(chengjiyao): try m tiling
25
+ # TODO(chengjiyao): try using [bm, bk] and [bk, bn] scratches memory shape for
26
+ # large bm
27
+ # TODO(chengjiyao): try splitting to two parts when n_per_device is large:
28
+ # output_0, gatherd_x = ag-matmul(x, y_0)
29
+ # output_1 = matmul(gatherd_x, y_1)
30
+ # output = concat(output_0, output_1)
31
+ # TODO(chengjiyao): investigate the register spilling
32
+ def _all_gather_kernel(
33
+ # Inputs
34
+ x_hbm_ref, # [m_per_device, k]
35
+ y_hbm_ref, # [k, n_per_device]
36
+ # Outputs
37
+ o_hbm_ref, # [m, n_per_device]
38
+ x_hbm_scratch_ref, # [num_devices - 1, m_per_device, k]
39
+ # Scratches
40
+ x_local_copy_sem, # []
41
+ y_local_copy_sem, # []
42
+ o_local_copy_sem, # []
43
+ send_sems, # [2, num_devices - 1] for left and right
44
+ recv_sems, # [2, num_devices - 1] for left and right
45
+ x_vmem_scratch_ref, # [2, m_per_device, k]
46
+ y_vmem_scratch_ref, # [k, n_per_device]
47
+ o_vmem_scratch_ref, # [2, m_per_device, bn]
48
+ acc_vmem_scratch_ref, # [m_per_device, bn] of jnp.float32
49
+ axis_name: str,
50
+ bn: int,
51
+ bk: int,
52
+ debug_mode=False,
53
+ rhs_transpose: bool = False,
54
+ ):
55
+ """Pallas kernel for all-gather.
56
+
57
+ Args:
58
+ x_hbm_ref: LHS of the matmul before all-gather.
59
+ y_hbm_ref: RHS of the matmul.
60
+ o_hbm_ref: Output of the matmul.
61
+ x_hbm_scratch_ref: Scratch memory for LHS of the matmul.
62
+ x_local_copy_sem: DMA semaphore for a local HBM-VMEM copy.
63
+ y_local_copy_sem: DMA semaphore for a local HBM-VMEM copy.
64
+ o_local_copy_sem: DMA semaphore for a local HBM-VMEM copy.
65
+ send_sem: DMA semaphore for the remote send.
66
+ capacity_sem: Capacity semaphore for the remote send.
67
+ recv_sems: DMA semaphore for the remote receive.
68
+ x_vmem_scratch_ref: Scratch memory for LHS of the matmul.
69
+ y_vmem_scratch_ref: Scratch memory for RHS of the matmul.
70
+ o_vmem_scratch_ref: Scratch memory for output of the matmul.
71
+ """
72
+ num_devices = pl.num_programs(0) - 2
73
+ grid_n = pl.num_programs(1)
74
+ grid_k = pl.num_programs(2)
75
+ outer_step = pl.program_id(0)
76
+ bn_i = pl.program_id(1)
77
+ bk_i = pl.program_id(2)
78
+ global_step_id = outer_step * grid_n * grid_k + bn_i * grid_k + bk_i
79
+ mxu_total_steps = num_devices * grid_n * grid_k
80
+ gn_by_gk = grid_n * grid_k
81
+ my_id = lax.axis_index(axis_name)
82
+ left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices))
83
+ right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
84
+ x_hbm_receiving_slot = outer_step
85
+ x_hbm_working_slot = outer_step - 1
86
+ x_vmem_receiving_slot = outer_step % 2
87
+ x_vmem_working_slot = (global_step_id - 1) // gn_by_gk % 2
88
+ o_receiving_slot = lax.rem((global_step_id + grid_k - 1) // grid_k, 2)
89
+ o_working_slot = 1 - o_receiving_slot
90
+ m_per_device, _ = x_hbm_ref.shape
91
+ m_per_device_per_direction = m_per_device // 2
92
+
93
+ def debug_print(msg, *args):
94
+ if debug_mode:
95
+
96
+ @pl.when(my_id == 0)
97
+ def _debug_print():
98
+ pl.debug_print(msg, *args)
99
+
100
+ def _start_or_wait_copy(
101
+ op: jax._src.pallas.mosaic.primitives.AsyncCopyDescriptor,
102
+ wait: bool = False,
103
+ ):
104
+ if wait:
105
+ op.wait()
106
+ else:
107
+ op.start()
108
+
109
+ def _do_first_x_local_copy(wait: bool = False):
110
+ debug_print(
111
+ "[AGMM debug, wait={}] do first x local copy, x_vmem_receiving_slot={},"
112
+ " bk_i={}",
113
+ int(wait),
114
+ x_vmem_receiving_slot,
115
+ bk_i,
116
+ )
117
+ k_slice = pl.ds(bk_i * bk, bk)
118
+ x_local_copy_op = pltpu.make_async_copy(
119
+ src_ref=x_hbm_ref.at[:, k_slice],
120
+ dst_ref=x_vmem_scratch_ref.at[x_vmem_receiving_slot, :, k_slice],
121
+ sem=x_local_copy_sem,
122
+ )
123
+ _start_or_wait_copy(x_local_copy_op, wait)
124
+
125
+ def _do_subsequent_x_left_local_copy(wait: bool = False):
126
+ debug_print(
127
+ "[AGMM debug, wait={}] do subsequent x left local copy,"
128
+ " x_hbm_working_slot={}, x_vmem_receiving_slot={}, bk_i={}",
129
+ int(wait),
130
+ x_hbm_working_slot,
131
+ x_vmem_receiving_slot,
132
+ bk_i,
133
+ )
134
+ k_slice = pl.ds(bk_i * bk, bk)
135
+ x_local_copy_op = pltpu.make_async_copy(
136
+ src_ref=x_hbm_scratch_ref.at[
137
+ x_hbm_working_slot,
138
+ :m_per_device_per_direction,
139
+ k_slice,
140
+ ],
141
+ dst_ref=x_vmem_scratch_ref.at[
142
+ x_vmem_receiving_slot,
143
+ :m_per_device_per_direction,
144
+ k_slice,
145
+ ],
146
+ sem=x_local_copy_sem,
147
+ )
148
+ _start_or_wait_copy(x_local_copy_op, wait)
149
+
150
+ def _do_subsequent_x_right_local_copy(wait: bool = False):
151
+ debug_print(
152
+ "[AGMM debug, wait={}] do subsequent x right local copy,"
153
+ " x_hbm_working_slot={}, x_vmem_receiving_slot={}, bk_i={}",
154
+ int(wait),
155
+ x_hbm_working_slot,
156
+ x_vmem_receiving_slot,
157
+ bk_i,
158
+ )
159
+ x_local_copy_op = pltpu.make_async_copy(
160
+ src_ref=x_hbm_scratch_ref.at[
161
+ x_hbm_working_slot,
162
+ m_per_device_per_direction:,
163
+ pl.ds(bk_i * bk, bk),
164
+ ],
165
+ dst_ref=x_vmem_scratch_ref.at[
166
+ x_vmem_receiving_slot,
167
+ m_per_device_per_direction:,
168
+ pl.ds(bk_i * bk, bk),
169
+ ],
170
+ sem=x_local_copy_sem,
171
+ )
172
+ _start_or_wait_copy(x_local_copy_op, wait)
173
+
174
+ def _do_y_local_copy(wait: bool = False):
175
+ debug_print(
176
+ "[AGMM debug, wait={}] do y local copy, bk_i={}, bn_i={}",
177
+ int(wait),
178
+ bk_i,
179
+ bn_i,
180
+ )
181
+ k_slice = pl.ds(bk_i * bk, bk)
182
+ n_slice = pl.ds(bn_i * bn, bn)
183
+ if rhs_transpose:
184
+ y_local_copy_op = pltpu.make_async_copy(
185
+ src_ref=y_hbm_ref.at[n_slice, k_slice],
186
+ dst_ref=y_vmem_scratch_ref.at[n_slice, k_slice],
187
+ sem=y_local_copy_sem,
188
+ )
189
+ else:
190
+ y_local_copy_op = pltpu.make_async_copy(
191
+ src_ref=y_hbm_ref.at[k_slice, n_slice],
192
+ dst_ref=y_vmem_scratch_ref.at[k_slice, n_slice],
193
+ sem=y_local_copy_sem,
194
+ )
195
+ _start_or_wait_copy(y_local_copy_op, wait)
196
+
197
+ def _do_first_left_remote_copy(wait: bool = False):
198
+ debug_print(
199
+ "[AGMM debug, wait={}] do first left remote copy,"
200
+ " x_hbm_receiving_slot={}, x_hbm_working_slot={}",
201
+ int(wait),
202
+ x_hbm_receiving_slot,
203
+ x_hbm_working_slot,
204
+ )
205
+ left_remote_copy_op = pltpu.make_async_remote_copy(
206
+ src_ref=x_hbm_ref.at[0:m_per_device_per_direction],
207
+ dst_ref=x_hbm_scratch_ref.at[x_hbm_receiving_slot,
208
+ 0:m_per_device_per_direction],
209
+ send_sem=send_sems.at[0, outer_step],
210
+ recv_sem=recv_sems.at[0, outer_step],
211
+ device_id=(left_neighbor, ),
212
+ device_id_type=pltpu.DeviceIdType.MESH,
213
+ )
214
+ _start_or_wait_copy(left_remote_copy_op, wait)
215
+
216
+ def _do_first_right_remote_copy(wait: bool = False):
217
+ debug_print(
218
+ "[AGMM debug, wait={}] do first right remote copy,"
219
+ " x_hbm_receiving_slot={}, x_hbm_working_slot={}",
220
+ int(wait),
221
+ x_hbm_receiving_slot,
222
+ x_hbm_working_slot,
223
+ )
224
+ right_remote_copy_op = pltpu.make_async_remote_copy(
225
+ src_ref=x_hbm_ref.at[m_per_device_per_direction:m_per_device],
226
+ dst_ref=x_hbm_scratch_ref.at[
227
+ x_hbm_receiving_slot, m_per_device_per_direction:m_per_device],
228
+ send_sem=send_sems.at[1, outer_step],
229
+ recv_sem=recv_sems.at[1, outer_step],
230
+ device_id=(right_neighbor, ),
231
+ device_id_type=pltpu.DeviceIdType.MESH,
232
+ )
233
+ _start_or_wait_copy(right_remote_copy_op, wait)
234
+
235
+ def _do_subsequent_left_remote_copy(wait: bool = False):
236
+ debug_print(
237
+ "[AGMM debug, wait={}] do subsequent left remote copy,"
238
+ " x_hbm_receiving_slot={}, x_hbm_working_slot={}",
239
+ int(wait),
240
+ x_hbm_receiving_slot,
241
+ x_hbm_working_slot,
242
+ )
243
+ left_remote_copy_op = pltpu.make_async_remote_copy(
244
+ src_ref=x_hbm_scratch_ref.at[x_hbm_working_slot,
245
+ 0:m_per_device_per_direction],
246
+ dst_ref=x_hbm_scratch_ref.at[x_hbm_receiving_slot,
247
+ 0:m_per_device_per_direction],
248
+ send_sem=send_sems.at[0, outer_step],
249
+ recv_sem=recv_sems.at[0, outer_step],
250
+ device_id=(left_neighbor, ),
251
+ device_id_type=pltpu.DeviceIdType.MESH,
252
+ )
253
+ _start_or_wait_copy(left_remote_copy_op, wait)
254
+
255
+ def _do_subsequent_right_remote_copy(wait: bool = False):
256
+ debug_print(
257
+ "[AGMM debug, wait={}] do subsequent right remote copy,"
258
+ " x_hbm_receiving_slot={}, x_hbm_working_slot={}",
259
+ int(wait),
260
+ x_hbm_receiving_slot,
261
+ x_hbm_working_slot,
262
+ )
263
+ right_remote_copy_op = pltpu.make_async_remote_copy(
264
+ src_ref=x_hbm_scratch_ref.at[
265
+ x_hbm_working_slot, m_per_device_per_direction:m_per_device],
266
+ dst_ref=x_hbm_scratch_ref.at[
267
+ x_hbm_receiving_slot, m_per_device_per_direction:m_per_device],
268
+ send_sem=send_sems.at[1, outer_step],
269
+ recv_sem=recv_sems.at[1, outer_step],
270
+ device_id=(right_neighbor, ),
271
+ device_id_type=pltpu.DeviceIdType.MESH,
272
+ )
273
+ _start_or_wait_copy(right_remote_copy_op, wait)
274
+
275
+ def _do_mxu():
276
+ working_global_step_id = global_step_id - 1
277
+ working_bk_i = working_global_step_id % grid_k
278
+ working_bn_i = working_global_step_id % gn_by_gk // grid_k
279
+ debug_print(
280
+ "[AGMM debug] do mxu, x_vmem_working_slot={}, o_receiving_slot={},"
281
+ " working_bk_i={}, working_bn_i={}",
282
+ x_vmem_working_slot,
283
+ o_receiving_slot,
284
+ working_bk_i,
285
+ working_bn_i,
286
+ )
287
+ k_slice = pl.ds(working_bk_i * bk, bk)
288
+ n_slice = pl.ds(working_bn_i * bn, bn)
289
+
290
+ if grid_k == 1:
291
+ if rhs_transpose:
292
+ lhs = x_vmem_scratch_ref.at[x_vmem_working_slot][...]
293
+ rhs = y_vmem_scratch_ref.at[n_slice, :][...]
294
+ o_vmem_scratch_ref.at[o_receiving_slot][...] = lax.dot_general(
295
+ lhs,
296
+ rhs,
297
+ dimension_numbers=(((1, ), (1, )), ((), ())),
298
+ preferred_element_type=jnp.float32,
299
+ ).astype(x_vmem_scratch_ref.dtype)
300
+ else:
301
+ o_vmem_scratch_ref.at[o_receiving_slot][...] = jnp.dot(
302
+ x_vmem_scratch_ref.at[x_vmem_working_slot][...],
303
+ y_vmem_scratch_ref.at[:, n_slice][...],
304
+ preferred_element_type=jnp.float32,
305
+ ).astype(x_vmem_scratch_ref.dtype)
306
+ else:
307
+ # TODO(chengjiyao): optimize the vstore
308
+ if rhs_transpose:
309
+ lhs = x_vmem_scratch_ref.at[x_vmem_working_slot, :,
310
+ k_slice][...]
311
+ rhs = y_vmem_scratch_ref.at[n_slice, k_slice][...]
312
+ acc_vmem_scratch_ref[...] += lax.dot_general(
313
+ lhs,
314
+ rhs,
315
+ dimension_numbers=(((1, ), (1, )), ((), ())),
316
+ preferred_element_type=jnp.float32,
317
+ )
318
+ else:
319
+ acc_vmem_scratch_ref[...] += jnp.dot(
320
+ x_vmem_scratch_ref.at[x_vmem_working_slot, :,
321
+ k_slice][...],
322
+ y_vmem_scratch_ref.at[k_slice, n_slice][...],
323
+ preferred_element_type=jnp.float32,
324
+ )
325
+
326
+ @pl.when(working_bk_i == grid_k - 1)
327
+ def _update():
328
+ debug_print(
329
+ "[AGMM debug] update, o_receiving_slot={}",
330
+ o_receiving_slot,
331
+ )
332
+ o_vmem_scratch_ref.at[o_receiving_slot][
333
+ ...] = acc_vmem_scratch_ref[...].astype(
334
+ x_vmem_scratch_ref.dtype)
335
+ # TODO(chengjiyao): based on the kyuyeunk' suggestion:
336
+ # this logic can be more optimized. right now it does this.
337
+ # line 316 performs dot
338
+ # line 316 loads from acc_vmem_scartch_ref
339
+ # line 316 adds resulting dot with acc_vmem_scratch_ref
340
+ # line 316 stores result into acc_vmem_scratch_ref
341
+ # line 335 loads from acc_vmem_scratch_ref again.
342
+ # line 338 zero initializes & stores it to acc_vmem_scratch_ref
343
+ # better way would be
344
+
345
+ # perform dot
346
+ # if working_bk_i != 0, load from acc_vmem_scratch_ref and add result
347
+ # from previous step. If not, skip this process.
348
+ # if working_bk_i == gk - 1, store the result from step 2 into
349
+ # o_vmem_scratch_ref, if not, store it into acc_vmem_scratch_ref
350
+ acc_vmem_scratch_ref[...] = jnp.zeros_like(
351
+ acc_vmem_scratch_ref)
352
+
353
+ def _do_o_local_copy(wait: bool = False):
354
+ working_global_step_id = global_step_id - grid_k - 1
355
+ working_bn_i = (working_global_step_id % gn_by_gk) // grid_k
356
+ n_slice = pl.ds(working_bn_i * bn, bn)
357
+ offset = (global_step_id - 2) // gn_by_gk
358
+ left_o_idx = (my_id + offset) % num_devices
359
+ left_o_idx = left_o_idx * 2
360
+ right_o_idx = (my_id - offset + num_devices) % num_devices
361
+ right_o_idx = right_o_idx * 2 + 1
362
+ debug_print(
363
+ "[AGMM debug, wait={}] do o local copy, o_working_slot={},"
364
+ " left_o_idx={}, right_o_idx={}, working_bn_i={}",
365
+ int(wait),
366
+ o_working_slot,
367
+ left_o_idx,
368
+ right_o_idx,
369
+ working_bn_i,
370
+ )
371
+ o_left_local_copy_op = pltpu.make_async_copy(
372
+ src_ref=o_vmem_scratch_ref.at[
373
+ o_working_slot, :m_per_device_per_direction],
374
+ dst_ref=o_hbm_ref.at[
375
+ pl.ds(
376
+ m_per_device_per_direction * left_o_idx,
377
+ m_per_device_per_direction,
378
+ ),
379
+ n_slice,
380
+ ],
381
+ sem=o_local_copy_sem,
382
+ )
383
+ o_right_local_copy_op = pltpu.make_async_copy(
384
+ src_ref=o_vmem_scratch_ref.at[o_working_slot,
385
+ m_per_device_per_direction:],
386
+ dst_ref=o_hbm_ref.at[
387
+ pl.ds(
388
+ m_per_device_per_direction * right_o_idx,
389
+ m_per_device_per_direction,
390
+ ),
391
+ n_slice,
392
+ ],
393
+ sem=o_local_copy_sem,
394
+ )
395
+ _start_or_wait_copy(o_left_local_copy_op, wait)
396
+ _start_or_wait_copy(o_right_local_copy_op, wait)
397
+
398
+ ### ------- Kernel start ------- ###
399
+ # TODO(chengjiyao): explore a fine-grained way to do the waits and signal
400
+
401
+ debug_print(
402
+ "===== starting a grid, outer_step={}, bn_i={}, bk_i={} =====",
403
+ outer_step,
404
+ bn_i,
405
+ bk_i,
406
+ )
407
+
408
+ @pl.when(global_step_id == 0)
409
+ @jax.named_scope("_start_first_remote_copy")
410
+ def _start_first_remote_copy():
411
+ if grid_k > 1:
412
+ acc_vmem_scratch_ref[...] = jnp.zeros_like(acc_vmem_scratch_ref)
413
+ # Barrier with both neighbors at the start, since we will be
414
+ # communicating with both.
415
+ util.local_barrier(left_neighbor, right_neighbor)
416
+ _do_first_left_remote_copy(wait=False)
417
+ _do_first_right_remote_copy(wait=False)
418
+
419
+ cond_start_subsequent_remote_copy = jnp.logical_and(
420
+ jnp.logical_and(outer_step > 0, outer_step < num_devices - 1),
421
+ global_step_id % gn_by_gk == 0,
422
+ )
423
+
424
+ @pl.when(cond_start_subsequent_remote_copy)
425
+ @jax.named_scope("_start_subsequent_remote_copy")
426
+ def _start_subsequent_remote_copy():
427
+ _do_subsequent_left_remote_copy(wait=False)
428
+ _do_subsequent_right_remote_copy(wait=False)
429
+
430
+ @pl.when(jnp.logical_and(outer_step == 0, bn_i == 0))
431
+ @jax.named_scope("_start_first_local_x_copy")
432
+ def _start_first_x_local_copy():
433
+ _do_first_x_local_copy(wait=False)
434
+
435
+ cond_subsequent_x_local_copy = jnp.logical_and(
436
+ jnp.logical_and(outer_step > 0, outer_step < num_devices), bn_i == 0)
437
+
438
+ @pl.when(cond_subsequent_x_local_copy)
439
+ @jax.named_scope("_start_subsequent_x_local_copy")
440
+ def _start_subsequent_x_local_copy():
441
+ _do_subsequent_x_left_local_copy(wait=False)
442
+ _do_subsequent_x_right_local_copy(wait=False)
443
+
444
+ @pl.when(outer_step == 0)
445
+ @jax.named_scope("_start_y_local_copy")
446
+ def _start_y_local_copy():
447
+ _do_y_local_copy(wait=False)
448
+
449
+ def _get_start_o_local_copy_cond():
450
+ if grid_k == 1:
451
+ return jnp.logical_and(global_step_id >= 2, global_step_id
452
+ < mxu_total_steps + 2)
453
+ else:
454
+ return jnp.logical_and(
455
+ jnp.logical_and(
456
+ global_step_id >= grid_k + 1,
457
+ global_step_id < mxu_total_steps + grid_k + 1,
458
+ ),
459
+ global_step_id % grid_k == 1,
460
+ )
461
+
462
+ @pl.when(_get_start_o_local_copy_cond())
463
+ @jax.named_scope("_start_o_local_copy")
464
+ def _start_o_local_copy():
465
+ _do_o_local_copy(wait=False)
466
+
467
+ @pl.when(
468
+ jnp.logical_and(global_step_id >= 1, global_step_id
469
+ < 1 + mxu_total_steps))
470
+ @jax.named_scope("_mxu")
471
+ def _mxu():
472
+ _do_mxu()
473
+
474
+ def _get_wait_o_local_copy_cond():
475
+ if grid_k == 1:
476
+ return jnp.logical_and(global_step_id >= 2, global_step_id
477
+ < mxu_total_steps + 2)
478
+ else:
479
+ return jnp.logical_and(
480
+ jnp.logical_and(
481
+ global_step_id >= grid_k + 1,
482
+ global_step_id < mxu_total_steps + grid_k + 1,
483
+ ),
484
+ global_step_id % grid_k == 0,
485
+ )
486
+
487
+ @pl.when(_get_wait_o_local_copy_cond())
488
+ @jax.named_scope("_wait_o_local_copy")
489
+ def _wait_o_local_copy():
490
+ _do_o_local_copy(wait=True)
491
+
492
+ @pl.when(outer_step == 0)
493
+ @jax.named_scope("_wait_y_local_copy")
494
+ def _wait_y_local_copy():
495
+ _do_y_local_copy(wait=True)
496
+
497
+ @pl.when(jnp.logical_and(outer_step == 0, bn_i == 0))
498
+ @jax.named_scope("_wait_first_x_local_copy")
499
+ def _wait_first_x_local_copy():
500
+ _do_first_x_local_copy(wait=True)
501
+
502
+ @pl.when(cond_subsequent_x_local_copy)
503
+ @jax.named_scope("_wait_subsequent_x_local_copy")
504
+ def _wait_subsequent_x_local_copy():
505
+ _do_subsequent_x_left_local_copy(wait=True)
506
+ _do_subsequent_x_right_local_copy(wait=True)
507
+
508
+ @pl.when(global_step_id == gn_by_gk - 1)
509
+ @jax.named_scope("_wait_first_remote_copy")
510
+ def _wait_first_remote_copy():
511
+ _do_first_left_remote_copy(wait=True)
512
+ _do_first_right_remote_copy(wait=True)
513
+
514
+ cond_wait_subsequent_remote_copy = jnp.logical_and(
515
+ jnp.logical_and(outer_step > 0, outer_step < num_devices - 1),
516
+ global_step_id % gn_by_gk == gn_by_gk - 1,
517
+ )
518
+
519
+ @pl.when(cond_wait_subsequent_remote_copy)
520
+ @jax.named_scope("_wait_subsequent_remote_copy")
521
+ def _wait_subsequent_remote_copy():
522
+ _do_subsequent_left_remote_copy(wait=True)
523
+ _do_subsequent_right_remote_copy(wait=True)
524
+
525
+ ### ------- Kernel end ------- ###
526
+
527
+
528
+ # FIXME(chengjiyao): make it accurate for the cases of quantization
529
+ def get_vmem_estimate_bytes(
530
+ m,
531
+ n,
532
+ k,
533
+ bn,
534
+ acc_bytes,
535
+ tp_size,
536
+ x_dtype,
537
+ y_dtype,
538
+ out_dtype,
539
+ ):
540
+ """Returns the total vmem bytes used by the kernel."""
541
+ m_per_device = m // tp_size
542
+ n_per_device = n // tp_size
543
+ y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype) // 8
544
+ total_bytes = (
545
+ 2 * m_per_device * k * dtypes.bit_width(x_dtype) //
546
+ 8 # x_vmem_scratch_ref
547
+ + y_vmem_bytes # y_vmem_scratch_ref
548
+ + 2 * m * bn * dtypes.bit_width(out_dtype) // 8 # o_vmem_scratch_ref
549
+ + acc_bytes # acc_vmem_scratch_ref, jnp.float32
550
+ )
551
+ return total_bytes
552
+
553
+
554
+ def validate_inputs(x, y, tp_size, rhs_transpose=False):
555
+ """Validates the inputs to the all_gather_matmul kernel."""
556
+ if x.ndim != 2 or y.ndim != 2:
557
+ raise ValueError(
558
+ f"Inputs must be 2D, got shapes {x.shape} and {y.shape}.")
559
+ if x.dtype != y.dtype:
560
+ raise ValueError(
561
+ f"Input dtypes must match, got {x.dtype} and {y.dtype}.")
562
+ m, k = x.shape
563
+ if rhs_transpose:
564
+ n, k_from_y = y.shape
565
+ else:
566
+ k_from_y, n = y.shape
567
+ if k != k_from_y:
568
+ raise ValueError(
569
+ "Incompatible shapes for matmul: contracting dimension mismatch:"
570
+ f" {x.shape} and {y.shape}.")
571
+
572
+ if k % 128 != 0:
573
+ raise ValueError(f"k ({k}) must be divisible by 128.")
574
+
575
+ if n % 128 != 0:
576
+ raise ValueError(f"n ({n}) must be divisible by 128.")
577
+
578
+ m_per_device_per_direction = m // tp_size // 2
579
+ if m_per_device_per_direction % 8 != 0:
580
+ raise ValueError(f"m ({m}) must be divisible by {{tp_size * 2 * 8}}.")
581
+
582
+ if m % (tp_size * 2) != 0:
583
+ raise ValueError(
584
+ f"x.shape[0] ({m}) must be divisible by tp_size * 2 ({tp_size * 2})'."
585
+ )
586
+ if n % tp_size != 0:
587
+ raise ValueError(
588
+ f"y.shape[{0 if rhs_transpose else 1}] ({n}) must be divisible by"
589
+ f" tp_size ({tp_size}) on axis '{tp_size}'.")
590
+
591
+
592
+ def all_gather_matmul(
593
+ x: jax.Array,
594
+ y: jax.Array,
595
+ mesh: jax.sharding.AbstractMesh,
596
+ axis_name: str,
597
+ collective_id: int | None = 0,
598
+ bn: int | None = None,
599
+ bk: int | None = None,
600
+ rhs_transpose: bool = False,
601
+ ):
602
+ """Performs all-gather on the input tensor and then a matmul.
603
+
604
+ Args:
605
+ x: LHS of the matmul before all-gather.
606
+ y: RHS of the matmul.
607
+ mesh: JAX mesh.
608
+ axis_name: Name of the axis to all-gather over.
609
+ collective_id: An integer used for barrier semaphore allocation.
610
+ bn: Number of blocks in the n dimension.
611
+ bk: Number of blocks in the k dimension.
612
+ rhs_transpose: If True, y is transposed.
613
+
614
+ Returns:
615
+ all-gather(x, axis=0) @ y
616
+ """
617
+ tp_size = mesh.shape[axis_name]
618
+ validate_inputs(x, y, tp_size, rhs_transpose)
619
+ m, k = x.shape
620
+ if rhs_transpose:
621
+ n, _ = y.shape
622
+ y_in_spec = P(axis_name, None)
623
+ else:
624
+ _, n = y.shape
625
+ y_in_spec = P(None, axis_name)
626
+ m_per_device = m // tp_size
627
+ n_per_device = n // tp_size
628
+ tuned_bn, tuned_bk = (
629
+ all_gather_matmul_tuned_block_sizes.get_tuned_block_sizes(
630
+ m, n, k,
631
+ jnp.dtype(x.dtype).name, tp_size))
632
+ if bn is None:
633
+ bn = tuned_bn if tuned_bn is not None else n
634
+ if bk is None:
635
+ bk = tuned_bk if tuned_bk is not None else k
636
+ grid_n = _cdiv(n_per_device, bn)
637
+ grid_k = _cdiv(k, bk)
638
+ acc_shape = (m_per_device, bn)
639
+ # NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
640
+ if grid_k == 1:
641
+ acc_shape = (8, 128)
642
+ acc_bytes = acc_shape[0] * acc_shape[1] * dtypes.bit_width(
643
+ jnp.float32) // 8
644
+ y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
645
+ estimated_vmem_bytes = get_vmem_estimate_bytes(
646
+ m,
647
+ n,
648
+ k,
649
+ bn,
650
+ acc_bytes,
651
+ tp_size,
652
+ x.dtype,
653
+ y.dtype,
654
+ x.dtype,
655
+ )
656
+ out_shape = [
657
+ jax.ShapeDtypeStruct((m, n_per_device), x.dtype), # output
658
+ jax.ShapeDtypeStruct((tp_size - 1, m_per_device, k),
659
+ x.dtype), # x HBM scratch
660
+ ]
661
+ grid_spec = pltpu.PrefetchScalarGridSpec(
662
+ num_scalar_prefetch=0,
663
+ in_specs=[
664
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
665
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
666
+ ],
667
+ out_specs=[
668
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
669
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
670
+ ],
671
+ scratch_shapes=(
672
+ pltpu.SemaphoreType.DMA, # x_local_copy_sem
673
+ pltpu.SemaphoreType.DMA, # y_local_copy_sem
674
+ pltpu.SemaphoreType.DMA, # o_local_copy_sem
675
+ pltpu.SemaphoreType.DMA(
676
+ (2, tp_size - 1)), # left and right send semaphores
677
+ pltpu.SemaphoreType.DMA((
678
+ 2,
679
+ tp_size - 1,
680
+ )), # left and right recv semaphores
681
+ pltpu.VMEM((2, m_per_device, k), x.dtype), # x vmem scratch
682
+ pltpu.VMEM(y_vmem_shape, y.dtype), # y vmem scratch
683
+ pltpu.VMEM((2, m_per_device, bn), x.dtype), # output vmem scratch
684
+ pltpu.VMEM(acc_shape, jnp.float32), # acc vmem scratch
685
+ ),
686
+ grid=(tp_size + 2, grid_n, grid_k),
687
+ )
688
+ flops = 2 * m * k * n_per_device
689
+ bytes_accessed = x.dtype.itemsize * (m * k + k * n_per_device +
690
+ m * n_per_device)
691
+ cost_estimate = pl.CostEstimate(flops=flops,
692
+ bytes_accessed=bytes_accessed,
693
+ transcendentals=0)
694
+
695
+ @functools.partial(jax.jit, static_argnames=["bn", "bk", "rhs_transpose"])
696
+ def _all_gather_matmul_call(x, y, bn, bk, rhs_transpose):
697
+ return pl.pallas_call(
698
+ functools.partial(
699
+ _all_gather_kernel,
700
+ bn=bn,
701
+ bk=bk,
702
+ axis_name=axis_name,
703
+ rhs_transpose=rhs_transpose,
704
+ ),
705
+ out_shape=out_shape,
706
+ grid_spec=grid_spec,
707
+ compiler_params=pltpu.CompilerParams(
708
+ collective_id=collective_id,
709
+ vmem_limit_bytes=estimated_vmem_bytes + 8 * 1024 * 1024,
710
+ ),
711
+ cost_estimate=cost_estimate,
712
+ name=get_kernel_name(bn, bk, rhs_transpose),
713
+ )(x, y)[0]
714
+
715
+ shard_map_kernel = jax.jit(
716
+ jax.shard_map(
717
+ functools.partial(
718
+ _all_gather_matmul_call,
719
+ bn=bn,
720
+ bk=bk,
721
+ rhs_transpose=rhs_transpose,
722
+ ),
723
+ mesh=mesh,
724
+ in_specs=(P(axis_name, None), y_in_spec),
725
+ out_specs=P(None, axis_name),
726
+ check_vma=False,
727
+ ), )
728
+
729
+ return shard_map_kernel(x, y)
730
+
731
+
732
+ def get_kernel_name(bn: int, bk: int, rhs_transpose: bool):
733
+ return (
734
+ f"all_gather_matmul_kernel_bn_{bn}_bk_{bk}_rhs_transpose_{rhs_transpose}"
735
+ )