tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
tpu_inference/utils.py ADDED
@@ -0,0 +1,294 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import os
3
+ from collections import defaultdict
4
+ from collections.abc import Sequence
5
+ from typing import Any, Callable, List, Tuple
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from jax._src import dtypes
11
+ from jax._src import mesh as mesh_lib
12
+ from jax._src import xla_bridge as xb
13
+ from jax._src.lib import xla_client as xc
14
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
15
+ from vllm import envs, utils
16
+
17
+ from tpu_inference.logger import init_logger
18
+
19
+ GBYTES = 1024 * 1024 * 1024
20
+ TPU_HEAD_SIZE_ALIGNMENT = 128
21
+ TPU_SECOND_LAST_MINOR = 8
22
+
23
+ # This is used to translate from a string name for a dtype
24
+ # to formal jax.numpy DType. One use case for this is
25
+ # converting the `--kv_cache_dtype` flag to a dtype.
26
+ TPU_STR_DTYPE_TO_JAX_DTYPE = {
27
+ "bfloat16": jnp.bfloat16,
28
+ "fp8": jnp.float8_e4m3fn,
29
+ "fp8_e4m3": jnp.float8_e4m3,
30
+ "fp8_e5m2": jnp.float8_e5m2,
31
+ "int8": jnp.int8,
32
+ }
33
+
34
+ _megacore = False
35
+ logger = init_logger(__name__)
36
+
37
+
38
+ def enable_megacore() -> None:
39
+ global _megacore
40
+ _megacore = True
41
+
42
+
43
+ def get_megacore() -> bool:
44
+ return _megacore
45
+
46
+
47
+ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
48
+ if tp_size <= num_kv_heads:
49
+ assert num_kv_heads % tp_size == 0
50
+ return num_kv_heads
51
+ else:
52
+ assert tp_size % num_kv_heads == 0
53
+ return tp_size
54
+
55
+
56
+ def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
57
+ usage = []
58
+ if envs.VLLM_TPU_USING_PATHWAYS:
59
+ return pathways_hbm_usage_gb(devices)
60
+
61
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
62
+ if multihost_backend == "ray":
63
+ # MemoryStats is only supported for addressable PjRt devices.
64
+ # Assume all the devices have similar memory usage for now.
65
+ # TODO(ranlihao): find a proper way to get the memory usage of each device.
66
+ for device in devices:
67
+ try:
68
+ hbm_used = device.memory_stats()["bytes_in_use"]
69
+ hbm_limit = device.memory_stats()["bytes_limit"]
70
+ logger.info(
71
+ "Get memory stats for device %s. Assuming all devices have the same usage.",
72
+ device)
73
+ usage.extend([(hbm_used, hbm_limit)] * len(devices))
74
+ break
75
+ except Exception as e:
76
+ logger.warning(
77
+ "Failed to get memory stats for device %s: %s. ", device,
78
+ e)
79
+ else:
80
+ for device in devices:
81
+ hbm_used = device.memory_stats()["bytes_in_use"]
82
+ hbm_limit = device.memory_stats()["bytes_limit"]
83
+ usage.append((hbm_used, hbm_limit))
84
+
85
+ return usage
86
+
87
+
88
+ def get_device_name(num_devices: int | None = None):
89
+ kind = jax.devices()[0].device_kind
90
+ if 'TPU' not in kind:
91
+ raise RuntimeError('Expected TPU devices')
92
+ suffix = ''
93
+ if kind.endswith(' lite'):
94
+ kind = kind[:-len(' lite')]
95
+ suffix = 'e'
96
+ elif kind.endswith('e'):
97
+ kind = kind[:-1]
98
+ suffix = 'e'
99
+ elif kind.endswith('p'):
100
+ kind = kind[:-1]
101
+ suffix = 'p'
102
+ elif kind == 'TPU7x':
103
+ kind = 'TPU v7'
104
+ assert kind[:-1] == 'TPU v', kind
105
+ kind += suffix
106
+ if num_devices is not None:
107
+ kind += f'-{num_devices}'
108
+ return kind
109
+
110
+
111
+ def get_device_hbm_limit() -> int:
112
+
113
+ device_kind = get_device_name()
114
+ if device_kind == "TPU v5p" or device_kind == "TPU v5":
115
+ return 95 * GBYTES
116
+ elif device_kind == "TPU v5e":
117
+ return 16 * GBYTES
118
+ elif device_kind == "TPU v6e" or device_kind == "TPU v4":
119
+ return 32 * GBYTES
120
+ elif device_kind == "TPU v7":
121
+ return 192 * GBYTES
122
+ else:
123
+ raise ValueError(f"Unknown device kind: {device_kind}")
124
+
125
+
126
+ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
127
+ live_arrays = jax.live_arrays()
128
+ hbm_used = defaultdict(int)
129
+ hbm_limit = get_device_hbm_limit()
130
+ for array in live_arrays:
131
+ assert hasattr(array, 'sharding') and hasattr(
132
+ array.sharding, 'device_set'
133
+ ), "This function must not be called within jax tracer (e.g. jit, vmap, grad)"
134
+ for device in array.sharding.device_set:
135
+ hbm_used[device] += array.dtype.itemsize * array.size // len(
136
+ array.sharding.device_set)
137
+ return [(hbm_used[device], hbm_limit) for device in devices]
138
+
139
+
140
+ def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
141
+ usage = hbm_usage_bytes(devices)
142
+ usage = [(round(used / GBYTES, 2), round(limit / GBYTES, 2))
143
+ for used, limit in usage]
144
+ return usage
145
+
146
+
147
+ def get_padded_head_dim(head_dim: int) -> int:
148
+ """Pads head_dim up to the nearest multiple of 128 for kernel performance."""
149
+ return (head_dim + 127) // 128 * 128
150
+
151
+
152
+ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
153
+ if num_heads >= sharding_size:
154
+ assert num_heads % sharding_size == 0
155
+ else:
156
+ assert sharding_size % num_heads == 0
157
+ num_heads = sharding_size
158
+ return num_heads
159
+
160
+
161
+ def get_dtype_packing(dtype):
162
+ bits = dtypes.bit_width(dtype)
163
+ return 32 // bits
164
+
165
+
166
+ def make_optimized_mesh(axis_shapes: Sequence[int],
167
+ axis_names: Sequence[str],
168
+ *,
169
+ devices: Sequence[xc.Device] | None = None):
170
+ if devices is None:
171
+ devices = xb.devices()
172
+ # Sort the devices in case it's passed in an arbitary order
173
+ devices = sorted(devices, key=lambda x: x.coords)
174
+
175
+ def _is_1D(axis_shapes):
176
+ return sum(x > 1 for x in axis_shapes) == 1
177
+
178
+ if _is_1D(axis_shapes):
179
+ dev_kind = devices[0].device_kind
180
+ device_num = len(devices)
181
+ if dev_kind == "TPU v6 lite":
182
+ ordered_devices = None
183
+ # NOTE(chengjiyao):
184
+ # The coords of v6e-8 are
185
+ # (0,0,0)
186
+ # (1,0,0)
187
+ # (0,1,0)
188
+ # (1,1,0)
189
+ # (0,2,0)
190
+ # (1,2,0)
191
+ # (0,3,0)
192
+ # (1,3,0)
193
+ if device_num == 8:
194
+ ordered_devices = np.array([
195
+ devices[0],
196
+ devices[1],
197
+ devices[2],
198
+ devices[3],
199
+ devices[7],
200
+ devices[6],
201
+ devices[5],
202
+ devices[4],
203
+ ])
204
+ # NOTE(chengjiyao):
205
+ # The coords of v6e-4 are
206
+ # (0,0,0)
207
+ # (1,0,0)
208
+ # (0,1,0)
209
+ # (1,1,0)
210
+ elif device_num == 4:
211
+ ordered_devices = np.array([
212
+ devices[0],
213
+ devices[1],
214
+ devices[3],
215
+ devices[2],
216
+ ])
217
+ if ordered_devices is not None:
218
+ ordered_devices = np.array(ordered_devices)
219
+ ordered_devices = ordered_devices.reshape(axis_shapes)
220
+ mesh = mesh_lib.Mesh(ordered_devices, axis_names)
221
+ logger.info("Use customized mesh: %s", mesh)
222
+ return mesh
223
+
224
+ return jax.make_mesh(axis_shapes, axis_names, devices=devices)
225
+
226
+
227
+ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
228
+ """
229
+ Create a device array with the specified mesh and sharding.
230
+
231
+ Args:
232
+ mesh: The JAX mesh to use for device placement
233
+ *args: Positional arguments to pass to jax.device_put
234
+ sharding: Optional sharding specification. If None, uses PartitionSpec(None)
235
+ **kwargs: Keyword arguments to pass to jax.device_put
236
+
237
+ Returns:
238
+ A JAX array placed on the specified devices
239
+ """
240
+ if sharding is None:
241
+ sharding = NamedSharding(mesh, PartitionSpec(None))
242
+ return jax.device_put(*args, device=sharding, **kwargs)
243
+
244
+
245
+ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
246
+ """
247
+ A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
248
+ """
249
+ if hash_fn_name == "builtin":
250
+ return hash
251
+ return utils.get_hash_fn_by_name(hash_fn_name)
252
+
253
+
254
+ def quantize_kv(key: jax.Array, value: jax.Array,
255
+ kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
256
+ v_scale: float) -> Tuple[jax.Array, jax.Array]:
257
+ """
258
+ Quantize the key and value tensors.
259
+
260
+ Args:
261
+ key: The key tensor to quantize.
262
+ value: The value tensor to quantize.
263
+ kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
264
+ q_scale: The scale to quantize the key and value tensors by.
265
+ k_scale: The scale to quantize the key tensor by.
266
+ v_scale: The scale to quantize the value tensor by.
267
+
268
+ Returns:
269
+ Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
270
+ """
271
+ dtype_info = jnp.finfo(kv_cache_quantized_dtype)
272
+ minval, maxval = float(dtype_info.min), float(dtype_info.max)
273
+ key = key.astype(jnp.float32) / k_scale
274
+ key = jnp.clip(key, minval, maxval)
275
+ key = key.astype(kv_cache_quantized_dtype)
276
+ value = value.astype(jnp.float32) / v_scale
277
+ value = jnp.clip(value, minval, maxval)
278
+ value = value.astype(kv_cache_quantized_dtype)
279
+
280
+ return key, value
281
+
282
+
283
+ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
284
+ """
285
+ Get the JAX dtype from a string dtype.
286
+
287
+ Args:
288
+ str_dtype: The string dtype to get the JAX dtype from.
289
+
290
+ Returns:
291
+ jnp.dtype: The JAX dtype.
292
+ """
293
+ str_dtype = str_dtype.lower().strip()
294
+ return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
File without changes
@@ -0,0 +1,129 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ #
4
+ # WARNING: This is a temporary compatibility module.
5
+ #
6
+ #
7
+ # THE PROBLEM:
8
+ # The ideal dependency injection pattern dictates that the "producer" of data
9
+ # (in this case, the vLLM engine) should be responsible for adapting its data
10
+ # into the abstract format that the "consumer" (the TPU worker) expects.
11
+ #
12
+ # However, this would require a simultaneous code change in both the `vllm` and
13
+ # `tpu_inference` repositories. Such cross-repository changes are difficult to
14
+ # coordinate, slow to land, and can easily cause breakages if the releases
15
+ # are not perfectly synchronized.
16
+ #
17
+ #
18
+ # THE TEMPORARY SOLUTION:
19
+ # To enable independent development and deployment, we are temporarily violating
20
+ # this pattern. We are making the consumer (`tpu_inference`) responsible for
21
+ # detecting and adapting the producer's raw data.
22
+ #
23
+ # This function checks if it has received a raw `vllm.SchedulerOutput` and,
24
+ # if so, wraps it in the appropriate adapter. This allows `vllm` to continue
25
+ # sending its raw data type without modification, decoupling the release cycles.
26
+ #
27
+ #
28
+ # THE FUTURE (HOW TO REMOVE THIS):
29
+ # This entire file should be deleted once the `vllm` repository has been updated.
30
+ # The required change in `vllm` is small and looks like this:
31
+ #
32
+ # --- SKELETON CODE FOR FUTURE vLLM CHANGE ---
33
+ # In the vLLM engine, where `execute_model` is called:
34
+ #
35
+ # from tpu_inference.adapters.vllm_adapters import VllmSchedulerOutputAdapter
36
+ # from vllm.v1.core.sched.output import SchedulerOutput
37
+ #
38
+ # # ... inside some method ...
39
+ #
40
+ # # OLD CODE:
41
+ # # concrete_work = SchedulerOutput(...)
42
+ # # self.tpu_backend.execute_model(concrete_work)
43
+ #
44
+ # # NEW CODE:
45
+ # concrete_work = SchedulerOutput(...)
46
+ # adapted_work = VllmSchedulerOutputAdapter(concrete_work) # This line is added
47
+ # self.tpu_backend.execute_model(adapted_work) # Pass the adapter
48
+ #
49
+ # --- END SKELETON CODE ---
50
+ #
51
+
52
+ import logging
53
+ from typing import Union
54
+
55
+ from vllm.lora.request import LoRARequest as VllmLoRARequest
56
+ # Import the concrete vLLM type for the check
57
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
58
+ from vllm.v1.kv_cache_interface import KVCacheConfig as VllmKVCacheConfig
59
+
60
+ from tpu_inference.adapters.vllm_adapters import (VllmKVCacheConfigAdapter,
61
+ VllmLoRARequestAdapter,
62
+ VllmSchedulerOutputAdapter)
63
+ from tpu_inference.di.abstracts import (AbstractKVCacheConfig,
64
+ AbstractLoRARequest,
65
+ AbstractSchedulerOutput)
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ def adapt_scheduler_output_if_needed(
71
+ scheduler_output: Union[AbstractSchedulerOutput, VllmSchedulerOutput]
72
+ ) -> AbstractSchedulerOutput:
73
+ """
74
+ Checks if the input is a raw VllmSchedulerOutput and wraps it.
75
+ If it's already an AbstractSchedulerOutput, it's passed through.
76
+ """
77
+ if isinstance(scheduler_output, VllmSchedulerOutput):
78
+ # logger.warning(
79
+ # "Received raw VllmSchedulerOutput. Performing temporary, on-the-fly "
80
+ # "adaptation. This is a compatibility feature and should be removed "
81
+ # "once the vLLM engine is updated to provide an adapted object.")
82
+ return VllmSchedulerOutputAdapter(scheduler_output)
83
+
84
+ if isinstance(scheduler_output, AbstractSchedulerOutput):
85
+ return scheduler_output
86
+
87
+ raise TypeError(
88
+ f"Unsupported type for scheduler_output: {type(scheduler_output)}")
89
+
90
+
91
+ def adapt_kv_cache_config_if_needed(
92
+ kv_cache_config: Union[AbstractKVCacheConfig, VllmKVCacheConfig]
93
+ ) -> AbstractKVCacheConfig:
94
+ """
95
+ Checks if the input is a raw VllmKVCacheConfig and wraps it.
96
+ If it's already an AbstractKVCacheConfig, it's passed through.
97
+ """
98
+ if isinstance(kv_cache_config, VllmKVCacheConfig):
99
+ # logger.warning(
100
+ # "Received raw VllmKVCacheConfig. Performing temporary, on-the-fly "
101
+ # "adaptation. This is a compatibility feature and should be removed "
102
+ # "once the vLLM engine is updated to provide an adapted object.")
103
+ return VllmKVCacheConfigAdapter(kv_cache_config)
104
+
105
+ if isinstance(kv_cache_config, AbstractKVCacheConfig):
106
+ return kv_cache_config
107
+
108
+ raise TypeError(
109
+ f"Unsupported type for kv_cache_config: {type(kv_cache_config)}")
110
+
111
+
112
+ def adapt_lora_request_if_needed(
113
+ lora_request: Union[AbstractLoRARequest, VllmLoRARequest]
114
+ ) -> AbstractLoRARequest:
115
+ """
116
+ Checks if the input is a raw VllmLoRARequest and wraps it.
117
+ If it's already an AbstractLoRARequest, it's passed through.
118
+ """
119
+ if isinstance(lora_request, VllmLoRARequest):
120
+ # logger.warning(
121
+ # "Received raw VllmLoRARequest. Performing temporary, on-the-fly "
122
+ # "adaptation. This is a compatibility feature and should be removed "
123
+ # "once the vLLM engine is updated to provide an adapted object.")
124
+ return VllmLoRARequestAdapter(lora_request)
125
+
126
+ if isinstance(lora_request, AbstractLoRARequest):
127
+ return lora_request
128
+
129
+ raise TypeError(f"Unsupported type for lora_request: {type(lora_request)}")
@@ -0,0 +1,100 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Optional, Union
5
+
6
+ import torch.nn as nn
7
+ from vllm.lora.request import LoRARequest
8
+ from vllm.v1.core.sched.output import SchedulerOutput
9
+ from vllm.v1.kv_cache_interface import KVCacheConfig
10
+ from vllm.v1.outputs import ModelRunnerOutput
11
+
12
+ from tpu_inference.di.abstracts import (AbstractKVCacheConfig,
13
+ AbstractKVCacheSpec,
14
+ AbstractLoRARequest,
15
+ AbstractSchedulerOutput)
16
+ from tpu_inference.di.interfaces import HostInterface
17
+
18
+
19
+ class AbstractTpuWorker(ABC):
20
+ """Base class for TPU workers.
21
+
22
+ This class defines a pure, host-agnostic contract for what a TPU worker
23
+ must be able to do. It is intentionally decoupled from any specific host
24
+ system like vLLM or SGLang.
25
+
26
+ Architectural Note on Dependencies:
27
+ This abstract class only depends on other abstractions (e.g., HostInterface).
28
+ It does NOT hold configuration objects from any specific host (e.g.,
29
+ VllmConfig). Doing so would create a "leaky abstraction," forcing all
30
+ future implementations to depend on a concrete detail from a single host.
31
+
32
+ The responsibility for managing concrete configuration is pushed down to the
33
+ concrete subclasses (e.g., TPUWorkerJax), which keeps this base class
34
+ pure and truly reusable across different host systems.
35
+ """
36
+
37
+ def __init__(self, host_interface: Optional[HostInterface] = None):
38
+ self.host_interface = host_interface
39
+
40
+ @abstractmethod
41
+ def initialize_cache(self, num_gpu_blocks: int,
42
+ num_cpu_blocks: int) -> None:
43
+ """Initialize the cache with the given number of blocks."""
44
+ pass
45
+
46
+ @abstractmethod
47
+ def init_device(self):
48
+ """Initialize the TPU device and distributed environment."""
49
+ pass
50
+
51
+ @abstractmethod
52
+ def determine_available_memory(self) -> int:
53
+ """Determine available memory for the TPU worker."""
54
+ pass
55
+
56
+ @abstractmethod
57
+ def execute_model(
58
+ self,
59
+ scheduler_output: Union[AbstractSchedulerOutput, SchedulerOutput],
60
+ ) -> Optional[ModelRunnerOutput]:
61
+ pass
62
+
63
+ @abstractmethod
64
+ def profile(self, is_start: bool = True):
65
+ pass
66
+
67
+ @abstractmethod
68
+ def add_lora(
69
+ self,
70
+ lora_request: Union[AbstractLoRARequest, LoRARequest],
71
+ ) -> bool:
72
+ pass
73
+
74
+ @abstractmethod
75
+ def load_model(self) -> None:
76
+ pass
77
+
78
+ @abstractmethod
79
+ def compile_or_warm_up_model(self) -> None:
80
+ pass
81
+
82
+ @abstractmethod
83
+ def get_model(self) -> nn.Module:
84
+ pass
85
+
86
+ @abstractmethod
87
+ def get_kv_cache_spec(self) -> dict[str, AbstractKVCacheSpec]:
88
+ pass
89
+
90
+ @abstractmethod
91
+ def initialize_from_config(
92
+ self,
93
+ kv_cache_config: Union[AbstractKVCacheConfig, KVCacheConfig],
94
+ ) -> None:
95
+ """Allocate KV cache with the specified kv_cache_config."""
96
+ pass
97
+
98
+ def check_health(self) -> None:
99
+ # worker will always be healthy as long as it's running.
100
+ return