tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (25) hide show
  1. tests/test_envs.py +32 -11
  2. tests/test_utils.py +1 -2
  3. tpu_inference/distributed/tpu_connector.py +1 -1
  4. tpu_inference/envs.py +60 -7
  5. tpu_inference/executors/ray_distributed_executor.py +5 -1
  6. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +72 -19
  7. tpu_inference/layers/common/sharding.py +3 -4
  8. tpu_inference/layers/vllm/quantization/mxfp4.py +2 -1
  9. tpu_inference/models/common/model_loader.py +3 -1
  10. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  11. tpu_inference/models/vllm/vllm_model_wrapper.py +1 -2
  12. tpu_inference/platforms/tpu_platform.py +13 -20
  13. tpu_inference/runner/compilation_manager.py +87 -27
  14. tpu_inference/runner/kv_cache_manager.py +8 -15
  15. tpu_inference/runner/persistent_batch_manager.py +40 -2
  16. tpu_inference/runner/tpu_runner.py +68 -45
  17. tpu_inference/runner/utils.py +2 -2
  18. tpu_inference/spec_decode/jax/eagle3.py +52 -19
  19. tpu_inference/utils.py +31 -9
  20. tpu_inference/worker/tpu_worker.py +2 -2
  21. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +1 -1
  22. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +25 -25
  23. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  24. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  25. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,9 @@ from typing import Any, Optional
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
+ from flax import nnx
10
+ from jax import lax
11
+ from jax.sharding import NamedSharding, PartitionSpec
9
12
  from vllm.config import VllmConfig
10
13
 
11
14
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
@@ -127,6 +130,17 @@ class Eagle3Proposer:
127
130
  max_num_blocks_per_req)
128
131
  new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
129
132
 
133
+ positions = lax.with_sharding_constraint(
134
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
135
+ clamped_positions = lax.with_sharding_constraint(
136
+ clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
137
+ new_seq_lens = lax.with_sharding_constraint(
138
+ new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
139
+ query_start_loc = lax.with_sharding_constraint(
140
+ query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
141
+ new_block_tables = lax.with_sharding_constraint(
142
+ new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
143
+
130
144
  return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
131
145
 
132
146
  @functools.partial(jax.jit, static_argnums=(0, ))
@@ -138,6 +152,7 @@ class Eagle3Proposer:
138
152
  @functools.partial(jax.jit, static_argnums=(0, ))
139
153
  def _prepare_hidden_states_and_input_ids(
140
154
  self,
155
+ state: nnx.State,
141
156
  aux_hidden_states: tuple[jax.Array, ...],
142
157
  query_start_loc: jax.Array,
143
158
  target_token_ids: jax.Array,
@@ -146,7 +161,7 @@ class Eagle3Proposer:
146
161
  ) -> tuple[jax.Array, jax.Array, jax.Array]:
147
162
  target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
148
163
  target_hidden_states = self.combine_hidden_states_fn(
149
- self.state, target_hidden_states)
164
+ state, target_hidden_states)
150
165
 
151
166
  input_ids, last_token_indices = self._prepare_input_ids(
152
167
  query_start_loc, target_token_ids, next_token_ids, num_reqs)
@@ -193,8 +208,8 @@ class Eagle3Proposer:
193
208
  block_tables=device_array(
194
209
  self.mesh, block_tables))
195
210
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
196
- aux_hidden_states, attn_metadata.query_start_loc, input_ids,
197
- next_token_ids, num_reqs)
211
+ self.state, aux_hidden_states, attn_metadata.query_start_loc,
212
+ input_ids, next_token_ids, num_reqs)
198
213
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
199
214
 
200
215
  # Host copies from the metadata prepared by the runner.
@@ -258,12 +273,13 @@ class Eagle3Proposer:
258
273
 
259
274
  attn_metadata = replace(attn_metadata, block_tables=block_tables)
260
275
  return self._filter_token_and_prepare_initial_inputs(
261
- token_indices, query_start_loc, seq_lens, input_ids,
276
+ self.state, token_indices, query_start_loc, seq_lens, input_ids,
262
277
  aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
263
278
 
264
279
  @functools.partial(jax.jit, static_argnums=(0, ))
265
280
  def _filter_token_and_prepare_initial_inputs(
266
281
  self,
282
+ state: nnx.State,
267
283
  token_indices: jax.Array,
268
284
  query_start_loc: jax.Array,
269
285
  seq_lens: jax.Array,
@@ -291,35 +307,51 @@ class Eagle3Proposer:
291
307
  )
292
308
 
293
309
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
294
- [h[token_indices] for h in aux_hidden_states], query_start_loc,
295
- target_token_ids, next_token_ids, num_reqs)
310
+ state, [h[token_indices] for h in aux_hidden_states],
311
+ query_start_loc, target_token_ids, next_token_ids, num_reqs)
296
312
 
