tpu-inference 0.11.1rc2__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +1 -1
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,735 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
"""All-gather matmul kernel."""
|
|
3
|
+
|
|
4
|
+
import functools
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from jax import lax
|
|
9
|
+
from jax._src import dtypes
|
|
10
|
+
from jax.experimental import pallas as pl
|
|
11
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
12
|
+
|
|
13
|
+
from tpu_inference.kernels.collectives import (
|
|
14
|
+
all_gather_matmul_tuned_block_sizes, util)
|
|
15
|
+
|
|
16
|
+
P = jax.sharding.PartitionSpec
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _cdiv(x, y):
|
|
20
|
+
return (x + y - 1) // y
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# TODO(chengjiyao): try unrolling the loop instead of using pallas_call grid
|
|
24
|
+
# TODO(chengjiyao): try m tiling
|
|
25
|
+
# TODO(chengjiyao): try using [bm, bk] and [bk, bn] scratches memory shape for
|
|
26
|
+
# large bm
|
|
27
|
+
# TODO(chengjiyao): try splitting to two parts when n_per_device is large:
|
|
28
|
+
# output_0, gatherd_x = ag-matmul(x, y_0)
|
|
29
|
+
# output_1 = matmul(gatherd_x, y_1)
|
|
30
|
+
# output = concat(output_0, output_1)
|
|
31
|
+
# TODO(chengjiyao): investigate the register spilling
|
|
32
|
+
def _all_gather_kernel(
|
|
33
|
+
# Inputs
|
|
34
|
+
x_hbm_ref, # [m_per_device, k]
|
|
35
|
+
y_hbm_ref, # [k, n_per_device]
|
|
36
|
+
# Outputs
|
|
37
|
+
o_hbm_ref, # [m, n_per_device]
|
|
38
|
+
x_hbm_scratch_ref, # [num_devices - 1, m_per_device, k]
|
|
39
|
+
# Scratches
|
|
40
|
+
x_local_copy_sem, # []
|
|
41
|
+
y_local_copy_sem, # []
|
|
42
|
+
o_local_copy_sem, # []
|
|
43
|
+
send_sems, # [2, num_devices - 1] for left and right
|
|
44
|
+
recv_sems, # [2, num_devices - 1] for left and right
|
|
45
|
+
x_vmem_scratch_ref, # [2, m_per_device, k]
|
|
46
|
+
y_vmem_scratch_ref, # [k, n_per_device]
|
|
47
|
+
o_vmem_scratch_ref, # [2, m_per_device, bn]
|
|
48
|
+
acc_vmem_scratch_ref, # [m_per_device, bn] of jnp.float32
|
|
49
|
+
axis_name: str,
|
|
50
|
+
bn: int,
|
|
51
|
+
bk: int,
|
|
52
|
+
debug_mode=False,
|
|
53
|
+
rhs_transpose: bool = False,
|
|
54
|
+
):
|
|
55
|
+
"""Pallas kernel for all-gather.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
x_hbm_ref: LHS of the matmul before all-gather.
|
|
59
|
+
y_hbm_ref: RHS of the matmul.
|
|
60
|
+
o_hbm_ref: Output of the matmul.
|
|
61
|
+
x_hbm_scratch_ref: Scratch memory for LHS of the matmul.
|
|
62
|
+
x_local_copy_sem: DMA semaphore for a local HBM-VMEM copy.
|
|
63
|
+
y_local_copy_sem: DMA semaphore for a local HBM-VMEM copy.
|
|
64
|
+
o_local_copy_sem: DMA semaphore for a local HBM-VMEM copy.
|
|
65
|
+
send_sem: DMA semaphore for the remote send.
|
|
66
|
+
capacity_sem: Capacity semaphore for the remote send.
|
|
67
|
+
recv_sems: DMA semaphore for the remote receive.
|
|
68
|
+
x_vmem_scratch_ref: Scratch memory for LHS of the matmul.
|
|
69
|
+
y_vmem_scratch_ref: Scratch memory for RHS of the matmul.
|
|
70
|
+
o_vmem_scratch_ref: Scratch memory for output of the matmul.
|
|
71
|
+
"""
|
|
72
|
+
num_devices = pl.num_programs(0) - 2
|
|
73
|
+
grid_n = pl.num_programs(1)
|
|
74
|
+
grid_k = pl.num_programs(2)
|
|
75
|
+
outer_step = pl.program_id(0)
|
|
76
|
+
bn_i = pl.program_id(1)
|
|
77
|
+
bk_i = pl.program_id(2)
|
|
78
|
+
global_step_id = outer_step * grid_n * grid_k + bn_i * grid_k + bk_i
|
|
79
|
+
mxu_total_steps = num_devices * grid_n * grid_k
|
|
80
|
+
gn_by_gk = grid_n * grid_k
|
|
81
|
+
my_id = lax.axis_index(axis_name)
|
|
82
|
+
left_neighbor = lax.rem(my_id + num_devices - 1, jnp.int32(num_devices))
|
|
83
|
+
right_neighbor = lax.rem(my_id + 1, jnp.int32(num_devices))
|
|
84
|
+
x_hbm_receiving_slot = outer_step
|
|
85
|
+
x_hbm_working_slot = outer_step - 1
|
|
86
|
+
x_vmem_receiving_slot = outer_step % 2
|
|
87
|
+
x_vmem_working_slot = (global_step_id - 1) // gn_by_gk % 2
|
|
88
|
+
o_receiving_slot = lax.rem((global_step_id + grid_k - 1) // grid_k, 2)
|
|
89
|
+
o_working_slot = 1 - o_receiving_slot
|
|
90
|
+
m_per_device, _ = x_hbm_ref.shape
|
|
91
|
+
m_per_device_per_direction = m_per_device // 2
|
|
92
|
+
|
|
93
|
+
def debug_print(msg, *args):
|
|
94
|
+
if debug_mode:
|
|
95
|
+
|
|
96
|
+
@pl.when(my_id == 0)
|
|
97
|
+
def _debug_print():
|
|
98
|
+
pl.debug_print(msg, *args)
|
|
99
|
+
|
|
100
|
+
def _start_or_wait_copy(
|
|
101
|
+
op: jax._src.pallas.mosaic.primitives.AsyncCopyDescriptor,
|
|
102
|
+
wait: bool = False,
|
|
103
|
+
):
|
|
104
|
+
if wait:
|
|
105
|
+
op.wait()
|
|
106
|
+
else:
|
|
107
|
+
op.start()
|
|
108
|
+
|
|
109
|
+
def _do_first_x_local_copy(wait: bool = False):
|
|
110
|
+
debug_print(
|
|
111
|
+
"[AGMM debug, wait={}] do first x local copy, x_vmem_receiving_slot={},"
|
|
112
|
+
" bk_i={}",
|
|
113
|
+
int(wait),
|
|
114
|
+
x_vmem_receiving_slot,
|
|
115
|
+
bk_i,
|
|
116
|
+
)
|
|
117
|
+
k_slice = pl.ds(bk_i * bk, bk)
|
|
118
|
+
x_local_copy_op = pltpu.make_async_copy(
|
|
119
|
+
src_ref=x_hbm_ref.at[:, k_slice],
|
|
120
|
+
dst_ref=x_vmem_scratch_ref.at[x_vmem_receiving_slot, :, k_slice],
|
|
121
|
+
sem=x_local_copy_sem,
|
|
122
|
+
)
|
|
123
|
+
_start_or_wait_copy(x_local_copy_op, wait)
|
|
124
|
+
|
|
125
|
+
def _do_subsequent_x_left_local_copy(wait: bool = False):
|
|
126
|
+
debug_print(
|
|
127
|
+
"[AGMM debug, wait={}] do subsequent x left local copy,"
|
|
128
|
+
" x_hbm_working_slot={}, x_vmem_receiving_slot={}, bk_i={}",
|
|
129
|
+
int(wait),
|
|
130
|
+
x_hbm_working_slot,
|
|
131
|
+
x_vmem_receiving_slot,
|
|
132
|
+
bk_i,
|
|
133
|
+
)
|
|
134
|
+
k_slice = pl.ds(bk_i * bk, bk)
|
|
135
|
+
x_local_copy_op = pltpu.make_async_copy(
|
|
136
|
+
src_ref=x_hbm_scratch_ref.at[
|
|
137
|
+
x_hbm_working_slot,
|
|
138
|
+
:m_per_device_per_direction,
|
|
139
|
+
k_slice,
|
|
140
|
+
],
|
|
141
|
+
dst_ref=x_vmem_scratch_ref.at[
|
|
142
|
+
x_vmem_receiving_slot,
|
|
143
|
+
:m_per_device_per_direction,
|
|
144
|
+
k_slice,
|
|
145
|
+
],
|
|
146
|
+
sem=x_local_copy_sem,
|
|
147
|
+
)
|
|
148
|
+
_start_or_wait_copy(x_local_copy_op, wait)
|
|
149
|
+
|
|
150
|
+
def _do_subsequent_x_right_local_copy(wait: bool = False):
|
|
151
|
+
debug_print(
|
|
152
|
+
"[AGMM debug, wait={}] do subsequent x right local copy,"
|
|
153
|
+
" x_hbm_working_slot={}, x_vmem_receiving_slot={}, bk_i={}",
|
|
154
|
+
int(wait),
|
|
155
|
+
x_hbm_working_slot,
|
|
156
|
+
x_vmem_receiving_slot,
|
|
157
|
+
bk_i,
|
|
158
|
+
)
|
|
159
|
+
x_local_copy_op = pltpu.make_async_copy(
|
|
160
|
+
src_ref=x_hbm_scratch_ref.at[
|
|
161
|
+
x_hbm_working_slot,
|
|
162
|
+
m_per_device_per_direction:,
|
|
163
|
+
pl.ds(bk_i * bk, bk),
|
|
164
|
+
],
|
|
165
|
+
dst_ref=x_vmem_scratch_ref.at[
|
|
166
|
+
x_vmem_receiving_slot,
|
|
167
|
+
m_per_device_per_direction:,
|
|
168
|
+
pl.ds(bk_i * bk, bk),
|
|
169
|
+
],
|
|
170
|
+
sem=x_local_copy_sem,
|
|
171
|
+
)
|
|
172
|
+
_start_or_wait_copy(x_local_copy_op, wait)
|
|
173
|
+
|
|
174
|
+
def _do_y_local_copy(wait: bool = False):
|
|
175
|
+
debug_print(
|
|
176
|
+
"[AGMM debug, wait={}] do y local copy, bk_i={}, bn_i={}",
|
|
177
|
+
int(wait),
|
|
178
|
+
bk_i,
|
|
179
|
+
bn_i,
|
|
180
|
+
)
|
|
181
|
+
k_slice = pl.ds(bk_i * bk, bk)
|
|
182
|
+
n_slice = pl.ds(bn_i * bn, bn)
|
|
183
|
+
if rhs_transpose:
|
|
184
|
+
y_local_copy_op = pltpu.make_async_copy(
|
|
185
|
+
src_ref=y_hbm_ref.at[n_slice, k_slice],
|
|
186
|
+
dst_ref=y_vmem_scratch_ref.at[n_slice, k_slice],
|
|
187
|
+
sem=y_local_copy_sem,
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
y_local_copy_op = pltpu.make_async_copy(
|
|
191
|
+
src_ref=y_hbm_ref.at[k_slice, n_slice],
|
|
192
|
+
dst_ref=y_vmem_scratch_ref.at[k_slice, n_slice],
|
|
193
|
+
sem=y_local_copy_sem,
|
|
194
|
+
)
|
|
195
|
+
_start_or_wait_copy(y_local_copy_op, wait)
|
|
196
|
+
|
|
197
|
+
def _do_first_left_remote_copy(wait: bool = False):
|
|
198
|
+
debug_print(
|
|
199
|
+
"[AGMM debug, wait={}] do first left remote copy,"
|
|
200
|
+
" x_hbm_receiving_slot={}, x_hbm_working_slot={}",
|
|
201
|
+
int(wait),
|
|
202
|
+
x_hbm_receiving_slot,
|
|
203
|
+
x_hbm_working_slot,
|
|
204
|
+
)
|
|
205
|
+
left_remote_copy_op = pltpu.make_async_remote_copy(
|
|
206
|
+
src_ref=x_hbm_ref.at[0:m_per_device_per_direction],
|
|
207
|
+
dst_ref=x_hbm_scratch_ref.at[x_hbm_receiving_slot,
|
|
208
|
+
0:m_per_device_per_direction],
|
|
209
|
+
send_sem=send_sems.at[0, outer_step],
|
|
210
|
+
recv_sem=recv_sems.at[0, outer_step],
|
|
211
|
+
device_id=(left_neighbor, ),
|
|
212
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
213
|
+
)
|
|
214
|
+
_start_or_wait_copy(left_remote_copy_op, wait)
|
|
215
|
+
|
|
216
|
+
def _do_first_right_remote_copy(wait: bool = False):
|
|
217
|
+
debug_print(
|
|
218
|
+
"[AGMM debug, wait={}] do first right remote copy,"
|
|
219
|
+
" x_hbm_receiving_slot={}, x_hbm_working_slot={}",
|
|
220
|
+
int(wait),
|
|
221
|
+
x_hbm_receiving_slot,
|
|
222
|
+
x_hbm_working_slot,
|
|
223
|
+
)
|
|
224
|
+
right_remote_copy_op = pltpu.make_async_remote_copy(
|
|
225
|
+
src_ref=x_hbm_ref.at[m_per_device_per_direction:m_per_device],
|
|
226
|
+
dst_ref=x_hbm_scratch_ref.at[
|
|
227
|
+
x_hbm_receiving_slot, m_per_device_per_direction:m_per_device],
|
|
228
|
+
send_sem=send_sems.at[1, outer_step],
|
|
229
|
+
recv_sem=recv_sems.at[1, outer_step],
|
|
230
|
+
device_id=(right_neighbor, ),
|
|
231
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
232
|
+
)
|
|
233
|
+
_start_or_wait_copy(right_remote_copy_op, wait)
|
|
234
|
+
|
|
235
|
+
def _do_subsequent_left_remote_copy(wait: bool = False):
|
|
236
|
+
debug_print(
|
|
237
|
+
"[AGMM debug, wait={}] do subsequent left remote copy,"
|
|
238
|
+
" x_hbm_receiving_slot={}, x_hbm_working_slot={}",
|
|
239
|
+
int(wait),
|
|
240
|
+
x_hbm_receiving_slot,
|
|
241
|
+
x_hbm_working_slot,
|
|
242
|
+
)
|
|
243
|
+
left_remote_copy_op = pltpu.make_async_remote_copy(
|
|
244
|
+
src_ref=x_hbm_scratch_ref.at[x_hbm_working_slot,
|
|
245
|
+
0:m_per_device_per_direction],
|
|
246
|
+
dst_ref=x_hbm_scratch_ref.at[x_hbm_receiving_slot,
|
|
247
|
+
0:m_per_device_per_direction],
|
|
248
|
+
send_sem=send_sems.at[0, outer_step],
|
|
249
|
+
recv_sem=recv_sems.at[0, outer_step],
|
|
250
|
+
device_id=(left_neighbor, ),
|
|
251
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
252
|
+
)
|
|
253
|
+
_start_or_wait_copy(left_remote_copy_op, wait)
|
|
254
|
+
|
|
255
|
+
def _do_subsequent_right_remote_copy(wait: bool = False):
|
|
256
|
+
debug_print(
|
|
257
|
+
"[AGMM debug, wait={}] do subsequent right remote copy,"
|
|
258
|
+
" x_hbm_receiving_slot={}, x_hbm_working_slot={}",
|
|
259
|
+
int(wait),
|
|
260
|
+
x_hbm_receiving_slot,
|
|
261
|
+
x_hbm_working_slot,
|
|
262
|
+
)
|
|
263
|
+
right_remote_copy_op = pltpu.make_async_remote_copy(
|
|
264
|
+
src_ref=x_hbm_scratch_ref.at[
|
|
265
|
+
x_hbm_working_slot, m_per_device_per_direction:m_per_device],
|
|
266
|
+
dst_ref=x_hbm_scratch_ref.at[
|
|
267
|
+
x_hbm_receiving_slot, m_per_device_per_direction:m_per_device],
|
|
268
|
+
send_sem=send_sems.at[1, outer_step],
|
|
269
|
+
recv_sem=recv_sems.at[1, outer_step],
|
|
270
|
+
device_id=(right_neighbor, ),
|
|
271
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
272
|
+
)
|
|
273
|
+
_start_or_wait_copy(right_remote_copy_op, wait)
|
|
274
|
+
|
|
275
|
+
def _do_mxu():
|
|
276
|
+
working_global_step_id = global_step_id - 1
|
|
277
|
+
working_bk_i = working_global_step_id % grid_k
|
|
278
|
+
working_bn_i = working_global_step_id % gn_by_gk // grid_k
|
|
279
|
+
debug_print(
|
|
280
|
+
"[AGMM debug] do mxu, x_vmem_working_slot={}, o_receiving_slot={},"
|
|
281
|
+
" working_bk_i={}, working_bn_i={}",
|
|
282
|
+
x_vmem_working_slot,
|
|
283
|
+
o_receiving_slot,
|
|
284
|
+
working_bk_i,
|
|
285
|
+
working_bn_i,
|
|
286
|
+
)
|
|
287
|
+
k_slice = pl.ds(working_bk_i * bk, bk)
|
|
288
|
+
n_slice = pl.ds(working_bn_i * bn, bn)
|
|
289
|
+
|
|
290
|
+
if grid_k == 1:
|
|
291
|
+
if rhs_transpose:
|
|
292
|
+
lhs = x_vmem_scratch_ref.at[x_vmem_working_slot][...]
|
|
293
|
+
rhs = y_vmem_scratch_ref.at[n_slice, :][...]
|
|
294
|
+
o_vmem_scratch_ref.at[o_receiving_slot][...] = lax.dot_general(
|
|
295
|
+
lhs,
|
|
296
|
+
rhs,
|
|
297
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
298
|
+
preferred_element_type=jnp.float32,
|
|
299
|
+
).astype(x_vmem_scratch_ref.dtype)
|
|
300
|
+
else:
|
|
301
|
+
o_vmem_scratch_ref.at[o_receiving_slot][...] = jnp.dot(
|
|
302
|
+
x_vmem_scratch_ref.at[x_vmem_working_slot][...],
|
|
303
|
+
y_vmem_scratch_ref.at[:, n_slice][...],
|
|
304
|
+
preferred_element_type=jnp.float32,
|
|
305
|
+
).astype(x_vmem_scratch_ref.dtype)
|
|
306
|
+
else:
|
|
307
|
+
# TODO(chengjiyao): optimize the vstore
|
|
308
|
+
if rhs_transpose:
|
|
309
|
+
lhs = x_vmem_scratch_ref.at[x_vmem_working_slot, :,
|
|
310
|
+
k_slice][...]
|
|
311
|
+
rhs = y_vmem_scratch_ref.at[n_slice, k_slice][...]
|
|
312
|
+
acc_vmem_scratch_ref[...] += lax.dot_general(
|
|
313
|
+
lhs,
|
|
314
|
+
rhs,
|
|
315
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
316
|
+
preferred_element_type=jnp.float32,
|
|
317
|
+
)
|
|
318
|
+
else:
|
|
319
|
+
acc_vmem_scratch_ref[...] += jnp.dot(
|
|
320
|
+
x_vmem_scratch_ref.at[x_vmem_working_slot, :,
|
|
321
|
+
k_slice][...],
|
|
322
|
+
y_vmem_scratch_ref.at[k_slice, n_slice][...],
|
|
323
|
+
preferred_element_type=jnp.float32,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
@pl.when(working_bk_i == grid_k - 1)
|
|
327
|
+
def _update():
|
|
328
|
+
debug_print(
|
|
329
|
+
"[AGMM debug] update, o_receiving_slot={}",
|
|
330
|
+
o_receiving_slot,
|
|
331
|
+
)
|
|
332
|
+
o_vmem_scratch_ref.at[o_receiving_slot][
|
|
333
|
+
...] = acc_vmem_scratch_ref[...].astype(
|
|
334
|
+
x_vmem_scratch_ref.dtype)
|
|
335
|
+
# TODO(chengjiyao): based on the kyuyeunk' suggestion:
|
|
336
|
+
# this logic can be more optimized. right now it does this.
|
|
337
|
+
# line 316 performs dot
|
|
338
|
+
# line 316 loads from acc_vmem_scartch_ref
|
|
339
|
+
# line 316 adds resulting dot with acc_vmem_scratch_ref
|
|
340
|
+
# line 316 stores result into acc_vmem_scratch_ref
|
|
341
|
+
# line 335 loads from acc_vmem_scratch_ref again.
|
|
342
|
+
# line 338 zero initializes & stores it to acc_vmem_scratch_ref
|
|
343
|
+
# better way would be
|
|
344
|
+
|
|
345
|
+
# perform dot
|
|
346
|
+
# if working_bk_i != 0, load from acc_vmem_scratch_ref and add result
|
|
347
|
+
# from previous step. If not, skip this process.
|
|
348
|
+
# if working_bk_i == gk - 1, store the result from step 2 into
|
|
349
|
+
# o_vmem_scratch_ref, if not, store it into acc_vmem_scratch_ref
|
|
350
|
+
acc_vmem_scratch_ref[...] = jnp.zeros_like(
|
|
351
|
+
acc_vmem_scratch_ref)
|
|
352
|
+
|
|
353
|
+
def _do_o_local_copy(wait: bool = False):
|
|
354
|
+
working_global_step_id = global_step_id - grid_k - 1
|
|
355
|
+
working_bn_i = (working_global_step_id % gn_by_gk) // grid_k
|
|
356
|
+
n_slice = pl.ds(working_bn_i * bn, bn)
|
|
357
|
+
offset = (global_step_id - 2) // gn_by_gk
|
|
358
|
+
left_o_idx = (my_id + offset) % num_devices
|
|
359
|
+
left_o_idx = left_o_idx * 2
|
|
360
|
+
right_o_idx = (my_id - offset + num_devices) % num_devices
|
|
361
|
+
right_o_idx = right_o_idx * 2 + 1
|
|
362
|
+
debug_print(
|
|
363
|
+
"[AGMM debug, wait={}] do o local copy, o_working_slot={},"
|
|
364
|
+
" left_o_idx={}, right_o_idx={}, working_bn_i={}",
|
|
365
|
+
int(wait),
|
|
366
|
+
o_working_slot,
|
|
367
|
+
left_o_idx,
|
|
368
|
+
right_o_idx,
|
|
369
|
+
working_bn_i,
|
|
370
|
+
)
|
|
371
|
+
o_left_local_copy_op = pltpu.make_async_copy(
|
|
372
|
+
src_ref=o_vmem_scratch_ref.at[
|
|
373
|
+
o_working_slot, :m_per_device_per_direction],
|
|
374
|
+
dst_ref=o_hbm_ref.at[
|
|
375
|
+
pl.ds(
|
|
376
|
+
m_per_device_per_direction * left_o_idx,
|
|
377
|
+
m_per_device_per_direction,
|
|
378
|
+
),
|
|
379
|
+
n_slice,
|
|
380
|
+
],
|
|
381
|
+
sem=o_local_copy_sem,
|
|
382
|
+
)
|
|
383
|
+
o_right_local_copy_op = pltpu.make_async_copy(
|
|
384
|
+
src_ref=o_vmem_scratch_ref.at[o_working_slot,
|
|
385
|
+
m_per_device_per_direction:],
|
|
386
|
+
dst_ref=o_hbm_ref.at[
|
|
387
|
+
pl.ds(
|
|
388
|
+
m_per_device_per_direction * right_o_idx,
|
|
389
|
+
m_per_device_per_direction,
|
|
390
|
+
),
|
|
391
|
+
n_slice,
|
|
392
|
+
],
|
|
393
|
+
sem=o_local_copy_sem,
|
|
394
|
+
)
|
|
395
|
+
_start_or_wait_copy(o_left_local_copy_op, wait)
|
|
396
|
+
_start_or_wait_copy(o_right_local_copy_op, wait)
|
|
397
|
+
|
|
398
|
+
### ------- Kernel start ------- ###
|
|
399
|
+
# TODO(chengjiyao): explore a fine-grained way to do the waits and signal
|
|
400
|
+
|
|
401
|
+
debug_print(
|
|
402
|
+
"===== starting a grid, outer_step={}, bn_i={}, bk_i={} =====",
|
|
403
|
+
outer_step,
|
|
404
|
+
bn_i,
|
|
405
|
+
bk_i,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
@pl.when(global_step_id == 0)
|
|
409
|
+
@jax.named_scope("_start_first_remote_copy")
|
|
410
|
+
def _start_first_remote_copy():
|
|
411
|
+
if grid_k > 1:
|
|
412
|
+
acc_vmem_scratch_ref[...] = jnp.zeros_like(acc_vmem_scratch_ref)
|
|
413
|
+
# Barrier with both neighbors at the start, since we will be
|
|
414
|
+
# communicating with both.
|
|
415
|
+
util.local_barrier(left_neighbor, right_neighbor)
|
|
416
|
+
_do_first_left_remote_copy(wait=False)
|
|
417
|
+
_do_first_right_remote_copy(wait=False)
|
|
418
|
+
|
|
419
|
+
cond_start_subsequent_remote_copy = jnp.logical_and(
|
|
420
|
+
jnp.logical_and(outer_step > 0, outer_step < num_devices - 1),
|
|
421
|
+
global_step_id % gn_by_gk == 0,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
@pl.when(cond_start_subsequent_remote_copy)
|
|
425
|
+
@jax.named_scope("_start_subsequent_remote_copy")
|
|
426
|
+
def _start_subsequent_remote_copy():
|
|
427
|
+
_do_subsequent_left_remote_copy(wait=False)
|
|
428
|
+
_do_subsequent_right_remote_copy(wait=False)
|
|
429
|
+
|
|
430
|
+
@pl.when(jnp.logical_and(outer_step == 0, bn_i == 0))
|
|
431
|
+
@jax.named_scope("_start_first_local_x_copy")
|
|
432
|
+
def _start_first_x_local_copy():
|
|
433
|
+
_do_first_x_local_copy(wait=False)
|
|
434
|
+
|
|
435
|
+
cond_subsequent_x_local_copy = jnp.logical_and(
|
|
436
|
+
jnp.logical_and(outer_step > 0, outer_step < num_devices), bn_i == 0)
|
|
437
|
+
|
|
438
|
+
@pl.when(cond_subsequent_x_local_copy)
|
|
439
|
+
@jax.named_scope("_start_subsequent_x_local_copy")
|
|
440
|
+
def _start_subsequent_x_local_copy():
|
|
441
|
+
_do_subsequent_x_left_local_copy(wait=False)
|
|
442
|
+
_do_subsequent_x_right_local_copy(wait=False)
|
|
443
|
+
|
|
444
|
+
@pl.when(outer_step == 0)
|
|
445
|
+
@jax.named_scope("_start_y_local_copy")
|
|
446
|
+
def _start_y_local_copy():
|
|
447
|
+
_do_y_local_copy(wait=False)
|
|
448
|
+
|
|
449
|
+
def _get_start_o_local_copy_cond():
|
|
450
|
+
if grid_k == 1:
|
|
451
|
+
return jnp.logical_and(global_step_id >= 2, global_step_id
|
|
452
|
+
< mxu_total_steps + 2)
|
|
453
|
+
else:
|
|
454
|
+
return jnp.logical_and(
|
|
455
|
+
jnp.logical_and(
|
|
456
|
+
global_step_id >= grid_k + 1,
|
|
457
|
+
global_step_id < mxu_total_steps + grid_k + 1,
|
|
458
|
+
),
|
|
459
|
+
global_step_id % grid_k == 1,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
@pl.when(_get_start_o_local_copy_cond())
|
|
463
|
+
@jax.named_scope("_start_o_local_copy")
|
|
464
|
+
def _start_o_local_copy():
|
|
465
|
+
_do_o_local_copy(wait=False)
|
|
466
|
+
|
|
467
|
+
@pl.when(
|
|
468
|
+
jnp.logical_and(global_step_id >= 1, global_step_id
|
|
469
|
+
< 1 + mxu_total_steps))
|
|
470
|
+
@jax.named_scope("_mxu")
|
|
471
|
+
def _mxu():
|
|
472
|
+
_do_mxu()
|
|
473
|
+
|
|
474
|
+
def _get_wait_o_local_copy_cond():
|
|
475
|
+
if grid_k == 1:
|
|
476
|
+
return jnp.logical_and(global_step_id >= 2, global_step_id
|
|
477
|
+
< mxu_total_steps + 2)
|
|
478
|
+
else:
|
|
479
|
+
return jnp.logical_and(
|
|
480
|
+
jnp.logical_and(
|
|
481
|
+
global_step_id >= grid_k + 1,
|
|
482
|
+
global_step_id < mxu_total_steps + grid_k + 1,
|
|
483
|
+
),
|
|
484
|
+
global_step_id % grid_k == 0,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
@pl.when(_get_wait_o_local_copy_cond())
|
|
488
|
+
@jax.named_scope("_wait_o_local_copy")
|
|
489
|
+
def _wait_o_local_copy():
|
|
490
|
+
_do_o_local_copy(wait=True)
|
|
491
|
+
|
|
492
|
+
@pl.when(outer_step == 0)
|
|
493
|
+
@jax.named_scope("_wait_y_local_copy")
|
|
494
|
+
def _wait_y_local_copy():
|
|
495
|
+
_do_y_local_copy(wait=True)
|
|
496
|
+
|
|
497
|
+
@pl.when(jnp.logical_and(outer_step == 0, bn_i == 0))
|
|
498
|
+
@jax.named_scope("_wait_first_x_local_copy")
|
|
499
|
+
def _wait_first_x_local_copy():
|
|
500
|
+
_do_first_x_local_copy(wait=True)
|
|
501
|
+
|
|
502
|
+
@pl.when(cond_subsequent_x_local_copy)
|
|
503
|
+
@jax.named_scope("_wait_subsequent_x_local_copy")
|
|
504
|
+
def _wait_subsequent_x_local_copy():
|
|
505
|
+
_do_subsequent_x_left_local_copy(wait=True)
|
|
506
|
+
_do_subsequent_x_right_local_copy(wait=True)
|
|
507
|
+
|
|
508
|
+
@pl.when(global_step_id == gn_by_gk - 1)
|
|
509
|
+
@jax.named_scope("_wait_first_remote_copy")
|
|
510
|
+
def _wait_first_remote_copy():
|
|
511
|
+
_do_first_left_remote_copy(wait=True)
|
|
512
|
+
_do_first_right_remote_copy(wait=True)
|
|
513
|
+
|
|
514
|
+
cond_wait_subsequent_remote_copy = jnp.logical_and(
|
|
515
|
+
jnp.logical_and(outer_step > 0, outer_step < num_devices - 1),
|
|
516
|
+
global_step_id % gn_by_gk == gn_by_gk - 1,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
@pl.when(cond_wait_subsequent_remote_copy)
|
|
520
|
+
@jax.named_scope("_wait_subsequent_remote_copy")
|
|
521
|
+
def _wait_subsequent_remote_copy():
|
|
522
|
+
_do_subsequent_left_remote_copy(wait=True)
|
|
523
|
+
_do_subsequent_right_remote_copy(wait=True)
|
|
524
|
+
|
|
525
|
+
### ------- Kernel end ------- ###
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
# FIXME(chengjiyao): make it accurate for the cases of quantization
|
|
529
|
+
def get_vmem_estimate_bytes(
|
|
530
|
+
m,
|
|
531
|
+
n,
|
|
532
|
+
k,
|
|
533
|
+
bn,
|
|
534
|
+
acc_bytes,
|
|
535
|
+
tp_size,
|
|
536
|
+
x_dtype,
|
|
537
|
+
y_dtype,
|
|
538
|
+
out_dtype,
|
|
539
|
+
):
|
|
540
|
+
"""Returns the total vmem bytes used by the kernel."""
|
|
541
|
+
m_per_device = m // tp_size
|
|
542
|
+
n_per_device = n // tp_size
|
|
543
|
+
y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype) // 8
|
|
544
|
+
total_bytes = (
|
|
545
|
+
2 * m_per_device * k * dtypes.bit_width(x_dtype) //
|
|
546
|
+
8 # x_vmem_scratch_ref
|
|
547
|
+
+ y_vmem_bytes # y_vmem_scratch_ref
|
|
548
|
+
+ 2 * m * bn * dtypes.bit_width(out_dtype) // 8 # o_vmem_scratch_ref
|
|
549
|
+
+ acc_bytes # acc_vmem_scratch_ref, jnp.float32
|
|
550
|
+
)
|
|
551
|
+
return total_bytes
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def validate_inputs(x, y, tp_size, rhs_transpose=False):
|
|
555
|
+
"""Validates the inputs to the all_gather_matmul kernel."""
|
|
556
|
+
if x.ndim != 2 or y.ndim != 2:
|
|
557
|
+
raise ValueError(
|
|
558
|
+
f"Inputs must be 2D, got shapes {x.shape} and {y.shape}.")
|
|
559
|
+
if x.dtype != y.dtype:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
f"Input dtypes must match, got {x.dtype} and {y.dtype}.")
|
|
562
|
+
m, k = x.shape
|
|
563
|
+
if rhs_transpose:
|
|
564
|
+
n, k_from_y = y.shape
|
|
565
|
+
else:
|
|
566
|
+
k_from_y, n = y.shape
|
|
567
|
+
if k != k_from_y:
|
|
568
|
+
raise ValueError(
|
|
569
|
+
"Incompatible shapes for matmul: contracting dimension mismatch:"
|
|
570
|
+
f" {x.shape} and {y.shape}.")
|
|
571
|
+
|
|
572
|
+
if k % 128 != 0:
|
|
573
|
+
raise ValueError(f"k ({k}) must be divisible by 128.")
|
|
574
|
+
|
|
575
|
+
if n % 128 != 0:
|
|
576
|
+
raise ValueError(f"n ({n}) must be divisible by 128.")
|
|
577
|
+
|
|
578
|
+
m_per_device_per_direction = m // tp_size // 2
|
|
579
|
+
if m_per_device_per_direction % 8 != 0:
|
|
580
|
+
raise ValueError(f"m ({m}) must be divisible by {{tp_size * 2 * 8}}.")
|
|
581
|
+
|
|
582
|
+
if m % (tp_size * 2) != 0:
|
|
583
|
+
raise ValueError(
|
|
584
|
+
f"x.shape[0] ({m}) must be divisible by tp_size * 2 ({tp_size * 2})'."
|
|
585
|
+
)
|
|
586
|
+
if n % tp_size != 0:
|
|
587
|
+
raise ValueError(
|
|
588
|
+
f"y.shape[{0 if rhs_transpose else 1}] ({n}) must be divisible by"
|
|
589
|
+
f" tp_size ({tp_size}) on axis '{tp_size}'.")
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def all_gather_matmul(
|
|
593
|
+
x: jax.Array,
|
|
594
|
+
y: jax.Array,
|
|
595
|
+
mesh: jax.sharding.AbstractMesh,
|
|
596
|
+
axis_name: str,
|
|
597
|
+
collective_id: int | None = 0,
|
|
598
|
+
bn: int | None = None,
|
|
599
|
+
bk: int | None = None,
|
|
600
|
+
rhs_transpose: bool = False,
|
|
601
|
+
):
|
|
602
|
+
"""Performs all-gather on the input tensor and then a matmul.
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
x: LHS of the matmul before all-gather.
|
|
606
|
+
y: RHS of the matmul.
|
|
607
|
+
mesh: JAX mesh.
|
|
608
|
+
axis_name: Name of the axis to all-gather over.
|
|
609
|
+
collective_id: An integer used for barrier semaphore allocation.
|
|
610
|
+
bn: Number of blocks in the n dimension.
|
|
611
|
+
bk: Number of blocks in the k dimension.
|
|
612
|
+
rhs_transpose: If True, y is transposed.
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
all-gather(x, axis=0) @ y
|
|
616
|
+
"""
|
|
617
|
+
tp_size = mesh.shape[axis_name]
|
|
618
|
+
validate_inputs(x, y, tp_size, rhs_transpose)
|
|
619
|
+
m, k = x.shape
|
|
620
|
+
if rhs_transpose:
|
|
621
|
+
n, _ = y.shape
|
|
622
|
+
y_in_spec = P(axis_name, None)
|
|
623
|
+
else:
|
|
624
|
+
_, n = y.shape
|
|
625
|
+
y_in_spec = P(None, axis_name)
|
|
626
|
+
m_per_device = m // tp_size
|
|
627
|
+
n_per_device = n // tp_size
|
|
628
|
+
tuned_bn, tuned_bk = (
|
|
629
|
+
all_gather_matmul_tuned_block_sizes.get_tuned_block_sizes(
|
|
630
|
+
m, n, k,
|
|
631
|
+
jnp.dtype(x.dtype).name, tp_size))
|
|
632
|
+
if bn is None:
|
|
633
|
+
bn = tuned_bn if tuned_bn is not None else n
|
|
634
|
+
if bk is None:
|
|
635
|
+
bk = tuned_bk if tuned_bk is not None else k
|
|
636
|
+
grid_n = _cdiv(n_per_device, bn)
|
|
637
|
+
grid_k = _cdiv(k, bk)
|
|
638
|
+
acc_shape = (m_per_device, bn)
|
|
639
|
+
# NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
|
|
640
|
+
if grid_k == 1:
|
|
641
|
+
acc_shape = (8, 128)
|
|
642
|
+
acc_bytes = acc_shape[0] * acc_shape[1] * dtypes.bit_width(
|
|
643
|
+
jnp.float32) // 8
|
|
644
|
+
y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
|
|
645
|
+
estimated_vmem_bytes = get_vmem_estimate_bytes(
|
|
646
|
+
m,
|
|
647
|
+
n,
|
|
648
|
+
k,
|
|
649
|
+
bn,
|
|
650
|
+
acc_bytes,
|
|
651
|
+
tp_size,
|
|
652
|
+
x.dtype,
|
|
653
|
+
y.dtype,
|
|
654
|
+
x.dtype,
|
|
655
|
+
)
|
|
656
|
+
out_shape = [
|
|
657
|
+
jax.ShapeDtypeStruct((m, n_per_device), x.dtype), # output
|
|
658
|
+
jax.ShapeDtypeStruct((tp_size - 1, m_per_device, k),
|
|
659
|
+
x.dtype), # x HBM scratch
|
|
660
|
+
]
|
|
661
|
+
grid_spec = pltpu.PrefetchScalarGridSpec(
|
|
662
|
+
num_scalar_prefetch=0,
|
|
663
|
+
in_specs=[
|
|
664
|
+
pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
665
|
+
pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
666
|
+
],
|
|
667
|
+
out_specs=[
|
|
668
|
+
pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
669
|
+
pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
670
|
+
],
|
|
671
|
+
scratch_shapes=(
|
|
672
|
+
pltpu.SemaphoreType.DMA, # x_local_copy_sem
|
|
673
|
+
pltpu.SemaphoreType.DMA, # y_local_copy_sem
|
|
674
|
+
pltpu.SemaphoreType.DMA, # o_local_copy_sem
|
|
675
|
+
pltpu.SemaphoreType.DMA(
|
|
676
|
+
(2, tp_size - 1)), # left and right send semaphores
|
|
677
|
+
pltpu.SemaphoreType.DMA((
|
|
678
|
+
2,
|
|
679
|
+
tp_size - 1,
|
|
680
|
+
)), # left and right recv semaphores
|
|
681
|
+
pltpu.VMEM((2, m_per_device, k), x.dtype), # x vmem scratch
|
|
682
|
+
pltpu.VMEM(y_vmem_shape, y.dtype), # y vmem scratch
|
|
683
|
+
pltpu.VMEM((2, m_per_device, bn), x.dtype), # output vmem scratch
|
|
684
|
+
pltpu.VMEM(acc_shape, jnp.float32), # acc vmem scratch
|
|
685
|
+
),
|
|
686
|
+
grid=(tp_size + 2, grid_n, grid_k),
|
|
687
|
+
)
|
|
688
|
+
flops = 2 * m * k * n_per_device
|
|
689
|
+
bytes_accessed = x.dtype.itemsize * (m * k + k * n_per_device +
|
|
690
|
+
m * n_per_device)
|
|
691
|
+
cost_estimate = pl.CostEstimate(flops=flops,
|
|
692
|
+
bytes_accessed=bytes_accessed,
|
|
693
|
+
transcendentals=0)
|
|
694
|
+
|
|
695
|
+
@functools.partial(jax.jit, static_argnames=["bn", "bk", "rhs_transpose"])
|
|
696
|
+
def _all_gather_matmul_call(x, y, bn, bk, rhs_transpose):
|
|
697
|
+
return pl.pallas_call(
|
|
698
|
+
functools.partial(
|
|
699
|
+
_all_gather_kernel,
|
|
700
|
+
bn=bn,
|
|
701
|
+
bk=bk,
|
|
702
|
+
axis_name=axis_name,
|
|
703
|
+
rhs_transpose=rhs_transpose,
|
|
704
|
+
),
|
|
705
|
+
out_shape=out_shape,
|
|
706
|
+
grid_spec=grid_spec,
|
|
707
|
+
compiler_params=pltpu.CompilerParams(
|
|
708
|
+
collective_id=collective_id,
|
|
709
|
+
vmem_limit_bytes=estimated_vmem_bytes + 8 * 1024 * 1024,
|
|
710
|
+
),
|
|
711
|
+
cost_estimate=cost_estimate,
|
|
712
|
+
name=get_kernel_name(bn, bk, rhs_transpose),
|
|
713
|
+
)(x, y)[0]
|
|
714
|
+
|
|
715
|
+
shard_map_kernel = jax.jit(
|
|
716
|
+
jax.shard_map(
|
|
717
|
+
functools.partial(
|
|
718
|
+
_all_gather_matmul_call,
|
|
719
|
+
bn=bn,
|
|
720
|
+
bk=bk,
|
|
721
|
+
rhs_transpose=rhs_transpose,
|
|
722
|
+
),
|
|
723
|
+
mesh=mesh,
|
|
724
|
+
in_specs=(P(axis_name, None), y_in_spec),
|
|
725
|
+
out_specs=P(None, axis_name),
|
|
726
|
+
check_vma=False,
|
|
727
|
+
), )
|
|
728
|
+
|
|
729
|
+
return shard_map_kernel(x, y)
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def get_kernel_name(bn: int, bk: int, rhs_transpose: bool):
|
|
733
|
+
return (
|
|
734
|
+
f"all_gather_matmul_kernel_bn_{bn}_bk_{bk}_rhs_transpose_{rhs_transpose}"
|
|
735
|
+
)
|