tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (59) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -1
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/test_envs.py +110 -12
  9. tests/test_quantization.py +3 -0
  10. tests/test_utils.py +1 -2
  11. tpu_inference/distributed/tpu_connector.py +1 -1
  12. tpu_inference/envs.py +92 -8
  13. tpu_inference/executors/ray_distributed_executor.py +5 -1
  14. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  15. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  16. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  17. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  18. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  19. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  20. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  21. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
  22. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
  23. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  25. tpu_inference/layers/common/attention_interface.py +7 -1
  26. tpu_inference/layers/common/sharding.py +11 -7
  27. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  28. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  29. tpu_inference/layers/vllm/fused_moe.py +170 -208
  30. tpu_inference/layers/vllm/linear_common.py +43 -21
  31. tpu_inference/layers/vllm/quantization/common.py +11 -6
  32. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  33. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  34. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  35. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  36. tpu_inference/models/common/model_loader.py +78 -22
  37. tpu_inference/models/jax/deepseek_v3.py +185 -64
  38. tpu_inference/models/jax/gpt_oss.py +3 -3
  39. tpu_inference/models/jax/llama_eagle3.py +4 -5
  40. tpu_inference/models/jax/qwen2_5_vl.py +161 -47
  41. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  42. tpu_inference/models/jax/utils/weight_utils.py +203 -155
  43. tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
  44. tpu_inference/platforms/tpu_platform.py +29 -48
  45. tpu_inference/runner/compilation_manager.py +112 -46
  46. tpu_inference/runner/kv_cache.py +40 -20
  47. tpu_inference/runner/kv_cache_manager.py +40 -31
  48. tpu_inference/runner/persistent_batch_manager.py +40 -2
  49. tpu_inference/runner/structured_decoding_manager.py +2 -3
  50. tpu_inference/runner/tpu_runner.py +94 -51
  51. tpu_inference/runner/utils.py +2 -2
  52. tpu_inference/spec_decode/jax/eagle3.py +71 -22
  53. tpu_inference/utils.py +41 -14
  54. tpu_inference/worker/tpu_worker.py +43 -45
  55. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
  56. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
  57. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  58. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  59. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -16,17 +16,30 @@ DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
16
16
  DEFAULT_VMEM_LIMIT_BYTES = 100 * 1024 * 1024
17
17
 
18
18
 
19
+ def get_kv_cache_shape(
20
+ total_num_pages,
21
+ page_size,
22
+ kv_dim,
23
+ kv_dtype,
24
+ ):
25
+ kv_packing = get_dtype_packing(kv_dtype)
26
+ return (
27
+ total_num_pages,
28
+ align_to(page_size, kv_packing) // kv_packing,
29
+ kv_packing,
30
+ align_to(kv_dim, 128),
31
+ )
32
+
33
+
19
34
  @functools.partial(
20
35
  jax.jit,
21
- donate_argnames=("cache_kv_c", "cache_k_pe"),
36
+ donate_argnames=("cache_kv"),
22
37
  )
23
38
  def update_kv_cache(
24
39
  new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
25
40
  new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
26
- cache_kv_c: jax.
27
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
28
- cache_k_pe: jax.
29
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
41
+ cache_kv: jax.
42
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim+r_dim]
30
43
  kv_lens: jax.Array, # i32[max_num_seqs]
