tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (40) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tpu_inference/__init__.py +22 -3
  4. tpu_inference/core/disagg_utils.py +6 -8
  5. tpu_inference/distributed/tpu_connector.py +2 -3
  6. tpu_inference/distributed/utils.py +3 -2
  7. tpu_inference/envs.py +1 -1
  8. tpu_inference/executors/ray_distributed_executor.py +4 -1
  9. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  10. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +77 -54
  11. tpu_inference/layers/vllm/sharding.py +2 -2
  12. tpu_inference/lora/torch_punica_tpu.py +1 -2
  13. tpu_inference/models/common/model_loader.py +9 -9
  14. tpu_inference/models/jax/llama3.py +2 -1
  15. tpu_inference/models/jax/llama_eagle3.py +9 -5
  16. tpu_inference/models/jax/llama_guard_4.py +361 -0
  17. tpu_inference/models/jax/qwen2.py +2 -1
  18. tpu_inference/models/jax/qwen2_5_vl.py +2 -1
  19. tpu_inference/models/jax/qwen3.py +2 -1
  20. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  21. tpu_inference/models/vllm/vllm_model_wrapper.py +4 -4
  22. tpu_inference/platforms/tpu_platform.py +5 -2
  23. tpu_inference/runner/compilation_manager.py +33 -15
  24. tpu_inference/runner/kv_cache_manager.py +8 -2
  25. tpu_inference/runner/tpu_runner.py +187 -99
  26. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  27. tpu_inference/tpu_info.py +4 -3
  28. tpu_inference/utils.py +5 -4
  29. tpu_inference/worker/tpu_worker.py +158 -22
  30. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  31. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +34 -39
  32. tpu_inference/mock/__init__.py +0 -0
  33. tpu_inference/mock/vllm_config_utils.py +0 -28
  34. tpu_inference/mock/vllm_envs.py +0 -1219
  35. tpu_inference/mock/vllm_logger.py +0 -212
  36. tpu_inference/mock/vllm_logging_utils.py +0 -15
  37. tpu_inference/models/jax/phi3.py +0 -376
  38. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  39. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  40. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  import os
4
4
  import tempfile
5
+ from dataclasses import dataclass, field
5
6
  from typing import Callable, Dict, Optional, Tuple
6
7
 
7
8
  import jax
@@ -10,6 +11,7 @@ import jaxlib
10
11
  import jaxtyping
11
12
  import vllm.envs as vllm_envs
12
13
  from vllm.config import VllmConfig, set_current_vllm_config
14
+ from vllm.distributed import get_pp_group
13
15
  from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
14
16
  has_kv_transfer_group)
15
17
  from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@@ -23,10 +25,13 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
23
25
  from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
24
26
 
25
27
  from tpu_inference import envs, utils
28
+ from tpu_inference.distributed import jax_parallel_state
26
29
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
27
30
  get_node_id)
28
31
  from tpu_inference.layers.common.sharding import ShardingConfigManager
29
32
  from tpu_inference.logger import init_logger
33
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
34
+ JaxIntermediateTensors
30
35
  from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
31
36
  from tpu_inference.runner.tpu_runner import TPUModelRunner
32
37
 
@@ -39,15 +44,39 @@ _DTYPE: dict[str, jnp.dtype] = {
39
44
  }
40
45
 
41
46
 
47
+ @dataclass
48
+ class PPConfig:
49
+ rank: int
50
+ ip: str
51
+ prev_worker_ip: str
52
+ pp_world_size: int
53
+
54
+ # default env vars for
55
+ # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
56
+ # if PP is used in single host.
57
+ default_tpu_process_bounds: str = field(init=False)
58
+ default_tpu_chips_per_process_bounds: str = field(init=False)
59
+ default_tpu_visible_chips: str = field(init=False)
60
+
61
+ def __post_init__(self):
62
+ self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
63
+ self.default_tpu_chips_per_process_bounds = "1,1,1"
64
+ self.default_tpu_visible_chips = f"{self.rank}"
65
+
66
+
42
67
  class TPUWorker:
43
68
 
44
- def __init__(self,
45
- vllm_config: VllmConfig,
46
- local_rank: int,
47
- rank: int,
48
- distributed_init_method: str,
49
- is_driver_worker: bool = False,
50
- devices=None):
69
+ def __init__(
70
+ self,
71
+ vllm_config: VllmConfig,
72
+ local_rank: int,
73
+ rank: int,
74
+ distributed_init_method: str,
75
+ is_driver_worker: bool = False,
76
+ devices=None,
77
+ ip: str = "localhost",
78
+ prev_worker_ip: str = "localhost",
79
+ ):
51
80
  # If we use vLLM's model implementation in PyTorch, we should set it
