tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
tpu_inference/utils.py
CHANGED
|
@@ -3,16 +3,19 @@ import time
|
|
|
3
3
|
from collections import defaultdict
|
|
4
4
|
from collections.abc import Sequence
|
|
5
5
|
from functools import wraps
|
|
6
|
-
from typing import Any, Callable, List, Tuple
|
|
6
|
+
from typing import Any, Callable, List, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
10
10
|
import numpy as np
|
|
11
|
+
import torch
|
|
11
12
|
from jax._src import dtypes
|
|
12
13
|
from jax._src import mesh as mesh_lib
|
|
13
14
|
from jax._src import xla_bridge as xb
|
|
14
15
|
from jax._src.lib import xla_client as xc
|
|
16
|
+
from jax._src.numpy.scalar_types import _ScalarMeta
|
|
15
17
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
18
|
+
from torchax.ops.mappings import j2t_dtype, t2j_dtype
|
|
16
19
|
from vllm import envs as vllm_envs
|
|
17
20
|
from vllm import utils
|
|
18
21
|
|
|
@@ -23,21 +26,44 @@ GBYTES = 1024 * 1024 * 1024
|
|
|
23
26
|
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
24
27
|
TPU_SECOND_LAST_MINOR = 8
|
|
25
28
|
|
|
26
|
-
#
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
"
|
|
31
|
-
"fp8": jnp.float8_e4m3fn,
|
|
32
|
-
"fp8_e4m3": jnp.float8_e4m3,
|
|
33
|
-
"fp8_e5m2": jnp.float8_e5m2,
|
|
34
|
-
"int8": jnp.int8,
|
|
29
|
+
# Map vllm dtype string that doesn't exactly match jax dtype string name.
|
|
30
|
+
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
|
|
31
|
+
"fp8": jnp.float8_e4m3fn.dtype,
|
|
32
|
+
"fp8_e4m3": jnp.float8_e4m3fn.dtype,
|
|
33
|
+
"fp8_e5m2": jnp.float8_e5m2.dtype,
|
|
35
34
|
}
|
|
36
35
|
|
|
36
|
+
|
|
37
|
+
def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
|
|
38
|
+
if isinstance(dtype, str):
|
|
39
|
+
if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
|
|
40
|
+
return dict_dtype
|
|
41
|
+
return jnp.dtype(dtype)
|
|
42
|
+
elif isinstance(dtype, torch.dtype):
|
|
43
|
+
return t2j_dtype(dtype)
|
|
44
|
+
elif isinstance(dtype, jnp.dtype):
|
|
45
|
+
return dtype
|
|
46
|
+
elif isinstance(dtype, _ScalarMeta):
|
|
47
|
+
return dtype.dtype
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Argument is unsupported data type {type(dtype)}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
|
|
53
|
+
# Use jax dtype as an intermediate dtype which we'll be used to convert it
|
|
54
|
+
# into torch dtype.
|
|
55
|
+
dtype = to_jax_dtype(dtype)
|
|
56
|
+
return j2t_dtype(dtype)
|
|
57
|
+
|
|
58
|
+
|
|
37
59
|
_megacore = False
|
|
38
60
|
logger = init_logger(__name__)
|
|
39
61
|
|
|
40
62
|
|
|
63
|
+
def align_to(unpadded_dim, pad_multiple):
|
|
64
|
+
return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
|
|
65
|
+
|
|
66
|
+
|
|
41
67
|
def enable_megacore() -> None:
|
|
42
68
|
global _megacore
|
|
43
69
|
_megacore = True
|
|
@@ -164,7 +190,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
|
|
|
164
190
|
|
|
165
191
|
|
|
166
192
|
def get_dtype_packing(dtype):
|
|
167
|
-
bits = dtypes.bit_width(dtype)
|
|
193
|
+
bits = (dtypes.bit_width(dtype)
|
|
194
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
168
195
|
return 32 // bits
|
|
169
196
|
|
|
170
197
|
|
|
@@ -249,40 +276,11 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
|
|
|
249
276
|
|
|
250
277
|
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
|
251
278
|
"""
|
|
252
|
-
A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
|
|
279
|
+
A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
|
|
253
280
|
"""
|
|
254
281
|
if hash_fn_name == "builtin":
|
|
255
282
|
return hash
|
|
256
|
-
return utils.get_hash_fn_by_name(hash_fn_name)
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
def quantize_kv(key: jax.Array, value: jax.Array,
|
|
260
|
-
kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
|
|
261
|
-
v_scale: float) -> Tuple[jax.Array, jax.Array]:
|
|
262
|
-
"""
|
|
263
|
-
Quantize the key and value tensors.
|
|
264
|
-
|
|
265
|
-
Args:
|
|
266
|
-
key: The key tensor to quantize.
|
|
267
|
-
value: The value tensor to quantize.
|
|
268
|
-
kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
|
|
269
|
-
q_scale: The scale to quantize the key and value tensors by.
|
|
270
|
-
k_scale: The scale to quantize the key tensor by.
|
|
271
|
-
v_scale: The scale to quantize the value tensor by.
|
|
272
|
-
|
|
273
|
-
Returns:
|
|
274
|
-
Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
|
|
275
|
-
"""
|
|
276
|
-
dtype_info = jnp.finfo(kv_cache_quantized_dtype)
|
|
277
|
-
minval, maxval = float(dtype_info.min), float(dtype_info.max)
|
|
278
|
-
key = key.astype(jnp.float32) / k_scale
|
|
279
|
-
key = jnp.clip(key, minval, maxval)
|
|
280
|
-
key = key.astype(kv_cache_quantized_dtype)
|
|
281
|
-
value = value.astype(jnp.float32) / v_scale
|
|
282
|
-
value = jnp.clip(value, minval, maxval)
|
|
283
|
-
value = value.astype(kv_cache_quantized_dtype)
|
|
284
|
-
|
|
285
|
-
return key, value
|
|
283
|
+
return utils.hashing.get_hash_fn_by_name(hash_fn_name)
|
|
286
284
|
|
|
287
285
|
|
|
288
286
|
def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
@@ -295,8 +293,38 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
|
295
293
|
Returns:
|
|
296
294
|
jnp.dtype: The JAX dtype.
|
|
297
295
|
"""
|
|
298
|
-
|
|
299
|
-
return
|
|
296
|
+
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
|
|
297
|
+
return to_jax_dtype(str_dtype)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def get_mesh_shape_product(
|
|
301
|
+
mesh: Mesh,
|
|
302
|
+
axes: Union[str, list[str], None],
|
|
303
|
+
) -> int:
|
|
304
|
+
"""
|
|
305
|
+
Get the product of mesh dimensions for one or more axes.
|
|
306
|
+
|
|
307
|
+
Examples:
|
|
308
|
+
# Single axis (defaults to 1 if not present)
|
|
309
|
+
get_mesh_shape_product(mesh, "model")
|
|
310
|
+
|
|
311
|
+
# Multiple axes - computes product of their sizes
|
|
312
|
+
get_mesh_shape_product(mesh, ["model", "attn_dp"])
|
|
313
|
+
|
|
314
|
+
# None means no sharding on this dimension
|
|
315
|
+
get_mesh_shape_product(mesh, None) # returns 1
|
|
316
|
+
"""
|
|
317
|
+
if axes is None:
|
|
318
|
+
return 1
|
|
319
|
+
|
|
320
|
+
if isinstance(axes, str):
|
|
321
|
+
axes = [axes]
|
|
322
|
+
|
|
323
|
+
product = 1
|
|
324
|
+
for axis in axes:
|
|
325
|
+
product *= mesh.shape.get(axis, 1)
|
|
326
|
+
|
|
327
|
+
return product
|
|
300
328
|
|
|
301
329
|
|
|
302
330
|
def time_function(func):
|
tpu_inference/worker/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field
|
|
|
6
6
|
from typing import Callable, Dict, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
import jax
|
|
9
|
-
import jax.numpy as jnp
|
|
10
9
|
import jaxlib
|
|
11
10
|
import jaxtyping
|
|
12
11
|
import vllm.envs as vllm_envs
|
|
@@ -19,30 +18,25 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
|
19
18
|
from vllm.lora.request import LoRARequest
|
|
20
19
|
from vllm.tasks import SupportedTask
|
|
21
20
|
from vllm.v1 import utils as vllm_utils
|
|
22
|
-
from vllm.v1.core.kv_cache_utils import get_num_blocks,
|
|
21
|
+
from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
|
|
22
|
+
get_uniform_page_size)
|
|
23
23
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
|
24
24
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
25
25
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
26
26
|
|
|
27
27
|
from tpu_inference import envs, utils
|
|
28
28
|
from tpu_inference.distributed import jax_parallel_state
|
|
29
|
-
from tpu_inference.distributed.utils import (
|
|
30
|
-
|
|
29
|
+
from tpu_inference.distributed.utils import (get_device_topology_order_id,
|
|
30
|
+
get_host_ip, get_kv_transfer_port)
|
|
31
31
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
32
32
|
from tpu_inference.logger import init_logger
|
|
33
33
|
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
34
34
|
JaxIntermediateTensors
|
|
35
|
-
from tpu_inference.runner.kv_cache import
|
|
35
|
+
from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
|
|
36
36
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
37
37
|
|
|
38
38
|
logger = init_logger(__name__)
|
|
39
39
|
|
|
40
|
-
_DTYPE: dict[str, jnp.dtype] = {
|
|
41
|
-
"bfloat16": jnp.bfloat16,
|
|
42
|
-
"float": jnp.float32,
|
|
43
|
-
"float32": jnp.float32,
|
|
44
|
-
}
|
|
45
|
-
|
|
46
40
|
|
|
47
41
|
@dataclass
|
|
48
42
|
class PPConfig:
|
|
@@ -77,21 +71,6 @@ class TPUWorker:
|
|
|
77
71
|
ip: str = "localhost",
|
|
78
72
|
prev_worker_ip: str = "localhost",
|
|
79
73
|
):
|
|
80
|
-
# If we use vLLM's model implementation in PyTorch, we should set it
|
|
81
|
-
# with torch version of the dtype.
|
|
82
|
-
impl = envs.MODEL_IMPL_TYPE
|
|
83
|
-
if impl != "vllm": # vllm-pytorch implementation does not need this conversion
|
|
84
|
-
|
|
85
|
-
# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
|
|
86
|
-
if not isinstance(vllm_config.model_config.dtype, str):
|
|
87
|
-
logger.warning(
|
|
88
|
-
"The model dtype is not properly set for JAX backend. "
|
|
89
|
-
"Overwriting it to jnp.bfloat16")
|
|
90
|
-
vllm_config.model_config.dtype = jnp.bfloat16
|
|
91
|
-
else:
|
|
92
|
-
vllm_config.model_config.dtype = _DTYPE.get(
|
|
93
|
-
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
94
|
-
|
|
95
74
|
self.vllm_config = vllm_config
|
|
96
75
|
self.model_config = vllm_config.model_config
|
|
97
76
|
self.parallel_config = vllm_config.parallel_config
|
|
@@ -108,7 +87,7 @@ class TPUWorker:
|
|
|
108
87
|
|
|
109
88
|
if self.model_config.trust_remote_code:
|
|
110
89
|
# note: lazy import to avoid importing torch before initializing
|
|
111
|
-
from vllm.utils import init_cached_hf_modules
|
|
90
|
+
from vllm.utils.import_utils import init_cached_hf_modules
|
|
112
91
|
|
|
113
92
|
init_cached_hf_modules()
|
|
114
93
|
|
|
@@ -250,14 +229,33 @@ class TPUWorker:
|
|
|
250
229
|
need_pp=self.parallel_config.pipeline_parallel_size > 1)
|
|
251
230
|
|
|
252
231
|
ensure_kv_transfer_initialized(self.vllm_config)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
232
|
+
|
|
233
|
+
is_first_rank = True
|
|
234
|
+
is_last_rank = True
|
|
235
|
+
self.topology_order_id = self.rank
|
|
236
|
+
if self.parallel_config.pipeline_parallel_size > 1:
|
|
237
|
+
is_first_rank = self.rank == 0
|
|
238
|
+
is_last_rank = self.rank == self.pp_config.pp_world_size - 1
|
|
239
|
+
else:
|
|
240
|
+
# topology_order_id is used to determine the KV cache
|
|
241
|
+
# mapping between P/D workers
|
|
242
|
+
if multihost_backend == "ray":
|
|
243
|
+
self.topology_order_id = get_device_topology_order_id(
|
|
244
|
+
jax.local_devices(), jax.devices())
|
|
245
|
+
|
|
246
|
+
self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
|
|
247
|
+
self.rank, is_first_rank,
|
|
248
|
+
is_last_rank)
|
|
256
249
|
logger.info(f"Init worker | "
|
|
257
250
|
f"rank={self.rank} | "
|
|
258
|
-
f"
|
|
251
|
+
f"is_first_rank={is_first_rank} | "
|
|
252
|
+
f"is_last_rank={is_last_rank} | "
|
|
253
|
+
f"topology_order_id={self.topology_order_id} | "
|
|
259
254
|
f"is_driver_worker={self.is_driver_worker} | "
|
|
260
|
-
f"hbm={utils.hbm_usage_gb(self.devices)}GiB"
|
|
255
|
+
f"hbm={utils.hbm_usage_gb(self.devices)}GiB |"
|
|
256
|
+
f"self.devices={self.devices} | "
|
|
257
|
+
f"total devices={jax.devices()} | "
|
|
258
|
+
f"local_devices={jax.local_devices()}")
|
|
261
259
|
vllm_utils.report_usage_stats(self.vllm_config)
|
|
262
260
|
|
|
263
261
|
def initialize_pp_transfer_connect(self):
|
|
@@ -357,7 +355,7 @@ class TPUWorker:
|
|
|
357
355
|
if is_start:
|
|
358
356
|
options = jax.profiler.ProfileOptions()
|
|
359
357
|
# default: https://docs.jax.dev/en/latest/profiling.html#general-options
|
|
360
|
-
options.python_tracer_level =
|
|
358
|
+
options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
361
359
|
options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
|
|
362
360
|
jax.profiler.start_trace(self.profile_dir,
|
|
363
361
|
profiler_options=options)
|
|
@@ -395,45 +393,56 @@ class TPUWorker:
|
|
|
395
393
|
# responsible for this translation. When vLLM can be modified, this
|
|
396
394
|
# method should be changed to return `dict[str, AbstractKVCacheSpec]`,
|
|
397
395
|
# and the vLLM side should be updated to handle the translation.
|
|
398
|
-
|
|
396
|
+
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
|
399
397
|
|
|
400
|
-
if len(
|
|
401
|
-
return
|
|
398
|
+
if len(kv_cache_spec) == 0:
|
|
399
|
+
return kv_cache_spec
|
|
402
400
|
|
|
403
401
|
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
|
|
404
402
|
# feature that allows overriding page_size_bytes of KVCacheSpec.
|
|
405
|
-
vllm_page_size_bytes = get_uniform_page_size(
|
|
406
|
-
|
|
407
|
-
|
|
403
|
+
vllm_page_size_bytes = get_uniform_page_size(
|
|
404
|
+
list(kv_cache_spec.values()))
|
|
405
|
+
attention_page_size_bytes = get_attention_page_size_bytes(
|
|
406
|
+
self.model_runner.mesh, kv_cache_spec)
|
|
408
407
|
|
|
409
|
-
if vllm_page_size_bytes !=
|
|
408
|
+
if vllm_page_size_bytes != attention_page_size_bytes:
|
|
410
409
|
logger.info(
|
|
411
|
-
f"
|
|
412
|
-
f"
|
|
413
|
-
f"
|
|
414
|
-
f"
|
|
415
|
-
|
|
410
|
+
f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
|
|
411
|
+
f"does not match with actual page size used by the kernel "
|
|
412
|
+
f"({attention_page_size_bytes} Bytes). Recalculating number of "
|
|
413
|
+
f"KV blocks using actual page size.")
|
|
414
|
+
|
|
415
|
+
kv_cache_groups = get_kv_cache_groups(self.vllm_config,
|
|
416
|
+
kv_cache_spec)
|
|
417
|
+
group_size = max(
|
|
418
|
+
len(group.layer_names) for group in kv_cache_groups)
|
|
416
419
|
available_memory = self.determine_available_memory()
|
|
417
|
-
num_blocks = get_num_blocks(self.vllm_config,
|
|
418
|
-
available_memory,
|
|
419
|
-
|
|
420
|
+
num_blocks = get_num_blocks(self.vllm_config, group_size,
|
|
421
|
+
available_memory,
|
|
422
|
+
attention_page_size_bytes)
|
|
420
423
|
cache_config = self.vllm_config.cache_config
|
|
421
424
|
cache_config.num_gpu_blocks_override = num_blocks
|
|
422
425
|
|
|
423
|
-
return
|
|
426
|
+
return kv_cache_spec
|
|
424
427
|
|
|
425
428
|
def initialize_from_config(
|
|
426
429
|
self,
|
|
427
430
|
kv_cache_config: KVCacheConfig,
|
|
428
431
|
) -> None:
|
|
429
432
|
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
|
430
|
-
|
|
433
|
+
# Precompile functions with large vocab_size tensors before allocating KV cache to avoid OOM
|
|
434
|
+
if not (envs.SKIP_JAX_PRECOMPILE or
|
|
435
|
+
(hasattr(self.model_runner.model_config, "enforce_eager")
|
|
436
|
+
and self.model_runner.model_config.enforce_eager)):
|
|
437
|
+
self.model_runner.compilation_manager._precompile_sampling()
|
|
438
|
+
self.model_runner.compilation_manager._precompile_gather_logprobs()
|
|
439
|
+
self.model_runner.initialize_kv_cache(kv_cache_config,
|
|
440
|
+
self.topology_order_id)
|
|
431
441
|
|
|
432
442
|
def get_node_kv_ip_port(self) -> tuple[int, str, int]:
|
|
433
|
-
node_id = get_node_id()
|
|
434
443
|
ip = get_host_ip()
|
|
435
444
|
port = get_kv_transfer_port()
|
|
436
|
-
return (int(
|
|
445
|
+
return (int(self.topology_order_id), ip, int(port))
|
|
437
446
|
|
|
438
447
|
def check_health(self) -> None:
|
|
439
448
|
# worker will always be healthy as long as it's running.
|
|
@@ -455,3 +464,8 @@ class TPUWorker:
|
|
|
455
464
|
|
|
456
465
|
def shutdown(self) -> None:
|
|
457
466
|
return
|
|
467
|
+
|
|
468
|
+
# Ray executor do not need handshake metadata
|
|
469
|
+
# as we pass the kv_parameters through proxy server
|
|
470
|
+
def get_kv_connector_handshake_metadata(self) -> None:
|
|
471
|
+
pass
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tpu_inference
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.13.2.dev20251230
|
|
4
4
|
Author: tpu_inference Contributors
|
|
5
5
|
Classifier: Development Status :: 3 - Alpha
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -14,7 +14,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
14
14
|
Requires-Python: >=3.10
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist: tpu-info==0.
|
|
17
|
+
Requires-Dist: tpu-info==0.7.1
|
|
18
18
|
Requires-Dist: yapf==0.43.0
|
|
19
19
|
Requires-Dist: pytest
|
|
20
20
|
Requires-Dist: pytest-mock
|
|
@@ -25,12 +25,13 @@ Requires-Dist: jax[tpu]==0.8.0
|
|
|
25
25
|
Requires-Dist: jaxlib==0.8.0
|
|
26
26
|
Requires-Dist: jaxtyping
|
|
27
27
|
Requires-Dist: flax==0.11.1
|
|
28
|
-
Requires-Dist: torchax==0.0.
|
|
28
|
+
Requires-Dist: torchax==0.0.10
|
|
29
29
|
Requires-Dist: qwix==0.1.1
|
|
30
30
|
Requires-Dist: torchvision==0.24.0
|
|
31
31
|
Requires-Dist: pathwaysutils
|
|
32
32
|
Requires-Dist: parameterized
|
|
33
33
|
Requires-Dist: numba==0.62.1
|
|
34
|
+
Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
|
|
34
35
|
Dynamic: author
|
|
35
36
|
Dynamic: classifier
|
|
36
37
|
Dynamic: description
|
|
@@ -52,14 +53,12 @@ Dynamic: requires-python
|
|
|
52
53
|
|
|
53
54
|
---
|
|
54
55
|
|
|
55
|
-
_Upcoming Events_ 🔥
|
|
56
|
-
|
|
57
|
-
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
|
|
58
|
-
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
59
|
-
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
60
|
-
|
|
61
56
|
_Latest News_ 🔥
|
|
62
57
|
|
|
58
|
+
- [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
|
|
59
|
+
- Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
60
|
+
- Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
61
|
+
|
|
63
62
|
- [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
|
|
64
63
|
|
|
65
64
|
<details>
|