31
44
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
32
45
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -43,25 +56,21 @@ def update_kv_cache(
43
56
  if actual_lkv_dim != lkv_dim:
44
57
  new_kv_c = jnp.pad(new_kv_c, ((0, 0), (0, lkv_dim - actual_lkv_dim)),
45
58
  constant_values=0)
46
-
47
- _, page_size_per_kv_packing, kv_packing, cache_lkv_dim = cache_kv_c.shape
48
- _, _, _, cache_r_dim = cache_k_pe.shape
49
- assert lkv_dim == cache_lkv_dim
50
- assert r_dim == cache_r_dim
59
+ kv_dim = r_dim + lkv_dim
60
+ _, page_size_per_kv_packing, kv_packing, cache_kv_dim = cache_kv.shape
61
+ assert kv_dim == cache_kv_dim
51
62
  page_size = page_size_per_kv_packing * kv_packing
52
63
 
53
64
  max_num_seqs = kv_lens.shape[0]
54
65
  num_page_indices = page_indices.shape[0]
55
66
  pages_per_seq = num_page_indices // max_num_seqs
56
67
 
57
- def seq_loop_body(i, caches):
58
- cache_kv_c, cache_k_pe = caches
68
+ def seq_loop_body(i, cache_kv):
59
69
  q_start, q_end = cu_q_lens[i], cu_q_lens[i + 1]
60
70
  q_len = q_end - q_start
61
71
  kv_len = kv_lens[i]
62
72
 
63
- def token_loop_body(j, caches_):
64
- cache_kv_c_, cache_k_pe_ = caches_
73
+ def token_loop_body(j, cache_kv_):
65
74
  token_idx_in_seq = kv_len - q_len + j
66
75
  page_num_in_seq = token_idx_in_seq // page_size
67
76
  page_indices_start = i * pages_per_seq
@@ -69,18 +78,17 @@ def update_kv_cache(
69
78
  row = (token_idx_in_seq % page_size) // kv_packing
70
79
  col = (token_idx_in_seq % page_size) % kv_packing
71
80
 
72
- cache_kv_c_ = cache_kv_c_.at[page_idx, row,
73
- col].set(new_kv_c[q_start + j])
74
- cache_k_pe_ = cache_k_pe_.at[page_idx, row,
75
- col].set(new_k_pe[q_start + j])
76
- return cache_kv_c_, cache_k_pe_
81
+ cache_kv_ = cache_kv_.at[page_idx, row, col,
82
+ ..., :lkv_dim].set(new_kv_c[q_start + j])
83
+ cache_kv_ = cache_kv_.at[page_idx, row, col, ...,
84
+ lkv_dim:].set(new_k_pe[q_start + j])
85
+ return cache_kv_
86
+
87
+ return lax.fori_loop(0, q_len, token_loop_body, cache_kv)
77
88
 
78
- return lax.fori_loop(0, q_len, token_loop_body,
79
- (cache_kv_c, cache_k_pe))
89
+ cache_kv = lax.fori_loop(0, distribution[-1], seq_loop_body, cache_kv)
80
90
 
81
- cache_kv_c, cache_k_pe = lax.fori_loop(0, distribution[-1], seq_loop_body,
82
- (cache_kv_c, cache_k_pe))
83
- return cache_kv_c, cache_k_pe
91
+ return cache_kv
84
92
 
85
93
 
86
94
  def ref_mla_ragged_paged_attention(
@@ -88,10 +96,8 @@ def ref_mla_ragged_paged_attention(
88
96
  q_pe: jax.Array, # [num_tokens, actual_num_q_heads, actual_r_dim]
89
97
  new_kv_c: jax.Array, # [num_tokens, actual_lkv_dim]
90
98
  new_k_pe: jax.Array, # [num_tokens, actual_r_dim]
91
- cache_kv_c: jax.
99
+ cache_kv: jax.
92
100
  Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
93
- cache_k_pe: jax.
94
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
95
101
  kv_lens: jax.Array, # i32[max_num_seqs]
96
102
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
97
103
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -111,8 +117,7 @@ def ref_mla_ragged_paged_attention(
111
117
  q_pe,
112
118
  new_kv_c,
113
119
  new_k_pe,
114
- cache_kv_c,
115
- cache_k_pe,
120
+ cache_kv,
116
121
  kv_lens,
117
122
  page_indices,
118
123
  cu_q_lens,
@@ -123,11 +128,10 @@ def ref_mla_ragged_paged_attention(
123
128
  mask_value=mask_value,
124
129
  )
125
130
 
126
- cache_kv_c, cache_k_pe = update_kv_cache(
131
+ updated_cache_kv = update_kv_cache(
127
132
  new_kv_c,
128
133
  new_k_pe,
129
- cache_kv_c,
130
- cache_k_pe,
134
+ cache_kv,
131
135
  kv_lens,
132
136
  page_indices,
133
137
  cu_q_lens,
@@ -154,13 +158,17 @@ def ref_mla_ragged_paged_attention(
154
158
  assert num_page_indices % max_num_seqs == 0
155
159
  pages_per_seq = num_page_indices // max_num_seqs
156
160
 
157
- total_num_pages, page_size_per_kv_packing, kv_packing, _ = cache_kv_c.shape
161
+ total_num_pages, page_size_per_kv_packing, kv_packing, _ = updated_cache_kv.shape
158
162
  page_size = page_size_per_kv_packing * kv_packing
159
163
  assert lkv_dim == ql_nope.shape[-1]
160
164
  assert r_dim == q_pe.shape[-1]
165
+ assert lkv_dim + r_dim == updated_cache_kv.shape[-1]
161
166
 
162
- kv_c_cache = cache_kv_c.reshape(total_num_pages, page_size, lkv_dim)
163
- k_pe_cache = cache_k_pe.reshape(total_num_pages, page_size, r_dim)
167
+ kv_c_cache = updated_cache_kv[..., :lkv_dim].reshape(
168
+ total_num_pages, page_size, lkv_dim)
169
+ k_pe_cache = updated_cache_kv[...,
170
+ lkv_dim:].reshape(total_num_pages, page_size,
171
+ r_dim)
164
172
 
165
173
  outputs = []
166
174
 
@@ -221,8 +229,7 @@ def ref_mla_ragged_paged_attention(
221
229
 
222
230
  return (
223
231
  jnp.concatenate(outputs, axis=0),
224
- cache_kv_c,
225
- cache_k_pe,
232
+ updated_cache_kv,
226
233
  )
227
234
 
228
235
 
@@ -232,10 +239,8 @@ def dynamic_validate_inputs(
232
239
  q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
233
240
  new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
234
241
  new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
235
- cache_kv_c: jax.
242
+ cache_kv: jax.
236
243
  Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
237
- cache_k_pe: jax.
238
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
239
244
  kv_lens: jax.Array, # i32[max_num_seqs]
240
245
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
241
246
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -260,8 +265,7 @@ def dynamic_validate_inputs(
260
265
  q_pe,
261
266
  new_kv_c,
262
267
  new_k_pe,
263
- cache_kv_c,
264
- cache_k_pe,
268
+ cache_kv,
265
269
  kv_lens,
266
270
  page_indices,
267
271
  cu_q_lens,
@@ -277,8 +281,8 @@ def dynamic_validate_inputs(
277
281
  debug_mode=debug_mode,
278
282
  )
279
283
  max_num_tokens = ql_nope.shape[0]
280
- total_num_pages = cache_kv_c.shape[0]
281
- _, page_size_per_kv_packing, kv_packing, _ = cache_kv_c.shape
284
+ total_num_pages = cache_kv.shape[0]
285
+ _, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
282
286
  page_size = page_size_per_kv_packing * kv_packing
283
287
  max_num_seqs = kv_lens.shape[0]
284
288
  num_page_indices = page_indices.shape[0]
@@ -320,10 +324,8 @@ def static_validate_inputs(
320
324
  q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
321
325
  new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
322
326
  new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
323
- cache_kv_c: jax.
327
+ cache_kv: jax.
324
328
  Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
325
- cache_k_pe: jax.
326
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
327
329
  kv_lens: jax.Array, # i32[max_num_seqs]
328
330
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
329
331
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -373,44 +375,34 @@ def static_validate_inputs(
373
375
 
374
376
  actual_lkv_dim = ql_nope.shape[2]
375
377
  actual_r_dim = q_pe.shape[2]
378
+ lkv_dim = align_to(actual_lkv_dim, 128)
379
+ r_dim = align_to(actual_r_dim, 128)
376
380
 
377
381
  (
378
382
  _,
379
383
  page_size_per_kv_packing,
380
384
  kv_packing,
381
- lkv_dim,
382
- ) = cache_kv_c.shape
383
- _, _, _, r_dim = cache_k_pe.shape
385
+ kv_dim,
386
+ ) = cache_kv.shape
384
387
 
385
- if lkv_dim != align_to(actual_lkv_dim, 128):
386
- raise ValueError(
387
- f"Expected {lkv_dim=} is equal to {align_to(actual_lkv_dim, 128)=}"
388
- )
389
- if r_dim != align_to(actual_r_dim, 128):
388
+ if lkv_dim + r_dim != kv_dim:
390
389
  raise ValueError(
391
- f"Expected {r_dim=} is equal to {align_to(actual_r_dim, 128)=}")
390
+ f"Expected {lkv_dim=} + {r_dim=} to be equal to {kv_dim=}")
392
391
 
393
- if not (cache_kv_c.dtype == new_kv_c.dtype):
392
+ if not (cache_kv.dtype == new_kv_c.dtype):
394
393
  raise ValueError(
395
- f"Expected {cache_kv_c.dtype=} to be equal to {new_kv_c.dtype=}.")
396
- if not (cache_k_pe.dtype == new_k_pe.dtype):
394
+ f"Expected {cache_kv.dtype=} to be equal to {new_kv_c.dtype=}.")
395
+ if not (cache_kv.dtype == new_k_pe.dtype):
397
396
  raise ValueError(
398
- f"Expected {cache_k_pe.dtype=} to be equal to {new_k_pe.dtype=}.")
397
+ f"Expected {cache_kv.dtype=} to be equal to {new_k_pe.dtype=}.")
399
398
 
400
399
  # Integer kv quantization is currently not supported.
401
- if not jnp.issubdtype(cache_kv_c.dtype, jnp.floating):
402
- raise ValueError(
403
- f"Expected {cache_kv_c.dtype=} to be a floating point.")
404
- if not jnp.issubdtype(cache_k_pe.dtype, jnp.floating):
405
- raise ValueError(
406
- f"Expected {cache_k_pe.dtype=} to be a floating point.")
400
+ if not jnp.issubdtype(cache_kv.dtype, jnp.floating):
401
+ raise ValueError(f"Expected {cache_kv.dtype=} to be a floating point.")
407
402
 
408
- if kv_packing != get_dtype_packing(cache_kv_c.dtype):
403
+ if kv_packing != get_dtype_packing(cache_kv.dtype):
409
404
  raise ValueError(
410
- f"{kv_packing=} does not match with {cache_kv_c.dtype=}")
411
- if kv_packing != get_dtype_packing(cache_k_pe.dtype):
412
- raise ValueError(
413
- f"{kv_packing=} does not match with {cache_k_pe.dtype=}")
405
+ f"{kv_packing=} does not match with {cache_kv.dtype=}")
414
406
 
415
407
  if not (jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype
416
408
  == distribution.dtype):
@@ -475,14 +467,12 @@ def _mla_ragged_paged_attention_kernel(
475
467
  q_pe_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, r_dim]
476
468
  new_kv_c_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, lkv_dim]
477
469
  new_k_pe_hbm_ref, # [max_num_tokens_per_kv_packing, kv_packing, r_dim]
478
- cache_kv_c_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
479
- cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
470
+ cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
480
471
  # Output
481
472
  o_hbm_ref, # [max_num_tokens, num_q_heads_per_q_packing, q_packing, lkv_dim]
482
- updated_cache_kv_c_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
483
- updated_cache_k_pe_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
473
+ updated_cache_kv_hbm_ref, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim + r_dim, 128)]
484
474
  # Scratch
485
- bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim]
475
+ bkvc_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, lkv_dim].
486
476
  bkpe_x2_ref, # [2, bkv_sz_per_kv_packing, kv_packing, r_dim]
487
477
  bq_nope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, lkv_dim]
488
478
  bq_rope_x2_ref, # [2, bq_sz, num_q_heads_per_q_packing, q_packing, r_dim]
@@ -505,20 +495,24 @@ def _mla_ragged_paged_attention_kernel(
505
495
  debug_mode: bool = False,
506
496
  ):
507
497
  assert ql_nope_hbm_ref.shape == o_hbm_ref.shape
508
- assert ql_nope_hbm_ref.shape[-1] == cache_kv_c_hbm_ref.shape[-1]
509
- assert q_pe_hbm_ref.shape[-1] == cache_k_pe_hbm_ref.shape[-1]
498
+ # Validation checks on the dimensions
499
+ nope_dim = ql_nope_hbm_ref.shape[-1]
500
+ pe_dim = q_pe_hbm_ref.shape[-1]
501
+ assert nope_dim + pe_dim == cache_kv_hbm_ref.shape[-1]
502
+
510
503
  _, num_q_heads_per_q_packing, q_packing, lkv_dim = ql_nope_hbm_ref.shape
511
504
  r_dim = q_pe_hbm_ref.shape[-1]
512
505
  num_q_heads = num_q_heads_per_q_packing * q_packing
513
506
  total_num_pages, page_size_per_kv_packing, kv_packing, _ = (
514
- cache_kv_c_hbm_ref.shape)
507
+ cache_kv_hbm_ref.shape)
515
508
  max_num_seqs = kv_lens_ref.shape[0]
516
509
  num_page_indices = page_indices_ref.shape[0]
517
510
 
518
511
  assert num_page_indices % max_num_seqs == 0
519
512
  pages_per_seq = num_page_indices // max_num_seqs
520
513
  q_dtype = ql_nope_hbm_ref.dtype
521
- kv_dtype = cache_kv_c_hbm_ref.dtype
514
+ # Validate against the KV dtype.
515
+ kv_dtype = cache_kv_hbm_ref.dtype
522
516
  assert q_pe_hbm_ref.dtype == q_dtype
523
517
  assert o_hbm_ref.dtype == q_dtype
524
518
  assert get_dtype_packing(q_dtype) == q_packing
@@ -561,8 +555,8 @@ def _mla_ragged_paged_attention_kernel(
561
555
  def flash_attention(
562
556
  ql_nope, # [actual_bq_sz * num_q_heads, lkv_dim]
563
557
  q_pe, # [actual_bq_sz * num_q_heads, r_dim]
564
- kv_c, # [bkv_sz, lkv_dim]
565
- k_pe, # [bkv_sz, r_dim]
558
+ kv_c, # [bkv_sz, lkv_dim] <- Correspond to data from bkvc_x2_ref
559
+ k_pe, # [bkv_sz, r_dim] <- Correspond to data from bpe_x2_ref
566
560
  *,
567
561
  bq_idx,
568
562
  bkv_idx,
@@ -649,14 +643,9 @@ def _mla_ragged_paged_attention_kernel(
649
643
  sem = sems.at[0, bkv_sem_idx]
650
644
  bkvc_vmem_ref = bkvc_x2_ref.at[bkv_sem_idx]
651
645
  bkvpe_vmem_ref = bkpe_x2_ref.at[bkv_sem_idx]
652
-
653
- reshaped_cache_kv_c_hbm_ref = cache_kv_c_hbm_ref.reshape(
646
+ reshaped_cache_hbm_ref = cache_kv_hbm_ref.reshape(
654
647
  total_num_pages * page_size_per_kv_packing,
655
- *cache_kv_c_hbm_ref.shape[2:],
656
- )
657
- reshaped_cache_k_pe_hbm_ref = cache_k_pe_hbm_ref.reshape(
658
- total_num_pages * page_size_per_kv_packing,
659
- *cache_k_pe_hbm_ref.shape[2:],
648
+ *cache_kv_hbm_ref.shape[2:],
660
649
  )
661
650
  kv_len = kv_lens_ref[seq_idx]
662
651
  kv_len_start = bkv_idx * bkv_sz
@@ -684,22 +673,22 @@ def _mla_ragged_paged_attention_kernel(
684
673
  kv_left_per_kv_packing - i * page_size_per_kv_packing,
685
674
  )
686
675
  _async_copy(
687
- reshaped_cache_kv_c_hbm_ref.at[pl.ds(
676
+ reshaped_cache_hbm_ref.at[pl.ds(
688
677
  page_indices_ref[page_indices_offset + i] *
689
678
  page_size_per_kv_packing,
690
679
  sz_per_kv_packing,
691
- )],
680
+ ), ..., :nope_dim],
692
681
  bkvc_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
693
682
  sz_per_kv_packing)],
694
683
  sem,
695
684
  wait,
696
685
  )
697
686
  _async_copy(
698
- reshaped_cache_k_pe_hbm_ref.at[pl.ds(
687
+ reshaped_cache_hbm_ref.at[pl.ds(
699
688
  page_indices_ref[page_indices_offset + i] *
700
689
  page_size_per_kv_packing,
701
690
  sz_per_kv_packing,
702
- )],
691
+ ), ..., nope_dim:],
703
692
  bkvpe_vmem_ref.at[pl.ds(i * page_size_per_kv_packing,
704
693
  sz_per_kv_packing)],
705
694
  sem,
@@ -835,7 +824,6 @@ def _mla_ragged_paged_attention_kernel(
835
824
  jnp.zeros_like(concated_bkvc_vec))
836
825
  concated_bkvc_vec = pltpu.bitcast(concated_bkvc_vec.astype(repack_ty),
837
826
  kv_dtype)
838
-
839
827
  bkpe_ref = (bkpe_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(
840
828
  bkv_sz_per_kv_packing, r_dim))
841
829
  bkpe_vec = bkpe_ref[...]
@@ -1082,17 +1070,16 @@ def prepare_outputs(
1082
1070
  "vmem_limit_bytes",
1083
1071
  "debug_mode",
1084
1072
  ),
1085
- donate_argnames=("cache_kv_c", "cache_k_pe"),
1073
+ donate_argnames=("cache_kv"),
1086
1074
  )
1087
1075
  def mla_ragged_paged_attention(
1088
1076
  ql_nope: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1089
1077
  q_pe: jax.Array, # [max_num_tokens, actual_num_q_heads, actual_r_dim]
1090
1078
  new_kv_c: jax.Array, # [max_num_tokens, actual_lkv_dim]
1091
1079
  new_k_pe: jax.Array, # [max_num_tokens, actual_r_dim]
1092
- cache_kv_c: jax.
1093
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, lkv_dim]
1094
- cache_k_pe: jax.
1095
- Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, r_dim]
1080
+ # TODO(gpolovets): Explore separating out into lkv & pe KV caches.
1081
+ cache_kv: jax.
1082
+ Array, # [total_num_pages, page_size_per_kv_packing, kv_packing, align_to(lkv_dim, 128)]
1096
1083
  kv_lens: jax.Array, # i32[max_num_seqs]
1097
1084
  page_indices: jax.Array, # i32[max_num_seqs * pages_per_seq]
1098
1085
  cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
@@ -1124,8 +1111,7 @@ def mla_ragged_paged_attention(
1124
1111
  q_pe: concatenated all sequences' rope.
1125
1112
  new_kv_c: concatenated all sequences' kv_c values
1126
1113
  new_k_pe: concatenated all sequences' k_pe values
1127
- cache_kv_c: the current kv_c cache.
1128
- cache_k_pe: the current k_pe cache.
1114
+ cache_kv: the current kv cache.
1129
1115
  kv_lens: the length of each sequence in the kv cache.
1130
1116
  page_indices: flattened page indices look-up table by (seq_id, page_id).
1131
1117
  cu_q_lens: the cumulative sum of the effective query lengths. Similar to
@@ -1159,8 +1145,7 @@ def mla_ragged_paged_attention(
1159
1145
  q_pe,
1160
1146
  new_kv_c,
1161
1147
  new_k_pe,
1162
- cache_kv_c,
1163
- cache_k_pe,
1148
+ cache_kv,
1164
1149
  kv_lens,
1165
1150
  page_indices,
1166
1151
  cu_q_lens,
@@ -1177,11 +1162,10 @@ def mla_ragged_paged_attention(
1177
1162
  )
1178
1163
 
1179
1164
  # TODO(chengjiyao): fuse kv cache update into the kernel.
1180
- cache_kv_c, cache_k_pe = update_kv_cache(
1165
+ cache_kv = update_kv_cache(
1181
1166
  new_kv_c,
1182
1167
  new_k_pe,
1183
- cache_kv_c,
1184
- cache_k_pe,
1168
+ cache_kv,
1185
1169
  kv_lens,
1186
1170
  page_indices,
1187
1171
  cu_q_lens,
@@ -1202,7 +1186,7 @@ def mla_ragged_paged_attention(
1202
1186
  lkv_dim = new_kv_c.shape[-1]
1203
1187
  r_dim = new_k_pe.shape[-1]
1204
1188
 
1205
- _, page_size_per_kv_packing, kv_packing, _ = cache_kv_c.shape
1189
+ _, page_size_per_kv_packing, kv_packing, _ = cache_kv.shape
1206
1190
  page_size = page_size_per_kv_packing * kv_packing
1207
1191
  _, num_q_heads_per_q_packing, q_packing, _ = ql_nope.shape
1208
1192
  max_num_seqs = kv_lens.shape[0]
@@ -1221,23 +1205,21 @@ def mla_ragged_paged_attention(
1221
1205
  pl.BlockSpec(memory_space=pltpu.HBM),
1222
1206
  pl.BlockSpec(memory_space=pltpu.HBM),
1223
1207
  pl.BlockSpec(memory_space=pltpu.HBM),
1224
- pl.BlockSpec(memory_space=pltpu.HBM),
1225
1208
  ]
1226
1209
 
1227
1210
  out_specs = [
1228
1211
  pl.BlockSpec(memory_space=pltpu.HBM),
1229
1212
  pl.BlockSpec(memory_space=pltpu.HBM),
1230
- pl.BlockSpec(memory_space=pltpu.HBM),
1231
1213
  ]
1232
1214
 
1233
1215
  bkvc_double_buf = pltpu.VMEM(
1234
1216
  (2, bkv_sz_per_kv_packing, kv_packing, lkv_dim),
1235
- cache_kv_c.dtype,
1217
+ cache_kv.dtype,
1236
1218
  )
1237
1219
 
1238
1220
  bkpe_double_buf = pltpu.VMEM(
1239
1221
  (2, bkv_sz_per_kv_packing, kv_packing, r_dim),
1240
- cache_k_pe.dtype,
1222
+ cache_kv.dtype,
1241
1223
  )
1242
1224
 
1243
1225
  bq_nope_double_buf = pltpu.VMEM(
@@ -1320,30 +1302,26 @@ def mla_ragged_paged_attention(
1320
1302
  ),
1321
1303
  out_shape=[
1322
1304
  jax.ShapeDtypeStruct(shape=ql_nope.shape, dtype=ql_nope.dtype),
1323
- jax.ShapeDtypeStruct(shape=cache_kv_c.shape,
1324
- dtype=cache_kv_c.dtype),
1325
- jax.ShapeDtypeStruct(shape=cache_k_pe.shape,
1326
- dtype=cache_k_pe.dtype),
1305
+ jax.ShapeDtypeStruct(shape=cache_kv.shape,
1306
+ dtype=cache_kv.dtype),
1327
1307
  ],
1328
1308
  input_output_aliases={
1329
1309
  7: 0,
1330
1310
  11: 1,
1331
- 12: 2
1332
1311
  },
1333
1312
  name=scope_name,
1334
1313
  ))
1335
1314
 
1336
- output, updated_kv_c, updated_k_pe = kernel(
1315
+ output, updated_kv = kernel(
1337
1316
  *scalar_prefetches,
1338
1317
  ql_nope,
1339
1318
  q_pe,
1340
1319
  new_kv_c,
1341
1320
  new_k_pe,
1342
- cache_kv_c,
1343
- cache_k_pe,
1321
+ cache_kv,
1344
1322
  )
1345
1323
  output = prepare_outputs(
1346
1324
  output, actual_num_q_heads,
1347
1325
  actual_lkv_dim) # [max_num_tokens, actual_num_q_heads, actual_lkv_dim]
1348
1326
 
1349
- return output, updated_kv_c, updated_k_pe
1327
+ return output, updated_kv
@@ -9,12 +9,58 @@ from jax._src import dtypes
9
9
  from jax.experimental import pallas as pl
10
10
  from jax.experimental.pallas import tpu as pltpu
11
11
 
12
+ from tpu_inference.kernels.quantized_matmul import util
12
13
  from tpu_inference.kernels.quantized_matmul.tuned_block_sizes import (
13
14
  TunedValue, get_device_vmem_limit, get_tuned_block_sizes)
14
15
  from tpu_inference.kernels.quantized_matmul.util import (get_kernel_name,
15
16
  next_multiple,
16
17
  unfold_args)
17
18
 
19
+ quantize_tensor = util.quantize_tensor
20
+
21
+
22
+ def xla_quantized_matmul(
23
+ x: jax.Array,
24
+ w_q: jax.Array,
25
+ w_scale: jax.Array,
26
+ quantize_activation=True,
27
+ ) -> jax.Array:
28
+ """
29
+ Reference (pure JAX) implementation of the quantized matmul kernel below.
30
+
31
+ Args:
32
+ x: Activation.
33
+ w_q: Weight quantized array. [n_output_features, n_input_features]
34
+ w_s: Weight quantization scale. [n_output_features]
35
+ mesh: Mesh to shard on.
36
+ weight_sharding: PartitionSpec for the weight tensor.
37
+
38
+ Returns:
39
+ Output of the quantized matmul.
40
+ """
41
+ if quantize_activation:
42
+ acc_dtype = jnp.float32
43
+ if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
44
+ acc_dtype = jnp.int32
45
+
46
+ x_q, x_scale = quantize_tensor(x, w_q.dtype)
47
+ out = jax.lax.dot_general(
48
+ x_q,
49
+ w_q,
50
+ dimension_numbers=(((1, ), (1, )), ((), ())),
51
+ preferred_element_type=acc_dtype,
52
+ ).astype(jnp.float32)
53
+ out *= x_scale
54
+ else:
55
+ out = jax.lax.dot_general(
56
+ x,
57
+ w_q,
58
+ dimension_numbers=(((1, ), (1, )), ((), ())),
59
+ preferred_element_type=jnp.float32,
60
+ )
61
+ out *= jnp.expand_dims(w_scale, 0)
62
+ return out.astype(x.dtype)
63
+
18
64
 
19
65
  def quantize_array(
20
66
  x: jax.Array, # [bs_block_size, in_block_size]
@@ -50,11 +96,20 @@ def get_vmem_limit(
50
96
  """Calculate VMEM limit for the kernel."""
51
97
 
52
98
  # Calculate in/out VMEM size.
53
- x_size = batch_block_size * in_block_size * dtypes.bit_width(x_dtype)
54
- x_abs_max_size = batch_block_size * dtypes.bit_width(scale_dtype)
55
- w_q_size = out_block_size * in_block_size * dtypes.bit_width(w_q_dtype)
56
- w_scale_size = out_block_size * dtypes.bit_width(scale_dtype)
57
- out_size = batch_block_size * out_block_size * dtypes.bit_width(out_dtype)
99
+ x_size = (batch_block_size *
100
+ in_block_size * (dtypes.bit_width(x_dtype) if hasattr(
101
+ dtypes, "bit_width") else dtypes.itemsize_bits(x_dtype)))
102
+ x_abs_max_size = (
103
+ batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
104
+ dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
105
+ w_q_size = (out_block_size *
106
+ in_block_size * (dtypes.bit_width(w_q_dtype) if hasattr(
107
+ dtypes, "bit_width") else dtypes.itemsize_bits(w_q_dtype)))
108
+ w_scale_size = (out_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
109
+ dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
110
+ out_size = (batch_block_size *
111
+ out_block_size * (dtypes.bit_width(out_dtype) if hasattr(
112
+ dtypes, "bit_width") else dtypes.itemsize_bits(out_dtype)))
58
113
 
59
114
  vmem_in_out = x_size + x_abs_max_size + w_q_size + w_scale_size + out_size
60
115
  vmem_in_out *= 2 # Account for compute and vreg spills.
@@ -68,9 +123,15 @@ def get_vmem_limit(
68
123
  vmem_in_out += out_size if (n_batch > 1 or n_out > 1) else 0
69
124
 
70
125
  # Calculate scratch VMEM size.
71
- acc_size = batch_block_size * out_block_size * dtypes.bit_width(acc_dtype)
72
- x_q_size = batch_block_size * in_block_size * dtypes.bit_width(x_q_dtype)
73
- x_scale_size = batch_block_size * dtypes.bit_width(scale_dtype)
126
+ acc_size = (batch_block_size *
127
+ out_block_size * (dtypes.bit_width(acc_dtype) if hasattr(
128
+ dtypes, "bit_width") else dtypes.itemsize_bits(acc_dtype)))
129
+ x_q_size = (batch_block_size *
130
+ in_block_size * (dtypes.bit_width(x_q_dtype) if hasattr(
131
+ dtypes, "bit_width") else dtypes.itemsize_bits(x_q_dtype)))
132
+ x_scale_size = (
133
+ batch_block_size * (dtypes.bit_width(scale_dtype) if hasattr(
134
+ dtypes, "bit_width") else dtypes.itemsize_bits(scale_dtype)))
74
135
 
75
136
  vmem_scratch = acc_size if save_acc else 0
76
137
  vmem_scratch += x_q_size + x_scale_size if save_x_q else 0
@@ -655,7 +655,8 @@ def cdiv(a, b):
655
655
 
656
656
 
657
657
  def get_dtype_packing(dtype):
658
- bits = dtypes.bit_width(dtype)
658
+ bits = (dtypes.bit_width(dtype)
659
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
659
660
  return 32 // bits
660
661
 
661
662
 
@@ -200,7 +200,8 @@ def _prev_power_of_2(n: int) -> int:
200
200
  def _get_page_size_bytes(block_size: int, num_combined_kv_heads: int,
201
201
  head_size: int, kv_cache_dtype) -> int:
202
202
  """Returns the size in bytes of one page of the KV cache."""
203
- kv_cache_dtype_bit_size = dtypes.bit_width(kv_cache_dtype)
203
+ kv_cache_dtype_bit_size = (dtypes.bit_width(kv_cache_dtype) if hasattr(
204
+ dtypes, "bit_width") else dtypes.itemsize_bits(kv_cache_dtype))
204
205
  padded_head_size = _ceil_div(
205
206
  head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
206
207