tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (58) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tests/test_envs.py +182 -0
  4. tests/test_utils.py +23 -14
  5. tpu_inference/__init__.py +22 -3
  6. tpu_inference/core/core_tpu.py +17 -9
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +2 -3
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +1 -1
  11. tpu_inference/executors/ray_distributed_executor.py +27 -11
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
  14. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
  15. tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
  16. tpu_inference/layers/common/quant_methods.py +8 -0
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  19. tpu_inference/layers/jax/sample/sampling.py +2 -2
  20. tpu_inference/layers/vllm/attention.py +1 -1
  21. tpu_inference/layers/vllm/quantization/__init__.py +7 -3
  22. tpu_inference/layers/vllm/quantization/awq.py +4 -3
  23. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
  24. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  25. tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
  26. tpu_inference/layers/vllm/sharding.py +2 -2
  27. tpu_inference/lora/torch_punica_tpu.py +1 -2
  28. tpu_inference/models/common/model_loader.py +12 -11
  29. tpu_inference/models/jax/llama3.py +4 -3
  30. tpu_inference/models/jax/llama_eagle3.py +9 -5
  31. tpu_inference/models/jax/llama_guard_4.py +361 -0
  32. tpu_inference/models/jax/qwen2.py +3 -2
  33. tpu_inference/models/jax/qwen2_5_vl.py +4 -3
  34. tpu_inference/models/jax/qwen3.py +3 -2
  35. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  36. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
  37. tpu_inference/platforms/tpu_platform.py +17 -7
  38. tpu_inference/runner/compilation_manager.py +37 -17
  39. tpu_inference/runner/kv_cache.py +1 -1
  40. tpu_inference/runner/kv_cache_manager.py +8 -2
  41. tpu_inference/runner/tpu_runner.py +199 -87
  42. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  43. tpu_inference/tpu_info.py +4 -3
  44. tpu_inference/utils.py +7 -6
  45. tpu_inference/worker/tpu_worker.py +159 -23
  46. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  47. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
  48. tpu_inference/mock/__init__.py +0 -0
  49. tpu_inference/mock/vllm_config_utils.py +0 -28
  50. tpu_inference/mock/vllm_envs.py +0 -1219
  51. tpu_inference/mock/vllm_logger.py +0 -212
  52. tpu_inference/mock/vllm_logging_utils.py +0 -15
  53. tpu_inference/models/jax/phi3.py +0 -376
  54. /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
  55. /tpu_inference/layers/{jax → common}/sharding.py +0 -0
  56. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  57. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  58. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
tpu_inference/tpu_info.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
 
4
4
  import requests
5
5
 
6
+ from tpu_inference import envs
6
7
  from tpu_inference.logger import init_logger
7
8
 
8
9
  logger = init_logger(__name__)
@@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str:
32
33
 
33
34
 
34
35
  def get_tpu_type() -> str:
35
- tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
36
+ tpu_type = envs.TPU_ACCELERATOR_TYPE
36
37
  if tpu_type is None:
37
38
  tpu_type = get_tpu_metadata(key="accelerator-type")
38
39
  return tpu_type
39
40
 
40
41
 
41
42
  def get_node_name() -> str:
42
- tpu_name = os.getenv("TPU_NAME", None)
43
+ tpu_name = envs.TPU_NAME
43
44
  if not tpu_name:
44
45
  tpu_name = get_tpu_metadata(key="instance-id")
45
46
  return tpu_name
@@ -47,7 +48,7 @@ def get_node_name() -> str:
47
48
 
48
49
  def get_node_worker_id() -> int:
49
50
  """For multi-host TPU VM, this returns the worker id for the current node."""
50
- worker_id = os.getenv("TPU_WORKER_ID", None)
51
+ worker_id = envs.TPU_WORKER_ID
51
52
  if worker_id is None:
52
53
  worker_id = get_tpu_metadata(key="agent-worker-number")
53
54
  if worker_id is None:
tpu_inference/utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
- import os
3
2
  import time
4
3
  from collections import defaultdict
5
4
  from collections.abc import Sequence
@@ -14,8 +13,10 @@ from jax._src import mesh as mesh_lib
14
13
  from jax._src import xla_bridge as xb
15
14
  from jax._src.lib import xla_client as xc
16
15
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
17
- from vllm import envs, utils
16
+ from vllm import envs as vllm_envs
17
+ from vllm import utils
18
18
 
19
+ from tpu_inference import envs
19
20
  from tpu_inference.logger import init_logger
20
21
 
21
22
  GBYTES = 1024 * 1024 * 1024
@@ -57,10 +58,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
57
58
 
58
59
  def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
59
60
  usage = []
60
- if envs.VLLM_TPU_USING_PATHWAYS:
61
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS:
61
62
  return pathways_hbm_usage_gb(devices)
62
63
 
63
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
64
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
64
65
  if multihost_backend == "ray":
65
66
  # MemoryStats is only supported for addressable PjRt devices.
66
67
  # Assume all the devices have similar memory usage for now.
@@ -132,8 +133,8 @@ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
132
133
  hbm_used = defaultdict(int)
133
134
  hbm_limit = get_device_hbm_limit()
134
135
  for array in live_arrays:
135
- for buffer in array.device_buffers:
136
- hbm_used[buffer.device] += buffer.nbytes
136
+ for buffer in array.addressable_shards:
137
+ hbm_used[buffer.data.device] += buffer.data.nbytes
137
138
  return [(hbm_used[device], hbm_limit) for device in devices]
138
139
 
139
140
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  import os
4
4
  import tempfile
5
+ from dataclasses import dataclass, field
5
6
  from typing import Callable, Dict, Optional, Tuple
6
7
 
7
8
  import jax
@@ -10,6 +11,7 @@ import jaxlib
10
11
  import jaxtyping
11
12
  import vllm.envs as vllm_envs
12
13
  from vllm.config import VllmConfig, set_current_vllm_config
14
+ from vllm.distributed import get_pp_group
13
15
  from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