52
81
  # with torch version of the dtype.
53
82
  impl = envs.MODEL_IMPL_TYPE
@@ -74,6 +103,8 @@ class TPUWorker:
74
103
  self.devices = devices if devices is not None else []
75
104
  self.device_ranks = set(device.id for device in self.devices
76
105
  if isinstance(device, jaxlib._jax.Device))
106
+ self.pp_config = PPConfig(rank, ip, prev_worker_ip,
107
+ self.parallel_config.pipeline_parallel_size)
77
108
 
78
109
  if self.model_config.trust_remote_code:
79
110
  # note: lazy import to avoid importing torch before initializing
@@ -86,7 +117,7 @@ class TPUWorker:
86
117
  # TPU Worker is initialized. The profiler server needs to start after
87
118
  # MP runtime is initialized.
88
119
  self.profile_dir = None
89
- if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
120
+ if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
90
121
  if not self.devices or 0 in self.device_ranks:
91
122
  # For TPU, we can only have 1 active profiler session for 1 profiler
92
123
  # server. So we only profile on rank0.
@@ -94,6 +125,14 @@ class TPUWorker:
94
125
  logger.info("Profiling enabled. Traces will be saved to: %s",
95
126
  self.profile_dir)
96
127
 
128
+ # For PP, we use MPMD so we want to profile every worker.
129
+ if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
130
+ self.profile_dir = os.path.join(
131
+ vllm_envs.VLLM_TORCH_PROFILER_DIR,
132
+ f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
133
+ )
134
+ os.makedirs(self.profile_dir, exist_ok=True)
135
+
97
136
  use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
98
137
  # Only one instance of profiler is allowed
99
138
  if use_jax_profiler_server and self.rank < 1:
@@ -105,31 +144,87 @@ class TPUWorker:
105
144
  )
106
145
  jax.profiler.start_server(jax_profiler_server_port)
107
146
 
147
+ # step_counter is used to calculate uuid to transfer intermediate tensors.
148
+ self.step_counter = 0
149
+
108
150
  def initialize_cache(self, num_gpu_blocks: int,
109
151
  num_cpu_blocks: int) -> None:
110
152
  self.cache_config.num_gpu_blocks = num_gpu_blocks
111
153
  self.cache_config.num_cpu_blocks = num_cpu_blocks
112
154
 
