tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,468 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ import tempfile
5
+ from dataclasses import dataclass, field
6
+ from typing import Callable, Dict, Optional, Tuple
7
+
8
+ import jax
9
+ import jaxlib
10
+ import jaxtyping
11
+ import vllm.envs as vllm_envs
12
+ from vllm.config import VllmConfig, set_current_vllm_config
13
+ from vllm.distributed import get_pp_group
14
+ from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
15
+ has_kv_transfer_group)
16
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
17
+ init_distributed_environment)
18
+ from vllm.lora.request import LoRARequest
19
+ from vllm.tasks import SupportedTask
20
+ from vllm.v1 import utils as vllm_utils
21
+ from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
22
+ get_uniform_page_size)
23
+ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
24
+ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
25
+ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
26
+
27
+ from tpu_inference import envs, utils
28
+ from tpu_inference.distributed import jax_parallel_state
29
+ from tpu_inference.distributed.utils import (get_device_topology_order_id,
30
+ get_host_ip, get_kv_transfer_port)
31
+ from tpu_inference.layers.common.sharding import ShardingConfigManager
32
+ from tpu_inference.logger import init_logger
33
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
34
+ JaxIntermediateTensors
35
+ from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
36
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
37
+
38
+ logger = init_logger(__name__)
39
+
40
+
41
+ @dataclass
42
+ class PPConfig:
43
+ rank: int
44
+ ip: str
45
+ prev_worker_ip: str
46
+ pp_world_size: int
47
+
48
+ # default env vars for
49
+ # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
50
+ # if PP is used in single host.
51
+ default_tpu_process_bounds: str = field(init=False)
52
+ default_tpu_chips_per_process_bounds: str = field(init=False)
53
+ default_tpu_visible_chips: str = field(init=False)
54
+
55
+ def __post_init__(self):
56
+ self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
57
+ self.default_tpu_chips_per_process_bounds = "1,1,1"
58
+ self.default_tpu_visible_chips = f"{self.rank}"
59
+
60
+
61
+ class TPUWorker:
62
+
63
+ def __init__(
64
+ self,
65
+ vllm_config: VllmConfig,
66
+ local_rank: int,
67
+ rank: int,
68
+ distributed_init_method: str,
69
+ is_driver_worker: bool = False,
70
+ devices=None,
71
+ ip: str = "localhost",
72
+ prev_worker_ip: str = "localhost",
73
+ ):
74
+ self.vllm_config = vllm_config
75
+ self.model_config = vllm_config.model_config
76
+ self.parallel_config = vllm_config.parallel_config
77
+ self.cache_config = vllm_config.cache_config
78
+ self.local_rank = local_rank
79
+ self.rank = rank
80
+ self.distributed_init_method = distributed_init_method
81
+ self.is_driver_worker = is_driver_worker
82
+ self.devices = devices if devices is not None else []
83
+ self.device_ranks = set(device.id for device in self.devices
84
+ if isinstance(device, jaxlib._jax.Device))
85
+ self.pp_config = PPConfig(rank, ip, prev_worker_ip,
86
+ self.parallel_config.pipeline_parallel_size)
87
+
88
+ if self.model_config.trust_remote_code:
89
+ # note: lazy import to avoid importing torch before initializing
90
+ from vllm.utils.import_utils import init_cached_hf_modules
91
+
92
+ init_cached_hf_modules()
93
+
94
+ # Delay profiler initialization to the start of the profiling.
95
+ # This is because in vLLM V1, MP runtime is initialized before the
96
+ # TPU Worker is initialized. The profiler server needs to start after
97
+ # MP runtime is initialized.
98
+ self.profile_dir = None
99
+ if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
100
+ if not self.devices or 0 in self.device_ranks:
101
+ # For TPU, we can only have 1 active profiler session for 1 profiler
102
+ # server. So we only profile on rank0.
103
+ self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR
104
+ logger.info("Profiling enabled. Traces will be saved to: %s",
105
+ self.profile_dir)
106
+
107
+ # For PP, we use MPMD so we want to profile every worker.
108
+ if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
109
+ self.profile_dir = os.path.join(
110
+ vllm_envs.VLLM_TORCH_PROFILER_DIR,
111
+ f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
112
+ )
113
+ os.makedirs(self.profile_dir, exist_ok=True)
114
+
115
+ use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
116
+ # Only one instance of profiler is allowed
117
+ if use_jax_profiler_server and self.rank < 1:
118
+ if not self.devices or 0 in self.device_ranks:
119
+ jax_profiler_server_port = int(
120
+ os.getenv("JAX_PROFILER_SERVER_PORT", 9999))
121
+ logger.info(
122
+ f"Starting JAX profiler server on port {jax_profiler_server_port}"
123
+ )
124
+ jax.profiler.start_server(jax_profiler_server_port)
125
+
126
+ # step_counter is used to calculate uuid to transfer intermediate tensors.
127
+ self.step_counter = 0
128
+
129
+ def initialize_cache(self, num_gpu_blocks: int,
130
+ num_cpu_blocks: int) -> None:
131
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
132
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
133
+
134
+ def init_device(self,
135
+ tpu_process_bounds="",
136
+ tpu_chips_per_process_bounds="",
137
+ tpu_visible_chips=""):
138
+ # set tpu visible devices for Jax runtime in single host PP.
139
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
140
+ if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
141
+ tpu_ports = [
142
+ jax_parallel_state.BASE_JAX_PORT + i
143
+ for i in range(self.pp_config.pp_world_size)
144
+ ]
145
+ os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
146
+ [f"localhost:{port}" for port in tpu_ports])
147
+ os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
148
+ os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
149
+
150
+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
151
+ # Replace with your own topology.
152
+ # There are 2 ways of subslicing a v6e
153
+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
154
+ # TPU_PROCESS_BOUNDS = "1,1,1"
155
+ # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
156
+ # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
157
+ # 2) 1 chip for each subslice, with at most 8 subslices,
158
+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
159
+ os.environ[
160
+ "TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
161
+ if tpu_process_bounds \
162
+ else self.pp_config.default_tpu_process_bounds
163
+ os.environ[
164
+ "TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
165
+ if tpu_chips_per_process_bounds \
166
+ else self.pp_config.default_tpu_chips_per_process_bounds
167
+ os.environ[
168
+ "TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
169
+ if tpu_visible_chips \
170
+ else self.pp_config.default_tpu_visible_chips
171
+
172
+ if not self.devices:
173
+ sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
174
+ device_indexes = sharding_config.device_indexes
175
+ if device_indexes is not None and len(device_indexes) > 0:
176
+ # Enforcing the devices sequence to be consistent with the specified device indexes
177
+ all_local_devices = jax.local_devices()
178
+ device_dict = {
179
+ device.id: device
180
+ for device in all_local_devices
181
+ }
182
+ self.devices = []
183
+ for device_index in device_indexes:
184
+ device = device_dict[device_index]
185
+ if device is None:
186
+ raise KeyError(
187
+ f"Device index {device_index} not found in "
188
+ f"jax.local_devices() with IDs {list(device_dict.keys())}!"
189
+ )
190
+ self.devices.append(device)
191
+ assert len(self.devices) >= sharding_config.total_devices
192
+ self.devices = self.devices[:sharding_config.total_devices]
193
+ else:
194
+ if self.pp_config.pp_world_size > 1:
195
+ # We only support a mixed tp + pp scenario that tp size is
196
+ # smaller or equals the total TPUs in one node
197
+ # say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
198
+ assert jax.local_device_count(
199
+ ) >= sharding_config.total_devices
200
+ self.devices = jax.local_devices()[:sharding_config.
201
+ total_devices]
202
+ else:
203
+ # In a multi-host distributed env, say: Ray, local_device count may smaller
204
+ # than the total devices, we just choose the smaller set here.
205
+ self.devices = jax.devices()[:sharding_config.
206
+ total_devices]
207
+
208
+ # Initialize the vLLM distribution layer as a single chip environment,
209
+ # we'll swap the model's parallel modules with TPU SPMD equivalents.
210
+ with set_current_vllm_config(self.vllm_config):
211
+ temp_file = tempfile.mkstemp()[1]
212
+ init_distributed_environment(
213
+ world_size=1,
214
+ rank=0,
215
+ local_rank=0,
216
+ distributed_init_method=f"file://{temp_file}",
217
+ backend="gloo",
218
+ )
219
+ ensure_model_parallel_initialized(
220
+ tensor_model_parallel_size=1,
221
+ pipeline_model_parallel_size=1,
222
+ )
223
+
224
+ jax_parallel_state.init_pp_distributed_environment(
225
+ self.pp_config.ip,
226
+ self.rank,
227
+ self.parallel_config.pipeline_parallel_size,
228
+ self.devices[0],
229
+ need_pp=self.parallel_config.pipeline_parallel_size > 1)
230
+
231
+ ensure_kv_transfer_initialized(self.vllm_config)
232
+
233
+ is_first_rank = True
234
+ is_last_rank = True
235
+ self.topology_order_id = self.rank
236
+ if self.parallel_config.pipeline_parallel_size > 1:
237
+ is_first_rank = self.rank == 0
238
+ is_last_rank = self.rank == self.pp_config.pp_world_size - 1
239
+ else:
240
+ # topology_order_id is used to determine the KV cache
241
+ # mapping between P/D workers
242
+ if multihost_backend == "ray":
243
+ self.topology_order_id = get_device_topology_order_id(
244
+ jax.local_devices(), jax.devices())
245
+
246
+ self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
247
+ self.rank, is_first_rank,
248
+ is_last_rank)
249
+ logger.info(f"Init worker | "
250
+ f"rank={self.rank} | "
251
+ f"is_first_rank={is_first_rank} | "
252
+ f"is_last_rank={is_last_rank} | "
253
+ f"topology_order_id={self.topology_order_id} | "
254
+ f"is_driver_worker={self.is_driver_worker} | "
255
+ f"hbm={utils.hbm_usage_gb(self.devices)}GiB |"
256
+ f"self.devices={self.devices} | "
257
+ f"total devices={jax.devices()} | "
258
+ f"local_devices={jax.local_devices()}")
259
+ vllm_utils.report_usage_stats(self.vllm_config)
260
+
261
+ def initialize_pp_transfer_connect(self):
262
+ if self.rank == 0:
263
+ return
264
+ jax_parallel_state.connect(self.pp_config.prev_worker_ip,
265
+ self.rank - 1)
266
+
267
+ def determine_available_memory(self) -> int:
268
+ gpu_memory_utilization = self.cache_config.gpu_memory_utilization
269
+ hbm_usage = utils.hbm_usage_bytes(self.devices)
270
+ total_hbm_limit = total_hbm_used = 0
271
+ for used, limit in hbm_usage:
272
+ total_hbm_used += used
273
+ total_hbm_limit += limit
274
+
275
+ total_hbm_limit_cap = total_hbm_limit * gpu_memory_utilization
276
+ total_hbm_avail = int(total_hbm_limit_cap - total_hbm_used)
277
+
278
+ total_hbm_limit_gb = round(total_hbm_limit / utils.GBYTES, 2)
279
+ total_hbm_limit_cap_gb = round(total_hbm_limit_cap / utils.GBYTES, 2)
280
+ total_hbm_used_gb = round(total_hbm_used / utils.GBYTES, 2)
281
+ total_hbm_avail_gb = round(total_hbm_avail / utils.GBYTES, 2)
282
+
283
+ logger.info(f"Memory statistics | "
284
+ f"{total_hbm_limit_gb=}GiB | "
285
+ f"{total_hbm_limit_cap_gb=}GiB | "
286
+ f"{total_hbm_used_gb=}GiB | "
287
+ f"{total_hbm_avail_gb=}GiB")
288
+
289
+ if total_hbm_avail <= 0:
290
+ raise ValueError(f"{total_hbm_used_gb=}GiB exceeds "
291
+ f"{total_hbm_limit_cap_gb=}GiB by "
292
+ f"{-total_hbm_avail_gb}GiB. Please consider "
293
+ f"increasing --gpu-memory-utilization from "
294
+ f"{gpu_memory_utilization} to a larger value.")
295
+ return total_hbm_avail
296
+
297
+ def execute_model(
298
+ self,
299
+ scheduler_output: SchedulerOutput,
300
+ ) -> Optional[ModelRunnerOutput]:
301
+ # NOTE: This method intentionally returns a concrete vLLM type, which
302
+ # violates the pure abstract contract of the base class. This is a
303
+ # deliberate, temporary compromise for the same reasons outlined in
304
+ # the `get_kv_cache_spec` method.
305
+
306
+ if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
307
+ intermediate_tensors = None
308
+ else:
309
+ # receive intermediate tensors
310
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
311
+ scheduler_output, self.rank - 1, self.step_counter)
312
+ # TODO: this method might only works for vllm model, not sure about jax models.
313
+ tensor_spec = self.model_runner.get_intermediate_tensor_spec(
314
+ scheduler_output.total_num_scheduled_tokens)
315
+ intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
316
+ uuid, tensor_spec)
317
+ intermediate_tensors = JaxIntermediateTensors(
318
+ intermediate_tensors_dict)
319
+
320
+ output = self.model_runner.execute_model(scheduler_output,
321
+ intermediate_tensors)
322
+
323
+ if isinstance(output, JaxIntermediateTensors):
324
+ assert self.parallel_config.pipeline_parallel_size > 1
325
+ assert not get_pp_group().is_last_rank
326
+ # send intermediate tensors
327
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
328
+ scheduler_output, self.rank, self.step_counter)
329
+ get_pp_group().send_tensor_dict(uuid, output.tensors)
330
+ self.step_counter += 1
331
+ return None
332
+ else:
333
+ self.step_counter += 1
334
+ # With a connector, the scheduler expects output from all workers
335
+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
336
+ if has_kv_transfer_group():
337
+ return output
338
+ return output if self.is_driver_worker else None
339
+
340
+ def sample_tokens(self,
341
+ grammar_output: GrammarOutput) -> ModelRunnerOutput:
342
+ return self.model_runner.sample_tokens(grammar_output)
343
+
344
+ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
345
+ return self.model_runner.take_draft_token_ids()
346
+
347
+ def add_lora(
348
+ self,
349
+ lora_request: LoRARequest,
350
+ ) -> bool:
351
+ raise NotImplementedError(
352
+ "LoRA is not supported by the JAX worker yet.")
353
+
354
+ def profile(self, is_start: bool = True):
355
+ if is_start:
356
+ options = jax.profiler.ProfileOptions()
357
+ # default: https://docs.jax.dev/en/latest/profiling.html#general-options
358
+ options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
359
+ options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
360
+ jax.profiler.start_trace(self.profile_dir,
361
+ profiler_options=options)
362
+ else:
363
+ jax.profiler.stop_trace()
364
+
365
+ def load_model(self) -> None:
366
+ self.model_runner.load_model()
367
+
368
+ def compile_or_warm_up_model(self) -> None:
369
+ self.model_runner.capture_model()
370
+ # Reset the seed to ensure that the random state is not affected by
371
+ # the model initialization and profiling.
372
+ self.model_runner._init_random()
373
+
374
+ def reset_mm_cache(self) -> None:
375
+ pass
376
+
377
+ def get_model(self):
378
+ return self.model_runner.get_model()
379
+
380
+ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
381
+ return self.model_runner.get_supported_tasks()
382
+
383
+ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
384
+ # NOTE: This method intentionally returns a concrete vLLM type, which
385
+ # violates the pure abstract contract of the base class. This is a
386
+ # deliberate, temporary compromise.
387
+ #
388
+ # The vLLM executor that calls this method expects the concrete
389
+ # `vllm.KVCacheSpec` object to perform its own internal logic. If we
390
+ # returned an abstract adapter, the vLLM code would break.
391
+ #
392
+ # The ideal long-term solution is for the vLLM DI container to be
393
+ # responsible for this translation. When vLLM can be modified, this
394
+ # method should be changed to return `dict[str, AbstractKVCacheSpec]`,
395
+ # and the vLLM side should be updated to handle the translation.
396
+ kv_cache_spec = self.model_runner.get_kv_cache_spec()
397
+
398
+ if len(kv_cache_spec) == 0:
399
+ return kv_cache_spec
400
+
401
+ # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
402
+ # feature that allows overriding page_size_bytes of KVCacheSpec.
403
+ vllm_page_size_bytes = get_uniform_page_size(
404
+ list(kv_cache_spec.values()))
405
+ attention_page_size_bytes = get_attention_page_size_bytes(
406
+ self.model_runner.mesh, kv_cache_spec)
407
+
408
+ if vllm_page_size_bytes != attention_page_size_bytes:
409
+ logger.info(
410
+ f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
411
+ f"does not match with actual page size used by the kernel "
412
+ f"({attention_page_size_bytes} Bytes). Recalculating number of "
413
+ f"KV blocks using actual page size.")
414
+
415
+ kv_cache_groups = get_kv_cache_groups(self.vllm_config,
416
+ kv_cache_spec)
417
+ group_size = max(
418
+ len(group.layer_names) for group in kv_cache_groups)
419
+ available_memory = self.determine_available_memory()
420
+ num_blocks = get_num_blocks(self.vllm_config, group_size,
421
+ available_memory,
422
+ attention_page_size_bytes)
423
+ cache_config = self.vllm_config.cache_config
424
+ cache_config.num_gpu_blocks_override = num_blocks
425
+
426
+ return kv_cache_spec
427
+
428
+ def initialize_from_config(
429
+ self,
430
+ kv_cache_config: KVCacheConfig,
431
+ ) -> None:
432
+ """Allocate GPU KV cache with the specified kv_cache_config."""
433
+ # Precompile functions with large vocab_size tensors before allocating KV cache to avoid OOM
434
+ self.model_runner.compilation_manager._precompile_sampling()
435
+ self.model_runner.compilation_manager._precompile_gather_logprobs()
436
+ self.model_runner.initialize_kv_cache(kv_cache_config,
437
+ self.topology_order_id)
438
+
439
+ def get_node_kv_ip_port(self) -> tuple[int, str, int]:
440
+ ip = get_host_ip()
441
+ port = get_kv_transfer_port()
442
+ return (int(self.topology_order_id), ip, int(port))
443
+
444
+ def check_health(self) -> None:
445
+ # worker will always be healthy as long as it's running.
446
+ return
447
+
448
+ def sync_weights(
449
+ self,
450
+ updated_weights: jaxtyping.PyTree,
451
+ mappings: Dict[str, Tuple[str, Tuple[str]]],
452
+ transpose_keys: Dict[str, Tuple[int]],
453
+ reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
454
+ jaxtyping.PyTree] = None
455
+ ) -> None:
456
+ """Sync the updated weights to the model runner."""
457
+ return self.model_runner._sync_weights(updated_weights=updated_weights,
458
+ mappings=mappings,
459
+ transpose_keys=transpose_keys,
460
+ reshard_fn=reshard_fn)
461
+
462
+ def shutdown(self) -> None:
463
+ return
464
+
465
+ # Ray executor do not need handshake metadata
466
+ # as we pass the kv_parameters through proxy server
467
+ def get_kv_connector_handshake_metadata(self) -> None:
468
+ pass
@@ -0,0 +1,106 @@
1
+ Metadata-Version: 2.4
2
+ Name: tpu_inference
3
+ Version: 0.12.0.dev20251222
4
+ Author: tpu_inference Contributors
5
+ Classifier: Development Status :: 3 - Alpha
6
+ Classifier: Intended Audience :: Developers
7
+ Classifier: Intended Audience :: Education
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: >=3.10
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: tpu-info==0.7.1
18
+ Requires-Dist: yapf==0.43.0
19
+ Requires-Dist: pytest
20
+ Requires-Dist: pytest-mock
21
+ Requires-Dist: absl-py
22
+ Requires-Dist: numpy
23
+ Requires-Dist: google-cloud-storage
24
+ Requires-Dist: jax[tpu]==0.8.0
25
+ Requires-Dist: jaxlib==0.8.0
26
+ Requires-Dist: jaxtyping
27
+ Requires-Dist: flax==0.11.1
28
+ Requires-Dist: torchax==0.0.10
29
+ Requires-Dist: qwix==0.1.1
30
+ Requires-Dist: torchvision==0.24.0
31
+ Requires-Dist: pathwaysutils
32
+ Requires-Dist: parameterized
33
+ Requires-Dist: numba==0.62.1
34
+ Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
35
+ Dynamic: author
36
+ Dynamic: classifier
37
+ Dynamic: description
38
+ Dynamic: description-content-type
39
+ Dynamic: license-file
40
+ Dynamic: requires-dist
41
+ Dynamic: requires-python
42
+
43
+ <p align="center">
44
+ <!-- This image will ONLY show up in GitHub's dark mode -->
45
+ <img src="docs/assets/tpu_inference_dark_mode_short.png#gh-dark-mode-only" alt="vLLM TPU" style="width: 86%;">
46
+ <!-- This image will ONLY show up in GitHub's light mode (and on other platforms) -->
47
+ <img src="docs/assets/tpu_inference_light_mode_short.png#gh-light-mode-only" alt="vLLM TPU" style="width: 86%;">
48
+ </p>
49
+
50
+ <p align="center">
51
+ | <a href="https://docs.vllm.ai/projects/tpu/en/latest/"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> (#sig-tpu) |
52
+ </p>
53
+
54
+ ---
55
+
56
+ _Latest News_ 🔥
57
+
58
+ - [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
59
+ - Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
+ - Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
61
+
62
+ - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
63
+
64
+ <details>
65
+ <summary><i>Previous News</i> 🔥</summary>
66
+
67
+ </details>
68
+
69
+ ---
70
+ ## About
71
+
72
+ vLLM TPU is now powered by `tpu-inference`, an expressive and powerful new hardware plugin unifying JAX and PyTorch under a single lowering path within the vLLM project. The new backend now provides a framework for developers to:
73
+
74
+ - Push the limits of TPU hardware performance in open source.
75
+ - Provide more flexibility to JAX and PyTorch users by running PyTorch model definitions performantly on TPU without any additional code changes, while also extending native support to JAX.
76
+ - Retain vLLM standardization: keep the same user experience, telemetry, and interface.
77
+
78
+ ## Recommended models and features
79
+
80
+ Although vLLM TPU’s new unified backend makes out-of-the-box high performance serving possible with any model supported in vLLM, the reality is that we're still in the process of implementing a few core components.
81
+
82
+ For this reason, we’ve provided a **[Recommended Models and Features](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/)** page detailing the models and features that are validated through unit, integration, and performance testing.
83
+
84
+ ## Get started
85
+
86
+ Get started with vLLM on TPUs by following the [quickstart guide](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/quickstart/).
87
+
88
+ Visit our [documentation](https://docs.vllm.ai/projects/tpu/en/latest/) to learn more.
89
+
90
+ **Compatible TPU Generations**
91
+ - Recommended: v5e, v6e
92
+ - Experimental: v3, v4, v5p
93
+
94
+ *Check out a few v6e recipes [here](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/inference/trillium/vLLM)!*
95
+
96
+ ## Contribute
97
+
98
+ We're always looking for ways to partner with the community to accelerate vLLM TPU development. If you're interested in contributing to this effort, check out the [Contributing guide](https://github.com/vllm-project/tpu-inference/blob/main/CONTRIBUTING.md) and [Issues](https://github.com/vllm-project/tpu-inference/issues) to start. We recommend filtering Issues on the [**good first issue** tag](https://github.com/vllm-project/tpu-inference/issues?q=is%3Aissue+state%3Aopen+label%3A%22good+first+issue%22) if it's your first time contributing.
99
+
100
+ ## Contact us
101
+
102
+ - For technical questions and feature requests, open a GitHub [Issue](https://github.com/vllm-project/tpu-inference/issues)
103
+ - For feature requests, please open one on Github [here](https://github.com/vllm-project/tpu-inference/issues/new/choose)
104
+ - For discussing with fellow users, use the [TPU support topic in the vLLM Forum](https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27)
105
+ - For coordinating contributions and development, use the [Developer Slack](https://join.slack.com/share/enQtOTY2OTUxMDIyNjY1OS00M2MxYWQwZjAyMGZjM2MyZjRjNTA0ZjRkNjkzOTRhMzg0NDM2OTlkZDAxOTAzYmJmNzdkNDc4OGZjYTUwMmRh)
106
+ - For collaborations and partnerships, contact us at [vllm-tpu@google.com](mailto:vllm-tpu@google.com)