14
16
  has_kv_transfer_group)
15
17
  from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@@ -23,10 +25,13 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
23
25
  from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
24
26
 
25
27
  from tpu_inference import envs, utils
28
+ from tpu_inference.distributed import jax_parallel_state
26
29
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
27
30
  get_node_id)
28
- from tpu_inference.layers.jax.sharding import ShardingConfigManager
31
+ from tpu_inference.layers.common.sharding import ShardingConfigManager
29
32
  from tpu_inference.logger import init_logger
33
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
34
+ JaxIntermediateTensors
30
35
  from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
31
36
  from tpu_inference.runner.tpu_runner import TPUModelRunner
32
37
 
@@ -39,15 +44,39 @@ _DTYPE: dict[str, jnp.dtype] = {
39
44
  }
40
45
 
41
46
 
47
+ @dataclass
48
+ class PPConfig:
49
+ rank: int
50
+ ip: str
51
+ prev_worker_ip: str
52
+ pp_world_size: int
53
+
54
+ # default env vars for
55
+ # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
56
+ # if PP is used in single host.
57
+ default_tpu_process_bounds: str = field(init=False)
58
+ default_tpu_chips_per_process_bounds: str = field(init=False)
59
+ default_tpu_visible_chips: str = field(init=False)
60
+
61
+ def __post_init__(self):
62
+ self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
63
+ self.default_tpu_chips_per_process_bounds = "1,1,1"
64
+ self.default_tpu_visible_chips = f"{self.rank}"
65
+
66
+
42
67
  class TPUWorker:
43
68
 
44
- def __init__(self,
45
- vllm_config: VllmConfig,
46
- local_rank: int,
47
- rank: int,
48
- distributed_init_method: str,
49
- is_driver_worker: bool = False,
50
- devices=None):
69
+ def __init__(
70
+ self,
71
+ vllm_config: VllmConfig,
72
+ local_rank: int,
73
+ rank: int,
74
+ distributed_init_method: str,
75
+ is_driver_worker: bool = False,
76
+ devices=None,
77
+ ip: str = "localhost",
78
+ prev_worker_ip: str = "localhost",
79
+ ):
51
80
  # If we use vLLM's model implementation in PyTorch, we should set it
52
81
  # with torch version of the dtype.
53
82
  impl = envs.MODEL_IMPL_TYPE
@@ -74,6 +103,8 @@ class TPUWorker:
74
103
  self.devices = devices if devices is not None else []
75
104
  self.device_ranks = set(device.id for device in self.devices
76
105
  if isinstance(device, jaxlib._jax.Device))
106
+ self.pp_config = PPConfig(rank, ip, prev_worker_ip,
107
+ self.parallel_config.pipeline_parallel_size)
77
108
 
78
109
  if self.model_config.trust_remote_code:
79
110
  # note: lazy import to avoid importing torch before initializing
@@ -86,7 +117,7 @@ class TPUWorker:
86
117
  # TPU Worker is initialized. The profiler server needs to start after
87
118
  # MP runtime is initialized.
88
119
  self.profile_dir = None
89
- if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
120
+ if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
90
121
  if not self.devices or 0 in self.device_ranks:
91
122
  # For TPU, we can only have 1 active profiler session for 1 profiler
92
123
  # server. So we only profile on rank0.
@@ -94,6 +125,14 @@ class TPUWorker:
94
125
  logger.info("Profiling enabled. Traces will be saved to: %s",
95
126
  self.profile_dir)
96
127
 
128
+ # For PP, we use MPMD so we want to profile every worker.
129
+ if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
130
+ self.profile_dir = os.path.join(
131
+ vllm_envs.VLLM_TORCH_PROFILER_DIR,
132
+ f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
133
+ )
134
+ os.makedirs(self.profile_dir, exist_ok=True)
135
+
97
136
  use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
98
137
  # Only one instance of profiler is allowed
99
138
  if use_jax_profiler_server and self.rank < 1:
@@ -105,31 +144,87 @@ class TPUWorker:
105
144
  )
106
145
  jax.profiler.start_server(jax_profiler_server_port)
107
146
 
147
+ # step_counter is used to calculate uuid to transfer intermediate tensors.
148
+ self.step_counter = 0
149
+
108
150
  def initialize_cache(self, num_gpu_blocks: int,
109
151
  num_cpu_blocks: int) -> None:
110
152
  self.cache_config.num_gpu_blocks = num_gpu_blocks
111
153
  self.cache_config.num_cpu_blocks = num_cpu_blocks
112
154
 