113
- def init_device(self):
155
+ def init_device(self,
156
+ tpu_process_bounds="",
157
+ tpu_chips_per_process_bounds="",
158
+ tpu_visible_chips=""):
159
+ # set tpu visible devices for Jax runtime in single host PP.
160
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
161
+ if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
162
+ tpu_ports = [
163
+ jax_parallel_state.BASE_JAX_PORT + i
164
+ for i in range(self.pp_config.pp_world_size)
165
+ ]
166
+ os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
167
+ [f"localhost:{port}" for port in tpu_ports])
168
+ os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
169
+ os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
170
+
171
+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
172
+ # Replace with your own topology.
173
+ # There are 2 ways of subslicing a v6e
174
+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
175
+ # TPU_PROCESS_BOUNDS = "1,1,1"
176
+ # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
177
+ # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
178
+ # 2) 1 chip for each subslice, with at most 8 subslices,
179
+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
180
+ os.environ[
181
+ "TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
182
+ if tpu_process_bounds \
183
+ else self.pp_config.default_tpu_process_bounds
184
+ os.environ[
185
+ "TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
186
+ if tpu_chips_per_process_bounds \
187
+ else self.pp_config.default_tpu_chips_per_process_bounds
188
+ os.environ[
189
+ "TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
190
+ if tpu_visible_chips \
191
+ else self.pp_config.default_tpu_visible_chips
192
+
114
193
  if not self.devices:
115
194
  sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
116
195
  device_indexes = sharding_config.device_indexes
117
196
  if device_indexes is not None and len(device_indexes) > 0:
118
197
  # Enforcing the devices sequence to be consistent with the specified device indexes
119
- all_devices = jax.devices()
120
- device_dict = {device.id: device for device in all_devices}
198
+ all_local_devices = jax.local_devices()
199
+ device_dict = {
200
+ device.id: device
201
+ for device in all_local_devices
202
+ }
121
203
  self.devices = []
122
204
  for device_index in device_indexes:
123
205
  device = device_dict[device_index]
124
206
  if device is None:
125
207
  raise KeyError(
126
208
  f"Device index {device_index} not found in "
127
- f"jax.devices() with IDs {list(device_dict.keys())}!"
209
+ f"jax.local_devices() with IDs {list(device_dict.keys())}!"
128
210
  )
129
211
  self.devices.append(device)
212
+ assert len(self.devices) >= sharding_config.total_devices
130
213
  self.devices = self.devices[:sharding_config.total_devices]
131
214
  else:
132
- self.devices = jax.devices()[:sharding_config.total_devices]
215
+ if self.pp_config.pp_world_size > 1:
216
+ # We only support a mixed tp + pp scenario that tp size is
217
+ # smaller or equals the total TPUs in one node
218
+ # say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
219
+ assert jax.local_device_count(
220
+ ) >= sharding_config.total_devices
221
+ self.devices = jax.local_devices()[:sharding_config.
222
+ total_devices]
223
+ else:
224
+ # In a multi-host distributed env, say: Ray, local_device count may smaller
225
+ # than the total devices, we just choose the smaller set here.
226
+ self.devices = jax.devices()[:sharding_config.
227
+ total_devices]
133
228
 
134
229
  # Initialize the vLLM distribution layer as a single chip environment,
135
230
  # we'll swap the model's parallel modules with TPU SPMD equivalents.
@@ -146,8 +241,18 @@ class TPUWorker:
146
241
  tensor_model_parallel_size=1,
147
242
  pipeline_model_parallel_size=1,
148
243
  )
244
+
245
+ jax_parallel_state.init_pp_distributed_environment(
246
+ self.pp_config.ip,
247
+ self.rank,
248
+ self.parallel_config.pipeline_parallel_size,
249
+ self.devices[0],
250
+ need_pp=self.parallel_config.pipeline_parallel_size > 1)
251
+
149
252
  ensure_kv_transfer_initialized(self.vllm_config)
150
- self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
253
+ self.model_runner = TPUModelRunner(
254
+ self.vllm_config, self.devices, self.rank, self.rank == 0,
255
+ self.rank == self.pp_config.pp_world_size - 1)
151
256
  logger.info(f"Init worker | "
152
257
  f"rank={self.rank} | "
153
258
  f"node_id={get_node_id()} | "
@@ -155,6 +260,12 @@ class TPUWorker:
155
260
  f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
156
261
  vllm_utils.report_usage_stats(self.vllm_config)
157
262
 
263
+ def initialize_pp_transfer_connect(self):
264
+ if self.rank == 0:
265
+ return
266
+ jax_parallel_state.connect(self.pp_config.prev_worker_ip,
267
+ self.rank - 1)
268
+
158
269
  def determine_available_memory(self) -> int:
159
270
  gpu_memory_utilization = self.cache_config.gpu_memory_utilization
160
271
  hbm_usage = utils.hbm_usage_bytes(self.devices)
@@ -194,14 +305,39 @@ class TPUWorker:
194
305
  # deliberate, temporary compromise for the same reasons outlined in
195
306
  # the `get_kv_cache_spec` method.
196
307
 
197
- output = self.model_runner.execute_model(scheduler_output)
198
-
199
- # With a connector, the scheduler expects output from all workers
200
- # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
201
- if has_kv_transfer_group():
202
- return output
203
-
204
- return output if self.is_driver_worker else None
308
+ if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
309
+ intermediate_tensors = None
310
+ else:
311
+ # receive intermediate tensors
312
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
313
+ scheduler_output, self.rank - 1, self.step_counter)
314
+ # TODO: this method might only works for vllm model, not sure about jax models.
315
+ tensor_spec = self.model_runner.get_intermediate_tensor_spec(
316
+ scheduler_output.total_num_scheduled_tokens)
317
+ intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
318
+ uuid, tensor_spec)
319
+ intermediate_tensors = JaxIntermediateTensors(
320
+ intermediate_tensors_dict)
321
+
322
+ output = self.model_runner.execute_model(scheduler_output,
323
+ intermediate_tensors)
324
+
325
+ if isinstance(output, JaxIntermediateTensors):
326
+ assert self.parallel_config.pipeline_parallel_size > 1
327
+ assert not get_pp_group().is_last_rank
328
+ # send intermediate tensors
329
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
330
+ scheduler_output, self.rank, self.step_counter)
331
+ get_pp_group().send_tensor_dict(uuid, output.tensors)
332
+ self.step_counter += 1
333
+ return None
334
+ else:
335
+ self.step_counter += 1
336
+ # With a connector, the scheduler expects output from all workers
337
+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
338
+ if has_kv_transfer_group():
339
+ return output
340
+ return output if self.is_driver_worker else None
205
341
 
