tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
tpu_inference/utils.py ADDED
@@ -0,0 +1,317 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import os
3
+ import time
4
+ from collections import defaultdict
5
+ from collections.abc import Sequence
6
+ from functools import wraps
7
+ from typing import Any, Callable, List, Tuple
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import numpy as np
12
+ from jax._src import dtypes
13
+ from jax._src import mesh as mesh_lib
14
+ from jax._src import xla_bridge as xb
15
+ from jax._src.lib import xla_client as xc
16
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
17
+ from vllm import envs, utils
18
+
19
+ from tpu_inference.logger import init_logger
20
+
21
+ GBYTES = 1024 * 1024 * 1024
22
+ TPU_HEAD_SIZE_ALIGNMENT = 128
23
+ TPU_SECOND_LAST_MINOR = 8
24
+
25
+ # This is used to translate from a string name for a dtype
26
+ # to formal jax.numpy DType. One use case for this is
27
+ # converting the `--kv_cache_dtype` flag to a dtype.
28
+ TPU_STR_DTYPE_TO_JAX_DTYPE = {
29
+ "bfloat16": jnp.bfloat16,
30
+ "fp8": jnp.float8_e4m3fn,
31
+ "fp8_e4m3": jnp.float8_e4m3,
32
+ "fp8_e5m2": jnp.float8_e5m2,
33
+ "int8": jnp.int8,
34
+ }
35
+
36
+ _megacore = False
37
+ logger = init_logger(__name__)
38
+
39
+
40
+ def enable_megacore() -> None:
41
+ global _megacore
42
+ _megacore = True
43
+
44
+
45
+ def get_megacore() -> bool:
46
+ return _megacore
47
+
48
+
49
+ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
50
+ if tp_size <= num_kv_heads:
51
+ assert num_kv_heads % tp_size == 0
52
+ return num_kv_heads
53
+ else:
54
+ assert tp_size % num_kv_heads == 0
55
+ return tp_size
56
+
57
+
58
+ def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
59
+ usage = []
60
+ if envs.VLLM_TPU_USING_PATHWAYS:
61
+ return pathways_hbm_usage_gb(devices)
62
+
63
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
64
+ if multihost_backend == "ray":
65
+ # MemoryStats is only supported for addressable PjRt devices.
66
+ # Assume all the devices have similar memory usage for now.
67
+ # TODO(ranlihao): find a proper way to get the memory usage of each device.
68
+ for device in devices:
69
+ try:
70
+ hbm_used = device.memory_stats()["bytes_in_use"]
71
+ hbm_limit = device.memory_stats()["bytes_limit"]
72
+ logger.info(
73
+ "Get memory stats for device %s. Assuming all devices have the same usage.",
74
+ device)
75
+ usage.extend([(hbm_used, hbm_limit)] * len(devices))
76
+ break
77
+ except Exception as e:
78
+ logger.warning(
79
+ "Failed to get memory stats for device %s: %s. ", device,
80
+ e)
81
+ else:
82
+ for device in devices:
83
+ hbm_used = device.memory_stats()["bytes_in_use"]
84
+ hbm_limit = device.memory_stats()["bytes_limit"]
85
+ usage.append((hbm_used, hbm_limit))
86
+
87
+ return usage
88
+
89
+
90
+ def get_device_name(num_devices: int | None = None):
91
+ kind = jax.devices()[0].device_kind
92
+ if 'TPU' not in kind:
93
+ raise RuntimeError('Expected TPU devices')
94
+ suffix = ''
95
+ if kind.endswith(' lite'):
96
+ kind = kind[:-len(' lite')]
97
+ suffix = 'e'
98
+ elif kind.endswith('e'):
99
+ kind = kind[:-1]
100
+ suffix = 'e'
101
+ elif kind.endswith('p'):
102
+ kind = kind[:-1]
103
+ suffix = 'p'
104
+ elif kind == 'TPU7x':
105
+ kind = 'TPU v7'
106
+ assert kind[:-1] == 'TPU v', kind
107
+ kind += suffix
108
+ if num_devices is not None:
109
+ kind += f'-{num_devices}'
110
+ return kind
111
+
112
+
113
+ def get_device_hbm_limit() -> int:
114
+
115
+ device_kind = get_device_name()
116
+ if device_kind == "TPU v5p" or device_kind == "TPU v5":
117
+ return 95 * GBYTES
118
+ elif device_kind == "TPU v5e":
119
+ return 16 * GBYTES
120
+ elif device_kind == "TPU v6e" or device_kind == "TPU v4":
121
+ return 32 * GBYTES
122
+ elif device_kind == "TPU v7":
123
+ # 192 * GBYTES / 2 because each JAX device (v7x core) has
124
+ # 1/2 of the total chip HBM
125
+ return 96 * GBYTES
126
+ else:
127
+ raise ValueError(f"Unknown device kind: {device_kind}")
128
+
129
+
130
+ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
131
+ live_arrays = jax.live_arrays()
132
+ hbm_used = defaultdict(int)
133
+ hbm_limit = get_device_hbm_limit()
134
+ for array in live_arrays:
135
+ for buffer in array.addressable_shards:
136
+ hbm_used[buffer.data.device] += buffer.data.nbytes
137
+ return [(hbm_used[device], hbm_limit) for device in devices]
138
+
139
+
140
+ def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
141
+ usage = hbm_usage_bytes(devices)
142
+ usage = [(round(used / GBYTES, 2), round(limit / GBYTES, 2))
143
+ for used, limit in usage]
144
+ return usage
145
+
146
+
147
+ def get_padded_head_dim(head_dim: int) -> int:
148
+ """Pads head_dim up to the nearest multiple of 128 for kernel performance."""
149
+ # When head_dim == 64, we use kernel specificly optimized for it which does
150
+ # not require any padding.
151
+ if head_dim == 64:
152
+ return 64
153
+ return (head_dim + 127) // 128 * 128
154
+
155
+
156
+ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
157
+ if num_heads >= sharding_size:
158
+ assert num_heads % sharding_size == 0
159
+ else:
160
+ assert sharding_size % num_heads == 0
161
+ num_heads = sharding_size
162
+ return num_heads
163
+
164
+
165
+ def get_dtype_packing(dtype):
166
+ bits = dtypes.bit_width(dtype)
167
+ return 32 // bits
168
+
169
+
170
+ def make_optimized_mesh(axis_shapes: Sequence[int],
171
+ axis_names: Sequence[str],
172
+ *,
173
+ devices: Sequence[xc.Device] | None = None):
174
+ if devices is None:
175
+ devices = xb.devices()
176
+ # Sort the devices in case it's passed in an arbitary order
177
+ devices = sorted(devices, key=lambda x: x.coords)
178
+
179
+ def _is_1D(axis_shapes):
180
+ return sum(x > 1 for x in axis_shapes) == 1
181
+
182
+ if _is_1D(axis_shapes):
183
+ dev_kind = devices[0].device_kind
184
+ device_num = len(devices)
185
+ if dev_kind == "TPU v6 lite":
186
+ ordered_devices = None
187
+ # NOTE(chengjiyao):
188
+ # The coords of v6e-8 are
189
+ # (0,0,0)
190
+ # (1,0,0)
191
+ # (0,1,0)
192
+ # (1,1,0)
193
+ # (0,2,0)
194
+ # (1,2,0)
195
+ # (0,3,0)
196
+ # (1,3,0)
197
+ if device_num == 8:
198
+ ordered_devices = np.array([
199
+ devices[0],
200
+ devices[1],
201
+ devices[2],
202
+ devices[3],
203
+ devices[7],
204
+ devices[6],
205
+ devices[5],
206
+ devices[4],
207
+ ])
208
+ # NOTE(chengjiyao):
209
+ # The coords of v6e-4 are
210
+ # (0,0,0)
211
+ # (1,0,0)
212
+ # (0,1,0)
213
+ # (1,1,0)
214
+ elif device_num == 4:
215
+ ordered_devices = np.array([
216
+ devices[0],
217
+ devices[1],
218
+ devices[3],
219
+ devices[2],
220
+ ])
221
+ if ordered_devices is not None:
222
+ ordered_devices = np.array(ordered_devices)
223
+ ordered_devices = ordered_devices.reshape(axis_shapes)
224
+ mesh = mesh_lib.Mesh(ordered_devices, axis_names)
225
+ logger.info("Use customized mesh: %s", mesh)
226
+ return mesh
227
+
228
+ return jax.make_mesh(axis_shapes, axis_names, devices=devices)
229
+
230
+
231
+ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
232
+ """
233
+ Create a device array with the specified mesh and sharding.
234
+
235
+ Args:
236
+ mesh: The JAX mesh to use for device placement
237
+ *args: Positional arguments to pass to jax.device_put
238
+ sharding: Optional sharding specification. If None, uses PartitionSpec(None)
239
+ **kwargs: Keyword arguments to pass to jax.device_put
240
+
241
+ Returns:
242
+ A JAX array placed on the specified devices
243
+ """
244
+ if sharding is None:
245
+ sharding = NamedSharding(mesh, PartitionSpec(None))
246
+ return jax.device_put(*args, device=sharding, **kwargs)
247
+
248
+
249
+ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
250
+ """
251
+ A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
252
+ """
253
+ if hash_fn_name == "builtin":
254
+ return hash
255
+ return utils.get_hash_fn_by_name(hash_fn_name)
256
+
257
+
258
+ def quantize_kv(key: jax.Array, value: jax.Array,
259
+ kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
260
+ v_scale: float) -> Tuple[jax.Array, jax.Array]:
261
+ """
262
+ Quantize the key and value tensors.
263
+
264
+ Args:
265
+ key: The key tensor to quantize.
266
+ value: The value tensor to quantize.
267
+ kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
268
+ q_scale: The scale to quantize the key and value tensors by.
269
+ k_scale: The scale to quantize the key tensor by.
270
+ v_scale: The scale to quantize the value tensor by.
271
+
272
+ Returns:
273
+ Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
274
+ """
275
+ dtype_info = jnp.finfo(kv_cache_quantized_dtype)
276
+ minval, maxval = float(dtype_info.min), float(dtype_info.max)
277
+ key = key.astype(jnp.float32) / k_scale
278
+ key = jnp.clip(key, minval, maxval)
279
+ key = key.astype(kv_cache_quantized_dtype)
280
+ value = value.astype(jnp.float32) / v_scale
281
+ value = jnp.clip(value, minval, maxval)
282
+ value = value.astype(kv_cache_quantized_dtype)
283
+
284
+ return key, value
285
+
286
+
287
+ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
288
+ """
289
+ Get the JAX dtype from a string dtype.
290
+
291
+ Args:
292
+ str_dtype: The string dtype to get the JAX dtype from.
293
+
294
+ Returns:
295
+ jnp.dtype: The JAX dtype.
296
+ """
297
+ str_dtype = str_dtype.lower().strip()
298
+ return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
299
+
300
+
301
+ def time_function(func):
302
+ """
303
+ A decorator to measure the execution time of a function.
304
+ """
305
+
306
+ @wraps(func)
307
+ def wrapper(*args, **kwargs):
308
+ start_time = time.perf_counter()
309
+ result = func(*args, **kwargs)
310
+ end_time = time.perf_counter()
311
+ execution_time = end_time - start_time
312
+ logger.debug(
313
+ f"Function '{func.__name__}' executed in {execution_time:.4f} seconds."
314
+ )
315
+ return result
316
+
317
+ return wrapper
File without changes
@@ -0,0 +1,321 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ import tempfile
5
+ from typing import Callable, Dict, Optional, Tuple
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jaxlib
10
+ import jaxtyping
11
+ import vllm.envs as vllm_envs
12
+ from vllm.config import VllmConfig, set_current_vllm_config
13
+ from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
14
+ has_kv_transfer_group)
15
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
16
+ init_distributed_environment)
17
+ from vllm.lora.request import LoRARequest
18
+ from vllm.tasks import SupportedTask
19
+ from vllm.v1 import utils as vllm_utils
20
+ from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
21
+ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
22
+ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
23
+ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
24
+
25
+ from tpu_inference import envs, utils
26
+ from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
27
+ get_node_id)
28
+ from tpu_inference.layers.common.sharding import ShardingConfigManager
29
+ from tpu_inference.logger import init_logger
30
+ from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
31
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
32
+
33
+ logger = init_logger(__name__)
34
+
35
+ _DTYPE: dict[str, jnp.dtype] = {
36
+ "bfloat16": jnp.bfloat16,
37
+ "float": jnp.float32,
38
+ "float32": jnp.float32,
39
+ }
40
+
41
+
42
+ class TPUWorker:
43
+
44
+ def __init__(self,
45
+ vllm_config: VllmConfig,
46
+ local_rank: int,
47
+ rank: int,
48
+ distributed_init_method: str,
49
+ is_driver_worker: bool = False,
50
+ devices=None):
51
+ # If we use vLLM's model implementation in PyTorch, we should set it
52
+ # with torch version of the dtype.
53
+ impl = envs.MODEL_IMPL_TYPE
54
+ if impl != "vllm": # vllm-pytorch implementation does not need this conversion
55
+
56
+ # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
57
+ if not isinstance(vllm_config.model_config.dtype, str):
58
+ logger.warning(
59
+ "The model dtype is not properly set for JAX backend. "
60
+ "Overwriting it to jnp.bfloat16")
61
+ vllm_config.model_config.dtype = jnp.bfloat16
62
+ else:
63
+ vllm_config.model_config.dtype = _DTYPE.get(
64
+ vllm_config.model_config.dtype, jnp.bfloat16)
65
+
66
+ self.vllm_config = vllm_config
67
+ self.model_config = vllm_config.model_config
68
+ self.parallel_config = vllm_config.parallel_config
69
+ self.cache_config = vllm_config.cache_config
70
+ self.local_rank = local_rank
71
+ self.rank = rank
72
+ self.distributed_init_method = distributed_init_method
73
+ self.is_driver_worker = is_driver_worker
74
+ self.devices = devices if devices is not None else []
75
+ self.device_ranks = set(device.id for device in self.devices
76
+ if isinstance(device, jaxlib._jax.Device))
77
+
78
+ if self.model_config.trust_remote_code:
79
+ # note: lazy import to avoid importing torch before initializing
80
+ from vllm.utils import init_cached_hf_modules
81
+
82
+ init_cached_hf_modules()
83
+
84
+ # Delay profiler initialization to the start of the profiling.
85
+ # This is because in vLLM V1, MP runtime is initialized before the
86
+ # TPU Worker is initialized. The profiler server needs to start after
87
+ # MP runtime is initialized.
88
+ self.profile_dir = None
89
+ if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
90
+ if not self.devices or 0 in self.device_ranks:
91
+ # For TPU, we can only have 1 active profiler session for 1 profiler
92
+ # server. So we only profile on rank0.
93
+ self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR
94
+ logger.info("Profiling enabled. Traces will be saved to: %s",
95
+ self.profile_dir)
96
+
97
+ use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
98
+ # Only one instance of profiler is allowed
99
+ if use_jax_profiler_server and self.rank < 1:
100
+ if not self.devices or 0 in self.device_ranks:
101
+ jax_profiler_server_port = int(
102
+ os.getenv("JAX_PROFILER_SERVER_PORT", 9999))
103
+ logger.info(
104
+ f"Starting JAX profiler server on port {jax_profiler_server_port}"
105
+ )
106
+ jax.profiler.start_server(jax_profiler_server_port)
107
+
108
+ def initialize_cache(self, num_gpu_blocks: int,
109
+ num_cpu_blocks: int) -> None:
110
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
111
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
112
+
113
+ def init_device(self):
114
+ if not self.devices:
115
+ sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
116
+ device_indexes = sharding_config.device_indexes
117
+ if device_indexes is not None and len(device_indexes) > 0:
118
+ # Enforcing the devices sequence to be consistent with the specified device indexes
119
+ all_devices = jax.devices()
120
+ device_dict = {device.id: device for device in all_devices}
121
+ self.devices = []
122
+ for device_index in device_indexes:
123
+ device = device_dict[device_index]
124
+ if device is None:
125
+ raise KeyError(
126
+ f"Device index {device_index} not found in "
127
+ f"jax.devices() with IDs {list(device_dict.keys())}!"
128
+ )
129
+ self.devices.append(device)
130
+ self.devices = self.devices[:sharding_config.total_devices]
131
+ else:
132
+ self.devices = jax.devices()[:sharding_config.total_devices]
133
+
134
+ # Initialize the vLLM distribution layer as a single chip environment,
135
+ # we'll swap the model's parallel modules with TPU SPMD equivalents.
136
+ with set_current_vllm_config(self.vllm_config):
137
+ temp_file = tempfile.mkstemp()[1]
138
+ init_distributed_environment(
139
+ world_size=1,
140
+ rank=0,
141
+ local_rank=0,
142
+ distributed_init_method=f"file://{temp_file}",
143
+ backend="gloo",
144
+ )
145
+ ensure_model_parallel_initialized(
146
+ tensor_model_parallel_size=1,
147
+ pipeline_model_parallel_size=1,
148
+ )
149
+ ensure_kv_transfer_initialized(self.vllm_config)
150
+ self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
151
+ logger.info(f"Init worker | "
152
+ f"rank={self.rank} | "
153
+ f"node_id={get_node_id()} | "
154
+ f"is_driver_worker={self.is_driver_worker} | "
155
+ f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
156
+ vllm_utils.report_usage_stats(self.vllm_config)
157
+
158
+ def determine_available_memory(self) -> int:
159
+ gpu_memory_utilization = self.cache_config.gpu_memory_utilization
160
+ hbm_usage = utils.hbm_usage_bytes(self.devices)
161
+ total_hbm_limit = total_hbm_used = 0
162
+ for used, limit in hbm_usage:
163
+ total_hbm_used += used
164
+ total_hbm_limit += limit
165
+
166
+ total_hbm_limit_cap = total_hbm_limit * gpu_memory_utilization
167
+ total_hbm_avail = int(total_hbm_limit_cap - total_hbm_used)
168
+
169
+ total_hbm_limit_gb = round(total_hbm_limit / utils.GBYTES, 2)
170
+ total_hbm_limit_cap_gb = round(total_hbm_limit_cap / utils.GBYTES, 2)
171
+ total_hbm_used_gb = round(total_hbm_used / utils.GBYTES, 2)
172
+ total_hbm_avail_gb = round(total_hbm_avail / utils.GBYTES, 2)
173
+
174
+ logger.info(f"Memory statistics | "
175
+ f"{total_hbm_limit_gb=}GiB | "
176
+ f"{total_hbm_limit_cap_gb=}GiB | "
177
+ f"{total_hbm_used_gb=}GiB | "
178
+ f"{total_hbm_avail_gb=}GiB")
179
+
180
+ if total_hbm_avail <= 0:
181
+ raise ValueError(f"{total_hbm_used_gb=}GiB exceeds "
182
+ f"{total_hbm_limit_cap_gb=}GiB by "
183
+ f"{-total_hbm_avail_gb}GiB. Please consider "
184
+ f"increasing --gpu-memory-utilization from "
185
+ f"{gpu_memory_utilization} to a larger value.")
186
+ return total_hbm_avail
187
+
188
+ def execute_model(
189
+ self,
190
+ scheduler_output: SchedulerOutput,
191
+ ) -> Optional[ModelRunnerOutput]:
192
+ # NOTE: This method intentionally returns a concrete vLLM type, which
193
+ # violates the pure abstract contract of the base class. This is a
194
+ # deliberate, temporary compromise for the same reasons outlined in
195
+ # the `get_kv_cache_spec` method.
196
+
197
+ output = self.model_runner.execute_model(scheduler_output)
198
+
199
+ # With a connector, the scheduler expects output from all workers
200
+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
201
+ if has_kv_transfer_group():
202
+ return output
203
+
204
+ return output if self.is_driver_worker else None
205
+
206
+ def sample_tokens(self,
207
+ grammar_output: GrammarOutput) -> ModelRunnerOutput:
208
+ return self.model_runner.sample_tokens(grammar_output)
209
+
210
+ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
211
+ return self.model_runner.take_draft_token_ids()
212
+
213
+ def add_lora(
214
+ self,
215
+ lora_request: LoRARequest,
216
+ ) -> bool:
217
+ raise NotImplementedError(
218
+ "LoRA is not supported by the JAX worker yet.")
219
+
220
+ def profile(self, is_start: bool = True):
221
+ if is_start:
222
+ options = jax.profiler.ProfileOptions()
223
+ # default: https://docs.jax.dev/en/latest/profiling.html#general-options
224
+ options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
225
+ options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
226
+ jax.profiler.start_trace(self.profile_dir,
227
+ profiler_options=options)
228
+ else:
229
+ jax.profiler.stop_trace()
230
+
231
+ def load_model(self) -> None:
232
+ self.model_runner.load_model()
233
+
234
+ def compile_or_warm_up_model(self) -> None:
235
+ self.model_runner.capture_model()
236
+ # Reset the seed to ensure that the random state is not affected by
237
+ # the model initialization and profiling.
238
+ self.model_runner._init_random()
239
+
240
+ def reset_mm_cache(self) -> None:
241
+ pass
242
+
243
+ def get_model(self):
244
+ return self.model_runner.get_model()
245
+
246
+ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
247
+ return self.model_runner.get_supported_tasks()
248
+
249
+ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
250
+ # NOTE: This method intentionally returns a concrete vLLM type, which
251
+ # violates the pure abstract contract of the base class. This is a
252
+ # deliberate, temporary compromise.
253
+ #
254
+ # The vLLM executor that calls this method expects the concrete
255
+ # `vllm.KVCacheSpec` object to perform its own internal logic. If we
256
+ # returned an abstract adapter, the vLLM code would break.
257
+ #
258
+ # The ideal long-term solution is for the vLLM DI container to be
259
+ # responsible for this translation. When vLLM can be modified, this
260
+ # method should be changed to return `dict[str, AbstractKVCacheSpec]`,
261
+ # and the vLLM side should be updated to handle the translation.
262
+ kv_cache_specs = self.model_runner.get_kv_cache_spec()
263
+
264
+ if len(kv_cache_specs) == 0:
265
+ return kv_cache_specs
266
+
267
+ # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
268
+ # feature that allows overriding page_size_bytes of KVCacheSpec.
269
+ vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
270
+ rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
271
+ kv_cache_specs)
272
+
273
+ if vllm_page_size_bytes != rpa_page_size_bytes:
274
+ logger.info(
275
+ f"KV cache page size calculated by vLLM "
276
+ f"({vllm_page_size_bytes} Bytes) does not match with actual "
277
+ f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
278
+ f"Recalculating number of KV blocks using actual page size.")
279
+
280
+ available_memory = self.determine_available_memory()
281
+ num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
282
+ available_memory, rpa_page_size_bytes)
283
+
284
+ cache_config = self.vllm_config.cache_config
285
+ cache_config.num_gpu_blocks_override = num_blocks
286
+
287
+ return kv_cache_specs
288
+
289
+ def initialize_from_config(
290
+ self,
291
+ kv_cache_config: KVCacheConfig,
292
+ ) -> None:
293
+ """Allocate GPU KV cache with the specified kv_cache_config."""
294
+ self.model_runner.initialize_kv_cache(kv_cache_config)
295
+
296
+ def get_node_kv_ip_port(self) -> tuple[int, str, int]:
297
+ node_id = get_node_id()
298
+ ip = get_host_ip()
299
+ port = get_kv_transfer_port()
300
+ return (int(node_id), ip, int(port))
301
+
302
+ def check_health(self) -> None:
303
+ # worker will always be healthy as long as it's running.
304
+ return
305
+
306
+ def sync_weights(
307
+ self,
308
+ updated_weights: jaxtyping.PyTree,
309
+ mappings: Dict[str, Tuple[str, Tuple[str]]],
310
+ transpose_keys: Dict[str, Tuple[int]],
311
+ reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
312
+ jaxtyping.PyTree] = None
313
+ ) -> None:
314
+ """Sync the updated weights to the model runner."""
315
+ return self.model_runner._sync_weights(updated_weights=updated_weights,
316
+ mappings=mappings,
317
+ transpose_keys=transpose_keys,
318
+ reshard_fn=reshard_fn)
319
+
320
+ def shutdown(self) -> None:
321
+ return