tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,504 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import numpy as np
4
+ from absl.testing import absltest, parameterized
5
+ from jax._src import dtypes
6
+ from jax._src import test_util as jtu
7
+
8
+ from tpu_inference.kernels.ragged_paged_attention.v3.kernel import (
9
+ ragged_paged_attention, ref_ragged_paged_attention)
10
+ from tpu_inference.kernels.ragged_paged_attention.v3.util import (
11
+ align_to, cdiv, get_dtype_packing)
12
+
13
+ jax.config.parse_flags_with_absl()
14
+
15
+
16
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
17
+ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
18
+
19
+ def _test_ragged_paged_attention(
20
+ self,
21
+ seq_lens, # List[(q_len, kv_len)]
22
+ num_heads, # [num_q_heads, num_kv_heads]
23
+ head_dim,
24
+ page_size,
25
+ q_dtype,
26
+ kv_dtype,
27
+ num_pages,
28
+ *,
29
+ num_kv_pages_per_block=8,
30
+ num_queries_per_block=64,
31
+ vmem_limit_bytes=100 * 1024 * 1024,
32
+ max_num_batched_tokens=512,
33
+ max_num_seq=8,
34
+ sliding_window: int | None = None,
35
+ soft_cap: float | None = None,
36
+ q_scale: float | None = None,
37
+ k_scale: float | None = None,
38
+ v_scale: float | None = None,
39
+ ):
40
+ rng = np.random.default_rng(1234)
41
+
42
+ def gen_random(shape, dtype):
43
+ return jnp.array(rng.random(size=shape,
44
+ dtype=np.float32)).astype(dtype)
45
+
46
+ if not jtu.is_device_tpu_at_least(version=4):
47
+ self.skipTest("Expect TPUv4+")
48
+ cu_q_lens = [0]
49
+ kv_lens = []
50
+ for q_len, kv_len in seq_lens:
51
+ assert q_len <= kv_len
52
+ cu_q_lens.append(cu_q_lens[-1] + q_len)
53
+ kv_lens.append(kv_len)
54
+
55
+ max_num_batched_tokens = max(align_to(cu_q_lens[-1], 128),
56
+ max_num_batched_tokens)
57
+ max_num_seq = max(align_to(len(seq_lens), 8), max_num_seq)
58
+ max_kv_len = max(kv_lens)
59
+ pages_per_seq = cdiv(max_kv_len, page_size)
60
+ num_q_heads, num_kv_heads = num_heads
61
+
62
+ q = gen_random((max_num_batched_tokens, num_q_heads, head_dim),
63
+ q_dtype)
64
+ k = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
65
+ kv_dtype)
66
+ v = gen_random((max_num_batched_tokens, num_kv_heads, head_dim),
67
+ kv_dtype)
68
+ page_cnt = 0
69
+ page_indices_list = []
70
+ kv_pages_list = []
71
+ kv_packing = get_dtype_packing(kv_dtype)
72
+ padded_head_dim = align_to(head_dim, 128)
73
+ num_kv_heads_x2 = align_to(num_kv_heads * 2, kv_packing)
74
+ for kv_len in kv_lens:
75
+ kv = gen_random((
76
+ kv_len,
77
+ num_kv_heads_x2 // kv_packing,
78
+ kv_packing,
79
+ padded_head_dim,
80
+ ), kv_dtype)
81
+ kv = jnp.pad(
82
+ kv,
83
+ (
84
+ (
85
+ 0,
86
+ cdiv(kv_len, page_size) * page_size - kv_len,
87
+ ),
88
+ (0, 0),
89
+ (0, 0),
90
+ (0, 0),
91
+ ),
92
+ constant_values=jnp.nan,
93
+ ).reshape(
94
+ -1,
95
+ page_size,
96
+ num_kv_heads_x2 // kv_packing,
97
+ kv_packing,
98
+ padded_head_dim,
99
+ )
100
+ indices = page_cnt + jnp.arange(kv.shape[0], dtype=jnp.int32)
101
+ indices = jnp.pad(
102
+ indices,
103
+ ((0, pages_per_seq - indices.shape[0]), ),
104
+ constant_values=jnp.nan,
105
+ )
106
+ page_indices_list.append(indices)
107
+ page_cnt += kv.shape[0]
108
+ kv_pages_list.append(kv)
109
+
110
+ kv_cache = jnp.concatenate(kv_pages_list, axis=0)
111
+ kv_cache = jnp.pad(
112
+ kv_cache,
113
+ ((0, num_pages - kv_cache.shape[0]), (0, 0), (0, 0), (0, 0),
114
+ (0, 0)),
115
+ constant_values=jnp.nan,
116
+ )
117
+ page_indices = jnp.stack(page_indices_list, axis=0)
118
+ page_indices = jnp.pad(
119
+ page_indices,
120
+ ((0, max_num_seq - page_indices.shape[0]), (0, 0)),
121
+ constant_values=jnp.nan,
122
+ )
123
+ page_indices = page_indices.reshape(-1)
124
+
125
+ cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32)
126
+ cu_q_lens = jnp.pad(cu_q_lens,
127
+ (0, max_num_seq + 1 - cu_q_lens.shape[0]))
128
+ kv_lens = jnp.array(kv_lens, dtype=jnp.int32)
129
+ kv_lens = jnp.pad(kv_lens, (0, max_num_seq - kv_lens.shape[0]))
130
+ distribution = jnp.array([0, 0, len(seq_lens)], dtype=jnp.int32)
131
+
132
+ args = (
133
+ q,
134
+ k,
135
+ v,
136
+ kv_cache,
137
+ kv_lens,
138
+ page_indices,
139
+ cu_q_lens,
140
+ distribution,
141
+ )
142
+
143
+ kwargs = {
144
+ "sliding_window": sliding_window,
145
+ "soft_cap": soft_cap,
146
+ "q_scale": q_scale,
147
+ "k_scale": k_scale,
148
+ "v_scale": v_scale,
149
+ }
150
+
151
+ expected, expected_kv_cache = ref_ragged_paged_attention(
152
+ *args,
153
+ **kwargs,
154
+ )
155
+
156
+ output, updated_kv_cache = ragged_paged_attention(
157
+ *args,
158
+ **kwargs,
159
+ num_kv_pages_per_block=num_kv_pages_per_block,
160
+ num_queries_per_block=num_queries_per_block,
161
+ vmem_limit_bytes=vmem_limit_bytes,
162
+ )
163
+ output = output[:cu_q_lens[distribution[-1]]]
164
+
165
+ dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
166
+ tols = {
167
+ 32: 0.15,
168
+ 16: 0.2,
169
+ 8: 0.2,
170
+ 4: 0.2,
171
+ }
172
+ tol = tols[dtype_bits]
173
+ self.assertAllClose(output, expected, atol=tol, rtol=tol)
174
+ mask = ~jnp.isnan(expected_kv_cache)
175
+ self.assertArraysEqual(updated_kv_cache[mask], expected_kv_cache[mask])
176
+ self.assertEqual(output.shape[-1], head_dim)
177
+
178
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
179
+ def test_ragged_paged_attention_basic(self, dtype):
180
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
181
+ num_heads = (32, 8)
182
+ head_dim = 128
183
+ page_size = 16
184
+ num_pages = 1000
185
+
186
+ self._test_ragged_paged_attention(
187
+ seq_lens,
188
+ num_heads,
189
+ head_dim,
190
+ page_size,
191
+ dtype,
192
+ dtype,
193
+ num_pages,
194
+ )
195
+
196
+ # TODO: support integer (int8, int4) and fp4 kv cache
197
+ @parameterized.product(
198
+ q_dtype=[jnp.bfloat16],
199
+ kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
200
+ kv_scales=[(0.5, 0.5), (1.0, 1.0)],
201
+ )
202
+ def test_ragged_paged_attention_quantized_kv_cache(self, q_dtype, kv_dtype,
203
+ kv_scales):
204
+ if not jtu.is_device_tpu_at_least(version=5):
205
+ self.skipTest("Expect TPUv5+")
206
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
207
+ num_heads = (32, 8)
208
+ head_dim = 128
209
+ page_size = 16
210
+ num_pages = 1000
211
+ k_scale, v_scale = kv_scales
212
+
213
+ self._test_ragged_paged_attention(
214
+ seq_lens,
215
+ num_heads,
216
+ head_dim,
217
+ page_size,
218
+ q_dtype,
219
+ kv_dtype,
220
+ num_pages,
221
+ k_scale=k_scale,
222
+ v_scale=v_scale,
223
+ )
224
+
225
+ @parameterized.product(
226
+ q_dtype=[jnp.bfloat16],
227
+ kv_dtype=[jnp.float8_e5m2, jnp.float8_e4m3fn],
228
+ q_scale=[0.5, 1.0],
229
+ kv_scales=[(0.5, 0.5), (1.0, 1.0)],
230
+ )
231
+ def test_ragged_paged_attention_quantized_attention(
232
+ self, q_dtype, kv_dtype, q_scale, kv_scales):
233
+ if not jtu.is_device_tpu_at_least(version=5):
234
+ self.skipTest("Expect TPUv5+")
235
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
236
+ num_heads = (32, 8)
237
+ head_dim = 128
238
+ page_size = 16
239
+ num_pages = 1000
240
+ k_scale, v_scale = kv_scales
241
+
242
+ self._test_ragged_paged_attention(
243
+ seq_lens,
244
+ num_heads,
245
+ head_dim,
246
+ page_size,
247
+ q_dtype,
248
+ kv_dtype,
249
+ num_pages,
250
+ q_scale=q_scale,
251
+ k_scale=k_scale,
252
+ v_scale=v_scale,
253
+ )
254
+
255
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
256
+ def test_ragged_paged_attention_decode_only(self, dtype):
257
+ seq_lens = [
258
+ (1, 18),
259
+ (1, 129),
260
+ (1, 597),
261
+ (1, 122),
262
+ (1, 64),
263
+ (1, 322),
264
+ (1, 463),
265
+ (1, 181),
266
+ (1, 1107),
267
+ (1, 123),
268
+ (1, 31),
269
+ (1, 18),
270
+ (1, 1229),
271
+ (1, 229),
272
+ (1, 87),
273
+ (1, 1328),
274
+ ]
275
+ num_heads = (32, 8)
276
+ head_dim = 128
277
+ page_size = 16
278
+ num_pages = 1000
279
+
280
+ self._test_ragged_paged_attention(
281
+ seq_lens,
282
+ num_heads,
283
+ head_dim,
284
+ page_size,
285
+ dtype,
286
+ dtype,
287
+ num_pages,
288
+ )
289
+
290
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
291
+ def test_ragged_paged_attention_prefill_only(self, dtype):
292
+ seq_lens = [
293
+ (5, 18),
294
+ (15, 129),
295
+ (120, 597),
296
+ (100, 122),
297
+ (21, 64),
298
+ (32, 322),
299
+ (251, 463),
300
+ (40, 181),
301
+ (64, 1107),
302
+ (99, 123),
303
+ (10, 31),
304
+ (5, 18),
305
+ (3, 1229),
306
+ (120, 229),
307
+ (9, 87),
308
+ (2, 1328),
309
+ ]
310
+ num_heads = (32, 8)
311
+ head_dim = 128
312
+ page_size = 16
313
+ num_pages = 1000
314
+
315
+ self._test_ragged_paged_attention(
316
+ seq_lens,
317
+ num_heads,
318
+ head_dim,
319
+ page_size,
320
+ dtype,
321
+ dtype,
322
+ num_pages,
323
+ )
324
+
325
+ @parameterized.product(dtype=[jnp.float32, jnp.bfloat16], )
326
+ def test_ragged_paged_attention_mixed(self, dtype):
327
+ seq_lens = [
328
+ (5, 18),
329
+ (1, 129),
330
+ (120, 597),
331
+ (1, 122),
332
+ (1, 64),
333
+ (32, 322),
334
+ (251, 463),
335
+ (1, 181),
336
+ (1, 1107),
337
+ (99, 123),
338
+ (1, 31),
339
+ (5, 18),
340
+ (3, 1229),
341
+ (117, 229),
342
+ (1, 87),
343
+ (1, 1328),
344
+ ]
345
+ num_heads = (32, 8)
346
+ head_dim = 128
347
+ page_size = 16
348
+ num_pages = 1000
349
+
350
+ self._test_ragged_paged_attention(
351
+ seq_lens,
352
+ num_heads,
353
+ head_dim,
354
+ page_size,
355
+ dtype,
356
+ dtype,
357
+ num_pages,
358
+ )
359
+
360
+ @parameterized.product(
361
+ num_seqs=[1, 17],
362
+ num_heads=[(32, 8), (12, 2), (5, 1), (3, 3)],
363
+ head_dim=[80, 240],
364
+ dtype=[jnp.float32, jnp.bfloat16],
365
+ # num_kv_pages_per_block=[8, 16],
366
+ # num_queries_per_block=[16, 32],
367
+ )
368
+ def test_ragged_paged_attention_complex(
369
+ self,
370
+ num_seqs,
371
+ num_heads,
372
+ head_dim,
373
+ dtype,
374
+ # num_kv_pages_per_block,
375
+ # num_queries_per_block,
376
+ ):
377
+ rng = np.random.default_rng(1234)
378
+ q_lens = rng.integers(1, 100, num_seqs)
379
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
380
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
381
+ page_size = 16
382
+ num_pages = 1000
383
+
384
+ self._test_ragged_paged_attention(
385
+ seq_lens,
386
+ num_heads,
387
+ head_dim,
388
+ page_size,
389
+ dtype,
390
+ dtype,
391
+ num_pages,
392
+ # num_kv_pages_per_block=num_kv_pages_per_block,
393
+ # num_queries_per_block=num_queries_per_block,
394
+ )
395
+
396
+ @parameterized.product(sliding_window=[None, 5, 128], )
397
+ def test_ragged_paged_attention_sliding_window(
398
+ self,
399
+ sliding_window: int | None,
400
+ ):
401
+ num_seqs = 5
402
+ num_heads = (4, 4)
403
+ dtype = jnp.float32
404
+ rng = np.random.default_rng(1234)
405
+ q_lens = rng.integers(1, 100, num_seqs)
406
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
407
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
408
+ head_dim = 128
409
+ page_size = 16
410
+ num_pages = 1000
411
+
412
+ self._test_ragged_paged_attention(
413
+ seq_lens,
414
+ num_heads,
415
+ head_dim,
416
+ page_size,
417
+ dtype,
418
+ dtype,
419
+ num_pages,
420
+ sliding_window=sliding_window,
421
+ )
422
+
423
+ @parameterized.product(soft_cap=[None, 50.0], )
424
+ def test_ragged_paged_attention_logit_soft_capping(
425
+ self,
426
+ soft_cap: float | None,
427
+ ):
428
+ num_heads = (16, 2)
429
+ num_seqs = 2
430
+ dtype = jnp.float32
431
+ rng = np.random.default_rng(1234)
432
+ q_lens = rng.integers(1, 100, num_seqs)
433
+ kv_lens = q_lens + rng.integers(0, 50, num_seqs)
434
+ seq_lens = list(zip(q_lens.tolist(), kv_lens.tolist()))
435
+ head_dim = 128
436
+ page_size = 16
437
+ num_pages = 1000
438
+
439
+ self._test_ragged_paged_attention(
440
+ seq_lens,
441
+ num_heads,
442
+ head_dim,
443
+ page_size,
444
+ dtype,
445
+ dtype,
446
+ num_pages,
447
+ soft_cap=soft_cap,
448
+ )
449
+
450
+ def test_ragged_paged_attention_sliding_window_should_be_positive(self):
451
+ dtype = jnp.float32
452
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
453
+ num_heads = (32, 8)
454
+ head_dim = 128
455
+ page_size = 16
456
+ num_pages = 1000
457
+
458
+ with self.assertRaisesRegex(ValueError, "must be positive"):
459
+ self._test_ragged_paged_attention(
460
+ seq_lens,
461
+ num_heads,
462
+ head_dim,
463
+ page_size,
464
+ dtype,
465
+ dtype,
466
+ num_pages,
467
+ sliding_window=0,
468
+ )
469
+
470
+ with self.assertRaisesRegex(ValueError, "must be positive"):
471
+ self._test_ragged_paged_attention(
472
+ seq_lens,
473
+ num_heads,
474
+ head_dim,
475
+ page_size,
476
+ dtype,
477
+ dtype,
478
+ num_pages,
479
+ sliding_window=-1,
480
+ )
481
+
482
+ def test_ragged_paged_attention_soft_cap_cannot_be_zero(self):
483
+ dtype = jnp.float32
484
+ seq_lens = [(192, 328), (128, 180), (64, 255)]
485
+ num_heads = (32, 8)
486
+ head_dim = 128
487
+ page_size = 16
488
+ num_pages = 1000
489
+
490
+ with self.assertRaisesRegex(ValueError, "must not be 0.0"):
491
+ self._test_ragged_paged_attention(
492
+ seq_lens,
493
+ num_heads,
494
+ head_dim,
495
+ page_size,
496
+ dtype,
497
+ dtype,
498
+ num_pages,
499
+ soft_cap=0.0,
500
+ )
501
+
502
+
503
+ if __name__ == "__main__":
504
+ absltest.main(testLoader=jtu.JaxTestLoader())
tests/lora/__init__.py ADDED
File without changes
@@ -0,0 +1,123 @@
1
+ # https://github.com/vllm-project/vllm/blob/ed10f3cea199a7a1f3532fbe367f5c5479a6cae9/tests/tpu/lora/test_lora.py
2
+ import pytest
3
+ import vllm
4
+ from vllm.lora.request import LoRARequest
5
+
6
+ # This file contains tests to ensure that LoRA works correctly on the TPU
7
+ # backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
8
+ # for this. The adapters are:
9
+ # Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
10
+ # from 1 to 4.
11
+
12
+ # These adapters are trained using a standard huggingface peft training script,
13
+ # where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
14
+ # 100 training iterations with a training batch size of 100.
15
+
16
+
17
+ @pytest.fixture(scope="function", autouse=True)
18
+ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
19
+ """
20
+ Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1
21
+ for all tests in this file
22
+ """
23
+ with monkeypatch.context() as m:
24
+ m.setenv("VLLM_USE_V1", "1")
25
+ yield
26
+
27
+
28
+ def setup_vllm(num_loras: int) -> vllm.LLM:
29
+ return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
30
+ max_model_len=256,
31
+ max_num_seqs=8,
32
+ enable_lora=True,
33
+ max_loras=num_loras,
34
+ max_lora_rank=8)
35
+
36
+
37
+ def test_single_lora():
38
+ """
39
+ This test ensures we can run a single LoRA adapter on the TPU backend.
40
+ We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter" which
41
+ will force Qwen2.5-3B-Instruct to claim 1+1=2.
42
+ """
43
+
44
+ llm = setup_vllm(1)
45
+
46
+ prompt = "What is 1+1? \n"
47
+
48
+ lora_request = LoRARequest(
49
+ "lora_adapter_2", 2,
50
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
51
+ output = llm.generate(prompt,
52
+ sampling_params=vllm.SamplingParams(max_tokens=16,
53
+ temperature=0),
54
+ lora_request=lora_request)[0].outputs[0].text
55
+
56
+ answer = output.strip()[0]
57
+
58
+ assert answer.isdigit()
59
+ assert int(answer) == 2
60
+
61
+
62
+ def test_lora_hotswapping():
63
+ """
64
+ This test ensures we can run multiple LoRA adapters on the TPU backend, even
65
+ if we only have space to store 1.
66
+
67
+ We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
68
+ will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
69
+ """
70
+
71
+ lora_name_template = \
72
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
73
+ lora_requests = [
74
+ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
75
+ for i in range(1, 5)
76
+ ]
77
+
78
+ llm = setup_vllm(1)
79
+
80
+ prompt = "What is 1+1? \n"
81
+
82
+ for i, req in enumerate(lora_requests):
83
+ output = llm.generate(prompt,
84
+ sampling_params=vllm.SamplingParams(
85
+ max_tokens=16, temperature=0),
86
+ lora_request=req)[0].outputs[0].text
87
+ answer = output.strip()[0]
88
+
89
+ assert answer.isdigit()
90
+ assert int(answer) == i + 1, f"Expected {i + 1}, got {answer}"
91
+
92
+
93
+ def test_multi_lora():
94
+ """
95
+ This test ensures we can run multiple LoRA adapters on the TPU backend, when
96
+ we have enough space to store all of them.
97
+
98
+ We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
99
+ will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
100
+ """
101
+ lora_name_template = \
102
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
103
+ lora_requests = [
104
+ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
105
+ for i in range(1, 5)
106
+ ]
107
+
108
+ llm = setup_vllm(4)
109
+
110
+ prompt = "What is 1+1? \n"
111
+
112
+ for i, req in enumerate(lora_requests):
113
+ output = llm.generate(prompt,
114
+ sampling_params=vllm.SamplingParams(
115
+ max_tokens=16, temperature=0),
116
+ lora_request=req)[0].outputs[0].text
117
+
118
+ answer = output.strip()[0]
119
+
120
+ assert answer.isdigit()
121
+ assert int(
122
+ output.strip()
123
+ [0]) == i + 1, f"Expected {i + 1}, got {int(output.strip()[0])}"