206
342
  def sample_tokens(self,
207
343
  grammar_output: GrammarOutput) -> ModelRunnerOutput:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1.dev202511180814
3
+ Version: 0.11.1.dev202511220812
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -27,7 +27,7 @@ Requires-Dist: jaxtyping
27
27
  Requires-Dist: flax==0.11.1
28
28
  Requires-Dist: torchax==0.0.7
29
29
  Requires-Dist: qwix==0.1.1
30
- Requires-Dist: torchvision==0.23.0
30
+ Requires-Dist: torchvision==0.24.0
31
31
  Requires-Dist: pathwaysutils
32
32
  Requires-Dist: parameterized
33
33
  Requires-Dist: numba==0.62.1
@@ -21,27 +21,27 @@ tests/kernels/ragged_paged_attention_kernel_v3_test.py,sha256=Hrd8iUkS1pS3rxeTyY
21
21
  tests/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  tests/lora/conftest.py,sha256=EXjwE1CjmUUlMEXpyE3UwxvgrKUllE73I8BNKfP1FTc,984
23
23
  tests/lora/test_bgmv.py,sha256=gQxWsJdNX2nkrE2xyrG0exwf3E2eHm2k2nkEXoANuQc,1359
24
- tests/lora/test_layers.py,sha256=21ekYlsK36r1GPZOfzs7E-KIsfI1JcuZl1E6vaQbHf4,26273
24
+ tests/lora/test_layers.py,sha256=6B4HhMAItQmt0hPAQgyXgwSYs7b3bIbUf6LaPsqXLzY,25923
25
25
  tests/lora/test_lora.py,sha256=wJiF1P1BDnPN8TLX2tlFtdZ_QCkV-S9nPl6_uR6DqFc,4439
26
- tests/lora/utils.py,sha256=dR_v1H20vPVjFHdBhDajWOz0WJZlKuPLgMFQsME0LtA,3009
27
- tpu_inference/__init__.py,sha256=7IduGWw-_fwx0VA6EvC_AqHF67fnnShz6YvkqCfvFx8,1317
26
+ tests/lora/utils.py,sha256=rY0tDZEZe58ye4-ykwrTnsiWuLcaEG57N_Rua90bDXI,2726
27
+ tpu_inference/__init__.py,sha256=p4MaepRdN7723FUNE-3pOMxZWjFn4_TVFgjrNyty4JE,2304
28
28
  tpu_inference/env_override.py,sha256=pmL7lfs_rGCP92ya3wuWuudsCYeOMZ6tFZY82A4KkQc,365
29
- tpu_inference/envs.py,sha256=MTT_Pdtd6cAcciYjv1OekEmvspaq3SYL0oR_jDkQ_aE,3948
29
+ tpu_inference/envs.py,sha256=hoPuT0SyLCxqyZ0QJIha6EXSZv2TpACfmENuiT0iJMM,3956
30
30
  tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
31
- tpu_inference/tpu_info.py,sha256=9UohshkndR6dZpGWpWXfTD4qvIVdVgHf0yOoSEkLTrw,2276
32
- tpu_inference/utils.py,sha256=iGPY147jP_8AKMu3g7vYTndjJJiOrK_4opA0JWtws5Q,10068
31
+ tpu_inference/tpu_info.py,sha256=3iilHRQSFjwMJwhKcuuawTm7mhwkgHbj4zi6CiAySrs,2265
32
+ tpu_inference/utils.py,sha256=Ddsx2CY2ARe46RZL27URzXCN3P6pMcKWB-APXUB8sHs,10098
33
33
  tpu_inference/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
34
  tpu_inference/core/core_tpu.py,sha256=WDD3koE_j1QhWS2BbMA2aQOZayPZm4tYPvzL4YCX2jY,33294
35
35
  tpu_inference/core/disagg_executor.py,sha256=HZpgYMVxRxm0RQxO4l8IDYBWJ6Z3Tac6xavc5otcirc,4657
36
- tpu_inference/core/disagg_utils.py,sha256=ufWNFWQ5n4YnZpPOtoReHlYo4dlN7AbIqCyqS4an0t4,1572
36
+ tpu_inference/core/disagg_utils.py,sha256=lv8MAVoAjtcmTaenUXVokg2q3d0tzsma86UiQlQ3omY,1492
37
37
  tpu_inference/core/sched/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
