tpu-inference 0.11.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1rc1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1rc1.dist-info/RECORD +123 -0
- tpu_inference-0.11.1rc1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1rc1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1rc1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import numpy as np
|
|
6
|
+
from vllm.logger import init_logger
|
|
7
|
+
from vllm.utils import cdiv
|
|
8
|
+
|
|
9
|
+
logger = init_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO(xiang): fix device allocation
|
|
13
|
+
class BlockTable:
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
max_num_reqs: int,
|
|
18
|
+
max_num_blocks_per_req: int,
|
|
19
|
+
max_num_batched_tokens: int,
|
|
20
|
+
pin_memory: bool,
|
|
21
|
+
):
|
|
22
|
+
self.max_num_reqs = max_num_reqs
|
|
23
|
+
self.max_num_blocks_per_req = max_num_blocks_per_req
|
|
24
|
+
self.max_num_batched_tokens = max_num_batched_tokens
|
|
25
|
+
self.pin_memory = pin_memory
|
|
26
|
+
|
|
27
|
+
self.block_table = jnp.zeros(
|
|
28
|
+
(max_num_reqs, max_num_blocks_per_req),
|
|
29
|
+
dtype=jnp.int32,
|
|
30
|
+
)
|
|
31
|
+
self.block_table_cpu = np.zeros(
|
|
32
|
+
(max_num_reqs, max_num_blocks_per_req),
|
|
33
|
+
dtype=jnp.int32,
|
|
34
|
+
)
|
|
35
|
+
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
|
36
|
+
|
|
37
|
+
def append_row(
|
|
38
|
+
self,
|
|
39
|
+
block_ids: list[int],
|
|
40
|
+
row_idx: int,
|
|
41
|
+
) -> None:
|
|
42
|
+
if not block_ids:
|
|
43
|
+
return
|
|
44
|
+
num_blocks = len(block_ids)
|
|
45
|
+
start = self.num_blocks_per_row[row_idx]
|
|
46
|
+
self.num_blocks_per_row[row_idx] += num_blocks
|
|
47
|
+
self.block_table_cpu[row_idx, start:start + num_blocks] = block_ids
|
|
48
|
+
|
|
49
|
+
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
|
50
|
+
self.num_blocks_per_row[row_idx] = 0
|
|
51
|
+
self.append_row(block_ids, row_idx)
|
|
52
|
+
|
|
53
|
+
def move_row(self, src: int, tgt: int) -> None:
|
|
54
|
+
num_blocks = self.num_blocks_per_row[src]
|
|
55
|
+
self.block_table_cpu[tgt, :num_blocks] = self.block_table_cpu[
|
|
56
|
+
src, :num_blocks]
|
|
57
|
+
self.num_blocks_per_row[tgt] = num_blocks
|
|
58
|
+
|
|
59
|
+
def swap_row(self, src: int, tgt: int) -> None:
|
|
60
|
+
num_blocks_src = self.num_blocks_per_row[src]
|
|
61
|
+
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
|
62
|
+
self.num_blocks_per_row[src] = num_blocks_tgt
|
|
63
|
+
self.num_blocks_per_row[tgt] = num_blocks_src
|
|
64
|
+
|
|
65
|
+
self.block_table_cpu[[src, tgt]] = self.block_table_cpu[[tgt, src]]
|
|
66
|
+
|
|
67
|
+
def commit(self, num_reqs: int) -> None:
|
|
68
|
+
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
|
69
|
+
non_blocking=True)
|
|
70
|
+
|
|
71
|
+
def clear(self) -> None:
|
|
72
|
+
self.block_table.fill_(0)
|
|
73
|
+
self.block_table_cpu.fill_(0)
|
|
74
|
+
|
|
75
|
+
def get_device_tensor(self) -> jax.Array:
|
|
76
|
+
"""Ruturns the device tensor of the block table."""
|
|
77
|
+
return self.block_table
|
|
78
|
+
|
|
79
|
+
def get_cpu_tensor(self) -> jax.Array:
|
|
80
|
+
"""Returns the CPU tensor of the block table."""
|
|
81
|
+
return self.block_table_cpu
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class MultiGroupBlockTable:
|
|
85
|
+
"""The BlockTables for each KV cache group."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, max_num_reqs: int, max_model_len: int,
|
|
88
|
+
max_num_batched_tokens: int, pin_memory: bool,
|
|
89
|
+
block_sizes: list[int]) -> None:
|
|
90
|
+
self.block_tables = [
|
|
91
|
+
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
|
|
92
|
+
max_num_batched_tokens, pin_memory)
|
|
93
|
+
for block_size in block_sizes
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
|
97
|
+
for i, block_table in enumerate(self.block_tables):
|
|
98
|
+
block_table.append_row(block_ids[i], row_idx)
|
|
99
|
+
|
|
100
|
+
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
|
101
|
+
for i, block_table in enumerate(self.block_tables):
|
|
102
|
+
block_table.add_row(block_ids[i], row_idx)
|
|
103
|
+
|
|
104
|
+
def move_row(self, src: int, tgt: int) -> None:
|
|
105
|
+
for block_table in self.block_tables:
|
|
106
|
+
block_table.move_row(src, tgt)
|
|
107
|
+
|
|
108
|
+
def swap_row(self, src: int, tgt: int) -> None:
|
|
109
|
+
for block_table in self.block_tables:
|
|
110
|
+
block_table.swap_row(src, tgt)
|
|
111
|
+
|
|
112
|
+
def commit(self, num_reqs: int) -> None:
|
|
113
|
+
for block_table in self.block_tables:
|
|
114
|
+
block_table.commit(num_reqs)
|
|
115
|
+
|
|
116
|
+
def clear(self) -> None:
|
|
117
|
+
for block_table in self.block_tables:
|
|
118
|
+
block_table.clear()
|
|
119
|
+
|
|
120
|
+
def __getitem__(self, idx: int) -> "BlockTable":
|
|
121
|
+
"""Returns the BlockTable for the i-th KV cache group."""
|
|
122
|
+
return self.block_tables[idx]
|