tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,49 @@
1
+ import importlib
2
+ import unittest
3
+ from unittest.mock import patch
4
+
5
+
6
+ class TestPathwaysInit(unittest.TestCase):
7
+
8
+ @patch.dict("os.environ", {"JAX_PLATFORMS": "proxy,cpu"})
9
+ def test_VLLM_TPU_USING_PATHWAYS_enabled(self):
10
+ """Test when JAX_PLATFORMS contains 'proxy'."""
11
+
12
+ # Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
13
+ import vllm.envs as envs
14
+
15
+ # Reload the module to ensure fresh import
16
+ importlib.reload(envs)
17
+
18
+ # Check that VLLM_TPU_USING_PATHWAYS is True when JAX_PLATFORMS contains "proxy"
19
+ self.assertTrue(envs.VLLM_TPU_USING_PATHWAYS)
20
+
21
+ @patch.dict("os.environ", {"JAX_PLATFORMS": "cpu"})
22
+ def test_VLLM_TPU_USING_PATHWAYS_not_enabled(self):
23
+ """Test when JAX_PLATFORMS does not contain 'proxy'."""
24
+
25
+ # Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
26
+ import vllm.envs as envs
27
+
28
+ # Reload the module to ensure fresh import
29
+ importlib.reload(envs)
30
+
31
+ # Check that VLLM_TPU_USING_PATHWAYS is False when JAX_PLATFORMS doesn't contain "proxy"
32
+ self.assertFalse(envs.VLLM_TPU_USING_PATHWAYS)
33
+
34
+ @patch.dict("os.environ", {"JAX_PLATFORMS": "PROXY,CPU"})
35
+ def test_VLLM_TPU_USING_PATHWAYS_case_insensitive(self):
36
+ """Test that JAX_PLATFORMS check is case insensitive."""
37
+
38
+ # Import vllm.envs to test the VLLM_TPU_USING_PATHWAYS logic
39
+ import vllm.envs as envs
40
+
41
+ # Reload the module to ensure fresh import
42
+ importlib.reload(envs)
43
+
44
+ # Check that VLLM_TPU_USING_PATHWAYS is True even with uppercase "PROXY"
45
+ self.assertTrue(envs.VLLM_TPU_USING_PATHWAYS)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ unittest.main()
File without changes
@@ -0,0 +1,105 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from absl.testing import absltest
5
+ from jax._src import test_util as jtu
6
+ from jax.sharding import Mesh
7
+
8
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
9
+
10
+ jax.config.parse_flags_with_absl()
11
+
12
+
13
+ def gen_moe_inputs(
14
+ dtype,
15
+ top_k,
16
+ num_experts,
17
+ hidden_size,
18
+ intermediate_size,
19
+ num_tokens,
20
+ *,
21
+ seed=1234,
22
+ ):
23
+ key = jax.random.key(seed)
24
+ k0, k1, k2, k4, k5 = jax.random.split(key, 5)
25
+ a = jax.random.normal(k0, (num_tokens, hidden_size),
26
+ dtype=jnp.float32).astype(dtype) / 10
27
+ w1 = (jax.random.normal(
28
+ k1,
29
+ (num_experts, 2, hidden_size, intermediate_size),
30
+ dtype=jnp.float32,
31
+ ) / 10).astype(dtype)
32
+ w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
33
+ dtype=jnp.float32) / 10).astype(dtype)
34
+ gating_output = (
35
+ jax.random.normal(k4, (num_tokens, num_experts), dtype=jnp.float32) +
36
+ jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
37
+ num_tokens, num_experts) / 100)
38
+ # To generate unique top-k!
39
+ top_k_indices = jax.random.randint(k5, (num_tokens, top_k),
40
+ minval=0,
41
+ maxval=num_experts - 1,
42
+ dtype=jnp.int32)
43
+ one_hot = (jnp.sum(
44
+ jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
45
+ axis=1,
46
+ ) * 10)
47
+ gating_output = (gating_output + one_hot).astype(dtype)
48
+ return a, w1, w2, gating_output
49
+
50
+
51
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
52
+ class MoEKernelTest(jtu.JaxTestCase):
53
+
54
+ def setUp(self):
55
+ super().setUp()
56
+ self.mesh_devices = sorted(
57
+ jax.devices(),
58
+ key=lambda x: (
59
+ x.coords[0],
60
+ (-1 if x.coords[0] % 2 else 1) * x.coords[1],
61
+ ),
62
+ )
63
+ self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
64
+ axis_names=("data", "model"))
65
+
66
+ def test_basic(self):
67
+ dtype = jnp.bfloat16
68
+ top_k = 2
69
+ num_experts = 16
70
+ hidden_size = 256
71
+ intermediate_size = 256
72
+ num_tokens = 8 * 2
73
+
74
+ a, w1, w2, gating_output = gen_moe_inputs(
75
+ dtype,
76
+ top_k,
77
+ num_experts,
78
+ hidden_size,
79
+ intermediate_size,
80
+ num_tokens,
81
+ )
82
+
83
+ actual = jax.block_until_ready(
84
+ fused_ep_moe(
85
+ mesh=self.mesh,
86
+ tokens=a,
87
+ w1=w1,
88
+ w2=w2,
89
+ gating_output=gating_output,
90
+ top_k=top_k,
91
+ bt=32,
92
+ bf=512,
93
+ bd1=512,
94
+ bd2=512,
95
+ btc=32,
96
+ bfc=256,
97
+ bd1c=256,
98
+ bd2c=256,
99
+ ))
100
+ expected = ref_moe(a, w1, w2, gating_output, top_k)
101
+ self.assertAllClose(expected, actual, atol=2e-2, rtol=2e-2)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -0,0 +1,396 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from absl.testing import absltest, parameterized
5
+ from jax._src import test_util as jtu
6
+
7
+ import tpu_inference.kernels.mla.v1.kernel as mla
8
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
9
+ align_to, cdiv, get_dtype_packing)
10
+
11
+ jax.config.parse_flags_with_absl()
12
+
13
+
14
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
15
+ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
16
+
17
+ def _test_mla_ragged_paged_attention(
18
+ self,
19
+ seq_lens, # List[(q_len, kv_len)]
20
+ num_heads,
21
+ lkv_dim,
22
+ r_dim,
23
+ page_size,
24
+ q_dtype,
25
+ kv_dtype,
26
+ num_pages,
27
+ *,
28
+ num_kv_pages_per_block=8,
29
+ num_queries_per_block=8,
30
+ vmem_limit_bytes=100 * 1024 * 1024,
31
+ sm_scale=1.0,
32
+ sliding_window: int | None = None,
33
+ soft_cap: float | None = None,
34
+ ):
35
+ if not jtu.is_device_tpu_at_least(version=4):
36
+ self.skipTest("Expect TPUv4+")
37
+ rng = np.random.default_rng(1234)
38
+
39
+ def gen_random(shape, dtype):
40
+ return jnp.array(rng.random(size=shape,
41
+ dtype=np.float32)).astype(dtype)
42
+
43
+ padded_r_dim = align_to(r_dim, 128)
44
+ padded_lkv_dim = align_to(lkv_dim, 128)
45
+ packing = get_dtype_packing(kv_dtype)
46
+ q_lens = [s[0] for s in seq_lens]
47
+ kv_lens_list = [s[1] for s in seq_lens]
48
+ total_q_len = sum(q_lens)
49
+ cu_q_lens_list = [0]
50
+ for q_len in q_lens:
51
+ cu_q_lens_list.append(cu_q_lens_list[-1] + q_len)
52
+
53
+ max_kv_len = max(kv_lens_list) if kv_lens_list else 0
54
+ pages_per_seq = cdiv(max_kv_len, page_size)
55
+
56
+ page_indices_list = []
57
+ page_count = 0
58
+ for kv_len in kv_lens_list:
59
+ num_seq_pages = cdiv(kv_len, page_size)
60
+ indices = list(range(page_count, page_count + num_seq_pages))
61
+ page_indices_list.extend(indices + [-1] *
62
+ (pages_per_seq - num_seq_pages))
63
+ page_count += num_seq_pages
64
+
65
+ total_num_pages = max(num_pages, page_count)
66
+
67
+ ql_nope = gen_random((total_q_len, num_heads, lkv_dim), q_dtype)
68
+ q_pe = gen_random((total_q_len, num_heads, r_dim), q_dtype)
69
+ new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype)
70
+ new_k_pe = gen_random((total_q_len, r_dim), kv_dtype)
71
+
72
+ cache_kv_c = gen_random(
73
+ (total_num_pages, page_size // packing, packing, padded_lkv_dim),
74
+ kv_dtype,
75
+ )
76
+ cache_k_pe = gen_random(
77
+ (total_num_pages, page_size // packing, packing, padded_r_dim),
78
+ kv_dtype)
79
+ kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32)
80
+ page_indices = jnp.array(page_indices_list, dtype=jnp.int32)
81
+ cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32)
82
+ distribution = jnp.array([0, 0, len(seq_lens)], dtype=jnp.int32)
83
+
84
+ ql_nope_for_kernel = ql_nope.copy()
85
+ q_pe_for_kernel = q_pe.copy()
86
+
87
+ expected_out, expected_updated_kv_c, expeceted_updated_k_pe = (
88
+ mla.ref_mla_ragged_paged_attention(
89
+ ql_nope,
90
+ q_pe,
91
+ new_kv_c,
92
+ new_k_pe,
93
+ cache_kv_c.copy(),
94
+ cache_k_pe.copy(),
95
+ kv_lens,
96
+ page_indices,
97
+ cu_q_lens,
98
+ distribution,
99
+ sm_scale=sm_scale,
100
+ sliding_window=sliding_window,
101
+ soft_cap=soft_cap,
102
+ ))
103
+
104
+ kernel_out, kernel_updated_kv_c, kernel_updated_k_pe = (
105
+ mla.mla_ragged_paged_attention(
106
+ ql_nope_for_kernel,
107
+ q_pe_for_kernel,
108
+ new_kv_c,
109
+ new_k_pe,
110
+ cache_kv_c.copy(),
111
+ cache_k_pe.copy(),
112
+ kv_lens,
113
+ page_indices,
114
+ cu_q_lens,
115
+ distribution,
116
+ sm_scale=sm_scale,
117
+ sliding_window=sliding_window,
118
+ soft_cap=soft_cap,
119
+ num_kv_pages_per_block=num_kv_pages_per_block,
120
+ num_queries_per_block=num_queries_per_block,
121
+ vmem_limit_bytes=vmem_limit_bytes,
122
+ ))
123
+
124
+ self.assertEqual(expected_out.shape,
125
+ (total_q_len, num_heads, padded_lkv_dim))
126
+ self.assertEqual(
127
+ expected_updated_kv_c.shape,
128
+ (total_num_pages, page_size // packing, packing, padded_lkv_dim),
129
+ )
130
+ self.assertEqual(
131
+ expeceted_updated_k_pe.shape,
132
+ (total_num_pages, page_size // packing, packing, padded_r_dim),
133
+ )
134
+ self.assertEqual(expected_out.dtype, kv_dtype)
135
+ self.assertEqual(expected_updated_kv_c.dtype, kv_dtype)
136
+ self.assertEqual(expeceted_updated_k_pe.dtype, kv_dtype)
137
+
138
+ self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2)
139
+ self.assertAllClose(expected_updated_kv_c,
140
+ kernel_updated_kv_c,
141
+ atol=0.2,
142
+ rtol=0.2)
143
+ self.assertAllClose(expeceted_updated_k_pe,
144
+ kernel_updated_k_pe,
145
+ atol=0.2,
146
+ rtol=0.2)
147
+
148
+ def test_ragged_paged_attention_basic(self):
149
+ dtype = jnp.bfloat16
150
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
151
+ num_heads = 128
152
+ lkv_dim = 512
153
+ r_dim = 64
154
+ page_size = 16
155
+ num_pages = 1000
156
+
157
+ self._test_mla_ragged_paged_attention(
158
+ seq_lens,
159
+ num_heads,
160
+ lkv_dim,
161
+ r_dim,
162
+ page_size,
163
+ dtype,
164
+ dtype,
165
+ num_pages,
166
+ )
167
+
168
+ @parameterized.product(dtype=[jnp.bfloat16], )
169
+ def test_ragged_paged_attention_decode_only(self, dtype):
170
+ seq_lens = [
171
+ (1, 18),
172
+ (1, 129),
173
+ (1, 597),
174
+ (1, 122),
175
+ (1, 64),
176
+ (1, 322),
177
+ (1, 463),
178
+ (1, 181),
179
+ (1, 1107),
180
+ (1, 123),
181
+ (1, 31),
182
+ (1, 18),
183
+ (1, 1229),
184
+ (1, 229),
185
+ (1, 87),
186
+ (1, 1328),
187
+ ]
188
+ num_heads = 128
189
+ lkv_dim = 512
190
+ r_dim = 64
191
+ page_size = 16
192
+ num_pages = 1000
193
+
194
+ self._test_mla_ragged_paged_attention(
195
+ seq_lens,
196
+ num_heads,
197
+ lkv_dim,
198
+ r_dim,
199
+ page_size,
200
+ dtype,
201
+ dtype,
202
+ num_pages,
203
+ )
204
+
205
+ @parameterized.product(dtype=[jnp.bfloat16], )
206
+ def test_ragged_paged_attention_prefill_only(self, dtype):
207
+ seq_lens = [
208
+ (5, 18),
209
+ (15, 129),
210
+ (120, 597),
211
+ (100, 122),
212
+ (21, 64),
213
+ (32, 322),
214
+ (251, 463),
215
+ (40, 181),
216
+ (64, 1107),
217
+ (99, 123),
218
+ (10, 31),
219
+ (5, 18),
220
+ (3, 1229),
221
+ (120, 229),
222
+ (9, 87),
223
+ (2, 1328),
224
+ ]
225
+ num_heads = 128
226
+ lkv_dim = 512
227
+ r_dim = 64
228
+ page_size = 16
229
+ num_pages = 1000
230
+
231
+ self._test_mla_ragged_paged_attention(
232
+ seq_lens,
233
+ num_heads,
234
+ lkv_dim,
235
+ r_dim,
236
+ page_size,
237
+ dtype,
238
+ dtype,
239
+ num_pages,
240
+ )
241
+
242
+ @parameterized.product(dtype=[jnp.bfloat16], )
243
+ def test_ragged_paged_attention_mixed(self, dtype):
244
+ seq_lens = [
245
+ (5, 18),
246
+ (1, 129),
247
+ (120, 597),
248
+ (1, 122),
249
+ (1, 64),
250
+ (32, 322),
251
+ (251, 463),
252
+ (1, 181),
253
+ (1, 1107),
254
+ (99, 123),
255
+ (1, 31),
256
+ (5, 18),
257
+ (3, 1229),
258
+ (117, 229),
259
+ (1, 87),
260
+ (1, 1328),
261
+ ]
262
+ num_heads = 128
263
+ lkv_dim = 512
264
+ r_dim = 64
265
+ page_size = 16
266
+ num_pages = 1000
267
+
268
+ self._test_mla_ragged_paged_attention(
269
+ seq_lens,
270
+ num_heads,
271
+ lkv_dim,
272
+ r_dim,
273
+ page_size,
274
+ dtype,
275
+ dtype,
276
+ num_pages,
277
+ )
278
+
279
+ @parameterized.product(sliding_window=[None, 5, 128], )
280
+ def test_ragged_paged_attention_sliding_window(
281
+ self,
282
+ sliding_window: int | None,
283
+ ):
284
+ num_seqs = 5
285
+ num_heads = 128
286
+ lkv_dim = 512
287
+ r_dim = 64
288
+ dtype = jnp.float32
289
+ rng = np.random.default_rng(1234)
290
+ q_lens = rng.integers(1, 100, num_seqs)
291
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
292
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
293
+ page_size = 16
294
+ num_pages = 1000
295
+
296
+ self._test_mla_ragged_paged_attention(
297
+ seq_lens,
298
+ num_heads,
299
+ lkv_dim,
300
+ r_dim,
301
+ page_size,
302
+ dtype,
303
+ dtype,
304
+ num_pages,
305
+ sliding_window=sliding_window,
306
+ )
307
+
308
+ @parameterized.product(soft_cap=[None, 50.0], )
309
+ def test_ragged_paged_attention_logit_soft_capping(
310
+ self,
311
+ soft_cap: float | None,
312
+ ):
313
+ num_heads = 128
314
+ num_seqs = 2
315
+ dtype = jnp.float32
316
+ rng = np.random.default_rng(1234)
317
+ q_lens = rng.integers(1, 100, num_seqs)
318
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
319
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
320
+ lkv_dim = 512
321
+ r_dim = 64
322
+ page_size = 16
323
+ num_pages = 1000
324
+
325
+ self._test_mla_ragged_paged_attention(
326
+ seq_lens,
327
+ num_heads,
328
+ lkv_dim,
329
+ r_dim,
330
+ page_size,
331
+ dtype,
332
+ dtype,
333
+ num_pages,
334
+ soft_cap=soft_cap,
335
+ )
336
+
337
+ def test_ragged_paged_attention_sliding_window_should_be_positive(self):
338
+ dtype = jnp.float32
339
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
340
+ num_heads = 128
341
+ lkv_dim = 512
342
+ r_dim = 64
343
+ page_size = 16
344
+ num_pages = 1000
345
+
346
+ with self.assertRaisesRegex(ValueError, "must be positive"):
347
+ self._test_mla_ragged_paged_attention(
348
+ seq_lens,
349
+ num_heads,
350
+ lkv_dim,
351
+ r_dim,
352
+ page_size,
353
+ dtype,
354
+ dtype,
355
+ num_pages,
356
+ sliding_window=0,
357
+ )
358
+
359
+ with self.assertRaisesRegex(ValueError, "must be positive"):
360
+ self._test_mla_ragged_paged_attention(
361
+ seq_lens,
362
+ num_heads,
363
+ lkv_dim,
364
+ r_dim,
365
+ page_size,
366
+ dtype,
367
+ dtype,
368
+ num_pages,
369
+ sliding_window=-1,
370
+ )
371
+
372
+ def test_ragged_paged_attention_soft_cap_cannot_be_zero(self):
373
+ dtype = jnp.float32
374
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
375
+ num_heads = 128
376
+ lkv_dim = 512
377
+ r_dim = 64
378
+ page_size = 16
379
+ num_pages = 1000
380
+
381
+ with self.assertRaisesRegex(ValueError, "must not be 0.0"):
382
+ self._test_mla_ragged_paged_attention(
383
+ seq_lens,
384
+ num_heads,
385
+ lkv_dim,
386
+ r_dim,
387
+ page_size,
388
+ dtype,
389
+ dtype,
390
+ num_pages,
391
+ soft_cap=0.0,
392
+ )
393
+
394
+
395
+ if __name__ == "__main__":
396
+ absltest.main(testLoader=jtu.JaxTestLoader())