tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  import numpy as np
4
- from absl.testing import absltest
4
+ from absl.testing import absltest, parameterized
5
5
  from jax._src import test_util as jtu
6
6
  from jax.sharding import Mesh
7
7
 
@@ -10,6 +10,15 @@ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
10
10
  jax.config.parse_flags_with_absl()
11
11
 
12
12
 
13
+ def cdiv(a, b):
14
+ assert b != 0
15
+ return (a + b - 1) // b
16
+
17
+
18
+ def align_to(x, a):
19
+ return cdiv(x, a) * a
20
+
21
+
13
22
  def gen_moe_inputs(
14
23
  dtype,
15
24
  top_k,
@@ -19,11 +28,14 @@ def gen_moe_inputs(
19
28
  num_tokens,
20
29
  *,
21
30
  seed=1234,
31
+ has_bias=False,
22
32
  ):
23
33
  key = jax.random.key(seed)
24
- k0, k1, k2, k4, k5 = jax.random.split(key, 5)
34
+ k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
35
+
25
36
  a = jax.random.normal(k0, (num_tokens, hidden_size),
26
37
  dtype=jnp.float32).astype(dtype) / 10
38
+
27
39
  w1 = (jax.random.normal(
28
40
  k1,
29
41
  (num_experts, 2, hidden_size, intermediate_size),
@@ -31,21 +43,54 @@ def gen_moe_inputs(
31
43
  ) / 10).astype(dtype)
32
44
  w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
33
45
  dtype=jnp.float32) / 10).astype(dtype)
46
+
47
+ if has_bias:
48
+ b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
49
+ dtype=jnp.float32) / 10).astype(dtype)
50
+ b2 = (jax.random.normal(k4, (num_experts, hidden_size),
51
+ dtype=jnp.float32) / 10).astype(dtype)
52
+ else:
53
+ b1 = b2 = None
54
+
34
55
  gating_output = (
35
- jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) +
56
+ jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
36
57
  jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
37
58
  num_tokens, num_experts) / 100)
59
+
38
60
  # To generate unique top-k!
39
- top_k_indices = jax.random.randint(k5, (num_tokens, top_k),
61
+ top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
40
62
  minval=0,
41
63
  maxval=num_experts - 1,
42
64
  dtype=jnp.int32)
65
+
43
66
  one_hot = (jnp.sum(
44
67
  jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
45
68
  axis=1,
46
- ) * 10)
69
+ ) * 30)
70
+
47
71
  gating_output = (gating_output + one_hot).astype(dtype)
48
- return a, w1, w2, gating_output
72
+
73
+ return a, w1, w2, b1, b2, gating_output
74
+
75
+
76
+ def sub_channel_quantize(x, quant_dtype, wsz=256):
77
+ """Quantizes x with sub-channel quantization on the 2nd minor."""
78
+ if jnp.issubdtype(quant_dtype, jnp.floating):
79
+ dtype_info = jnp.finfo(quant_dtype)
80
+ else:
81
+ dtype_info = jnp.iinfo(quant_dtype)
82
+ dtype_max = float(dtype_info.max)
83
+ w_lst, scale_lst = [], []
84
+ assert len(x.shape) >= 2
85
+ assert x.shape[-2] % wsz == 0
86
+ for i in range(0, x.shape[-2], wsz):
87
+ y = x[..., i:i + wsz, :]
88
+ abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
89
+ scale = (abs_max / dtype_max).astype(jnp.float32)
90
+ w = (y / scale).astype(quant_dtype)
91
+ w_lst.append(w)
92
+ scale_lst.append(scale)
93
+ return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
49
94
 
50
95
 
51
96
  @jtu.with_config(jax_numpy_dtype_promotion="standard")
@@ -63,42 +108,266 @@ class MoEKernelTest(jtu.JaxTestCase):
63
108
  self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
64
109
  axis_names=("data", "model"))
65
110
 
