tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import os
|
|
2
16
|
|
|
3
17
|
from vllm.utils.network_utils import get_ip
|
|
@@ -54,7 +68,45 @@ def get_side_channel_port() -> str:
|
|
|
54
68
|
return port
|
|
55
69
|
|
|
56
70
|
|
|
57
|
-
def
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
71
|
+
def get_device_topology_order_id(local_devices, global_devices) -> int:
|
|
72
|
+
"""
|
|
73
|
+
Calculates the topology order ID for the local device set within the global topology.
|
|
74
|
+
|
|
75
|
+
This function determines the rank of the current host/process based on the
|
|
76
|
+
coordinate of its TPU devices relative to all devices in the topology.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
local_devices: A list of TpuDevice objects available to the current process.
|
|
80
|
+
global_devices: A list of all TpuDevice objects in the global topology.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
The topology order ID (rank) of the local devices.
|
|
84
|
+
"""
|
|
85
|
+
if not local_devices:
|
|
86
|
+
raise ValueError("local_devices cannot be empty")
|
|
87
|
+
if not global_devices:
|
|
88
|
+
raise ValueError("global_devices cannot be empty")
|
|
89
|
+
|
|
90
|
+
# 1. Find the 'anchor' (minimum coordinate) for the local devices.
|
|
91
|
+
# This represents the physical top-left corner of the local machine.
|
|
92
|
+
local_anchor = min(d.coords for d in local_devices)
|
|
93
|
+
|
|
94
|
+
# 2. Group global devices by process to find the anchor for EVERY process.
|
|
95
|
+
process_anchors = {}
|
|
96
|
+
for d in global_devices:
|
|
97
|
+
pid = d.process_index
|
|
98
|
+
# Update the minimum coordinate found for this process so far
|
|
99
|
+
if pid not in process_anchors or d.coords < process_anchors[pid]:
|
|
100
|
+
process_anchors[pid] = d.coords
|
|
101
|
+
|
|
102
|
+
# 3. Sort the unique anchors to establish the canonical topology order.
|
|
103
|
+
# Tuples (x, y, z) sort lexicographically (x first, then y, then z).
|
|
104
|
+
sorted_anchors = sorted(process_anchors.values())
|
|
105
|
+
|
|
106
|
+
# 4. Return the index (rank) of the local anchor in the sorted list.
|
|
107
|
+
try:
|
|
108
|
+
return sorted_anchors.index(local_anchor)
|
|
109
|
+
except ValueError:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Local devices: {local_devices} do not exist in the global device: {global_devices} list."
|
|
112
|
+
)
|
tpu_inference/envs.py
CHANGED
|
@@ -15,13 +15,88 @@ if TYPE_CHECKING:
|
|
|
15
15
|
PREFILL_SLICES: str = ""
|
|
16
16
|
DECODE_SLICES: str = ""
|
|
17
17
|
SKIP_JAX_PRECOMPILE: bool = False
|
|
18
|
-
|
|
18
|
+
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
|
19
|
+
MODEL_IMPL_TYPE: str = "auto"
|
|
19
20
|
NEW_MODEL_DESIGN: bool = False
|
|
20
21
|
PHASED_PROFILING_DIR: str = ""
|
|
21
22
|
PYTHON_TRACER_LEVEL: int = 1
|
|
22
23
|
USE_MOE_EP_KERNEL: bool = False
|
|
24
|
+
NUM_SLICES: int = 1
|
|
23
25
|
RAY_USAGE_STATS_ENABLED: str = "0"
|
|
24
26
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
|
|
27
|
+
ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def env_with_choices(
|
|
31
|
+
env_name: str,
|
|
32
|
+
default: str | None,
|
|
33
|
+
choices: list[str] | Callable[[], list[str]],
|
|
34
|
+
case_sensitive: bool = True,
|
|
35
|
+
) -> Callable[[], str | None]:
|
|
36
|
+
"""
|
|
37
|
+
Create a lambda that validates environment variable against allowed choices
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
env_name: Name of the environment variable
|
|
41
|
+
default: Default value if not set (can be None)
|
|
42
|
+
choices: List of valid string options or callable that returns list
|
|
43
|
+
case_sensitive: Whether validation should be case sensitive
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Lambda function for environment_variables dict
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def _get_validated_env() -> str | None:
|
|
50
|
+
value = os.getenv(env_name)
|
|
51
|
+
if value is None:
|
|
52
|
+
return default
|
|
53
|
+
|
|
54
|
+
# Resolve choices if it's a callable (for lazy loading)
|
|
55
|
+
actual_choices = choices() if callable(choices) else choices
|
|
56
|
+
|
|
57
|
+
if not case_sensitive:
|
|
58
|
+
check_value = value.lower()
|
|
59
|
+
check_choices = [choice.lower() for choice in actual_choices]
|
|
60
|
+
else:
|
|
61
|
+
check_value = value
|
|
62
|
+
check_choices = actual_choices
|
|
63
|
+
|
|
64
|
+
if check_value not in check_choices:
|
|
65
|
+
raise ValueError(f"Invalid value '{value}' for {env_name}. "
|
|
66
|
+
f"Valid options: {actual_choices}.")
|
|
67
|
+
|
|
68
|
+
return value
|
|
69
|
+
|
|
70
|
+
return _get_validated_env
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
|
|
74
|
+
"""
|
|
75
|
+
Accepts both numeric strings ("0", "1") and boolean strings
|
|
76
|
+
("true", "false", "True", "False").
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
env_name: Name of the environment variable
|
|
80
|
+
default: Default boolean value if not set
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def _get_bool_env() -> bool:
|
|
84
|
+
value = os.getenv(env_name)
|
|
85
|
+
if value is None or value == "":
|
|
86
|
+
return default
|
|
87
|
+
|
|
88
|
+
value_lower = value.lower()
|
|
89
|
+
if value_lower in ("true", "1"):
|
|
90
|
+
return True
|
|
91
|
+
elif value_lower in ("false", "0"):
|
|
92
|
+
return False
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Invalid boolean value '{value}' for {env_name}. "
|
|
96
|
+
f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
|
|
97
|
+
|
|
98
|
+
return _get_bool_env
|
|
99
|
+
|
|
25
100
|
|
|
26
101
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
27
102
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
@@ -38,7 +113,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
38
113
|
lambda: os.getenv("TPU_WORKER_ID", None),
|
|
39
114
|
# Backend for multi-host communication on TPU
|
|
40
115
|
"TPU_MULTIHOST_BACKEND":
|
|
41
|
-
|
|
116
|
+
env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
|
|
42
117
|
# Slice configuration for disaggregated prefill workers
|
|
43
118
|
"PREFILL_SLICES":
|
|
44
119
|
lambda: os.getenv("PREFILL_SLICES", ""),
|
|
@@ -47,28 +122,37 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
47
122
|
lambda: os.getenv("DECODE_SLICES", ""),
|
|
48
123
|
# Skip JAX precompilation step during initialization
|
|
49
124
|
"SKIP_JAX_PRECOMPILE":
|
|
50
|
-
|
|
125
|
+
env_bool("SKIP_JAX_PRECOMPILE", default=False),
|
|
126
|
+
# Check for XLA recompilation during execution
|
|
127
|
+
"VLLM_XLA_CHECK_RECOMPILATION":
|
|
128
|
+
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
|
|
51
129
|
# Model implementation type (e.g., "flax_nnx")
|
|
52
130
|
"MODEL_IMPL_TYPE":
|
|
53
|
-
|
|
131
|
+
env_with_choices("MODEL_IMPL_TYPE", "auto",
|
|
132
|
+
["auto", "vllm", "flax_nnx", "jetpack"]),
|
|
54
133
|
# Enable new experimental model design
|
|
55
134
|
"NEW_MODEL_DESIGN":
|
|
56
|
-
|
|
135
|
+
env_bool("NEW_MODEL_DESIGN", default=False),
|
|
57
136
|
# Directory to store phased profiling output
|
|
58
137
|
"PHASED_PROFILING_DIR":
|
|
59
138
|
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
|
|
60
139
|
# Python tracer level for profiling
|
|
61
140
|
"PYTHON_TRACER_LEVEL":
|
|
62
|
-
lambda: int(os.getenv("PYTHON_TRACER_LEVEL"
|
|
141
|
+
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
|
|
63
142
|
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
|
|
64
143
|
"USE_MOE_EP_KERNEL":
|
|
65
|
-
|
|
144
|
+
env_bool("USE_MOE_EP_KERNEL", default=False),
|
|
145
|
+
# Number of TPU slices for multi-slice mesh
|
|
146
|
+
"NUM_SLICES":
|
|
147
|
+
lambda: int(os.getenv("NUM_SLICES") or "1"),
|
|
66
148
|
# Enable/disable Ray usage statistics collection
|
|
67
149
|
"RAY_USAGE_STATS_ENABLED":
|
|
68
150
|
lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
|
|
69
151
|
# Ray compiled DAG channel type for TPU
|
|
70
152
|
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
|
|
71
|
-
|
|
153
|
+
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
|
|
154
|
+
"ENABLE_QUANTIZED_MATMUL_KERNEL":
|
|
155
|
+
lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
|
|
72
156
|
}
|
|
73
157
|
|
|
74
158
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import os
|
|
2
16
|
from array import array
|
|
3
17
|
from typing import Any, Dict, List, Optional
|
|
@@ -6,7 +20,7 @@ import ray
|
|
|
6
20
|
import vllm.envs as envs
|
|
7
21
|
from ray.util.placement_group import PlacementGroup
|
|
8
22
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
|
9
|
-
from vllm.multimodal.inputs import
|
|
23
|
+
from vllm.multimodal.inputs import MultiModalKwargsItem
|
|
10
24
|
from vllm.platforms import current_platform
|
|
11
25
|
from vllm.ray.ray_env import get_env_vars_to_copy
|
|
12
26
|
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
|
|
@@ -39,7 +53,7 @@ logger = init_logger(__name__)
|
|
|
39
53
|
|
|
40
54
|
|
|
41
55
|
def _encode_hook(obj: Any) -> Any:
|
|
42
|
-
"""Custom msgspec enc hook that supports array types and
|
|
56
|
+
"""Custom msgspec enc hook that supports array types and MultiModalKwargsItem.
|
|
43
57
|
|
|
44
58
|
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
|
45
59
|
"""
|
|
@@ -48,7 +62,7 @@ def _encode_hook(obj: Any) -> Any:
|
|
|
48
62
|
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
|
|
49
63
|
f"Given array has a type code of {obj.typecode}.")
|
|
50
64
|
return obj.tobytes()
|
|
51
|
-
if isinstance(obj,
|
|
65
|
+
if isinstance(obj, MultiModalKwargsItem):
|
|
52
66
|
return dict(obj)
|
|
53
67
|
|
|
54
68
|
|
|
@@ -136,11 +150,18 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
136
150
|
|
|
137
151
|
pp_size = self.parallel_config.pipeline_parallel_size
|
|
138
152
|
placement_group_specs: List[Dict[str, float]] = []
|
|
153
|
+
|
|
154
|
+
ray_nodes = ray.nodes()
|
|
155
|
+
logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
|
|
156
|
+
|
|
139
157
|
if pp_size == 1:
|
|
140
158
|
placement_group_specs = [{
|
|
141
159
|
device_str: node['Resources'][device_str]
|
|
142
|
-
} for node in
|
|
160
|
+
} for node in ray_nodes]
|
|
143
161
|
else:
|
|
162
|
+
assert pp_size == len(
|
|
163
|
+
ray_nodes
|
|
164
|
+
), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
|
|
144
165
|
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
145
166
|
placement_group_specs = [{
|
|
146
167
|
device_str: num_devices_per_pp_rank
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
# TODO: Update documentation
|
|
2
16
|
|
|
3
17
|
from typing import List, Optional, Tuple
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
|
|
|
540
540
|
"""Returns the total vmem bytes used by the kernel."""
|
|
541
541
|
m_per_device = m // tp_size
|
|
542
542
|
n_per_device = n // tp_size
|
|
543
|
-
y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype)
|
|
543
|
+
y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
|
|
544
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
|
|
544
545
|
total_bytes = (
|
|
545
|
-
2 * m_per_device * k *
|
|
546
|
-
|
|
546
|
+
2 * m_per_device * k *
|
|
547
|
+
(dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
|
|
548
|
+
dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
|
|
547
549
|
+ y_vmem_bytes # y_vmem_scratch_ref
|
|
548
|
-
+ 2 * m * bn *
|
|
550
|
+
+ 2 * m * bn *
|
|
551
|
+
(dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
|
|
552
|
+
dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
|
|
549
553
|
+ acc_bytes # acc_vmem_scratch_ref, jnp.float32
|
|
550
554
|
)
|
|
551
555
|
return total_bytes
|
|
@@ -639,8 +643,10 @@ def all_gather_matmul(
|
|
|
639
643
|
# NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
|
|
640
644
|
if grid_k == 1:
|
|
641
645
|
acc_shape = (8, 128)
|
|
642
|
-
acc_bytes =
|
|
643
|
-
|
|
646
|
+
acc_bytes = (
|
|
647
|
+
acc_shape[0] *
|
|
648
|
+
acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
|
|
649
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
|
|
644
650
|
y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
|
|
645
651
|
estimated_vmem_bytes = get_vmem_estimate_bytes(
|
|
646
652
|
m,
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
"""All-gather matmul kernel's tuned block sizes."""
|
|
3
3
|
|
|
4
|
+
import re
|
|
5
|
+
|
|
4
6
|
import jax
|
|
5
7
|
|
|
6
8
|
# key:
|
|
@@ -32,8 +34,11 @@ def get_tpu_version() -> int:
|
|
|
32
34
|
return -1
|
|
33
35
|
if kind.endswith(' lite'):
|
|
34
36
|
kind = kind[:-len(' lite')]
|
|
35
|
-
|
|
36
|
-
|
|
37
|
+
|
|
38
|
+
# v6: "TPU v6"
|
|
39
|
+
# v7: "TPU7x"
|
|
40
|
+
assert kind[:3] == 'TPU', kind
|
|
41
|
+
return int(re.search(r'\d+', kind).group())
|
|
37
42
|
|
|
38
43
|
|
|
39
44
|
def get_key(
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|