tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,147 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # https://github.com/vllm-project/vllm/blob/ed10f3cea199a7a1f3532fbe367f5c5479a6cae9/tests/tpu/lora/test_lora.py
16
+ import os
17
+ import time
18
+
19
+ import pytest
20
+ import vllm
21
+ from vllm.lora.request import LoRARequest
22
+
23
+ # This file contains tests to ensure that LoRA works correctly on the TPU
24
+ # backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
25
+ # for this. The adapters are:
26
+ # Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
27
+ # from 1 to 4.
28
+
29
+ # These adapters are trained using a standard huggingface peft training script,
30
+ # where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
31
+ # 100 training iterations with a training batch size of 100.
32
+
33
+
34
+ def setup_vllm(num_loras: int, tp: int = 1) -> vllm.LLM:
35
+ return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
36
+ max_model_len=256,
37
+ max_num_batched_tokens=64,
38
+ max_num_seqs=8,
39
+ tensor_parallel_size=tp,
40
+ enable_lora=True,
41
+ max_loras=num_loras,
42
+ max_lora_rank=8)
43
+
44
+
45
+ # For multi-chip test, we only use TP=2 because the base model Qwen/Qwen2.5-3B-Instruct has 2 kv heads and the current attention kernel requires it to be divisible by tp_size.
46
+ TP = [2] if os.environ.get("TEST_LORA_TP", False) else [1]
47
+
48
+
49
+ @pytest.mark.parametrize("tp", TP)
50
+ def test_single_lora(tp):
51
+ """
52
+ This test ensures we can run a single LoRA adapter on the TPU backend.
53
+ We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter" which
54
+ will force Qwen2.5-3B-Instruct to claim 1+1=2.
55
+ """
56
+
57
+ llm = setup_vllm(1, tp)
58
+
59
+ prompt = "What is 1+1? \n"
60
+
61
+ lora_request = LoRARequest(
62
+ "lora_adapter_2", 2,
63
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
64
+ output = llm.generate(prompt,
65
+ sampling_params=vllm.SamplingParams(max_tokens=16,
66
+ temperature=0),
67
+ lora_request=lora_request)[0].outputs[0].text
68
+
69
+ answer = output.strip()[0]
70
+
71
+ assert answer.isdigit()
72
+ assert int(answer) == 2
73
+
74
+ del llm
75
+ time.sleep(10)
76
+
77
+
78
+ @pytest.mark.parametrize("tp", TP)
79
+ def test_lora_hotswapping(tp):
80
+ """
81
+ This test ensures we can run multiple LoRA adapters on the TPU backend, even
82
+ if we only have space to store 1.
83
+
84
+ We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
85
+ will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
86
+ """
87
+
88
+ lora_name_template = \
89
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
90
+ lora_requests = [
91
+ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
92
+ for i in range(1, 5)
93
+ ]
94
+
95
+ llm = setup_vllm(1, tp)
96
+
97
+ prompt = "What is 1+1? \n"
98
+
99
+ for i, req in enumerate(lora_requests):
100
+ output = llm.generate(prompt,
101
+ sampling_params=vllm.SamplingParams(
102
+ max_tokens=16, temperature=0),
103
+ lora_request=req)[0].outputs[0].text
104
+ answer = output.strip()[0]
105
+
106
+ assert answer.isdigit()
107
+ assert int(answer) == i + 1, f"Expected {i + 1}, got {answer}"
108
+
109
+ del llm
110
+ time.sleep(10)
111
+
112
+
113
+ @pytest.mark.parametrize("tp", TP)
114
+ def test_multi_lora(tp):
115
+ """
116
+ This test ensures we can run multiple LoRA adapters on the TPU backend, when
117
+ we have enough space to store all of them.
118
+
119
+ We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
120
+ will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
121
+ """
122
+ lora_name_template = \
123
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
124
+ lora_requests = [
125
+ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
126
+ for i in range(1, 5)
127
+ ]
128
+
129
+ llm = setup_vllm(4, tp)
130
+
131
+ prompt = "What is 1+1? \n"
132
+
133
+ for i, req in enumerate(lora_requests):
134
+ output = llm.generate(prompt,
135
+ sampling_params=vllm.SamplingParams(
136
+ max_tokens=16, temperature=0),
137
+ lora_request=req)[0].outputs[0].text
138
+
139
+ answer = output.strip()[0]
140
+
141
+ assert answer.isdigit()
142
+ assert int(
143
+ output.strip()
144
+ [0]) == i + 1, f"Expected {i + 1}, got {int(output.strip()[0])}"
145
+
146
+ del llm
147
+ time.sleep(10)
@@ -0,0 +1,67 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+
18
+ import pytest
19
+ import vllm
20
+ from vllm.lora.request import LoRARequest
21
+
22
+ TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
23
+
24
+
25
+ @pytest.mark.parametrize("tp", TP)
26
+ def test_lora_performance(tp):
27
+ prompt = "What is 1+1? \n"
28
+ llm_without_lora = vllm.LLM(
29
+ model="Qwen/Qwen2.5-3B-Instruct",
30
+ max_model_len=256,
31
+ max_num_batched_tokens=64,
32
+ max_num_seqs=8,
33
+ tensor_parallel_size=tp,
34
+ )
35
+ start_time = time.time()
36
+ llm_without_lora.generate(
37
+ prompt,
38
+ sampling_params=vllm.SamplingParams(max_tokens=16, temperature=0),
39
+ )[0].outputs[0].text
40
+ base_time = time.time() - start_time
41
+
42
+ del llm_without_lora
43
+ # Waiting for TPUs to be released
44
+ time.sleep(10)
45
+
46
+ llm_with_lora = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
47
+ max_model_len=256,
48
+ max_num_batched_tokens=64,
49
+ max_num_seqs=8,
50
+ tensor_parallel_size=tp,
51
+ enable_lora=True,
52
+ max_loras=1,
53
+ max_lora_rank=8)
54
+ lora_request = LoRARequest(
55
+ "lora_adapter_2", 2,
56
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
57
+ start_time = time.time()
58
+ llm_with_lora.generate(prompt,
59
+ sampling_params=vllm.SamplingParams(max_tokens=16,
60
+ temperature=0),
61
+ lora_request=lora_request)[0].outputs[0].text
62
+ lora_time = time.time() - start_time
63
+ print(f"Base time: {base_time}, LoRA time: {lora_time}")
64
+ assert (base_time /
65
+ lora_time) < 8, f"Base time: {base_time}, LoRA time: {lora_time}"
66
+
67
+ del llm_with_lora
tests/lora/utils.py ADDED
@@ -0,0 +1,88 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import torch
5
+ from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
6
+
7
+
8
+ # https://github.com/vllm-project/vllm/blob/279a5f31b3faa6f40759516efa5c742f637ab8b7/tests/lora/utils.py
9
+ class DummyLoRAManager:
10
+
11
+ def __init__(self, device: torch.device = "cuda:0"):
12
+ super().__init__()
13
+ self._loras: dict[str, LoRALayerWeights] = {}
14
+ self._device = device
15
+
16
+ def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
17
+ self._loras[module_name] = lora
18
+
19
+ def get_module_lora(self, module_name: str) -> LoRALayerWeights:
20
+ return self._loras[module_name]
21
+
22
+ def init_random_lora(
23
+ self,
24
+ module_name: str,
25
+ weight: torch.Tensor,
26
+ rank: int = 8,
27
+ ):
28
+ lora = LoRALayerWeights(
29
+ module_name,
30
+ rank=rank,
31
+ lora_alpha=1,
32
+ lora_a=torch.rand([rank, weight.shape[1]],
33
+ dtype=weight.dtype,
34
+ device=self._device),
35
+ lora_b=torch.rand([weight.shape[0], rank],
36
+ dtype=weight.dtype,
37
+ device=self._device),
38
+ )
39
+ self.set_module_lora(module_name, lora)
40
+
41
+ return lora
42
+
43
+ def init_lora(
44
+ self,
45
+ module_name: str,
46
+ input_dim: int,
47
+ output_dim: int,
48
+ rank=8,
49
+ noop=False,
50
+ embeddings_tensor=None,
51
+ ):
52
+ lora = LoRALayerWeights(
53
+ module_name,
54
+ rank=rank,
55
+ lora_alpha=1,
56
+ lora_a=torch.rand([rank, input_dim], device="cuda"),
57
+ lora_b=torch.rand([output_dim, input_dim], device="cuda"),
58
+ embeddings_tensor=embeddings_tensor,
59
+ )
60
+ self.set_module_lora(module_name, lora)
61
+ return lora
62
+
63
+ def reset_lora(self):
64
+ self._loras = {}
65
+
66
+ def init_packed_lora(
67
+ self,
68
+ module_name: str,
69
+ input_dim: int,
70
+ output_dims: list[int],
71
+ noop_lora_index: list[int] | None = None,
72
+ rank: int = 8,
73
+ ):
74
+ base_loras: list[LoRALayerWeights] = []
75
+ noop_lora_index_set = set(noop_lora_index or [])
76
+
77
+ for i, out_dim in enumerate(output_dims):
78
+ base_lora = self.init_lora(
79
+ module_name + "_000_" + str(i),
80
+ input_dim,
81
+ out_dim,
82
+ rank=rank,
83
+ noop=i in noop_lora_index_set,
84
+ )
85
+ base_loras.append(base_lora)
86
+ packed_lora = PackedLoRALayerWeights.pack(base_loras)
87
+ self.set_module_lora(module_name, packed_lora)
88
+ return packed_lora
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.