113
- def init_device(self):
155
+ def init_device(self,
156
+ tpu_process_bounds="",
157
+ tpu_chips_per_process_bounds="",
158
+ tpu_visible_chips=""):
159
+ # set tpu visible devices for Jax runtime in single host PP.
160
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
161
+ if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
162
+ tpu_ports = [
163
+ jax_parallel_state.BASE_JAX_PORT + i
164
+ for i in range(self.pp_config.pp_world_size)
165
+ ]
166
+ os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
167
+ [f"localhost:{port}" for port in tpu_ports])
168
+ os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
169
+ os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
170
+
171
+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
172
+ # Replace with your own topology.
173
+ # There are 2 ways of subslicing a v6e
174
+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
175
+ # TPU_PROCESS_BOUNDS = "1,1,1"
176
+ # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
177
+ # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
178
+ # 2) 1 chip for each subslice, with at most 8 subslices,
179
+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
180
+ os.environ[
181
+ "TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
182
+ if tpu_process_bounds \
183
+ else self.pp_config.default_tpu_process_bounds
184
+ os.environ[
185
+ "TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
186
+ if tpu_chips_per_process_bounds \
187
+ else self.pp_config.default_tpu_chips_per_process_bounds
188
+ os.environ[
189
+ "TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
190
+ if tpu_visible_chips \
191
+ else self.pp_config.default_tpu_visible_chips
192
+
114
193
  if not self.devices:
115
194
  sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
116
195
  device_indexes = sharding_config.device_indexes
117
196
  if device_indexes is not None and len(device_indexes) > 0:
118
197
  # Enforcing the devices sequence to be consistent with the specified device indexes
119
- all_devices = jax.devices()
120
- device_dict = {device.id: device for device in all_devices}
198
+ all_local_devices = jax.local_devices()
199
+ device_dict = {
200
+ device.id: device
201
+ for device in all_local_devices
202
+ }
121
203
  self.devices = []
122
204
  for device_index in device_indexes:
123
205
  device = device_dict[device_index]
124
206
  if device is None:
125
207
  raise KeyError(
126
208
  f"Device index {device_index} not found in "
127
- f"jax.devices() with IDs {list(device_dict.keys())}!"
209
+ f"jax.local_devices() with IDs {list(device_dict.keys())}!"
128
210
  )
129
211
  self.devices.append(device)
212
+ assert len(self.devices) >= sharding_config.total_devices
130
213
  self.devices = self.devices[:sharding_config.total_devices]
131
214
  else:
132
- self.devices = jax.devices()[:sharding_config.total_devices]
215
+ if self.pp_config.pp_world_size > 1:
216
+ # We only support a mixed tp + pp scenario that tp size is
217
+ # smaller or equals the total TPUs in one node
218
+ # say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
219
+ assert jax.local_device_count(
220
+ ) >= sharding_config.total_devices
221
+ self.devices = jax.local_devices()[:sharding_config.
222
+ total_devices]
223
+ else:
224
+ # In a multi-host distributed env, say: Ray, local_device count may smaller
225
+ # than the total devices, we just choose the smaller set here.
226
+ self.devices = jax.devices()[:sharding_config.
227
+ total_devices]
133
228
 
134
229
  # Initialize the vLLM distribution layer as a single chip environment,
135
230
  # we'll swap the model's parallel modules with TPU SPMD equivalents.
@@ -146,8 +241,18 @@ class TPUWorker:
146
241
  tensor_model_parallel_size=1,
147
242
  pipeline_model_parallel_size=1,
148
243
  )
244
+
245
+ jax_parallel_state.init_pp_distributed_environment(
246
+ self.pp_config.ip,
247
+ self.rank,
248
+ self.parallel_config.pipeline_parallel_size,
249
+ self.devices[0],
250
+ need_pp=self.parallel_config.pipeline_parallel_size > 1)
251
+
149
252
  ensure_kv_transfer_initialized(self.vllm_config)
150
- self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
253
+ self.model_runner = TPUModelRunner(
254
+ self.vllm_config, self.devices, self.rank, self.rank == 0,
255
+ self.rank == self.pp_config.pp_world_size - 1)
151
256
  logger.info(f"Init worker | "
152
257
  f"rank={self.rank} | "
153
258
  f"node_id={get_node_id()} | "
@@ -155,6 +260,12 @@ class TPUWorker:
155
260
  f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
156
261
  vllm_utils.report_usage_stats(self.vllm_config)
157
262
 
263
+ def initialize_pp_transfer_connect(self):
264
+ if self.rank == 0:
265
+ return
266
+ jax_parallel_state.connect(self.pp_config.prev_worker_ip,
267
+ self.rank - 1)
268
+
158
269
  def determine_available_memory(self) -> int:
159
270
  gpu_memory_utilization = self.cache_config.gpu_memory_utilization
160
271
  hbm_usage = utils.hbm_usage_bytes(self.devices)
@@ -194,14 +305,39 @@ class TPUWorker:
194
305
  # deliberate, temporary compromise for the same reasons outlined in
195
306
  # the `get_kv_cache_spec` method.
196
307
 
197
- output = self.model_runner.execute_model(scheduler_output)
198
-
199
- # With a connector, the scheduler expects output from all workers
200
- # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
201
- if has_kv_transfer_group():
202
- return output
203
-
204
- return output if self.is_driver_worker else None
308
+ if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
309
+ intermediate_tensors = None
310
+ else:
311
+ # receive intermediate tensors
312
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
313
+ scheduler_output, self.rank - 1, self.step_counter)
314
+ # TODO: this method might only works for vllm model, not sure about jax models.
315
+ tensor_spec = self.model_runner.get_intermediate_tensor_spec(
316
+ scheduler_output.total_num_scheduled_tokens)
317
+ intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
318
+ uuid, tensor_spec)
319
+ intermediate_tensors = JaxIntermediateTensors(
320
+ intermediate_tensors_dict)
321
+
322
+ output = self.model_runner.execute_model(scheduler_output,
323
+ intermediate_tensors)
324
+
325
+ if isinstance(output, JaxIntermediateTensors):
326
+ assert self.parallel_config.pipeline_parallel_size > 1
327
+ assert not get_pp_group().is_last_rank
328
+ # send intermediate tensors
329
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
330
+ scheduler_output, self.rank, self.step_counter)
331
+ get_pp_group().send_tensor_dict(uuid, output.tensors)
332
+ self.step_counter += 1
333
+ return None
334
+ else:
335
+ self.step_counter += 1
336
+ # With a connector, the scheduler expects output from all workers
337
+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
338
+ if has_kv_transfer_group():
339
+ return output
340
+ return output if self.is_driver_worker else None
205
341
 
206
342
  def sample_tokens(self,
207
343
  grammar_output: GrammarOutput) -> ModelRunnerOutput:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1.dev202511130813
3
+ Version: 0.11.1.dev202511220812
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -27,7 +27,7 @@ Requires-Dist: jaxtyping
27
27
  Requires-Dist: flax==0.11.1
28
28
  Requires-Dist: torchax==0.0.7
29
29
  Requires-Dist: qwix==0.1.1
30
- Requires-Dist: torchvision==0.23.0
30
+ Requires-Dist: torchvision==0.24.0
31
31
  Requires-Dist: pathwaysutils
32
32
  Requires-Dist: parameterized
33
33
  Requires-Dist: numba==0.62.1
