tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
from typing import Callable, Dict, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import jaxtyping
|
|
10
|
+
import vllm.envs as envs
|
|
11
|
+
from vllm.config import VllmConfig, set_current_vllm_config
|
|
12
|
+
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
|
13
|
+
has_kv_transfer_group)
|
|
14
|
+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
15
|
+
init_distributed_environment)
|
|
16
|
+
from vllm.lora.request import LoRARequest
|
|
17
|
+
from vllm.tasks import SupportedTask
|
|
18
|
+
from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
|
|
19
|
+
from vllm.v1.core.sched.output import SchedulerOutput
|
|
20
|
+
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
21
|
+
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
22
|
+
|
|
23
|
+
from tpu_inference import utils
|
|
24
|
+
from tpu_inference.di.abstracts import (AbstractKVCacheConfig,
|
|
25
|
+
AbstractLoRARequest,
|
|
26
|
+
AbstractSchedulerOutput)
|
|
27
|
+
from tpu_inference.di.interfaces import HostInterface
|
|
28
|
+
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
|
|
29
|
+
get_node_id)
|
|
30
|
+
from tpu_inference.logger import init_logger
|
|
31
|
+
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
|
|
32
|
+
from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
|
|
33
|
+
from tpu_inference.worker._temporary_vllm_compat import (
|
|
34
|
+
adapt_kv_cache_config_if_needed, adapt_lora_request_if_needed,
|
|
35
|
+
adapt_scheduler_output_if_needed)
|
|
36
|
+
from tpu_inference.worker.base import AbstractTpuWorker
|
|
37
|
+
|
|
38
|
+
logger = init_logger(__name__)
|
|
39
|
+
|
|
40
|
+
_DTYPE: dict[str, jnp.dtype] = {
|
|
41
|
+
"bfloat16": jnp.bfloat16,
|
|
42
|
+
"float": jnp.float32,
|
|
43
|
+
"float32": jnp.float32,
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TPUWorker(AbstractTpuWorker):
|
|
48
|
+
|
|
49
|
+
def __init__(self,
|
|
50
|
+
vllm_config: VllmConfig,
|
|
51
|
+
local_rank: int,
|
|
52
|
+
rank: int,
|
|
53
|
+
distributed_init_method: str,
|
|
54
|
+
is_driver_worker: bool = False,
|
|
55
|
+
devices=None,
|
|
56
|
+
host_interface: Optional[HostInterface] = None):
|
|
57
|
+
super().__init__(host_interface)
|
|
58
|
+
|
|
59
|
+
# If we use vLLM's model implementation in PyTorch, we should set it
|
|
60
|
+
# with torch version of the dtype.
|
|
61
|
+
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
|
|
62
|
+
if impl != "vllm": # vllm-pytorch implementation does not need this conversion
|
|
63
|
+
|
|
64
|
+
# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
|
|
65
|
+
if not isinstance(vllm_config.model_config.dtype, str):
|
|
66
|
+
logger.warning(
|
|
67
|
+
"The model dtype is not properly set for JAX backend. "
|
|
68
|
+
"Overwriting it to jnp.bfloat16")
|
|
69
|
+
vllm_config.model_config.dtype = jnp.bfloat16
|
|
70
|
+
else:
|
|
71
|
+
vllm_config.model_config.dtype = _DTYPE.get(
|
|
72
|
+
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
73
|
+
|
|
74
|
+
self.vllm_config = vllm_config
|
|
75
|
+
self.model_config = vllm_config.model_config
|
|
76
|
+
self.parallel_config = vllm_config.parallel_config
|
|
77
|
+
self.cache_config = vllm_config.cache_config
|
|
78
|
+
self.local_rank = local_rank
|
|
79
|
+
self.rank = rank
|
|
80
|
+
self.distributed_init_method = distributed_init_method
|
|
81
|
+
self.is_driver_worker = is_driver_worker
|
|
82
|
+
self.devices = devices if devices is not None else []
|
|
83
|
+
|
|
84
|
+
if self.model_config.trust_remote_code:
|
|
85
|
+
# note: lazy import to avoid importing torch before initializing
|
|
86
|
+
from vllm.utils import init_cached_hf_modules
|
|
87
|
+
|
|
88
|
+
init_cached_hf_modules()
|
|
89
|
+
|
|
90
|
+
# Delay profiler initialization to the start of the profiling.
|
|
91
|
+
# This is because in vLLM V1, MP runtime is initialized before the
|
|
92
|
+
# TPU Worker is initialized. The profiler server needs to start after
|
|
93
|
+
# MP runtime is initialized.
|
|
94
|
+
self.profile_dir = None
|
|
95
|
+
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
|
96
|
+
# For TPU, we can only have 1 active profiler session for 1 profiler
|
|
97
|
+
# server. So we only profile on rank0.
|
|
98
|
+
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
|
99
|
+
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
100
|
+
self.profile_dir)
|
|
101
|
+
|
|
102
|
+
use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
|
|
103
|
+
# Only one instance of profiler is allowed
|
|
104
|
+
if use_jax_profiler_server and jax.devices()[0] == self.devices[0]:
|
|
105
|
+
jax_profiler_server_port = int(
|
|
106
|
+
os.getenv("JAX_PROFILER_SERVER_PORT", 9999))
|
|
107
|
+
logger.info(
|
|
108
|
+
f"Starting JAX profiler server on port {jax_profiler_server_port}"
|
|
109
|
+
)
|
|
110
|
+
jax.profiler.start_server(jax_profiler_server_port)
|
|
111
|
+
|
|
112
|
+
def initialize_cache(self, num_gpu_blocks: int,
|
|
113
|
+
num_cpu_blocks: int) -> None:
|
|
114
|
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
115
|
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
116
|
+
|
|
117
|
+
def init_device(self):
|
|
118
|
+
if not self.devices:
|
|
119
|
+
try:
|
|
120
|
+
device_indexes = self.vllm_config.additional_config[
|
|
121
|
+
"sharding"]["sharding_strategy"]["device_indexes"]
|
|
122
|
+
self.devices = [jax.devices()[i] for i in device_indexes]
|
|
123
|
+
except KeyError:
|
|
124
|
+
tp = self.parallel_config.tensor_parallel_size
|
|
125
|
+
self.devices = jax.devices()[:tp]
|
|
126
|
+
|
|
127
|
+
# Initialize the vLLM distribution layer as a single chip environment,
|
|
128
|
+
# we'll swap the model's parallel modules with TPU SPMD equivalents.
|
|
129
|
+
with set_current_vllm_config(self.vllm_config):
|
|
130
|
+
temp_file = tempfile.mkstemp()[1]
|
|
131
|
+
init_distributed_environment(
|
|
132
|
+
world_size=1,
|
|
133
|
+
rank=0,
|
|
134
|
+
local_rank=0,
|
|
135
|
+
distributed_init_method=f"file://{temp_file}",
|
|
136
|
+
backend="gloo",
|
|
137
|
+
)
|
|
138
|
+
ensure_model_parallel_initialized(
|
|
139
|
+
tensor_model_parallel_size=1,
|
|
140
|
+
pipeline_model_parallel_size=1,
|
|
141
|
+
)
|
|
142
|
+
ensure_kv_transfer_initialized(self.vllm_config)
|
|
143
|
+
self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
|
|
144
|
+
logger.info(f"Init worker | "
|
|
145
|
+
f"rank={self.rank} | "
|
|
146
|
+
f"node_id={get_node_id()} | "
|
|
147
|
+
f"is_driver_worker={self.is_driver_worker} | "
|
|
148
|
+
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
|
|
149
|
+
|
|
150
|
+
def determine_available_memory(self) -> int:
|
|
151
|
+
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
|
|
152
|
+
hbm_usage = utils.hbm_usage_bytes(self.devices)
|
|
153
|
+
total_hbm_limit = total_hbm_used = 0
|
|
154
|
+
for used, limit in hbm_usage:
|
|
155
|
+
total_hbm_used += used
|
|
156
|
+
total_hbm_limit += limit
|
|
157
|
+
|
|
158
|
+
total_hbm_limit_cap = total_hbm_limit * gpu_memory_utilization
|
|
159
|
+
total_hbm_avail = int(total_hbm_limit_cap - total_hbm_used)
|
|
160
|
+
|
|
161
|
+
total_hbm_limit_gb = round(total_hbm_limit / utils.GBYTES, 2)
|
|
162
|
+
total_hbm_limit_cap_gb = round(total_hbm_limit_cap / utils.GBYTES, 2)
|
|
163
|
+
total_hbm_used_gb = round(total_hbm_used / utils.GBYTES, 2)
|
|
164
|
+
total_hbm_avail_gb = round(total_hbm_avail / utils.GBYTES, 2)
|
|
165
|
+
|
|
166
|
+
logger.info(f"Memory statistics | "
|
|
167
|
+
f"{total_hbm_limit_gb=}GiB | "
|
|
168
|
+
f"{total_hbm_limit_cap_gb=}GiB | "
|
|
169
|
+
f"{total_hbm_used_gb=}GiB | "
|
|
170
|
+
f"{total_hbm_avail_gb=}GiB")
|
|
171
|
+
|
|
172
|
+
if total_hbm_avail <= 0:
|
|
173
|
+
raise ValueError(f"{total_hbm_used_gb=}GiB exceeds "
|
|
174
|
+
f"{total_hbm_limit_cap_gb=}GiB by "
|
|
175
|
+
f"{-total_hbm_avail_gb}GiB. Please consider "
|
|
176
|
+
f"increasing --gpu-memory-utilization from "
|
|
177
|
+
f"{gpu_memory_utilization} to a larger value.")
|
|
178
|
+
return total_hbm_avail
|
|
179
|
+
|
|
180
|
+
def execute_model(
|
|
181
|
+
self,
|
|
182
|
+
scheduler_output: Union[AbstractSchedulerOutput, SchedulerOutput],
|
|
183
|
+
) -> Optional[ModelRunnerOutput]:
|
|
184
|
+
# NOTE: This method intentionally returns a concrete vLLM type, which
|
|
185
|
+
# violates the pure abstract contract of the base class. This is a
|
|
186
|
+
# deliberate, temporary compromise for the same reasons outlined in
|
|
187
|
+
# the `get_kv_cache_spec` method.
|
|
188
|
+
|
|
189
|
+
# Adapt the input if necessary (temporary compatibility layer)
|
|
190
|
+
adapted_scheduler_output = adapt_scheduler_output_if_needed(
|
|
191
|
+
scheduler_output)
|
|
192
|
+
|
|
193
|
+
# Unwrap the adapter to get the concrete vLLM object
|
|
194
|
+
vllm_scheduler_output = adapted_scheduler_output.vllm_scheduler_output
|
|
195
|
+
output = self.model_runner.execute_model(vllm_scheduler_output)
|
|
196
|
+
|
|
197
|
+
# With a connector, the scheduler expects output from all workers
|
|
198
|
+
if has_kv_transfer_group():
|
|
199
|
+
return output
|
|
200
|
+
|
|
201
|
+
return output if self.is_driver_worker else None
|
|
202
|
+
|
|
203
|
+
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
|
204
|
+
return self.model_runner.take_draft_token_ids()
|
|
205
|
+
|
|
206
|
+
def add_lora(
|
|
207
|
+
self,
|
|
208
|
+
lora_request: Union[AbstractLoRARequest, LoRARequest],
|
|
209
|
+
) -> bool:
|
|
210
|
+
# Adapt the input if necessary (temporary compatibility layer)
|
|
211
|
+
adapted_lora_request = adapt_lora_request_if_needed(lora_request)
|
|
212
|
+
|
|
213
|
+
# Unwrap the adapter to get the concrete vLLM object
|
|
214
|
+
vllm_lora_request = adapted_lora_request.vllm_lora_request # noqa: F841
|
|
215
|
+
|
|
216
|
+
raise NotImplementedError(
|
|
217
|
+
"LoRA is not supported by the JAX worker yet.")
|
|
218
|
+
|
|
219
|
+
def profile(self, is_start: bool = True):
|
|
220
|
+
if is_start:
|
|
221
|
+
options = jax.profiler.ProfileOptions()
|
|
222
|
+
options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
|
|
223
|
+
jax.profiler.start_trace(self.profile_dir,
|
|
224
|
+
profiler_options=options)
|
|
225
|
+
else:
|
|
226
|
+
jax.profiler.stop_trace()
|
|
227
|
+
|
|
228
|
+
def load_model(self) -> None:
|
|
229
|
+
self.model_runner.load_model()
|
|
230
|
+
|
|
231
|
+
def compile_or_warm_up_model(self) -> None:
|
|
232
|
+
self.model_runner.capture_model()
|
|
233
|
+
# Reset the seed to ensure that the random state is not affected by
|
|
234
|
+
# the model initialization and profiling.
|
|
235
|
+
self.model_runner._init_random()
|
|
236
|
+
|
|
237
|
+
def reset_mm_cache(self) -> None:
|
|
238
|
+
pass
|
|
239
|
+
|
|
240
|
+
def get_model(self):
|
|
241
|
+
return self.model_runner.get_model()
|
|
242
|
+
|
|
243
|
+
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
244
|
+
return self.model_runner.get_supported_tasks()
|
|
245
|
+
|
|
246
|
+
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
247
|
+
# NOTE: This method intentionally returns a concrete vLLM type, which
|
|
248
|
+
# violates the pure abstract contract of the base class. This is a
|
|
249
|
+
# deliberate, temporary compromise.
|
|
250
|
+
#
|
|
251
|
+
# The vLLM executor that calls this method expects the concrete
|
|
252
|
+
# `vllm.KVCacheSpec` object to perform its own internal logic. If we
|
|
253
|
+
# returned an abstract adapter, the vLLM code would break.
|
|
254
|
+
#
|
|
255
|
+
# The ideal long-term solution is for the vLLM DI container to be
|
|
256
|
+
# responsible for this translation. When vLLM can be modified, this
|
|
257
|
+
# method should be changed to return `dict[str, AbstractKVCacheSpec]`,
|
|
258
|
+
# and the vLLM side should be updated to handle the translation.
|
|
259
|
+
kv_cache_specs = self.model_runner.get_kv_cache_spec()
|
|
260
|
+
|
|
261
|
+
if len(kv_cache_specs) == 0:
|
|
262
|
+
return kv_cache_specs
|
|
263
|
+
|
|
264
|
+
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
|
|
265
|
+
# feature that allows overriding page_size_bytes of KVCacheSpec.
|
|
266
|
+
vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
|
|
267
|
+
rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
|
|
268
|
+
kv_cache_specs)
|
|
269
|
+
|
|
270
|
+
if vllm_page_size_bytes != rpa_page_size_bytes:
|
|
271
|
+
logger.info(
|
|
272
|
+
f"KV cache page size calculated by vLLM "
|
|
273
|
+
f"({vllm_page_size_bytes} Bytes) does not match with actual "
|
|
274
|
+
f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
|
|
275
|
+
f"Recalculating number of KV blocks using actual page size.")
|
|
276
|
+
|
|
277
|
+
available_memory = self.determine_available_memory()
|
|
278
|
+
num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
|
|
279
|
+
available_memory, rpa_page_size_bytes)
|
|
280
|
+
|
|
281
|
+
cache_config = self.vllm_config.cache_config
|
|
282
|
+
cache_config.num_gpu_blocks_override = num_blocks
|
|
283
|
+
|
|
284
|
+
return kv_cache_specs
|
|
285
|
+
|
|
286
|
+
def initialize_from_config(
|
|
287
|
+
self,
|
|
288
|
+
kv_cache_config: Union[AbstractKVCacheConfig, KVCacheConfig],
|
|
289
|
+
) -> None:
|
|
290
|
+
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
|
291
|
+
adapted_kv_cache_config = adapt_kv_cache_config_if_needed(
|
|
292
|
+
kv_cache_config)
|
|
293
|
+
vllm_kv_cache_config = adapted_kv_cache_config.vllm_kv_cache_config
|
|
294
|
+
self.model_runner.initialize_kv_cache(vllm_kv_cache_config)
|
|
295
|
+
|
|
296
|
+
def get_node_kv_ip_port(self) -> tuple[int, str, int]:
|
|
297
|
+
node_id = get_node_id()
|
|
298
|
+
ip = get_host_ip()
|
|
299
|
+
port = get_kv_transfer_port()
|
|
300
|
+
return (int(node_id), ip, int(port))
|
|
301
|
+
|
|
302
|
+
def check_health(self) -> None:
|
|
303
|
+
# worker will always be healthy as long as it's running.
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
def sync_weights(
|
|
307
|
+
self,
|
|
308
|
+
updated_weights: jaxtyping.PyTree,
|
|
309
|
+
mappings: Dict[str, Tuple[str, Tuple[str]]],
|
|
310
|
+
transpose_keys: Dict[str, Tuple[int]],
|
|
311
|
+
reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
|
|
312
|
+
jaxtyping.PyTree] = None
|
|
313
|
+
) -> None:
|
|
314
|
+
"""Sync the updated weights to the model runner."""
|
|
315
|
+
return self.model_runner._sync_weights(updated_weights=updated_weights,
|
|
316
|
+
mappings=mappings,
|
|
317
|
+
transpose_keys=transpose_keys,
|
|
318
|
+
reshard_fn=reshard_fn)
|
|
319
|
+
|
|
320
|
+
def shutdown(self) -> None:
|
|
321
|
+
return
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tpu_inference
|
|
3
|
+
Version: 0.11.1
|
|
4
|
+
Author: tpu_inference Contributors
|
|
5
|
+
Classifier: Development Status :: 3 - Alpha
|
|
6
|
+
Classifier: Intended Audience :: Developers
|
|
7
|
+
Classifier: Intended Audience :: Education
|
|
8
|
+
Classifier: Intended Audience :: Science/Research
|
|
9
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Requires-Python: >=3.10
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: tpu-info==0.4.0
|
|
18
|
+
Requires-Dist: yapf==0.43.0
|
|
19
|
+
Requires-Dist: pytest
|
|
20
|
+
Requires-Dist: pytest-mock
|
|
21
|
+
Requires-Dist: absl-py
|
|
22
|
+
Requires-Dist: numpy
|
|
23
|
+
Requires-Dist: google-cloud-storage
|
|
24
|
+
Requires-Dist: jax==0.7.2
|
|
25
|
+
Requires-Dist: jaxlib==0.7.2
|
|
26
|
+
Requires-Dist: libtpu==0.0.23
|
|
27
|
+
Requires-Dist: jaxtyping
|
|
28
|
+
Requires-Dist: flax==0.11.1
|
|
29
|
+
Requires-Dist: torchax==0.0.7
|
|
30
|
+
Requires-Dist: qwix==0.1.1
|
|
31
|
+
Requires-Dist: torchvision==0.23.0
|
|
32
|
+
Requires-Dist: pathwaysutils
|
|
33
|
+
Requires-Dist: parameterized
|
|
34
|
+
Dynamic: author
|
|
35
|
+
Dynamic: classifier
|
|
36
|
+
Dynamic: description
|
|
37
|
+
Dynamic: description-content-type
|
|
38
|
+
Dynamic: license-file
|
|
39
|
+
Dynamic: requires-dist
|
|
40
|
+
Dynamic: requires-python
|
|
41
|
+
|
|
42
|
+
<p align="center">
|
|
43
|
+
<!-- This image will ONLY show up in GitHub's dark mode -->
|
|
44
|
+
<img src="docs/assets/tpu_inference_dark_mode_short.png#gh-dark-mode-only" alt="vLLM TPU" style="width: 86%;">
|
|
45
|
+
<!-- This image will ONLY show up in GitHub's light mode (and on other platforms) -->
|
|
46
|
+
<img src="docs/assets/tpu_inference_light_mode_short.png#gh-light-mode-only" alt="vLLM TPU" style="width: 86%;">
|
|
47
|
+
</p>
|
|
48
|
+
|
|
49
|
+
<p align="center">
|
|
50
|
+
| <a href="https://tpu.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27"><b>User Forum</b></a> | <a href="https://join.slack.com/share/enQtOTY2OTUxMDIyNjY1OS00M2MxYWQwZjAyMGZjM2MyZjRjNTA0ZjRkNjkzOTRhMzg0NDM2OTlkZDAxOTAzYmJmNzdkNDc4OGZjYTUwMmRh"><b>Developer Slack</b></a> |
|
|
51
|
+
</p>
|
|
52
|
+
|
|
53
|
+
---
|
|
54
|
+
|
|
55
|
+
_Upcoming Events_ 🔥
|
|
56
|
+
|
|
57
|
+
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
|
|
58
|
+
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
59
|
+
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
60
|
+
|
|
61
|
+
_Latest News_ 🔥
|
|
62
|
+
|
|
63
|
+
- [2025/10] vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU
|
|
64
|
+
|
|
65
|
+
<details>
|
|
66
|
+
<summary><i>Previous News</i> 🔥</summary>
|
|
67
|
+
|
|
68
|
+
</details>
|
|
69
|
+
|
|
70
|
+
---
|
|
71
|
+
## About
|
|
72
|
+
|
|
73
|
+
vLLM TPU is now powered by `tpu-inference`, an expressive and powerful new hardware plugin unifying JAX and PyTorch under a single lowering path within the vLLM project. The new backend now provides a framework for developers to:
|
|
74
|
+
|
|
75
|
+
- Push the limits of TPU hardware performance in open source.
|
|
76
|
+
- Provide more flexibility to JAX and PyTorch users by running PyTorch model definitions performantly on TPU without any additional code changes, while also extending native support to JAX.
|
|
77
|
+
- Retain vLLM standardization: keep the same user experience, telemetry, and interface.
|
|
78
|
+
|
|
79
|
+
## Recommended models and features
|
|
80
|
+
|
|
81
|
+
Although vLLM TPU’s new unified backend makes out-of-the-box high performance serving possible with any model supported in vLLM, the reality is that we're still in the process of implementing a few core components.
|
|
82
|
+
|
|
83
|
+
For this reason, we’ve provided a list of recommended [models](https://github.com/vllm-project/tpu-inference/blob/main/support_matrices/model_support_matrix.csv) and [features](https://github.com/vllm-project/tpu-inference/blob/main/support_matrices/feature_support_matrix.csv) that are validated for accuracy and stress-tested for performance.
|
|
84
|
+
|
|
85
|
+
## Get started
|
|
86
|
+
|
|
87
|
+
Get started with vLLM on TPUs by following the [quickstart guide](https://github.com/vllm-project/tpu-inference/tree/main/docs/getting_started/quickstart.md).
|
|
88
|
+
|
|
89
|
+
Visit our [documentation](https://github.com/vllm-project/tpu-inference/tree/main/docs) to learn more.
|
|
90
|
+
|
|
91
|
+
## Contribute
|
|
92
|
+
|
|
93
|
+
We're always looking for ways to partner with the community to accelerate vLLM TPU development. If you're interested in contributing to this effort, check out the [Contributing guide](https://github.com/vllm-project/tpu-inference/blob/main/CONTRIBUTING.md) and [Issues](https://github.com/vllm-project/tpu-inference/issues) to start. We recommend filtering Issues on the [**good first issue** tag](https://github.com/vllm-project/tpu-inference/issues?q=is%3Aissue+state%3Aopen+label%3A%22good+first+issue%22) if it's your first time contributing.
|
|
94
|
+
|
|
95
|
+
## Contact us
|
|
96
|
+
|
|
97
|
+
- For technical questions and feature requests, open a GitHub [Issue](https://github.com/vllm-project/tpu-inference/issues)
|
|
98
|
+
- For feature requests, please open one on Github [here](https://github.com/vllm-project/tpu-inference/issues/new/choose)
|
|
99
|
+
- For discussing with fellow users, use the [TPU support topic in the vLLM Forum](https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27)
|
|
100
|
+
- For coordinating contributions and development, use the [Developer Slack](https://join.slack.com/share/enQtOTY2OTUxMDIyNjY1OS00M2MxYWQwZjAyMGZjM2MyZjRjNTA0ZjRkNjkzOTRhMzg0NDM2OTlkZDAxOTAzYmJmNzdkNDc4OGZjYTUwMmRh)
|
|
101
|
+
- For collaborations and partnerships, contact us at [vllm-tpu@google.com](mailto:vllm-tpu@google.com)
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
tests/test_base.py,sha256=Ct5WFRMHL7IHEIxk8FrzAvO8m0xFuDpzDBKkAKKAL2Q,7341
|
|
3
|
+
tests/test_quantization.py,sha256=tmHBwpAh1Lz4cSB15fwnvmbA1TZ_zM_I1iP99hhGaEk,34444
|
|
4
|
+
tests/test_tpu_info.py,sha256=ZrwlMsp8ffITkS_b8Q1t_QG-a-WVAd4NUcjHhGibcsI,4670
|
|
5
|
+
tests/test_utils.py,sha256=JFxlYnIddw8t096smLEs_PTycocVVzMGDBgZv5YUlnc,7763
|
|
6
|
+
tests/tpu_backend_test.py,sha256=1_rEUA2XGsDCbZVX5KFOQ00OyTF4YnKRtNmk6ctbKXc,2462
|
|
7
|
+
tests/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
tests/core/test_adapters.py,sha256=HcZHf0GTwfHtW1rhcvAb1A3ezejQpYzzMuJhUvIsDo4,2927
|
|
9
|
+
tests/core/test_core_tpu.py,sha256=n6IPk3VzaFYgm3LDeDp1qoKgRN5ysL7JidFOex2lIDg,22342
|
|
10
|
+
tests/core/test_disagg_executor.py,sha256=QdE2YZs08EyDDCmSjhiXkXqQ9BJTgO6csr_E1xkkfSg,2256
|
|
11
|
+
tests/core/test_disagg_utils.py,sha256=alktTGppaGdg-_un0Amz8Y0IDQz-xNJN0dXG-YApEmY,1955
|
|
12
|
+
tests/core/test_init.py,sha256=NEFI5A9eKGu4rmeJ2iqd0EmhlA3bzbVkXmMi1PV1b9U,1687
|
|
13
|
+
tests/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
tests/kernels/quantized_matmul_kernel_test.py,sha256=od5-zXFjcsc_gWGRDrREL8E_ftymNniQVTzgtkBo_Gc,5679
|
|
15
|
+
tests/kernels/ragged_kv_cache_update_v2_test.py,sha256=6-HjP5CoUG-kcuP8MS-JJVMiBnPRo_zadS3VInnO0D4,10821
|
|
16
|
+
tests/kernels/ragged_paged_attention_kernel_v2_test.py,sha256=pWqo9UYF0tzwgBKO_xYw-TYSPrtAsKcMK5Haj8hFG7I,11340
|
|
17
|
+
tests/kernels/ragged_paged_attention_kernel_v3_test.py,sha256=Hrd8iUkS1pS3rxeTyY53aYRg_ZL_d3NqgBXvOgnigSU,14838
|
|
18
|
+
tests/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
|
+
tests/lora/test_lora.py,sha256=nBbwnmvNgTWjqjKXJ0o_n5k7IksXMt5I9SDbpe6IsfM,4168
|
|
20
|
+
tpu_inference/__init__.py,sha256=5hJ_YCx4yQJ3HH2BruqWaOtnYi_IapS9no7l62foFFo,1096
|
|
21
|
+
tpu_inference/backend.py,sha256=V0DveQe4maWGz_hRD4bivwTXIQsANZkEj63_0m7U6nA,2552
|
|
22
|
+
tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
|
|
23
|
+
tpu_inference/tpu_info.py,sha256=9UohshkndR6dZpGWpWXfTD4qvIVdVgHf0yOoSEkLTrw,2276
|
|
24
|
+
tpu_inference/utils.py,sha256=M1JMLFtd_5_za7XAQi2ENY8d7aRC-S7wbpYpLh42tyQ,9533
|
|
25
|
+
tpu_inference/adapters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
+
tpu_inference/adapters/vllm_adapters.py,sha256=n_iJ-BM4aGlnuf6Qhgye6u-H9dkzZP4SPufjspqw-dk,1412
|
|
27
|
+
tpu_inference/adapters/vllm_config_adapters.py,sha256=V9sNdkKYHJpK-OKaaMYXZZP-IhZW6MOe7fJSwQbJngE,4076
|
|
28
|
+
tpu_inference/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
|
+
tpu_inference/core/adapters.py,sha256=dTZV95MDUdORJbcdYf1JYNTnNDVmv38V22GE7hiQJmo,4484
|
|
30
|
+
tpu_inference/core/core_tpu.py,sha256=_iGVOp30qEGJX3MTYFRXRTsCdbdtd6vwtBtIeFA0sy8,32609
|
|
31
|
+
tpu_inference/core/disagg_executor.py,sha256=dM0cvw2uS-jDlfG4BtsmGAa6hKyhhQ1H-ZQVvn65Xb0,4597
|
|
32
|
+
tpu_inference/core/disagg_utils.py,sha256=ufWNFWQ5n4YnZpPOtoReHlYo4dlN7AbIqCyqS4an0t4,1572
|
|
33
|
+
tpu_inference/di/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
|
+
tpu_inference/di/abstracts.py,sha256=pMC-wD9aVoCmD3RXW5A4oHZ9Islu2C6huG9HYmQvxeY,541
|
|
35
|
+
tpu_inference/di/host.py,sha256=FKRd5Xs1BVzbXku8A35tZmJDwPVg_66drdpAdSzJ5VI,2601
|
|
36
|
+
tpu_inference/di/interfaces.py,sha256=LFlfXHWK61apIBb2nEBNjuAsdLLmnxTtrkVGslEKTj8,1524
|
|
37
|
+
tpu_inference/distributed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
38
|
+
tpu_inference/distributed/tpu_connector.py,sha256=l_5l44BVIIClz4hrv5kWtctoUELHtvEXdqfypXlQh3I,28499
|
|
39
|
+
tpu_inference/distributed/utils.py,sha256=8AOevmxJi7o9hLXyAydcYh-WaWGS6-BKJpV8kW6-P6E,1494
|
|
40
|
+
tpu_inference/executors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
41
|
+
tpu_inference/executors/ray_distributed_executor.py,sha256=VzAPBVb7c8zwGZFtn1OxnwxQTiZMfLnzeI1P7M69d5k,14888
|
|
42
|
+
tpu_inference/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
43
|
+
tpu_inference/experimental/llama3_jax_stashed.py,sha256=YK1oSIfto9ALo-HB45XfSrbq9XgVbE4m2C-9zRwmSzI,10913
|
|
44
|
+
tpu_inference/interfaces/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
45
|
+
tpu_inference/interfaces/cache.py,sha256=ZSNvjYpRjmm3RlsqENa3R9oHP0L7W4zv8nhCvpNGJLA,813
|
|
46
|
+
tpu_inference/interfaces/config.py,sha256=f0fJBbp5FAWjfS-gC5UK2ptrn42f-DMLK9y_QB9Hm_U,1211
|
|
47
|
+
tpu_inference/interfaces/config_parts.py,sha256=QuqV6LH8rPXmWuxraVFKmY08aTaZCy62ndM_vK9JkKQ,2105
|
|
48
|
+
tpu_inference/interfaces/engine.py,sha256=Z1Vxmf5tiKTT3LMMpMX73urMqK3Uc4ZvO0UW8oXAsxE,1446
|
|
49
|
+
tpu_inference/interfaces/outputs.py,sha256=ay9DXf_9JnaPc5kPJg3MYbO8frIQTJZxiugP7ow0gUI,580
|
|
50
|
+
tpu_inference/interfaces/params.py,sha256=Cp8MtBj3LW8-4h9J23AJO4wGvG3aOuFIq3YFS-OG8zA,364
|
|
51
|
+
tpu_inference/interfaces/platform.py,sha256=_EVTdilqpXJX2rRdypANuojOhDO0BCkUwekaxXQqDvQ,1833
|
|
52
|
+
tpu_inference/interfaces/request.py,sha256=DRkjdWo5wmkVwQlq9DqpMDPeVPmQd6dfyhN2_k8tezw,950
|
|
53
|
+
tpu_inference/interfaces/scheduler.py,sha256=cFBRkqVNXHrn-08Zvr9B23YTJUzSehy1rE-Fy2V5nvg,816
|
|
54
|
+
tpu_inference/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
55
|
+
tpu_inference/kernels/collectives/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
56
|
+
tpu_inference/kernels/collectives/all_gather_matmul.py,sha256=0OYLLjlDmkRYScl7lHRi0o___5I5iMiW1gso-_dWSbc,27255
|
|
57
|
+
tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py,sha256=KdaOIzTfIgUR0CcUTA46tpYH-cxPNoJx2cTMEvHx-Ac,1399
|
|
58
|
+
tpu_inference/kernels/collectives/util.py,sha256=LbLD6lOxuszbUsykF89gWQqEJUICCZsfzam3EJDPnFE,1859
|
|
59
|
+
tpu_inference/kernels/flash_attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
60
|
+
tpu_inference/kernels/flash_attention/kernel.py,sha256=n8gmAFVfchMXlyaSEj8xXJm6AadFt26edQihPRdithY,25897
|
|
61
|
+
tpu_inference/kernels/quantized_matmul/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
62
|
+
tpu_inference/kernels/quantized_matmul/kernel.py,sha256=4oEVUXgWOeOY-PfySHf-iEuUSd9J7GQk_rDSbxa7CXg,14086
|
|
63
|
+
tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py,sha256=3zhIm73JEE8qOty2_0v3AJlVz13k6qMB5wlXBDyC1EM,35130
|
|
64
|
+
tpu_inference/kernels/quantized_matmul/util.py,sha256=rf6nIiAj9I2cj4LDvtaZGhcLXEc94o2xgMWasnFaREM,1943
|
|
65
|
+
tpu_inference/kernels/ragged_paged_attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
66
|
+
tpu_inference/kernels/ragged_paged_attention/v2/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
67
|
+
tpu_inference/kernels/ragged_paged_attention/v2/kernel.py,sha256=OiQGAHhyggbp1PeuasPymopFohKOJjGXcpq9p_S8UWA,34940
|
|
68
|
+
tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256=vGp2ZWODTbjyG9z2z0Qf_BX-wYHd5bUybnc_DtOz0nI,10995
|
|
69
|
+
tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
|
|
70
|
+
tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
71
|
+
tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=zc-re4Knsdcfvt2oRO5KGD9-dJs0P8GVJ3yGtclHU2A,54740
|
|
72
|
+
tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=KR2UFpCWjsXCmfMcxxV3yV2DVJp5xcEomOtOKYnSL78,131402
|
|
73
|
+
tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=5ij66Rl7YsjTCH1UERP1W-XXC57sL6ZVPQdTLhMtKHQ,1010
|
|
74
|
+
tpu_inference/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
75
|
+
tpu_inference/layers/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
76
|
+
tpu_inference/layers/common/attention_metadata.py,sha256=St8ZatbY1D7xQACKJH459jMgp3oTP3AQ36mi9FZdrPU,850
|
|
77
|
+
tpu_inference/layers/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
78
|
+
tpu_inference/layers/jax/attention_interface.py,sha256=bXBD8C8RTYTyLJOIGcKd1jH_ZruM0jabLj4n98RIKSA,12003
|
|
79
|
+
tpu_inference/layers/jax/base.py,sha256=Vhts6ZMwNCZ8LbnEXeB0rl3nHdS5hDJWX7HEa7Fl7yE,5775
|
|
80
|
+
tpu_inference/layers/jax/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
|
|
81
|
+
tpu_inference/layers/jax/constants.py,sha256=NcYg0zAf3ClfP7YMYdYu_F1GngOzZaIxIAHBZDunKw4,2755
|
|
82
|
+
tpu_inference/layers/jax/layers.py,sha256=yv_lC2tbJuzVL-OaXYooX82Ys8hWZATeH9M78coJ3VI,10633
|
|
83
|
+
tpu_inference/layers/jax/misc.py,sha256=znKv1Nuq_LgYpaIu0qlzUVDgQWnjjG7aqPJGM8kuwcw,566
|
|
84
|
+
tpu_inference/layers/jax/rope.py,sha256=3ZyR06vwliipkynHHrvcK-Q_aRhvQKDYBOqBYr3oWM8,7029
|
|
85
|
+
tpu_inference/layers/jax/rope_interface.py,sha256=X0SruXizlCHGnssFujC1pL07UC4Vsp7-gdBy_Q7JZhI,8375
|
|
86
|
+
tpu_inference/layers/jax/sharding.py,sha256=L0Uh92oLaXFNNQ0qqzNtBD3x3wnTRexQt8GzsCvqH1k,17874
|
|
87
|
+
tpu_inference/layers/jax/transformer_block.py,sha256=MBN4_hYCGq_-eyomGVUqplBZugZ2LBWUFOgM1UtUxFY,2952
|
|
88
|
+
tpu_inference/layers/jax/attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
89
|
+
tpu_inference/layers/jax/attention/attention.py,sha256=KsGuQpOu7yUpimIr5XBniHKaa2ohx_Ke2YaCOvAG3jc,9837
|
|
90
|
+
tpu_inference/layers/jax/attention/deepseek_v3_attention.py,sha256=YlagoBMwINv2KRH1dr4oEcH_cQ9QMPB55nO2FQZsWs0,14010
|
|
91
|
+
tpu_inference/layers/jax/attention/llama4_attention.py,sha256=VvUmfBxQEbHf3F2BrcYDUnq5abj7CSDYeRsNx_eVAh0,6162
|
|
92
|
+
tpu_inference/layers/jax/moe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
93
|
+
tpu_inference/layers/jax/moe/deepseek_v3_moe.py,sha256=Q6CuwwiZtWYm6iUee1wJoDJrwJE6_bcznTK2HrtXb0M,26089
|
|
94
|
+
tpu_inference/layers/jax/moe/moe.py,sha256=cA8R1rjbBwNEoNlsPWjeIBB9nvaRDwlEdwQTVg6lTpY,8762
|
|
95
|
+
tpu_inference/layers/jax/sample/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
96
|
+
tpu_inference/layers/jax/sample/rejection_sampler.py,sha256=IRfVWjkbVXp9Sv1YrGMMh-LYx1AwbY-3FTXEO1-Ue9g,20423
|
|
97
|
+
tpu_inference/layers/jax/sample/sampling.py,sha256=-47SC7AqU4UgyO91zAdYXTgrBfdlQ9I89HFZKwU0eQA,3223
|
|
98
|
+
tpu_inference/layers/jax/sample/sampling_metadata.py,sha256=c3jHNjh1hkFJ5gxGTEk0qBOZnICeY3EELViF5Omp_Nc,2252
|
|
99
|
+
tpu_inference/layers/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
100
|
+
tpu_inference/layers/vllm/attention.py,sha256=UVhuNCCrz6jdLNotjGtgaR4CVZ4zNmq5VhsiuOTi6_I,6649
|
|
101
|
+
tpu_inference/layers/vllm/fused_moe.py,sha256=ld_-sIHRdUY2tTTHrzHzCahVxH4P0sZVZrxYBQYSJhE,17455
|
|
102
|
+
tpu_inference/layers/vllm/linear_common.py,sha256=_YlJtbdaYcck_j-gFLos_k0ycktVWxT8Qo57tR2YqJ8,7749
|
|
103
|
+
tpu_inference/layers/vllm/sharding.py,sha256=Ck2OzNiucHtrEutDqPQNteu8MEm6isIkE8U5ziowHgM,5779
|
|
104
|
+
tpu_inference/layers/vllm/quantization/__init__.py,sha256=UGv9cJftrBNoC0pU8SLnTLq3zvqMcolN5YJ6n_J5jf4,1392
|
|
105
|
+
tpu_inference/layers/vllm/quantization/awq.py,sha256=78H4AYgbvLCrW-5bGbn9_WM1J8KnRzVOInfKSW_QmzQ,8476
|
|
106
|
+
tpu_inference/layers/vllm/quantization/common.py,sha256=wm3pge6XMTMsLK7_SSdgBP0PvQzz-1mrqN2I6xMqzrc,4218
|
|
107
|
+
tpu_inference/layers/vllm/quantization/unquantized.py,sha256=QIN6lWfVhN4ikUQlDbD8GhkZcLp1-s1Zi66aqKenmeo,10062
|
|
108
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
109
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=ifC6UsCY0tB6BO7X-PWtw-ikUc5IhcPcLvo0_RFrEsM,5253
|
|
110
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
111
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py,sha256=6sQvsxiWdi5Vte8V9vrQ2abaqGqWpq-mtzU7lGAo-ac,8759
|
|
112
|
+
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=4y7lYgybpXszpCAtxGFhR8LDEbEoCCeo3DfUSOXxhaQ,5202
|
|
113
|
+
tpu_inference/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
114
|
+
tpu_inference/lora/torch_lora_ops.py,sha256=pr3N7DVfkn3ANijUC6dBoiCtIJW4fdJpKdC3zWBUsxE,3121
|
|
115
|
+
tpu_inference/lora/torch_punica_tpu.py,sha256=ZfwWpPhkz4VQyxX9KeClx1hhchglsCWl0xpcGZsuMG0,12522
|
|
116
|
+
tpu_inference/mock/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
117
|
+
tpu_inference/mock/vllm_config_utils.py,sha256=FlQshLjoHdgs3C66tYHYbKFUjbk9DhUwY-7HibZk0fI,878
|
|
118
|
+
tpu_inference/mock/vllm_envs.py,sha256=hHtbFOM45T5EB2tEGecMGbJA0qOI9dmNYcjANgtah98,51477
|
|
119
|
+
tpu_inference/mock/vllm_logger.py,sha256=vUGnN5nKT--ZvU15YCzODUM_FGiXKhcrrjDGjeN00RQ,7297
|
|
120
|
+
tpu_inference/mock/vllm_logging_utils.py,sha256=TEUmKj3xHiLzHBnFqAujcxH0t2hBQ04sUaho2RyORnk,486
|
|
121
|
+
tpu_inference/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
122
|
+
tpu_inference/models/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
123
|
+
tpu_inference/models/common/model_loader.py,sha256=kOwc5Dyn433U0F-qZU1D0_k5USkMTY5Em0_WvQfjIYc,17661
|
|
124
|
+
tpu_inference/models/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
125
|
+
tpu_inference/models/jax/deepseek_v3.py,sha256=735PSgqxxrYL9JIsohhUXimjSNYMeNlepfRLrYHZ9us,40038
|
|
126
|
+
tpu_inference/models/jax/llama3.py,sha256=bi-wIgZxR9h_DwoYHczPZXqrcvbzCVwnANuKnak6HcI,13024
|
|
127
|
+
tpu_inference/models/jax/llama4.py,sha256=WMs4gQxbkEZXo7beVJSwPNyZX0AR6prpSE7RGVb9U74,21733
|
|
128
|
+
tpu_inference/models/jax/llama_eagle3.py,sha256=STUkAK6XEA7JM3i_Lx36-t5BhkAGeW_xYiq3zYhHP1A,12297
|
|
129
|
+
tpu_inference/models/jax/phi3.py,sha256=Oz68PE2Z1t8wTed95_w0KMIXfnfV72ZwXugNOdWOV5w,13576
|
|
130
|
+
tpu_inference/models/jax/qwen2.py,sha256=RYb0hMKzPnFOAyhqbztoNlSrFIlRa74fYqSNecA2VOY,13354
|
|
131
|
+
tpu_inference/models/jax/qwen2_5_vl.py,sha256=GrUlM16EWsaGPpSnn1KhjcrAHfeJeC1Z3cVefw0-ynQ,38522
|
|
132
|
+
tpu_inference/models/jax/qwen3.py,sha256=SOL-Pvp56IrMxqXpPf5EFacBI6AJNlqf4Zrr1pkabGw,10994
|
|
133
|
+
tpu_inference/models/jax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
134
|
+
tpu_inference/models/jax/utils/file_utils.py,sha256=NOuSC3YFnZpf3CZgYdghbbiNYJt42zgjlEYbOZIVct4,2840
|
|
135
|
+
tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=huW_yfntOJ_3ZXYUN1tJtmeK7EMoOBZExTZQtvfHOdk,6189
|
|
136
|
+
tpu_inference/models/jax/utils/weight_utils.py,sha256=lZIW-39BA6GzdMZ_nr-CapBttLsfEajJvMJo8ykr0B0,19507
|
|
137
|
+
tpu_inference/models/jax/utils/quantization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
138
|
+
tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=hpzEzosiGi_02bgBXzW-AwZnKEiP_NPiKvpLSIPNjD4,24519
|
|
139
|
+
tpu_inference/models/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
140
|
+
tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=CyA9Gk8rmL1_FmIJ0NQcsutkwZn_DBZlzwuib2M2HuI,11141
|
|
141
|
+
tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=yxlJHPmRQIAwlb1MmHK3xfXokgIkJ-evNU4PgyoJUdg,1187
|
|
142
|
+
tpu_inference/platforms/__init__.py,sha256=2m4E-nxkBhYZFG23Ni4_AFpZe8xQTimdRltkrNzp7WA,69
|
|
143
|
+
tpu_inference/platforms/tpu_jax.py,sha256=oKQFXjNF6cK2QZT7bqgb50oBwr-FN4VO0VdQXl1TQmE,9941
|
|
144
|
+
tpu_inference/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
145
|
+
tpu_inference/runner/block_table_jax.py,sha256=HCjrOMpsWk_x3lW--AOPPUHBHplIzGioMTuHKFxtr6A,4164
|
|
146
|
+
tpu_inference/runner/compilation_manager.py,sha256=16n36Ne4LbmPei8UIAnUMw4TrLcBpe7a5Kvc3oibqcA,30904
|
|
147
|
+
tpu_inference/runner/input_batch_jax.py,sha256=lqFGhZ3w92MPzpiGJ6bNUsQC8X1AP8JpKiWMeXs5tto,18260
|
|
148
|
+
tpu_inference/runner/kv_cache.py,sha256=dU7DRJn0--qgPLV00jCIw4sabSf007mO5kCWnNrNeDI,3952
|
|
149
|
+
tpu_inference/runner/kv_cache_manager.py,sha256=bDkbfpQ41L-n6R-LrseZE85DIuTtu4vbt4mCj1MJa48,21467
|
|
150
|
+
tpu_inference/runner/lora_utils.py,sha256=XFNHPJvZe4e87tbyyKpOY9Vb28M9Rza3HXHNsem7jVg,3872
|
|
151
|
+
tpu_inference/runner/multimodal_manager.py,sha256=2-QQcLuWikP7JgmC3tGovNDYfvikZl3tWyAiX4x8YDc,9283
|
|
152
|
+
tpu_inference/runner/persistent_batch_manager.py,sha256=Zo8w2EdFZTSQtx6DCl57P8kQWkebquXl22RGIX2yqec,11160
|
|
153
|
+
tpu_inference/runner/speculative_decoding_manager.py,sha256=_2oAwo_8e4N-FJXjC9oR-fsO8WjukCdvQhPH4R8B-c4,10274
|
|
154
|
+
tpu_inference/runner/structured_decoding_manager.py,sha256=0SIoa5orxDcx76ziatKJ-GfnTAVIPCPTaMS15nxRR5U,3673
|
|
155
|
+
tpu_inference/runner/tpu_jax_runner.py,sha256=HA-PBThgXv0GHfFxA9ltQr7fQFDw8rwE4SKMsYV0zMI,34285
|
|
156
|
+
tpu_inference/runner/utils.py,sha256=5QcZW8an8EHs_zHKzIGqIf3ltAevusdwgaLLFrB9rc8,17131
|
|
157
|
+
tpu_inference/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
158
|
+
tpu_inference/spec_decode/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
159
|
+
tpu_inference/spec_decode/jax/eagle3.py,sha256=PgIAJMuEyy61Tz4SQ6QZqB-B4t4-RYDmUIoHDyOHEjA,15204
|
|
160
|
+
tpu_inference/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
161
|
+
tpu_inference/worker/_temporary_vllm_compat.py,sha256=GpF8TuPMDbc0fvIxe7XWEe69FES_F-jJnmcaTgf2dO8,5182
|
|
162
|
+
tpu_inference/worker/base.py,sha256=0Dd3CKk3e7DgvzhfH4M-9-MEQNyYh4zUWSO4tnHFd6s,3140
|
|
163
|
+
tpu_inference/worker/tpu_worker_jax.py,sha256=7b2QVTSbveifm9_BgNnVGwEvh5zPrEi1qiXXTwFFODc,14093
|
|
164
|
+
tpu_inference-0.11.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
165
|
+
tpu_inference-0.11.1.dist-info/METADATA,sha256=uKyRzPptKu13NN6_lYOinPlLYk57ZUFleECr2JDgLrs,5393
|
|
166
|
+
tpu_inference-0.11.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
167
|
+
tpu_inference-0.11.1.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
|
|
168
|
+
tpu_inference-0.11.1.dist-info/RECORD,,
|