tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import numpy as np
|
|
6
|
+
from vllm.logger import init_logger
|
|
7
|
+
from vllm.utils import cdiv
|
|
8
|
+
|
|
9
|
+
logger = init_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO(xiang): fix device allocation
|
|
13
|
+
class BlockTable:
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
max_num_reqs: int,
|
|
18
|
+
max_num_blocks_per_req: int,
|
|
19
|
+
max_num_batched_tokens: int,
|
|
20
|
+
pin_memory: bool,
|
|
21
|
+
):
|
|
22
|
+
self.max_num_reqs = max_num_reqs
|
|
23
|
+
self.max_num_blocks_per_req = max_num_blocks_per_req
|
|
24
|
+
self.max_num_batched_tokens = max_num_batched_tokens
|
|
25
|
+
self.pin_memory = pin_memory
|
|
26
|
+
|
|
27
|
+
self.block_table = jnp.zeros(
|
|
28
|
+
(max_num_reqs, max_num_blocks_per_req),
|
|
29
|
+
dtype=jnp.int32,
|
|
30
|
+
)
|
|
31
|
+
self.block_table_cpu = np.zeros(
|
|
32
|
+
(max_num_reqs, max_num_blocks_per_req),
|
|
33
|
+
dtype=jnp.int32,
|
|
34
|
+
)
|
|
35
|
+
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
|
36
|
+
|
|
37
|
+
def append_row(
|
|
38
|
+
self,
|
|
39
|
+
block_ids: list[int],
|
|
40
|
+
row_idx: int,
|
|
41
|
+
) -> None:
|
|
42
|
+
if not block_ids:
|
|
43
|
+
return
|
|
44
|
+
num_blocks = len(block_ids)
|
|
45
|
+
start = self.num_blocks_per_row[row_idx]
|
|
46
|
+
self.num_blocks_per_row[row_idx] += num_blocks
|
|
47
|
+
self.block_table_cpu[row_idx, start:start + num_blocks] = block_ids
|
|
48
|
+
|
|
49
|
+
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
|
50
|
+
self.num_blocks_per_row[row_idx] = 0
|
|
51
|
+
self.append_row(block_ids, row_idx)
|
|
52
|
+
|
|
53
|
+
def move_row(self, src: int, tgt: int) -> None:
|
|
54
|
+
num_blocks = self.num_blocks_per_row[src]
|
|
55
|
+
self.block_table_cpu[tgt, :num_blocks] = self.block_table_cpu[
|
|
56
|
+
src, :num_blocks]
|
|
57
|
+
self.num_blocks_per_row[tgt] = num_blocks
|
|
58
|
+
|
|
59
|
+
def swap_row(self, src: int, tgt: int) -> None:
|
|
60
|
+
num_blocks_src = self.num_blocks_per_row[src]
|
|
61
|
+
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
|
62
|
+
self.num_blocks_per_row[src] = num_blocks_tgt
|
|
63
|
+
self.num_blocks_per_row[tgt] = num_blocks_src
|
|
64
|
+
|
|
65
|
+
self.block_table_cpu[[src, tgt]] = self.block_table_cpu[[tgt, src]]
|
|
66
|
+
|
|
67
|
+
def commit(self, num_reqs: int) -> None:
|
|
68
|
+
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
|
69
|
+
non_blocking=True)
|
|
70
|
+
|
|
71
|
+
def clear(self) -> None:
|
|
72
|
+
self.block_table.fill_(0)
|
|
73
|
+
self.block_table_cpu.fill_(0)
|
|
74
|
+
|
|
75
|
+
def get_device_tensor(self) -> jax.Array:
|
|
76
|
+
"""Ruturns the device tensor of the block table."""
|
|
77
|
+
return self.block_table
|
|
78
|
+
|
|
79
|
+
def get_cpu_tensor(self) -> jax.Array:
|
|
80
|
+
"""Returns the CPU tensor of the block table."""
|
|
81
|
+
return self.block_table_cpu
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class MultiGroupBlockTable:
|
|
85
|
+
"""The BlockTables for each KV cache group."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, max_num_reqs: int, max_model_len: int,
|
|
88
|
+
max_num_batched_tokens: int, pin_memory: bool,
|
|
89
|
+
block_sizes: list[int]) -> None:
|
|
90
|
+
self.block_tables = [
|
|
91
|
+
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
|
|
92
|
+
max_num_batched_tokens, pin_memory)
|
|
93
|
+
for block_size in block_sizes
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
|
97
|
+
for i, block_table in enumerate(self.block_tables):
|
|
98
|
+
block_table.append_row(block_ids[i], row_idx)
|
|
99
|
+
|
|
100
|
+
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
|
101
|
+
for i, block_table in enumerate(self.block_tables):
|
|
102
|
+
block_table.add_row(block_ids[i], row_idx)
|
|
103
|
+
|
|
104
|
+
def move_row(self, src: int, tgt: int) -> None:
|
|
105
|
+
for block_table in self.block_tables:
|
|
106
|
+
block_table.move_row(src, tgt)
|
|
107
|
+
|
|
108
|
+
def swap_row(self, src: int, tgt: int) -> None:
|
|
109
|
+
for block_table in self.block_tables:
|
|
110
|
+
block_table.swap_row(src, tgt)
|
|
111
|
+
|
|
112
|
+
def commit(self, num_reqs: int) -> None:
|
|
113
|
+
for block_table in self.block_tables:
|
|
114
|
+
block_table.commit(num_reqs)
|
|
115
|
+
|
|
116
|
+
def clear(self) -> None:
|
|
117
|
+
for block_table in self.block_tables:
|
|
118
|
+
block_table.clear()
|
|
119
|
+
|
|
120
|
+
def __getitem__(self, idx: int) -> "BlockTable":
|
|
121
|
+
"""Returns the BlockTable for the i-th KV cache group."""
|
|
122
|
+
return self.block_tables[idx]
|