66
- def test_basic(self):
67
- dtype = jnp.bfloat16
68
- top_k = 2
69
- num_experts = 16
70
- hidden_size = 256
71
- intermediate_size = 256
72
- num_tokens = 8 * 2
73
-
74
- a, w1, w2, gating_output = gen_moe_inputs(
111
+ def _test_moe(
112
+ self,
113
+ dtype,
114
+ top_k,
115
+ num_experts,
116
+ hidden_size,
117
+ intermediate_size,
118
+ num_tokens,
119
+ seed,
120
+ renormalize_topk_logits,
121
+ bt,
122
+ bf,
123
+ bd1,
124
+ bd2,
125
+ btc,
126
+ bfc,
127
+ bd1c,
128
+ bd2c,
129
+ act_fn="silu",
130
+ w_dtype=None,
131
+ subc_quant_wsz=None,
132
+ has_bias=False,
133
+ atol=2e-1,
134
+ rtol=2e-1,
135
+ ):
136
+ a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
75
137
  dtype,
76
138
  top_k,
77
139
  num_experts,
78
140
  hidden_size,
79
141
  intermediate_size,
80
142
  num_tokens,
143
+ seed=seed,
144
+ has_bias=has_bias,
81
145
  )
146
+ w1_scale = None
147
+ w2_scale = None
148
+ if w_dtype is not None:
149
+ if subc_quant_wsz is None:
150
+ subc_quant_wsz = 256
151
+ w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
152
+ w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
82
153
 