38
  tpu_inference/core/sched/dp_scheduler.py,sha256=mKs8Ms46szdlBfo8hjdqis2ZKAZbcKnHAGfEr0X5R8g,22527
39
39
  tpu_inference/distributed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
40
  tpu_inference/distributed/jax_parallel_state.py,sha256=5_xCwcL03lFPUoSO_OP7hIVKpUFroW1m-jVO7R6FbUc,2223
41
- tpu_inference/distributed/tpu_connector.py,sha256=Zah46Sm5iOuh72SzXw69NxMc0MLnqsLEpe2BfDhpnqA,29731
42
- tpu_inference/distributed/utils.py,sha256=RwFQi8G4TzN1g9RjQu0pb5JxSc_jhoIZVsFJo0uHjxo,1513
41
+ tpu_inference/distributed/tpu_connector.py,sha256=w_gOI6hX7NWefaxN_9XH9TXReGElOyFifdDHpPswotM,29696
42
+ tpu_inference/distributed/utils.py,sha256=1KIREn28Zg10O-MSUkVQMRzS09WoGc_VLGOX4QTFJac,1504
43
43
  tpu_inference/executors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- tpu_inference/executors/ray_distributed_executor.py,sha256=ZMuVUwmroi7UUZs3u67OsOwUIkxNDz9IszUPG20F18E,15904
44
+ tpu_inference/executors/ray_distributed_executor.py,sha256=emYfSFJ3kluEmi6mlfnvxSUrC_mGVRVcjrUqUH2MR4g,16122
45
45
  tpu_inference/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  tpu_inference/experimental/llama3_jax_stashed.py,sha256=YK1oSIfto9ALo-HB45XfSrbq9XgVbE4m2C-9zRwmSzI,10913
47
47
  tpu_inference/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -67,8 +67,8 @@ tpu_inference/kernels/ragged_paged_attention/v2/kernel.py,sha256=OiQGAHhyggbp1Pe
67
67
  tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256=vGp2ZWODTbjyG9z2z0Qf_BX-wYHd5bUybnc_DtOz0nI,10995
68
68
  tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
69
69
  tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
70
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=tlP6121yfXaukx_RQroHlHcZnbKPyyum0lAcvT0B_Pk,56132
71
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=pD1Pte3neoLAxE3I3-VyV_4FuqgCHeAHGzEjMVt0MMk,56004
70
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=O179Fft5KpuN5LIFx3SghWXJJUqh3Og-xqfO4Z8QXYU,57032
71
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=z0oaH8ZkDmHSoG4yiiO2CN0kuAuFcEpQ3RUoi5msjlo,56904
72
72
  tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=k3LwduhZO85cJ-pSgnGN0c2Nn8eNeQq4eA94KUXJzMw,142198
73
73
  tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py,sha256=P3_ivi8iUz5QMU_3pgpl4Bkbmn0q0NpDtVJX39haRQA,11208
74
74
  tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=1N_ozjKboDYLteFJndWoLXNudj2z53rGXMkELa5Z9tY,1102
@@ -104,7 +104,7 @@ tpu_inference/layers/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
104
104
  tpu_inference/layers/vllm/attention.py,sha256=wbJpcgqEAuIirv5PIULbiP-ggMKjmTanbB7Dg0BVYv4,7366
105
105
  tpu_inference/layers/vllm/fused_moe.py,sha256=XZt2CPUz00qZzDcyfBFz6buhVzmGL1amHalHJALl9zw,18945
106
106
  tpu_inference/layers/vllm/linear_common.py,sha256=_YlJtbdaYcck_j-gFLos_k0ycktVWxT8Qo57tR2YqJ8,7749
107
- tpu_inference/layers/vllm/sharding.py,sha256=WTx1tF_7R99AdyE-lL7HQJ378hAafeI-JVRsugAvwn4,9177
107
+ tpu_inference/layers/vllm/sharding.py,sha256=as7CF8UKTF3ToymwRY5Pi8uzwJk0P1sHPkWB5xEx3mA,9169
108
108
  tpu_inference/layers/vllm/quantization/__init__.py,sha256=SEppGayBzzQ5tsXLSy99aqilkAawQwYxnv2alCg6-ZU,1777
109
109
  tpu_inference/layers/vllm/quantization/awq.py,sha256=-8ZmjGvSKJB6_JuwSctNWt8xHWq4VSvK_AK9iahlgCo,8495
