tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from vllm.config import VllmConfig
|
|
5
|
+
from vllm.v1.engine.core import EngineCore as vLLMEngineCore
|
|
6
|
+
from vllm.v1.request import Request as VllmRequest
|
|
7
|
+
from vllm.v1.request import RequestStatus
|
|
8
|
+
|
|
9
|
+
from tpu_inference.interfaces.config import IConfig
|
|
10
|
+
from tpu_inference.interfaces.engine import IEngineCore, IScheduler
|
|
11
|
+
from tpu_inference.interfaces.request import IRequest
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VllmConfigAdapter(IConfig):
|
|
15
|
+
"""Wraps a vLLM VllmConfig object to expose it as an IConfig."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, vllm_config: VllmConfig):
|
|
18
|
+
self._vllm_config = vllm_config
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def vllm_config(self) -> VllmConfig:
|
|
22
|
+
"""Returns the underlying VllmConfig."""
|
|
23
|
+
return self._vllm_config
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def scheduler_config(self) -> Any:
|
|
27
|
+
return self._vllm_config.scheduler_config
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def cache_config(self) -> Any:
|
|
31
|
+
return self._vllm_config.cache_config
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class VllmSchedulerAdapter(IScheduler):
|
|
35
|
+
"""Wraps a vLLM Scheduler to expose it as an IScheduler."""
|
|
36
|
+
|
|
37
|
+
def __init__(self, scheduler: Any):
|
|
38
|
+
self._scheduler = scheduler
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def requests(self) -> dict[str, IRequest]:
|
|
42
|
+
return self._scheduler.requests
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def kv_cache_manager(self) -> Any:
|
|
46
|
+
return self._scheduler.kv_cache_manager
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def encoder_cache_manager(self) -> Any:
|
|
50
|
+
return self._scheduler.encoder_cache_manager
|
|
51
|
+
|
|
52
|
+
def add_request(self, request: IRequest) -> None:
|
|
53
|
+
# Unwrap the IRequest to pass the concrete vllm.Request
|
|
54
|
+
self._scheduler.add_request(request.vllm_request)
|
|
55
|
+
|
|
56
|
+
def __getattr__(self, name: str) -> Any:
|
|
57
|
+
# Pass through other methods like 'schedule', 'has_requests', etc.
|
|
58
|
+
return getattr(self._scheduler, name)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class VllmEngineAdapter(IEngineCore):
|
|
62
|
+
"""Wraps a vLLM EngineCore to expose it as an IEngineCore."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, engine_core: vLLMEngineCore):
|
|
65
|
+
self._engine_core = engine_core
|
|
66
|
+
# Wrap the concrete scheduler in our scheduler adapter
|
|
67
|
+
self._scheduler = VllmSchedulerAdapter(engine_core.scheduler)
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def scheduler(self) -> IScheduler:
|
|
71
|
+
# Return the adapted scheduler
|
|
72
|
+
return self._scheduler
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def model_executor(self) -> Any:
|
|
76
|
+
return self._engine_core.model_executor
|
|
77
|
+
|
|
78
|
+
def execute_model_with_error_logging(self, *args, **kwargs) -> Any:
|
|
79
|
+
return self._engine_core.execute_model_with_error_logging(
|
|
80
|
+
*args, **kwargs)
|
|
81
|
+
|
|
82
|
+
def shutdown(self) -> None:
|
|
83
|
+
self._engine_core.shutdown()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class VllmRequestAdapter(IRequest):
|
|
87
|
+
"""Wraps a vLLM Request object to expose it as an IRequest."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, vllm_request: VllmRequest):
|
|
90
|
+
self._vllm_request = vllm_request
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def vllm_request(self) -> VllmRequest:
|
|
94
|
+
"""Provides access to the underlying concrete request for unwrapping."""
|
|
95
|
+
return self._vllm_request
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def request_id(self) -> str:
|
|
99
|
+
return self._vllm_request.request_id
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def num_computed_tokens(self) -> int:
|
|
103
|
+
return self._vllm_request.num_computed_tokens
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def num_output_placeholders(self) -> int:
|
|
107
|
+
return self._vllm_request.num_output_placeholders
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def num_tokens(self) -> int:
|
|
111
|
+
return self._vllm_request.num_tokens
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def num_tokens_with_spec(self) -> int:
|
|
115
|
+
return self._vllm_request.num_tokens_with_spec
|
|
116
|
+
|
|
117
|
+
@num_computed_tokens.setter
|
|
118
|
+
def num_computed_tokens(self, value: int) -> None:
|
|
119
|
+
self._vllm_request.num_computed_tokens = value
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def status(self) -> RequestStatus:
|
|
123
|
+
return self._vllm_request.status
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def prompt_token_ids(self):
|
|
127
|
+
return self._vllm_request.prompt_token_ids
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def all_token_ids(self):
|
|
131
|
+
return self._vllm_request.all_token_ids
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def sampling_params(self):
|
|
135
|
+
return self._vllm_request.sampling_params
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def lora_request(self):
|
|
139
|
+
return self._vllm_request.lora_request
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def block_hashes(self):
|
|
143
|
+
return self._vllm_request.block_hashes
|
|
144
|
+
|
|
145
|
+
@status.setter
|
|
146
|
+
def status(self, value: RequestStatus) -> None:
|
|
147
|
+
self._vllm_request.status = value
|
|
148
|
+
|
|
149
|
+
def is_finished(self) -> bool:
|
|
150
|
+
return self._vllm_request.is_finished()
|
|
151
|
+
|
|
152
|
+
def get_request_id(self) -> str:
|
|
153
|
+
return self._vllm_request.request_id
|