tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
tests/lora/test_lora.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# https://github.com/vllm-project/vllm/blob/ed10f3cea199a7a1f3532fbe367f5c5479a6cae9/tests/tpu/lora/test_lora.py
|
|
16
|
+
import os
|
|
17
|
+
import time
|
|
18
|
+
|
|
19
|
+
import pytest
|
|
20
|
+
import vllm
|
|
21
|
+
from vllm.lora.request import LoRARequest
|
|
22
|
+
|
|
23
|
+
# This file contains tests to ensure that LoRA works correctly on the TPU
|
|
24
|
+
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
|
|
25
|
+
# for this. The adapters are:
|
|
26
|
+
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
|
|
27
|
+
# from 1 to 4.
|
|
28
|
+
|
|
29
|
+
# These adapters are trained using a standard huggingface peft training script,
|
|
30
|
+
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
|
|
31
|
+
# 100 training iterations with a training batch size of 100.
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def setup_vllm(num_loras: int, tp: int = 1) -> vllm.LLM:
|
|
35
|
+
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
|
36
|
+
max_model_len=256,
|
|
37
|
+
max_num_batched_tokens=64,
|
|
38
|
+
max_num_seqs=8,
|
|
39
|
+
tensor_parallel_size=tp,
|
|
40
|
+
enable_lora=True,
|
|
41
|
+
max_loras=num_loras,
|
|
42
|
+
max_lora_rank=8)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# For multi-chip test, we only use TP=2 because the base model Qwen/Qwen2.5-3B-Instruct has 2 kv heads and the current attention kernel requires it to be divisible by tp_size.
|
|
46
|
+
TP = [2] if os.environ.get("TEST_LORA_TP", False) else [1]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.mark.parametrize("tp", TP)
|
|
50
|
+
def test_single_lora(tp):
|
|
51
|
+
"""
|
|
52
|
+
This test ensures we can run a single LoRA adapter on the TPU backend.
|
|
53
|
+
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter" which
|
|
54
|
+
will force Qwen2.5-3B-Instruct to claim 1+1=2.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
llm = setup_vllm(1, tp)
|
|
58
|
+
|
|
59
|
+
prompt = "What is 1+1? \n"
|
|
60
|
+
|
|
61
|
+
lora_request = LoRARequest(
|
|
62
|
+
"lora_adapter_2", 2,
|
|
63
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
|
|
64
|
+
output = llm.generate(prompt,
|
|
65
|
+
sampling_params=vllm.SamplingParams(max_tokens=16,
|
|
66
|
+
temperature=0),
|
|
67
|
+
lora_request=lora_request)[0].outputs[0].text
|
|
68
|
+
|
|
69
|
+
answer = output.strip()[0]
|
|
70
|
+
|
|
71
|
+
assert answer.isdigit()
|
|
72
|
+
assert int(answer) == 2
|
|
73
|
+
|
|
74
|
+
del llm
|
|
75
|
+
time.sleep(10)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest.mark.parametrize("tp", TP)
|
|
79
|
+
def test_lora_hotswapping(tp):
|
|
80
|
+
"""
|
|
81
|
+
This test ensures we can run multiple LoRA adapters on the TPU backend, even
|
|
82
|
+
if we only have space to store 1.
|
|
83
|
+
|
|
84
|
+
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
|
85
|
+
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
lora_name_template = \
|
|
89
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
|
90
|
+
lora_requests = [
|
|
91
|
+
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
|
92
|
+
for i in range(1, 5)
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
llm = setup_vllm(1, tp)
|
|
96
|
+
|
|
97
|
+
prompt = "What is 1+1? \n"
|
|
98
|
+
|
|
99
|
+
for i, req in enumerate(lora_requests):
|
|
100
|
+
output = llm.generate(prompt,
|
|
101
|
+
sampling_params=vllm.SamplingParams(
|
|
102
|
+
max_tokens=16, temperature=0),
|
|
103
|
+
lora_request=req)[0].outputs[0].text
|
|
104
|
+
answer = output.strip()[0]
|
|
105
|
+
|
|
106
|
+
assert answer.isdigit()
|
|
107
|
+
assert int(answer) == i + 1, f"Expected {i + 1}, got {answer}"
|
|
108
|
+
|
|
109
|
+
del llm
|
|
110
|
+
time.sleep(10)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@pytest.mark.parametrize("tp", TP)
|
|
114
|
+
def test_multi_lora(tp):
|
|
115
|
+
"""
|
|
116
|
+
This test ensures we can run multiple LoRA adapters on the TPU backend, when
|
|
117
|
+
we have enough space to store all of them.
|
|
118
|
+
|
|
119
|
+
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
|
120
|
+
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
|
121
|
+
"""
|
|
122
|
+
lora_name_template = \
|
|
123
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
|
124
|
+
lora_requests = [
|
|
125
|
+
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
|
126
|
+
for i in range(1, 5)
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
llm = setup_vllm(4, tp)
|
|
130
|
+
|
|
131
|
+
prompt = "What is 1+1? \n"
|
|
132
|
+
|
|
133
|
+
for i, req in enumerate(lora_requests):
|
|
134
|
+
output = llm.generate(prompt,
|
|
135
|
+
sampling_params=vllm.SamplingParams(
|
|
136
|
+
max_tokens=16, temperature=0),
|
|
137
|
+
lora_request=req)[0].outputs[0].text
|
|
138
|
+
|
|
139
|
+
answer = output.strip()[0]
|
|
140
|
+
|
|
141
|
+
assert answer.isdigit()
|
|
142
|
+
assert int(
|
|
143
|
+
output.strip()
|
|
144
|
+
[0]) == i + 1, f"Expected {i + 1}, got {int(output.strip()[0])}"
|
|
145
|
+
|
|
146
|
+
del llm
|
|
147
|
+
time.sleep(10)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import time
|
|
17
|
+
|
|
18
|
+
import pytest
|
|
19
|
+
import vllm
|
|
20
|
+
from vllm.lora.request import LoRARequest
|
|
21
|
+
|
|
22
|
+
TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.mark.parametrize("tp", TP)
|
|
26
|
+
def test_lora_performance(tp):
|
|
27
|
+
prompt = "What is 1+1? \n"
|
|
28
|
+
llm_without_lora = vllm.LLM(
|
|
29
|
+
model="Qwen/Qwen2.5-3B-Instruct",
|
|
30
|
+
max_model_len=256,
|
|
31
|
+
max_num_batched_tokens=64,
|
|
32
|
+
max_num_seqs=8,
|
|
33
|
+
tensor_parallel_size=tp,
|
|
34
|
+
)
|
|
35
|
+
start_time = time.time()
|
|
36
|
+
llm_without_lora.generate(
|
|
37
|
+
prompt,
|
|
38
|
+
sampling_params=vllm.SamplingParams(max_tokens=16, temperature=0),
|
|
39
|
+
)[0].outputs[0].text
|
|
40
|
+
base_time = time.time() - start_time
|
|
41
|
+
|
|
42
|
+
del llm_without_lora
|
|
43
|
+
# Waiting for TPUs to be released
|
|
44
|
+
time.sleep(10)
|
|
45
|
+
|
|
46
|
+
llm_with_lora = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
|
47
|
+
max_model_len=256,
|
|
48
|
+
max_num_batched_tokens=64,
|
|
49
|
+
max_num_seqs=8,
|
|
50
|
+
tensor_parallel_size=tp,
|
|
51
|
+
enable_lora=True,
|
|
52
|
+
max_loras=1,
|
|
53
|
+
max_lora_rank=8)
|
|
54
|
+
lora_request = LoRARequest(
|
|
55
|
+
"lora_adapter_2", 2,
|
|
56
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
|
|
57
|
+
start_time = time.time()
|
|
58
|
+
llm_with_lora.generate(prompt,
|
|
59
|
+
sampling_params=vllm.SamplingParams(max_tokens=16,
|
|
60
|
+
temperature=0),
|
|
61
|
+
lora_request=lora_request)[0].outputs[0].text
|
|
62
|
+
lora_time = time.time() - start_time
|
|
63
|
+
print(f"Base time: {base_time}, LoRA time: {lora_time}")
|
|
64
|
+
assert (base_time /
|
|
65
|
+
lora_time) < 8, f"Base time: {base_time}, LoRA time: {lora_time}"
|
|
66
|
+
|
|
67
|
+
del llm_with_lora
|
tests/lora/utils.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# https://github.com/vllm-project/vllm/blob/279a5f31b3faa6f40759516efa5c742f637ab8b7/tests/lora/utils.py
|
|
9
|
+
class DummyLoRAManager:
|
|
10
|
+
|
|
11
|
+
def __init__(self, device: torch.device = "cuda:0"):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self._loras: dict[str, LoRALayerWeights] = {}
|
|
14
|
+
self._device = device
|
|
15
|
+
|
|
16
|
+
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
|
|
17
|
+
self._loras[module_name] = lora
|
|
18
|
+
|
|
19
|
+
def get_module_lora(self, module_name: str) -> LoRALayerWeights:
|
|
20
|
+
return self._loras[module_name]
|
|
21
|
+
|
|
22
|
+
def init_random_lora(
|
|
23
|
+
self,
|
|
24
|
+
module_name: str,
|
|
25
|
+
weight: torch.Tensor,
|
|
26
|
+
rank: int = 8,
|
|
27
|
+
):
|
|
28
|
+
lora = LoRALayerWeights(
|
|
29
|
+
module_name,
|
|
30
|
+
rank=rank,
|
|
31
|
+
lora_alpha=1,
|
|
32
|
+
lora_a=torch.rand([rank, weight.shape[1]],
|
|
33
|
+
dtype=weight.dtype,
|
|
34
|
+
device=self._device),
|
|
35
|
+
lora_b=torch.rand([weight.shape[0], rank],
|
|
36
|
+
dtype=weight.dtype,
|
|
37
|
+
device=self._device),
|
|
38
|
+
)
|
|
39
|
+
self.set_module_lora(module_name, lora)
|
|
40
|
+
|
|
41
|
+
return lora
|
|
42
|
+
|
|
43
|
+
def init_lora(
|
|
44
|
+
self,
|
|
45
|
+
module_name: str,
|
|
46
|
+
input_dim: int,
|
|
47
|
+
output_dim: int,
|
|
48
|
+
rank=8,
|
|
49
|
+
noop=False,
|
|
50
|
+
embeddings_tensor=None,
|
|
51
|
+
):
|
|
52
|
+
lora = LoRALayerWeights(
|
|
53
|
+
module_name,
|
|
54
|
+
rank=rank,
|
|
55
|
+
lora_alpha=1,
|
|
56
|
+
lora_a=torch.rand([rank, input_dim], device="cuda"),
|
|
57
|
+
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
|
|
58
|
+
embeddings_tensor=embeddings_tensor,
|
|
59
|
+
)
|
|
60
|
+
self.set_module_lora(module_name, lora)
|
|
61
|
+
return lora
|
|
62
|
+
|
|
63
|
+
def reset_lora(self):
|
|
64
|
+
self._loras = {}
|
|
65
|
+
|
|
66
|
+
def init_packed_lora(
|
|
67
|
+
self,
|
|
68
|
+
module_name: str,
|
|
69
|
+
input_dim: int,
|
|
70
|
+
output_dims: list[int],
|
|
71
|
+
noop_lora_index: list[int] | None = None,
|
|
72
|
+
rank: int = 8,
|
|
73
|
+
):
|
|
74
|
+
base_loras: list[LoRALayerWeights] = []
|
|
75
|
+
noop_lora_index_set = set(noop_lora_index or [])
|
|
76
|
+
|
|
77
|
+
for i, out_dim in enumerate(output_dims):
|
|
78
|
+
base_lora = self.init_lora(
|
|
79
|
+
module_name + "_000_" + str(i),
|
|
80
|
+
input_dim,
|
|
81
|
+
out_dim,
|
|
82
|
+
rank=rank,
|
|
83
|
+
noop=i in noop_lora_index_set,
|
|
84
|
+
)
|
|
85
|
+
base_loras.append(base_lora)
|
|
86
|
+
packed_lora = PackedLoRALayerWeights.pack(base_loras)
|
|
87
|
+
self.set_module_lora(module_name, packed_lora)
|
|
88
|
+
return packed_lora
|
tests/models/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|