110
110
  tpu_inference/layers/vllm/quantization/common.py,sha256=wm3pge6XMTMsLK7_SSdgBP0PvQzz-1mrqN2I6xMqzrc,4218
@@ -118,30 +118,25 @@ tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_ten
118
118
  tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=4y7lYgybpXszpCAtxGFhR8LDEbEoCCeo3DfUSOXxhaQ,5202
119
119
  tpu_inference/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
120
120
  tpu_inference/lora/torch_lora_ops.py,sha256=pr3N7DVfkn3ANijUC6dBoiCtIJW4fdJpKdC3zWBUsxE,3121
121
- tpu_inference/lora/torch_punica_tpu.py,sha256=b27DpmIS_N5bhlIcryiENYNmPxp_cu40CGxjPW64d44,12706
122
- tpu_inference/mock/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
123
- tpu_inference/mock/vllm_config_utils.py,sha256=FlQshLjoHdgs3C66tYHYbKFUjbk9DhUwY-7HibZk0fI,878
124
- tpu_inference/mock/vllm_envs.py,sha256=cCubeOhH2WeYZQFJt6W0y_IiQo0fzIWR1LCCE8i6kI4,50990
125
- tpu_inference/mock/vllm_logger.py,sha256=vUGnN5nKT--ZvU15YCzODUM_FGiXKhcrrjDGjeN00RQ,7297
126
- tpu_inference/mock/vllm_logging_utils.py,sha256=TEUmKj3xHiLzHBnFqAujcxH0t2hBQ04sUaho2RyORnk,486
121
+ tpu_inference/lora/torch_punica_tpu.py,sha256=qTnXZGLoOgvukSxeunO_SfpPTlkq9GlMj9H7zVYg9LE,12680
127
122
  tpu_inference/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
128
123
  tpu_inference/models/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
129
- tpu_inference/models/common/model_loader.py,sha256=VgxM2OODb0-69dexv4aNJ4g24Nrx5sj_ra4XStkhl14,18289
124
+ tpu_inference/models/common/model_loader.py,sha256=3rRntyGqS6l7yAfURmRaGkhyIaee2E43a5F0_i0IFmE,18177
130
125
  tpu_inference/models/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
131
126
  tpu_inference/models/jax/deepseek_v3.py,sha256=SKOHVEC-_2NLxBnzBzbu5tu0d6FTlAEiI1EefGaO2QE,40047
132
127
  tpu_inference/models/jax/gpt_oss.py,sha256=Vw4LRB5Kp6hbA2hjZGFS8kiEqOCjf881XH2JNtu2S1I,20924
133
128
  tpu_inference/models/jax/jax_intermediate_tensor.py,sha256=Pxu1PCV5LN5X58aYVkPiohcXZIeKVim2oqvrS_cVgw4,2604
134
- tpu_inference/models/jax/llama3.py,sha256=w99DAfipGS9HyX2ZRwqyYLxC3oa0ew5eEQ6EXlMMf18,13426
129
+ tpu_inference/models/jax/llama3.py,sha256=ZiFtrpAzXTT9vAPES9UeuJInCWGbvDWs7g0_JLdCCa4,13479
135
130
  tpu_inference/models/jax/llama4.py,sha256=wf2Sp2iYViaYD5rSfv3_ryO6gYuYM5XaOyvghaP4OCY,29631
136
- tpu_inference/models/jax/llama_eagle3.py,sha256=STUkAK6XEA7JM3i_Lx36-t5BhkAGeW_xYiq3zYhHP1A,12297
137
- tpu_inference/models/jax/phi3.py,sha256=TpP3Nvr1myW_Qd8xNrLP1VmXtq7BuTcWNayJitskFd0,13579
138
- tpu_inference/models/jax/qwen2.py,sha256=P_x_Qygf-nanmF8Uufk4c-qLNxP4RAk4yuqSF8VwbxE,13357
139
- tpu_inference/models/jax/qwen2_5_vl.py,sha256=fvMgM5GfUn5EECaMbR0z37mmbCHphAT1AvWPvGkhVn4,43942
140
- tpu_inference/models/jax/qwen3.py,sha256=lr3TIIQKmNgWFDFxwuPsVOypqBijkqrpnNCopVg4iBo,10997
131
+ tpu_inference/models/jax/llama_eagle3.py,sha256=xUoNetxDbcFIEVLZ2DiD-GEQhHcdau2v1R12WdMyGec,12550
132
+ tpu_inference/models/jax/llama_guard_4.py,sha256=LrnU2zBWM0s4q_5dwmR--OO0V7ttltsYhrHYlBgQVIw,15275
133
+ tpu_inference/models/jax/qwen2.py,sha256=SuAp7tErk8OoIRko0Vt6QSOZP_9B9r5GTfqmVfImUIo,13410
134
+ tpu_inference/models/jax/qwen2_5_vl.py,sha256=tf177ypgA1ZVIn34Ff_LTwr10NwzlZ3-DPqSoRLAQtQ,43995
135
+ tpu_inference/models/jax/qwen3.py,sha256=CIZQKjZDke_LPGsLNhRCJdDTzWueUneBPAQ1blS24IM,11050
141
136
  tpu_inference/models/jax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