@@ -1,8 +1,9 @@
1
1
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  tests/test_base.py,sha256=Ct5WFRMHL7IHEIxk8FrzAvO8m0xFuDpzDBKkAKKAL2Q,7341
3
+ tests/test_envs.py,sha256=Woyfp_d5HS-uTGo4_u9dYlBbgmhfIEoFb-Rx_k7YXD4,6298
3
4
  tests/test_quantization.py,sha256=IT5ASyS1uuWcxc22kRtBcA-V4j3Z3hb7pMztm3GOlBs,34445
4
5
  tests/test_tpu_info.py,sha256=ZrwlMsp8ffITkS_b8Q1t_QG-a-WVAd4NUcjHhGibcsI,4670
5
- tests/test_utils.py,sha256=szRg4UB36RcgIvbEd9xMhKYbWi-O4XAUWGJlIU6FJ9E,7983
6
+ tests/test_utils.py,sha256=Mta5ZzYCgRAh1-BjcOvvx9iQ9DnnXLps7oDHxVQp2yE,8236
6
7
  tests/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
8
  tests/core/test_core_tpu.py,sha256=r496rk1eOsK_F4nvm9zprl_T-RcO6eCUb7LuVReOZno,21413
8
9
  tests/core/test_disagg_executor.py,sha256=QdE2YZs08EyDDCmSjhiXkXqQ9BJTgO6csr_E1xkkfSg,2256
@@ -20,27 +21,27 @@ tests/kernels/ragged_paged_attention_kernel_v3_test.py,sha256=Hrd8iUkS1pS3rxeTyY
20
21
  tests/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
22
  tests/lora/conftest.py,sha256=EXjwE1CjmUUlMEXpyE3UwxvgrKUllE73I8BNKfP1FTc,984
22
23
  tests/lora/test_bgmv.py,sha256=gQxWsJdNX2nkrE2xyrG0exwf3E2eHm2k2nkEXoANuQc,1359
23
- tests/lora/test_layers.py,sha256=21ekYlsK36r1GPZOfzs7E-KIsfI1JcuZl1E6vaQbHf4,26273
24
+ tests/lora/test_layers.py,sha256=6B4HhMAItQmt0hPAQgyXgwSYs7b3bIbUf6LaPsqXLzY,25923
24
25
  tests/lora/test_lora.py,sha256=wJiF1P1BDnPN8TLX2tlFtdZ_QCkV-S9nPl6_uR6DqFc,4439
25
- tests/lora/utils.py,sha256=dR_v1H20vPVjFHdBhDajWOz0WJZlKuPLgMFQsME0LtA,3009
26
- tpu_inference/__init__.py,sha256=7IduGWw-_fwx0VA6EvC_AqHF67fnnShz6YvkqCfvFx8,1317
26
+ tests/lora/utils.py,sha256=rY0tDZEZe58ye4-ykwrTnsiWuLcaEG57N_Rua90bDXI,2726
27
+ tpu_inference/__init__.py,sha256=p4MaepRdN7723FUNE-3pOMxZWjFn4_TVFgjrNyty4JE,2304
27
28
  tpu_inference/env_override.py,sha256=pmL7lfs_rGCP92ya3wuWuudsCYeOMZ6tFZY82A4KkQc,365
28
- tpu_inference/envs.py,sha256=MTT_Pdtd6cAcciYjv1OekEmvspaq3SYL0oR_jDkQ_aE,3948
29
+ tpu_inference/envs.py,sha256=hoPuT0SyLCxqyZ0QJIha6EXSZv2TpACfmENuiT0iJMM,3956
29
30
  tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
30
- tpu_inference/tpu_info.py,sha256=9UohshkndR6dZpGWpWXfTD4qvIVdVgHf0yOoSEkLTrw,2276
31
- tpu_inference/utils.py,sha256=LWEshJgUdB20H2fDA-QI-Sk4EP7PD_FWvW3Mrqb-k8M,10054
31
+ tpu_inference/tpu_info.py,sha256=3iilHRQSFjwMJwhKcuuawTm7mhwkgHbj4zi6CiAySrs,2265
32
+ tpu_inference/utils.py,sha256=Ddsx2CY2ARe46RZL27URzXCN3P6pMcKWB-APXUB8sHs,10098
32
33
  tpu_inference/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
- tpu_inference/core/core_tpu.py,sha256=JdN4-xaxSWnzY4T181SCnbZ5HEnwQ5IifYA9ybF4pWo,32710
34
+ tpu_inference/core/core_tpu.py,sha256=WDD3koE_j1QhWS2BbMA2aQOZayPZm4tYPvzL4YCX2jY,33294
34
35
  tpu_inference/core/disagg_executor.py,sha256=HZpgYMVxRxm0RQxO4l8IDYBWJ6Z3Tac6xavc5otcirc,4657
35
- tpu_inference/core/disagg_utils.py,sha256=ufWNFWQ5n4YnZpPOtoReHlYo4dlN7AbIqCyqS4an0t4,1572
36
+ tpu_inference/core/disagg_utils.py,sha256=lv8MAVoAjtcmTaenUXVokg2q3d0tzsma86UiQlQ3omY,1492
36
37
  tpu_inference/core/sched/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
38
  tpu_inference/core/sched/dp_scheduler.py,sha256=mKs8Ms46szdlBfo8hjdqis2ZKAZbcKnHAGfEr0X5R8g,22527
38
39
  tpu_inference/distributed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
40
  tpu_inference/distributed/jax_parallel_state.py,sha256=5_xCwcL03lFPUoSO_OP7hIVKpUFroW1m-jVO7R6FbUc,2223
40
- tpu_inference/distributed/tpu_connector.py,sha256=Zah46Sm5iOuh72SzXw69NxMc0MLnqsLEpe2BfDhpnqA,29731
41
- tpu_inference/distributed/utils.py,sha256=RwFQi8G4TzN1g9RjQu0pb5JxSc_jhoIZVsFJo0uHjxo,1513
41
+ tpu_inference/distributed/tpu_connector.py,sha256=w_gOI6hX7NWefaxN_9XH9TXReGElOyFifdDHpPswotM,29696
42
+ tpu_inference/distributed/utils.py,sha256=1KIREn28Zg10O-MSUkVQMRzS09WoGc_VLGOX4QTFJac,1504
42
43
  tpu_inference/executors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
