tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,117 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from concurrent.futures import Future
3
+ from multiprocessing import Lock
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ from vllm.logger import init_logger
7
+ from vllm.multimodal import MULTIMODAL_REGISTRY
8
+ from vllm.multimodal.cache import worker_receiver_cache_from_config
9
+ from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
10
+ run_method)
11
+ from vllm.v1.executor.abstract import Executor
12
+ from vllm.v1.outputs import AsyncModelRunnerOutput
13
+ from vllm.v1.worker.worker_base import WorkerWrapperBase
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ class DisaggExecutor(Executor):
19
+
20
+ def _init_executor(self) -> None:
21
+ """Initialize the worker and load the model.
22
+ """
23
+ self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
24
+ rpc_rank=0)
25
+ slice_config = getattr(self.vllm_config.device_config, "slice")
26
+ idx = slice_config[0]
27
+ jax_devices = slice_config[-1]
28
+ devices = []
29
+ if isinstance(idx, int):
30
+ sizes = slice_config[1]
31
+ start = sum(sizes[0:idx])
32
+ end = start + sizes[idx]
33
+
34
+ devices = jax_devices[start:end]
35
+ setattr(self.vllm_config.device_config, "slice",
36
+ (idx + 1, sizes, jax_devices))
37
+ logger.debug(
38
+ f"Creating DisaggExecutor with {devices}, index: {start} -> {end}"
39
+ )
40
+ elif isinstance(idx, tuple):
41
+ slice_idx = slice_config[1]
42
+ sizes = slice_config[2][slice_idx]
43
+ start_row, start_col = idx
44
+ selected_devices = []
45
+ max_row, max_col = 0, 0
46
+ for device in jax_devices:
47
+ coords = device.coords
48
+ max_row = max(max_row, coords[0])
49
+ max_col = max(max_col, coords[1])
50
+ if coords[0] >= start_row and coords[0] < start_row + sizes[0]:
51
+ if coords[1] >= start_col and coords[
52
+ 1] < start_col + sizes[1]:
53
+ selected_devices.append(device)
54
+ max_row, max_col = max_row + 1, max_col + 1
55
+
56
+ devices = selected_devices
57
+ if start_col + sizes[1] >= max_col:
58
+ start_row += sizes[0]
59
+ start_col = 0
60
+ else:
61
+ start_col += sizes[1]
62
+
63
+ setattr(self.vllm_config.device_config, "slice",
64
+ ((start_row, start_col), slice_idx + 1, slice_config[2],
65
+ jax_devices))
66
+ logger.debug(
67
+ f"Creating DisaggExecutor with {devices}, next start: {((start_row, start_col), slice_idx+1, slice_config[2])}"
68
+ )
69
+
70
+ distributed_init_method = get_distributed_init_method(
71
+ get_ip(), get_open_port())
72
+ local_rank = 0
73
+ rank = 0
74
+ is_driver_worker = True
75
+ kwargs = dict(
76
+ vllm_config=self.vllm_config,
77
+ local_rank=local_rank,
78
+ rank=rank,
79
+ distributed_init_method=distributed_init_method,
80
+ is_driver_worker=is_driver_worker,
81
+ devices=devices,
82
+ )
83
+ self.mm_receiver_cache = worker_receiver_cache_from_config(
84
+ self.vllm_config, MULTIMODAL_REGISTRY, Lock())
85
+ self.collective_rpc("init_worker", args=([kwargs], ))
86
+ self.collective_rpc("init_device")
87
+ self.collective_rpc("load_model")
88
+
89
+ def collective_rpc(self,
90
+ method: Union[str, Callable],
91
+ timeout: Optional[float] = None,
92
+ args: Tuple = (),
93
+ kwargs: Optional[Dict] = None,
94
+ non_block: bool = False) -> List[Any]:
95
+ if kwargs is None:
96
+ kwargs = {}
97
+
98
+ if not non_block:
99
+ return [run_method(self.driver_worker, method, args, kwargs)]
100
+
101
+ try:
102
+ result = run_method(self.driver_worker, method, args, kwargs)
103
+ if isinstance(result, AsyncModelRunnerOutput):
104
+ if (async_thread := self.async_output_thread) is not None:
105
+ return [async_thread.submit(result.get_output)]
106
+ result = result.get_output()
107
+ future = Future[Any]()
108
+ future.set_result(result)
109
+ except Exception as e:
110
+ future = Future[Any]()
111
+ future.set_exception(e)
112
+ return [future]
113
+
114
+ def check_health(self) -> None:
115
+ # DisaggExecutor will always be healthy as long as
116
+ # it's running.
117
+ return
@@ -0,0 +1,51 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ from typing import Tuple
5
+
6
+ PREFILL_SLICES = 'PREFILL_SLICES'
7
+ DECODE_SLICES = 'DECODE_SLICES'
8
+
9
+
10
+ def is_disagg_enabled() -> bool:
11
+ # We triggrer our code path as long as prefill slices are set. This
12
+ # allows us to test interleave mode effectively with the code path
13
+ # for comparison purposes.
14
+ return PREFILL_SLICES in os.environ
15
+
16
+
17
+ def _parse_slices(slices_str: str) -> Tuple[int, ...]:
18
+ """Parse slices environment variable and return the a list of integers, each the size of a slice.
19
+
20
+ For example, if slices_str is set to `2x2,2x1,2x4`, we should return `(4, 2, 8)`.
21
+
22
+ Throws exception if the slice str is malformed.
23
+ """
24
+ if not slices_str:
25
+ return ()
26
+
27
+ try:
28
+ slice_sizes = []
29
+ for s in slices_str.split(','):
30
+ dims = s.split('x')
31
+ if len(dims) == 1:
32
+ slice_sizes.append(int(dims[0]))
33
+ elif len(dims) == 2:
34
+ slice_sizes.append((int(dims[0]), int(dims[1])))
35
+ else:
36
+ raise ValueError("Each slice must be in 'N' or 'NxM' format.")
37
+ return tuple(slice_sizes)
38
+ except ValueError as e:
39
+ raise ValueError(f"Malformed slice string: '{slices_str}'") from e
40
+
41
+
42
+ def get_prefill_slices() -> Tuple[int, ...]:
43
+ if PREFILL_SLICES not in os.environ:
44
+ return ()
45
+ return _parse_slices(os.environ[PREFILL_SLICES])
46
+
47
+
48
+ def get_decode_slices() -> Tuple[int, ...]:
49
+ if DECODE_SLICES not in os.environ:
50
+ return ()
51
+ return _parse_slices(os.environ[DECODE_SLICES])
File without changes
@@ -0,0 +1,28 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC
4
+
5
+
6
+ class AbstractModelRunnerOutput(ABC):
7
+ """Abstract base class for model runner output."""
8
+ pass
9
+
10
+
11
+ class AbstractSchedulerOutput(ABC):
12
+ """Abstract base class for scheduler output."""
13
+ pass
14
+
15
+
16
+ class AbstractLoRARequest(ABC):
17
+ """Abstract base class for LoRA request."""
18
+ pass
19
+
20
+
21
+ class AbstractKVCacheConfig(ABC):
22
+ """Abstract base class for KV cache config."""
23
+ pass
24
+
25
+
26
+ class AbstractKVCacheSpec(ABC):
27
+ """Abstract base class for KV cache spec."""
28
+ pass
@@ -0,0 +1,76 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+ from typing import Any, Callable, Dict, Type
17
+
18
+
19
+ class DIHost:
20
+ """
21
+ A simple dependency injection host.
22
+
23
+ This host manages a graph of functions, where each function is a provider
24
+ for a specific data type and declares its own dependencies.
25
+ """
26
+
27
+ def __init__(self):
28
+ self._providers: Dict[Type, Callable[..., Any]] = {}
29
+ self._dependencies: Dict[Callable[..., Any], Dict[str, Type]] = {}
30
+
31
+ def register(self,
32
+ provider: Callable[..., Any],
33
+ output_type: Type,
34
+ dependencies: Dict[str, Type] = None):
35
+ """
36
+ Registers a provider function with the host.
37
+
38
+ Args:
39
+ provider: The function that produces the output.
40
+ output_type: The data type that the function produces.
41
+ dependencies: A dictionary mapping argument names of the provider
42
+ to the data types they require.
43
+ """
44
+ self._providers[output_type] = provider
45
+ if dependencies:
46
+ self._dependencies[provider] = dependencies
47
+
48
+ def resolve(self, target_type: Type) -> Any:
49
+ """
50
+ Resolves a dependency by creating an instance of the target type.
51
+
52
+ This method will recursively resolve all dependencies required to call
53
+ the provider for the target type.
54
+
55
+ Args:
56
+ target_type: The data type to be resolved.
57
+
58
+ Returns:
59
+ An instance of the target type.
60
+ """
61
+ if target_type not in self._providers:
62
+ raise ValueError(
63
+ f"No provider registered for type {target_type.__name__}")
64
+
65
+ provider = self._providers[target_type]
66
+
67
+ if provider not in self._dependencies:
68
+ # Provider has no dependencies, so just call it.
69
+ return provider()
70
+
71
+ # Resolve dependencies for the provider.
72
+ kwargs = {}
73
+ for arg_name, dep_type in self._dependencies[provider].items():
74
+ kwargs[arg_name] = self.resolve(dep_type)
75
+
76
+ return provider(**kwargs)
@@ -0,0 +1,51 @@
1
+ """
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+ import abc
17
+
18
+
19
+ class HostInterface(abc.ABC):
20
+ """
21
+ An interface that the host system (e.g., SGLang, vLLM) must implement.
22
+ This defines the contract for how the backend can call back into the host.
23
+ """
24
+
25
+ @abc.abstractmethod
26
+ def get_next_batch_to_run(self):
27
+ """
28
+ The backend calls this to get the next batch of requests to process.
29
+ """
30
+ pass
31
+
32
+ @abc.abstractmethod
33
+ def process_batch_result(self, batch_result):
34
+ """
35
+ The backend calls this to return the results of a processed batch.
36
+ """
37
+ pass
38
+
39
+
40
+ class BackendInterface(abc.ABC):
41
+ """
42
+ An interface that the backend system (e.g., tpu_inference) must implement.
43
+ This defines the contract for how the host can call into the backend.
44
+ """
45
+
46
+ @abc.abstractmethod
47
+ def launch_tpu_batch(self, batch_to_launch):
48
+ """
49
+ The host calls this to launch a batch of requests on the backend.
50
+ """
51
+ pass
File without changes