tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,20 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
2
16
|
import functools
|
|
3
|
-
import
|
|
17
|
+
import logging
|
|
4
18
|
import random
|
|
5
19
|
from contextlib import nullcontext
|
|
6
20
|
from dataclasses import dataclass
|
|
@@ -14,7 +28,6 @@ import vllm.envs as vllm_envs
|
|
|
14
28
|
from flax import nnx
|
|
15
29
|
from jax.experimental import mesh_utils
|
|
16
30
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
17
|
-
from torchax.ops.mappings import t2j_dtype
|
|
18
31
|
from vllm.config import VllmConfig
|
|
19
32
|
from vllm.distributed import get_pp_group
|
|
20
33
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
@@ -66,10 +79,12 @@ from tpu_inference.runner.structured_decoding_manager import \
|
|
|
66
79
|
StructuredDecodingManager
|
|
67
80
|
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
68
81
|
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
69
|
-
time_function, to_torch_dtype)
|
|
82
|
+
time_function, to_jax_dtype, to_torch_dtype)
|
|
70
83
|
|
|
71
84
|
logger = init_logger(__name__)
|
|
72
85
|
|
|
86
|
+
logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
|
|
87
|
+
|
|
73
88
|
INVALID_TOKEN_ID = -1
|
|
74
89
|
# Smallest output size
|
|
75
90
|
MIN_NUM_SEQS = 8
|
|
@@ -493,10 +508,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
493
508
|
multimodal_fns = multimodal_fns or {}
|
|
494
509
|
self.precompile_vision_encoder_fn = multimodal_fns.get(
|
|
495
510
|
"precompile_vision_encoder_fn", None)
|
|
496
|
-
self.
|
|
497
|
-
|
|
498
|
-
self.
|
|
499
|
-
|
|
511
|
+
self.embed_multimodal_fn = multimodal_fns.get("embed_multimodal_fn",
|
|
512
|
+
None)
|
|
513
|
+
self.embed_input_ids_fn = multimodal_fns.get("embed_input_ids_fn",
|
|
514
|
+
None)
|
|
500
515
|
self.get_mrope_input_positions_fn = multimodal_fns.get(
|
|
501
516
|
"get_mrope_input_positions_fn", None)
|
|
502
517
|
|
|
@@ -508,7 +523,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
508
523
|
jax.random.key(self.model_config.seed)).params()
|
|
509
524
|
self.is_multimodal_model = (
|
|
510
525
|
self.model_config.is_multimodal_model
|
|
511
|
-
and self.
|
|
526
|
+
and self.embed_multimodal_fn is not None and hasattr(
|
|
512
527
|
self.model_config.hf_config, "architectures"
|
|
513
528
|
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
514
529
|
and len(self.model_config.hf_config.architectures) >= 1
|
|
@@ -524,7 +539,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
524
539
|
def get_kv_cache_spec(self):
|
|
525
540
|
return self.kv_cache_manager.get_kv_cache_spec()
|
|
526
541
|
|
|
527
|
-
def initialize_kv_cache(self,
|
|
542
|
+
def initialize_kv_cache(self,
|
|
543
|
+
kv_cache_config: KVCacheConfig,
|
|
544
|
+
topology_order_id: int = 0) -> None:
|
|
545
|
+
self.topology_order_id = topology_order_id
|
|
528
546
|
self.kv_cache_config = kv_cache_config
|
|
529
547
|
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
530
548
|
self.kv_caches = []
|
|
@@ -809,7 +827,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
809
827
|
sharding = None
|
|
810
828
|
if self.dp_size > 1:
|
|
811
829
|
sharding = NamedSharding(self.mesh,
|
|
812
|
-
PartitionSpec(ShardingAxisName.
|
|
830
|
+
PartitionSpec(ShardingAxisName.MLP_DATA))
|
|
813
831
|
|
|
814
832
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
815
833
|
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
|
|
@@ -1336,7 +1354,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1336
1354
|
_request_distribution = []
|
|
1337
1355
|
for dp_rank in range(dp_size):
|
|
1338
1356
|
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1339
|
-
|
|
1357
|
+
# The batch has been reordered by _reorder_batch so decode requests come first
|
|
1358
|
+
# Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
|
|
1359
|
+
num_decode_in_dp_rank = 0
|
|
1360
|
+
for req_id in req_ids_dp[dp_rank]:
|
|
1361
|
+
if scheduler_output.num_scheduled_tokens[req_id] == 1:
|
|
1362
|
+
num_decode_in_dp_rank += 1
|
|
1363
|
+
_request_distribution.append(
|
|
1364
|
+
[num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
|
|
1340
1365
|
request_distribution = np.array(_request_distribution).ravel()
|
|
1341
1366
|
|
|
1342
1367
|
use_spec_decode = len(
|
|
@@ -1365,7 +1390,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1365
1390
|
self.mesh,
|
|
1366
1391
|
self.input_batch,
|
|
1367
1392
|
padded_num_reqs,
|
|
1368
|
-
sharding=
|
|
1393
|
+
sharding=NamedSharding(self.mesh,
|
|
1394
|
+
PartitionSpec(ShardingAxisName.MLP_DATA)),
|
|
1369
1395
|
)
|
|
1370
1396
|
if self.uses_mrope:
|
|
1371
1397
|
positions = mrope_positions
|
|
@@ -1395,7 +1421,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1395
1421
|
block_tables[
|
|
1396
1422
|
req_offset:req_offset + _num_reqs, :self.
|
|
1397
1423
|
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1398
|
-
|
|
1424
|
+
kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1399
1425
|
# Convert block_tables to 1D on cpu.
|
|
1400
1426
|
block_tables = block_tables.reshape(-1)
|
|
1401
1427
|
block_tables = device_array(
|
|
@@ -1655,7 +1681,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1655
1681
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1656
1682
|
mm_embeds: list[jax.Array]):
|
|
1657
1683
|
if self.is_multimodal_model:
|
|
1658
|
-
inputs_embeds = self.
|
|
1684
|
+
inputs_embeds = self.embed_input_ids_fn(
|
|
1659
1685
|
self.state,
|
|
1660
1686
|
input_ids,
|
|
1661
1687
|
mm_embeds,
|
|
@@ -1712,8 +1738,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1712
1738
|
shard=shard)
|
|
1713
1739
|
|
|
1714
1740
|
def get_intermediate_tensor_spec(self, num_tokens: int):
|
|
1715
|
-
|
|
1716
|
-
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
|
|
1741
|
+
jax_dtype = to_jax_dtype(self.dtype)
|
|
1717
1742
|
num_padded_tokens = runner_utils.get_padded_token_len(
|
|
1718
1743
|
self.num_tokens_paddings, num_tokens)
|
|
1719
1744
|
sharding = NamedSharding(self.mesh, PartitionSpec())
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
|
|
2
15
|
import functools
|
|
3
16
|
from dataclasses import replace
|
tpu_inference/tpu_info.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import glob
|
|
2
16
|
import os
|
|
3
17
|
|
tpu_inference/utils.py
CHANGED
|
@@ -3,7 +3,7 @@ import time
|
|
|
3
3
|
from collections import defaultdict
|
|
4
4
|
from collections.abc import Sequence
|
|
5
5
|
from functools import wraps
|
|
6
|
-
from typing import Any, Callable, List, Tuple
|
|
6
|
+
from typing import Any, Callable, List, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
@@ -28,9 +28,9 @@ TPU_SECOND_LAST_MINOR = 8
|
|
|
28
28
|
|
|
29
29
|
# Map vllm dtype string that doesn't exactly match jax dtype string name.
|
|
30
30
|
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
|
|
31
|
-
"fp8": jnp.float8_e4m3fn,
|
|
32
|
-
"fp8_e4m3": jnp.float8_e4m3fn,
|
|
33
|
-
"fp8_e5m2": jnp.float8_e5m2,
|
|
31
|
+
"fp8": jnp.float8_e4m3fn.dtype,
|
|
32
|
+
"fp8_e4m3": jnp.float8_e4m3fn.dtype,
|
|
33
|
+
"fp8_e5m2": jnp.float8_e5m2.dtype,
|
|
34
34
|
}
|
|
35
35
|
|
|
36
36
|
|
|
@@ -60,6 +60,10 @@ _megacore = False
|
|
|
60
60
|
logger = init_logger(__name__)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
+
def align_to(unpadded_dim, pad_multiple):
|
|
64
|
+
return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
|
|
65
|
+
|
|
66
|
+
|
|
63
67
|
def enable_megacore() -> None:
|
|
64
68
|
global _megacore
|
|
65
69
|
_megacore = True
|
|
@@ -186,7 +190,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
|
|
|
186
190
|
|
|
187
191
|
|
|
188
192
|
def get_dtype_packing(dtype):
|
|
189
|
-
bits = dtypes.bit_width(dtype)
|
|
193
|
+
bits = (dtypes.bit_width(dtype)
|
|
194
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
190
195
|
return 32 // bits
|
|
191
196
|
|
|
192
197
|
|
|
@@ -271,40 +276,11 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
|
|
|
271
276
|
|
|
272
277
|
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
|
273
278
|
"""
|
|
274
|
-
A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
|
|
279
|
+
A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
|
|
275
280
|
"""
|
|
276
281
|
if hash_fn_name == "builtin":
|
|
277
282
|
return hash
|
|
278
|
-
return utils.get_hash_fn_by_name(hash_fn_name)
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
def quantize_kv(key: jax.Array, value: jax.Array,
|
|
282
|
-
kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
|
|
283
|
-
v_scale: float) -> Tuple[jax.Array, jax.Array]:
|
|
284
|
-
"""
|
|
285
|
-
Quantize the key and value tensors.
|
|
286
|
-
|
|
287
|
-
Args:
|
|
288
|
-
key: The key tensor to quantize.
|
|
289
|
-
value: The value tensor to quantize.
|
|
290
|
-
kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
|
|
291
|
-
q_scale: The scale to quantize the key and value tensors by.
|
|
292
|
-
k_scale: The scale to quantize the key tensor by.
|
|
293
|
-
v_scale: The scale to quantize the value tensor by.
|
|
294
|
-
|
|
295
|
-
Returns:
|
|
296
|
-
Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
|
|
297
|
-
"""
|
|
298
|
-
dtype_info = jnp.finfo(kv_cache_quantized_dtype)
|
|
299
|
-
minval, maxval = float(dtype_info.min), float(dtype_info.max)
|
|
300
|
-
key = key.astype(jnp.float32) / k_scale
|
|
301
|
-
key = jnp.clip(key, minval, maxval)
|
|
302
|
-
key = key.astype(kv_cache_quantized_dtype)
|
|
303
|
-
value = value.astype(jnp.float32) / v_scale
|
|
304
|
-
value = jnp.clip(value, minval, maxval)
|
|
305
|
-
value = value.astype(kv_cache_quantized_dtype)
|
|
306
|
-
|
|
307
|
-
return key, value
|
|
283
|
+
return utils.hashing.get_hash_fn_by_name(hash_fn_name)
|
|
308
284
|
|
|
309
285
|
|
|
310
286
|
def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
@@ -321,6 +297,36 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
|
321
297
|
return to_jax_dtype(str_dtype)
|
|
322
298
|
|
|
323
299
|
|
|
300
|
+
def get_mesh_shape_product(
|
|
301
|
+
mesh: Mesh,
|
|
302
|
+
axes: Union[str, list[str], None],
|
|
303
|
+
) -> int:
|
|
304
|
+
"""
|
|
305
|
+
Get the product of mesh dimensions for one or more axes.
|
|
306
|
+
|
|
307
|
+
Examples:
|
|
308
|
+
# Single axis (defaults to 1 if not present)
|
|
309
|
+
get_mesh_shape_product(mesh, "model")
|
|
310
|
+
|
|
311
|
+
# Multiple axes - computes product of their sizes
|
|
312
|
+
get_mesh_shape_product(mesh, ["model", "attn_dp"])
|
|
313
|
+
|
|
314
|
+
# None means no sharding on this dimension
|
|
315
|
+
get_mesh_shape_product(mesh, None) # returns 1
|
|
316
|
+
"""
|
|
317
|
+
if axes is None:
|
|
318
|
+
return 1
|
|
319
|
+
|
|
320
|
+
if isinstance(axes, str):
|
|
321
|
+
axes = [axes]
|
|
322
|
+
|
|
323
|
+
product = 1
|
|
324
|
+
for axis in axes:
|
|
325
|
+
product *= mesh.shape.get(axis, 1)
|
|
326
|
+
|
|
327
|
+
return product
|
|
328
|
+
|
|
329
|
+
|
|
324
330
|
def time_function(func):
|
|
325
331
|
"""
|
|
326
332
|
A decorator to measure the execution time of a function.
|
tpu_inference/worker/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field
|
|
|
6
6
|
from typing import Callable, Dict, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
import jax
|
|
9
|
-
import jax.numpy as jnp
|
|
10
9
|
import jaxlib
|
|
11
10
|
import jaxtyping
|
|
12
11
|
import vllm.envs as vllm_envs
|
|
@@ -19,30 +18,25 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
|
19
18
|
from vllm.lora.request import LoRARequest
|
|
20
19
|
from vllm.tasks import SupportedTask
|
|
21
20
|
from vllm.v1 import utils as vllm_utils
|
|
22
|
-
from vllm.v1.core.kv_cache_utils import get_num_blocks,
|
|
21
|
+
from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
|
|
22
|
+
get_uniform_page_size)
|
|
23
23
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
|
24
24
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
25
25
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
26
26
|
|
|
27
27
|
from tpu_inference import envs, utils
|
|
28
28
|
from tpu_inference.distributed import jax_parallel_state
|
|
29
|
-
from tpu_inference.distributed.utils import (
|
|
30
|
-
|
|
29
|
+
from tpu_inference.distributed.utils import (get_device_topology_order_id,
|
|
30
|
+
get_host_ip, get_kv_transfer_port)
|
|
31
31
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
32
32
|
from tpu_inference.logger import init_logger
|
|
33
33
|
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
34
34
|
JaxIntermediateTensors
|
|
35
|
-
from tpu_inference.runner.kv_cache import
|
|
35
|
+
from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
|
|
36
36
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
37
37
|
|
|
38
38
|
logger = init_logger(__name__)
|
|
39
39
|
|
|
40
|
-
_DTYPE: dict[str, jnp.dtype] = {
|
|
41
|
-
"bfloat16": jnp.bfloat16,
|
|
42
|
-
"float": jnp.float32,
|
|
43
|
-
"float32": jnp.float32,
|
|
44
|
-
}
|
|
45
|
-
|
|
46
40
|
|
|
47
41
|
@dataclass
|
|
48
42
|
class PPConfig:
|
|
@@ -77,21 +71,6 @@ class TPUWorker:
|
|
|
77
71
|
ip: str = "localhost",
|
|
78
72
|
prev_worker_ip: str = "localhost",
|
|
79
73
|
):
|
|
80
|
-
# If we use vLLM's model implementation in PyTorch, we should set it
|
|
81
|
-
# with torch version of the dtype.
|
|
82
|
-
impl = envs.MODEL_IMPL_TYPE
|
|
83
|
-
if impl != "vllm": # vllm-pytorch implementation does not need this conversion
|
|
84
|
-
|
|
85
|
-
# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
|
|
86
|
-
if not isinstance(vllm_config.model_config.dtype, str):
|
|
87
|
-
logger.warning(
|
|
88
|
-
"The model dtype is not properly set for JAX backend. "
|
|
89
|
-
"Overwriting it to jnp.bfloat16")
|
|
90
|
-
vllm_config.model_config.dtype = jnp.bfloat16
|
|
91
|
-
else:
|
|
92
|
-
vllm_config.model_config.dtype = _DTYPE.get(
|
|
93
|
-
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
94
|
-
|
|
95
74
|
self.vllm_config = vllm_config
|
|
96
75
|
self.model_config = vllm_config.model_config
|
|
97
76
|
self.parallel_config = vllm_config.parallel_config
|
|
@@ -250,14 +229,33 @@ class TPUWorker:
|
|
|
250
229
|
need_pp=self.parallel_config.pipeline_parallel_size > 1)
|
|
251
230
|
|
|
252
231
|
ensure_kv_transfer_initialized(self.vllm_config)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
232
|
+
|
|
233
|
+
is_first_rank = True
|
|
234
|
+
is_last_rank = True
|
|
235
|
+
self.topology_order_id = self.rank
|
|
236
|
+
if self.parallel_config.pipeline_parallel_size > 1:
|
|
237
|
+
is_first_rank = self.rank == 0
|
|
238
|
+
is_last_rank = self.rank == self.pp_config.pp_world_size - 1
|
|
239
|
+
else:
|
|
240
|
+
# topology_order_id is used to determine the KV cache
|
|
241
|
+
# mapping between P/D workers
|
|
242
|
+
if multihost_backend == "ray":
|
|
243
|
+
self.topology_order_id = get_device_topology_order_id(
|
|
244
|
+
jax.local_devices(), jax.devices())
|
|
245
|
+
|
|
246
|
+
self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
|
|
247
|
+
self.rank, is_first_rank,
|
|
248
|
+
is_last_rank)
|
|
256
249
|
logger.info(f"Init worker | "
|
|
257
250
|
f"rank={self.rank} | "
|
|
258
|
-
f"
|
|
251
|
+
f"is_first_rank={is_first_rank} | "
|
|
252
|
+
f"is_last_rank={is_last_rank} | "
|
|
253
|
+
f"topology_order_id={self.topology_order_id} | "
|
|
259
254
|
f"is_driver_worker={self.is_driver_worker} | "
|
|
260
|
-
f"hbm={utils.hbm_usage_gb(self.devices)}GiB"
|
|
255
|
+
f"hbm={utils.hbm_usage_gb(self.devices)}GiB |"
|
|
256
|
+
f"self.devices={self.devices} | "
|
|
257
|
+
f"total devices={jax.devices()} | "
|
|
258
|
+
f"local_devices={jax.local_devices()}")
|
|
261
259
|
vllm_utils.report_usage_stats(self.vllm_config)
|
|
262
260
|
|
|
263
261
|
def initialize_pp_transfer_connect(self):
|
|
@@ -395,46 +393,56 @@ class TPUWorker:
|
|
|
395
393
|
# responsible for this translation. When vLLM can be modified, this
|
|
396
394
|
# method should be changed to return `dict[str, AbstractKVCacheSpec]`,
|
|
397
395
|
# and the vLLM side should be updated to handle the translation.
|
|
398
|
-
|
|
396
|
+
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
|
399
397
|
|
|
400
|
-
if len(
|
|
401
|
-
return
|
|
398
|
+
if len(kv_cache_spec) == 0:
|
|
399
|
+
return kv_cache_spec
|
|
402
400
|
|
|
403
401
|
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
|
|
404
402
|
# feature that allows overriding page_size_bytes of KVCacheSpec.
|
|
405
403
|
vllm_page_size_bytes = get_uniform_page_size(
|
|
406
|
-
list(
|
|
407
|
-
|
|
408
|
-
|
|
404
|
+
list(kv_cache_spec.values()))
|
|
405
|
+
attention_page_size_bytes = get_attention_page_size_bytes(
|
|
406
|
+
self.model_runner.mesh, kv_cache_spec)
|
|
409
407
|
|
|
410
|
-
if vllm_page_size_bytes !=
|
|
408
|
+
if vllm_page_size_bytes != attention_page_size_bytes:
|
|
411
409
|
logger.info(
|
|
412
|
-
f"
|
|
413
|
-
f"
|
|
414
|
-
f"
|
|
415
|
-
f"
|
|
416
|
-
|
|
410
|
+
f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
|
|
411
|
+
f"does not match with actual page size used by the kernel "
|
|
412
|
+
f"({attention_page_size_bytes} Bytes). Recalculating number of "
|
|
413
|
+
f"KV blocks using actual page size.")
|
|
414
|
+
|
|
415
|
+
kv_cache_groups = get_kv_cache_groups(self.vllm_config,
|
|
416
|
+
kv_cache_spec)
|
|
417
|
+
group_size = max(
|
|
418
|
+
len(group.layer_names) for group in kv_cache_groups)
|
|
417
419
|
available_memory = self.determine_available_memory()
|
|
418
|
-
num_blocks = get_num_blocks(self.vllm_config,
|
|
419
|
-
available_memory,
|
|
420
|
-
|
|
420
|
+
num_blocks = get_num_blocks(self.vllm_config, group_size,
|
|
421
|
+
available_memory,
|
|
422
|
+
attention_page_size_bytes)
|
|
421
423
|
cache_config = self.vllm_config.cache_config
|
|
422
424
|
cache_config.num_gpu_blocks_override = num_blocks
|
|
423
425
|
|
|
424
|
-
return
|
|
426
|
+
return kv_cache_spec
|
|
425
427
|
|
|
426
428
|
def initialize_from_config(
|
|
427
429
|
self,
|
|
428
430
|
kv_cache_config: KVCacheConfig,
|
|
429
431
|
) -> None:
|
|
430
432
|
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
|
431
|
-
|
|
433
|
+
# Precompile functions with large vocab_size tensors before allocating KV cache to avoid OOM
|
|
434
|
+
if not (envs.SKIP_JAX_PRECOMPILE or
|
|
435
|
+
(hasattr(self.model_runner.model_config, "enforce_eager")
|
|
436
|
+
and self.model_runner.model_config.enforce_eager)):
|
|
437
|
+
self.model_runner.compilation_manager._precompile_sampling()
|
|
438
|
+
self.model_runner.compilation_manager._precompile_gather_logprobs()
|
|
439
|
+
self.model_runner.initialize_kv_cache(kv_cache_config,
|
|
440
|
+
self.topology_order_id)
|
|
432
441
|
|
|
433
442
|
def get_node_kv_ip_port(self) -> tuple[int, str, int]:
|
|
434
|
-
node_id = get_node_id()
|
|
435
443
|
ip = get_host_ip()
|
|
436
444
|
port = get_kv_transfer_port()
|
|
437
|
-
return (int(
|
|
445
|
+
return (int(self.topology_order_id), ip, int(port))
|
|
438
446
|
|
|
439
447
|
def check_health(self) -> None:
|
|
440
448
|
# worker will always be healthy as long as it's running.
|
|
@@ -456,3 +464,8 @@ class TPUWorker:
|
|
|
456
464
|
|
|
457
465
|
def shutdown(self) -> None:
|
|
458
466
|
return
|
|
467
|
+
|
|
468
|
+
# Ray executor do not need handshake metadata
|
|
469
|
+
# as we pass the kv_parameters through proxy server
|
|
470
|
+
def get_kv_connector_handshake_metadata(self) -> None:
|
|
471
|
+
pass
|
{tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tpu_inference
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.13.0rc2.post7
|
|
4
4
|
Author: tpu_inference Contributors
|
|
5
5
|
Classifier: Development Status :: 3 - Alpha
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -14,7 +14,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
14
14
|
Requires-Python: >=3.10
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist: tpu-info==0.
|
|
17
|
+
Requires-Dist: tpu-info==0.7.1
|
|
18
18
|
Requires-Dist: yapf==0.43.0
|
|
19
19
|
Requires-Dist: pytest
|
|
20
20
|
Requires-Dist: pytest-mock
|
|
@@ -25,13 +25,17 @@ Requires-Dist: jax[tpu]==0.8.0
|
|
|
25
25
|
Requires-Dist: jaxlib==0.8.0
|
|
26
26
|
Requires-Dist: jaxtyping
|
|
27
27
|
Requires-Dist: flax==0.11.1
|
|
28
|
-
Requires-Dist: torchax==0.0.
|
|
28
|
+
Requires-Dist: torchax==0.0.10
|
|
29
29
|
Requires-Dist: qwix==0.1.1
|
|
30
30
|
Requires-Dist: torchvision==0.24.0
|
|
31
31
|
Requires-Dist: pathwaysutils
|
|
32
32
|
Requires-Dist: parameterized
|
|
33
33
|
Requires-Dist: numba==0.62.1
|
|
34
34
|
Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
|
|
35
|
+
Requires-Dist: jax==0.8.1
|
|
36
|
+
Requires-Dist: jaxlib==0.8.1
|
|
37
|
+
Requires-Dist: jaxtyping==0.3.2
|
|
38
|
+
Requires-Dist: libtpu==0.0.31
|
|
35
39
|
Dynamic: author
|
|
36
40
|
Dynamic: classifier
|
|
37
41
|
Dynamic: description
|
|
@@ -53,14 +57,12 @@ Dynamic: requires-python
|
|
|
53
57
|
|
|
54
58
|
---
|
|
55
59
|
|
|
56
|
-
_Upcoming Events_ 🔥
|
|
57
|
-
|
|
58
|
-
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
|
|
59
|
-
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
60
|
-
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
61
|
-
|
|
62
60
|
_Latest News_ 🔥
|
|
63
61
|
|
|
62
|
+
- [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
|
|
63
|
+
- Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
64
|
+
- Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
65
|
+
|
|
64
66
|
- [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
|
|
65
67
|
|
|
66
68
|
<details>
|