297
313
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
298
314
 
299
315
  @functools.partial(jax.jit, static_argnums=(0, ))
300
316
  def _select_draft_token_ids(
301
317
  self,
318
+ state: nnx.State,
302
319
  hidden_states: jax.Array,
303
320
  last_token_indices: jax.Array,
304
321
  ) -> jax.Array:
305
322
  sample_hidden_states = hidden_states[last_token_indices]
306
- return self._get_draft_token_ids(sample_hidden_states)
323
+ sample_hidden_states = lax.with_sharding_constraint(
324
+ sample_hidden_states,
325
+ NamedSharding(self.mesh, PartitionSpec(None, None)))
326
+ return self._get_draft_token_ids(state, sample_hidden_states)
307
327
 
308
328
  @functools.partial(jax.jit, static_argnums=(0, ))
309
- def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
329
+ def _get_draft_token_ids(self, state: nnx.State,
330
+ hidden_states: jax.Array) -> jax.Array:
310
331
  lora_metadata = None
311
- logits = self.compute_logits_fn(self.state, hidden_states,
312
- lora_metadata)
313
- return jnp.argmax(logits, axis=-1)
332
+ logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
333
+ draft_token_ids = jnp.argmax(logits, axis=-1)
334
+ return lax.with_sharding_constraint(
335
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
314
336
 
315
337
  @functools.partial(jax.jit, static_argnums=(0, ))
316
338
  def _select_inputs_for_loop_speculation(
317
- self, positions: jax.Array, residual: jax.Array,
339
+ self, state: nnx.State, positions: jax.Array, residual: jax.Array,
318
340
  hidden_states: jax.Array,
319
341
  last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
320
- return positions[last_token_indices], residual[
321
- last_token_indices], self._select_draft_token_ids(
322
- hidden_states, last_token_indices)
342
+ positions = positions[last_token_indices]
343
+ residual = residual[last_token_indices]
344
+ draft_token_ids = self._select_draft_token_ids(state, hidden_states,
345
+ last_token_indices)
346
+
347
+ positions = lax.with_sharding_constraint(
348
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
349
+ residual = lax.with_sharding_constraint(
350
+ residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
351
+ draft_token_ids = lax.with_sharding_constraint(
352
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
353
+
354
+ return positions, residual, draft_token_ids
323
355
 
324
356
  def propose(
325
357
  self,
@@ -346,11 +378,11 @@ class Eagle3Proposer:
346
378
 
347
379
  if self.num_speculative_tokens == 1:
348
380
  return kv_caches, self._select_draft_token_ids(
349
- hidden_states, last_token_indices)
381
+ self.state, hidden_states, last_token_indices)
350
382
 
351
383
  positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
352
- attn_metadata.input_positions, residual[0], hidden_states,
353
- last_token_indices)
384
+ self.state, attn_metadata.input_positions, residual[0],
385
+ hidden_states, last_token_indices)
354
386
 
355
387
  draft_token_ids_list = [draft_token_ids]
356
388
 
@@ -375,7 +407,8 @@ class Eagle3Proposer:
375
407
  attn_metadata,
376
408
  )
377
409
  hidden_states = residual[0]
378
- draft_token_ids = self._get_draft_token_ids(new_hidden_states)
410
+ draft_token_ids = self._get_draft_token_ids(
411
+ self.state, new_hidden_states)
379
412
  draft_token_ids_list.append(draft_token_ids)
380
413
 
381
414
  # [batch_size, num_speculative_tokens]
tpu_inference/utils.py CHANGED
@@ -8,11 +8,14 @@ from typing import Any, Callable, List, Tuple
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  import numpy as np
11
+ import torch
11
12
  from jax._src import dtypes
12
13
  from jax._src import mesh as mesh_lib
13
14
  from jax._src import xla_bridge as xb
14
15
  from jax._src.lib import xla_client as xc
16
+ from jax._src.numpy.scalar_types import _ScalarMeta
15
17
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
18
+ from torchax.ops.mappings import j2t_dtype, t2j_dtype
16
19
  from vllm import envs as vllm_envs
17
20
  from vllm import utils
18
21
 
@@ -23,17 +26,36 @@ GBYTES = 1024 * 1024 * 1024
23
26
  TPU_HEAD_SIZE_ALIGNMENT = 128
24
27
  TPU_SECOND_LAST_MINOR = 8
25
28
 
26
- # This is used to translate from a string name for a dtype
27
- # to formal jax.numpy DType. One use case for this is
28
- # converting the `--kv_cache_dtype` flag to a dtype.
29
- TPU_STR_DTYPE_TO_JAX_DTYPE = {
30
- "bfloat16": jnp.bfloat16,
29
+ # Map vllm dtype string that doesn't exactly match jax dtype string name.
30
+ _VLLM_DTYPE_STR_TO_JAX_DTYPE = {
31
31
  "fp8": jnp.float8_e4m3fn,
32
- "fp8_e4m3": jnp.float8_e4m3,
32
+ "fp8_e4m3": jnp.float8_e4m3fn,
33
33
  "fp8_e5m2": jnp.float8_e5m2,
34
- "int8": jnp.int8,
35
34
  }
36
35
 
36
+
37
+ def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
38
+ if isinstance(dtype, str):
39
+ if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
40
+ return dict_dtype
41
+ return jnp.dtype(dtype)
42
+ elif isinstance(dtype, torch.dtype):
43
+ return t2j_dtype(dtype)
44
+ elif isinstance(dtype, jnp.dtype):
45
+ return dtype
46
+ elif isinstance(dtype, _ScalarMeta):
47
+ return dtype.dtype
48
+ else:
49
+ raise ValueError(f"Argument is unsupported data type {type(dtype)}")
50
+
51
+
52
+ def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
53
+ # Use jax dtype as an intermediate dtype which we'll be used to convert it
54
+ # into torch dtype.
55
+ dtype = to_jax_dtype(dtype)
56
+ return j2t_dtype(dtype)
57
+
58
+
37
59
  _megacore = False
38
60
  logger = init_logger(__name__)
39
61
 
@@ -295,8 +317,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
295
317
  Returns:
296
318
  jnp.dtype: The JAX dtype.
297
319
  """
298
- str_dtype = str_dtype.lower().strip()
299
- return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
320
+ # TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
321
+ return to_jax_dtype(str_dtype)
300
322
 
301
323
 
302
324
  def time_function(func):
@@ -108,7 +108,7 @@ class TPUWorker:
108
108
 
109
109
  if self.model_config.trust_remote_code:
110
110
  # note: lazy import to avoid importing torch before initializing
111
- from vllm.utils import init_cached_hf_modules
111
+ from vllm.utils.import_utils import init_cached_hf_modules
112
112
 
113
113
  init_cached_hf_modules()
114
114
 
@@ -357,7 +357,7 @@ class TPUWorker:
357
357
  if is_start:
358
358
  options = jax.profiler.ProfileOptions()
359
359
  # default: https://docs.jax.dev/en/latest/profiling.html#general-options
360
- options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
360
+ options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
361
361
  options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
362
362
  jax.profiler.start_trace(self.profile_dir,
363
363
  profiler_options=options)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1.dev202511270815
3
+ Version: 0.11.1.dev202512030818
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -1,9 +1,9 @@
1
1
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  tests/test_base.py,sha256=Ct5WFRMHL7IHEIxk8FrzAvO8m0xFuDpzDBKkAKKAL2Q,7341
3
- tests/test_envs.py,sha256=Woyfp_d5HS-uTGo4_u9dYlBbgmhfIEoFb-Rx_k7YXD4,6298
3
+ tests/test_envs.py,sha256=h502VxL2gvhECm8u5uDh5JTGvhFf_DfQO88SpqOFMzE,7135
4
4
  tests/test_quantization.py,sha256=IT5ASyS1uuWcxc22kRtBcA-V4j3Z3hb7pMztm3GOlBs,34445
5
5
  tests/test_tpu_info.py,sha256=ZrwlMsp8ffITkS_b8Q1t_QG-a-WVAd4NUcjHhGibcsI,4670
6
- tests/test_utils.py,sha256=Mta5ZzYCgRAh1-BjcOvvx9iQ9DnnXLps7oDHxVQp2yE,8236
6
+ tests/test_utils.py,sha256=GIXLdd-x4gnqSLrySXGk22phqPc8MegFd7ph1Jj8OcU,8182
7
7
  tests/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  tests/core/test_core_tpu.py,sha256=r496rk1eOsK_F4nvm9zprl_T-RcO6eCUb7LuVReOZno,21413
9
9
  tests/core/test_disagg_executor.py,sha256=QdE2YZs08EyDDCmSjhiXkXqQ9BJTgO6csr_E1xkkfSg,2256
@@ -26,10 +26,10 @@ tests/lora/test_lora.py,sha256=wJiF1P1BDnPN8TLX2tlFtdZ_QCkV-S9nPl6_uR6DqFc,4439
26
26
  tests/lora/utils.py,sha256=rY0tDZEZe58ye4-ykwrTnsiWuLcaEG57N_Rua90bDXI,2726
27
27
  tpu_inference/__init__.py,sha256=p4MaepRdN7723FUNE-3pOMxZWjFn4_TVFgjrNyty4JE,2304
28
28
  tpu_inference/env_override.py,sha256=pmL7lfs_rGCP92ya3wuWuudsCYeOMZ6tFZY82A4KkQc,365
29
- tpu_inference/envs.py,sha256=hoPuT0SyLCxqyZ0QJIha6EXSZv2TpACfmENuiT0iJMM,3956
29
+ tpu_inference/envs.py,sha256=ugze6VdQ_hG1IxUCbcgXZq7a22fZ-Lora3V_fkFOefw,5714
30
30
  tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
31
31
  tpu_inference/tpu_info.py,sha256=3iilHRQSFjwMJwhKcuuawTm7mhwkgHbj4zi6CiAySrs,2265
32
- tpu_inference/utils.py,sha256=Ddsx2CY2ARe46RZL27URzXCN3P6pMcKWB-APXUB8sHs,10098
32
+ tpu_inference/utils.py,sha256=mHbjI8fxInPxagLsSUg-R3DzSz-X7WYNdoorPYoE3hg,10855
33
33
  tpu_inference/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
34
  tpu_inference/core/core_tpu.py,sha256=WDD3koE_j1QhWS2BbMA2aQOZayPZm4tYPvzL4YCX2jY,33294
35
35
  tpu_inference/core/disagg_executor.py,sha256=HZpgYMVxRxm0RQxO4l8IDYBWJ6Z3Tac6xavc5otcirc,4657
@@ -38,10 +38,10 @@ tpu_inference/core/sched/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ
38
38
  tpu_inference/core/sched/dp_scheduler.py,sha256=mKs8Ms46szdlBfo8hjdqis2ZKAZbcKnHAGfEr0X5R8g,22527
39
39
  tpu_inference/distributed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
40
  tpu_inference/distributed/jax_parallel_state.py,sha256=5_xCwcL03lFPUoSO_OP7hIVKpUFroW1m-jVO7R6FbUc,2223
41
- tpu_inference/distributed/tpu_connector.py,sha256=w_gOI6hX7NWefaxN_9XH9TXReGElOyFifdDHpPswotM,29696
41
+ tpu_inference/distributed/tpu_connector.py,sha256=kLaTwy6BrAThJeFkd1soJ47bBo5iGp4GjUJs7xFx4Tg,29696
42
42
  tpu_inference/distributed/utils.py,sha256=1KIREn28Zg10O-MSUkVQMRzS09WoGc_VLGOX4QTFJac,1504
43
43
  tpu_inference/executors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
- tpu_inference/executors/ray_distributed_executor.py,sha256=emYfSFJ3kluEmi6mlfnvxSUrC_mGVRVcjrUqUH2MR4g,16122
44
+ tpu_inference/executors/ray_distributed_executor.py,sha256=9CnzWb8aurH1B0tJfMHB73F-RQBGqSf5DnymetBvZ5o,16225
45
45
  tpu_inference/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  tpu_inference/experimental/llama3_jax_stashed.py,sha256=YK1oSIfto9ALo-HB45XfSrbq9XgVbE4m2C-9zRwmSzI,10913
47
47
  tpu_inference/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -68,7 +68,7 @@ tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256
68
68
  tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
69
69
  tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
70
70
  tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=O179Fft5KpuN5LIFx3SghWXJJUqh3Og-xqfO4Z8QXYU,57032
71
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=X_d-SMGNc3zv396uQGL-73oLzp5ZQP8gaubMDebM_AY,57426
71
+ tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=ArwrqIQiKIop_jaDKAMw656YHQ3IFZ0sRu9Cgycrtko,59858
72
72
  tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=k3LwduhZO85cJ-pSgnGN0c2Nn8eNeQq4eA94KUXJzMw,142198
73
73
  tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py,sha256=P3_ivi8iUz5QMU_3pgpl4Bkbmn0q0NpDtVJX39haRQA,11208
74
74
  tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=1N_ozjKboDYLteFJndWoLXNudj2z53rGXMkELa5Z9tY,1102
@@ -78,7 +78,7 @@ tpu_inference/layers/common/attention_interface.py,sha256=SQZ-1I32Jqg7GGI-z4BVib
78
78
  tpu_inference/layers/common/attention_metadata.py,sha256=St8ZatbY1D7xQACKJH459jMgp3oTP3AQ36mi9FZdrPU,850
79
79
  tpu_inference/layers/common/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
80
80
  tpu_inference/layers/common/quant_methods.py,sha256=mQSxZ44-QQtm22C_8ViejnP1cP2Dv6yc2YaP6oMKJeQ,185
81
- tpu_inference/layers/common/sharding.py,sha256=KUPd5HxfmQZ01wc3lGEusI6QYHnZxFp7-Ur-0b8hOH8,25256
81
+ tpu_inference/layers/common/sharding.py,sha256=sjbwkDr2fP26Ob8f5cSDeDifr3eWFZMDHU4MKr7pIgQ,25217
82
82
  tpu_inference/layers/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
83
  tpu_inference/layers/jax/base.py,sha256=Vhts6ZMwNCZ8LbnEXeB0rl3nHdS5hDJWX7HEa7Fl7yE,5775
84
84
  tpu_inference/layers/jax/constants.py,sha256=NcYg0zAf3ClfP7YMYdYu_F1GngOzZaIxIAHBZDunKw4,2755
@@ -108,7 +108,7 @@ tpu_inference/layers/vllm/sharding.py,sha256=as7CF8UKTF3ToymwRY5Pi8uzwJk0P1sHPkW
108
108
  tpu_inference/layers/vllm/quantization/__init__.py,sha256=SEppGayBzzQ5tsXLSy99aqilkAawQwYxnv2alCg6-ZU,1777
109
109
  tpu_inference/layers/vllm/quantization/awq.py,sha256=-8ZmjGvSKJB6_JuwSctNWt8xHWq4VSvK_AK9iahlgCo,8495
110
110
  tpu_inference/layers/vllm/quantization/common.py,sha256=8XD64pPa077c9HThFhLFVHlDL9YBafnYwp6rp6gR44E,4432
111
- tpu_inference/layers/vllm/quantization/mxfp4.py,sha256=3T3M0qLoW7GKdqbv_toMoQP39lV1qCoQ8Uc8l8aq1hg,14495
111
+ tpu_inference/layers/vllm/quantization/mxfp4.py,sha256=o661uiSvLvWGr8hQMl7TqYXJyALPREtNWlKHAM9AUrw,14541
112
112
  tpu_inference/layers/vllm/quantization/unquantized.py,sha256=nSRBzVurTiQQkF9FuSTshfRwfxfzs54E2_4eK7Eyhj0,15345
113
113
  tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
114
114
  tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=6idEyy3e849fZ1UeNvc9eSHYX7e6qvohrJa_d_D9MBk,5285
@@ -121,7 +121,7 @@ tpu_inference/lora/torch_lora_ops.py,sha256=pr3N7DVfkn3ANijUC6dBoiCtIJW4fdJpKdC3
121
121
  tpu_inference/lora/torch_punica_tpu.py,sha256=qTnXZGLoOgvukSxeunO_SfpPTlkq9GlMj9H7zVYg9LE,12680
122
122
  tpu_inference/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
123
123
  tpu_inference/models/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
124
- tpu_inference/models/common/model_loader.py,sha256=_eFPCM0_ssjoVdj38rMLR-qnJ7iW_Ox_hc8JiWycxNs,19923
124
+ tpu_inference/models/common/model_loader.py,sha256=b3aigca81gMVJt42oF2aoRohQHjBBe3oK3IPblZAaUM,19996
125
125
  tpu_inference/models/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
126
126
  tpu_inference/models/jax/deepseek_v3.py,sha256=SKOHVEC-_2NLxBnzBzbu5tu0d6FTlAEiI1EefGaO2QE,40047
127
127
  tpu_inference/models/jax/gpt_oss.py,sha256=Vw4LRB5Kp6hbA2hjZGFS8kiEqOCjf881XH2JNtu2S1I,20924
@@ -139,36 +139,36 @@ tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=rrIrQWidkUnGilBHKNpdY
139
139
  tpu_inference/models/jax/utils/weight_utils.py,sha256=qFU53jPHPvIcs_EOdIH80oNojpUp7GdSY2E6NZNsjvM,21376
140
140
  tpu_inference/models/jax/utils/quantization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
141
141
  tpu_inference/models/jax/utils/quantization/mxfp4_utils.py,sha256=boGnqJCRIOf5nedAxQ8_IUTV6Rfll10DXnRC40BeeE8,3682
142
- tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=xgKoKB7AM3TYPxzVgEGLTK9ebQH2Kx8mNuO0heovkmk,26778
142
+ tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=rzAFU3OtQvg8w8ow0V15rMljAsa4SBrwOye6OI8Bty4,26530
143
143
  tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml,sha256=d_YHPtaRJ_7PBrPijSzJGnVeoJO62tKIGqrgFqpYT1k,137
144
144
  tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml,sha256=b7SyL75HuSTj3fN9_ZLCK_CDiccL5DGq_DddGmxj_qk,170
145
145
  tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml,sha256=0Qwij71zj9k6rmrUNd8Q5df9YYfkoJ1ZkgMAHxQy81k,128
146
146
  tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml,sha256=lGec0UwwxmNPNgKPSsTsCMSXNJjhw507KMtM2NsSCMw,152
147
147
  tpu_inference/models/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
148
- tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=Cfsd0PjuR-hCoiCwPVdzjkE6AmHLYY1JQyBERyFkl-E,12344
148
+ tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=3EcaD_1vZuyAZBfDtm5u_qfCahQU28qR4rAUraNAFqs,12305
149
149
  tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=yxlJHPmRQIAwlb1MmHK3xfXokgIkJ-evNU4PgyoJUdg,1187
150
150
  tpu_inference/platforms/__init__.py,sha256=lQCrKddS_GcGpCbeogvz9zOZD1mQw5bBsiw8On46qFQ,74
151
- tpu_inference/platforms/tpu_platform.py,sha256=W_19FvlFxPs0V0vcr3NI6oVBG-eA3eBV2-H0Cr3Kyco,10879
151
+ tpu_inference/platforms/tpu_platform.py,sha256=F4jjPEFHFUTxdfWZYTBuUVJt6SYTFeWEKmrl74sX-Zk,10663
152
152
  tpu_inference/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
153
153
  tpu_inference/runner/block_table.py,sha256=K3Ic8EgPM08d_C5nEN60mxoRydlaQWySAemf_8Q_qVw,4175
154
- tpu_inference/runner/compilation_manager.py,sha256=DR5TkHGin2QbRIWZlkkD5sUdxonTgr35pMYyrSwGk_U,37585
154
+ tpu_inference/runner/compilation_manager.py,sha256=dU0Yk8f0LtRTBe2q0iB3xcMSRco_WPsj2wS6zZJ8WhY,40375
155
155
  tpu_inference/runner/input_batch.py,sha256=bx221NX2IOWzrtopss-B-2ZKW4y-U6nQpG09PjpUziw,18273
156
156
  tpu_inference/runner/kv_cache.py,sha256=F4dzW2d53xuxkFUn0oKzwE6VklGUeVm-QM19NVfIQDU,4577
157
- tpu_inference/runner/kv_cache_manager.py,sha256=2iwEc1vXt-p8kkKByvlqy4IKi5bOqFpOlrq0QmHHnQA,22450
157
+ tpu_inference/runner/kv_cache_manager.py,sha256=N0a896CE7Zrs_d4ZSSzRdqgjV1It57RBDSIpOzkRqro,22013
158
158
  tpu_inference/runner/lora_utils.py,sha256=B4xMCgXGJ4VNdePvn89HH3tIZ-gYsQ7Vq_YCiYIATEY,3843
159
159
  tpu_inference/runner/multimodal_manager.py,sha256=azEPdHOwz8CN11MQmorGdtrCLbFaTCxdWyuEsZTzjYM,9778
160
- tpu_inference/runner/persistent_batch_manager.py,sha256=KERSfKy6XjMejnbtPGI3hzoYAHJLeCxmpZVYPqBCago,11156
160
+ tpu_inference/runner/persistent_batch_manager.py,sha256=Otu67vOTf1_HKAMZgPDDHlRvvZ3YVJdz-QderH4qOII,13263
161
161
  tpu_inference/runner/speculative_decoding_manager.py,sha256=I3FDWKh2dn6nV8LgTGfCTwMKYnxQsTPpBIrmaJngXHs,10215
162
162
  tpu_inference/runner/structured_decoding_manager.py,sha256=gZQKQUFxh6xYYH9eGTdbguqk8hc2WwTrIdMMuCcbymE,3573
163
- tpu_inference/runner/tpu_runner.py,sha256=A5Ed4NL6CPNv7o7u6zqmdPbmmPyiIxFcwWlJ0E5_fpU,77991
164
- tpu_inference/runner/utils.py,sha256=ZnWUoNo-7INeB0mdXti1jwUOdbmxyExznOs-crRTQLk,17126
163
+ tpu_inference/runner/tpu_runner.py,sha256=NBDKfSGShHmYpudrtGfo1hnVSQTcLpZV_nPiXEo7JPQ,79439
164
+ tpu_inference/runner/utils.py,sha256=lKqL5nxGTk7ufzJRNdp4udn2bPu3jIX52W7akXgSrHc,17133
165
165
  tpu_inference/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
166
166
  tpu_inference/spec_decode/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
167
- tpu_inference/spec_decode/jax/eagle3.py,sha256=ci-yPOSlAfsuwoR_QAGrywtDLMbicjOhl787o9MahYg,17376
167
+ tpu_inference/spec_decode/jax/eagle3.py,sha256=FxP0uWeQlHlgCpt1nY3FUd4lKlegKJljHyc05jJucaQ,19104
168
168
  tpu_inference/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
169
- tpu_inference/worker/tpu_worker.py,sha256=4QH83MzYCnubwWXTvPEc2BmiU2R5KILci6PawDNpnHM,20670
170
- tpu_inference-0.11.1.dev202511270815.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
171
- tpu_inference-0.11.1.dev202511270815.dist-info/METADATA,sha256=nAfRlJUVGJkVnroEwrw0EsiO9CqWJLrGgHkt5AORBJk,5517
172
- tpu_inference-0.11.1.dev202511270815.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
173
- tpu_inference-0.11.1.dev202511270815.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
174
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD,,
169
+ tpu_inference/worker/tpu_worker.py,sha256=LnZcSNxdhh0NkoWXxS5bZ0bsTMduSANehy2wELAaVsY,20672
170
+ tpu_inference-0.11.1.dev202512030818.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
171
+ tpu_inference-0.11.1.dev202512030818.dist-info/METADATA,sha256=oLzYFTCTvHDQLfyWoc8qV4IMYCoLRTiHECf08oT_bFA,5517
172
+ tpu_inference-0.11.1.dev202512030818.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
173
+ tpu_inference-0.11.1.dev202512030818.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
174
+ tpu_inference-0.11.1.dev202512030818.dist-info/RECORD,,