tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
from concurrent.futures import Future
|
|
3
|
+
from multiprocessing import Lock
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
from vllm.logger import init_logger
|
|
7
|
+
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
8
|
+
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
|
9
|
+
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
|
10
|
+
run_method)
|
|
11
|
+
from vllm.v1.executor.abstract import Executor
|
|
12
|
+
from vllm.v1.outputs import AsyncModelRunnerOutput
|
|
13
|
+
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
|
14
|
+
|
|
15
|
+
logger = init_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DisaggExecutor(Executor):
|
|
19
|
+
|
|
20
|
+
def _init_executor(self) -> None:
|
|
21
|
+
"""Initialize the worker and load the model.
|
|
22
|
+
"""
|
|
23
|
+
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
|
24
|
+
rpc_rank=0)
|
|
25
|
+
slice_config = getattr(self.vllm_config.device_config, "slice")
|
|
26
|
+
idx = slice_config[0]
|
|
27
|
+
jax_devices = slice_config[-1]
|
|
28
|
+
devices = []
|
|
29
|
+
if isinstance(idx, int):
|
|
30
|
+
sizes = slice_config[1]
|
|
31
|
+
start = sum(sizes[0:idx])
|
|
32
|
+
end = start + sizes[idx]
|
|
33
|
+
|
|
34
|
+
devices = jax_devices[start:end]
|
|
35
|
+
setattr(self.vllm_config.device_config, "slice",
|
|
36
|
+
(idx + 1, sizes, jax_devices))
|
|
37
|
+
logger.debug(
|
|
38
|
+
f"Creating DisaggExecutor with {devices}, index: {start} -> {end}"
|
|
39
|
+
)
|
|
40
|
+
elif isinstance(idx, tuple):
|
|
41
|
+
slice_idx = slice_config[1]
|
|
42
|
+
sizes = slice_config[2][slice_idx]
|
|
43
|
+
start_row, start_col = idx
|
|
44
|
+
selected_devices = []
|
|
45
|
+
max_row, max_col = 0, 0
|
|
46
|
+
for device in jax_devices:
|
|
47
|
+
coords = device.coords
|
|
48
|
+
max_row = max(max_row, coords[0])
|
|
49
|
+
max_col = max(max_col, coords[1])
|
|
50
|
+
if coords[0] >= start_row and coords[0] < start_row + sizes[0]:
|
|
51
|
+
if coords[1] >= start_col and coords[
|
|
52
|
+
1] < start_col + sizes[1]:
|
|
53
|
+
selected_devices.append(device)
|
|
54
|
+
max_row, max_col = max_row + 1, max_col + 1
|
|
55
|
+
|
|
56
|
+
devices = selected_devices
|
|
57
|
+
if start_col + sizes[1] >= max_col:
|
|
58
|
+
start_row += sizes[0]
|
|
59
|
+
start_col = 0
|
|
60
|
+
else:
|
|
61
|
+
start_col += sizes[1]
|
|
62
|
+
|
|
63
|
+
setattr(self.vllm_config.device_config, "slice",
|
|
64
|
+
((start_row, start_col), slice_idx + 1, slice_config[2],
|
|
65
|
+
jax_devices))
|
|
66
|
+
logger.debug(
|
|
67
|
+
f"Creating DisaggExecutor with {devices}, next start: {((start_row, start_col), slice_idx+1, slice_config[2])}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
distributed_init_method = get_distributed_init_method(
|
|
71
|
+
get_ip(), get_open_port())
|
|
72
|
+
local_rank = 0
|
|
73
|
+
rank = 0
|
|
74
|
+
is_driver_worker = True
|
|
75
|
+
kwargs = dict(
|
|
76
|
+
vllm_config=self.vllm_config,
|
|
77
|
+
local_rank=local_rank,
|
|
78
|
+
rank=rank,
|
|
79
|
+
distributed_init_method=distributed_init_method,
|
|
80
|
+
is_driver_worker=is_driver_worker,
|
|
81
|
+
devices=devices,
|
|
82
|
+
)
|
|
83
|
+
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
|
84
|
+
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
|
85
|
+
self.collective_rpc("init_worker", args=([kwargs], ))
|
|
86
|
+
self.collective_rpc("init_device")
|
|
87
|
+
self.collective_rpc("load_model")
|
|
88
|
+
|
|
89
|
+
def collective_rpc(self,
|
|
90
|
+
method: Union[str, Callable],
|
|
91
|
+
timeout: Optional[float] = None,
|
|
92
|
+
args: Tuple = (),
|
|
93
|
+
kwargs: Optional[Dict] = None,
|
|
94
|
+
non_block: bool = False) -> List[Any]:
|
|
95
|
+
if kwargs is None:
|
|
96
|
+
kwargs = {}
|
|
97
|
+
|
|
98
|
+
if not non_block:
|
|
99
|
+
return [run_method(self.driver_worker, method, args, kwargs)]
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
result = run_method(self.driver_worker, method, args, kwargs)
|
|
103
|
+
if isinstance(result, AsyncModelRunnerOutput):
|
|
104
|
+
if (async_thread := self.async_output_thread) is not None:
|
|
105
|
+
return [async_thread.submit(result.get_output)]
|
|
106
|
+
result = result.get_output()
|
|
107
|
+
future = Future[Any]()
|
|
108
|
+
future.set_result(result)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
future = Future[Any]()
|
|
111
|
+
future.set_exception(e)
|
|
112
|
+
return [future]
|
|
113
|
+
|
|
114
|
+
def check_health(self) -> None:
|
|
115
|
+
# DisaggExecutor will always be healthy as long as
|
|
116
|
+
# it's running.
|
|
117
|
+
return
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
PREFILL_SLICES = 'PREFILL_SLICES'
|
|
7
|
+
DECODE_SLICES = 'DECODE_SLICES'
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def is_disagg_enabled() -> bool:
|
|
11
|
+
# We triggrer our code path as long as prefill slices are set. This
|
|
12
|
+
# allows us to test interleave mode effectively with the code path
|
|
13
|
+
# for comparison purposes.
|
|
14
|
+
return PREFILL_SLICES in os.environ
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _parse_slices(slices_str: str) -> Tuple[int, ...]:
|
|
18
|
+
"""Parse slices environment variable and return the a list of integers, each the size of a slice.
|
|
19
|
+
|
|
20
|
+
For example, if slices_str is set to `2x2,2x1,2x4`, we should return `(4, 2, 8)`.
|
|
21
|
+
|
|
22
|
+
Throws exception if the slice str is malformed.
|
|
23
|
+
"""
|
|
24
|
+
if not slices_str:
|
|
25
|
+
return ()
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
slice_sizes = []
|
|
29
|
+
for s in slices_str.split(','):
|
|
30
|
+
dims = s.split('x')
|
|
31
|
+
if len(dims) == 1:
|
|
32
|
+
slice_sizes.append(int(dims[0]))
|
|
33
|
+
elif len(dims) == 2:
|
|
34
|
+
slice_sizes.append((int(dims[0]), int(dims[1])))
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError("Each slice must be in 'N' or 'NxM' format.")
|
|
37
|
+
return tuple(slice_sizes)
|
|
38
|
+
except ValueError as e:
|
|
39
|
+
raise ValueError(f"Malformed slice string: '{slices_str}'") from e
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_prefill_slices() -> Tuple[int, ...]:
|
|
43
|
+
if PREFILL_SLICES not in os.environ:
|
|
44
|
+
return ()
|
|
45
|
+
return _parse_slices(os.environ[PREFILL_SLICES])
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_decode_slices() -> Tuple[int, ...]:
|
|
49
|
+
if DECODE_SLICES not in os.environ:
|
|
50
|
+
return ()
|
|
51
|
+
return _parse_slices(os.environ[DECODE_SLICES])
|
|
File without changes
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AbstractModelRunnerOutput(ABC):
|
|
7
|
+
"""Abstract base class for model runner output."""
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AbstractSchedulerOutput(ABC):
|
|
12
|
+
"""Abstract base class for scheduler output."""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AbstractLoRARequest(ABC):
|
|
17
|
+
"""Abstract base class for LoRA request."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AbstractKVCacheConfig(ABC):
|
|
22
|
+
"""Abstract base class for KV cache config."""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AbstractKVCacheSpec(ABC):
|
|
27
|
+
"""Abstract base class for KV cache spec."""
|
|
28
|
+
pass
|
tpu_inference/di/host.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2025 Google LLC
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
from typing import Any, Callable, Dict, Type
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DIHost:
|
|
20
|
+
"""
|
|
21
|
+
A simple dependency injection host.
|
|
22
|
+
|
|
23
|
+
This host manages a graph of functions, where each function is a provider
|
|
24
|
+
for a specific data type and declares its own dependencies.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
self._providers: Dict[Type, Callable[..., Any]] = {}
|
|
29
|
+
self._dependencies: Dict[Callable[..., Any], Dict[str, Type]] = {}
|
|
30
|
+
|
|
31
|
+
def register(self,
|
|
32
|
+
provider: Callable[..., Any],
|
|
33
|
+
output_type: Type,
|
|
34
|
+
dependencies: Dict[str, Type] = None):
|
|
35
|
+
"""
|
|
36
|
+
Registers a provider function with the host.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
provider: The function that produces the output.
|
|
40
|
+
output_type: The data type that the function produces.
|
|
41
|
+
dependencies: A dictionary mapping argument names of the provider
|
|
42
|
+
to the data types they require.
|
|
43
|
+
"""
|
|
44
|
+
self._providers[output_type] = provider
|
|
45
|
+
if dependencies:
|
|
46
|
+
self._dependencies[provider] = dependencies
|
|
47
|
+
|
|
48
|
+
def resolve(self, target_type: Type) -> Any:
|
|
49
|
+
"""
|
|
50
|
+
Resolves a dependency by creating an instance of the target type.
|
|
51
|
+
|
|
52
|
+
This method will recursively resolve all dependencies required to call
|
|
53
|
+
the provider for the target type.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
target_type: The data type to be resolved.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
An instance of the target type.
|
|
60
|
+
"""
|
|
61
|
+
if target_type not in self._providers:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"No provider registered for type {target_type.__name__}")
|
|
64
|
+
|
|
65
|
+
provider = self._providers[target_type]
|
|
66
|
+
|
|
67
|
+
if provider not in self._dependencies:
|
|
68
|
+
# Provider has no dependencies, so just call it.
|
|
69
|
+
return provider()
|
|
70
|
+
|
|
71
|
+
# Resolve dependencies for the provider.
|
|
72
|
+
kwargs = {}
|
|
73
|
+
for arg_name, dep_type in self._dependencies[provider].items():
|
|
74
|
+
kwargs[arg_name] = self.resolve(dep_type)
|
|
75
|
+
|
|
76
|
+
return provider(**kwargs)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2025 Google LLC
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
import abc
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class HostInterface(abc.ABC):
|
|
20
|
+
"""
|
|
21
|
+
An interface that the host system (e.g., SGLang, vLLM) must implement.
|
|
22
|
+
This defines the contract for how the backend can call back into the host.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abc.abstractmethod
|
|
26
|
+
def get_next_batch_to_run(self):
|
|
27
|
+
"""
|
|
28
|
+
The backend calls this to get the next batch of requests to process.
|
|
29
|
+
"""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
@abc.abstractmethod
|
|
33
|
+
def process_batch_result(self, batch_result):
|
|
34
|
+
"""
|
|
35
|
+
The backend calls this to return the results of a processed batch.
|
|
36
|
+
"""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BackendInterface(abc.ABC):
|
|
41
|
+
"""
|
|
42
|
+
An interface that the backend system (e.g., tpu_inference) must implement.
|
|
43
|
+
This defines the contract for how the host can call into the backend.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
def launch_tpu_batch(self, batch_to_launch):
|
|
48
|
+
"""
|
|
49
|
+
The host calls this to launch a batch of requests on the backend.
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
File without changes
|