tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +4 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +77 -54
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +9 -9
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +9 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +2 -1
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +21 -8
- tpu_inference/models/vllm/vllm_model_wrapper.py +4 -4
- tpu_inference/platforms/tpu_platform.py +5 -2
- tpu_inference/runner/compilation_manager.py +33 -15
- tpu_inference/runner/kv_cache_manager.py +8 -2
- tpu_inference/runner/tpu_runner.py +187 -99
- tpu_inference/spec_decode/jax/eagle3.py +2 -1
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +158 -22
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +34 -39
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
|
+
from dataclasses import dataclass, field
|
|
5
6
|
from typing import Callable, Dict, Optional, Tuple
|
|
6
7
|
|
|
7
8
|
import jax
|
|
@@ -10,6 +11,7 @@ import jaxlib
|
|
|
10
11
|
import jaxtyping
|
|
11
12
|
import vllm.envs as vllm_envs
|
|
12
13
|
from vllm.config import VllmConfig, set_current_vllm_config
|
|
14
|
+
from vllm.distributed import get_pp_group
|
|
13
15
|
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
|
14
16
|
has_kv_transfer_group)
|
|
15
17
|
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
@@ -23,10 +25,13 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
|
23
25
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
24
26
|
|
|
25
27
|
from tpu_inference import envs, utils
|
|
28
|
+
from tpu_inference.distributed import jax_parallel_state
|
|
26
29
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
|
|
27
30
|
get_node_id)
|
|
28
31
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
29
32
|
from tpu_inference.logger import init_logger
|
|
33
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
34
|
+
JaxIntermediateTensors
|
|
30
35
|
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
|
|
31
36
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
32
37
|
|
|
@@ -39,15 +44,39 @@ _DTYPE: dict[str, jnp.dtype] = {
|
|
|
39
44
|
}
|
|
40
45
|
|
|
41
46
|
|
|
47
|
+
@dataclass
|
|
48
|
+
class PPConfig:
|
|
49
|
+
rank: int
|
|
50
|
+
ip: str
|
|
51
|
+
prev_worker_ip: str
|
|
52
|
+
pp_world_size: int
|
|
53
|
+
|
|
54
|
+
# default env vars for
|
|
55
|
+
# TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
|
|
56
|
+
# if PP is used in single host.
|
|
57
|
+
default_tpu_process_bounds: str = field(init=False)
|
|
58
|
+
default_tpu_chips_per_process_bounds: str = field(init=False)
|
|
59
|
+
default_tpu_visible_chips: str = field(init=False)
|
|
60
|
+
|
|
61
|
+
def __post_init__(self):
|
|
62
|
+
self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
|
|
63
|
+
self.default_tpu_chips_per_process_bounds = "1,1,1"
|
|
64
|
+
self.default_tpu_visible_chips = f"{self.rank}"
|
|
65
|
+
|
|
66
|
+
|
|
42
67
|
class TPUWorker:
|
|
43
68
|
|
|
44
|
-
def __init__(
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
vllm_config: VllmConfig,
|
|
72
|
+
local_rank: int,
|
|
73
|
+
rank: int,
|
|
74
|
+
distributed_init_method: str,
|
|
75
|
+
is_driver_worker: bool = False,
|
|
76
|
+
devices=None,
|
|
77
|
+
ip: str = "localhost",
|
|
78
|
+
prev_worker_ip: str = "localhost",
|
|
79
|
+
):
|
|
51
80
|
# If we use vLLM's model implementation in PyTorch, we should set it
|
|
52
81
|
# with torch version of the dtype.
|
|
53
82
|
impl = envs.MODEL_IMPL_TYPE
|
|
@@ -74,6 +103,8 @@ class TPUWorker:
|
|
|
74
103
|
self.devices = devices if devices is not None else []
|
|
75
104
|
self.device_ranks = set(device.id for device in self.devices
|
|
76
105
|
if isinstance(device, jaxlib._jax.Device))
|
|
106
|
+
self.pp_config = PPConfig(rank, ip, prev_worker_ip,
|
|
107
|
+
self.parallel_config.pipeline_parallel_size)
|
|
77
108
|
|
|
78
109
|
if self.model_config.trust_remote_code:
|
|
79
110
|
# note: lazy import to avoid importing torch before initializing
|
|
@@ -86,7 +117,7 @@ class TPUWorker:
|
|
|
86
117
|
# TPU Worker is initialized. The profiler server needs to start after
|
|
87
118
|
# MP runtime is initialized.
|
|
88
119
|
self.profile_dir = None
|
|
89
|
-
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
|
120
|
+
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
|
|
90
121
|
if not self.devices or 0 in self.device_ranks:
|
|
91
122
|
# For TPU, we can only have 1 active profiler session for 1 profiler
|
|
92
123
|
# server. So we only profile on rank0.
|
|
@@ -94,6 +125,14 @@ class TPUWorker:
|
|
|
94
125
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
95
126
|
self.profile_dir)
|
|
96
127
|
|
|
128
|
+
# For PP, we use MPMD so we want to profile every worker.
|
|
129
|
+
if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
|
|
130
|
+
self.profile_dir = os.path.join(
|
|
131
|
+
vllm_envs.VLLM_TORCH_PROFILER_DIR,
|
|
132
|
+
f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
|
|
133
|
+
)
|
|
134
|
+
os.makedirs(self.profile_dir, exist_ok=True)
|
|
135
|
+
|
|
97
136
|
use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
|
|
98
137
|
# Only one instance of profiler is allowed
|
|
99
138
|
if use_jax_profiler_server and self.rank < 1:
|
|
@@ -105,31 +144,87 @@ class TPUWorker:
|
|
|
105
144
|
)
|
|
106
145
|
jax.profiler.start_server(jax_profiler_server_port)
|
|
107
146
|
|
|
147
|
+
# step_counter is used to calculate uuid to transfer intermediate tensors.
|
|
148
|
+
self.step_counter = 0
|
|
149
|
+
|
|
108
150
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
109
151
|
num_cpu_blocks: int) -> None:
|
|
110
152
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
111
153
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
112
154
|
|
|
113
|
-
def init_device(self
|
|
155
|
+
def init_device(self,
|
|
156
|
+
tpu_process_bounds="",
|
|
157
|
+
tpu_chips_per_process_bounds="",
|
|
158
|
+
tpu_visible_chips=""):
|
|
159
|
+
# set tpu visible devices for Jax runtime in single host PP.
|
|
160
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
161
|
+
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
|
|
162
|
+
tpu_ports = [
|
|
163
|
+
jax_parallel_state.BASE_JAX_PORT + i
|
|
164
|
+
for i in range(self.pp_config.pp_world_size)
|
|
165
|
+
]
|
|
166
|
+
os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
|
|
167
|
+
[f"localhost:{port}" for port in tpu_ports])
|
|
168
|
+
os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
|
|
169
|
+
os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
|
|
170
|
+
|
|
171
|
+
# Note: Below is the setting for v6e8 host (8 chips of v6e)
|
|
172
|
+
# Replace with your own topology.
|
|
173
|
+
# There are 2 ways of subslicing a v6e
|
|
174
|
+
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
|
|
175
|
+
# TPU_PROCESS_BOUNDS = "1,1,1"
|
|
176
|
+
# TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
|
|
177
|
+
# TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
|
|
178
|
+
# 2) 1 chip for each subslice, with at most 8 subslices,
|
|
179
|
+
# we can do TP=1, PP=1/2/3/4/5/6/7/8
|
|
180
|
+
os.environ[
|
|
181
|
+
"TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
|
|
182
|
+
if tpu_process_bounds \
|
|
183
|
+
else self.pp_config.default_tpu_process_bounds
|
|
184
|
+
os.environ[
|
|
185
|
+
"TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
|
|
186
|
+
if tpu_chips_per_process_bounds \
|
|
187
|
+
else self.pp_config.default_tpu_chips_per_process_bounds
|
|
188
|
+
os.environ[
|
|
189
|
+
"TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
|
|
190
|
+
if tpu_visible_chips \
|
|
191
|
+
else self.pp_config.default_tpu_visible_chips
|
|
192
|
+
|
|
114
193
|
if not self.devices:
|
|
115
194
|
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
|
|
116
195
|
device_indexes = sharding_config.device_indexes
|
|
117
196
|
if device_indexes is not None and len(device_indexes) > 0:
|
|
118
197
|
# Enforcing the devices sequence to be consistent with the specified device indexes
|
|
119
|
-
|
|
120
|
-
device_dict = {
|
|
198
|
+
all_local_devices = jax.local_devices()
|
|
199
|
+
device_dict = {
|
|
200
|
+
device.id: device
|
|
201
|
+
for device in all_local_devices
|
|
202
|
+
}
|
|
121
203
|
self.devices = []
|
|
122
204
|
for device_index in device_indexes:
|
|
123
205
|
device = device_dict[device_index]
|
|
124
206
|
if device is None:
|
|
125
207
|
raise KeyError(
|
|
126
208
|
f"Device index {device_index} not found in "
|
|
127
|
-
f"jax.
|
|
209
|
+
f"jax.local_devices() with IDs {list(device_dict.keys())}!"
|
|
128
210
|
)
|
|
129
211
|
self.devices.append(device)
|
|
212
|
+
assert len(self.devices) >= sharding_config.total_devices
|
|
130
213
|
self.devices = self.devices[:sharding_config.total_devices]
|
|
131
214
|
else:
|
|
132
|
-
self.
|
|
215
|
+
if self.pp_config.pp_world_size > 1:
|
|
216
|
+
# We only support a mixed tp + pp scenario that tp size is
|
|
217
|
+
# smaller or equals the total TPUs in one node
|
|
218
|
+
# say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
|
|
219
|
+
assert jax.local_device_count(
|
|
220
|
+
) >= sharding_config.total_devices
|
|
221
|
+
self.devices = jax.local_devices()[:sharding_config.
|
|
222
|
+
total_devices]
|
|
223
|
+
else:
|
|
224
|
+
# In a multi-host distributed env, say: Ray, local_device count may smaller
|
|
225
|
+
# than the total devices, we just choose the smaller set here.
|
|
226
|
+
self.devices = jax.devices()[:sharding_config.
|
|
227
|
+
total_devices]
|
|
133
228
|
|
|
134
229
|
# Initialize the vLLM distribution layer as a single chip environment,
|
|
135
230
|
# we'll swap the model's parallel modules with TPU SPMD equivalents.
|
|
@@ -146,8 +241,18 @@ class TPUWorker:
|
|
|
146
241
|
tensor_model_parallel_size=1,
|
|
147
242
|
pipeline_model_parallel_size=1,
|
|
148
243
|
)
|
|
244
|
+
|
|
245
|
+
jax_parallel_state.init_pp_distributed_environment(
|
|
246
|
+
self.pp_config.ip,
|
|
247
|
+
self.rank,
|
|
248
|
+
self.parallel_config.pipeline_parallel_size,
|
|
249
|
+
self.devices[0],
|
|
250
|
+
need_pp=self.parallel_config.pipeline_parallel_size > 1)
|
|
251
|
+
|
|
149
252
|
ensure_kv_transfer_initialized(self.vllm_config)
|
|
150
|
-
self.model_runner = TPUModelRunner(
|
|
253
|
+
self.model_runner = TPUModelRunner(
|
|
254
|
+
self.vllm_config, self.devices, self.rank, self.rank == 0,
|
|
255
|
+
self.rank == self.pp_config.pp_world_size - 1)
|
|
151
256
|
logger.info(f"Init worker | "
|
|
152
257
|
f"rank={self.rank} | "
|
|
153
258
|
f"node_id={get_node_id()} | "
|
|
@@ -155,6 +260,12 @@ class TPUWorker:
|
|
|
155
260
|
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
|
|
156
261
|
vllm_utils.report_usage_stats(self.vllm_config)
|
|
157
262
|
|
|
263
|
+
def initialize_pp_transfer_connect(self):
|
|
264
|
+
if self.rank == 0:
|
|
265
|
+
return
|
|
266
|
+
jax_parallel_state.connect(self.pp_config.prev_worker_ip,
|
|
267
|
+
self.rank - 1)
|
|
268
|
+
|
|
158
269
|
def determine_available_memory(self) -> int:
|
|
159
270
|
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
|
|
160
271
|
hbm_usage = utils.hbm_usage_bytes(self.devices)
|
|
@@ -194,14 +305,39 @@ class TPUWorker:
|
|
|
194
305
|
# deliberate, temporary compromise for the same reasons outlined in
|
|
195
306
|
# the `get_kv_cache_spec` method.
|
|
196
307
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
308
|
+
if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
|
|
309
|
+
intermediate_tensors = None
|
|
310
|
+
else:
|
|
311
|
+
# receive intermediate tensors
|
|
312
|
+
uuid = self.model_runner.get_uuid_for_jax_transfer(
|
|
313
|
+
scheduler_output, self.rank - 1, self.step_counter)
|
|
314
|
+
# TODO: this method might only works for vllm model, not sure about jax models.
|
|
315
|
+
tensor_spec = self.model_runner.get_intermediate_tensor_spec(
|
|
316
|
+
scheduler_output.total_num_scheduled_tokens)
|
|
317
|
+
intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
|
|
318
|
+
uuid, tensor_spec)
|
|
319
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
320
|
+
intermediate_tensors_dict)
|
|
321
|
+
|
|
322
|
+
output = self.model_runner.execute_model(scheduler_output,
|
|
323
|
+
intermediate_tensors)
|
|
324
|
+
|
|
325
|
+
if isinstance(output, JaxIntermediateTensors):
|
|
326
|
+
assert self.parallel_config.pipeline_parallel_size > 1
|
|
327
|
+
assert not get_pp_group().is_last_rank
|
|
328
|
+
# send intermediate tensors
|
|
329
|
+
uuid = self.model_runner.get_uuid_for_jax_transfer(
|
|
330
|
+
scheduler_output, self.rank, self.step_counter)
|
|
331
|
+
get_pp_group().send_tensor_dict(uuid, output.tensors)
|
|
332
|
+
self.step_counter += 1
|
|
333
|
+
return None
|
|
334
|
+
else:
|
|
335
|
+
self.step_counter += 1
|
|
336
|
+
# With a connector, the scheduler expects output from all workers
|
|
337
|
+
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
|
|
338
|
+
if has_kv_transfer_group():
|
|
339
|
+
return output
|
|
340
|
+
return output if self.is_driver_worker else None
|
|
205
341
|
|
|
206
342
|
def sample_tokens(self,
|
|
207
343
|
grammar_output: GrammarOutput) -> ModelRunnerOutput:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tpu_inference
|
|
3
|
-
Version: 0.11.1.
|
|
3
|
+
Version: 0.11.1.dev202511220812
|
|
4
4
|
Author: tpu_inference Contributors
|
|
5
5
|
Classifier: Development Status :: 3 - Alpha
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -27,7 +27,7 @@ Requires-Dist: jaxtyping
|
|
|
27
27
|
Requires-Dist: flax==0.11.1
|
|
28
28
|
Requires-Dist: torchax==0.0.7
|
|
29
29
|
Requires-Dist: qwix==0.1.1
|
|
30
|
-
Requires-Dist: torchvision==0.
|
|
30
|
+
Requires-Dist: torchvision==0.24.0
|
|
31
31
|
Requires-Dist: pathwaysutils
|
|
32
32
|
Requires-Dist: parameterized
|
|
33
33
|
Requires-Dist: numba==0.62.1
|
|
@@ -21,27 +21,27 @@ tests/kernels/ragged_paged_attention_kernel_v3_test.py,sha256=Hrd8iUkS1pS3rxeTyY
|
|
|
21
21
|
tests/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
22
|
tests/lora/conftest.py,sha256=EXjwE1CjmUUlMEXpyE3UwxvgrKUllE73I8BNKfP1FTc,984
|
|
23
23
|
tests/lora/test_bgmv.py,sha256=gQxWsJdNX2nkrE2xyrG0exwf3E2eHm2k2nkEXoANuQc,1359
|
|
24
|
-
tests/lora/test_layers.py,sha256=
|
|
24
|
+
tests/lora/test_layers.py,sha256=6B4HhMAItQmt0hPAQgyXgwSYs7b3bIbUf6LaPsqXLzY,25923
|
|
25
25
|
tests/lora/test_lora.py,sha256=wJiF1P1BDnPN8TLX2tlFtdZ_QCkV-S9nPl6_uR6DqFc,4439
|
|
26
|
-
tests/lora/utils.py,sha256=
|
|
27
|
-
tpu_inference/__init__.py,sha256=
|
|
26
|
+
tests/lora/utils.py,sha256=rY0tDZEZe58ye4-ykwrTnsiWuLcaEG57N_Rua90bDXI,2726
|
|
27
|
+
tpu_inference/__init__.py,sha256=p4MaepRdN7723FUNE-3pOMxZWjFn4_TVFgjrNyty4JE,2304
|
|
28
28
|
tpu_inference/env_override.py,sha256=pmL7lfs_rGCP92ya3wuWuudsCYeOMZ6tFZY82A4KkQc,365
|
|
29
|
-
tpu_inference/envs.py,sha256=
|
|
29
|
+
tpu_inference/envs.py,sha256=hoPuT0SyLCxqyZ0QJIha6EXSZv2TpACfmENuiT0iJMM,3956
|
|
30
30
|
tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
|
|
31
|
-
tpu_inference/tpu_info.py,sha256=
|
|
32
|
-
tpu_inference/utils.py,sha256=
|
|
31
|
+
tpu_inference/tpu_info.py,sha256=3iilHRQSFjwMJwhKcuuawTm7mhwkgHbj4zi6CiAySrs,2265
|
|
32
|
+
tpu_inference/utils.py,sha256=Ddsx2CY2ARe46RZL27URzXCN3P6pMcKWB-APXUB8sHs,10098
|
|
33
33
|
tpu_inference/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
34
|
tpu_inference/core/core_tpu.py,sha256=WDD3koE_j1QhWS2BbMA2aQOZayPZm4tYPvzL4YCX2jY,33294
|
|
35
35
|
tpu_inference/core/disagg_executor.py,sha256=HZpgYMVxRxm0RQxO4l8IDYBWJ6Z3Tac6xavc5otcirc,4657
|
|
36
|
-
tpu_inference/core/disagg_utils.py,sha256=
|
|
36
|
+
tpu_inference/core/disagg_utils.py,sha256=lv8MAVoAjtcmTaenUXVokg2q3d0tzsma86UiQlQ3omY,1492
|
|
37
37
|
tpu_inference/core/sched/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
38
38
|
tpu_inference/core/sched/dp_scheduler.py,sha256=mKs8Ms46szdlBfo8hjdqis2ZKAZbcKnHAGfEr0X5R8g,22527
|
|
39
39
|
tpu_inference/distributed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
40
40
|
tpu_inference/distributed/jax_parallel_state.py,sha256=5_xCwcL03lFPUoSO_OP7hIVKpUFroW1m-jVO7R6FbUc,2223
|
|
41
|
-
tpu_inference/distributed/tpu_connector.py,sha256=
|
|
42
|
-
tpu_inference/distributed/utils.py,sha256=
|
|
41
|
+
tpu_inference/distributed/tpu_connector.py,sha256=w_gOI6hX7NWefaxN_9XH9TXReGElOyFifdDHpPswotM,29696
|
|
42
|
+
tpu_inference/distributed/utils.py,sha256=1KIREn28Zg10O-MSUkVQMRzS09WoGc_VLGOX4QTFJac,1504
|
|
43
43
|
tpu_inference/executors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
44
|
-
tpu_inference/executors/ray_distributed_executor.py,sha256=
|
|
44
|
+
tpu_inference/executors/ray_distributed_executor.py,sha256=emYfSFJ3kluEmi6mlfnvxSUrC_mGVRVcjrUqUH2MR4g,16122
|
|
45
45
|
tpu_inference/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
46
|
tpu_inference/experimental/llama3_jax_stashed.py,sha256=YK1oSIfto9ALo-HB45XfSrbq9XgVbE4m2C-9zRwmSzI,10913
|
|
47
47
|
tpu_inference/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -67,8 +67,8 @@ tpu_inference/kernels/ragged_paged_attention/v2/kernel.py,sha256=OiQGAHhyggbp1Pe
|
|
|
67
67
|
tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256=vGp2ZWODTbjyG9z2z0Qf_BX-wYHd5bUybnc_DtOz0nI,10995
|
|
68
68
|
tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
|
|
69
69
|
tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
70
|
-
tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=
|
|
71
|
-
tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=
|
|
70
|
+
tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=O179Fft5KpuN5LIFx3SghWXJJUqh3Og-xqfO4Z8QXYU,57032
|
|
71
|
+
tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=z0oaH8ZkDmHSoG4yiiO2CN0kuAuFcEpQ3RUoi5msjlo,56904
|
|
72
72
|
tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=k3LwduhZO85cJ-pSgnGN0c2Nn8eNeQq4eA94KUXJzMw,142198
|
|
73
73
|
tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py,sha256=P3_ivi8iUz5QMU_3pgpl4Bkbmn0q0NpDtVJX39haRQA,11208
|
|
74
74
|
tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=1N_ozjKboDYLteFJndWoLXNudj2z53rGXMkELa5Z9tY,1102
|
|
@@ -104,7 +104,7 @@ tpu_inference/layers/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
|
|
|
104
104
|
tpu_inference/layers/vllm/attention.py,sha256=wbJpcgqEAuIirv5PIULbiP-ggMKjmTanbB7Dg0BVYv4,7366
|
|
105
105
|
tpu_inference/layers/vllm/fused_moe.py,sha256=XZt2CPUz00qZzDcyfBFz6buhVzmGL1amHalHJALl9zw,18945
|
|
106
106
|
tpu_inference/layers/vllm/linear_common.py,sha256=_YlJtbdaYcck_j-gFLos_k0ycktVWxT8Qo57tR2YqJ8,7749
|
|
107
|
-
tpu_inference/layers/vllm/sharding.py,sha256=
|
|
107
|
+
tpu_inference/layers/vllm/sharding.py,sha256=as7CF8UKTF3ToymwRY5Pi8uzwJk0P1sHPkWB5xEx3mA,9169
|
|
108
108
|
tpu_inference/layers/vllm/quantization/__init__.py,sha256=SEppGayBzzQ5tsXLSy99aqilkAawQwYxnv2alCg6-ZU,1777
|
|
109
109
|
tpu_inference/layers/vllm/quantization/awq.py,sha256=-8ZmjGvSKJB6_JuwSctNWt8xHWq4VSvK_AK9iahlgCo,8495
|
|
110
110
|
tpu_inference/layers/vllm/quantization/common.py,sha256=wm3pge6XMTMsLK7_SSdgBP0PvQzz-1mrqN2I6xMqzrc,4218
|
|
@@ -118,30 +118,25 @@ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_ten
|
|
|
118
118
|
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=4y7lYgybpXszpCAtxGFhR8LDEbEoCCeo3DfUSOXxhaQ,5202
|
|
119
119
|
tpu_inference/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
120
120
|
tpu_inference/lora/torch_lora_ops.py,sha256=pr3N7DVfkn3ANijUC6dBoiCtIJW4fdJpKdC3zWBUsxE,3121
|
|
121
|
-
tpu_inference/lora/torch_punica_tpu.py,sha256=
|
|
122
|
-
tpu_inference/mock/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
123
|
-
tpu_inference/mock/vllm_config_utils.py,sha256=FlQshLjoHdgs3C66tYHYbKFUjbk9DhUwY-7HibZk0fI,878
|
|
124
|
-
tpu_inference/mock/vllm_envs.py,sha256=cCubeOhH2WeYZQFJt6W0y_IiQo0fzIWR1LCCE8i6kI4,50990
|
|
125
|
-
tpu_inference/mock/vllm_logger.py,sha256=vUGnN5nKT--ZvU15YCzODUM_FGiXKhcrrjDGjeN00RQ,7297
|
|
126
|
-
tpu_inference/mock/vllm_logging_utils.py,sha256=TEUmKj3xHiLzHBnFqAujcxH0t2hBQ04sUaho2RyORnk,486
|
|
121
|
+
tpu_inference/lora/torch_punica_tpu.py,sha256=qTnXZGLoOgvukSxeunO_SfpPTlkq9GlMj9H7zVYg9LE,12680
|
|
127
122
|
tpu_inference/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
128
123
|
tpu_inference/models/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
129
|
-
tpu_inference/models/common/model_loader.py,sha256=
|
|
124
|
+
tpu_inference/models/common/model_loader.py,sha256=3rRntyGqS6l7yAfURmRaGkhyIaee2E43a5F0_i0IFmE,18177
|
|
130
125
|
tpu_inference/models/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
131
126
|
tpu_inference/models/jax/deepseek_v3.py,sha256=SKOHVEC-_2NLxBnzBzbu5tu0d6FTlAEiI1EefGaO2QE,40047
|
|
132
127
|
tpu_inference/models/jax/gpt_oss.py,sha256=Vw4LRB5Kp6hbA2hjZGFS8kiEqOCjf881XH2JNtu2S1I,20924
|
|
133
128
|
tpu_inference/models/jax/jax_intermediate_tensor.py,sha256=Pxu1PCV5LN5X58aYVkPiohcXZIeKVim2oqvrS_cVgw4,2604
|
|
134
|
-
tpu_inference/models/jax/llama3.py,sha256=
|
|
129
|
+
tpu_inference/models/jax/llama3.py,sha256=ZiFtrpAzXTT9vAPES9UeuJInCWGbvDWs7g0_JLdCCa4,13479
|
|
135
130
|
tpu_inference/models/jax/llama4.py,sha256=wf2Sp2iYViaYD5rSfv3_ryO6gYuYM5XaOyvghaP4OCY,29631
|
|
136
|
-
tpu_inference/models/jax/llama_eagle3.py,sha256=
|
|
137
|
-
tpu_inference/models/jax/
|
|
138
|
-
tpu_inference/models/jax/qwen2.py,sha256=
|
|
139
|
-
tpu_inference/models/jax/qwen2_5_vl.py,sha256=
|
|
140
|
-
tpu_inference/models/jax/qwen3.py,sha256=
|
|
131
|
+
tpu_inference/models/jax/llama_eagle3.py,sha256=xUoNetxDbcFIEVLZ2DiD-GEQhHcdau2v1R12WdMyGec,12550
|
|
132
|
+
tpu_inference/models/jax/llama_guard_4.py,sha256=LrnU2zBWM0s4q_5dwmR--OO0V7ttltsYhrHYlBgQVIw,15275
|
|
133
|
+
tpu_inference/models/jax/qwen2.py,sha256=SuAp7tErk8OoIRko0Vt6QSOZP_9B9r5GTfqmVfImUIo,13410
|
|
134
|
+
tpu_inference/models/jax/qwen2_5_vl.py,sha256=tf177ypgA1ZVIn34Ff_LTwr10NwzlZ3-DPqSoRLAQtQ,43995
|
|
135
|
+
tpu_inference/models/jax/qwen3.py,sha256=CIZQKjZDke_LPGsLNhRCJdDTzWueUneBPAQ1blS24IM,11050
|
|
141
136
|
tpu_inference/models/jax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
142
137
|
tpu_inference/models/jax/utils/file_utils.py,sha256=NOuSC3YFnZpf3CZgYdghbbiNYJt42zgjlEYbOZIVct4,2840
|
|
143
138
|
tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=rrIrQWidkUnGilBHKNpdYh7_2BkvnAaqanXjC81GNcg,6156
|
|
144
|
-
tpu_inference/models/jax/utils/weight_utils.py,sha256=
|
|
139
|
+
tpu_inference/models/jax/utils/weight_utils.py,sha256=d5u8pPR-qPbEjX-8BMY0Zea9O-a34CpfuDlVnbwWfAw,20659
|
|
145
140
|
tpu_inference/models/jax/utils/quantization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
146
141
|
tpu_inference/models/jax/utils/quantization/mxfp4_utils.py,sha256=boGnqJCRIOf5nedAxQ8_IUTV6Rfll10DXnRC40BeeE8,3682
|
|
147
142
|
tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=xgKoKB7AM3TYPxzVgEGLTK9ebQH2Kx8mNuO0heovkmk,26778
|
|
@@ -150,30 +145,30 @@ tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml,sha256=b7Sy
|
|
|
150
145
|
tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml,sha256=0Qwij71zj9k6rmrUNd8Q5df9YYfkoJ1ZkgMAHxQy81k,128
|
|
151
146
|
tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml,sha256=lGec0UwwxmNPNgKPSsTsCMSXNJjhw507KMtM2NsSCMw,152
|
|
152
147
|
tpu_inference/models/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
153
|
-
tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=
|
|
148
|
+
tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=hEjg5hKotp-fEt3SXWkWpdnQ32TU1XGpTrfhyLTNyt0,12054
|
|
154
149
|
tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=yxlJHPmRQIAwlb1MmHK3xfXokgIkJ-evNU4PgyoJUdg,1187
|
|
155
150
|
tpu_inference/platforms/__init__.py,sha256=lQCrKddS_GcGpCbeogvz9zOZD1mQw5bBsiw8On46qFQ,74
|
|
156
|
-
tpu_inference/platforms/tpu_platform.py,sha256=
|
|
151
|
+
tpu_inference/platforms/tpu_platform.py,sha256=RSCe3Ne1FsWXVrX6_6V_Z6B0TDTRS38eM0KTkXbQ_w8,10579
|
|
157
152
|
tpu_inference/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
158
153
|
tpu_inference/runner/block_table.py,sha256=K3Ic8EgPM08d_C5nEN60mxoRydlaQWySAemf_8Q_qVw,4175
|
|
159
|
-
tpu_inference/runner/compilation_manager.py,sha256=
|
|
154
|
+
tpu_inference/runner/compilation_manager.py,sha256=oVML1KhhQ7YFaSWBaJA0qWQoNX2qRZOrwbbh4XYPc-8,37287
|
|
160
155
|
tpu_inference/runner/input_batch.py,sha256=bx221NX2IOWzrtopss-B-2ZKW4y-U6nQpG09PjpUziw,18273
|
|
161
156
|
tpu_inference/runner/kv_cache.py,sha256=F4dzW2d53xuxkFUn0oKzwE6VklGUeVm-QM19NVfIQDU,4577
|
|
162
|
-
tpu_inference/runner/kv_cache_manager.py,sha256=
|
|
157
|
+
tpu_inference/runner/kv_cache_manager.py,sha256=XEfis_9nQAz8uxM5y_P5biqSUijX4IeMhIusTf2V7vg,22444
|
|
163
158
|
tpu_inference/runner/lora_utils.py,sha256=B4xMCgXGJ4VNdePvn89HH3tIZ-gYsQ7Vq_YCiYIATEY,3843
|
|
164
159
|
tpu_inference/runner/multimodal_manager.py,sha256=azEPdHOwz8CN11MQmorGdtrCLbFaTCxdWyuEsZTzjYM,9778
|
|
165
160
|
tpu_inference/runner/persistent_batch_manager.py,sha256=KERSfKy6XjMejnbtPGI3hzoYAHJLeCxmpZVYPqBCago,11156
|
|
166
161
|
tpu_inference/runner/speculative_decoding_manager.py,sha256=I3FDWKh2dn6nV8LgTGfCTwMKYnxQsTPpBIrmaJngXHs,10215
|
|
167
162
|
tpu_inference/runner/structured_decoding_manager.py,sha256=Y0ERPhj4olFh6Y2TxP0R1_4UIJwy7nemYA-h63YIR2U,3622
|
|
168
|
-
tpu_inference/runner/tpu_runner.py,sha256=
|
|
163
|
+
tpu_inference/runner/tpu_runner.py,sha256=aHXHSlaNuc9q7pcPklqTFRkmkEQDULEEH_hsR_NcTMQ,77532
|
|
169
164
|
tpu_inference/runner/utils.py,sha256=ZnWUoNo-7INeB0mdXti1jwUOdbmxyExznOs-crRTQLk,17126
|
|
170
165
|
tpu_inference/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
171
166
|
tpu_inference/spec_decode/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
172
|
-
tpu_inference/spec_decode/jax/eagle3.py,sha256=
|
|
167
|
+
tpu_inference/spec_decode/jax/eagle3.py,sha256=1WVHTdv6jfCKwbiz0RwQLPyq8L720gD_bs0p_Gz0QiI,16644
|
|
173
168
|
tpu_inference/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
174
|
-
tpu_inference/worker/tpu_worker.py,sha256=
|
|
175
|
-
tpu_inference-0.11.1.
|
|
176
|
-
tpu_inference-0.11.1.
|
|
177
|
-
tpu_inference-0.11.1.
|
|
178
|
-
tpu_inference-0.11.1.
|
|
179
|
-
tpu_inference-0.11.1.
|
|
169
|
+
tpu_inference/worker/tpu_worker.py,sha256=aojB9-PY_ZzTaZgv1i5PUB9CSXNVuK4JZzftCv9ku4A,20642
|
|
170
|
+
tpu_inference-0.11.1.dev202511220812.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
171
|
+
tpu_inference-0.11.1.dev202511220812.dist-info/METADATA,sha256=JzmyOlYYkImIe_WSawI0LDwL28xS-0SCRCcFXeYSV0g,5465
|
|
172
|
+
tpu_inference-0.11.1.dev202511220812.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
173
|
+
tpu_inference-0.11.1.dev202511220812.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
|
|
174
|
+
tpu_inference-0.11.1.dev202511220812.dist-info/RECORD,,
|
tpu_inference/mock/__init__.py
DELETED
|
File without changes
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
from dataclasses import dataclass, field
|
|
2
|
-
from typing import Any, List, Mapping
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
@dataclass
|
|
6
|
-
class ModelConfig():
|
|
7
|
-
max_model_len: int = 2048
|
|
8
|
-
max_prefill_len: int = 1024
|
|
9
|
-
prefill_batch_size: int = 1
|
|
10
|
-
decode_batch_size: int = 1
|
|
11
|
-
block_size: int = 16
|
|
12
|
-
num_layers: int = 32
|
|
13
|
-
num_kv_heads: int = 32
|
|
14
|
-
head_dim: int = 128
|
|
15
|
-
vocab_size: int = 32000
|
|
16
|
-
model: str = "llama3"
|
|
17
|
-
hf_config: str = ""
|
|
18
|
-
architectures: List[str] = field(default_factory=list)
|
|
19
|
-
override_generation_config: dict[str, Any] = field(default_factory=dict)
|
|
20
|
-
hf_overrides: dict[str, Any] = field(default_factory=dict)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@dataclass
|
|
24
|
-
class VllmConfig():
|
|
25
|
-
additional_config: Mapping[str, Any] = field(default_factory=dict)
|
|
26
|
-
# Set default max_model_len to turn off warnings.
|
|
27
|
-
model_config: ModelConfig = field(
|
|
28
|
-
default_factory=lambda: ModelConfig(max_model_len=1024))
|