- tpu_inference/executors/ray_distributed_executor.py,sha256=UgJP-XSgDPKDj_mkVQ16XrRN96juVpnFl6fdWEyFL_Y,15249
44
+ tpu_inference/executors/ray_distributed_executor.py,sha256=emYfSFJ3kluEmi6mlfnvxSUrC_mGVRVcjrUqUH2MR4g,16122
44
45
  tpu_inference/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
46
  tpu_inference/experimental/llama3_jax_stashed.py,sha256=YK1oSIfto9ALo-HB45XfSrbq9XgVbE4m2C-9zRwmSzI,10913
46
47
  tpu_inference/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -66,27 +67,28 @@ tpu_inference/kernels/ragged_paged_attention/v2/kernel.py,sha256=OiQGAHhyggbp1Pe
66
67
  tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256=vGp2ZWODTbjyG9z2z0Qf_BX-wYHd5bUybnc_DtOz0nI,10995
67
68
  tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
68
69
  tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=tlP6121yfXaukx_RQroHlHcZnbKPyyum0lAcvT0B_Pk,56132
70
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=DFVdIIKmyufu_4b-3YhxI56jt0O1cJ3JsVl-2DDZHv4,55350
71
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=leTS75aq99N1Zuv6wB5yLdkfYnEtrBDVI4z_jOKnjL0,142012
70
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=O179Fft5KpuN5LIFx3SghWXJJUqh3Og-xqfO4Z8QXYU,57032
71
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=z0oaH8ZkDmHSoG4yiiO2CN0kuAuFcEpQ3RUoi5msjlo,56904
72
+ tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=k3LwduhZO85cJ-pSgnGN0c2Nn8eNeQq4eA94KUXJzMw,142198
72
73
  tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py,sha256=P3_ivi8iUz5QMU_3pgpl4Bkbmn0q0NpDtVJX39haRQA,11208
73
74
  tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=1N_ozjKboDYLteFJndWoLXNudj2z53rGXMkELa5Z9tY,1102
74
75
  tpu_inference/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
75
76
  tpu_inference/layers/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
+ tpu_inference/layers/common/attention_interface.py,sha256=CImMS8tuWgvaRY9YbGS3pY7OBnzeJ4Jla7LRFb4Xoa4,13224
76
78
  tpu_inference/layers/common/attention_metadata.py,sha256=St8ZatbY1D7xQACKJH459jMgp3oTP3AQ36mi9FZdrPU,850
79
+ tpu_inference/layers/common/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
80
+ tpu_inference/layers/common/quant_methods.py,sha256=mQSxZ44-QQtm22C_8ViejnP1cP2Dv6yc2YaP6oMKJeQ,185
81
+ tpu_inference/layers/common/sharding.py,sha256=wBqdkXZSWfnnH8pkJtyW2DSqmAe_V4Vxi0iMPaXq0Z0,25185
77
82
  tpu_inference/layers/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
- tpu_inference/layers/jax/attention_interface.py,sha256=1jlvSZWaP6DuPVtb1W_KPw4-Qi68BikOBNLLcpygupY,13221
79
83
  tpu_inference/layers/jax/base.py,sha256=Vhts6ZMwNCZ8LbnEXeB0rl3nHdS5hDJWX7HEa7Fl7yE,5775
80
- tpu_inference/layers/jax/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
81
84
  tpu_inference/layers/jax/constants.py,sha256=NcYg0zAf3ClfP7YMYdYu_F1GngOzZaIxIAHBZDunKw4,2755
82
85
  tpu_inference/layers/jax/layers.py,sha256=yv_lC2tbJuzVL-OaXYooX82Ys8hWZATeH9M78coJ3VI,10633
83
86
  tpu_inference/layers/jax/misc.py,sha256=znKv1Nuq_LgYpaIu0qlzUVDgQWnjjG7aqPJGM8kuwcw,566
84
87
  tpu_inference/layers/jax/rope.py,sha256=i2E7pRLWgOaFLbeo8_phZwKQWJW7ohAyl69E2V2Mc2U,11349
85
88
  tpu_inference/layers/jax/rope_interface.py,sha256=X0SruXizlCHGnssFujC1pL07UC4Vsp7-gdBy_Q7JZhI,8375
86
- tpu_inference/layers/jax/sharding.py,sha256=wBqdkXZSWfnnH8pkJtyW2DSqmAe_V4Vxi0iMPaXq0Z0,25185
87
89
  tpu_inference/layers/jax/transformer_block.py,sha256=ufv-yfVDmRP_Ynrx3UX9xj-x0PkNw_tQ-0N0eYf4i7M,3917
88
90
  tpu_inference/layers/jax/attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
89
- tpu_inference/layers/jax/attention/attention.py,sha256=bWXMtF8TToiRyZ3SvJGQnD2urZTbuX_omHVXpQdn0fs,10082
91
+ tpu_inference/layers/jax/attention/attention.py,sha256=DJFDkpQc9SDD156wVPFw3r2XaBgb44QNJ8OcdONaF5g,10085
90
92
  tpu_inference/layers/jax/attention/deepseek_v3_attention.py,sha256=YlagoBMwINv2KRH1dr4oEcH_cQ9QMPB55nO2FQZsWs0,14010
91
93
  tpu_inference/layers/jax/attention/gpt_oss_attention.py,sha256=rkrEv4aNZxtAGcXd1HXHUxhNeDNAd9nWTEZOKWSI8cA,8725
92
94
  tpu_inference/layers/jax/attention/llama4_attention.py,sha256=VvUmfBxQEbHf3F2BrcYDUnq5abj7CSDYeRsNx_eVAh0,6162
