tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import os
|
|
2
16
|
|
|
3
17
|
from vllm.utils.network_utils import get_ip
|
|
@@ -54,7 +68,45 @@ def get_side_channel_port() -> str:
|
|
|
54
68
|
return port
|
|
55
69
|
|
|
56
70
|
|
|
57
|
-
def
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
71
|
+
def get_device_topology_order_id(local_devices, global_devices) -> int:
|
|
72
|
+
"""
|
|
73
|
+
Calculates the topology order ID for the local device set within the global topology.
|
|
74
|
+
|
|
75
|
+
This function determines the rank of the current host/process based on the
|
|
76
|
+
coordinate of its TPU devices relative to all devices in the topology.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
local_devices: A list of TpuDevice objects available to the current process.
|
|
80
|
+
global_devices: A list of all TpuDevice objects in the global topology.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
The topology order ID (rank) of the local devices.
|
|
84
|
+
"""
|
|
85
|
+
if not local_devices:
|
|
86
|
+
raise ValueError("local_devices cannot be empty")
|
|
87
|
+
if not global_devices:
|
|
88
|
+
raise ValueError("global_devices cannot be empty")
|
|
89
|
+
|
|
90
|
+
# 1. Find the 'anchor' (minimum coordinate) for the local devices.
|
|
91
|
+
# This represents the physical top-left corner of the local machine.
|
|
92
|
+
local_anchor = min(d.coords for d in local_devices)
|
|
93
|
+
|
|
94
|
+
# 2. Group global devices by process to find the anchor for EVERY process.
|
|
95
|
+
process_anchors = {}
|
|
96
|
+
for d in global_devices:
|
|
97
|
+
pid = d.process_index
|
|
98
|
+
# Update the minimum coordinate found for this process so far
|
|
99
|
+
if pid not in process_anchors or d.coords < process_anchors[pid]:
|
|
100
|
+
process_anchors[pid] = d.coords
|
|
101
|
+
|
|
102
|
+
# 3. Sort the unique anchors to establish the canonical topology order.
|
|
103
|
+
# Tuples (x, y, z) sort lexicographically (x first, then y, then z).
|
|
104
|
+
sorted_anchors = sorted(process_anchors.values())
|
|
105
|
+
|
|
106
|
+
# 4. Return the index (rank) of the local anchor in the sorted list.
|
|
107
|
+
try:
|
|
108
|
+
return sorted_anchors.index(local_anchor)
|
|
109
|
+
except ValueError:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Local devices: {local_devices} do not exist in the global device: {global_devices} list."
|
|
112
|
+
)
|
tpu_inference/envs.py
CHANGED
|
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|
|
16
16
|
DECODE_SLICES: str = ""
|
|
17
17
|
SKIP_JAX_PRECOMPILE: bool = False
|
|
18
18
|
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
|
19
|
-
MODEL_IMPL_TYPE: str = "
|
|
19
|
+
MODEL_IMPL_TYPE: str = "auto"
|
|
20
20
|
NEW_MODEL_DESIGN: bool = False
|
|
21
21
|
PHASED_PROFILING_DIR: str = ""
|
|
22
22
|
PYTHON_TRACER_LEVEL: int = 1
|
|
@@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
|
|
24
24
|
NUM_SLICES: int = 1
|
|
25
25
|
RAY_USAGE_STATS_ENABLED: str = "0"
|
|
26
26
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
|
|
27
|
+
ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
def env_with_choices(
|
|
@@ -69,6 +70,34 @@ def env_with_choices(
|
|
|
69
70
|
return _get_validated_env
|
|
70
71
|
|
|
71
72
|
|
|
73
|
+
def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
|
|
74
|
+
"""
|
|
75
|
+
Accepts both numeric strings ("0", "1") and boolean strings
|
|
76
|
+
("true", "false", "True", "False").
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
env_name: Name of the environment variable
|
|
80
|
+
default: Default boolean value if not set
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def _get_bool_env() -> bool:
|
|
84
|
+
value = os.getenv(env_name)
|
|
85
|
+
if value is None or value == "":
|
|
86
|
+
return default
|
|
87
|
+
|
|
88
|
+
value_lower = value.lower()
|
|
89
|
+
if value_lower in ("true", "1"):
|
|
90
|
+
return True
|
|
91
|
+
elif value_lower in ("false", "0"):
|
|
92
|
+
return False
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Invalid boolean value '{value}' for {env_name}. "
|
|
96
|
+
f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
|
|
97
|
+
|
|
98
|
+
return _get_bool_env
|
|
99
|
+
|
|
100
|
+
|
|
72
101
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
73
102
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
74
103
|
"JAX_PLATFORMS":
|
|
@@ -93,17 +122,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
93
122
|
lambda: os.getenv("DECODE_SLICES", ""),
|
|
94
123
|
# Skip JAX precompilation step during initialization
|
|
95
124
|
"SKIP_JAX_PRECOMPILE":
|
|
96
|
-
|
|
125
|
+
env_bool("SKIP_JAX_PRECOMPILE", default=False),
|
|
97
126
|
# Check for XLA recompilation during execution
|
|
98
127
|
"VLLM_XLA_CHECK_RECOMPILATION":
|
|
99
|
-
|
|
128
|
+
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
|
|
100
129
|
# Model implementation type (e.g., "flax_nnx")
|
|
101
130
|
"MODEL_IMPL_TYPE":
|
|
102
|
-
env_with_choices("MODEL_IMPL_TYPE", "
|
|
103
|
-
["vllm", "flax_nnx", "jetpack"]),
|
|
131
|
+
env_with_choices("MODEL_IMPL_TYPE", "auto",
|
|
132
|
+
["auto", "vllm", "flax_nnx", "jetpack"]),
|
|
104
133
|
# Enable new experimental model design
|
|
105
134
|
"NEW_MODEL_DESIGN":
|
|
106
|
-
|
|
135
|
+
env_bool("NEW_MODEL_DESIGN", default=False),
|
|
107
136
|
# Directory to store phased profiling output
|
|
108
137
|
"PHASED_PROFILING_DIR":
|
|
109
138
|
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
|
|
@@ -112,7 +141,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
112
141
|
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
|
|
113
142
|
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
|
|
114
143
|
"USE_MOE_EP_KERNEL":
|
|
115
|
-
|
|
144
|
+
env_bool("USE_MOE_EP_KERNEL", default=False),
|
|
116
145
|
# Number of TPU slices for multi-slice mesh
|
|
117
146
|
"NUM_SLICES":
|
|
118
147
|
lambda: int(os.getenv("NUM_SLICES") or "1"),
|
|
@@ -122,6 +151,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
122
151
|
# Ray compiled DAG channel type for TPU
|
|
123
152
|
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
|
|
124
153
|
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
|
|
154
|
+
"ENABLE_QUANTIZED_MATMUL_KERNEL":
|
|
155
|
+
lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
|
|
125
156
|
}
|
|
126
157
|
|
|
127
158
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import os
|
|
2
16
|
from array import array
|
|
3
17
|
from typing import Any, Dict, List, Optional
|
|
@@ -145,6 +159,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
145
159
|
device_str: node['Resources'][device_str]
|
|
146
160
|
} for node in ray_nodes]
|
|
147
161
|
else:
|
|
162
|
+
assert pp_size == len(
|
|
163
|
+
ray_nodes
|
|
164
|
+
), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
|
|
148
165
|
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
149
166
|
placement_group_specs = [{
|
|
150
167
|
device_str: num_devices_per_pp_rank
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
# TODO: Update documentation
|
|
2
16
|
|
|
3
17
|
from typing import List, Optional, Tuple
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
|
|
|
540
540
|
"""Returns the total vmem bytes used by the kernel."""
|
|
541
541
|
m_per_device = m // tp_size
|
|
542
542
|
n_per_device = n // tp_size
|
|
543
|
-
y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype)
|
|
543
|
+
y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
|
|
544
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
|
|
544
545
|
total_bytes = (
|
|
545
|
-
2 * m_per_device * k *
|
|
546
|
-
|
|
546
|
+
2 * m_per_device * k *
|
|
547
|
+
(dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
|
|
548
|
+
dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
|
|
547
549
|
+ y_vmem_bytes # y_vmem_scratch_ref
|
|
548
|
-
+ 2 * m * bn *
|
|
550
|
+
+ 2 * m * bn *
|
|
551
|
+
(dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
|
|
552
|
+
dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
|
|
549
553
|
+ acc_bytes # acc_vmem_scratch_ref, jnp.float32
|
|
550
554
|
)
|
|
551
555
|
return total_bytes
|
|
@@ -639,8 +643,10 @@ def all_gather_matmul(
|
|
|
639
643
|
# NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
|
|
640
644
|
if grid_k == 1:
|
|
641
645
|
acc_shape = (8, 128)
|
|
642
|
-
acc_bytes =
|
|
643
|
-
|
|
646
|
+
acc_bytes = (
|
|
647
|
+
acc_shape[0] *
|
|
648
|
+
acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
|
|
649
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
|
|
644
650
|
y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
|
|
645
651
|
estimated_vmem_bytes = get_vmem_estimate_bytes(
|
|
646
652
|
m,
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
"""All-gather matmul kernel's tuned block sizes."""
|
|
3
3
|
|
|
4
|
+
import re
|
|
5
|
+
|
|
4
6
|
import jax
|
|
5
7
|
|
|
6
8
|
# key:
|
|
@@ -32,8 +34,11 @@ def get_tpu_version() -> int:
|
|
|
32
34
|
return -1
|
|
33
35
|
if kind.endswith(' lite'):
|
|
34
36
|
kind = kind[:-len(' lite')]
|
|
35
|
-
|
|
36
|
-
|
|
37
|
+
|
|
38
|
+
# v6: "TPU v6"
|
|
39
|
+
# v7: "TPU7x"
|
|
40
|
+
assert kind[:3] == 'TPU', kind
|
|
41
|
+
return int(re.search(r'\d+', kind).group())
|
|
37
42
|
|
|
38
43
|
|
|
39
44
|
def get_key(
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|