142
137
  tpu_inference/models/jax/utils/file_utils.py,sha256=NOuSC3YFnZpf3CZgYdghbbiNYJt42zgjlEYbOZIVct4,2840
143
138
  tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=rrIrQWidkUnGilBHKNpdYh7_2BkvnAaqanXjC81GNcg,6156
144
- tpu_inference/models/jax/utils/weight_utils.py,sha256=65-H8BTbyilIBMBfvWjkkW3mf4soYASbhrJFqbFKzL4,20129
139
+ tpu_inference/models/jax/utils/weight_utils.py,sha256=d5u8pPR-qPbEjX-8BMY0Zea9O-a34CpfuDlVnbwWfAw,20659
145
140
  tpu_inference/models/jax/utils/quantization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
146
141
  tpu_inference/models/jax/utils/quantization/mxfp4_utils.py,sha256=boGnqJCRIOf5nedAxQ8_IUTV6Rfll10DXnRC40BeeE8,3682
147
142
  tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=xgKoKB7AM3TYPxzVgEGLTK9ebQH2Kx8mNuO0heovkmk,26778
@@ -150,30 +145,30 @@ tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml,sha256=b7Sy
150
145
  tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml,sha256=0Qwij71zj9k6rmrUNd8Q5df9YYfkoJ1ZkgMAHxQy81k,128
151
146
  tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml,sha256=lGec0UwwxmNPNgKPSsTsCMSXNJjhw507KMtM2NsSCMw,152
152
147
  tpu_inference/models/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
153
- tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=o3oJ7Uhu-vSJEFHHifF8e0Q7dULRKJ2GRsT1qAN6PWY,12099
148
+ tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=hEjg5hKotp-fEt3SXWkWpdnQ32TU1XGpTrfhyLTNyt0,12054
154
149
  tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=yxlJHPmRQIAwlb1MmHK3xfXokgIkJ-evNU4PgyoJUdg,1187
155
150
  tpu_inference/platforms/__init__.py,sha256=lQCrKddS_GcGpCbeogvz9zOZD1mQw5bBsiw8On46qFQ,74
156
- tpu_inference/platforms/tpu_platform.py,sha256=AYFr1Q7VUN76wcdgOe_wZuVIHgp2U8isBJ3iHrYqt0M,10530
151
+ tpu_inference/platforms/tpu_platform.py,sha256=RSCe3Ne1FsWXVrX6_6V_Z6B0TDTRS38eM0KTkXbQ_w8,10579
157
152
  tpu_inference/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
158
153
  tpu_inference/runner/block_table.py,sha256=K3Ic8EgPM08d_C5nEN60mxoRydlaQWySAemf_8Q_qVw,4175
159
- tpu_inference/runner/compilation_manager.py,sha256=yIsonouB5G0-fyVtAKuyyRXaMGNFwnX8D7q6ppQYgUI,36318
154
+ tpu_inference/runner/compilation_manager.py,sha256=oVML1KhhQ7YFaSWBaJA0qWQoNX2qRZOrwbbh4XYPc-8,37287
160
155
  tpu_inference/runner/input_batch.py,sha256=bx221NX2IOWzrtopss-B-2ZKW4y-U6nQpG09PjpUziw,18273
161
156
  tpu_inference/runner/kv_cache.py,sha256=F4dzW2d53xuxkFUn0oKzwE6VklGUeVm-QM19NVfIQDU,4577
162
- tpu_inference/runner/kv_cache_manager.py,sha256=CJxXtdWuewJqcTBMoR70_Uvwxjtc3cK2jxe1KpI9kQc,22152
157
+ tpu_inference/runner/kv_cache_manager.py,sha256=XEfis_9nQAz8uxM5y_P5biqSUijX4IeMhIusTf2V7vg,22444
163
158
  tpu_inference/runner/lora_utils.py,sha256=B4xMCgXGJ4VNdePvn89HH3tIZ-gYsQ7Vq_YCiYIATEY,3843