@@ -95,50 +97,46 @@ tpu_inference/layers/jax/moe/deepseek_v3_moe.py,sha256=Q6CuwwiZtWYm6iUee1wJoDJrw
95
97
  tpu_inference/layers/jax/moe/gpt_oss_moe.py,sha256=Rx5b1jg2XMm7Xx9hrjgvyhscaJ_zGbVMHmeEiLh7kIQ,6196
96
98
  tpu_inference/layers/jax/moe/moe.py,sha256=cA8R1rjbBwNEoNlsPWjeIBB9nvaRDwlEdwQTVg6lTpY,8762
97
99
  tpu_inference/layers/jax/sample/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
98
- tpu_inference/layers/jax/sample/rejection_sampler.py,sha256=IRfVWjkbVXp9Sv1YrGMMh-LYx1AwbY-3FTXEO1-Ue9g,20423
99
- tpu_inference/layers/jax/sample/sampling.py,sha256=dVOcMdmPdAEsupPk96tCaZecIWUiDej0DiVnwaH9ckQ,3308
100
+ tpu_inference/layers/jax/sample/rejection_sampler.py,sha256=nI5s0E73xkqDIu2hTljIXt23B1Q-gRnC1myoQpGDJrQ,20426
101
+ tpu_inference/layers/jax/sample/sampling.py,sha256=C30KgmdOVSaagvHhbfLgVJtVQmJo86CbHPa4h36Vn70,3314
100
102
  tpu_inference/layers/jax/sample/sampling_metadata.py,sha256=Gd835LNWfGM0NRQBVBqEv0nPwt5q9F4AdFym0CUS1fw,2561
101
103
  tpu_inference/layers/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
102
- tpu_inference/layers/vllm/attention.py,sha256=JxEQ8ql_97zbQzukIbfUYq50-2k81VUG1Km_YV_RUtg,7363
104
+ tpu_inference/layers/vllm/attention.py,sha256=wbJpcgqEAuIirv5PIULbiP-ggMKjmTanbB7Dg0BVYv4,7366
103
105
  tpu_inference/layers/vllm/fused_moe.py,sha256=XZt2CPUz00qZzDcyfBFz6buhVzmGL1amHalHJALl9zw,18945
104
106
  tpu_inference/layers/vllm/linear_common.py,sha256=_YlJtbdaYcck_j-gFLos_k0ycktVWxT8Qo57tR2YqJ8,7749
105
- tpu_inference/layers/vllm/sharding.py,sha256=WTx1tF_7R99AdyE-lL7HQJ378hAafeI-JVRsugAvwn4,9177
106
- tpu_inference/layers/vllm/quantization/__init__.py,sha256=Tz44kUZTdNFu5Dmu48aQ-9f7ioWjbUWS0eVYURXZ17E,1535
107
- tpu_inference/layers/vllm/quantization/awq.py,sha256=ar8x1CPTPvfcf4wbuBC1XVh4pjtSUchoYWnbkZKH3CQ,8412
107
+ tpu_inference/layers/vllm/sharding.py,sha256=as7CF8UKTF3ToymwRY5Pi8uzwJk0P1sHPkWB5xEx3mA,9169
108
+ tpu_inference/layers/vllm/quantization/__init__.py,sha256=SEppGayBzzQ5tsXLSy99aqilkAawQwYxnv2alCg6-ZU,1777
109
+ tpu_inference/layers/vllm/quantization/awq.py,sha256=-8ZmjGvSKJB6_JuwSctNWt8xHWq4VSvK_AK9iahlgCo,8495
108
110
  tpu_inference/layers/vllm/quantization/common.py,sha256=wm3pge6XMTMsLK7_SSdgBP0PvQzz-1mrqN2I6xMqzrc,4218
109
- tpu_inference/layers/vllm/quantization/unquantized.py,sha256=id6d_IZIhDIvmaH3ANtmLiy4U_uY_AYAf4KTvfs3nmc,14900
111
+ tpu_inference/layers/vllm/quantization/mxfp4.py,sha256=KwGoqIiPkd6FplGuYAKi4uX5A8MPlZqq99MVPchXyi4,11561
112
+ tpu_inference/layers/vllm/quantization/unquantized.py,sha256=Q1v1ZbSIDmaoOg97Ehv6rA5CnSf6nTP40xDBMmHHeLw,15054
110
113
  tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
111
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=uKaauZhaRDcMqd8_NyQoFs9BazMOFix3nIuutbLHHbU,5123
114
+ tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=6idEyy3e849fZ1UeNvc9eSHYX7e6qvohrJa_d_D9MBk,5285
112
115
  tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py,sha256=FM901QhyhJRC8CuMeICzCVVERvBHbhruRxYW0EQ570s,8820
113
116
  tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
114
117
  tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py,sha256=6sQvsxiWdi5Vte8V9vrQ2abaqGqWpq-mtzU7lGAo-ac,8759
115
118
  tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=4y7lYgybpXszpCAtxGFhR8LDEbEoCCeo3DfUSOXxhaQ,5202
116
119
  tpu_inference/lora/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
117
120
  tpu_inference/lora/torch_lora_ops.py,sha256=pr3N7DVfkn3ANijUC6dBoiCtIJW4fdJpKdC3zWBUsxE,3121
118
- tpu_inference/lora/torch_punica_tpu.py,sha256=b27DpmIS_N5bhlIcryiENYNmPxp_cu40CGxjPW64d44,12706
119
- tpu_inference/mock/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
120
- tpu_inference/mock/vllm_config_utils.py,sha256=FlQshLjoHdgs3C66tYHYbKFUjbk9DhUwY-7HibZk0fI,878
121
- tpu_inference/mock/vllm_envs.py,sha256=cCubeOhH2WeYZQFJt6W0y_IiQo0fzIWR1LCCE8i6kI4,50990
122
- tpu_inference/mock/vllm_logger.py,sha256=vUGnN5nKT--ZvU15YCzODUM_FGiXKhcrrjDGjeN00RQ,7297
123
- tpu_inference/mock/vllm_logging_utils.py,sha256=TEUmKj3xHiLzHBnFqAujcxH0t2hBQ04sUaho2RyORnk,486
121
+ tpu_inference/lora/torch_punica_tpu.py,sha256=qTnXZGLoOgvukSxeunO_SfpPTlkq9GlMj9H7zVYg9LE,12680
124
122
  tpu_inference/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
