tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +60 -7
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +72 -19
- tpu_inference/layers/common/sharding.py +3 -4
- tpu_inference/layers/vllm/quantization/mxfp4.py +2 -1
- tpu_inference/models/common/model_loader.py +3 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/vllm/vllm_model_wrapper.py +1 -2
- tpu_inference/platforms/tpu_platform.py +13 -20
- tpu_inference/runner/compilation_manager.py +87 -27
- tpu_inference/runner/kv_cache_manager.py +8 -15
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/tpu_runner.py +68 -45
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +52 -19
- tpu_inference/utils.py +31 -9
- tpu_inference/worker/tpu_worker.py +2 -2
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +1 -1
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +25 -25
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,9 @@ from typing import Any, Optional
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
|
+
from flax import nnx
|
|
10
|
+
from jax import lax
|
|
11
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
12
|
from vllm.config import VllmConfig
|
|
10
13
|
|
|
11
14
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
@@ -127,6 +130,17 @@ class Eagle3Proposer:
|
|
|
127
130
|
max_num_blocks_per_req)
|
|
128
131
|
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
|
|
129
132
|
|
|
133
|
+
positions = lax.with_sharding_constraint(
|
|
134
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
135
|
+
clamped_positions = lax.with_sharding_constraint(
|
|
136
|
+
clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
137
|
+
new_seq_lens = lax.with_sharding_constraint(
|
|
138
|
+
new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
139
|
+
query_start_loc = lax.with_sharding_constraint(
|
|
140
|
+
query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
|
|
141
|
+
new_block_tables = lax.with_sharding_constraint(
|
|
142
|
+
new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
143
|
+
|
|
130
144
|
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
|
|
131
145
|
|
|
132
146
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
@@ -138,6 +152,7 @@ class Eagle3Proposer:
|
|
|
138
152
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
139
153
|
def _prepare_hidden_states_and_input_ids(
|
|
140
154
|
self,
|
|
155
|
+
state: nnx.State,
|
|
141
156
|
aux_hidden_states: tuple[jax.Array, ...],
|
|
142
157
|
query_start_loc: jax.Array,
|
|
143
158
|
target_token_ids: jax.Array,
|
|
@@ -146,7 +161,7 @@ class Eagle3Proposer:
|
|
|
146
161
|
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
147
162
|
target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
|
|
148
163
|
target_hidden_states = self.combine_hidden_states_fn(
|
|
149
|
-
|
|
164
|
+
state, target_hidden_states)
|
|
150
165
|
|
|
151
166
|
input_ids, last_token_indices = self._prepare_input_ids(
|
|
152
167
|
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
@@ -193,8 +208,8 @@ class Eagle3Proposer:
|
|
|
193
208
|
block_tables=device_array(
|
|
194
209
|
self.mesh, block_tables))
|
|
195
210
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
196
|
-
aux_hidden_states, attn_metadata.query_start_loc,
|
|
197
|
-
next_token_ids, num_reqs)
|
|
211
|
+
self.state, aux_hidden_states, attn_metadata.query_start_loc,
|
|
212
|
+
input_ids, next_token_ids, num_reqs)
|
|
198
213
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
199
214
|
|
|
200
215
|
# Host copies from the metadata prepared by the runner.
|
|
@@ -258,12 +273,13 @@ class Eagle3Proposer:
|
|
|
258
273
|
|
|
259
274
|
attn_metadata = replace(attn_metadata, block_tables=block_tables)
|
|
260
275
|
return self._filter_token_and_prepare_initial_inputs(
|
|
261
|
-
token_indices, query_start_loc, seq_lens, input_ids,
|
|
276
|
+
self.state, token_indices, query_start_loc, seq_lens, input_ids,
|
|
262
277
|
aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
|
|
263
278
|
|
|
264
279
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
265
280
|
def _filter_token_and_prepare_initial_inputs(
|
|
266
281
|
self,
|
|
282
|
+
state: nnx.State,
|
|
267
283
|
token_indices: jax.Array,
|
|
268
284
|
query_start_loc: jax.Array,
|
|
269
285
|
seq_lens: jax.Array,
|
|
@@ -291,35 +307,51 @@ class Eagle3Proposer:
|
|
|
291
307
|
)
|
|
292
308
|
|
|
293
309
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
294
|
-
[h[token_indices] for h in aux_hidden_states],
|
|
295
|
-
target_token_ids, next_token_ids, num_reqs)
|
|
310
|
+
state, [h[token_indices] for h in aux_hidden_states],
|
|
311
|
+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
296
312
|
|
|
297
313
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
298
314
|
|
|
299
315
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
300
316
|
def _select_draft_token_ids(
|
|
301
317
|
self,
|
|
318
|
+
state: nnx.State,
|
|
302
319
|
hidden_states: jax.Array,
|
|
303
320
|
last_token_indices: jax.Array,
|
|
304
321
|
) -> jax.Array:
|
|
305
322
|
sample_hidden_states = hidden_states[last_token_indices]
|
|
306
|
-
|
|
323
|
+
sample_hidden_states = lax.with_sharding_constraint(
|
|
324
|
+
sample_hidden_states,
|
|
325
|
+
NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
326
|
+
return self._get_draft_token_ids(state, sample_hidden_states)
|
|
307
327
|
|
|
308
328
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
309
|
-
def _get_draft_token_ids(self,
|
|
329
|
+
def _get_draft_token_ids(self, state: nnx.State,
|
|
330
|
+
hidden_states: jax.Array) -> jax.Array:
|
|
310
331
|
lora_metadata = None
|
|
311
|
-
logits = self.compute_logits_fn(
|
|
312
|
-
|
|
313
|
-
return
|
|
332
|
+
logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
|
|
333
|
+
draft_token_ids = jnp.argmax(logits, axis=-1)
|
|
334
|
+
return lax.with_sharding_constraint(
|
|
335
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
314
336
|
|
|
315
337
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
316
338
|
def _select_inputs_for_loop_speculation(
|
|
317
|
-
self, positions: jax.Array, residual: jax.Array,
|
|
339
|
+
self, state: nnx.State, positions: jax.Array, residual: jax.Array,
|
|
318
340
|
hidden_states: jax.Array,
|
|
319
341
|
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
342
|
+
positions = positions[last_token_indices]
|
|
343
|
+
residual = residual[last_token_indices]
|
|
344
|
+
draft_token_ids = self._select_draft_token_ids(state, hidden_states,
|
|
345
|
+
last_token_indices)
|
|
346
|
+
|
|
347
|
+
positions = lax.with_sharding_constraint(
|
|
348
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
349
|
+
residual = lax.with_sharding_constraint(
|
|
350
|
+
residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
351
|
+
draft_token_ids = lax.with_sharding_constraint(
|
|
352
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
353
|
+
|
|
354
|
+
return positions, residual, draft_token_ids
|
|
323
355
|
|
|
324
356
|
def propose(
|
|
325
357
|
self,
|
|
@@ -346,11 +378,11 @@ class Eagle3Proposer:
|
|
|
346
378
|
|
|
347
379
|
if self.num_speculative_tokens == 1:
|
|
348
380
|
return kv_caches, self._select_draft_token_ids(
|
|
349
|
-
hidden_states, last_token_indices)
|
|
381
|
+
self.state, hidden_states, last_token_indices)
|
|
350
382
|
|
|
351
383
|
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
|
|
352
|
-
attn_metadata.input_positions, residual[0],
|
|
353
|
-
last_token_indices)
|
|
384
|
+
self.state, attn_metadata.input_positions, residual[0],
|
|
385
|
+
hidden_states, last_token_indices)
|
|
354
386
|
|
|
355
387
|
draft_token_ids_list = [draft_token_ids]
|
|
356
388
|
|
|
@@ -375,7 +407,8 @@ class Eagle3Proposer:
|
|
|
375
407
|
attn_metadata,
|
|
376
408
|
)
|
|
377
409
|
hidden_states = residual[0]
|
|
378
|
-
draft_token_ids = self._get_draft_token_ids(
|
|
410
|
+
draft_token_ids = self._get_draft_token_ids(
|
|
411
|
+
self.state, new_hidden_states)
|
|
379
412
|
draft_token_ids_list.append(draft_token_ids)
|
|
380
413
|
|
|
381
414
|
# [batch_size, num_speculative_tokens]
|
tpu_inference/utils.py
CHANGED
|
@@ -8,11 +8,14 @@ from typing import Any, Callable, List, Tuple
|
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
10
10
|
import numpy as np
|
|
11
|
+
import torch
|
|
11
12
|
from jax._src import dtypes
|
|
12
13
|
from jax._src import mesh as mesh_lib
|
|
13
14
|
from jax._src import xla_bridge as xb
|
|
14
15
|
from jax._src.lib import xla_client as xc
|
|
16
|
+
from jax._src.numpy.scalar_types import _ScalarMeta
|
|
15
17
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
18
|
+
from torchax.ops.mappings import j2t_dtype, t2j_dtype
|
|
16
19
|
from vllm import envs as vllm_envs
|
|
17
20
|
from vllm import utils
|
|
18
21
|
|
|
@@ -23,17 +26,36 @@ GBYTES = 1024 * 1024 * 1024
|
|
|
23
26
|
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
24
27
|
TPU_SECOND_LAST_MINOR = 8
|
|
25
28
|
|
|
26
|
-
#
|
|
27
|
-
|
|
28
|
-
# converting the `--kv_cache_dtype` flag to a dtype.
|
|
29
|
-
TPU_STR_DTYPE_TO_JAX_DTYPE = {
|
|
30
|
-
"bfloat16": jnp.bfloat16,
|
|
29
|
+
# Map vllm dtype string that doesn't exactly match jax dtype string name.
|
|
30
|
+
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
|
|
31
31
|
"fp8": jnp.float8_e4m3fn,
|
|
32
|
-
"fp8_e4m3": jnp.
|
|
32
|
+
"fp8_e4m3": jnp.float8_e4m3fn,
|
|
33
33
|
"fp8_e5m2": jnp.float8_e5m2,
|
|
34
|
-
"int8": jnp.int8,
|
|
35
34
|
}
|
|
36
35
|
|
|
36
|
+
|
|
37
|
+
def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
|
|
38
|
+
if isinstance(dtype, str):
|
|
39
|
+
if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
|
|
40
|
+
return dict_dtype
|
|
41
|
+
return jnp.dtype(dtype)
|
|
42
|
+
elif isinstance(dtype, torch.dtype):
|
|
43
|
+
return t2j_dtype(dtype)
|
|
44
|
+
elif isinstance(dtype, jnp.dtype):
|
|
45
|
+
return dtype
|
|
46
|
+
elif isinstance(dtype, _ScalarMeta):
|
|
47
|
+
return dtype.dtype
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Argument is unsupported data type {type(dtype)}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
|
|
53
|
+
# Use jax dtype as an intermediate dtype which we'll be used to convert it
|
|
54
|
+
# into torch dtype.
|
|
55
|
+
dtype = to_jax_dtype(dtype)
|
|
56
|
+
return j2t_dtype(dtype)
|
|
57
|
+
|
|
58
|
+
|
|
37
59
|
_megacore = False
|
|
38
60
|
logger = init_logger(__name__)
|
|
39
61
|
|
|
@@ -295,8 +317,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
|
295
317
|
Returns:
|
|
296
318
|
jnp.dtype: The JAX dtype.
|
|
297
319
|
"""
|
|
298
|
-
|
|
299
|
-
return
|
|
320
|
+
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
|
|
321
|
+
return to_jax_dtype(str_dtype)
|
|
300
322
|
|
|
301
323
|
|
|
302
324
|
def time_function(func):
|
|
@@ -108,7 +108,7 @@ class TPUWorker:
|
|
|
108
108
|
|
|
109
109
|
if self.model_config.trust_remote_code:
|
|
110
110
|
# note: lazy import to avoid importing torch before initializing
|
|
111
|
-
from vllm.utils import init_cached_hf_modules
|
|
111
|
+
from vllm.utils.import_utils import init_cached_hf_modules
|
|
112
112
|
|
|
113
113
|
init_cached_hf_modules()
|
|
114
114
|
|
|
@@ -357,7 +357,7 @@ class TPUWorker:
|
|
|
357
357
|
if is_start:
|
|
358
358
|
options = jax.profiler.ProfileOptions()
|
|
359
359
|
# default: https://docs.jax.dev/en/latest/profiling.html#general-options
|
|
360
|
-
options.python_tracer_level =
|
|
360
|
+
options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
361
361
|
options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
|
|
362
362
|
jax.profiler.start_trace(self.profile_dir,
|
|
363
363
|
profiler_options=options)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
tests/test_base.py,sha256=Ct5WFRMHL7IHEIxk8FrzAvO8m0xFuDpzDBKkAKKAL2Q,7341
|
|
3
|
-
tests/test_envs.py,sha256=
|
|
3
|
+
tests/test_envs.py,sha256=h502VxL2gvhECm8u5uDh5JTGvhFf_DfQO88SpqOFMzE,7135
|
|
4
4
|
tests/test_quantization.py,sha256=IT5ASyS1uuWcxc22kRtBcA-V4j3Z3hb7pMztm3GOlBs,34445
|
|
5
5
|
tests/test_tpu_info.py,sha256=ZrwlMsp8ffITkS_b8Q1t_QG-a-WVAd4NUcjHhGibcsI,4670
|
|
6
|
-
tests/test_utils.py,sha256=
|
|
6
|
+
tests/test_utils.py,sha256=GIXLdd-x4gnqSLrySXGk22phqPc8MegFd7ph1Jj8OcU,8182
|
|
7
7
|
tests/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
tests/core/test_core_tpu.py,sha256=r496rk1eOsK_F4nvm9zprl_T-RcO6eCUb7LuVReOZno,21413
|
|
9
9
|
tests/core/test_disagg_executor.py,sha256=QdE2YZs08EyDDCmSjhiXkXqQ9BJTgO6csr_E1xkkfSg,2256
|
|
@@ -26,10 +26,10 @@ tests/lora/test_lora.py,sha256=wJiF1P1BDnPN8TLX2tlFtdZ_QCkV-S9nPl6_uR6DqFc,4439
|
|
|
26
26
|
tests/lora/utils.py,sha256=rY0tDZEZe58ye4-ykwrTnsiWuLcaEG57N_Rua90bDXI,2726
|
|
27
27
|
tpu_inference/__init__.py,sha256=p4MaepRdN7723FUNE-3pOMxZWjFn4_TVFgjrNyty4JE,2304
|
|
28
28
|
tpu_inference/env_override.py,sha256=pmL7lfs_rGCP92ya3wuWuudsCYeOMZ6tFZY82A4KkQc,365
|
|
29
|
-
tpu_inference/envs.py,sha256=
|
|
29
|
+
tpu_inference/envs.py,sha256=ugze6VdQ_hG1IxUCbcgXZq7a22fZ-Lora3V_fkFOefw,5714
|
|
30
30
|
tpu_inference/logger.py,sha256=HQCz7NefmbturuhOC7-3Ixbtcdgoz4g9FHh2RB6o8cc,334
|
|
31
31
|
tpu_inference/tpu_info.py,sha256=3iilHRQSFjwMJwhKcuuawTm7mhwkgHbj4zi6CiAySrs,2265
|
|
32
|
-
tpu_inference/utils.py,sha256=
|
|
32
|
+
tpu_inference/utils.py,sha256=mHbjI8fxInPxagLsSUg-R3DzSz-X7WYNdoorPYoE3hg,10855
|
|
33
33
|
tpu_inference/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
34
|
tpu_inference/core/core_tpu.py,sha256=WDD3koE_j1QhWS2BbMA2aQOZayPZm4tYPvzL4YCX2jY,33294
|
|
35
35
|
tpu_inference/core/disagg_executor.py,sha256=HZpgYMVxRxm0RQxO4l8IDYBWJ6Z3Tac6xavc5otcirc,4657
|
|
@@ -38,10 +38,10 @@ tpu_inference/core/sched/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ
|
|
|
38
38
|
tpu_inference/core/sched/dp_scheduler.py,sha256=mKs8Ms46szdlBfo8hjdqis2ZKAZbcKnHAGfEr0X5R8g,22527
|
|
39
39
|
tpu_inference/distributed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
40
40
|
tpu_inference/distributed/jax_parallel_state.py,sha256=5_xCwcL03lFPUoSO_OP7hIVKpUFroW1m-jVO7R6FbUc,2223
|
|
41
|
-
tpu_inference/distributed/tpu_connector.py,sha256=
|
|
41
|
+
tpu_inference/distributed/tpu_connector.py,sha256=kLaTwy6BrAThJeFkd1soJ47bBo5iGp4GjUJs7xFx4Tg,29696
|
|
42
42
|
tpu_inference/distributed/utils.py,sha256=1KIREn28Zg10O-MSUkVQMRzS09WoGc_VLGOX4QTFJac,1504
|
|
43
43
|
tpu_inference/executors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
44
|
-
tpu_inference/executors/ray_distributed_executor.py,sha256=
|
|
44
|
+
tpu_inference/executors/ray_distributed_executor.py,sha256=9CnzWb8aurH1B0tJfMHB73F-RQBGqSf5DnymetBvZ5o,16225
|
|
45
45
|
tpu_inference/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
46
|
tpu_inference/experimental/llama3_jax_stashed.py,sha256=YK1oSIfto9ALo-HB45XfSrbq9XgVbE4m2C-9zRwmSzI,10913
|
|
47
47
|
tpu_inference/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -68,7 +68,7 @@ tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py,sha256
|
|
|
68
68
|
tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py,sha256=mw80bXBGenroGdrITV0F_EaI2s-Z9KWwqU9WodvJg14,97919
|
|
69
69
|
tpu_inference/kernels/ragged_paged_attention/v3/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
70
70
|
tpu_inference/kernels/ragged_paged_attention/v3/kernel.py,sha256=O179Fft5KpuN5LIFx3SghWXJJUqh3Og-xqfO4Z8QXYU,57032
|
|
71
|
-
tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=
|
|
71
|
+
tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py,sha256=ArwrqIQiKIop_jaDKAMw656YHQ3IFZ0sRu9Cgycrtko,59858
|
|
72
72
|
tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py,sha256=k3LwduhZO85cJ-pSgnGN0c2Nn8eNeQq4eA94KUXJzMw,142198
|
|
73
73
|
tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py,sha256=P3_ivi8iUz5QMU_3pgpl4Bkbmn0q0NpDtVJX39haRQA,11208
|
|
74
74
|
tpu_inference/kernels/ragged_paged_attention/v3/util.py,sha256=1N_ozjKboDYLteFJndWoLXNudj2z53rGXMkELa5Z9tY,1102
|
|
@@ -78,7 +78,7 @@ tpu_inference/layers/common/attention_interface.py,sha256=SQZ-1I32Jqg7GGI-z4BVib
|
|
|
78
78
|
tpu_inference/layers/common/attention_metadata.py,sha256=St8ZatbY1D7xQACKJH459jMgp3oTP3AQ36mi9FZdrPU,850
|
|
79
79
|
tpu_inference/layers/common/binary_search.py,sha256=ZQi-z1wG6WTcfVQXeTGOZokX4K1DSf9kCzqfrhEU8lk,12320
|
|
80
80
|
tpu_inference/layers/common/quant_methods.py,sha256=mQSxZ44-QQtm22C_8ViejnP1cP2Dv6yc2YaP6oMKJeQ,185
|
|
81
|
-
tpu_inference/layers/common/sharding.py,sha256=
|
|
81
|
+
tpu_inference/layers/common/sharding.py,sha256=sjbwkDr2fP26Ob8f5cSDeDifr3eWFZMDHU4MKr7pIgQ,25217
|
|
82
82
|
tpu_inference/layers/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
83
83
|
tpu_inference/layers/jax/base.py,sha256=Vhts6ZMwNCZ8LbnEXeB0rl3nHdS5hDJWX7HEa7Fl7yE,5775
|
|
84
84
|
tpu_inference/layers/jax/constants.py,sha256=NcYg0zAf3ClfP7YMYdYu_F1GngOzZaIxIAHBZDunKw4,2755
|
|
@@ -108,7 +108,7 @@ tpu_inference/layers/vllm/sharding.py,sha256=as7CF8UKTF3ToymwRY5Pi8uzwJk0P1sHPkW
|
|
|
108
108
|
tpu_inference/layers/vllm/quantization/__init__.py,sha256=SEppGayBzzQ5tsXLSy99aqilkAawQwYxnv2alCg6-ZU,1777
|
|
109
109
|
tpu_inference/layers/vllm/quantization/awq.py,sha256=-8ZmjGvSKJB6_JuwSctNWt8xHWq4VSvK_AK9iahlgCo,8495
|
|
110
110
|
tpu_inference/layers/vllm/quantization/common.py,sha256=8XD64pPa077c9HThFhLFVHlDL9YBafnYwp6rp6gR44E,4432
|
|
111
|
-
tpu_inference/layers/vllm/quantization/mxfp4.py,sha256=
|
|
111
|
+
tpu_inference/layers/vllm/quantization/mxfp4.py,sha256=o661uiSvLvWGr8hQMl7TqYXJyALPREtNWlKHAM9AUrw,14541
|
|
112
112
|
tpu_inference/layers/vllm/quantization/unquantized.py,sha256=nSRBzVurTiQQkF9FuSTshfRwfxfzs54E2_4eK7Eyhj0,15345
|
|
113
113
|
tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
114
114
|
tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py,sha256=6idEyy3e849fZ1UeNvc9eSHYX7e6qvohrJa_d_D9MBk,5285
|
|
@@ -121,7 +121,7 @@ tpu_inference/lora/torch_lora_ops.py,sha256=pr3N7DVfkn3ANijUC6dBoiCtIJW4fdJpKdC3
|
|
|
121
121
|
tpu_inference/lora/torch_punica_tpu.py,sha256=qTnXZGLoOgvukSxeunO_SfpPTlkq9GlMj9H7zVYg9LE,12680
|
|
122
122
|
tpu_inference/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
123
123
|
tpu_inference/models/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
124
|
-
tpu_inference/models/common/model_loader.py,sha256=
|
|
124
|
+
tpu_inference/models/common/model_loader.py,sha256=b3aigca81gMVJt42oF2aoRohQHjBBe3oK3IPblZAaUM,19996
|
|
125
125
|
tpu_inference/models/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
126
126
|
tpu_inference/models/jax/deepseek_v3.py,sha256=SKOHVEC-_2NLxBnzBzbu5tu0d6FTlAEiI1EefGaO2QE,40047
|
|
127
127
|
tpu_inference/models/jax/gpt_oss.py,sha256=Vw4LRB5Kp6hbA2hjZGFS8kiEqOCjf881XH2JNtu2S1I,20924
|
|
@@ -139,36 +139,36 @@ tpu_inference/models/jax/utils/multi_modal_utils.py,sha256=rrIrQWidkUnGilBHKNpdY
|
|
|
139
139
|
tpu_inference/models/jax/utils/weight_utils.py,sha256=qFU53jPHPvIcs_EOdIH80oNojpUp7GdSY2E6NZNsjvM,21376
|
|
140
140
|
tpu_inference/models/jax/utils/quantization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
141
141
|
tpu_inference/models/jax/utils/quantization/mxfp4_utils.py,sha256=boGnqJCRIOf5nedAxQ8_IUTV6Rfll10DXnRC40BeeE8,3682
|
|
142
|
-
tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=
|
|
142
|
+
tpu_inference/models/jax/utils/quantization/quantization_utils.py,sha256=rzAFU3OtQvg8w8ow0V15rMljAsa4SBrwOye6OI8Bty4,26530
|
|
143
143
|
tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml,sha256=d_YHPtaRJ_7PBrPijSzJGnVeoJO62tKIGqrgFqpYT1k,137
|
|
144
144
|
tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml,sha256=b7SyL75HuSTj3fN9_ZLCK_CDiccL5DGq_DddGmxj_qk,170
|
|
145
145
|
tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml,sha256=0Qwij71zj9k6rmrUNd8Q5df9YYfkoJ1ZkgMAHxQy81k,128
|
|
146
146
|
tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml,sha256=lGec0UwwxmNPNgKPSsTsCMSXNJjhw507KMtM2NsSCMw,152
|
|
147
147
|
tpu_inference/models/vllm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
148
|
-
tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=
|
|
148
|
+
tpu_inference/models/vllm/vllm_model_wrapper.py,sha256=3EcaD_1vZuyAZBfDtm5u_qfCahQU28qR4rAUraNAFqs,12305
|
|
149
149
|
tpu_inference/models/vllm/vllm_model_wrapper_context.py,sha256=yxlJHPmRQIAwlb1MmHK3xfXokgIkJ-evNU4PgyoJUdg,1187
|
|
150
150
|
tpu_inference/platforms/__init__.py,sha256=lQCrKddS_GcGpCbeogvz9zOZD1mQw5bBsiw8On46qFQ,74
|
|
151
|
-
tpu_inference/platforms/tpu_platform.py,sha256=
|
|
151
|
+
tpu_inference/platforms/tpu_platform.py,sha256=F4jjPEFHFUTxdfWZYTBuUVJt6SYTFeWEKmrl74sX-Zk,10663
|
|
152
152
|
tpu_inference/runner/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
153
153
|
tpu_inference/runner/block_table.py,sha256=K3Ic8EgPM08d_C5nEN60mxoRydlaQWySAemf_8Q_qVw,4175
|
|
154
|
-
tpu_inference/runner/compilation_manager.py,sha256=
|
|
154
|
+
tpu_inference/runner/compilation_manager.py,sha256=dU0Yk8f0LtRTBe2q0iB3xcMSRco_WPsj2wS6zZJ8WhY,40375
|
|
155
155
|
tpu_inference/runner/input_batch.py,sha256=bx221NX2IOWzrtopss-B-2ZKW4y-U6nQpG09PjpUziw,18273
|
|
156
156
|
tpu_inference/runner/kv_cache.py,sha256=F4dzW2d53xuxkFUn0oKzwE6VklGUeVm-QM19NVfIQDU,4577
|
|
157
|
-
tpu_inference/runner/kv_cache_manager.py,sha256=
|
|
157
|
+
tpu_inference/runner/kv_cache_manager.py,sha256=N0a896CE7Zrs_d4ZSSzRdqgjV1It57RBDSIpOzkRqro,22013
|
|
158
158
|
tpu_inference/runner/lora_utils.py,sha256=B4xMCgXGJ4VNdePvn89HH3tIZ-gYsQ7Vq_YCiYIATEY,3843
|
|
159
159
|
tpu_inference/runner/multimodal_manager.py,sha256=azEPdHOwz8CN11MQmorGdtrCLbFaTCxdWyuEsZTzjYM,9778
|
|
160
|
-
tpu_inference/runner/persistent_batch_manager.py,sha256=
|
|
160
|
+
tpu_inference/runner/persistent_batch_manager.py,sha256=Otu67vOTf1_HKAMZgPDDHlRvvZ3YVJdz-QderH4qOII,13263
|
|
161
161
|
tpu_inference/runner/speculative_decoding_manager.py,sha256=I3FDWKh2dn6nV8LgTGfCTwMKYnxQsTPpBIrmaJngXHs,10215
|
|
162
162
|
tpu_inference/runner/structured_decoding_manager.py,sha256=gZQKQUFxh6xYYH9eGTdbguqk8hc2WwTrIdMMuCcbymE,3573
|
|
163
|
-
tpu_inference/runner/tpu_runner.py,sha256=
|
|
164
|
-
tpu_inference/runner/utils.py,sha256=
|
|
163
|
+
tpu_inference/runner/tpu_runner.py,sha256=NBDKfSGShHmYpudrtGfo1hnVSQTcLpZV_nPiXEo7JPQ,79439
|
|
164
|
+
tpu_inference/runner/utils.py,sha256=lKqL5nxGTk7ufzJRNdp4udn2bPu3jIX52W7akXgSrHc,17133
|
|
165
165
|
tpu_inference/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
166
166
|
tpu_inference/spec_decode/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
167
|
-
tpu_inference/spec_decode/jax/eagle3.py,sha256=
|
|
167
|
+
tpu_inference/spec_decode/jax/eagle3.py,sha256=FxP0uWeQlHlgCpt1nY3FUd4lKlegKJljHyc05jJucaQ,19104
|
|
168
168
|
tpu_inference/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
169
|
-
tpu_inference/worker/tpu_worker.py,sha256=
|
|
170
|
-
tpu_inference-0.11.1.
|
|
171
|
-
tpu_inference-0.11.1.
|
|
172
|
-
tpu_inference-0.11.1.
|
|
173
|
-
tpu_inference-0.11.1.
|
|
174
|
-
tpu_inference-0.11.1.
|
|
169
|
+
tpu_inference/worker/tpu_worker.py,sha256=LnZcSNxdhh0NkoWXxS5bZ0bsTMduSANehy2wELAaVsY,20672
|
|
170
|
+
tpu_inference-0.11.1.dev202512030818.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
171
|
+
tpu_inference-0.11.1.dev202512030818.dist-info/METADATA,sha256=oLzYFTCTvHDQLfyWoc8qV4IMYCoLRTiHECf08oT_bFA,5517
|
|
172
|
+
tpu_inference-0.11.1.dev202512030818.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
173
|
+
tpu_inference-0.11.1.dev202512030818.dist-info/top_level.txt,sha256=gb1hRIQ3DOawUfVzvPL2E__2KPIl9I0vb5r0xcRBGYQ,20
|
|
174
|
+
tpu_inference-0.11.1.dev202512030818.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|