tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -16,17 +16,30 @@ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
|
|
|
16
16
|
DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
def get_kv_cache_shape(
|
|
20
|
+
total_num_pages,
|
|
21
|
+
page_size,
|
|
22
|
+
kv_dim,
|
|
23
|
+
kv_dtype,
|
|
24
|
+
):
|
|
25
|
+
kv_packing = get_dtype_packing(kv_dtype)
|
|
26
|
+
return (
|
|
27
|
+
total_num_pages,
|
|
28
|
+
align_to(page_size, kv_packing) // kv_packing,
|
|
29
|
+
kv_packing,
|
|
30
|
+
align_to(kv_dim, 128),
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
19
34
|
@functools.partial(
|
|
20
35
|
jax.jit,
|
|
21
|
-
donate_argnames=("
|
|
36
|
+
donate_argnames=("cache_kv"),
|
|
22
37
|
)
|
|
23
38
|
def update_kv_cache(
|
|
24
39
|
new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
|
|
25
40
|
new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
|
|
26
|
-
|
|
27
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
|
|
28
|
-
cache_k_pe: jax.
|
|
29
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
41
|
+
cache_kv: jax.
|
|
42
|
+
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim+r_dim]
|
|
30
43
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
31
44
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
32
45
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -43,25 +56,21 @@ def update_kv_cache(
|
|
|
43
56
|
if actual_lkv_dim != lkv_dim:
|
|
44
57
|
new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)),
|
|
45
58
|
constant_values=0)
|
|
46
|
-
|
|
47
|
-
_, page_size_per_kv_packing, kv_packing,
|
|
48
|
-
|
|
49
|
-
assert lkv_dim == cache_lkv_dim
|
|
50
|
-
assert r_dim == cache_r_dim
|
|
59
|
+
kv_dim = r_dim + lkv_dim
|
|
60
|
+
_, page_size_per_kv_packing, kv_packing, cache_kv_dim = cache_kv.shape
|
|
61
|
+
assert kv_dim == cache_kv_dim
|
|
51
62
|
page_size = page_size_per_kv_packing * kv_packing
|
|
52
63
|
|
|
53
64
|
max_num_seqs = kv_lens.shape[0]
|
|
54
65
|
num_page_indices = page_indices.shape[0]
|
|
55
66
|
pages_per_seq = num_page_indices // max_num_seqs
|
|
56
67
|
|
|
57
|
-
def seq_loop_body(i,
|
|
58
|
-
cache_kv_c, cache_k_pe = caches
|
|
68
|
+
def seq_loop_body(i, cache_kv):
|
|
59
69
|
q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1]
|
|
60
70
|
q_len = q_end - q_start
|
|
61
71
|
kv_len = kv_lens[i]
|
|
62
72
|
|
|
63
|
-
def token_loop_body(j,
|
|
64
|
-
cache_kv_c_, cache_k_pe_ = caches_
|
|
73
|
+
def token_loop_body(j, cache_kv_):
|
|
65
74
|
token_idx_in_seq = kv_len - q_len + j
|
|
66
75
|
page_num_in_seq = token_idx_in_seq // page_size
|
|
67
76
|
page_indices_start = i * pages_per_seq
|
|
@@ -69,18 +78,17 @@ def update_kv_cache(
|
|
|
69
78
|
row = (token_idx_in_seq % page_size) // kv_packing
|
|
70
79
|
col = (token_idx_in_seq % page_size) % kv_packing
|
|
71
80
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
return
|
|
81
|
+
cache_kv_ = cache_kv_.at[page_idx, row, col,
|
|
82
|
+
..., :lkv_dim].set(new_kv_c[q_start + j])
|
|
83
|
+
cache_kv_ = cache_kv_.at[page_idx, row, col, ...,
|
|
84
|
+
lkv_dim:].set(new_k_pe[q_start + j])
|
|
85
|
+
return cache_kv_
|
|
86
|
+
|
|
87
|
+
return lax.fori_loop(0, q_len, token_loop_body, cache_kv)
|
|
77
88
|
|
|
78
|
-
|
|
79
|
-
(cache_kv_c, cache_k_pe))
|
|
89
|
+
cache_kv = lax.fori_loop(0, distribution[-1], seq_loop_body, cache_kv)
|
|
80
90
|
|
|
81
|
-
|
|
82
|
-
(cache_kv_c, cache_k_pe))
|
|
83
|
-
return cache_kv_c, cache_k_pe
|
|
91
|
+
return cache_kv
|
|
84
92
|
|
|
85
93
|
|
|
86
94
|
def ref_mla_ragged_paged_attention(
|
|
@@ -88,10 +96,8 @@ def ref_mla_ragged_paged_attention(
|
|
|
88
96
|
q_pe: jax.Array, # [num_tokens, actual_num_q_heads, actual_r_dim]
|
|
89
97
|
new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
|
|
90
98
|
new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
|
|
91
|
-
|
|
99
|
+
cache_kv: jax.
|
|
92
100
|
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
|
|
93
|
-
cache_k_pe: jax.
|
|
94
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
95
101
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
96
102
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
97
103
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -111,8 +117,7 @@ def ref_mla_ragged_paged_attention(
|
|
|
111
117
|
q_pe,
|
|
112
118
|
new_kv_c,
|
|
113
119
|
new_k_pe,
|
|
114
|
-
|
|
115
|
-
cache_k_pe,
|
|
120
|
+
cache_kv,
|
|
116
121
|
kv_lens,
|
|
117
122
|
page_indices,
|
|
118
123
|
cu_q_lens,
|
|
@@ -123,11 +128,10 @@ def ref_mla_ragged_paged_attention(
|
|
|
123
128
|
mask_value=mask_value,
|
|
124
129
|
)
|
|
125
130
|
|
|
126
|
-
|
|
131
|
+
updated_cache_kv = update_kv_cache(
|
|
127
132
|
new_kv_c,
|
|
128
133
|
new_k_pe,
|
|
129
|
-
|
|
130
|
-
cache_k_pe,
|
|
134
|
+
cache_kv,
|
|
131
135
|
kv_lens,
|
|
132
136
|
page_indices,
|
|
133
137
|
cu_q_lens,
|
|
@@ -154,13 +158,17 @@ def ref_mla_ragged_paged_attention(
|
|
|
154
158
|
assert num_page_indices % max_num_seqs == 0
|
|
155
159
|
pages_per_seq = num_page_indices // max_num_seqs
|
|
156
160
|
|
|
157
|
-
total_num_pages, page_size_per_kv_packing, kv_packing, _ =
|
|
161
|
+
total_num_pages, page_size_per_kv_packing, kv_packing, _ = updated_cache_kv.shape
|
|
158
162
|
page_size = page_size_per_kv_packing * kv_packing
|
|
159
163
|
assert lkv_dim == ql_nope.shape[-1]
|
|
160
164
|
assert r_dim == q_pe.shape[-1]
|
|
165
|
+
assert lkv_dim + r_dim == updated_cache_kv.shape[-1]
|
|
161
166
|
|
|
162
|
-
kv_c_cache =
|
|
163
|
-
|
|
167
|
+
kv_c_cache = updated_cache_kv[..., :lkv_dim].reshape(
|
|
168
|
+
total_num_pages, page_size, lkv_dim)
|
|
169
|
+
k_pe_cache = updated_cache_kv[...,
|
|
170
|
+
lkv_dim:].reshape(total_num_pages, page_size,
|
|
171
|
+
r_dim)
|
|
164
172
|
|
|
165
173
|
outputs = []
|
|
166
174
|
|
|
@@ -221,8 +229,7 @@ def ref_mla_ragged_paged_attention(
|
|
|
221
229
|
|
|
222
230
|
return (
|
|
223
231
|
jnp.concatenate(outputs, axis=0),
|
|
224
|
-
|
|
225
|
-
cache_k_pe,
|
|
232
|
+
updated_cache_kv,
|
|
226
233
|
)
|
|
227
234
|
|
|
228
235
|
|
|
@@ -232,10 +239,8 @@ def dynamic_validate_inputs(
|
|
|
232
239
|
q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
|
|
233
240
|
new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
|
|
234
241
|
new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
|
|
235
|
-
|
|
242
|
+
cache_kv: jax.
|
|
236
243
|
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
|
|
237
|
-
cache_k_pe: jax.
|
|
238
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
239
244
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
240
245
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
241
246
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -260,8 +265,7 @@ def dynamic_validate_inputs(
|
|
|
260
265
|
q_pe,
|
|
261
266
|
new_kv_c,
|
|
262
267
|
new_k_pe,
|
|
263
|
-
|
|
264
|
-
cache_k_pe,
|
|
268
|
+
cache_kv,
|
|
265
269
|
kv_lens,
|
|
266
270
|
page_indices,
|
|
267
271
|
cu_q_lens,
|
|
@@ -277,8 +281,8 @@ def dynamic_validate_inputs(
|
|
|
277
281
|
debug_mode=debug_mode,
|
|
278
282
|
)
|
|
279
283
|
max_num_tokens = ql_nope.shape[0]
|
|
280
|
-
total_num_pages =
|
|
281
|
-
_, page_size_per_kv_packing, kv_packing, _ =
|
|
284
|
+
total_num_pages = cache_kv.shape[0]
|
|
285
|
+
_, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
|
|
282
286
|
page_size = page_size_per_kv_packing * kv_packing
|
|
283
287
|
max_num_seqs = kv_lens.shape[0]
|
|
284
288
|
num_page_indices = page_indices.shape[0]
|
|
@@ -320,10 +324,8 @@ def static_validate_inputs(
|
|
|
320
324
|
q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
|
|
321
325
|
new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
|
|
322
326
|
new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
|
|
323
|
-
|
|
327
|
+
cache_kv: jax.
|
|
324
328
|
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
|
|
325
|
-
cache_k_pe: jax.
|
|
326
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
327
329
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
328
330
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
329
331
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -373,44 +375,34 @@ def static_validate_inputs(
|
|
|
373
375
|
|
|
374
376
|
actual_lkv_dim = ql_nope.shape[2]
|
|
375
377
|
actual_r_dim = q_pe.shape[2]
|
|
378
|
+
lkv_dim = align_to(actual_lkv_dim, 128)
|
|
379
|
+
r_dim = align_to(actual_r_dim, 128)
|
|
376
380
|
|
|
377
381
|
(
|
|
378
382
|
_,
|
|
379
383
|
page_size_per_kv_packing,
|
|
380
384
|
kv_packing,
|
|
381
|
-
|
|
382
|
-
) =
|
|
383
|
-
_, _, _, r_dim = cache_k_pe.shape
|
|
385
|
+
kv_dim,
|
|
386
|
+
) = cache_kv.shape
|
|
384
387
|
|
|
385
|
-
if lkv_dim !=
|
|
386
|
-
raise ValueError(
|
|
387
|
-
f"Expected {lkv_dim=} is equal to {align_to(actual_lkv_dim, 128)=}"
|
|
388
|
-
)
|
|
389
|
-
if r_dim != align_to(actual_r_dim, 128):
|
|
388
|
+
if lkv_dim + r_dim != kv_dim:
|
|
390
389
|
raise ValueError(
|
|
391
|
-
f"Expected {r_dim=}
|
|
390
|
+
f"Expected {lkv_dim=} + {r_dim=} to be equal to {kv_dim=}")
|
|
392
391
|
|
|
393
|
-
if not (
|
|
392
|
+
if not (cache_kv.dtype == new_kv_c.dtype):
|
|
394
393
|
raise ValueError(
|
|
395
|
-
f"Expected {
|
|
396
|
-
if not (
|
|
394
|
+
f"Expected {cache_kv.dtype=} to be equal to {new_kv_c.dtype=}.")
|
|
395
|
+
if not (cache_kv.dtype == new_k_pe.dtype):
|
|
397
396
|
raise ValueError(
|
|
398
|
-
f"Expected {
|
|
397
|
+
f"Expected {cache_kv.dtype=} to be equal to {new_k_pe.dtype=}.")
|
|
399
398
|
|
|
400
399
|
# Integer kv quantization is currently not supported.
|
|
401
|
-
if not jnp.issubdtype(
|
|
402
|
-
raise ValueError(
|
|
403
|
-
f"Expected {cache_kv_c.dtype=} to be a floating point.")
|
|
404
|
-
if not jnp.issubdtype(cache_k_pe.dtype, jnp.floating):
|
|
405
|
-
raise ValueError(
|
|
406
|
-
f"Expected {cache_k_pe.dtype=} to be a floating point.")
|
|
400
|
+
if not jnp.issubdtype(cache_kv.dtype, jnp.floating):
|
|
401
|
+
raise ValueError(f"Expected {cache_kv.dtype=} to be a floating point.")
|
|
407
402
|
|
|
408
|
-
if kv_packing != get_dtype_packing(
|
|
403
|
+
if kv_packing != get_dtype_packing(cache_kv.dtype):
|
|
409
404
|
raise ValueError(
|
|
410
|
-
f"{kv_packing=} does not match with {
|
|
411
|
-
if kv_packing != get_dtype_packing(cache_k_pe.dtype):
|
|
412
|
-
raise ValueError(
|
|
413
|
-
f"{kv_packing=} does not match with {cache_k_pe.dtype=}")
|
|
405
|
+
f"{kv_packing=} does not match with {cache_kv.dtype=}")
|
|
414
406
|
|
|
415
407
|
if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
|
|
416
408
|
== distribution.dtype):
|
|
@@ -475,14 +467,12 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
475
467
|
q_pe_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, r_dim]
|
|
476
468
|
new_kv_c_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, lkv_dim]
|
|
477
469
|
new_k_pe_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, r_dim]
|
|
478
|
-
|
|
479
|
-
cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
470
|
+
cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
|
|
480
471
|
# Output
|
|
481
472
|
o_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
|
|
482
|
-
|
|
483
|
-
updated_cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
473
|
+
updated_cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
|
|
484
474
|
# Scratch
|
|
485
|
-
bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim]
|
|
475
|
+
bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim].
|
|
486
476
|
bkpe_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, r_dim]
|
|
487
477
|
bq_nope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim]
|
|
488
478
|
bq_rope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, r_dim]
|
|
@@ -505,20 +495,24 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
505
495
|
debug_mode: bool = False,
|
|
506
496
|
):
|
|
507
497
|
assert ql_nope_hbm_ref.shape == o_hbm_ref.shape
|
|
508
|
-
|
|
509
|
-
|
|
498
|
+
# Validation checks on the dimensions
|
|
499
|
+
nope_dim = ql_nope_hbm_ref.shape[-1]
|
|
500
|
+
pe_dim = q_pe_hbm_ref.shape[-1]
|
|
501
|
+
assert nope_dim + pe_dim == cache_kv_hbm_ref.shape[-1]
|
|
502
|
+
|
|
510
503
|
_, num_q_heads_per_q_packing, q_packing, lkv_dim = ql_nope_hbm_ref.shape
|
|
511
504
|
r_dim = q_pe_hbm_ref.shape[-1]
|
|
512
505
|
num_q_heads = num_q_heads_per_q_packing * q_packing
|
|
513
506
|
total_num_pages, page_size_per_kv_packing, kv_packing, _ = (
|
|
514
|
-
|
|
507
|
+
cache_kv_hbm_ref.shape)
|
|
515
508
|
max_num_seqs = kv_lens_ref.shape[0]
|
|
516
509
|
num_page_indices = page_indices_ref.shape[0]
|
|
517
510
|
|
|
518
511
|
assert num_page_indices % max_num_seqs == 0
|
|
519
512
|
pages_per_seq = num_page_indices // max_num_seqs
|
|
520
513
|
q_dtype = ql_nope_hbm_ref.dtype
|
|
521
|
-
|
|
514
|
+
# Validate against the KV dtype.
|
|
515
|
+
kv_dtype = cache_kv_hbm_ref.dtype
|
|
522
516
|
assert q_pe_hbm_ref.dtype == q_dtype
|
|
523
517
|
assert o_hbm_ref.dtype == q_dtype
|
|
524
518
|
assert get_dtype_packing(q_dtype) == q_packing
|
|
@@ -561,8 +555,8 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
561
555
|
def flash_attention(
|
|
562
556
|
ql_nope, # [actual_bq_sz * num_q_heads, lkv_dim]
|
|
563
557
|
q_pe, # [actual_bq_sz * num_q_heads, r_dim]
|
|
564
|
-
kv_c, # [bkv_sz, lkv_dim]
|
|
565
|
-
k_pe, # [bkv_sz, r_dim]
|
|
558
|
+
kv_c, # [bkv_sz, lkv_dim] <- Correspond to data from bkvc_x2_ref
|
|
559
|
+
k_pe, # [bkv_sz, r_dim] <- Correspond to data from bpe_x2_ref
|
|
566
560
|
*,
|
|
567
561
|
bq_idx,
|
|
568
562
|
bkv_idx,
|
|
@@ -649,14 +643,9 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
649
643
|
sem = sems.at[0, bkv_sem_idx]
|
|
650
644
|
bkvc_vmem_ref = bkvc_x2_ref.at[bkv_sem_idx]
|
|
651
645
|
bkvpe_vmem_ref = bkpe_x2_ref.at[bkv_sem_idx]
|
|
652
|
-
|
|
653
|
-
reshaped_cache_kv_c_hbm_ref = cache_kv_c_hbm_ref.reshape(
|
|
646
|
+
reshaped_cache_hbm_ref = cache_kv_hbm_ref.reshape(
|
|
654
647
|
total_num_pages * page_size_per_kv_packing,
|
|
655
|
-
*
|
|
656
|
-
)
|
|
657
|
-
reshaped_cache_k_pe_hbm_ref = cache_k_pe_hbm_ref.reshape(
|
|
658
|
-
total_num_pages * page_size_per_kv_packing,
|
|
659
|
-
*cache_k_pe_hbm_ref.shape[2:],
|
|
648
|
+
*cache_kv_hbm_ref.shape[2:],
|
|
660
649
|
)
|
|
661
650
|
kv_len = kv_lens_ref[seq_idx]
|
|
662
651
|
kv_len_start = bkv_idx * bkv_sz
|
|
@@ -684,22 +673,22 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
684
673
|
kv_left_per_kv_packing - i * page_size_per_kv_packing,
|
|
685
674
|
)
|
|
686
675
|
_async_copy(
|
|
687
|
-
|
|
676
|
+
reshaped_cache_hbm_ref.at[pl.ds(
|
|
688
677
|
page_indices_ref[page_indices_offset + i] *
|
|
689
678
|
page_size_per_kv_packing,
|
|
690
679
|
sz_per_kv_packing,
|
|
691
|
-
)],
|
|
680
|
+
), ..., :nope_dim],
|
|
692
681
|
bkvc_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
|
|
693
682
|
sz_per_kv_packing)],
|
|
694
683
|
sem,
|
|
695
684
|
wait,
|
|
696
685
|
)
|
|
697
686
|
_async_copy(
|
|
698
|
-
|
|
687
|
+
reshaped_cache_hbm_ref.at[pl.ds(
|
|
699
688
|
page_indices_ref[page_indices_offset + i] *
|
|
700
689
|
page_size_per_kv_packing,
|
|
701
690
|
sz_per_kv_packing,
|
|
702
|
-
)],
|
|
691
|
+
), ..., nope_dim:],
|
|
703
692
|
bkvpe_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
|
|
704
693
|
sz_per_kv_packing)],
|
|
705
694
|
sem,
|
|
@@ -835,7 +824,6 @@ def _mla_ragged_paged_attention_kernel(
|
|
|
835
824
|
jnp.zeros_like(concated_bkvc_vec))
|
|
836
825
|
concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
|
|
837
826
|
kv_dtype)
|
|
838
|
-
|
|
839
827
|
bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
|
|
840
828
|
bkv_sz_per_kv_packing, r_dim))
|
|
841
829
|
bkpe_vec = bkpe_ref[...]
|
|
@@ -1082,17 +1070,16 @@ def prepare_outputs(
|
|
|
1082
1070
|
"vmem_limit_bytes",
|
|
1083
1071
|
"debug_mode",
|
|
1084
1072
|
),
|
|
1085
|
-
donate_argnames=("
|
|
1073
|
+
donate_argnames=("cache_kv"),
|
|
1086
1074
|
)
|
|
1087
1075
|
def mla_ragged_paged_attention(
|
|
1088
1076
|
ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
|
|
1089
1077
|
q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
|
|
1090
1078
|
new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
|
|
1091
1079
|
new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
|
|
1080
|
+
# TODO(gpolovets): Explore separating out into lkv & pe KV caches.
|
|
1081
|
+
cache_kv: jax.
|
|
1082
|
+
Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim, 128)]
|
|
1096
1083
|
kv_lens: jax.Array, # i32[max_num_seqs]
|
|
1097
1084
|
page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
|
|
1098
1085
|
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
|
@@ -1124,8 +1111,7 @@ def mla_ragged_paged_attention(
|
|
|
1124
1111
|
q_pe: concatenated all sequences' rope.
|
|
1125
1112
|
new_kv_c: concatenated all sequences' kv_c values
|
|
1126
1113
|
new_k_pe: concatenated all sequences' k_pe values
|
|
1127
|
-
|
|
1128
|
-
cache_k_pe: the current k_pe cache.
|
|
1114
|
+
cache_kv: the current kv cache.
|
|
1129
1115
|
kv_lens: the length of each sequence in the kv cache.
|
|
1130
1116
|
page_indices: flattened page indices look-up table by (seq_id, page_id).
|
|
1131
1117
|
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
|
@@ -1159,8 +1145,7 @@ def mla_ragged_paged_attention(
|
|
|
1159
1145
|
q_pe,
|
|
1160
1146
|
new_kv_c,
|
|
1161
1147
|
new_k_pe,
|
|
1162
|
-
|
|
1163
|
-
cache_k_pe,
|
|
1148
|
+
cache_kv,
|
|
1164
1149
|
kv_lens,
|
|
1165
1150
|
page_indices,
|
|
1166
1151
|
cu_q_lens,
|
|
@@ -1177,11 +1162,10 @@ def mla_ragged_paged_attention(
|
|
|
1177
1162
|
)
|
|
1178
1163
|
|
|
1179
1164
|
# TODO(chengjiyao): fuse kv cache update into the kernel.
|
|
1180
|
-
|
|
1165
|
+
cache_kv = update_kv_cache(
|
|
1181
1166
|
new_kv_c,
|
|
1182
1167
|
new_k_pe,
|
|
1183
|
-
|
|
1184
|
-
cache_k_pe,
|
|
1168
|
+
cache_kv,
|
|
1185
1169
|
kv_lens,
|
|
1186
1170
|
page_indices,
|
|
1187
1171
|
cu_q_lens,
|
|
@@ -1202,7 +1186,7 @@ def mla_ragged_paged_attention(
|
|
|
1202
1186
|
lkv_dim = new_kv_c.shape[-1]
|
|
1203
1187
|
r_dim = new_k_pe.shape[-1]
|
|
1204
1188
|
|
|
1205
|
-
_, page_size_per_kv_packing, kv_packing, _ =
|
|
1189
|
+
_, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
|
|
1206
1190
|
page_size = page_size_per_kv_packing * kv_packing
|
|
1207
1191
|
_, num_q_heads_per_q_packing, q_packing, _ = ql_nope.shape
|
|
1208
1192
|
max_num_seqs = kv_lens.shape[0]
|
|
@@ -1221,23 +1205,21 @@ def mla_ragged_paged_attention(
|
|
|
1221
1205
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1222
1206
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1223
1207
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1224
|
-
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1225
1208
|
]
|
|
1226
1209
|
|
|
1227
1210
|
out_specs = [
|
|
1228
1211
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1229
1212
|
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1230
|
-
pl.BlockSpec(memory_space=pltpu.HBM),
|
|
1231
1213
|
]
|
|
1232
1214
|
|
|
1233
1215
|
bkvc_double_buf = pltpu.VMEM(
|
|
1234
1216
|
(2, bkv_sz_per_kv_packing, kv_packing, lkv_dim),
|
|
1235
|
-
|
|
1217
|
+
cache_kv.dtype,
|
|
1236
1218
|
)
|
|
1237
1219
|
|
|
1238
1220
|
bkpe_double_buf = pltpu.VMEM(
|
|
1239
1221
|
(2, bkv_sz_per_kv_packing, kv_packing, r_dim),
|
|
1240
|
-
|
|
1222
|
+
cache_kv.dtype,
|
|
1241
1223
|
)
|
|
1242
1224
|
|
|
1243
1225
|
bq_nope_double_buf = pltpu.VMEM(
|
|
@@ -1320,30 +1302,26 @@ def mla_ragged_paged_attention(
|
|
|
1320
1302
|
),
|
|
1321
1303
|
out_shape=[
|
|
1322
1304
|
jax.ShapeDtypeStruct(shape=ql_nope.shape, dtype=ql_nope.dtype),
|
|
1323
|
-
jax.ShapeDtypeStruct(shape=
|
|
1324
|
-
dtype=
|
|
1325
|
-
jax.ShapeDtypeStruct(shape=cache_k_pe.shape,
|
|
1326
|
-
dtype=cache_k_pe.dtype),
|
|
1305
|
+
jax.ShapeDtypeStruct(shape=cache_kv.shape,
|
|
1306
|
+
dtype=cache_kv.dtype),
|
|
1327
1307
|
],
|
|
1328
1308
|
input_output_aliases={
|
|
1329
1309
|
7: 0,
|
|
1330
1310
|
11: 1,
|
|
1331
|
-
12: 2
|
|
1332
1311
|
},
|
|
1333
1312
|
name=scope_name,
|
|
1334
1313
|
))
|
|
1335
1314
|
|
|
1336
|
-
output,
|
|
1315
|
+
output, updated_kv = kernel(
|
|
1337
1316
|
*scalar_prefetches,
|
|
1338
1317
|
ql_nope,
|
|
1339
1318
|
q_pe,
|
|
1340
1319
|
new_kv_c,
|
|
1341
1320
|
new_k_pe,
|
|
1342
|
-
|
|
1343
|
-
cache_k_pe,
|
|
1321
|
+
cache_kv,
|
|
1344
1322
|
)
|
|
1345
1323
|
output = prepare_outputs(
|
|
1346
1324
|
output, actual_num_q_heads,
|
|
1347
1325
|
actual_lkv_dim) # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
|
|
1348
1326
|
|
|
1349
|
-
return output,
|
|
1327
|
+
return output, updated_kv
|
|
@@ -9,12 +9,58 @@ from jax._src import dtypes
|
|
|
9
9
|
from jax.experimental import pallas as pl
|
|
10
10
|
from jax.experimental.pallas import tpu as pltpu
|
|
11
11
|
|
|
12
|
+
from tpu_inference.kernels.quantized_matmul import util
|
|
12
13
|
from tpu_inference.kernels.quantized_matmul.tuned_block_sizes import (
|
|
13
14
|
TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
|
|
14
15
|
from tpu_inference.kernels.quantized_matmul.util import (get_kernel_name,
|
|
15
16
|
next_multiple,
|
|
16
17
|
unfold_args)
|
|
17
18
|
|
|
19
|
+
quantize_tensor = util.quantize_tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def xla_quantized_matmul(
|
|
23
|
+
x: jax.Array,
|
|
24
|
+
w_q: jax.Array,
|
|
25
|
+
w_scale: jax.Array,
|
|
26
|
+
quantize_activation=True,
|
|
27
|
+
) -> jax.Array:
|
|
28
|
+
"""
|
|
29
|
+
Reference (pure JAX) implementation of the quantized matmul kernel below.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
x: Activation.
|
|
33
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
34
|
+
w_s: Weight quantization scale. [n_output_features]
|
|
35
|
+
mesh: Mesh to shard on.
|
|
36
|
+
weight_sharding: PartitionSpec for the weight tensor.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Output of the quantized matmul.
|
|
40
|
+
"""
|
|
41
|
+
if quantize_activation:
|
|
42
|
+
acc_dtype = jnp.float32
|
|
43
|
+
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
|
|
44
|
+
acc_dtype = jnp.int32
|
|
45
|
+
|
|
46
|
+
x_q, x_scale = quantize_tensor(x, w_q.dtype)
|
|
47
|
+
out = jax.lax.dot_general(
|
|
48
|
+
x_q,
|
|
49
|
+
w_q,
|
|
50
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
51
|
+
preferred_element_type=acc_dtype,
|
|
52
|
+
).astype(jnp.float32)
|
|
53
|
+
out *= x_scale
|
|
54
|
+
else:
|
|
55
|
+
out = jax.lax.dot_general(
|
|
56
|
+
x,
|
|
57
|
+
w_q,
|
|
58
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
59
|
+
preferred_element_type=jnp.float32,
|
|
60
|
+
)
|
|
61
|
+
out *= jnp.expand_dims(w_scale, 0)
|
|
62
|
+
return out.astype(x.dtype)
|
|
63
|
+
|
|
18
64
|
|
|
19
65
|
def quantize_array(
|
|
20
66
|
x: jax.Array, # [bs_block_size, in_block_size]
|
|
@@ -50,11 +96,20 @@ def get_vmem_limit(
|
|
|
50
96
|
"""Calculate VMEM limit for the kernel."""
|
|
51
97
|
|
|
52
98
|
# Calculate in/out VMEM size.
|
|
53
|
-
x_size = batch_block_size *
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
99
|
+
x_size = (batch_block_size *
|
|
100
|
+
in_block_size * (dtypes.bit_width(x_dtype) if hasattr(
|
|
101
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(x_dtype)))
|
|
102
|
+
x_abs_max_size = (
|
|
103
|
+
batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
|
|
104
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
|
|
105
|
+
w_q_size = (out_block_size *
|
|
106
|
+
in_block_size * (dtypes.bit_width(w_q_dtype) if hasattr(
|
|
107
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(w_q_dtype)))
|
|
108
|
+
w_scale_size = (out_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
|
|
109
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
|
|
110
|
+
out_size = (batch_block_size *
|
|
111
|
+
out_block_size * (dtypes.bit_width(out_dtype) if hasattr(
|
|
112
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(out_dtype)))
|
|
58
113
|
|
|
59
114
|
vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
|
|
60
115
|
vmem_in_out *= 2 # Account for compute and vreg spills.
|
|
@@ -68,9 +123,15 @@ def get_vmem_limit(
|
|
|
68
123
|
vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
|
|
69
124
|
|
|
70
125
|
# Calculate scratch VMEM size.
|
|
71
|
-
acc_size = batch_block_size *
|
|
72
|
-
|
|
73
|
-
|
|
126
|
+
acc_size = (batch_block_size *
|
|
127
|
+
out_block_size * (dtypes.bit_width(acc_dtype) if hasattr(
|
|
128
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(acc_dtype)))
|
|
129
|
+
x_q_size = (batch_block_size *
|
|
130
|
+
in_block_size * (dtypes.bit_width(x_q_dtype) if hasattr(
|
|
131
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(x_q_dtype)))
|
|
132
|
+
x_scale_size = (
|
|
133
|
+
batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
|
|
134
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
|
|
74
135
|
|
|
75
136
|
vmem_scratch = acc_size if save_acc else 0
|
|
76
137
|
vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
|
|
@@ -200,7 +200,8 @@ def _prev_power_of_2(n: int) -> int:
|
|
|
200
200
|
def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int,
|
|
201
201
|
head_size: int, kv_cache_dtype) -> int:
|
|
202
202
|
"""Returns the size in bytes of one page of the KV cache."""
|
|
203
|
-
kv_cache_dtype_bit_size = dtypes.bit_width(kv_cache_dtype)
|
|
203
|
+
kv_cache_dtype_bit_size = (dtypes.bit_width(kv_cache_dtype) if hasattr(
|
|
204
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(kv_cache_dtype))
|
|
204
205
|
padded_head_size = _ceil_div(
|
|
205
206
|
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
|
206
207
|
|