125
123
  tpu_inference/models/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
126
- tpu_inference/models/common/model_loader.py,sha256=AwukmGaUq2wv3OnFHUU-nwdAnKLG_eGw7PYY5CNrNNI,18225
124
+ tpu_inference/models/common/model_loader.py,sha256=3rRntyGqS6l7yAfURmRaGkhyIaee2E43a5F0_i0IFmE,18177
127
125
  tpu_inference/models/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
128
126
  tpu_inference/models/jax/deepseek_v3.py,sha256=SKOHVEC-_2NLxBnzBzbu5tu0d6FTlAEiI1EefGaO2QE,40047
129
127
  tpu_inference/models/jax/gpt_oss.py,sha256=Vw4LRB5Kp6hbA2hjZGFS8kiEqOCjf881XH2JNtu2S1I,20924
130
128
  tpu_inference/models/jax/jax_intermediate_tensor.py,sha256=Pxu1PCV5LN5X58aYVkPiohcXZIeKVim2oqvrS_cVgw4,2604
131
- tpu_inference/models/jax/llama3.py,sha256=YUG0S0Y6cy7PLcq0cpmDsGWbOZIhZzzyObRQdmUUxkg,13420
129
+ tpu_inference/models/jax/llama3.py,sha256=ZiFtrpAzXTT9vAPES9UeuJInCWGbvDWs7g0_JLdCCa4,13479
132
130
  tpu_inference/models/jax/llama4.py,sha256=wf2Sp2iYViaYD5rSfv3_ryO6gYuYM5XaOyvghaP4OCY,29631
133
- tpu_inference/models/jax/llama_eagle3.py,sha256=STUkAK6XEA7JM3i_Lx36-t5BhkAGeW_xYiq3zYhHP1A,12297
134
- tpu_inference/models/jax/phi3.py,sha256=Oz68PE2Z1t8wTed95_w0KMIXfnfV72ZwXugNOdWOV5w,13576
135
- tpu_inference/models/jax/qwen2.py,sha256=RYb0hMKzPnFOAyhqbztoNlSrFIlRa74fYqSNecA2VOY,13354
136
- tpu_inference/models/jax/qwen2_5_vl.py,sha256=J4-AjeS_igJdxYCjTwS0HShiEfwQUMwrHxjlWvMw0ok,43939
137
- tpu_inference/models/jax/qwen3.py,sha256=SOL-Pvp56IrMxqXpPf5EFacBI6AJNlqf4Zrr1pkabGw,10994
131
+ tpu_inference/models/jax/llama_eagle3.py,sha256=xUoNetxDbcFIEVLZ2DiD-GEQhHcdau2v1R12WdMyGec,12550
132
+ tpu_inference/models/jax/llama_guard_4.py,sha256=LrnU2zBWM0s4q_5dwmR--OO0V7ttltsYhrHYlBgQVIw,15275
133
+ tpu_inference/models/jax/qwen2.py,sha256=SuAp7tErk8OoIRko0Vt6QSOZP_9B9r5GTfqmVfImUIo,13410
134
+ tpu_inference/models/jax/qwen2_5_vl.py,sha256=tf177ypgA1ZVIn34Ff_LTwr10NwzlZ3-DPqSoRLAQtQ,43995
135
+ tpu_inference/models/jax/qwen3.py,sha256=CIZQKjZDke_LPGsLNhRCJdDTzWueUneBPAQ1blS24IM,11050
138
136
  tpu_inference/models/jax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
139
137
  tpu_inference/models/jax/utils/file_utils.py,sha256=NOuSC3YFnZpf3CZgYdghbbiNYJt42zgjlEYbOZIVct4,2840
140
138
  tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=rrIrQWidkUnGilBHKNpdYh7_2BkvnAaqanXjC81GNcg,6156
141
- tpu_inference/models/jax/utils/weight_utils.py,sha256=65-H8BTbyilIBMBfvWjkkW3mf4soYASbhrJFqbFKzL4,20129
139
+ tpu_inference/models/jax/utils/weight_utils.py,sha256=d5u8pPR-qPbEjX-8BMY0Zea9O-a34CpfuDlVnbwWfAw,20659
142
140
  tpu_inference/models/jax/utils/quantization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
143
141
  tpu_inference/models/jax/utils/quantization/mxfp4_utils.py,sha256=boGnqJCRIOf5nedAxQ8_IUTV6Rfll10DXnRC40BeeE8,3682
144
142
  tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=xgKoKB7AM3TYPxzVgEGLTK9ebQH2Kx8mNuO0heovkmk,26778
@@ -147,30 +145,30 @@ tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml,sha256=b7Sy
147
145
  tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml,sha256=0Qwij71zj9k6rmrUNd8Q5df9YYfkoJ1ZkgMAHxQy81k,128
148
146
  tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml,sha256=lGec0UwwxmNPNgKPSsTsCMSXNJjhw507KMtM2NsSCMw,152
149
147
  tpu_inference/models/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
150
- tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=ERxj-cm-pmYpT9eiL-E3OxeaQDEDrH_Vs0iUS9nCU9s,11424
148
+ tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=hEjg5hKotp-fEt3SXWkWpdnQ32TU1XGpTrfhyLTNyt0,12054
151
149
  tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=yxlJHPmRQIAwlb1MmHK3xfXokgIkJ-evNU4PgyoJUdg,1187
152
150
  tpu_inference/platforms/__init__.py,sha256=lQCrKddS_GcGpCbeogvz9zOZD1mQw5bBsiw8On46qFQ,74
