tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import os
|
|
2
16
|
|
|
3
17
|
from vllm.utils.network_utils import get_ip
|
|
@@ -54,7 +68,45 @@ def get_side_channel_port() -> str:
|
|
|
54
68
|
return port
|
|
55
69
|
|
|
56
70
|
|
|
57
|
-
def
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
71
|
+
def get_device_topology_order_id(local_devices, global_devices) -> int:
|
|
72
|
+
"""
|
|
73
|
+
Calculates the topology order ID for the local device set within the global topology.
|
|
74
|
+
|
|
75
|
+
This function determines the rank of the current host/process based on the
|
|
76
|
+
coordinate of its TPU devices relative to all devices in the topology.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
local_devices: A list of TpuDevice objects available to the current process.
|
|
80
|
+
global_devices: A list of all TpuDevice objects in the global topology.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
The topology order ID (rank) of the local devices.
|
|
84
|
+
"""
|
|
85
|
+
if not local_devices:
|
|
86
|
+
raise ValueError("local_devices cannot be empty")
|
|
87
|
+
if not global_devices:
|
|
88
|
+
raise ValueError("global_devices cannot be empty")
|
|
89
|
+
|
|
90
|
+
# 1. Find the 'anchor' (minimum coordinate) for the local devices.
|
|
91
|
+
# This represents the physical top-left corner of the local machine.
|
|
92
|
+
local_anchor = min(d.coords for d in local_devices)
|
|
93
|
+
|
|
94
|
+
# 2. Group global devices by process to find the anchor for EVERY process.
|
|
95
|
+
process_anchors = {}
|
|
96
|
+
for d in global_devices:
|
|
97
|
+
pid = d.process_index
|
|
98
|
+
# Update the minimum coordinate found for this process so far
|
|
99
|
+
if pid not in process_anchors or d.coords < process_anchors[pid]:
|
|
100
|
+
process_anchors[pid] = d.coords
|
|
101
|
+
|
|
102
|
+
# 3. Sort the unique anchors to establish the canonical topology order.
|
|
103
|
+
# Tuples (x, y, z) sort lexicographically (x first, then y, then z).
|
|
104
|
+
sorted_anchors = sorted(process_anchors.values())
|
|
105
|
+
|
|
106
|
+
# 4. Return the index (rank) of the local anchor in the sorted list.
|
|
107
|
+
try:
|
|
108
|
+
return sorted_anchors.index(local_anchor)
|
|
109
|
+
except ValueError:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Local devices: {local_devices} do not exist in the global device: {global_devices} list."
|
|
112
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import os
|
|
2
16
|
from array import array
|
|
3
17
|
from typing import Any, Dict, List, Optional
|
|
@@ -6,7 +20,7 @@ import ray
|
|
|
6
20
|
import vllm.envs as envs
|
|
7
21
|
from ray.util.placement_group import PlacementGroup
|
|
8
22
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
|
9
|
-
from vllm.multimodal.inputs import
|
|
23
|
+
from vllm.multimodal.inputs import MultiModalKwargsItem
|
|
10
24
|
from vllm.platforms import current_platform
|
|
11
25
|
from vllm.ray.ray_env import get_env_vars_to_copy
|
|
12
26
|
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
|
|
@@ -39,7 +53,7 @@ logger = init_logger(__name__)
|
|
|
39
53
|
|
|
40
54
|
|
|
41
55
|
def _encode_hook(obj: Any) -> Any:
|
|
42
|
-
"""Custom msgspec enc hook that supports array types and
|
|
56
|
+
"""Custom msgspec enc hook that supports array types and MultiModalKwargsItem.
|
|
43
57
|
|
|
44
58
|
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
|
45
59
|
"""
|
|
@@ -48,7 +62,7 @@ def _encode_hook(obj: Any) -> Any:
|
|
|
48
62
|
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
|
|
49
63
|
f"Given array has a type code of {obj.typecode}.")
|
|
50
64
|
return obj.tobytes()
|
|
51
|
-
if isinstance(obj,
|
|
65
|
+
if isinstance(obj, MultiModalKwargsItem):
|
|
52
66
|
return dict(obj)
|
|
53
67
|
|
|
54
68
|
|
|
@@ -145,6 +159,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
145
159
|
device_str: node['Resources'][device_str]
|
|
146
160
|
} for node in ray_nodes]
|
|
147
161
|
else:
|
|
162
|
+
assert pp_size == len(
|
|
163
|
+
ray_nodes
|
|
164
|
+
), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
|
|
148
165
|
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
149
166
|
placement_group_specs = [{
|
|
150
167
|
device_str: num_devices_per_pp_rank
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
# TODO: Update documentation
|
|
2
16
|
|
|
3
17
|
from typing import List, Optional, Tuple
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
|
|
2
15
|
|
|
3
16
|
import functools
|
|
@@ -1376,171 +1389,166 @@ def fused_ep_moe(
|
|
|
1376
1389
|
hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
|
|
1377
1390
|
renorm_str = "-renorm_k" if renormalize_topk_logits else ""
|
|
1378
1391
|
scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
|
|
1379
|
-
fused_moe =
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
2,
|
|
1433
|
-
bt * num_devices,
|
|
1434
|
-
t_packing,
|
|
1435
|
-
hidden_size // t_packing,
|
|
1436
|
-
),
|
|
1437
|
-
t_dtype,
|
|
1392
|
+
fused_moe = pl.pallas_call(
|
|
1393
|
+
functools.partial(
|
|
1394
|
+
_fused_ep_moe_kernel,
|
|
1395
|
+
top_k=top_k,
|
|
1396
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
1397
|
+
ep_axis_name=ep_axis_name,
|
|
1398
|
+
act_fn=act_fn,
|
|
1399
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
1400
|
+
bt=bt,
|
|
1401
|
+
bf=bf,
|
|
1402
|
+
bd1=bd1,
|
|
1403
|
+
bd2=bd2,
|
|
1404
|
+
btc=btc,
|
|
1405
|
+
bfc=bfc,
|
|
1406
|
+
bd1c=bd1c,
|
|
1407
|
+
bd2c=bd2c,
|
|
1408
|
+
),
|
|
1409
|
+
out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
|
|
1410
|
+
t_dtype),
|
|
1411
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1412
|
+
num_scalar_prefetch=0,
|
|
1413
|
+
in_specs=[
|
|
1414
|
+
hbm_block_spec, # tokens_hbm
|
|
1415
|
+
hbm_block_spec, # w1_hbm
|
|
1416
|
+
hbm_block_spec, # w2_hbm
|
|
1417
|
+
None if w1_scale is None else hbm_block_spec, # w1_scale_hbm
|
|
1418
|
+
None if w2_scale is None else hbm_block_spec, # w2_scale_hbm
|
|
1419
|
+
None if b1 is None else hbm_block_spec, # b1_hbm
|
|
1420
|
+
None if b2 is None else hbm_block_spec, # b2_hbm
|
|
1421
|
+
hbm_block_spec, # gating_output_hbm
|
|
1422
|
+
hbm_block_spec, # a2a_g_hbm
|
|
1423
|
+
],
|
|
1424
|
+
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
1425
|
+
scratch_shapes=([
|
|
1426
|
+
# t2e_routing_x2_smem
|
|
1427
|
+
pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
|
|
1428
|
+
# d2e_count_x2_smem
|
|
1429
|
+
pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32),
|
|
1430
|
+
# expert_offsets_x2_smem
|
|
1431
|
+
pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
|
|
1432
|
+
# expert_starts_x2_smem
|
|
1433
|
+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
|
|
1434
|
+
# expert_sizes_x2_smem
|
|
1435
|
+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
|
|
1436
|
+
# a2a_s_sends_x2_smem
|
|
1437
|
+
pltpu.SMEM((2, ), jnp.int32),
|
|
1438
|
+
# a2a_s_x2_vmem
|
|
1439
|
+
pltpu.VMEM(
|
|
1440
|
+
(
|
|
1441
|
+
2,
|
|
1442
|
+
bt * num_devices,
|
|
1443
|
+
t_packing,
|
|
1444
|
+
hidden_size // t_packing,
|
|
1438
1445
|
),
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1446
|
+
t_dtype,
|
|
1447
|
+
),
|
|
1448
|
+
# a2a_s_acc_x2_vmem
|
|
1449
|
+
pltpu.VMEM(
|
|
1450
|
+
(
|
|
1451
|
+
2,
|
|
1452
|
+
bt * num_devices,
|
|
1453
|
+
t_packing,
|
|
1454
|
+
hidden_size // t_packing,
|
|
1448
1455
|
),
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
),
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1456
|
+
t_dtype,
|
|
1457
|
+
),
|
|
1458
|
+
# a2a_g_acc_vmem
|
|
1459
|
+
pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing),
|
|
1460
|
+
t_dtype),
|
|
1461
|
+
# b_gating_x2_vmem
|
|
1462
|
+
pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
|
|
1463
|
+
# b_output_x2_vmem
|
|
1464
|
+
pltpu.VMEM((2, bt, hidden_size), t_dtype),
|
|
1465
|
+
# b_w1_x2_vmem
|
|
1466
|
+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
1467
|
+
# b_w3_x2_vmem
|
|
1468
|
+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
1469
|
+
# b_w2_x2_vmem
|
|
1470
|
+
pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
|
|
1471
|
+
# b_w1_scale_x2_vmem
|
|
1472
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1473
|
+
(
|
|
1474
|
+
2,
|
|
1475
|
+
t_packing,
|
|
1476
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1477
|
+
1,
|
|
1478
|
+
bf,
|
|
1479
|
+
),
|
|
1480
|
+
jnp.float32,
|
|
1481
|
+
)),
|
|
1482
|
+
# b_w3_scale_x2_vmem
|
|
1483
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1484
|
+
(
|
|
1485
|
+
2,
|
|
1486
|
+
t_packing,
|
|
1487
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1488
|
+
1,
|
|
1489
|
+
bf,
|
|
1490
|
+
),
|
|
1491
|
+
jnp.float32,
|
|
1492
|
+
)),
|
|
1493
|
+
# b_w2_scale_x2_vmem
|
|
1494
|
+
(None if w2_scale is None else pltpu.VMEM(
|
|
1495
|
+
(
|
|
1496
|
+
2,
|
|
1497
|
+
t_packing,
|
|
1498
|
+
bf // subc_quant_wsz,
|
|
1499
|
+
1,
|
|
1500
|
+
bd2 // t_packing,
|
|
1501
|
+
),
|
|
1502
|
+
jnp.float32,
|
|
1503
|
+
)),
|
|
1504
|
+
# b_b1_x2_vmem
|
|
1505
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1506
|
+
(
|
|
1507
|
+
2,
|
|
1508
|
+
1,
|
|
1509
|
+
bf,
|
|
1510
|
+
),
|
|
1511
|
+
jnp.float32,
|
|
1512
|
+
)),
|
|
1513
|
+
# b_b3_x2_vmem
|
|
1514
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1515
|
+
(
|
|
1516
|
+
2,
|
|
1517
|
+
1,
|
|
1518
|
+
bf,
|
|
1519
|
+
),
|
|
1520
|
+
jnp.float32,
|
|
1521
|
+
)),
|
|
1522
|
+
# b_b2_x2_vmem
|
|
1523
|
+
(None if b2 is None else pltpu.VMEM(
|
|
1524
|
+
(
|
|
1525
|
+
2,
|
|
1526
|
+
t_packing,
|
|
1527
|
+
1,
|
|
1528
|
+
bd2 // t_packing,
|
|
1529
|
+
),
|
|
1530
|
+
jnp.float32,
|
|
1531
|
+
)),
|
|
1532
|
+
# b_acc_vmem
|
|
1533
|
+
pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
|
|
1534
|
+
# local_sems
|
|
1535
|
+
pltpu.SemaphoreType.DMA((2, 5)),
|
|
1536
|
+
# send_sems
|
|
1537
|
+
pltpu.SemaphoreType.DMA((2, )),
|
|
1538
|
+
# recv_sems
|
|
1539
|
+
pltpu.SemaphoreType.DMA((2, )),
|
|
1540
|
+
# a2a_gather_sem
|
|
1541
|
+
pltpu.SemaphoreType.DMA,
|
|
1542
|
+
# a2a_acc_sem
|
|
1543
|
+
pltpu.SemaphoreType.DMA,
|
|
1544
|
+
]),
|
|
1545
|
+
),
|
|
1546
|
+
compiler_params=pltpu.CompilerParams(
|
|
1547
|
+
collective_id=0,
|
|
1548
|
+
vmem_limit_bytes=100 * 1024 * 1024,
|
|
1549
|
+
),
|
|
1550
|
+
name=scope_name,
|
|
1551
|
+
)
|
|
1544
1552
|
|
|
1545
1553
|
@jax.jit
|
|
1546
1554
|
@jax.shard_map(
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|