tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,504 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from absl.testing import absltest, parameterized
5
+ from jax._src import dtypes
6
+ from jax._src import test_util as jtu
7
+
8
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import (
9
+ ragged_paged_attention, ref_ragged_paged_attention)
10
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
11
+ align_to, cdiv, get_dtype_packing)
12
+
13
+ jax.config.parse_flags_with_absl()
14
+
15
+
16
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
17
+ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
18
+
19
+ def _test_ragged_paged_attention(
20
+ self,
21
+ seq_lens, # List[(q_len, kv_len)]
22
+ num_heads, # [num_q_heads, num_kv_heads]
23
+ head_dim,
24
+ page_size,
25
+ q_dtype,
26
+ kv_dtype,
27
+ num_pages,
28
+ *,
29
+ num_kv_pages_per_block=8,
30
+ num_queries_per_block=64,
31
+ vmem_limit_bytes=100 * 1024 * 1024,
32
+ max_num_batched_tokens=512,
33
+ max_num_seq=8,
34
+ sliding_window: int | None = None,
35
+ soft_cap: float | None = None,
36
+ q_scale: float | None = None,
37
+ k_scale: float | None = None,
38
+ v_scale: float | None = None,
39
+ ):
40
+ rng = np.random.default_rng(1234)
41
+
42
+ def gen_random(shape, dtype):
43
+ return jnp.array(rng.random(size=shape,
44
+ dtype=np.float32)).astype(dtype)
45
+
46
+ if not jtu.is_device_tpu_at_least(version=4):
47
+ self.skipTest("Expect TPUv4+")
48
+ cu_q_lens = [0]
49
+ kv_lens = []
50
+ for q_len, kv_len in seq_lens:
51
+ assert q_len <= kv_len
52
+ cu_q_lens.append(cu_q_lens[-1] + q_len)
53
+ kv_lens.append(kv_len)
54
+
55
+ max_num_batched_tokens = max(align_to(cu_q_lens[-1], 128),
56
+ max_num_batched_tokens)
57
+ max_num_seq = max(align_to(len(seq_lens), 8), max_num_seq)
58
+ max_kv_len = max(kv_lens)
59
+ pages_per_seq = cdiv(max_kv_len, page_size)
60
+ num_q_heads, num_kv_heads = num_heads
61
+
62
+ q = gen_random((max_num_batched_tokens, num_q_heads, head_dim),
63
+ q_dtype)
64
+ k = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
65
+ kv_dtype)
66
+ v = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
67
+ kv_dtype)
68
+ page_cnt = 0
69
+ page_indices_list = []
70
+ kv_pages_list = []
71
+ kv_packing = get_dtype_packing(kv_dtype)
72
+ padded_head_dim = align_to(head_dim, 128)
73
+ num_kv_heads_x2 = align_to(num_kv_heads * 2, kv_packing)
74
+ for kv_len in kv_lens:
75
+ kv = gen_random((
76
+ kv_len,
77
+ num_kv_heads_x2 // kv_packing,
78
+ kv_packing,
79
+ padded_head_dim,
80
+ ), kv_dtype)
81
+ kv = jnp.pad(
82
+ kv,
83
+ (
84
+ (
85
+ 0,
86
+ cdiv(kv_len, page_size) * page_size - kv_len,
87
+ ),
88
+ (0, 0),
89
+ (0, 0),
90
+ (0, 0),
91
+ ),
92
+ constant_values=jnp.nan,
93
+ ).reshape(
94
+ -1,
95
+ page_size,
96
+ num_kv_heads_x2 // kv_packing,
97
+ kv_packing,
98
+ padded_head_dim,
99
+ )
100
+ indices = page_cnt + jnp.arange(kv.shape[0], dtype=jnp.int32)
101
+ indices = jnp.pad(
102
+ indices,
103
+ ((0, pages_per_seq - indices.shape[0]), ),
104
+ constant_values=jnp.nan,
105
+ )
106
+ page_indices_list.append(indices)
107
+ page_cnt += kv.shape[0]
108
+ kv_pages_list.append(kv)
109
+
110
+ kv_cache = jnp.concatenate(kv_pages_list, axis=0)
111
+ kv_cache = jnp.pad(
112
+ kv_cache,
113
+ ((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
114
+ (0, 0)),
115
+ constant_values=jnp.nan,
116
+ )
117
+ page_indices = jnp.stack(page_indices_list, axis=0)
118
+ page_indices = jnp.pad(
119
+ page_indices,
120
+ ((0, max_num_seq - page_indices.shape[0]), (0, 0)),
121
+ constant_values=jnp.nan,
122
+ )
123
+ page_indices = page_indices.reshape(-1)
124
+
125
+ cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32)
126
+ cu_q_lens = jnp.pad(cu_q_lens,
127
+ (0, max_num_seq + 1 - cu_q_lens.shape[0]))
128
+ kv_lens = jnp.array(kv_lens, dtype=jnp.int32)
129
+ kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0]))
130
+ distribution = jnp.array([0, 0, len(seq_lens)], dtype=jnp.int32)
131
+
132
+ args = (
133
+ q,
134
+ k,
135
+ v,
136
+ kv_cache,
137
+ kv_lens,
138
+ page_indices,
139
+ cu_q_lens,
140
+ distribution,
141
+ )
142
+
143
+ kwargs = {
144
+ "sliding_window": sliding_window,
145
+ "soft_cap": soft_cap,
146
+ "q_scale": q_scale,
147
+ "k_scale": k_scale,
148
+ "v_scale": v_scale,
149
+ }
150
+
151
+ expected, expected_kv_cache = ref_ragged_paged_attention(
152
+ *args,
153
+ **kwargs,
154
+ )
155
+
156
+ output, updated_kv_cache = ragged_paged_attention(
157
+ *args,
158
+ **kwargs,
159
+ num_kv_pages_per_block=num_kv_pages_per_block,
160
+ num_queries_per_block=num_queries_per_block,
161
+ vmem_limit_bytes=vmem_limit_bytes,
162
+ )
163
+ output = output[:cu_q_lens[distribution[-1]]]
164
+
165
+ dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
166
+ tols = {
167
+ 32: 0.15,
168
+ 16: 0.2,
169
+ 8: 0.2,
170
+ 4: 0.2,
171
+ }
172
+ tol = tols[dtype_bits]
173
+ self.assertAllClose(output, expected, atol=tol, rtol=tol)
174
+ mask = ~jnp.isnan(expected_kv_cache)
175
+ self.assertArraysEqual(updated_kv_cache[mask], expected_kv_cache[mask])
176
+ self.assertEqual(output.shape[-1], head_dim)
177
+
178
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
179
+ def test_ragged_paged_attention_basic(self, dtype):
180
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
181
+ num_heads = (32, 8)
182
+ head_dim = 128
183
+ page_size = 16
184
+ num_pages = 1000
185
+
186
+ self._test_ragged_paged_attention(
187
+ seq_lens,
188
+ num_heads,
189
+ head_dim,
190
+ page_size,
191
+ dtype,
192
+ dtype,
193
+ num_pages,
194
+ )
195
+
196
+ # TODO: support integer (int8, int4) and fp4 kv cache
197
+ @parameterized.product(
198
+ q_dtype=[jnp.bfloat16],
199
+ kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
200
+ kv_scales=[(0.5, 0.5), (1.0, 1.0)],
201
+ )
202
+ def test_ragged_paged_attention_quantized_kv_cache(self, q_dtype, kv_dtype,
203
+ kv_scales):
204
+ if not jtu.is_device_tpu_at_least(version=5):
205
+ self.skipTest("Expect TPUv5+")
206
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
207
+ num_heads = (32, 8)
208
+ head_dim = 128
209
+ page_size = 16
210
+ num_pages = 1000
211
+ k_scale, v_scale = kv_scales
212
+
213
+ self._test_ragged_paged_attention(
214
+ seq_lens,
215
+ num_heads,
216
+ head_dim,
217
+ page_size,
218
+ q_dtype,
219
+ kv_dtype,
220
+ num_pages,
221
+ k_scale=k_scale,
222
+ v_scale=v_scale,
223
+ )
224
+
225
+ @parameterized.product(
226
+ q_dtype=[jnp.bfloat16],
227
+ kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
228
+ q_scale=[0.5, 1.0],
229
+ kv_scales=[(0.5, 0.5), (1.0, 1.0)],
230
+ )
231
+ def test_ragged_paged_attention_quantized_attention(
232
+ self, q_dtype, kv_dtype, q_scale, kv_scales):
233
+ if not jtu.is_device_tpu_at_least(version=5):
234
+ self.skipTest("Expect TPUv5+")
235
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
236
+ num_heads = (32, 8)
237
+ head_dim = 128
238
+ page_size = 16
239
+ num_pages = 1000
240
+ k_scale, v_scale = kv_scales
241
+
242
+ self._test_ragged_paged_attention(
243
+ seq_lens,
244
+ num_heads,
245
+ head_dim,
246
+ page_size,
247
+ q_dtype,
248
+ kv_dtype,
249
+ num_pages,
250
+ q_scale=q_scale,
251
+ k_scale=k_scale,
252
+ v_scale=v_scale,
253
+ )
254
+
255
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
256
+ def test_ragged_paged_attention_decode_only(self, dtype):
257
+ seq_lens = [
258
+ (1, 18),
259
+ (1, 129),
260
+ (1, 597),
261
+ (1, 122),
262
+ (1, 64),
263
+ (1, 322),
264
+ (1, 463),
265
+ (1, 181),
266
+ (1, 1107),
267
+ (1, 123),
268
+ (1, 31),
269
+ (1, 18),
270
+ (1, 1229),
271
+ (1, 229),
272
+ (1, 87),
273
+ (1, 1328),
274
+ ]
275
+ num_heads = (32, 8)
276
+ head_dim = 128
277
+ page_size = 16
278
+ num_pages = 1000
279
+
280
+ self._test_ragged_paged_attention(
281
+ seq_lens,
282
+ num_heads,
283
+ head_dim,
284
+ page_size,
285
+ dtype,
286
+ dtype,
287
+ num_pages,
288
+ )
289
+
290
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
291
+ def test_ragged_paged_attention_prefill_only(self, dtype):
292
+ seq_lens = [
293
+ (5, 18),
294
+ (15, 129),
295
+ (120, 597),
296
+ (100, 122),
297
+ (21, 64),
298
+ (32, 322),
299
+ (251, 463),
300
+ (40, 181),
301
+ (64, 1107),
302
+ (99, 123),
303
+ (10, 31),
304
+ (5, 18),
305
+ (3, 1229),
306
+ (120, 229),
307
+ (9, 87),
308
+ (2, 1328),
309
+ ]
310
+ num_heads = (32, 8)
311
+ head_dim = 128
312
+ page_size = 16
313
+ num_pages = 1000
314
+
315
+ self._test_ragged_paged_attention(
316
+ seq_lens,
317
+ num_heads,
318
+ head_dim,
319
+ page_size,
320
+ dtype,
321
+ dtype,
322
+ num_pages,
323
+ )
324
+
325
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
326
+ def test_ragged_paged_attention_mixed(self, dtype):
327
+ seq_lens = [
328
+ (5, 18),
329
+ (1, 129),
330
+ (120, 597),
331
+ (1, 122),
332
+ (1, 64),
333
+ (32, 322),
334
+ (251, 463),
335
+ (1, 181),
336
+ (1, 1107),
337
+ (99, 123),
338
+ (1, 31),
339
+ (5, 18),
340
+ (3, 1229),
341
+ (117, 229),
342
+ (1, 87),
343
+ (1, 1328),
344
+ ]
345
+ num_heads = (32, 8)
346
+ head_dim = 128
347
+ page_size = 16
348
+ num_pages = 1000
349
+
350
+ self._test_ragged_paged_attention(
351
+ seq_lens,
352
+ num_heads,
353
+ head_dim,
354
+ page_size,
355
+ dtype,
356
+ dtype,
357
+ num_pages,
358
+ )
359
+
360
+ @parameterized.product(
361
+ num_seqs=[1, 17],
362
+ num_heads=[(32, 8), (12, 2), (5, 1), (3, 3)],
363
+ head_dim=[80, 240],
364
+ dtype=[jnp.float32, jnp.bfloat16],
365
+ # num_kv_pages_per_block=[8, 16],
366
+ # num_queries_per_block=[16, 32],
367
+ )
368
+ def test_ragged_paged_attention_complex(
369
+ self,
370
+ num_seqs,
371
+ num_heads,
372
+ head_dim,
373
+ dtype,
374
+ # num_kv_pages_per_block,
375
+ # num_queries_per_block,
376
+ ):
377
+ rng = np.random.default_rng(1234)
378
+ q_lens = rng.integers(1, 100, num_seqs)
379
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
380
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
381
+ page_size = 16
382
+ num_pages = 1000
383
+
384
+ self._test_ragged_paged_attention(
385
+ seq_lens,
386
+ num_heads,
387
+ head_dim,
388
+ page_size,
389
+ dtype,
390
+ dtype,
391
+ num_pages,
392
+ # num_kv_pages_per_block=num_kv_pages_per_block,
393
+ # num_queries_per_block=num_queries_per_block,
394
+ )
395
+
396
+ @parameterized.product(sliding_window=[None, 5, 128], )
397
+ def test_ragged_paged_attention_sliding_window(
398
+ self,
399
+ sliding_window: int | None,
400
+ ):
401
+ num_seqs = 5
402
+ num_heads = (4, 4)
403
+ dtype = jnp.float32
404
+ rng = np.random.default_rng(1234)
405
+ q_lens = rng.integers(1, 100, num_seqs)
406
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
407
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
408
+ head_dim = 128
409
+ page_size = 16
410
+ num_pages = 1000
411
+
412
+ self._test_ragged_paged_attention(
413
+ seq_lens,
414
+ num_heads,
415
+ head_dim,
416
+ page_size,
417
+ dtype,
418
+ dtype,
419
+ num_pages,
420
+ sliding_window=sliding_window,
421
+ )
422
+
423
+ @parameterized.product(soft_cap=[None, 50.0], )
424
+ def test_ragged_paged_attention_logit_soft_capping(
425
+ self,
426
+ soft_cap: float | None,
427
+ ):
428
+ num_heads = (16, 2)
429
+ num_seqs = 2
430
+ dtype = jnp.float32
431
+ rng = np.random.default_rng(1234)
432
+ q_lens = rng.integers(1, 100, num_seqs)
433
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
434
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
435
+ head_dim = 128
436
+ page_size = 16
437
+ num_pages = 1000
438
+
439
+ self._test_ragged_paged_attention(
440
+ seq_lens,
441
+ num_heads,
442
+ head_dim,
443
+ page_size,
444
+ dtype,
445
+ dtype,
446
+ num_pages,
447
+ soft_cap=soft_cap,
448
+ )
449
+
450
+ def test_ragged_paged_attention_sliding_window_should_be_positive(self):
451
+ dtype = jnp.float32
452
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
453
+ num_heads = (32, 8)
454
+ head_dim = 128
455
+ page_size = 16
456
+ num_pages = 1000
457
+
458
+ with self.assertRaisesRegex(ValueError, "must be positive"):
459
+ self._test_ragged_paged_attention(
460
+ seq_lens,
461
+ num_heads,
462
+ head_dim,
463
+ page_size,
464
+ dtype,
465
+ dtype,
466
+ num_pages,
467
+ sliding_window=0,
468
+ )
469
+
470
+ with self.assertRaisesRegex(ValueError, "must be positive"):
471
+ self._test_ragged_paged_attention(
472
+ seq_lens,
473
+ num_heads,
474
+ head_dim,
475
+ page_size,
476
+ dtype,
477
+ dtype,
478
+ num_pages,
479
+ sliding_window=-1,
480
+ )
481
+
482
+ def test_ragged_paged_attention_soft_cap_cannot_be_zero(self):
483
+ dtype = jnp.float32
484
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
485
+ num_heads = (32, 8)
486
+ head_dim = 128
487
+ page_size = 16
488
+ num_pages = 1000
489
+
490
+ with self.assertRaisesRegex(ValueError, "must not be 0.0"):
491
+ self._test_ragged_paged_attention(
492
+ seq_lens,
493
+ num_heads,
494
+ head_dim,
495
+ page_size,
496
+ dtype,
497
+ dtype,
498
+ num_pages,
499
+ soft_cap=0.0,
500
+ )
501
+
502
+
503
+ if __name__ == "__main__":
504
+ absltest.main(testLoader=jtu.JaxTestLoader())
tests/lora/__init__.py ADDED
File without changes
tests/lora/conftest.py ADDED
@@ -0,0 +1,32 @@
1
+ import tempfile
2
+
3
+ import pytest
4
+ from vllm.config import set_current_vllm_config
5
+ from vllm.distributed import cleanup_dist_env_and_memory
6
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
7
+ init_distributed_environment)
8
+ from vllm.engine.arg_utils import EngineArgs
9
+
10
+
11
+ @pytest.fixture
12
+ def dist_init():
13
+ engine_args = EngineArgs(
14
+ model="Qwen/Qwen2-1.5B-Instruct",
15
+ max_model_len=64,
16
+ max_num_batched_tokens=64,
17
+ max_num_seqs=4,
18
+ )
19
+
20
+ vllm_config = engine_args.create_engine_config()
21
+
22
+ with set_current_vllm_config(vllm_config):
23
+ temp_file = tempfile.mkstemp()[1]
24
+ init_distributed_environment(
25
+ 1,
26
+ 0,
27
+ local_rank=0,
28
+ distributed_init_method=f"file://{temp_file}",
29
+ backend="gloo")
30
+ ensure_model_parallel_initialized(1, 1)
31
+ yield vllm_config
32
+ cleanup_dist_env_and_memory(shutdown_ray=True)
@@ -0,0 +1,43 @@
1
+ import jax
2
+ import torch
3
+ import torchax
4
+
5
+ from tpu_inference.lora.torch_lora_ops import bgmv_torch
6
+
7
+
8
+ def test_bgmv_torch():
9
+ num_tokens = 16
10
+ hidden_size = 128
11
+ max_loras = 9
12
+ max_lora_rank = 8
13
+
14
+ with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
15
+ inputs = torch.rand(num_tokens, hidden_size, device='jax')
16
+ loras = torch.rand(max_loras,
17
+ 1,
18
+ max_lora_rank,
19
+ hidden_size,
20
+ device='jax')
21
+ idxs = torch.randint(0, max_loras, (num_tokens, ), device='jax')
22
+
23
+ actual = bgmv_torch(inputs, loras, idxs)
24
+ expected = _ref_bgmv_torch(inputs, loras, idxs)
25
+ torch.testing.assert_close(actual, expected, atol=3e-2, rtol=1e-3)
26
+
27
+
28
+ def _ref_bgmv_torch(inputs, loras, idxs):
29
+ if len(loras.shape) == 4:
30
+ loras = loras.squeeze(axis=1)
31
+
32
+ # Another equivalent ref impl is as the 2 lines below.
33
+ # selected_loras = loras[idxs]
34
+ # return torch.einsum('td,tld->tl', inputs, selected_loras)
35
+ num_tokens, _ = inputs.shape
36
+ outputs = []
37
+ for i in range(num_tokens):
38
+ input = inputs[i] # [hidden_size]
39
+ lora = loras[idxs[i]] # [max_lora_rank, hidden_size]
40
+ out = torch.matmul(lora, input)
41
+ outputs.append(out)
42
+
43
+ return torch.stack(outputs, axis=0)