153
- tpu_inference/platforms/tpu_platform.py,sha256=bdo_zlRqrhccpaz6zOdH18cU8kq6tGKgR1xJJehsVrc,10131
151
+ tpu_inference/platforms/tpu_platform.py,sha256=RSCe3Ne1FsWXVrX6_6V_Z6B0TDTRS38eM0KTkXbQ_w8,10579
154
152
  tpu_inference/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
155
153
  tpu_inference/runner/block_table.py,sha256=K3Ic8EgPM08d_C5nEN60mxoRydlaQWySAemf_8Q_qVw,4175
156
- tpu_inference/runner/compilation_manager.py,sha256=pJFFLkFVmhXukBIxGRUo-hrOqx8jl8JIUuS36fZ2yvg,36177
154
+ tpu_inference/runner/compilation_manager.py,sha256=oVML1KhhQ7YFaSWBaJA0qWQoNX2qRZOrwbbh4XYPc-8,37287
157
155
  tpu_inference/runner/input_batch.py,sha256=bx221NX2IOWzrtopss-B-2ZKW4y-U6nQpG09PjpUziw,18273
158
- tpu_inference/runner/kv_cache.py,sha256=i54EbGQB-9bbOgk6KibTpJpTE2pfFuTfis7J1P_UB0M,4574
159
- tpu_inference/runner/kv_cache_manager.py,sha256=CJxXtdWuewJqcTBMoR70_Uvwxjtc3cK2jxe1KpI9kQc,22152
156
+ tpu_inference/runner/kv_cache.py,sha256=F4dzW2d53xuxkFUn0oKzwE6VklGUeVm-QM19NVfIQDU,4577
157
+ tpu_inference/runner/kv_cache_manager.py,sha256=XEfis_9nQAz8uxM5y_P5biqSUijX4IeMhIusTf2V7vg,22444
160
158
  tpu_inference/runner/lora_utils.py,sha256=B4xMCgXGJ4VNdePvn89HH3tIZ-gYsQ7Vq_YCiYIATEY,3843
161
159
  tpu_inference/runner/multimodal_manager.py,sha256=azEPdHOwz8CN11MQmorGdtrCLbFaTCxdWyuEsZTzjYM,9778
162
160
  tpu_inference/runner/persistent_batch_manager.py,sha256=KERSfKy6XjMejnbtPGI3hzoYAHJLeCxmpZVYPqBCago,11156
163
161
  tpu_inference/runner/speculative_decoding_manager.py,sha256=I3FDWKh2dn6nV8LgTGfCTwMKYnxQsTPpBIrmaJngXHs,10215
164
162
  tpu_inference/runner/structured_decoding_manager.py,sha256=Y0ERPhj4olFh6Y2TxP0R1_4UIJwy7nemYA-h63YIR2U,3622
165
- tpu_inference/runner/tpu_runner.py,sha256=5vPFey3KFnh5lczyj4cIT3mVhR8RuX8kbcuHVOg8DAg,72318
163
+ tpu_inference/runner/tpu_runner.py,sha256=aHXHSlaNuc9q7pcPklqTFRkmkEQDULEEH_hsR_NcTMQ,77532
166
164
  tpu_inference/runner/utils.py,sha256=ZnWUoNo-7INeB0mdXti1jwUOdbmxyExznOs-crRTQLk,17126
167
165
  tpu_inference/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
168
166
  tpu_inference/spec_decode/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
169
- tpu_inference/spec_decode/jax/eagle3.py,sha256=A1dt-dmBttpy-5DGcL4noEDCB0OGP8Xo6MXqgJvWIo8,16593
167
+ tpu_inference/spec_decode/jax/eagle3.py,sha256=1WVHTdv6jfCKwbiz0RwQLPyq8L720gD_bs0p_Gz0QiI,16644
170
168
  tpu_inference/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
171
- tpu_inference/worker/tpu_worker.py,sha256=KY7fH--NP7jiTduP5m0gDnmB2LbhIel0Ts37XmjYpPM,14207
172
- tpu_inference-0.11.1.dev202511130813.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
173
- tpu_inference-0.11.1.dev202511130813.dist-info/METADATA,sha256=LARdH4AAJfZrrU2Pj4EIN8Zl0QLjzEpzkRCqBbeUdT8,5465
174
- tpu_inference-0.11.1.dev202511130813.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
175
- tpu_inference-0.11.1.dev202511130813.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
176
- tpu_inference-0.11.1.dev202511130813.dist-info/RECORD,,
169
+ tpu_inference/worker/tpu_worker.py,sha256=aojB9-PY_ZzTaZgv1i5PUB9CSXNVuK4JZzftCv9ku4A,20642
170
+ tpu_inference-0.11.1.dev202511220812.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
171
+ tpu_inference-0.11.1.dev202511220812.dist-info/METADATA,sha256=JzmyOlYYkImIe_WSawI0LDwL28xS-0SCRCcFXeYSV0g,5465
172
+ tpu_inference-0.11.1.dev202511220812.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
173
+ tpu_inference-0.11.1.dev202511220812.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
174
+ tpu_inference-0.11.1.dev202511220812.dist-info/RECORD,,
File without changes
@@ -1,28 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Any, List, Mapping
3
-
4
-
5
- @dataclass
6
- class ModelConfig():
7
- max_model_len: int = 2048
8
- max_prefill_len: int = 1024
9
- prefill_batch_size: int = 1
10
- decode_batch_size: int = 1
11
- block_size: int = 16
12
- num_layers: int = 32
13
- num_kv_heads: int = 32
14
- head_dim: int = 128
15
- vocab_size: int = 32000
16
- model: str = "llama3"
17
- hf_config: str = ""
18
- architectures: List[str] = field(default_factory=list)
19
- override_generation_config: dict[str, Any] = field(default_factory=dict)
20
- hf_overrides: dict[str, Any] = field(default_factory=dict)
21
-
22
-
23
- @dataclass
24
- class VllmConfig():
25
- additional_config: Mapping[str, Any] = field(default_factory=dict)
26
- # Set default max_model_len to turn off warnings.
27
- model_config: ModelConfig = field(
28
- default_factory=lambda: ModelConfig(max_model_len=1024))