tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import numpy as np
|
|
6
|
+
from vllm.logger import init_logger
|
|
7
|
+
from vllm.utils.math_utils import cdiv
|
|
8
|
+
|
|
9
|
+
logger = init_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO(xiang): fix device allocation
|
|
13
|
+
class BlockTable:
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
max_num_reqs: int,
|
|
18
|
+
max_num_blocks_per_req: int,
|
|
19
|
+
max_num_batched_tokens: int,
|
|
20
|
+
pin_memory: bool,
|
|
21
|
+
):
|
|
22
|
+
self.max_num_reqs = max_num_reqs
|
|
23
|
+
self.max_num_blocks_per_req = max_num_blocks_per_req
|
|
24
|
+
self.max_num_batched_tokens = max_num_batched_tokens
|
|
25
|
+
self.pin_memory = pin_memory
|
|
26
|
+
|
|
27
|
+
self.block_table = jnp.zeros(
|
|
28
|
+
(max_num_reqs, max_num_blocks_per_req),
|
|
29
|
+
dtype=jnp.int32,
|
|
30
|
+
)
|
|
31
|
+
self.block_table_cpu = np.zeros(
|
|
32
|
+
(max_num_reqs, max_num_blocks_per_req),
|
|
33
|
+
dtype=jnp.int32,
|
|
34
|
+
)
|
|
35
|
+
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
|
36
|
+
|
|
37
|
+
def append_row(
|
|
38
|
+
self,
|
|
39
|
+
block_ids: list[int],
|
|
40
|
+
row_idx: int,
|
|
41
|
+
) -> None:
|
|
42
|
+
if not block_ids:
|
|
43
|
+
return
|
|
44
|
+
num_blocks = len(block_ids)
|
|
45
|
+
start = self.num_blocks_per_row[row_idx]
|
|
46
|
+
self.num_blocks_per_row[row_idx] += num_blocks
|
|
47
|
+
self.block_table_cpu[row_idx, start:start + num_blocks] = block_ids
|
|
48
|
+
|
|
49
|
+
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
|
50
|
+
self.num_blocks_per_row[row_idx] = 0
|
|
51
|
+
self.append_row(block_ids, row_idx)
|
|
52
|
+
|
|
53
|
+
def move_row(self, src: int, tgt: int) -> None:
|
|
54
|
+
num_blocks = self.num_blocks_per_row[src]
|
|
55
|
+
self.block_table_cpu[tgt, :num_blocks] = self.block_table_cpu[
|
|
56
|
+
src, :num_blocks]
|
|
57
|
+
self.num_blocks_per_row[tgt] = num_blocks
|
|
58
|
+
|
|
59
|
+
def swap_row(self, src: int, tgt: int) -> None:
|
|
60
|
+
num_blocks_src = self.num_blocks_per_row[src]
|
|
61
|
+
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
|
62
|
+
self.num_blocks_per_row[src] = num_blocks_tgt
|
|
63
|
+
self.num_blocks_per_row[tgt] = num_blocks_src
|
|
64
|
+
|
|
65
|
+
self.block_table_cpu[[src, tgt]] = self.block_table_cpu[[tgt, src]]
|
|
66
|
+
|
|
67
|
+
def commit(self, num_reqs: int) -> None:
|
|
68
|
+
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
|
69
|
+
non_blocking=True)
|
|
70
|
+
|
|
71
|
+
def clear(self) -> None:
|
|
72
|
+
self.block_table.fill_(0)
|
|
73
|
+
self.block_table_cpu.fill_(0)
|
|
74
|
+
|
|
75
|
+
def get_device_tensor(self) -> jax.Array:
|
|
76
|
+
"""Ruturns the device tensor of the block table."""
|
|
77
|
+
return self.block_table
|
|
78
|
+
|
|
79
|
+
def get_cpu_tensor(self) -> jax.Array:
|
|
80
|
+
"""Returns the CPU tensor of the block table."""
|
|
81
|
+
return self.block_table_cpu
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class MultiGroupBlockTable:
|
|
85
|
+
"""The BlockTables for each KV cache group."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, max_num_reqs: int, max_model_len: int,
|
|
88
|
+
max_num_batched_tokens: int, pin_memory: bool,
|
|
89
|
+
block_sizes: list[int]) -> None:
|
|
90
|
+
self.block_tables = [
|
|
91
|
+
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
|
|
92
|
+
max_num_batched_tokens, pin_memory)
|
|
93
|
+
for block_size in block_sizes
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
|
97
|
+
for i, block_table in enumerate(self.block_tables):
|
|
98
|
+
block_table.append_row(block_ids[i], row_idx)
|
|
99
|
+
|
|
100
|
+
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
|
101
|
+
for i, block_table in enumerate(self.block_tables):
|
|
102
|
+
block_table.add_row(block_ids[i], row_idx)
|
|
103
|
+
|
|
104
|
+
def move_row(self, src: int, tgt: int) -> None:
|
|
105
|
+
for block_table in self.block_tables:
|
|
106
|
+
block_table.move_row(src, tgt)
|
|
107
|
+
|
|
108
|
+
def swap_row(self, src: int, tgt: int) -> None:
|
|
109
|
+
for block_table in self.block_tables:
|
|
110
|
+
block_table.swap_row(src, tgt)
|
|
111
|
+
|
|
112
|
+
def commit(self, num_reqs: int) -> None:
|
|
113
|
+
for block_table in self.block_tables:
|
|
114
|
+
block_table.commit(num_reqs)
|
|
115
|
+
|
|
116
|
+
def clear(self) -> None:
|
|
117
|
+
for block_table in self.block_tables:
|
|
118
|
+
block_table.clear()
|
|
119
|
+
|
|
120
|
+
def __getitem__(self, idx: int) -> "BlockTable":
|
|
121
|
+
"""Returns the BlockTable for the i-th KV cache group."""
|
|
122
|
+
return self.block_tables[idx]
|