tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,549 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from absl.testing import absltest, parameterized
5
+ from jax._src import dtypes
6
+ from jax._src import test_util as jtu
7
+
8
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 import (
9
+ ragged_paged_attention_hd64, ref_ragged_paged_attention_hd64)
10
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
11
+ align_to, cdiv, get_dtype_packing)
12
+
13
+ jax.config.parse_flags_with_absl()
14
+
15
+
16
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
17
+ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
18
+
19
+ def _test_ragged_paged_attention_hd64(
20
+ self,
21
+ seq_lens, # List[(q_len, kv_len)]
22
+ num_heads, # [num_q_heads, num_kv_heads]
23
+ head_dim,
24
+ page_size,
25
+ q_dtype,
26
+ kv_dtype,
27
+ num_pages,
28
+ *,
29
+ num_kv_pages_per_block=8,
30
+ num_queries_per_block=64,
31
+ vmem_limit_bytes=100 * 1024 * 1024,
32
+ max_num_batched_tokens=512,
33
+ max_num_seq=8,
34
+ sliding_window: int | None = None,
35
+ soft_cap: float | None = None,
36
+ q_scale: float | None = None,
37
+ k_scale: float | None = None,
38
+ v_scale: float | None = None,
39
+ use_attention_sink: bool = False,
40
+ ):
41
+ assert head_dim == 64
42
+ rng = np.random.default_rng(1234)
43
+
44
+ def gen_random(shape, dtype):
45
+ return jnp.array(rng.random(size=shape,
46
+ dtype=np.float32)).astype(dtype)
47
+
48
+ if not jtu.is_device_tpu_at_least(version=4):
49
+ self.skipTest("Expect TPUv4+")
50
+ cu_q_lens = [0]
51
+ kv_lens = []
52
+ seq_lens = sorted(seq_lens, key=lambda x: x[0])
53
+ num_decoding_seqs = sum(q_len for q_len, _ in seq_lens if q_len == 1)
54
+ for q_len, kv_len in seq_lens:
55
+ assert q_len <= kv_len
56
+ cu_q_lens.append(cu_q_lens[-1] + q_len)
57
+ kv_lens.append(kv_len)
58
+
59
+ max_num_batched_tokens = max(align_to(cu_q_lens[-1], 128),
60
+ max_num_batched_tokens)
61
+ max_num_seq = max(align_to(len(seq_lens), 8), max_num_seq)
62
+ max_kv_len = max(kv_lens)
63
+ pages_per_seq = cdiv(max_kv_len, page_size)
64
+ num_q_heads, num_kv_heads = num_heads
65
+
66
+ q = gen_random((max_num_batched_tokens, num_q_heads, head_dim),
67
+ q_dtype)
68
+ k = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
69
+ kv_dtype)
70
+ v = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
71
+ kv_dtype)
72
+ attention_sink = gen_random(
73
+ (num_q_heads), jnp.float32) * 10**2 if use_attention_sink else None
74
+
75
+ page_cnt = 0
76
+ page_indices_list = []
77
+ kv_pages_list = []
78
+ kv_packing = get_dtype_packing(kv_dtype)
79
+ padded_head_dim = align_to(head_dim, 128)
80
+ padded_num_kv_heads = align_to(num_kv_heads, kv_packing)
81
+ for kv_len in kv_lens:
82
+ kv = gen_random(
83
+ (
84
+ kv_len,
85
+ padded_num_kv_heads // kv_packing,
86
+ kv_packing,
87
+ padded_head_dim,
88
+ ),
89
+ kv_dtype,
90
+ )
91
+ kv = jnp.pad(
92
+ kv,
93
+ (
94
+ (
95
+ 0,
96
+ cdiv(kv_len, page_size) * page_size - kv_len,
97
+ ),
98
+ (0, 0),
99
+ (0, 0),
100
+ (0, 0),
101
+ ),
102
+ constant_values=jnp.nan,
103
+ ).reshape(
104
+ -1,
105
+ page_size,
106
+ padded_num_kv_heads // kv_packing,
107
+ kv_packing,
108
+ padded_head_dim,
109
+ )
110
+ indices = page_cnt + jnp.arange(kv.shape[0], dtype=jnp.int32)
111
+ indices = jnp.pad(
112
+ indices,
113
+ ((0, pages_per_seq - indices.shape[0]), ),
114
+ constant_values=jnp.nan,
115
+ )
116
+ page_indices_list.append(indices)
117
+ page_cnt += kv.shape[0]
118
+ kv_pages_list.append(kv)
119
+
120
+ kv_cache = jnp.concatenate(kv_pages_list, axis=0)
121
+ kv_cache = jnp.pad(
122
+ kv_cache,
123
+ ((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
124
+ (0, 0)),
125
+ constant_values=jnp.nan,
126
+ )
127
+ page_indices = jnp.stack(page_indices_list, axis=0)
128
+ page_indices = jnp.pad(
129
+ page_indices,
130
+ ((0, max_num_seq - page_indices.shape[0]), (0, 0)),
131
+ constant_values=jnp.nan,
132
+ )
133
+ page_indices = page_indices.reshape(-1)
134
+
135
+ cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32)
136
+ cu_q_lens = jnp.pad(cu_q_lens,
137
+ (0, max_num_seq + 1 - cu_q_lens.shape[0]))
138
+ kv_lens = jnp.array(kv_lens, dtype=jnp.int32)
139
+ kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0]))
140
+ distribution = jnp.array(
141
+ [num_decoding_seqs, num_decoding_seqs,
142
+ len(seq_lens)],
143
+ dtype=jnp.int32)
144
+
145
+ args = (
146
+ q,
147
+ k,
148
+ v,
149
+ kv_cache,
150
+ kv_lens,
151
+ page_indices,
152
+ cu_q_lens,
153
+ distribution,
154
+ attention_sink,
155
+ )
156
+
157
+ kwargs = {
158
+ "sliding_window": sliding_window,
159
+ "soft_cap": soft_cap,
160
+ "q_scale": q_scale,
161
+ "k_scale": k_scale,
162
+ "v_scale": v_scale,
163
+ }
164
+
165
+ expected, expected_kv_cache = ref_ragged_paged_attention_hd64(
166
+ *args,
167
+ **kwargs,
168
+ )
169
+
170
+ output, updated_kv_cache = ragged_paged_attention_hd64(
171
+ *args,
172
+ **kwargs,
173
+ num_kv_pages_per_block=num_kv_pages_per_block,
174
+ num_queries_per_block=num_queries_per_block,
175
+ vmem_limit_bytes=vmem_limit_bytes,
176
+ )
177
+ output = output[:cu_q_lens[distribution[-1]]]
178
+
179
+ dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
180
+ tols = {
181
+ 32: 0.15,
182
+ 16: 0.2,
183
+ 8: 0.2,
184
+ 4: 0.2,
185
+ }
186
+ tol = tols[dtype_bits]
187
+ self.assertAllClose(output, expected, atol=tol, rtol=tol)
188
+ mask = ~jnp.isnan(expected_kv_cache)
189
+ self.assertArraysEqual(updated_kv_cache[mask], expected_kv_cache[mask])
190
+ self.assertEqual(output.shape[-1], head_dim)
191
+
192
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
193
+ def test_ragged_paged_attention_basic(self, dtype):
194
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
195
+ num_heads = (32, 8)
196
+ head_dim = 64
197
+ page_size = 16
198
+ num_pages = 1000
199
+
200
+ self._test_ragged_paged_attention_hd64(
201
+ seq_lens,
202
+ num_heads,
203
+ head_dim,
204
+ page_size,
205
+ dtype,
206
+ dtype,
207
+ num_pages,
208
+ )
209
+
210
+ # TODO: support integer (int8, int4) and fp4 kv cache
211
+ @parameterized.product(
212
+ q_dtype=[jnp.bfloat16],
213
+ kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
214
+ kv_scales=[(0.5, 0.5), (None, None)],
215
+ )
216
+ def test_ragged_paged_attention_quantized_kv_cache(self, q_dtype, kv_dtype,
217
+ kv_scales):
218
+ if not jtu.is_device_tpu_at_least(version=5):
219
+ self.skipTest("Expect TPUv5+")
220
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
221
+ num_heads = (32, 8)
222
+ head_dim = 64
223
+ page_size = 16
224
+ num_pages = 1000
225
+ k_scale, v_scale = kv_scales
226
+
227
+ self._test_ragged_paged_attention_hd64(
228
+ seq_lens,
229
+ num_heads,
230
+ head_dim,
231
+ page_size,
232
+ q_dtype,
233
+ kv_dtype,
234
+ num_pages,
235
+ k_scale=k_scale,
236
+ v_scale=v_scale,
237
+ )
238
+
239
+ @parameterized.product(
240
+ q_dtype=[jnp.bfloat16],
241
+ kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
242
+ q_scale=[0.5],
243
+ kv_scales=[(0.5, 0.5), (None, None)],
244
+ )
245
+ def test_ragged_paged_attention_quantized_attention(
246
+ self, q_dtype, kv_dtype, q_scale, kv_scales):
247
+ if not jtu.is_device_tpu_at_least(version=5):
248
+ self.skipTest("Expect TPUv5+")
249
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
250
+ num_heads = (32, 8)
251
+ head_dim = 64
252
+ page_size = 16
253
+ num_pages = 1000
254
+ k_scale, v_scale = kv_scales
255
+
256
+ self._test_ragged_paged_attention_hd64(
257
+ seq_lens,
258
+ num_heads,
259
+ head_dim,
260
+ page_size,
261
+ q_dtype,
262
+ kv_dtype,
263
+ num_pages,
264
+ q_scale=q_scale,
265
+ k_scale=k_scale,
266
+ v_scale=v_scale,
267
+ )
268
+
269
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
270
+ def test_ragged_paged_attention_decode_only(self, dtype):
271
+ seq_lens = [
272
+ (1, 18),
273
+ (1, 129),
274
+ (1, 597),
275
+ (1, 122),
276
+ (1, 64),
277
+ (1, 322),
278
+ (1, 463),
279
+ (1, 181),
280
+ (1, 1107),
281
+ (1, 123),
282
+ (1, 31),
283
+ (1, 18),
284
+ (1, 1229),
285
+ (1, 229),
286
+ (1, 87),
287
+ (1, 1328),
288
+ ]
289
+ num_heads = (32, 8)
290
+ head_dim = 64
291
+ page_size = 16
292
+ num_pages = 1000
293
+
294
+ self._test_ragged_paged_attention_hd64(
295
+ seq_lens,
296
+ num_heads,
297
+ head_dim,
298
+ page_size,
299
+ dtype,
300
+ dtype,
301
+ num_pages,
302
+ )
303
+
304
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
305
+ def test_ragged_paged_attention_prefill_only(self, dtype):
306
+ seq_lens = [
307
+ (5, 18),
308
+ (15, 129),
309
+ (120, 597),
310
+ (100, 122),
311
+ (21, 64),
312
+ (32, 322),
313
+ (251, 463),
314
+ (40, 181),
315
+ (64, 1107),
316
+ (99, 123),
317
+ (10, 31),
318
+ (5, 18),
319
+ (3, 1229),
320
+ (120, 229),
321
+ (9, 87),
322
+ (2, 1328),
323
+ ]
324
+ num_heads = (32, 8)
325
+ head_dim = 64
326
+ page_size = 16
327
+ num_pages = 1000
328
+
329
+ self._test_ragged_paged_attention_hd64(
330
+ seq_lens,
331
+ num_heads,
332
+ head_dim,
333
+ page_size,
334
+ dtype,
335
+ dtype,
336
+ num_pages,
337
+ )
338
+
339
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
340
+ def test_ragged_paged_attention_mixed(self, dtype):
341
+ seq_lens = [
342
+ (5, 18),
343
+ (1, 129),
344
+ (120, 597),
345
+ (1, 122),
346
+ (1, 64),
347
+ (32, 322),
348
+ (251, 463),
349
+ (1, 181),
350
+ (1, 1107),
351
+ (99, 123),
352
+ (1, 31),
353
+ (5, 18),
354
+ (3, 1229),
355
+ (117, 229),
356
+ (1, 87),
357
+ (1, 1328),
358
+ ]
359
+ num_heads = (32, 8)
360
+ head_dim = 64
361
+ page_size = 16
362
+ num_pages = 1000
363
+
364
+ self._test_ragged_paged_attention_hd64(
365
+ seq_lens,
366
+ num_heads,
367
+ head_dim,
368
+ page_size,
369
+ dtype,
370
+ dtype,
371
+ num_pages,
372
+ )
373
+
374
+ @parameterized.product(
375
+ num_seqs=[1, 17],
376
+ num_heads=[(32, 8), (12, 2), (5, 1), (3, 3)],
377
+ head_dim=[64],
378
+ dtype=[jnp.float32, jnp.bfloat16],
379
+ # num_kv_pages_per_block=[8, 16],
380
+ # num_queries_per_block=[16, 32],
381
+ )
382
+ def test_ragged_paged_attention_complex(
383
+ self,
384
+ num_seqs,
385
+ num_heads,
386
+ head_dim,
387
+ dtype,
388
+ # num_kv_pages_per_block,
389
+ # num_queries_per_block,
390
+ ):
391
+ rng = np.random.default_rng(1234)
392
+ q_lens = rng.integers(1, 100, num_seqs)
393
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
394
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
395
+ page_size = 16
396
+ num_pages = 1000
397
+
398
+ self._test_ragged_paged_attention_hd64(
399
+ seq_lens,
400
+ num_heads,
401
+ head_dim,
402
+ page_size,
403
+ dtype,
404
+ dtype,
405
+ num_pages,
406
+ # num_kv_pages_per_block=num_kv_pages_per_block,
407
+ # num_queries_per_block=num_queries_per_block,
408
+ )
409
+
410
+ @parameterized.product(sliding_window=[None, 5, 128], )
411
+ def test_ragged_paged_attention_sliding_window(
412
+ self,
413
+ sliding_window: int | None,
414
+ ):
415
+ num_seqs = 5
416
+ num_heads = (4, 4)
417
+ dtype = jnp.float32
418
+ rng = np.random.default_rng(1234)
419
+ q_lens = rng.integers(1, 100, num_seqs)
420
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
421
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
422
+ head_dim = 64
423
+ page_size = 16
424
+ num_pages = 1000
425
+
426
+ self._test_ragged_paged_attention_hd64(
427
+ seq_lens,
428
+ num_heads,
429
+ head_dim,
430
+ page_size,
431
+ dtype,
432
+ dtype,
433
+ num_pages,
434
+ sliding_window=sliding_window,
435
+ )
436
+
437
+ @parameterized.product(
438
+ sliding_window=[5, 128],
439
+ num_heads=[(4, 4), (8, 4), (64, 8)],
440
+ )
441
+ def test_ragged_paged_attention_sliding_window_with_attention_sink_hd64(
442
+ self,
443
+ sliding_window: int | None,
444
+ num_heads: tuple[int, int],
445
+ ):
446
+ num_seqs = 5
447
+ dtype = jnp.float32
448
+ rng = np.random.default_rng(1234)
449
+ q_lens = rng.integers(1, 100, num_seqs)
450
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
451
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
452
+ head_dim = 64
453
+ page_size = 16
454
+ num_pages = 1000
455
+
456
+ self._test_ragged_paged_attention_hd64(
457
+ seq_lens,
458
+ num_heads,
459
+ head_dim,
460
+ page_size,
461
+ dtype,
462
+ dtype,
463
+ num_pages,
464
+ sliding_window=sliding_window,
465
+ use_attention_sink=True,
466
+ )
467
+
468
+ @parameterized.product(soft_cap=[None, 50.0], )
469
+ def test_ragged_paged_attention_logit_soft_capping(
470
+ self,
471
+ soft_cap: float | None,
472
+ ):
473
+ num_heads = (16, 2)
474
+ num_seqs = 2
475
+ dtype = jnp.float32
476
+ rng = np.random.default_rng(1234)
477
+ q_lens = rng.integers(1, 100, num_seqs)
478
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
479
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
480
+ head_dim = 64
481
+ page_size = 16
482
+ num_pages = 1000
483
+
484
+ self._test_ragged_paged_attention_hd64(
485
+ seq_lens,
486
+ num_heads,
487
+ head_dim,
488
+ page_size,
489
+ dtype,
490
+ dtype,
491
+ num_pages,
492
+ soft_cap=soft_cap,
493
+ )
494
+
495
+ def test_ragged_paged_attention_sliding_window_should_be_positive(self):
496
+ dtype = jnp.float32
497
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
498
+ num_heads = (32, 8)
499
+ head_dim = 64
500
+ page_size = 16
501
+ num_pages = 1000
502
+
503
+ with self.assertRaisesRegex(ValueError, "must be positive"):
504
+ self._test_ragged_paged_attention_hd64(
505
+ seq_lens,
506
+ num_heads,
507
+ head_dim,
508
+ page_size,
509
+ dtype,
510
+ dtype,
511
+ num_pages,
512
+ sliding_window=0,
513
+ )
514
+
515
+ with self.assertRaisesRegex(ValueError, "must be positive"):
516
+ self._test_ragged_paged_attention_hd64(
517
+ seq_lens,
518
+ num_heads,
519
+ head_dim,
520
+ page_size,
521
+ dtype,
522
+ dtype,
523
+ num_pages,
524
+ sliding_window=-1,
525
+ )
526
+
527
+ def test_ragged_paged_attention_soft_cap_cannot_be_zero(self):
528
+ dtype = jnp.float32
529
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
530
+ num_heads = (32, 8)
531
+ head_dim = 64
532
+ page_size = 16
533
+ num_pages = 1000
534
+
535
+ with self.assertRaisesRegex(ValueError, "must not be 0.0"):
536
+ self._test_ragged_paged_attention_hd64(
537
+ seq_lens,
538
+ num_heads,
539
+ head_dim,
540
+ page_size,
541
+ dtype,
542
+ dtype,
543
+ num_pages,
544
+ soft_cap=0.0,
545
+ )
546
+
547
+
548
+ if __name__ == "__main__":
549
+ absltest.main(testLoader=jtu.JaxTestLoader())