tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
tests/lora/test_lora.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# https://github.com/vllm-project/vllm/blob/ed10f3cea199a7a1f3532fbe367f5c5479a6cae9/tests/tpu/lora/test_lora.py
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import vllm
|
|
7
|
+
from vllm.lora.request import LoRARequest
|
|
8
|
+
|
|
9
|
+
# This file contains tests to ensure that LoRA works correctly on the TPU
|
|
10
|
+
# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct
|
|
11
|
+
# for this. The adapters are:
|
|
12
|
+
# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges
|
|
13
|
+
# from 1 to 4.
|
|
14
|
+
|
|
15
|
+
# These adapters are trained using a standard huggingface peft training script,
|
|
16
|
+
# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run
|
|
17
|
+
# 100 training iterations with a training batch size of 100.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def setup_vllm(num_loras: int, tp: int = 1) -> vllm.LLM:
|
|
21
|
+
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
|
22
|
+
max_model_len=256,
|
|
23
|
+
max_num_batched_tokens=64,
|
|
24
|
+
max_num_seqs=8,
|
|
25
|
+
tensor_parallel_size=tp,
|
|
26
|
+
enable_lora=True,
|
|
27
|
+
max_loras=num_loras,
|
|
28
|
+
max_lora_rank=8)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# For multi-chip test, we only use TP=2 because the base model Qwen/Qwen2.5-3B-Instruct has 2 kv heads and the current attention kernel requires it to be divisible by tp_size.
|
|
32
|
+
TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.mark.parametrize("tp", TP)
|
|
36
|
+
def test_single_lora(tp):
|
|
37
|
+
"""
|
|
38
|
+
This test ensures we can run a single LoRA adapter on the TPU backend.
|
|
39
|
+
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter" which
|
|
40
|
+
will force Qwen2.5-3B-Instruct to claim 1+1=2.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
llm = setup_vllm(1, tp)
|
|
44
|
+
|
|
45
|
+
prompt = "What is 1+1? \n"
|
|
46
|
+
|
|
47
|
+
lora_request = LoRARequest(
|
|
48
|
+
"lora_adapter_2", 2,
|
|
49
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
|
|
50
|
+
output = llm.generate(prompt,
|
|
51
|
+
sampling_params=vllm.SamplingParams(max_tokens=16,
|
|
52
|
+
temperature=0),
|
|
53
|
+
lora_request=lora_request)[0].outputs[0].text
|
|
54
|
+
|
|
55
|
+
answer = output.strip()[0]
|
|
56
|
+
|
|
57
|
+
assert answer.isdigit()
|
|
58
|
+
assert int(answer) == 2
|
|
59
|
+
|
|
60
|
+
del llm
|
|
61
|
+
time.sleep(10)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.mark.parametrize("tp", TP)
|
|
65
|
+
def test_lora_hotswapping(tp):
|
|
66
|
+
"""
|
|
67
|
+
This test ensures we can run multiple LoRA adapters on the TPU backend, even
|
|
68
|
+
if we only have space to store 1.
|
|
69
|
+
|
|
70
|
+
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
|
71
|
+
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
lora_name_template = \
|
|
75
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
|
76
|
+
lora_requests = [
|
|
77
|
+
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
|
78
|
+
for i in range(1, 5)
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
llm = setup_vllm(1, tp)
|
|
82
|
+
|
|
83
|
+
prompt = "What is 1+1? \n"
|
|
84
|
+
|
|
85
|
+
for i, req in enumerate(lora_requests):
|
|
86
|
+
output = llm.generate(prompt,
|
|
87
|
+
sampling_params=vllm.SamplingParams(
|
|
88
|
+
max_tokens=16, temperature=0),
|
|
89
|
+
lora_request=req)[0].outputs[0].text
|
|
90
|
+
answer = output.strip()[0]
|
|
91
|
+
|
|
92
|
+
assert answer.isdigit()
|
|
93
|
+
assert int(answer) == i + 1, f"Expected {i + 1}, got {answer}"
|
|
94
|
+
|
|
95
|
+
del llm
|
|
96
|
+
time.sleep(10)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@pytest.mark.parametrize("tp", TP)
|
|
100
|
+
def test_multi_lora(tp):
|
|
101
|
+
"""
|
|
102
|
+
This test ensures we can run multiple LoRA adapters on the TPU backend, when
|
|
103
|
+
we have enough space to store all of them.
|
|
104
|
+
|
|
105
|
+
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which
|
|
106
|
+
will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x.
|
|
107
|
+
"""
|
|
108
|
+
lora_name_template = \
|
|
109
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter"
|
|
110
|
+
lora_requests = [
|
|
111
|
+
LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i))
|
|
112
|
+
for i in range(1, 5)
|
|
113
|
+
]
|
|
114
|
+
|
|
115
|
+
llm = setup_vllm(4, tp)
|
|
116
|
+
|
|
117
|
+
prompt = "What is 1+1? \n"
|
|
118
|
+
|
|
119
|
+
for i, req in enumerate(lora_requests):
|
|
120
|
+
output = llm.generate(prompt,
|
|
121
|
+
sampling_params=vllm.SamplingParams(
|
|
122
|
+
max_tokens=16, temperature=0),
|
|
123
|
+
lora_request=req)[0].outputs[0].text
|
|
124
|
+
|
|
125
|
+
answer = output.strip()[0]
|
|
126
|
+
|
|
127
|
+
assert answer.isdigit()
|
|
128
|
+
assert int(
|
|
129
|
+
output.strip()
|
|
130
|
+
[0]) == i + 1, f"Expected {i + 1}, got {int(output.strip()[0])}"
|
|
131
|
+
|
|
132
|
+
del llm
|
|
133
|
+
time.sleep(10)
|
tests/lora/utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# https://github.com/vllm-project/vllm/blob/279a5f31b3faa6f40759516efa5c742f637ab8b7/tests/lora/utils.py
|
|
9
|
+
class DummyLoRAManager:
|
|
10
|
+
|
|
11
|
+
def __init__(self, device: torch.device = "cuda:0"):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self._loras: dict[str, LoRALayerWeights] = {}
|
|
14
|
+
self._device = device
|
|
15
|
+
|
|
16
|
+
def set_module_lora(self, module_name: str, lora: LoRALayerWeights):
|
|
17
|
+
self._loras[module_name] = lora
|
|
18
|
+
|
|
19
|
+
def get_module_lora(self, module_name: str) -> LoRALayerWeights:
|
|
20
|
+
return self._loras[module_name]
|
|
21
|
+
|
|
22
|
+
def init_random_lora(
|
|
23
|
+
self,
|
|
24
|
+
module_name: str,
|
|
25
|
+
weight: torch.Tensor,
|
|
26
|
+
rank: int = 8,
|
|
27
|
+
generate_embeddings_tensor: int = 0,
|
|
28
|
+
):
|
|
29
|
+
lora = LoRALayerWeights(
|
|
30
|
+
module_name,
|
|
31
|
+
rank=rank,
|
|
32
|
+
lora_alpha=1,
|
|
33
|
+
lora_a=torch.rand([rank, weight.shape[1]],
|
|
34
|
+
dtype=weight.dtype,
|
|
35
|
+
device=self._device),
|
|
36
|
+
lora_b=torch.rand([weight.shape[0], rank],
|
|
37
|
+
dtype=weight.dtype,
|
|
38
|
+
device=self._device),
|
|
39
|
+
)
|
|
40
|
+
if generate_embeddings_tensor:
|
|
41
|
+
lora.embeddings_tensor = torch.rand(
|
|
42
|
+
5,
|
|
43
|
+
generate_embeddings_tensor,
|
|
44
|
+
dtype=weight.dtype,
|
|
45
|
+
device=self._device,
|
|
46
|
+
)
|
|
47
|
+
self.set_module_lora(module_name, lora)
|
|
48
|
+
|
|
49
|
+
return lora
|
|
50
|
+
|
|
51
|
+
def init_lora(
|
|
52
|
+
self,
|
|
53
|
+
module_name: str,
|
|
54
|
+
input_dim: int,
|
|
55
|
+
output_dim: int,
|
|
56
|
+
rank=8,
|
|
57
|
+
noop=False,
|
|
58
|
+
embeddings_tensor=None,
|
|
59
|
+
):
|
|
60
|
+
lora = LoRALayerWeights(
|
|
61
|
+
module_name,
|
|
62
|
+
rank=rank,
|
|
63
|
+
lora_alpha=1,
|
|
64
|
+
lora_a=torch.rand([rank, input_dim], device="cuda"),
|
|
65
|
+
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
|
|
66
|
+
embeddings_tensor=embeddings_tensor,
|
|
67
|
+
)
|
|
68
|
+
self.set_module_lora(module_name, lora)
|
|
69
|
+
return lora
|
|
70
|
+
|
|
71
|
+
def reset_lora(self):
|
|
72
|
+
self._loras = {}
|
|
73
|
+
|
|
74
|
+
def init_packed_lora(
|
|
75
|
+
self,
|
|
76
|
+
module_name: str,
|
|
77
|
+
input_dim: int,
|
|
78
|
+
output_dims: list[int],
|
|
79
|
+
noop_lora_index: list[int] | None = None,
|
|
80
|
+
rank: int = 8,
|
|
81
|
+
):
|
|
82
|
+
base_loras: list[LoRALayerWeights] = []
|
|
83
|
+
noop_lora_index_set = set(noop_lora_index or [])
|
|
84
|
+
|
|
85
|
+
for i, out_dim in enumerate(output_dims):
|
|
86
|
+
base_lora = self.init_lora(
|
|
87
|
+
module_name + "_000_" + str(i),
|
|
88
|
+
input_dim,
|
|
89
|
+
out_dim,
|
|
90
|
+
rank=rank,
|
|
91
|
+
noop=i in noop_lora_index_set,
|
|
92
|
+
)
|
|
93
|
+
base_loras.append(base_lora)
|
|
94
|
+
packed_lora = PackedLoRALayerWeights.pack(base_loras)
|
|
95
|
+
self.set_module_lora(module_name, packed_lora)
|
|
96
|
+
return packed_lora
|
tests/test_base.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import unittest
|
|
3
|
+
import warnings
|
|
4
|
+
from dataclasses import dataclass, field, fields
|
|
5
|
+
from typing import Any, List, Mapping
|
|
6
|
+
|
|
7
|
+
from tpu_inference.layers.jax.base import Config
|
|
8
|
+
|
|
9
|
+
# Use the 'warnings' module to globally ignore warnings within this block
|
|
10
|
+
vllm_logger = logging.getLogger("vllm")
|
|
11
|
+
original_level = vllm_logger.level
|
|
12
|
+
|
|
13
|
+
with warnings.catch_warnings():
|
|
14
|
+
warnings.simplefilter("ignore")
|
|
15
|
+
|
|
16
|
+
# Set the vLLM logger to ERROR to suppress its messages
|
|
17
|
+
vllm_logger.setLevel(logging.ERROR)
|
|
18
|
+
|
|
19
|
+
# Import the class; all warnings will be suppressed
|
|
20
|
+
from vllm.config import ModelConfig
|
|
21
|
+
|
|
22
|
+
vllm_logger.setLevel(logging.WARNING)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def setup_vllm_config(subconfig_types: List[str],
|
|
26
|
+
overrides: List[Mapping[str, Any]]):
|
|
27
|
+
vllm_config = SimpleVllmConfig()
|
|
28
|
+
for (subconfig_type, override) in zip(subconfig_types, overrides):
|
|
29
|
+
if subconfig_type == "model":
|
|
30
|
+
for key in override:
|
|
31
|
+
setattr(vllm_config.model_config, key, override[key])
|
|
32
|
+
else:
|
|
33
|
+
for key in override:
|
|
34
|
+
setattr(vllm_config, key, override[key])
|
|
35
|
+
return vllm_config
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class SimpleVllmConfig():
|
|
40
|
+
additional_config: Mapping[str, Any] = field(default_factory=dict)
|
|
41
|
+
# Set default max_model_len to turn off warnings.
|
|
42
|
+
model_config: ModelConfig = field(
|
|
43
|
+
default_factory=lambda: ModelConfig(max_model_len=1024))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class SimpleConfig(Config):
|
|
48
|
+
vllm_config: SimpleVllmConfig
|
|
49
|
+
arg1: str
|
|
50
|
+
arg2: str
|
|
51
|
+
arg3: int
|
|
52
|
+
|
|
53
|
+
def is_equal(self, other: Config):
|
|
54
|
+
for f in fields(self):
|
|
55
|
+
if f.name != "vllm_config":
|
|
56
|
+
if getattr(self, f.name) != getattr(other, f.name):
|
|
57
|
+
return False
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ConfigOverrideTests(unittest.TestCase):
|
|
62
|
+
|
|
63
|
+
def test_additional_config_overrides(self):
|
|
64
|
+
subconfig_types = ['']
|
|
65
|
+
overrides = [{"additional_config": {"arg1": "val1", "arg2": "val2"}}]
|
|
66
|
+
override_vllm_config = setup_vllm_config(subconfig_types, overrides)
|
|
67
|
+
default_vllm_config = SimpleVllmConfig()
|
|
68
|
+
config = SimpleConfig(vllm_config=override_vllm_config,
|
|
69
|
+
arg1="foo",
|
|
70
|
+
arg2="bar",
|
|
71
|
+
arg3=123)
|
|
72
|
+
expected_config = SimpleConfig(vllm_config=default_vllm_config,
|
|
73
|
+
arg1="val1",
|
|
74
|
+
arg2="val2",
|
|
75
|
+
arg3=123)
|
|
76
|
+
self.assertTrue(config.is_equal(expected_config))
|
|
77
|
+
|
|
78
|
+
def test_hf_overrides(self):
|
|
79
|
+
subconfig_types = ['model']
|
|
80
|
+
overrides = [{"hf_overrides": {"arg2": "val2", "arg3": 456}}]
|
|
81
|
+
default_vllm_config = SimpleVllmConfig()
|
|
82
|
+
override_vllm_config = setup_vllm_config(subconfig_types, overrides)
|
|
83
|
+
config = SimpleConfig(vllm_config=override_vllm_config,
|
|
84
|
+
arg1="foo",
|
|
85
|
+
arg2="bar",
|
|
86
|
+
arg3=123)
|
|
87
|
+
expected_config = SimpleConfig(vllm_config=default_vllm_config,
|
|
88
|
+
arg1="foo",
|
|
89
|
+
arg2="val2",
|
|
90
|
+
arg3=456)
|
|
91
|
+
self.assertTrue(config.is_equal(expected_config))
|
|
92
|
+
|
|
93
|
+
def test_additional_and_hf_overrides(self):
|
|
94
|
+
subconfig_types = ['', 'model']
|
|
95
|
+
overrides = [{
|
|
96
|
+
"additional_config": {
|
|
97
|
+
"arg1": "val1",
|
|
98
|
+
"arg2": "val2"
|
|
99
|
+
}
|
|
100
|
+
}, {
|
|
101
|
+
"hf_overrides": {
|
|
102
|
+
"arg2": "val3",
|
|
103
|
+
"arg3": 456
|
|
104
|
+
}
|
|
105
|
+
}]
|
|
106
|
+
default_vllm_config = SimpleVllmConfig()
|
|
107
|
+
override_vllm_config = setup_vllm_config(subconfig_types, overrides)
|
|
108
|
+
config = SimpleConfig(vllm_config=override_vllm_config,
|
|
109
|
+
arg1="foo",
|
|
110
|
+
arg2="bar",
|
|
111
|
+
arg3=123)
|
|
112
|
+
expected_config = SimpleConfig(vllm_config=default_vllm_config,
|
|
113
|
+
arg1="val1",
|
|
114
|
+
arg2="val3",
|
|
115
|
+
arg3=456)
|
|
116
|
+
self.assertTrue(config.is_equal(expected_config))
|
|
117
|
+
|
|
118
|
+
def test_additional_and_generate_overrides(self):
|
|
119
|
+
subconfig_types = ['', 'model']
|
|
120
|
+
overrides = [{
|
|
121
|
+
"additional_config": {
|
|
122
|
+
"arg1": "val1",
|
|
123
|
+
"arg2": "val2"
|
|
124
|
+
}
|
|
125
|
+
}, {
|
|
126
|
+
"override_generation_config": {
|
|
127
|
+
"arg2": "val3",
|
|
128
|
+
"arg3": 456
|
|
129
|
+
}
|
|
130
|
+
}]
|
|
131
|
+
default_vllm_config = SimpleVllmConfig()
|
|
132
|
+
override_vllm_config = setup_vllm_config(subconfig_types, overrides)
|
|
133
|
+
config = SimpleConfig(vllm_config=override_vllm_config,
|
|
134
|
+
arg1="foo",
|
|
135
|
+
arg2="bar",
|
|
136
|
+
arg3=123)
|
|
137
|
+
expected_config = SimpleConfig(vllm_config=default_vllm_config,
|
|
138
|
+
arg1="val1",
|
|
139
|
+
arg2="val3",
|
|
140
|
+
arg3=456)
|
|
141
|
+
self.assertTrue(config.is_equal(expected_config))
|
|
142
|
+
|
|
143
|
+
def test_hf_and_generate_overrides(self):
|
|
144
|
+
subconfig_types = ['model', 'model']
|
|
145
|
+
overrides = [{
|
|
146
|
+
"hf_overrides": {
|
|
147
|
+
"arg2": "val2",
|
|
148
|
+
"arg3": 456
|
|
149
|
+
}
|
|
150
|
+
}, {
|
|
151
|
+
"override_generation_config": {
|
|
152
|
+
"arg2": "val4",
|
|
153
|
+
"arg3": 789
|
|
154
|
+
}
|
|
155
|
+
}]
|
|
156
|
+
default_vllm_config = SimpleVllmConfig()
|
|
157
|
+
override_vllm_config = setup_vllm_config(subconfig_types, overrides)
|
|
158
|
+
config = SimpleConfig(vllm_config=override_vllm_config,
|
|
159
|
+
arg1="foo",
|
|
160
|
+
arg2="bar",
|
|
161
|
+
arg3=123)
|
|
162
|
+
expected_config = SimpleConfig(vllm_config=default_vllm_config,
|
|
163
|
+
arg1="foo",
|
|
164
|
+
arg2="val4",
|
|
165
|
+
arg3=789)
|
|
166
|
+
self.assertTrue(config.is_equal(expected_config))
|
|
167
|
+
|
|
168
|
+
def test_additional_and_hf_and_generate_overrides(self):
|
|
169
|
+
subconfig_types = ['', 'model', 'model']
|
|
170
|
+
overrides = [{
|
|
171
|
+
"additional_config": {
|
|
172
|
+
"arg1": "val1",
|
|
173
|
+
"arg2": "val2"
|
|
174
|
+
}
|
|
175
|
+
}, {
|
|
176
|
+
"hf_overrides": {
|
|
177
|
+
"arg2": "val2",
|
|
178
|
+
"arg3": 456
|
|
179
|
+
}
|
|
180
|
+
}, {
|
|
181
|
+
"override_generation_config": {
|
|
182
|
+
"arg1": "val3",
|
|
183
|
+
"arg2": "val4",
|
|
184
|
+
"arg3": 789
|
|
185
|
+
}
|
|
186
|
+
}]
|
|
187
|
+
default_vllm_config = SimpleVllmConfig()
|
|
188
|
+
override_vllm_config = setup_vllm_config(subconfig_types, overrides)
|
|
189
|
+
config = SimpleConfig(vllm_config=override_vllm_config,
|
|
190
|
+
arg1="foo",
|
|
191
|
+
arg2="bar",
|
|
192
|
+
arg3=123)
|
|
193
|
+
expected_config = SimpleConfig(vllm_config=default_vllm_config,
|
|
194
|
+
arg1="val3",
|
|
195
|
+
arg2="val4",
|
|
196
|
+
arg3=789)
|
|
197
|
+
self.assertTrue(config.is_equal(expected_config))
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
if __name__ == '__main__':
|
|
201
|
+
unittest.main()
|
tests/test_envs.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
import tpu_inference.envs as envs
|
|
7
|
+
from tpu_inference.envs import enable_envs_cache, environment_variables
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch):
|
|
11
|
+
assert envs.JAX_PLATFORMS == ""
|
|
12
|
+
assert envs.PHASED_PROFILING_DIR == ""
|
|
13
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
14
|
+
monkeypatch.setenv("PHASED_PROFILING_DIR", "/tmp/profiling")
|
|
15
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
16
|
+
assert envs.PHASED_PROFILING_DIR == "/tmp/profiling"
|
|
17
|
+
|
|
18
|
+
assert envs.TPU_NAME is None
|
|
19
|
+
assert envs.TPU_ACCELERATOR_TYPE is None
|
|
20
|
+
monkeypatch.setenv("TPU_NAME", "my-tpu")
|
|
21
|
+
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v5litepod-16")
|
|
22
|
+
assert envs.TPU_NAME == "my-tpu"
|
|
23
|
+
assert envs.TPU_ACCELERATOR_TYPE == "v5litepod-16"
|
|
24
|
+
|
|
25
|
+
# __getattr__ is not decorated with functools.cache
|
|
26
|
+
assert not hasattr(envs.__getattr__, "cache_info")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
30
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
31
|
+
monkeypatch.setenv("TPU_NAME", "my-tpu")
|
|
32
|
+
|
|
33
|
+
# __getattr__ is not decorated with functools.cache
|
|
34
|
+
assert not hasattr(envs.__getattr__, "cache_info")
|
|
35
|
+
|
|
36
|
+
enable_envs_cache()
|
|
37
|
+
|
|
38
|
+
# __getattr__ is decorated with functools.cache
|
|
39
|
+
assert hasattr(envs.__getattr__, "cache_info")
|
|
40
|
+
start_hits = envs.__getattr__.cache_info().hits
|
|
41
|
+
|
|
42
|
+
# 2 more hits due to JAX_PLATFORMS and TPU_NAME accesses
|
|
43
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
44
|
+
assert envs.TPU_NAME == "my-tpu"
|
|
45
|
+
assert envs.__getattr__.cache_info().hits == start_hits + 2
|
|
46
|
+
|
|
47
|
+
# All environment variables are cached
|
|
48
|
+
for environment_variable in environment_variables:
|
|
49
|
+
envs.__getattr__(environment_variable)
|
|
50
|
+
assert envs.__getattr__.cache_info(
|
|
51
|
+
).hits == start_hits + 2 + len(environment_variables)
|
|
52
|
+
|
|
53
|
+
# Reset envs.__getattr__ back to non-cached version to
|
|
54
|
+
# avoid affecting other tests
|
|
55
|
+
envs.__getattr__ = envs.__getattr__.__wrapped__
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
+
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
60
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
61
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
62
|
+
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
63
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
64
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
65
|
+
|
|
66
|
+
# Test NEW_MODEL_DESIGN (default False)
|
|
67
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
68
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
69
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
70
|
+
|
|
71
|
+
# Test USE_MOE_EP_KERNEL (default False)
|
|
72
|
+
assert envs.USE_MOE_EP_KERNEL is False
|
|
73
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
|
|
74
|
+
assert envs.USE_MOE_EP_KERNEL is True
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
78
|
+
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
79
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
80
|
+
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
81
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
82
|
+
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
|
|
86
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
|
|
87
|
+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
88
|
+
|
|
89
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "FLAX_NNX")
|
|
90
|
+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
94
|
+
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
95
|
+
monkeypatch.delenv("PREFILL_SLICES", raising=False)
|
|
96
|
+
monkeypatch.delenv("DECODE_SLICES", raising=False)
|
|
97
|
+
|
|
98
|
+
assert envs.JAX_PLATFORMS == ""
|
|
99
|
+
assert envs.PREFILL_SLICES == ""
|
|
100
|
+
assert envs.DECODE_SLICES == ""
|
|
101
|
+
assert envs.PHASED_PROFILING_DIR == ""
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def test_none_default_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
105
|
+
monkeypatch.delenv("TPU_ACCELERATOR_TYPE", raising=False)
|
|
106
|
+
monkeypatch.delenv("TPU_NAME", raising=False)
|
|
107
|
+
monkeypatch.delenv("TPU_WORKER_ID", raising=False)
|
|
108
|
+
|
|
109
|
+
assert envs.TPU_ACCELERATOR_TYPE is None
|
|
110
|
+
assert envs.TPU_NAME is None
|
|
111
|
+
assert envs.TPU_WORKER_ID is None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
115
|
+
assert envs.RAY_USAGE_STATS_ENABLED == "0"
|
|
116
|
+
monkeypatch.setenv("RAY_USAGE_STATS_ENABLED", "1")
|
|
117
|
+
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
118
|
+
|
|
119
|
+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
120
|
+
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
|
|
121
|
+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def test_invalid_attribute_raises_error():
|
|
125
|
+
with pytest.raises(AttributeError,
|
|
126
|
+
match="has no attribute 'NONEXISTENT_VAR'"):
|
|
127
|
+
_ = envs.NONEXISTENT_VAR
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_dir_returns_all_env_vars():
|
|
131
|
+
env_vars = envs.__dir__()
|
|
132
|
+
assert isinstance(env_vars, list)
|
|
133
|
+
assert len(env_vars) == len(environment_variables)
|
|
134
|
+
assert "JAX_PLATFORMS" in env_vars
|
|
135
|
+
assert "TPU_NAME" in env_vars
|
|
136
|
+
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
137
|
+
assert "MODEL_IMPL_TYPE" in env_vars
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
141
|
+
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
142
|
+
assert envs.TPU_WORKER_ID == "0"
|
|
143
|
+
|
|
144
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "grpc")
|
|
145
|
+
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
146
|
+
|
|
147
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
|
|
148
|
+
assert envs.TPU_MULTIHOST_BACKEND == "xla"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
152
|
+
monkeypatch.setenv("PREFILL_SLICES", "0,1,2,3")
|
|
153
|
+
assert envs.PREFILL_SLICES == "0,1,2,3"
|
|
154
|
+
|
|
155
|
+
monkeypatch.setenv("DECODE_SLICES", "4,5,6,7")
|
|
156
|
+
assert envs.DECODE_SLICES == "4,5,6,7"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
|
|
160
|
+
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
|
|
161
|
+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_cache_preserves_values_across_env_changes(
|
|
165
|
+
monkeypatch: pytest.MonkeyPatch):
|
|
166
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
167
|
+
|
|
168
|
+
enable_envs_cache()
|
|
169
|
+
|
|
170
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
171
|
+
|
|
172
|
+
# Change environment variable
|
|
173
|
+
monkeypatch.setenv("JAX_PLATFORMS", "cpu")
|
|
174
|
+
|
|
175
|
+
# Cached value should still be "tpu"
|
|
176
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
177
|
+
|
|
178
|
+
# Reset envs.__getattr__ back to non-cached version
|
|
179
|
+
envs.__getattr__ = envs.__getattr__.__wrapped__
|
|
180
|
+
|
|
181
|
+
# Now it should reflect the new value
|
|
182
|
+
assert envs.JAX_PLATFORMS == "cpu"
|