tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,38 +1,35 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING,
|
|
3
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
|
4
4
|
|
|
5
5
|
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
6
7
|
import vllm.envs as vllm_envs
|
|
7
|
-
from torchax.ops.mappings import j2t_dtype
|
|
8
8
|
from tpu_info import device
|
|
9
9
|
from vllm.inputs import ProcessorInputs, PromptType
|
|
10
10
|
from vllm.platforms.interface import Platform, PlatformEnum
|
|
11
|
-
from vllm.sampling_params import SamplingParams, SamplingType
|
|
12
11
|
|
|
13
12
|
from tpu_inference import envs
|
|
14
13
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
15
14
|
from tpu_inference.logger import init_logger
|
|
16
15
|
|
|
17
16
|
if TYPE_CHECKING:
|
|
18
|
-
from vllm.attention.backends.registry import
|
|
17
|
+
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
18
|
+
from vllm.attention.selector import AttentionSelectorConfig
|
|
19
19
|
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
|
20
20
|
from vllm.pooling_params import PoolingParams
|
|
21
|
+
from vllm.sampling_params import SamplingParams, SamplingType
|
|
21
22
|
else:
|
|
22
23
|
BlockSize = None
|
|
23
24
|
ModelConfig = None
|
|
24
25
|
VllmConfig = None
|
|
25
26
|
PoolingParams = None
|
|
26
|
-
|
|
27
|
+
AttentionBackendEnum = None
|
|
28
|
+
SamplingParams = None
|
|
29
|
+
SamplingType = None
|
|
27
30
|
|
|
28
31
|
logger = init_logger(__name__)
|
|
29
32
|
|
|
30
|
-
_DTYPE: dict[str, jnp.dtype] = {
|
|
31
|
-
"bfloat16": jnp.bfloat16,
|
|
32
|
-
"float": jnp.float32,
|
|
33
|
-
"float32": jnp.float32,
|
|
34
|
-
}
|
|
35
|
-
|
|
36
33
|
|
|
37
34
|
class TpuPlatform(Platform):
|
|
38
35
|
_enum = PlatformEnum.TPU
|
|
@@ -49,25 +46,21 @@ class TpuPlatform(Platform):
|
|
|
49
46
|
|
|
50
47
|
additional_env_vars: list[str] = [
|
|
51
48
|
"PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
|
|
52
|
-
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
|
|
49
|
+
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE",
|
|
50
|
+
"NEW_MODEL_DESIGN"
|
|
53
51
|
]
|
|
54
52
|
|
|
55
53
|
@classmethod
|
|
56
|
-
def get_attn_backend_cls(cls, selected_backend: "
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
if selected_backend != _Backend.PALLAS:
|
|
54
|
+
def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
|
|
55
|
+
attn_selector_config: "AttentionSelectorConfig",
|
|
56
|
+
**kwargs) -> str:
|
|
57
|
+
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
58
|
+
|
|
59
|
+
if selected_backend != AttentionBackendEnum.PALLAS:
|
|
63
60
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
64
61
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
|
|
68
|
-
else:
|
|
69
|
-
logger.info("Using Pallas backend.")
|
|
70
|
-
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
|
62
|
+
logger.info("Using Pallas V1 backend.")
|
|
63
|
+
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
|
|
71
64
|
|
|
72
65
|
@classmethod
|
|
73
66
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
@@ -82,6 +75,14 @@ class TpuPlatform(Platform):
|
|
|
82
75
|
logger.warning(f"Error getting device name: {e}")
|
|
83
76
|
return 'TPU'
|
|
84
77
|
|
|
78
|
+
@classmethod
|
|
79
|
+
def fp8_dtype(cls) -> torch.dtype:
|
|
80
|
+
if cls.get_device_name().lower() == "tpu v6e":
|
|
81
|
+
logger.info(
|
|
82
|
+
"Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
|
|
83
|
+
return torch.float8_e5m2
|
|
84
|
+
return torch.float8_e4m3fn
|
|
85
|
+
|
|
85
86
|
@classmethod
|
|
86
87
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
87
88
|
raise NotImplementedError
|
|
@@ -132,6 +133,7 @@ class TpuPlatform(Platform):
|
|
|
132
133
|
# For v0, the default block size is 16.
|
|
133
134
|
if cache_config and cache_config.block_size is None:
|
|
134
135
|
cache_config.block_size = cast(BlockSize, 16)
|
|
136
|
+
|
|
135
137
|
compilation_config = vllm_config.compilation_config
|
|
136
138
|
|
|
137
139
|
# TPU only supports DYNAMO_TRACE_ONCE compilation level
|
|
@@ -142,40 +144,21 @@ class TpuPlatform(Platform):
|
|
|
142
144
|
if compilation_config.backend == "":
|
|
143
145
|
compilation_config.backend = "openxla"
|
|
144
146
|
|
|
145
|
-
# If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
|
|
146
|
-
impl = envs.MODEL_IMPL_TYPE
|
|
147
|
-
|
|
148
|
-
# NOTE(xiang): convert dtype to jnp.dtype
|
|
149
|
-
# NOTE(wenlong): skip this logic for mm model preprocessing
|
|
150
|
-
# For mm model preprocessors, it may need the output dtype to be torch.
|
|
151
|
-
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
|
|
152
|
-
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
|
|
153
|
-
if not isinstance(vllm_config.model_config.dtype, str):
|
|
154
|
-
logger.warning(
|
|
155
|
-
"The model dtype is not properly set for JAX backend. "
|
|
156
|
-
"Overwriting it to jnp.bfloat16")
|
|
157
|
-
vllm_config.model_config.dtype = jnp.bfloat16
|
|
158
|
-
else:
|
|
159
|
-
vllm_config.model_config.dtype = _DTYPE.get(
|
|
160
|
-
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
161
|
-
|
|
162
|
-
if impl == "vllm":
|
|
163
|
-
vllm_config.model_config.dtype = j2t_dtype(
|
|
164
|
-
vllm_config.model_config.dtype.dtype)
|
|
165
|
-
|
|
166
147
|
# TODO(cuiq): remove this dependency.
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
148
|
+
if vllm_config.model_config:
|
|
149
|
+
from vllm.v1.attention.backends.pallas import \
|
|
150
|
+
PallasAttentionBackend
|
|
151
|
+
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
|
152
|
+
vllm_config) # type: ignore[assignment]
|
|
153
|
+
min_page_size = PallasAttentionBackend.get_min_page_size(
|
|
154
|
+
vllm_config)
|
|
155
|
+
if min_page_size > cache_config.block_size:
|
|
156
|
+
logger.warning(
|
|
157
|
+
"Increase the page size from %s to %s to avoid SMEM OOM",
|
|
158
|
+
cache_config.block_size,
|
|
159
|
+
min_page_size,
|
|
160
|
+
)
|
|
161
|
+
cache_config.block_size = min_page_size # type: ignore[assignment]
|
|
179
162
|
|
|
180
163
|
parallel_config = vllm_config.parallel_config
|
|
181
164
|
scheduler_config = vllm_config.scheduler_config
|
|
@@ -185,12 +168,12 @@ class TpuPlatform(Platform):
|
|
|
185
168
|
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
186
169
|
if not multihost_backend: # Single host
|
|
187
170
|
if parallel_config.pipeline_parallel_size == 1:
|
|
188
|
-
logger.info("Force using UniProcExecutor for JAX on
|
|
189
|
-
|
|
171
|
+
logger.info("Force using UniProcExecutor for JAX on "
|
|
172
|
+
"single host without pipeline parallelism.")
|
|
190
173
|
parallel_config.distributed_executor_backend = "uni"
|
|
191
174
|
else:
|
|
192
|
-
logger.info("Force using MultiprocExecutor for JAX on
|
|
193
|
-
|
|
175
|
+
logger.info("Force using MultiprocExecutor for JAX on "
|
|
176
|
+
"single host with pipeline parallelism.")
|
|
194
177
|
parallel_config.distributed_executor_backend = "mp"
|
|
195
178
|
elif multihost_backend == "ray":
|
|
196
179
|
from tpu_inference.executors.ray_distributed_executor import \
|
|
@@ -206,20 +189,15 @@ class TpuPlatform(Platform):
|
|
|
206
189
|
|
|
207
190
|
if scheduler_config.is_multimodal_model and not \
|
|
208
191
|
scheduler_config.disable_chunked_mm_input:
|
|
209
|
-
logger.warning("TPU does not support running Multimodal models"
|
|
210
|
-
|
|
211
|
-
|
|
192
|
+
logger.warning("TPU does not support running Multimodal models"
|
|
193
|
+
" without setting `--disable_chunked_mm_input`. "
|
|
194
|
+
"Forcing --disable_chunked_mm_input.")
|
|
212
195
|
scheduler_config.disable_chunked_mm_input = True
|
|
213
196
|
|
|
214
197
|
kv_transfer_config = vllm_config.kv_transfer_config
|
|
215
198
|
if kv_transfer_config is not None:
|
|
216
199
|
assert kv_transfer_config.kv_connector == "TPUConnector"
|
|
217
|
-
# Late initialization to avoid circular import
|
|
218
|
-
from tpu_inference.models.jax.utils.quantization.quantization_utils import \
|
|
219
|
-
update_vllm_config_for_qwix_quantization
|
|
220
|
-
|
|
221
|
-
update_vllm_config_for_qwix_quantization(vllm_config)
|
|
222
|
-
|
|
200
|
+
# Late initialization to avoid circular import.
|
|
223
201
|
from tpu_inference.core.sched.dp_scheduler import \
|
|
224
202
|
update_vllm_config_for_dp_scheduler
|
|
225
203
|
update_vllm_config_for_dp_scheduler(vllm_config)
|
|
@@ -246,10 +224,11 @@ class TpuPlatform(Platform):
|
|
|
246
224
|
def validate_request(
|
|
247
225
|
cls,
|
|
248
226
|
prompt: PromptType,
|
|
249
|
-
params: Union[SamplingParams, PoolingParams],
|
|
227
|
+
params: Union["SamplingParams", PoolingParams],
|
|
250
228
|
processed_inputs: ProcessorInputs,
|
|
251
229
|
) -> None:
|
|
252
230
|
"""Raises if this request is unsupported on this platform"""
|
|
231
|
+
from vllm.sampling_params import SamplingParams, SamplingType
|
|
253
232
|
|
|
254
233
|
if isinstance(params, SamplingParams):
|
|
255
234
|
if params.sampling_type == SamplingType.RANDOM_SEED:
|
tpu_inference/runner/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|