tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defines the abstract contracts for cache managers.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class IKVCacheManager(Protocol):
|
|
8
|
+
"""
|
|
9
|
+
Abstract contract for a KVCacheManager.
|
|
10
|
+
"""
|
|
11
|
+
# Add methods and properties from vllm.v1.core.kv_cache_manager.KVCacheManager
|
|
12
|
+
# that tpu_inference actually uses.
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class IEncoderCacheManager(Protocol):
|
|
17
|
+
"""
|
|
18
|
+
Abstract contract for an EncoderCacheManager.
|
|
19
|
+
"""
|
|
20
|
+
# Add methods and properties from vllm.v1.core.encoder_cache_manager.EncoderCacheManager
|
|
21
|
+
# that tpu_inference actually uses.
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class IMirroredProcessingCache(Protocol):
|
|
26
|
+
"""
|
|
27
|
+
Abstract contract for a MirroredProcessingCache.
|
|
28
|
+
"""
|
|
29
|
+
# Add methods and properties from vllm.v1.engine.mm_input_cache.MirroredProcessingCache
|
|
30
|
+
# that tpu_inference actually uses.
|
|
31
|
+
...
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defines the abstract contract for a configuration object.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any, Optional, Protocol
|
|
5
|
+
|
|
6
|
+
from .config_parts import (ICacheConfig, ICompilationConfig, IModelConfig,
|
|
7
|
+
IParallelConfig, ISchedulerConfig,
|
|
8
|
+
ISpeculativeConfig)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class IConfig(Protocol):
|
|
12
|
+
"""
|
|
13
|
+
A minimal, abstract interface for a configuration object.
|
|
14
|
+
|
|
15
|
+
This protocol defines only the methods and properties that tpu_inference
|
|
16
|
+
requires to operate. Client libraries (like vLLM) will provide concrete
|
|
17
|
+
implementations that satisfy this contract.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def cache_config(self) -> ICacheConfig:
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def compilation_config(self) -> ICompilationConfig:
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def model_config(self) -> Optional[IModelConfig]:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def parallel_config(self) -> IParallelConfig:
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def scheduler_config(self) -> ISchedulerConfig:
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def speculative_config(self) -> Optional[ISpeculativeConfig]:
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
# Escape hatch for direct access when needed by the adapter.
|
|
45
|
+
@property
|
|
46
|
+
def vllm_config(self) -> Any:
|
|
47
|
+
...
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defines the abstract contracts for the component parts of an IConfig.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any, Optional, Protocol
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class IModelConfig(Protocol):
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
def dtype(self) -> torch.dtype:
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
@dtype.setter
|
|
16
|
+
def dtype(self, value: torch.dtype) -> None:
|
|
17
|
+
...
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def use_mla(self) -> bool:
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ICacheConfig(Protocol):
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def block_size(self) -> Optional[int]:
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
@block_size.setter
|
|
31
|
+
def block_size(self, value: Optional[int]) -> None:
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class IParallelConfig(Protocol):
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def worker_cls(self) -> str:
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
@worker_cls.setter
|
|
42
|
+
def worker_cls(self, value: str) -> None:
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ISchedulerConfig(Protocol):
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def max_num_seqs(self) -> int:
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def is_multi_step(self) -> bool:
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def is_multimodal_model(self) -> bool:
|
|
58
|
+
...
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def disable_chunked_mm_input(self) -> bool:
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
@disable_chunked_mm_input.setter
|
|
65
|
+
def disable_chunked_mm_input(self, value: bool) -> None:
|
|
66
|
+
...
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def enable_chunked_prefill(self) -> bool:
|
|
70
|
+
...
|
|
71
|
+
|
|
72
|
+
@enable_chunked_prefill.setter
|
|
73
|
+
def enable_chunked_prefill(self, value: bool) -> None:
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def chunked_prefill_enabled(self) -> bool:
|
|
78
|
+
...
|
|
79
|
+
|
|
80
|
+
@chunked_prefill_enabled.setter
|
|
81
|
+
def chunked_prefill_enabled(self, value: bool) -> None:
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def max_model_len(self) -> int:
|
|
86
|
+
...
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def max_num_batched_tokens(self) -> int:
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
@max_num_batched_tokens.setter
|
|
93
|
+
def max_num_batched_tokens(self, value: int) -> None:
|
|
94
|
+
...
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class ICompilationConfig(Protocol):
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def level(self) -> Any:
|
|
101
|
+
...
|
|
102
|
+
|
|
103
|
+
@level.setter
|
|
104
|
+
def level(self, value: Any) -> None:
|
|
105
|
+
...
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def backend(self) -> str:
|
|
109
|
+
...
|
|
110
|
+
|
|
111
|
+
@backend.setter
|
|
112
|
+
def backend(self, value: str) -> None:
|
|
113
|
+
...
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ISpeculativeConfig(Protocol):
|
|
117
|
+
...
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines the engine interface contracts required by tpu_inference.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Protocol
|
|
6
|
+
|
|
7
|
+
# tpu_inference now depends on its own, locally defined interfaces.
|
|
8
|
+
from .cache import IMirroredProcessingCache
|
|
9
|
+
from .outputs import IStructuredOutputManager
|
|
10
|
+
from .scheduler import IScheduler
|
|
11
|
+
|
|
12
|
+
# This block is only processed by type checkers, not at runtime.
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .outputs import IModelRunnerOutput
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class IEngineProc(Protocol):
|
|
18
|
+
"""
|
|
19
|
+
A high-level interface for any process that can be launched by a client.
|
|
20
|
+
It defines the single entry point for starting the process's main loop.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def run_busy_loop(self) -> None:
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class IDisaggEngineCoreProc(IEngineProc):
|
|
28
|
+
"""
|
|
29
|
+
An interface for the disaggregated engine process. It inherits the common
|
|
30
|
+
IEngineProc contract.
|
|
31
|
+
"""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class IEngineCore(Protocol):
|
|
36
|
+
"""
|
|
37
|
+
An interface defining the contract for an Engine Core building block.
|
|
38
|
+
This mirrors the public API of a vLLM Engine Core that is used by the
|
|
39
|
+
DisaggEngineCoreProc.
|
|
40
|
+
"""
|
|
41
|
+
scheduler: IScheduler
|
|
42
|
+
mm_input_cache_server: IMirroredProcessingCache
|
|
43
|
+
structured_output_manager: IStructuredOutputManager
|
|
44
|
+
model_executor: Any
|
|
45
|
+
|
|
46
|
+
def execute_model_with_error_logging(self, *args,
|
|
47
|
+
**kwargs) -> "IModelRunnerOutput":
|
|
48
|
+
...
|
|
49
|
+
|
|
50
|
+
def shutdown(self) -> None:
|
|
51
|
+
...
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defines the abstract contracts for model and structured outputs.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Protocol
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class IModelRunnerOutput(Protocol):
|
|
8
|
+
"""
|
|
9
|
+
Abstract contract for the output of a model runner.
|
|
10
|
+
"""
|
|
11
|
+
# Add methods and properties from vllm.v1.outputs.ModelRunnerOutput
|
|
12
|
+
# that tpu_inference actually uses.
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class IStructuredOutputManager(Protocol):
|
|
17
|
+
"""
|
|
18
|
+
Abstract contract for a StructuredOutputManager.
|
|
19
|
+
"""
|
|
20
|
+
# Add methods and properties from vllm.v1.structured_output.StructuredOutputManager
|
|
21
|
+
# that tpu_inference actually uses.
|
|
22
|
+
...
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defines the abstract contracts for sampling and pooling parameters.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any, Protocol
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class IPoolingParams(Protocol):
|
|
8
|
+
"""
|
|
9
|
+
Abstract contract for PoolingParams.
|
|
10
|
+
"""
|
|
11
|
+
...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ISamplingParams(Protocol):
|
|
15
|
+
"""
|
|
16
|
+
Abstract contract for SamplingParams.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def sampling_type(self) -> Any:
|
|
21
|
+
...
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defines the abstract contract for a hardware platform.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any, Optional, Protocol, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from .config import IConfig
|
|
9
|
+
from .params import IPoolingParams, ISamplingParams
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class IPlatform(Protocol):
|
|
13
|
+
"""
|
|
14
|
+
A minimal, abstract interface for a hardware platform.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def can_update_inplace(self) -> bool:
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
def check_and_update_config(self, vllm_config: IConfig) -> None:
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
def get_attn_backend_cls(self, selected_backend: Any, head_size: int,
|
|
24
|
+
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
|
25
|
+
block_size: int, use_v1: bool, use_mla: bool,
|
|
26
|
+
has_sink: bool, use_spare: bool) -> str:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
def get_device_communicator_cls(self) -> str:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
def get_device_name(self, device_id: int = 0) -> str:
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
def get_device_total_memory(self, device_id: int = 0) -> int:
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
def get_infinity_values(self, dtype: torch.dtype) -> tuple[float, float]:
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
def get_lora_vocab_padding_size(self) -> int:
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
def get_punica_wrapper(self) -> str:
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
def inference_mode(self) -> Any:
|
|
48
|
+
...
|
|
49
|
+
|
|
50
|
+
def is_async_output_supported(self, enforce_eager: Optional[bool]) -> bool:
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
def is_kv_cache_dtype_supported(self, kv_cache_dtype: str) -> bool:
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
def is_pin_memory_available(self) -> bool:
|
|
57
|
+
...
|
|
58
|
+
|
|
59
|
+
def set_device(self, device: torch.device) -> None:
|
|
60
|
+
...
|
|
61
|
+
|
|
62
|
+
def supports_v1(self, model_config: Any) -> bool:
|
|
63
|
+
...
|
|
64
|
+
|
|
65
|
+
def use_all_gather(self) -> bool:
|
|
66
|
+
...
|
|
67
|
+
|
|
68
|
+
def validate_request(
|
|
69
|
+
self,
|
|
70
|
+
prompt: Any,
|
|
71
|
+
params: Union[ISamplingParams, IPoolingParams],
|
|
72
|
+
processed_inputs: Any,
|
|
73
|
+
) -> None:
|
|
74
|
+
...
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defines the abstract contract for a Request.
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any, Protocol
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class IRequest(Protocol):
|
|
8
|
+
"""
|
|
9
|
+
A minimal, abstract interface for a request.
|
|
10
|
+
|
|
11
|
+
This protocol defines only the methods and properties that tpu_inference
|
|
12
|
+
requires to operate. Client libraries (like vLLM) will provide concrete
|
|
13
|
+
implementations that satisfy this contract.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def vllm_request(self) -> Any:
|
|
18
|
+
...
|
|
19
|
+
|
|
20
|
+
def is_finished(self) -> bool:
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
def get_request_id(self) -> str:
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
# Add mm_hashes. it's used by `if request.mm_hashes is not None:`.
|
|
27
|
+
|
|
28
|
+
# Add other methods and properties from vllm.v1.request.Request that are
|
|
29
|
+
# actually used by the orchestration logic.
|
|
30
|
+
# For example:
|
|
31
|
+
# @property
|
|
32
|
+
# def prompt(self) -> str: ...
|
|
33
|
+
#
|
|
34
|
+
# @property
|
|
35
|
+
# def prompt_token_ids(self) -> list[int]: ...
|
|
36
|
+
#
|
|
37
|
+
# def is_finished(self) -> bool: ...
|
|
38
|
+
#
|
|
39
|
+
# ... etc.
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module defines the scheduler interface contract required by tpu_inference.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Protocol
|
|
6
|
+
|
|
7
|
+
# tpu_inference now depends on its own, locally defined interfaces.
|
|
8
|
+
from .cache import IEncoderCacheManager, IKVCacheManager
|
|
9
|
+
from .request import IRequest
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class IScheduler(Protocol):
|
|
13
|
+
"""
|
|
14
|
+
An extended interface for a scheduler, tailored to the needs
|
|
15
|
+
of advanced orchestration engines.
|
|
16
|
+
|
|
17
|
+
This contract is defined by tpu_inference and must be implemented by
|
|
18
|
+
any client library (like vLLM) that wishes to use this orchestrator.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def requests(self) -> Dict[str, IRequest]:
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def kv_cache_manager(self) -> IKVCacheManager:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def encoder_cache_manager(self) -> IEncoderCacheManager:
|
|
31
|
+
...
|
|
File without changes
|
|
File without changes
|