83
- actual = jax.block_until_ready(
84
- fused_ep_moe(
85
- mesh=self.mesh,
86
- tokens=a,
87
- w1=w1,
88
- w2=w2,
89
- gating_output=gating_output,
90
- top_k=top_k,
91
- bt=32,
92
- bf=512,
93
- bd1=512,
94
- bd2=512,
95
- btc=32,
96
- bfc=256,
97
- bd1c=256,
98
- bd2c=256,
99
- ))
100
- expected = ref_moe(a, w1, w2, gating_output, top_k)
101
- self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2)
154
+ actual = fused_ep_moe(
155
+ mesh=self.mesh,
156
+ tokens=a,
157
+ w1=w1,
158
+ w2=w2,
159
+ gating_output=gating_output,
160
+ top_k=top_k,
161
+ renormalize_topk_logits=renormalize_topk_logits,
162
+ act_fn=act_fn,
163
+ subc_quant_wsz=subc_quant_wsz,
164
+ w1_scale=w1_scale,
165
+ w2_scale=w2_scale,
166
+ b1=b1,
167
+ b2=b2,
168
+ bt=bt,
169
+ bf=bf,
170
+ bd1=bd1,
171
+ bd2=bd2,
172
+ btc=btc,
173
+ bfc=bfc,
174
+ bd1c=bd1c,
175
+ bd2c=bd2c,
176
+ )
177
+ expected = ref_moe(
178
+ a,
179
+ w1,
180
+ w2,
181
+ gating_output,
182
+ top_k,
183
+ b1=b1,
184
+ b2=b2,
185
+ renormalize_topk_logits=renormalize_topk_logits,
186
+ activation=act_fn,
187
+ subc_quant_wsz=subc_quant_wsz,
188
+ w1_scale=w1_scale,
189
+ w2_scale=w2_scale,
190
+ )
191
+ self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
192
+
193
+ @parameterized.product(renormalize_topk_logits=[True, False], )
194
+ def test_basic(self, renormalize_topk_logits):
195
+ dtype = jnp.bfloat16
196
+ top_k = 8
197
+ num_experts = 128
198
+ hidden_size = 1024
199
+ intermediate_size = 1024
200
+ num_tokens = 8 * 32
201
+ self._test_moe(
202
+ dtype=dtype,
203
+ top_k=top_k,
204
+ num_experts=num_experts,
205
+ hidden_size=hidden_size,
206
+ intermediate_size=intermediate_size,
207
+ num_tokens=num_tokens,
208
+ seed=1234,
209
+ renormalize_topk_logits=renormalize_topk_logits,
210
+ bt=32,
211
+ bf=1024,
212
+ bd1=1024,
213
+ bd2=1024,
214
+ btc=32,
215
+ bfc=256,
216
+ bd1c=256,
217
+ bd2c=256,
218
+ )
219
+
220
+ @parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
221
+ def test_activation(self, act_fn):
222
+ dtype = jnp.bfloat16
223
+ top_k = 8
224
+ num_experts = 128
225
+ hidden_size = 1024
226
+ intermediate_size = 1024
227
+ num_tokens = 8 * 32
228
+ self._test_moe(
229
+ dtype=dtype,
230
+ top_k=top_k,
231
+ num_experts=num_experts,
232
+ hidden_size=hidden_size,
233
+ intermediate_size=intermediate_size,
234
+ num_tokens=num_tokens,
235
+ seed=1234,
236
+ renormalize_topk_logits=True,
237
+ act_fn=act_fn,
238
+ bt=32,
239
+ bf=512,
240
+ bd1=512,
241
+ bd2=512,
242
+ btc=32,
243
+ bfc=256,
244
+ bd1c=256,
245
+ bd2c=256,
246
+ )
247
+
248
+ def test_benchmark_qwen_235(self):
249
+ num_experts = 128
250
+ top_k = 8
251
+ hidden_size = 4096
252
+ intermediate_size = 1536
253
+ dtype = jnp.bfloat16
254
+ num_tokens = 8 * 64
255
+ seed = 54321
256
+ renormalize_topk_logits = True
257
+ self._test_moe(
258
+ dtype=dtype,
259
+ top_k=top_k,
260
+ num_experts=num_experts,
261
+ hidden_size=hidden_size,
262
+ intermediate_size=intermediate_size,
263
+ num_tokens=num_tokens,
264
+ seed=seed,
265
+ renormalize_topk_logits=renormalize_topk_logits,
266
+ bt=64,
267
+ bf=768,
268
+ bd1=2048,
269
+ bd2=2048,
270
+ btc=64,
271
+ bfc=768,
272
+ bd1c=2048,
273
+ bd2c=2048,
274
+ act_fn="silu",
275
+ atol=5e-2,
276
+ rtol=5e-2,
277
+ )
278
+
279
+ def test_benchmark_qwen_30b_a3b(self):
280
+ num_experts = 128
281
+ top_k = 8
282
+ hidden_size = 2048
283
+ intermediate_size = 768
284
+ dtype = jnp.bfloat16
285
+ num_tokens = 512
286
+ seed = 54321
287
+ renormalize_topk_logits = True
288
+ self._test_moe(
289
+ dtype=dtype,
290
+ top_k=top_k,
291
+ num_experts=num_experts,
292
+ hidden_size=hidden_size,
293
+ intermediate_size=intermediate_size,
294
+ num_tokens=num_tokens,
295
+ seed=seed,
296
+ renormalize_topk_logits=renormalize_topk_logits,
297
+ bt=16,
298
+ bf=384,
299
+ bd1=512,
300
+ bd2=512,
301
+ btc=16,
302
+ bfc=384,
303
+ bd1c=256,
304
+ bd2c=256,
305
+ act_fn="silu",
306
+ atol=5e-2,
307
+ rtol=5e-2,
308
+ )
309
+
310
+ @parameterized.product(
311
+ w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
312
+ def test_sub_channel_quantization(self, w_dtype):
313
+ if w_dtype in (
314
+ jnp.float8_e5m2,
315
+ jnp.float4_e2m1fn,
316
+ ) and not jtu.is_device_tpu_at_least(version=7):
317
+ self.skipTest("Expect TPUv7+")
318
+ dtype = jnp.bfloat16
319
+ top_k = 8
320
+ num_experts = 128
321
+ hidden_size = 1024
322
+ intermediate_size = 1024
323
+ num_tokens = 8 * 32
324
+ self._test_moe(
325
+ dtype=dtype,
326
+ top_k=top_k,
327
+ num_experts=num_experts,
328
+ hidden_size=hidden_size,
329
+ intermediate_size=intermediate_size,
330
+ num_tokens=num_tokens,
331
+ seed=1234,
332
+ renormalize_topk_logits=False,
333
+ w_dtype=w_dtype,
334
+ subc_quant_wsz=256,
335
+ bt=32,
336
+ bf=1024,
337
+ bd1=1024,
338
+ bd2=1024,
339
+ btc=32,
340
+ bfc=256,
341
+ bd1c=256,
342
+ bd2c=256,
343
+ )
344
+
345
+ def test_bias(self):
346
+ dtype = jnp.bfloat16
347
+ top_k = 8
348
+ num_experts = 128
349
+ hidden_size = 1024
350
+ intermediate_size = 1024
351
+ num_tokens = 8 * 32
352
+ self._test_moe(
353
+ dtype=dtype,
354
+ top_k=top_k,
355
+ num_experts=num_experts,
356
+ hidden_size=hidden_size,
357
+ intermediate_size=intermediate_size,
358
+ num_tokens=num_tokens,
359
+ seed=1234,
360
+ renormalize_topk_logits=False,
361
+ has_bias=True,
362
+ bt=32,
363
+ bf=512,
364
+ bd1=512,
365
+ bd2=512,
366
+ btc=32,
367
+ bfc=256,
368
+ bd1c=256,
369
+ bd2c=256,
370
+ )
102
371
 
103
372
 
104
373
  if __name__ == "__main__":
@@ -42,6 +42,7 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
42
42
 
43
43
  padded_r_dim = align_to(r_dim, 128)
44
44
  padded_lkv_dim = align_to(lkv_dim, 128)
45
+ padded_kv_dim = padded_lkv_dim + padded_r_dim
45
46
  packing = get_dtype_packing(kv_dtype)
46
47
  q_lens = [s[0] for s in seq_lens]
47
48
  kv_lens_list = [s[1] for s in seq_lens]
@@ -69,13 +70,10 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
69
70
  new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype)
