tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import math
|
|
2
3
|
from typing import TYPE_CHECKING, Dict, List
|
|
3
4
|
|
|
4
5
|
import jax
|
|
5
6
|
import jax.numpy as jnp
|
|
6
|
-
import numpy as np
|
|
7
7
|
import vllm.envs as envs
|
|
8
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
9
|
from torchax.ops.mappings import t2j_dtype
|
|
10
|
+
from vllm.attention import Attention
|
|
10
11
|
from vllm.attention.backends.abstract import AttentionType
|
|
11
|
-
from vllm.attention.layer import Attention
|
|
12
12
|
from vllm.config import get_layers_from_vllm_config
|
|
13
|
-
from vllm.utils.math_utils import cdiv
|
|
14
13
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
15
14
|
KVCacheSpec, MLAAttentionSpec,
|
|
16
15
|
SlidingWindowSpec)
|
|
@@ -176,11 +175,6 @@ class KVCacheManager:
|
|
|
176
175
|
)
|
|
177
176
|
self.runner.input_batch = new_input_batch
|
|
178
177
|
self.runner.persistent_batch_manager.input_batch = new_input_batch
|
|
179
|
-
self.runner.block_tables_cpu = [
|
|
180
|
-
np.zeros((self.runner.max_num_reqs,
|
|
181
|
-
cdiv(self.runner.max_model_len, block_size)),
|
|
182
|
-
dtype=np.int32) for block_size in block_sizes
|
|
183
|
-
]
|
|
184
178
|
|
|
185
179
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
186
180
|
self.maybe_reinitialize_input_batch(kv_cache_config)
|
|
@@ -196,7 +190,7 @@ class KVCacheManager:
|
|
|
196
190
|
num_blocks = kv_cache_tensor.size // page_size_bytes
|
|
197
191
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
198
192
|
# num_blocks must be a multiple of dp_size
|
|
199
|
-
num_blocks = (num_blocks
|
|
193
|
+
num_blocks = math.ceil(num_blocks / dp_size) * dp_size
|
|
200
194
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
201
195
|
kv_cache = create_kv_caches(
|
|
202
196
|
num_blocks=num_blocks,
|
|
@@ -289,8 +283,13 @@ class KVCacheManager:
|
|
|
289
283
|
|
|
290
284
|
def _update_layer(cache, slices):
|
|
291
285
|
"""The function to apply to each layer's cache and slices."""
|
|
292
|
-
reshaped_slices = slices.reshape(-1, block_size,
|
|
293
|
-
|
|
286
|
+
reshaped_slices = slices.reshape(-1, 1, block_size,
|
|
287
|
+
*slices.shape[1:])
|
|
288
|
+
for (i, block_idx) in enumerate(block_numbers):
|
|
289
|
+
cache = jax.lax.dynamic_update_slice_in_dim(cache,
|
|
290
|
+
reshaped_slices[i],
|
|
291
|
+
block_idx,
|
|
292
|
+
axis=0)
|
|
294
293
|
return cache
|
|
295
294
|
|
|
296
295
|
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
@@ -343,12 +342,16 @@ class KVCacheManager:
|
|
|
343
342
|
"""
|
|
344
343
|
if block_ids == list(range(block_ids[0],
|
|
345
344
|
block_ids[0] + len(block_ids))):
|
|
346
|
-
|
|
347
|
-
|
|
345
|
+
with runner_utils.LatencyTracker(
|
|
346
|
+
"BatchedGatherKVSlices-for-blocks"):
|
|
347
|
+
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
348
|
+
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
348
349
|
|
|
349
350
|
else:
|
|
350
|
-
|
|
351
|
-
|
|
351
|
+
with runner_utils.LatencyTracker(
|
|
352
|
+
"BatchedGatherKVSlices-for-blocks"):
|
|
353
|
+
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
354
|
+
self.runner.kv_caches, jnp.array(block_ids))
|
|
352
355
|
return batched_kv_cache_per_layer
|
|
353
356
|
|
|
354
357
|
def transfer_kv_cache(self,
|
|
@@ -437,7 +440,6 @@ class KVCacheManager:
|
|
|
437
440
|
kv_cache_slices,
|
|
438
441
|
start_block,
|
|
439
442
|
)
|
|
440
|
-
jax.block_until_ready(self.runner.kv_caches)
|
|
441
443
|
else:
|
|
442
444
|
with runner_utils.LatencyTracker(
|
|
443
445
|
f"JittedInsertKVCache-b{len(block_numbers)}"):
|
|
@@ -449,7 +451,6 @@ class KVCacheManager:
|
|
|
449
451
|
kv_cache_slices,
|
|
450
452
|
jnp.array(block_numbers),
|
|
451
453
|
)
|
|
452
|
-
jax.block_until_ready(self.runner.kv_caches)
|
|
453
454
|
|
|
454
455
|
logger.debug(
|
|
455
456
|
f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
|
|
@@ -14,13 +14,12 @@ class PersistentBatchManager:
|
|
|
14
14
|
def __init__(self, requests: Dict[str, CachedRequestState],
|
|
15
15
|
input_batch: InputBatch, encoder_cache: Dict[str,
|
|
16
16
|
'jax.Array'],
|
|
17
|
-
uses_mrope: bool, model_config
|
|
17
|
+
uses_mrope: bool, model_config):
|
|
18
18
|
self.requests = requests
|
|
19
19
|
self.input_batch = input_batch
|
|
20
20
|
self.encoder_cache = encoder_cache
|
|
21
21
|
self.uses_mrope = uses_mrope
|
|
22
22
|
self.model_config = model_config
|
|
23
|
-
self.is_last_rank = is_last_rank
|
|
24
23
|
|
|
25
24
|
def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
|
|
26
25
|
""" Reorder the sheduled requests to RPA kernel friendly distribution
|
|
@@ -180,35 +179,9 @@ class PersistentBatchManager:
|
|
|
180
179
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
181
180
|
new_block_ids = req_data.new_block_ids[i]
|
|
182
181
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
183
|
-
num_output_tokens = req_data.num_output_tokens[i]
|
|
184
182
|
|
|
185
183
|
# Update the cached states.
|
|
186
184
|
req_state.num_computed_tokens = num_computed_tokens
|
|
187
|
-
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
188
|
-
|
|
189
|
-
if not self.is_last_rank:
|
|
190
|
-
# When using PP, the scheduler sends the sampled tokens back,
|
|
191
|
-
# because there's no direct communication between the first-
|
|
192
|
-
# stage worker and the last-stage worker.
|
|
193
|
-
new_token_ids = req_data.new_token_ids[i]
|
|
194
|
-
# Add the sampled token(s) from the previous step (if any).
|
|
195
|
-
# This doesn't include "unverified" tokens like spec tokens.
|
|
196
|
-
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
197
|
-
req_state.num_tokens)
|
|
198
|
-
if num_new_tokens == 1:
|
|
199
|
-
req_state.output_token_ids.append(new_token_ids[-1])
|
|
200
|
-
elif num_new_tokens > 0:
|
|
201
|
-
req_state.output_token_ids.extend(
|
|
202
|
-
new_token_ids[-num_new_tokens:])
|
|
203
|
-
elif num_output_tokens < len(req_state.output_token_ids):
|
|
204
|
-
del req_state.output_token_ids[num_output_tokens:]
|
|
205
|
-
if req_index is not None:
|
|
206
|
-
end_idx = (self.input_batch.num_prompt_tokens[req_index] +
|
|
207
|
-
num_output_tokens)
|
|
208
|
-
self.input_batch.num_tokens[req_index] = end_idx
|
|
209
|
-
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
|
210
|
-
|
|
211
|
-
# Update the block IDs.
|
|
212
185
|
if not resumed_from_preemption:
|
|
213
186
|
if new_block_ids is not None:
|
|
214
187
|
# Append the new blocks to the existing block IDs.
|
|
@@ -221,6 +194,7 @@ class PersistentBatchManager:
|
|
|
221
194
|
# Replace the existing block IDs with the new ones.
|
|
222
195
|
req_state.block_ids = new_block_ids
|
|
223
196
|
|
|
197
|
+
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
224
198
|
if req_index is None:
|
|
225
199
|
# The request is not in the persistent batch.
|
|
226
200
|
# The request was either preempted and resumed later, or was not
|
|
@@ -235,18 +209,6 @@ class PersistentBatchManager:
|
|
|
235
209
|
self.input_batch.block_table.append_row(
|
|
236
210
|
new_block_ids, req_index)
|
|
237
211
|
|
|
238
|
-
# For the last rank, we don't need to update the token_ids_cpu
|
|
239
|
-
# because the sampled tokens are already cached.
|
|
240
|
-
if not self.is_last_rank:
|
|
241
|
-
start_token_index = num_computed_tokens
|
|
242
|
-
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
243
|
-
self.input_batch.token_ids_cpu[
|
|
244
|
-
req_index,
|
|
245
|
-
start_token_index:end_token_index] = new_token_ids
|
|
246
|
-
self.input_batch.num_tokens_no_spec[
|
|
247
|
-
req_index] = end_token_index
|
|
248
|
-
self.input_batch.num_tokens[req_index] = end_token_index
|
|
249
|
-
|
|
250
212
|
# Add spec_token_ids to token_ids_cpu.
|
|
251
213
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
252
214
|
req_id, ())
|
|
@@ -61,10 +61,11 @@ class StructuredDecodingManager:
|
|
|
61
61
|
self.runner.require_structured_out_cpu.fill(0)
|
|
62
62
|
|
|
63
63
|
sorted_struct_requests = sorted(
|
|
64
|
-
grammar_output.structured_output_request_ids)
|
|
64
|
+
grammar_output.structured_output_request_ids.items(),
|
|
65
|
+
key=lambda item: item[1])
|
|
65
66
|
|
|
66
67
|
cumulative_mask_idx = 0
|
|
67
|
-
for req_id in sorted_struct_requests:
|
|
68
|
+
for req_id, _ in sorted_struct_requests:
|
|
68
69
|
if req_id not in self.runner.input_batch.req_id_to_index:
|
|
69
70
|
continue
|
|
70
71
|
batch_index = self.runner.input_batch.req_id_to_index[req_id]
|