164
159
  tpu_inference/runner/multimodal_manager.py,sha256=azEPdHOwz8CN11MQmorGdtrCLbFaTCxdWyuEsZTzjYM,9778
165
160
  tpu_inference/runner/persistent_batch_manager.py,sha256=KERSfKy6XjMejnbtPGI3hzoYAHJLeCxmpZVYPqBCago,11156
166
161
  tpu_inference/runner/speculative_decoding_manager.py,sha256=I3FDWKh2dn6nV8LgTGfCTwMKYnxQsTPpBIrmaJngXHs,10215
167
162
  tpu_inference/runner/structured_decoding_manager.py,sha256=Y0ERPhj4olFh6Y2TxP0R1_4UIJwy7nemYA-h63YIR2U,3622
168
- tpu_inference/runner/tpu_runner.py,sha256=3SZYn0CBA4LOaTO3GdQOxKx3HKmVcNmUEeSyzSAGyFY,73320
163
+ tpu_inference/runner/tpu_runner.py,sha256=aHXHSlaNuc9q7pcPklqTFRkmkEQDULEEH_hsR_NcTMQ,77532
169
164
  tpu_inference/runner/utils.py,sha256=ZnWUoNo-7INeB0mdXti1jwUOdbmxyExznOs-crRTQLk,17126
170
165
  tpu_inference/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
171
166
  tpu_inference/spec_decode/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
172
- tpu_inference/spec_decode/jax/eagle3.py,sha256=A1dt-dmBttpy-5DGcL4noEDCB0OGP8Xo6MXqgJvWIo8,16593
167
+ tpu_inference/spec_decode/jax/eagle3.py,sha256=1WVHTdv6jfCKwbiz0RwQLPyq8L720gD_bs0p_Gz0QiI,16644
173
168
  tpu_inference/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
174
- tpu_inference/worker/tpu_worker.py,sha256=0ZguK2BtIQjQSvyUTcUH9ENBrxt09w3CbgPoDY13Eok,14210
175
- tpu_inference-0.11.1.dev202511180814.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
176
- tpu_inference-0.11.1.dev202511180814.dist-info/METADATA,sha256=6dHy_ByQ0ihDNFuqyb-ZXTFczvQ8Ia54zBNTKaUPhSk,5465
177
- tpu_inference-0.11.1.dev202511180814.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
178
- tpu_inference-0.11.1.dev202511180814.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
179
- tpu_inference-0.11.1.dev202511180814.dist-info/RECORD,,
169
+ tpu_inference/worker/tpu_worker.py,sha256=aojB9-PY_ZzTaZgv1i5PUB9CSXNVuK4JZzftCv9ku4A,20642
170
+ tpu_inference-0.11.1.dev202511220812.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
171
+ tpu_inference-0.11.1.dev202511220812.dist-info/METADATA,sha256=JzmyOlYYkImIe_WSawI0LDwL28xS-0SCRCcFXeYSV0g,5465
172
+ tpu_inference-0.11.1.dev202511220812.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
173
+ tpu_inference-0.11.1.dev202511220812.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
174
+ tpu_inference-0.11.1.dev202511220812.dist-info/RECORD,,
File without changes
@@ -1,28 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Any, List, Mapping
3
-
4
-
5
- @dataclass
6
- class ModelConfig():
7
- max_model_len: int = 2048
8
- max_prefill_len: int = 1024
9
- prefill_batch_size: int = 1
10
- decode_batch_size: int = 1
11
- block_size: int = 16
12
- num_layers: int = 32
13
- num_kv_heads: int = 32
14
- head_dim: int = 128
15
- vocab_size: int = 32000
16
- model: str = "llama3"
17
- hf_config: str = ""
18
- architectures: List[str] = field(default_factory=list)
19
- override_generation_config: dict[str, Any] = field(default_factory=dict)
20
- hf_overrides: dict[str, Any] = field(default_factory=dict)
21
-
22
-
23
- @dataclass
24
- class VllmConfig():
25
- additional_config: Mapping[str, Any] = field(default_factory=dict)
26
- # Set default max_model_len to turn off warnings.
27
- model_config: ModelConfig = field(
28
- default_factory=lambda: ModelConfig(max_model_len=1024))