70
71
  new_k_pe = gen_random((total_q_len, r_dim), kv_dtype)
71
72
 
72
- cache_kv_c = gen_random(
73
- (total_num_pages, page_size // packing, packing, padded_lkv_dim),
73
+ cache_kv = gen_random(
74
+ (total_num_pages, page_size // packing, packing, padded_kv_dim),
74
75
  kv_dtype,
75
76
  )
76
- cache_k_pe = gen_random(
77
- (total_num_pages, page_size // packing, packing, padded_r_dim),
78
- kv_dtype)
79
77
  kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32)
80
78
  page_indices = jnp.array(page_indices_list, dtype=jnp.int32)
81
79
  cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32)
@@ -84,14 +82,13 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
84
82
  ql_nope_for_kernel = ql_nope.copy()
85
83
  q_pe_for_kernel = q_pe.copy()
86
84
 
87
- expected_out, expected_updated_kv_c, expeceted_updated_k_pe = (
85
+ expected_out, expected_updated_kv = (
88
86
  mla.ref_mla_ragged_paged_attention(
89
87
  ql_nope,
90
88
  q_pe,
91
89
  new_kv_c,
92
90
  new_k_pe,
93
- cache_kv_c.copy(),
94
- cache_k_pe.copy(),
91
+ cache_kv.copy(),
95
92
  kv_lens,
96
93
  page_indices,
97
94
  cu_q_lens,
@@ -101,50 +98,141 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
101
98
  soft_cap=soft_cap,
102
99
  ))
103
100
 
104
- kernel_out, kernel_updated_kv_c, kernel_updated_k_pe = (
105
- mla.mla_ragged_paged_attention(
106
- ql_nope_for_kernel,
107
- q_pe_for_kernel,
108
- new_kv_c,
109
- new_k_pe,
110
- cache_kv_c.copy(),
111
- cache_k_pe.copy(),
112
- kv_lens,
113
- page_indices,
114
- cu_q_lens,
115
- distribution,
116
- sm_scale=sm_scale,
117
- sliding_window=sliding_window,
118
- soft_cap=soft_cap,
119
- num_kv_pages_per_block=num_kv_pages_per_block,
120
- num_queries_per_block=num_queries_per_block,
121
- vmem_limit_bytes=vmem_limit_bytes,
122
- ))
101
+ kernel_out, kernel_updated_kv = (mla.mla_ragged_paged_attention(
102
+ ql_nope_for_kernel,
103
+ q_pe_for_kernel,
104
+ new_kv_c,
105
+ new_k_pe,
106
+ cache_kv.copy(),
107
+ kv_lens,
108
+ page_indices,
109
+ cu_q_lens,
110
+ distribution,
111
+ sm_scale=sm_scale,
112
+ sliding_window=sliding_window,
113
+ soft_cap=soft_cap,
114
+ num_kv_pages_per_block=num_kv_pages_per_block,
115
+ num_queries_per_block=num_queries_per_block,
116
+ vmem_limit_bytes=vmem_limit_bytes,
117
+ ))
123
118
 
124
119
  self.assertEqual(expected_out.shape,
125
120
  (total_q_len, num_heads, padded_lkv_dim))
126
121
  self.assertEqual(
127
- expected_updated_kv_c.shape,
128
- (total_num_pages, page_size // packing, packing, padded_lkv_dim),
129
- )
130
- self.assertEqual(
131
- expeceted_updated_k_pe.shape,
132
- (total_num_pages, page_size // packing, packing, padded_r_dim),
122
+ expected_updated_kv.shape,
123
+ (total_num_pages, page_size // packing, packing, padded_kv_dim),
133
124
  )
134
125
  self.assertEqual(expected_out.dtype, kv_dtype)
135
- self.assertEqual(expected_updated_kv_c.dtype, kv_dtype)
136
- self.assertEqual(expeceted_updated_k_pe.dtype, kv_dtype)
126
+ self.assertEqual(expected_updated_kv.dtype, kv_dtype)
137
127
 
138
128
  self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2)
139
- self.assertAllClose(expected_updated_kv_c,
140
- kernel_updated_kv_c,
141
- atol=0.2,
142
- rtol=0.2)
143
- self.assertAllClose(expeceted_updated_k_pe,
144
- kernel_updated_k_pe,
129
+ self.assertAllClose(expected_updated_kv,
130
+ kernel_updated_kv,
145
131
  atol=0.2,
146
132
  rtol=0.2)
147
133
 
134
+ def test_update_kv_cache(self):
135
+ lkv_dim = 4
136
+ r_dim = 4
137
+ padded_lkv_dim = align_to(lkv_dim, 128)
138
+ padded_r_dim = align_to(r_dim, 128)
139
+ kv_dtype = jnp.bfloat16
140
+ new_kv_c = jnp.arange(16, dtype=kv_dtype).reshape((4, lkv_dim))
141
+ new_k_pe = (jnp.arange(16, dtype=kv_dtype).reshape((4, r_dim)) + 100)
142
+ total_num_pages = 2
143
+ page_size = 4
144
+ cache_kv_shape = mla.get_kv_cache_shape(
145
+ total_num_pages,
146
+ page_size,
147
+ padded_lkv_dim + padded_r_dim,
148
+ kv_dtype,
149
+ )
150
+ cache_kv = jnp.zeros(cache_kv_shape, dtype=kv_dtype)
151
+
152
+ # two sequences, first with 3 tokens, second with 1 token
153
+ kv_lens = jnp.array([3, 1], dtype=jnp.int32)
154
+ # first seq uses page 0, second uses page 1
155
+ page_indices = jnp.array([0, -1, 1, -1], dtype=jnp.int32)
156
+ # three tokens for first seq, one for second
157
+ cu_q_lens = jnp.array([0, 3, 4], dtype=jnp.int32)
158
+ distribution = jnp.array([0, 0, 2], dtype=jnp.int32)
159
+
160
+ # manually compute the expected cache
161
+ padded_new_kv_c = jnp.pad(new_kv_c,
162
+ ((0, 0), (0, padded_lkv_dim - lkv_dim)),
163
+ constant_values=0)
164
+ padded_new_k_pe = jnp.pad(new_k_pe,
165
+ ((0, 0), (0, padded_r_dim - r_dim)),
166
+ constant_values=0)
167
+
168
+ expected_cache = cache_kv
169
+ # First sequence
170
+ # token 0
171
+ page_idx, row, col = 0, 0, 0
172
+ expected_cache = expected_cache.at[page_idx, row,
173
+ col, :padded_lkv_dim].set(
174
+ padded_new_kv_c[0])
175
+ expected_cache = expected_cache.at[page_idx, row, col,
176
+ padded_lkv_dim:padded_lkv_dim +
177
+ padded_r_dim].set(
178
+ padded_new_k_pe[0])
179
+ # token 1
180
+ page_idx, row, col = 0, 0, 1
181
+ expected_cache = expected_cache.at[page_idx, row,
182
+ col, :padded_lkv_dim].set(
183
+ padded_new_kv_c[1])
184
+ expected_cache = expected_cache.at[page_idx, row, col,
185
+ padded_lkv_dim:padded_lkv_dim +
186
+ padded_r_dim].set(
187
+ padded_new_k_pe[1])
188
+ # token 2
189
+ page_idx, row, col = 0, 1, 0
190
+ expected_cache = expected_cache.at[page_idx, row,
191
+ col, :padded_lkv_dim].set(
192
+ padded_new_kv_c[2])
193
+ expected_cache = expected_cache.at[page_idx, row, col,
194
+ padded_lkv_dim:padded_lkv_dim +
195
+ padded_r_dim].set(
196
+ padded_new_k_pe[2])
197
+
198
+ # Second sequence
199
+ # token 0
200
+ page_idx, row, col = 1, 0, 0
201
+ expected_cache = expected_cache.at[page_idx, row,
202
+ col, :padded_lkv_dim].set(
203
+ padded_new_kv_c[3])
204
+ expected_cache = expected_cache.at[page_idx, row, col,
205
+ padded_lkv_dim:padded_lkv_dim +
206
+ padded_r_dim].set(
207
+ padded_new_k_pe[3])
208
+
209
+ updated_cache = mla.update_kv_cache(
210
+ new_kv_c,
211
+ new_k_pe,
212
+ cache_kv,
213
+ kv_lens,
214
+ page_indices,
215
+ cu_q_lens,
216
+ distribution,
217
+ )
218
+
219
+ self.assertAllClose(updated_cache, expected_cache)
220
+
221
+ def test_get_kv_cache_shape(self):
222
+ total_num_pages = 10
223
+ page_size = 16
224
+ lkv_dim = 128
225
+ kv_dtype = jnp.bfloat16
226
+ # The calculation for the expected shape is as follows:
227
+ # kv_packing is determined by the dtype, which is 2 for bfloat16.
228
+ # The second dimension is page_size / kv_packing = 16 / 2 = 8
229
+ # The third dimension is kv_packing = 2
230
+ # The fourth dimension is lkv_dim aligned to 128, which is 128
231
+ expected_shape = (10, 8, 2, 128)
232
+ self.assertEqual(
233
+ mla.get_kv_cache_shape(total_num_pages, page_size, lkv_dim,
234
+ kv_dtype), expected_shape)
235
+
148
236
  def test_ragged_paged_attention_basic(self):
149
237
  dtype = jnp.bfloat16
150
238
  seq_lens = [(192, 328), (128, 180), (64, 255)]
@@ -1,7 +1,5 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import functools
4
-
5
3
  import jax
6
4
  import jax.numpy as jnp
7
5
  from absl.testing import absltest, parameterized
@@ -10,6 +8,7 @@ from jax._src import test_util as jtu
10
8
  from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
11
9
  util)
12
10
 
11
+ xla_quantized_matmul = kernel.xla_quantized_matmul
13
12
  quantized_matmul_kernel = kernel.quantized_matmul_kernel
14
13
  quantize_tensor = util.quantize_tensor
15
14
  get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
@@ -17,37 +16,6 @@ get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
17
16
  jax.config.parse_flags_with_absl()
18
17
 
19
18
 
20
- @functools.partial(jax.jit, static_argnames=["quantize_activation"])
21
- def reference_quantized_matmul(
22
- x: jax.Array,
23
- w_q: jax.Array,
24
- w_scale: jax.Array,
25
- quantize_activation=True,
26
- ):
27
- if quantize_activation:
28
- acc_dtype = jnp.float32
29
- if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
30
- acc_dtype = jnp.int32
31
-
32
- x_q, x_scale = quantize_tensor(x, w_q.dtype)
33
- out = jax.lax.dot_general(
34
- x_q,
35
- w_q,
36
- dimension_numbers=(((1, ), (1, )), ((), ())),
37
- preferred_element_type=acc_dtype,
38
- ).astype(jnp.float32)
39
- out *= x_scale
40
- else:
41
- out = jax.lax.dot_general(
42
- x,
43
- w_q,
44
- dimension_numbers=(((1, ), (1, )), ((), ())),
45
- preferred_element_type=jnp.float32,
46
- )
47
- out *= jnp.expand_dims(w_scale, 0)
48
- return out.astype(x.dtype)
49
-
50
-
51
19
  @jtu.with_config(jax_numpy_dtype_promotion="standard")
52
20
  class QuantizedMatmulKernelTest(jtu.JaxTestCase):
53
21
 
@@ -94,7 +62,7 @@ class QuantizedMatmulKernelTest(jtu.JaxTestCase):
94
62
  x_q_dtype=x_q_dtype,
95
63
  tuned_value=tuned_value,
96
64
  )
97
- expected = reference_quantized_matmul(
65
+ expected = xla_quantized_matmul(
98
66
  x, w_q, w_scale, quantize_activation=quantize_activation)
99
67
 
100
68
  self.assertAllClose(output,
@@ -176,7 +176,9 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
176
176
  )
177
177
  output = output[:cu_q_lens[distribution[-1]]]
178
178
 
179
- dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
179
+ dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
180
+ dtypes, "bit_width") else dtypes.itemsize_bits(
181
+ jnp.dtype(kv_dtype)))
180
182
  tols = {
181
183
  32: 0.15,
182
184
  16: 0.2,
@@ -162,7 +162,9 @@ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
162
162
  )
163
163
  output = output[:cu_q_lens[distribution[-1]]]
164
164
 
165
- dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
165
+ dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
166
+ dtypes, "bit_width") else dtypes.itemsize_bits(
167
+ jnp.dtype(kv_dtype)))
166
168
  tols = {
167
169
  32: 0.15,
168
170
  16: 0.2,