tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -11,9 +11,13 @@ from jax.sharding import Mesh
|
|
|
11
11
|
from jax.sharding import PartitionSpec as P
|
|
12
12
|
|
|
13
13
|
from tpu_inference import utils
|
|
14
|
+
from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
|
|
14
15
|
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
|
|
15
16
|
ragged_paged_attention
|
|
17
|
+
from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
|
|
18
|
+
get_tuned_block_sizes
|
|
16
19
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
20
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
17
21
|
from tpu_inference.layers.jax.base import create_param
|
|
18
22
|
from tpu_inference.layers.jax.layers import RMSNorm
|
|
19
23
|
from tpu_inference.layers.jax.rope import DeepseekScalingRotaryEmbedding
|
|
@@ -66,6 +70,7 @@ class MLA(nnx.Module):
|
|
|
66
70
|
rope_input_ordering: str = "split"
|
|
67
71
|
quant: Any | None = None
|
|
68
72
|
rope_mscale_all_dim: float = 1.0
|
|
73
|
+
use_mla_kernel: bool = False
|
|
69
74
|
|
|
70
75
|
rngs: InitVar[nnx.Rngs]
|
|
71
76
|
|
|
@@ -77,10 +82,10 @@ class MLA(nnx.Module):
|
|
|
77
82
|
self.N = self.num_attention_heads
|
|
78
83
|
self.K = self.num_key_value_heads
|
|
79
84
|
self.D = self.hidden_size
|
|
80
|
-
|
|
81
85
|
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
|
82
86
|
|
|
83
|
-
|
|
87
|
+
if not self.use_mla_kernel:
|
|
88
|
+
assert self.N == self.K, "N and K must be equal for MLA"
|
|
84
89
|
|
|
85
90
|
if self.rope_scaling["factor"] <= 1.0:
|
|
86
91
|
yarn_mscale = 1.0
|
|
@@ -122,14 +127,30 @@ class MLA(nnx.Module):
|
|
|
122
127
|
self.dtype,
|
|
123
128
|
random_init=self.random_init,
|
|
124
129
|
)
|
|
125
|
-
self.
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
130
|
+
if self.use_mla_kernel:
|
|
131
|
+
self.kernel_k_up_proj_ANH = create_param(
|
|
132
|
+
rngs,
|
|
133
|
+
(self.kv_lora_rank, self.N, self.qk_nope_head_dim),
|
|
134
|
+
self.anh_sharding,
|
|
135
|
+
self.dtype,
|
|
136
|
+
random_init=self.random_init,
|
|
137
|
+
)
|
|
138
|
+
self.kernel_v_up_proj_ANH = create_param(
|
|
139
|
+
rngs,
|
|
140
|
+
(self.kv_lora_rank, self.N, self.v_head_dim),
|
|
141
|
+
self.anh_sharding,
|
|
142
|
+
self.dtype,
|
|
143
|
+
random_init=self.random_init,
|
|
144
|
+
)
|
|
145
|
+
else:
|
|
146
|
+
self.kernel_kv_up_proj_ANH = create_param(
|
|
147
|
+
rngs,
|
|
148
|
+
(self.kv_lora_rank, self.N,
|
|
149
|
+
self.qk_nope_head_dim + self.v_head_dim),
|
|
150
|
+
self.anh_sharding,
|
|
151
|
+
self.dtype,
|
|
152
|
+
random_init=self.random_init,
|
|
153
|
+
)
|
|
133
154
|
self.kernel_o_proj_NHD = create_param(
|
|
134
155
|
rngs, (self.N, self.v_head_dim, self.D),
|
|
135
156
|
self.nhd_sharding,
|
|
@@ -195,10 +216,16 @@ class MLA(nnx.Module):
|
|
|
195
216
|
q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
|
|
196
217
|
q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
|
|
197
218
|
q_rope_TNH = self.rope.apply_rope(md.input_positions, q_rope_TNH)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
219
|
+
if self.use_mla_kernel:
|
|
220
|
+
# Absorb the k up-projection matrix into q
|
|
221
|
+
q_TNA = jnp.einsum("TNH,ANH -> TNA", q_nope_TNH,
|
|
222
|
+
self.kernel_k_up_proj_ANH.value)
|
|
223
|
+
q_TNA = nnx.with_sharding_constraint(q_TNA, self.query_tnh)
|
|
224
|
+
else:
|
|
225
|
+
# Concatenate the nope and rope queries.
|
|
226
|
+
q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
|
|
227
|
+
# Multiply the query by scaling factor
|
|
228
|
+
q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
|
|
202
229
|
|
|
203
230
|
with jax.named_scope("kv_proj"):
|
|
204
231
|
# KV down projection.
|
|
@@ -209,21 +236,27 @@ class MLA(nnx.Module):
|
|
|
209
236
|
# Reshape k_rope_BSH to include head dimension for RoPE application
|
|
210
237
|
k_rope_SNH = k_rope_SH[..., None, :]
|
|
211
238
|
k_rope_SNH = self.rope.apply_rope(md.input_positions, k_rope_SNH)
|
|
212
|
-
k_rope_SNH
|
|
213
|
-
|
|
214
|
-
|
|
239
|
+
assert k_rope_SNH.shape[1] == 1
|
|
240
|
+
k_rope_SH = k_rope_SNH[:, 0, :]
|
|
241
|
+
|
|
215
242
|
kv_SA = kv_SA[..., :self.kv_lora_rank]
|
|
216
243
|
kv_SA = self.kv_rms_norm(kv_SA)
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
244
|
+
kv_SA = nnx.with_sharding_constraint(kv_SA, self.keyvalue_skh)
|
|
245
|
+
|
|
246
|
+
if not self.use_mla_kernel:
|
|
247
|
+
k_rope_SNH = jnp.broadcast_to(
|
|
248
|
+
k_rope_SNH,
|
|
249
|
+
(k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
|
|
250
|
+
# KV up projection.
|
|
251
|
+
kv_nope_SNH = jnp.einsum("SA,ANH -> SNH", kv_SA,
|
|
252
|
+
self.kernel_kv_up_proj_ANH.value)
|
|
253
|
+
# Split the latent kv vector into k nope vector and v vector.
|
|
254
|
+
k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
|
|
255
|
+
v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
|
|
256
|
+
# Concatenate the key vector.
|
|
257
|
+
k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
|
|
258
|
+
k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
|
|
259
|
+
v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
|
|
227
260
|
|
|
228
261
|
with jax.named_scope("attn_op"):
|
|
229
262
|
# TODO(wenxindongwork): K and V have different head dimension,
|
|
@@ -234,44 +267,66 @@ class MLA(nnx.Module):
|
|
|
234
267
|
# q, k, v head dimension to be multiple of 128. For now, we will
|
|
235
268
|
# pad the q, k, v dimension to multiple of 128.
|
|
236
269
|
# We should update the MLA kv cache implementation in the future.
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
270
|
+
if not self.use_mla_kernel: # MLA kernel handles padding
|
|
271
|
+
multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
|
|
272
|
+
q_TNH = jnp.pad(q_TNH,
|
|
273
|
+
((0, 0), (0, 0),
|
|
274
|
+
(0, multiple_of_128 - self.qk_head_dim)))
|
|
275
|
+
k_SNH = jnp.pad(k_SNH,
|
|
276
|
+
((0, 0), (0, 0),
|
|
277
|
+
(0, multiple_of_128 - self.qk_head_dim)))
|
|
278
|
+
v_SNH = jnp.pad(v_SNH,
|
|
279
|
+
((0, 0), (0, 0),
|
|
280
|
+
(0, multiple_of_128 - self.v_head_dim)))
|
|
281
|
+
|
|
244
282
|
q_scale = k_scale = v_scale = None
|
|
245
|
-
if self.kv_cache_quantized_dtype:
|
|
246
|
-
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
247
|
-
# q_scale = self._q_scale
|
|
248
|
-
k_scale = self._k_scale
|
|
249
|
-
v_scale = self._v_scale
|
|
250
|
-
k_SNH, v_SNH = utils.quantize_kv(k_SNH, v_SNH,
|
|
251
|
-
self.kv_cache_quantized_dtype,
|
|
252
|
-
k_scale, v_scale)
|
|
253
|
-
new_kv_cache, outputs_TNH = self.attention(
|
|
254
|
-
is_prefill,
|
|
255
|
-
kv_cache,
|
|
256
|
-
q_TNH,
|
|
257
|
-
k_SNH,
|
|
258
|
-
v_SNH,
|
|
259
|
-
attention_metadata,
|
|
260
|
-
self.mesh,
|
|
261
|
-
q_scale,
|
|
262
|
-
k_scale,
|
|
263
|
-
v_scale,
|
|
264
|
-
)
|
|
265
|
-
# TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
|
|
266
|
-
# We shall add the MLA kv cache implementation in the future.
|
|
267
|
-
outputs_TNH = outputs_TNH[..., :self.v_head_dim]
|
|
268
283
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
284
|
+
# TODO(gpolovets): MLA does not currently support quantized KV!
|
|
285
|
+
if not self.use_mla_kernel:
|
|
286
|
+
if self.kv_cache_quantized_dtype:
|
|
287
|
+
# TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
|
|
288
|
+
k_scale = self._k_scale
|
|
289
|
+
v_scale = self._v_scale
|
|
290
|
+
k_SNH, v_SNH = utils.quantize_kv(
|
|
291
|
+
k_SNH, v_SNH, self.kv_cache_quantized_dtype, k_scale,
|
|
292
|
+
v_scale)
|
|
293
|
+
|
|
294
|
+
new_kv_cache, outputs_TNH = self.attention(
|
|
295
|
+
is_prefill,
|
|
296
|
+
kv_cache,
|
|
297
|
+
q_TNH,
|
|
298
|
+
k_SNH,
|
|
299
|
+
v_SNH,
|
|
300
|
+
attention_metadata,
|
|
301
|
+
self.mesh,
|
|
302
|
+
q_scale,
|
|
303
|
+
k_scale,
|
|
304
|
+
v_scale,
|
|
305
|
+
)
|
|
306
|
+
# TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
|
|
307
|
+
# We shall add the MLA kv cache implementation in the future.
|
|
308
|
+
outputs_TNH = outputs_TNH[..., :self.v_head_dim]
|
|
309
|
+
|
|
310
|
+
else:
|
|
311
|
+
new_kv_cache, outputs_TNA = self.mla_attention(
|
|
312
|
+
kv_cache,
|
|
313
|
+
q_TNA,
|
|
314
|
+
q_rope_TNH,
|
|
315
|
+
kv_SA,
|
|
316
|
+
k_rope_SH,
|
|
317
|
+
attention_metadata,
|
|
318
|
+
self.mesh,
|
|
319
|
+
)
|
|
320
|
+
outputs_TNH = jnp.einsum("TNA,ANH -> TNH", outputs_TNA,
|
|
321
|
+
self.kernel_v_up_proj_ANH.value)
|
|
322
|
+
|
|
323
|
+
with jax.named_scope("o_proj"):
|
|
324
|
+
outputs_TNH = nnx.with_sharding_constraint(
|
|
325
|
+
outputs_TNH, self.activation_attention_out_td)
|
|
326
|
+
o_TD = jnp.einsum("TNH,NHD -> TD", outputs_TNH,
|
|
327
|
+
self.kernel_o_proj_NHD.value)
|
|
328
|
+
|
|
329
|
+
return new_kv_cache, o_TD
|
|
275
330
|
|
|
276
331
|
def attention(
|
|
277
332
|
self,
|
|
@@ -326,13 +381,14 @@ class MLA(nnx.Module):
|
|
|
326
381
|
out_specs = (self.attn_o_tnh, P(None, None, "model"))
|
|
327
382
|
|
|
328
383
|
def _ragged_paged_attention(*args):
|
|
329
|
-
|
|
384
|
+
outputs = ragged_paged_attention(
|
|
330
385
|
*args,
|
|
331
386
|
sm_scale=self.scale,
|
|
332
387
|
q_scale=q_scale,
|
|
333
388
|
k_scale=k_scale,
|
|
334
389
|
v_scale=v_scale,
|
|
335
390
|
)
|
|
391
|
+
return outputs
|
|
336
392
|
|
|
337
393
|
output_TNH, kv_cache = jax.jit(
|
|
338
394
|
shard_map.shard_map(
|
|
@@ -352,3 +408,115 @@ class MLA(nnx.Module):
|
|
|
352
408
|
md.request_distribution,
|
|
353
409
|
)
|
|
354
410
|
return kv_cache, output_TNH
|
|
411
|
+
|
|
412
|
+
def mla_attention(
|
|
413
|
+
self,
|
|
414
|
+
kv_cache: KVCache,
|
|
415
|
+
q_TNA: jax.Array,
|
|
416
|
+
q_rope_TNH: jax.Array,
|
|
417
|
+
k_SA: jax.Array,
|
|
418
|
+
k_rope_SH: jax.Array,
|
|
419
|
+
attention_metadata: AttentionMetadata,
|
|
420
|
+
mesh: Mesh,
|
|
421
|
+
) -> Tuple[KVCache, jax.Array]:
|
|
422
|
+
"""Performs scaled dot-product attention and updates the KV cache.
|
|
423
|
+
|
|
424
|
+
This function handles the core attention logic, which varies between
|
|
425
|
+
prefill and generation modes. In prefill, it computes self-attention
|
|
426
|
+
over the input sequence with a causal mask. In generation, it attends
|
|
427
|
+
to the full history of keys and values stored in the cache.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
kv_cache: The key-value cache to be updated and used.
|
|
431
|
+
q_TNA: Query tensor of shape `(query_seq, num_attention_heads, lkv_dim)`.
|
|
432
|
+
q_rope_TNH: Query rope tensor of shape `(query_seq, num_attention_heads, rope_dim)`.
|
|
433
|
+
k_SA: Key tensor of shape `(kv_seq, lkv_dim)`.
|
|
434
|
+
k_rope_SH: Key rope tensor of shape `(kv_seq, rope_dim)`.
|
|
435
|
+
attention_metadata: Metadata containing sequence lengths.
|
|
436
|
+
mesh: The JAX device mesh (unused in this specific function but
|
|
437
|
+
kept for potential future use or API consistency).
|
|
438
|
+
q_scale: Quantization scale for q.
|
|
439
|
+
k_scale: Quantization scale for k.
|
|
440
|
+
v_scale: Quantization scale for v.
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
A tuple containing:
|
|
444
|
+
- The updated KV cache.
|
|
445
|
+
- The attention output tensor of shape
|
|
446
|
+
`(seq, num_q_heads, head_dim)`.
|
|
447
|
+
"""
|
|
448
|
+
md = attention_metadata
|
|
449
|
+
in_specs = (
|
|
450
|
+
self.query_tnh, # q
|
|
451
|
+
self.query_tnh, # q_rope
|
|
452
|
+
self.keyvalue_skh, # k
|
|
453
|
+
self.keyvalue_skh, # k_rope
|
|
454
|
+
P(ShardingAxisName.MLP_TENSOR), # kv_cache
|
|
455
|
+
P(ShardingAxisName.ATTN_DATA), # md.seq_lens: Replicated
|
|
456
|
+
P(ShardingAxisName.ATTN_DATA), # page_indices_flat: Replicated
|
|
457
|
+
P(ShardingAxisName.ATTN_DATA), # query_start_loc: Replicated
|
|
458
|
+
P(ShardingAxisName.ATTN_DATA), # distribution: Replicated
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
out_specs = (self.attn_o_tnh, P(ShardingAxisName.MLP_TENSOR))
|
|
462
|
+
|
|
463
|
+
def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
|
|
464
|
+
|
|
465
|
+
def _initialize_block_sizes():
|
|
466
|
+
# Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
|
|
467
|
+
# Referring to get_tuned_block_sizes() in kernels/ragged_paged_attention/v3/tuned_block_sizes.py: 'TPU v7'/128/'q_bfloat16_kv_bfloat16/q_head-128_kv_head-1_head-128'/4096
|
|
468
|
+
max_num_tokens = q.shape[0]
|
|
469
|
+
max_num_seqs = md.seq_lens.shape[0]
|
|
470
|
+
num_page_indices = md.block_tables.shape[0]
|
|
471
|
+
assert num_page_indices % max_num_seqs == 0
|
|
472
|
+
pages_per_seq = num_page_indices // max_num_seqs
|
|
473
|
+
# num_kv_pages_per_block = min(pages_per_seq, 16)
|
|
474
|
+
bkv_p, bq_sz = get_tuned_block_sizes(
|
|
475
|
+
q.dtype,
|
|
476
|
+
kv_cache.dtype,
|
|
477
|
+
self.num_attention_heads,
|
|
478
|
+
1,
|
|
479
|
+
self.qk_nope_head_dim,
|
|
480
|
+
kv_cache.shape[1], # page size
|
|
481
|
+
max_num_tokens,
|
|
482
|
+
pages_per_seq,
|
|
483
|
+
)
|
|
484
|
+
num_kv_pages_per_block = min(min(pages_per_seq, bkv_p), 4)
|
|
485
|
+
num_queries_per_block = min(min(max_num_tokens, bq_sz),
|
|
486
|
+
4) # OOMS at 8
|
|
487
|
+
return num_kv_pages_per_block, num_queries_per_block
|
|
488
|
+
|
|
489
|
+
num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes(
|
|
490
|
+
)
|
|
491
|
+
output, kv_cache = mla_ragged_paged_attention(
|
|
492
|
+
q,
|
|
493
|
+
q_rope,
|
|
494
|
+
k,
|
|
495
|
+
k_rope,
|
|
496
|
+
kv_cache,
|
|
497
|
+
*args,
|
|
498
|
+
sm_scale=self.scale,
|
|
499
|
+
num_kv_pages_per_block=num_kv_pages_per_block,
|
|
500
|
+
num_queries_per_block=num_queries_per_block)
|
|
501
|
+
|
|
502
|
+
return kv_cache, output
|
|
503
|
+
|
|
504
|
+
kv_cache, output_TNH = jax.jit(
|
|
505
|
+
shard_map.shard_map(
|
|
506
|
+
_mla_ragged_paged_attention,
|
|
507
|
+
mesh=mesh,
|
|
508
|
+
in_specs=in_specs,
|
|
509
|
+
out_specs=out_specs,
|
|
510
|
+
check_rep=False,
|
|
511
|
+
), )(
|
|
512
|
+
q_TNA,
|
|
513
|
+
q_rope_TNH,
|
|
514
|
+
k_SA,
|
|
515
|
+
k_rope_SH,
|
|
516
|
+
kv_cache,
|
|
517
|
+
md.seq_lens,
|
|
518
|
+
md.block_tables,
|
|
519
|
+
md.query_start_loc,
|
|
520
|
+
md.request_distribution,
|
|
521
|
+
)
|
|
522
|
+
return kv_cache, output_TNH
|
|
@@ -158,17 +158,17 @@ class GptOssAttention(nnx.Module):
|
|
|
158
158
|
) -> Tuple[KVCache, jax.Array]:
|
|
159
159
|
"""Performs scaled dot-product attention by calling the ragged_paged_attention kernel."""
|
|
160
160
|
md = attention_metadata
|
|
161
|
-
kv_cache_spec = P(
|
|
161
|
+
kv_cache_spec = P("data", None, "model")
|
|
162
162
|
|
|
163
163
|
in_specs = (
|
|
164
164
|
self.query_tnh, # q
|
|
165
165
|
self.keyvalue_skh, # k
|
|
166
166
|
self.keyvalue_skh, # v
|
|
167
167
|
kv_cache_spec, # kv_cache
|
|
168
|
-
P(), # md.seq_lens
|
|
169
|
-
P(), # page_indices_flat
|
|
170
|
-
P(), # query_start_loc
|
|
171
|
-
P(), # distribution
|
|
168
|
+
P("data"), # md.seq_lens
|
|
169
|
+
P("data"), # page_indices_flat
|
|
170
|
+
P("data"), # query_start_loc
|
|
171
|
+
P("data"), # distribution
|
|
172
172
|
P(('model')), # sinks
|
|
173
173
|
)
|
|
174
174
|
out_specs = (self.attn_o_tnh, kv_cache_spec)
|