tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,458 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ import tempfile
5
+ from dataclasses import dataclass, field
6
+ from typing import Callable, Dict, Optional, Tuple
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import jaxlib
11
+ import jaxtyping
12
+ import vllm.envs as vllm_envs
13
+ from vllm.config import VllmConfig, set_current_vllm_config
14
+ from vllm.distributed import get_pp_group
15
+ from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
16
+ has_kv_transfer_group)
17
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
18
+ init_distributed_environment)
19
+ from vllm.lora.request import LoRARequest
20
+ from vllm.tasks import SupportedTask
21
+ from vllm.v1 import utils as vllm_utils
22
+ from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
23
+ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
24
+ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
25
+ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
26
+
27
+ from tpu_inference import envs, utils
28
+ from tpu_inference.distributed import jax_parallel_state
29
+ from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
30
+ get_node_id)
31
+ from tpu_inference.layers.common.sharding import ShardingConfigManager
32
+ from tpu_inference.logger import init_logger
33
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
34
+ JaxIntermediateTensors
35
+ from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
36
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
37
+
38
+ logger = init_logger(__name__)
39
+
40
+ _DTYPE: dict[str, jnp.dtype] = {
41
+ "bfloat16": jnp.bfloat16,
42
+ "float": jnp.float32,
43
+ "float32": jnp.float32,
44
+ }
45
+
46
+
47
+ @dataclass
48
+ class PPConfig:
49
+ rank: int
50
+ ip: str
51
+ prev_worker_ip: str
52
+ pp_world_size: int
53
+
54
+ # default env vars for
55
+ # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
56
+ # if PP is used in single host.
57
+ default_tpu_process_bounds: str = field(init=False)
58
+ default_tpu_chips_per_process_bounds: str = field(init=False)
59
+ default_tpu_visible_chips: str = field(init=False)
60
+
61
+ def __post_init__(self):
62
+ self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
63
+ self.default_tpu_chips_per_process_bounds = "1,1,1"
64
+ self.default_tpu_visible_chips = f"{self.rank}"
65
+
66
+
67
+ class TPUWorker:
68
+
69
+ def __init__(
70
+ self,
71
+ vllm_config: VllmConfig,
72
+ local_rank: int,
73
+ rank: int,
74
+ distributed_init_method: str,
75
+ is_driver_worker: bool = False,
76
+ devices=None,
77
+ ip: str = "localhost",
78
+ prev_worker_ip: str = "localhost",
79
+ ):
80
+ # If we use vLLM's model implementation in PyTorch, we should set it
81
+ # with torch version of the dtype.
82
+ impl = envs.MODEL_IMPL_TYPE
83
+ if impl != "vllm": # vllm-pytorch implementation does not need this conversion
84
+
85
+ # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
86
+ if not isinstance(vllm_config.model_config.dtype, str):
87
+ logger.warning(
88
+ "The model dtype is not properly set for JAX backend. "
89
+ "Overwriting it to jnp.bfloat16")
90
+ vllm_config.model_config.dtype = jnp.bfloat16
91
+ else:
92
+ vllm_config.model_config.dtype = _DTYPE.get(
93
+ vllm_config.model_config.dtype, jnp.bfloat16)
94
+
95
+ self.vllm_config = vllm_config
96
+ self.model_config = vllm_config.model_config
97
+ self.parallel_config = vllm_config.parallel_config
98
+ self.cache_config = vllm_config.cache_config
99
+ self.local_rank = local_rank
100
+ self.rank = rank
101
+ self.distributed_init_method = distributed_init_method
102
+ self.is_driver_worker = is_driver_worker
103
+ self.devices = devices if devices is not None else []
104
+ self.device_ranks = set(device.id for device in self.devices
105
+ if isinstance(device, jaxlib._jax.Device))
106
+ self.pp_config = PPConfig(rank, ip, prev_worker_ip,
107
+ self.parallel_config.pipeline_parallel_size)
108
+
109
+ if self.model_config.trust_remote_code:
110
+ # note: lazy import to avoid importing torch before initializing
111
+ from vllm.utils.import_utils import init_cached_hf_modules
112
+
113
+ init_cached_hf_modules()
114
+
115
+ # Delay profiler initialization to the start of the profiling.
116
+ # This is because in vLLM V1, MP runtime is initialized before the
117
+ # TPU Worker is initialized. The profiler server needs to start after
118
+ # MP runtime is initialized.
119
+ self.profile_dir = None
120
+ if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
121
+ if not self.devices or 0 in self.device_ranks:
122
+ # For TPU, we can only have 1 active profiler session for 1 profiler
123
+ # server. So we only profile on rank0.
124
+ self.profile_dir = vllm_envs.VLLM_TORCH_PROFILER_DIR
125
+ logger.info("Profiling enabled. Traces will be saved to: %s",
126
+ self.profile_dir)
127
+
128
+ # For PP, we use MPMD so we want to profile every worker.
129
+ if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
130
+ self.profile_dir = os.path.join(
131
+ vllm_envs.VLLM_TORCH_PROFILER_DIR,
132
+ f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
133
+ )
134
+ os.makedirs(self.profile_dir, exist_ok=True)
135
+
136
+ use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
137
+ # Only one instance of profiler is allowed
138
+ if use_jax_profiler_server and self.rank < 1:
139
+ if not self.devices or 0 in self.device_ranks:
140
+ jax_profiler_server_port = int(
141
+ os.getenv("JAX_PROFILER_SERVER_PORT", 9999))
142
+ logger.info(
143
+ f"Starting JAX profiler server on port {jax_profiler_server_port}"
144
+ )
145
+ jax.profiler.start_server(jax_profiler_server_port)
146
+
147
+ # step_counter is used to calculate uuid to transfer intermediate tensors.
148
+ self.step_counter = 0
149
+
150
+ def initialize_cache(self, num_gpu_blocks: int,
151
+ num_cpu_blocks: int) -> None:
152
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
153
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
154
+
155
+ def init_device(self,
156
+ tpu_process_bounds="",
157
+ tpu_chips_per_process_bounds="",
158
+ tpu_visible_chips=""):
159
+ # set tpu visible devices for Jax runtime in single host PP.
160
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
161
+ if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
162
+ tpu_ports = [
163
+ jax_parallel_state.BASE_JAX_PORT + i
164
+ for i in range(self.pp_config.pp_world_size)
165
+ ]
166
+ os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
167
+ [f"localhost:{port}" for port in tpu_ports])
168
+ os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
169
+ os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
170
+
171
+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
172
+ # Replace with your own topology.
173
+ # There are 2 ways of subslicing a v6e
174
+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
175
+ # TPU_PROCESS_BOUNDS = "1,1,1"
176
+ # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
177
+ # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
178
+ # 2) 1 chip for each subslice, with at most 8 subslices,
179
+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
180
+ os.environ[
181
+ "TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
182
+ if tpu_process_bounds \
183
+ else self.pp_config.default_tpu_process_bounds
184
+ os.environ[
185
+ "TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
186
+ if tpu_chips_per_process_bounds \
187
+ else self.pp_config.default_tpu_chips_per_process_bounds
188
+ os.environ[
189
+ "TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
190
+ if tpu_visible_chips \
191
+ else self.pp_config.default_tpu_visible_chips
192
+
193
+ if not self.devices:
194
+ sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
195
+ device_indexes = sharding_config.device_indexes
196
+ if device_indexes is not None and len(device_indexes) > 0:
197
+ # Enforcing the devices sequence to be consistent with the specified device indexes
198
+ all_local_devices = jax.local_devices()
199
+ device_dict = {
200
+ device.id: device
201
+ for device in all_local_devices
202
+ }
203
+ self.devices = []
204
+ for device_index in device_indexes:
205
+ device = device_dict[device_index]
206
+ if device is None:
207
+ raise KeyError(
208
+ f"Device index {device_index} not found in "
209
+ f"jax.local_devices() with IDs {list(device_dict.keys())}!"
210
+ )
211
+ self.devices.append(device)
212
+ assert len(self.devices) >= sharding_config.total_devices
213
+ self.devices = self.devices[:sharding_config.total_devices]
214
+ else:
215
+ if self.pp_config.pp_world_size > 1:
216
+ # We only support a mixed tp + pp scenario that tp size is
217
+ # smaller or equals the total TPUs in one node
218
+ # say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
219
+ assert jax.local_device_count(
220
+ ) >= sharding_config.total_devices
221
+ self.devices = jax.local_devices()[:sharding_config.
222
+ total_devices]
223
+ else:
224
+ # In a multi-host distributed env, say: Ray, local_device count may smaller
225
+ # than the total devices, we just choose the smaller set here.
226
+ self.devices = jax.devices()[:sharding_config.
227
+ total_devices]
228
+
229
+ # Initialize the vLLM distribution layer as a single chip environment,
230
+ # we'll swap the model's parallel modules with TPU SPMD equivalents.
231
+ with set_current_vllm_config(self.vllm_config):
232
+ temp_file = tempfile.mkstemp()[1]
233
+ init_distributed_environment(
234
+ world_size=1,
235
+ rank=0,
236
+ local_rank=0,
237
+ distributed_init_method=f"file://{temp_file}",
238
+ backend="gloo",
239
+ )
240
+ ensure_model_parallel_initialized(
241
+ tensor_model_parallel_size=1,
242
+ pipeline_model_parallel_size=1,
243
+ )
244
+
245
+ jax_parallel_state.init_pp_distributed_environment(
246
+ self.pp_config.ip,
247
+ self.rank,
248
+ self.parallel_config.pipeline_parallel_size,
249
+ self.devices[0],
250
+ need_pp=self.parallel_config.pipeline_parallel_size > 1)
251
+
252
+ ensure_kv_transfer_initialized(self.vllm_config)
253
+ self.model_runner = TPUModelRunner(
254
+ self.vllm_config, self.devices, self.rank, self.rank == 0,
255
+ self.rank == self.pp_config.pp_world_size - 1)
256
+ logger.info(f"Init worker | "
257
+ f"rank={self.rank} | "
258
+ f"node_id={get_node_id()} | "
259
+ f"is_driver_worker={self.is_driver_worker} | "
260
+ f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
261
+ vllm_utils.report_usage_stats(self.vllm_config)
262
+
263
+ def initialize_pp_transfer_connect(self):
264
+ if self.rank == 0:
265
+ return
266
+ jax_parallel_state.connect(self.pp_config.prev_worker_ip,
267
+ self.rank - 1)
268
+
269
+ def determine_available_memory(self) -> int:
270
+ gpu_memory_utilization = self.cache_config.gpu_memory_utilization
271
+ hbm_usage = utils.hbm_usage_bytes(self.devices)
272
+ total_hbm_limit = total_hbm_used = 0
273
+ for used, limit in hbm_usage:
274
+ total_hbm_used += used
275
+ total_hbm_limit += limit
276
+
277
+ total_hbm_limit_cap = total_hbm_limit * gpu_memory_utilization
278
+ total_hbm_avail = int(total_hbm_limit_cap - total_hbm_used)
279
+
280
+ total_hbm_limit_gb = round(total_hbm_limit / utils.GBYTES, 2)
281
+ total_hbm_limit_cap_gb = round(total_hbm_limit_cap / utils.GBYTES, 2)
282
+ total_hbm_used_gb = round(total_hbm_used / utils.GBYTES, 2)
283
+ total_hbm_avail_gb = round(total_hbm_avail / utils.GBYTES, 2)
284
+
285
+ logger.info(f"Memory statistics | "
286
+ f"{total_hbm_limit_gb=}GiB | "
287
+ f"{total_hbm_limit_cap_gb=}GiB | "
288
+ f"{total_hbm_used_gb=}GiB | "
289
+ f"{total_hbm_avail_gb=}GiB")
290
+
291
+ if total_hbm_avail <= 0:
292
+ raise ValueError(f"{total_hbm_used_gb=}GiB exceeds "
293
+ f"{total_hbm_limit_cap_gb=}GiB by "
294
+ f"{-total_hbm_avail_gb}GiB. Please consider "
295
+ f"increasing --gpu-memory-utilization from "
296
+ f"{gpu_memory_utilization} to a larger value.")
297
+ return total_hbm_avail
298
+
299
+ def execute_model(
300
+ self,
301
+ scheduler_output: SchedulerOutput,
302
+ ) -> Optional[ModelRunnerOutput]:
303
+ # NOTE: This method intentionally returns a concrete vLLM type, which
304
+ # violates the pure abstract contract of the base class. This is a
305
+ # deliberate, temporary compromise for the same reasons outlined in
306
+ # the `get_kv_cache_spec` method.
307
+
308
+ if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
309
+ intermediate_tensors = None
310
+ else:
311
+ # receive intermediate tensors
312
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
313
+ scheduler_output, self.rank - 1, self.step_counter)
314
+ # TODO: this method might only works for vllm model, not sure about jax models.
315
+ tensor_spec = self.model_runner.get_intermediate_tensor_spec(
316
+ scheduler_output.total_num_scheduled_tokens)
317
+ intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
318
+ uuid, tensor_spec)
319
+ intermediate_tensors = JaxIntermediateTensors(
320
+ intermediate_tensors_dict)
321
+
322
+ output = self.model_runner.execute_model(scheduler_output,
323
+ intermediate_tensors)
324
+
325
+ if isinstance(output, JaxIntermediateTensors):
326
+ assert self.parallel_config.pipeline_parallel_size > 1
327
+ assert not get_pp_group().is_last_rank
328
+ # send intermediate tensors
329
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
330
+ scheduler_output, self.rank, self.step_counter)
331
+ get_pp_group().send_tensor_dict(uuid, output.tensors)
332
+ self.step_counter += 1
333
+ return None
334
+ else:
335
+ self.step_counter += 1
336
+ # With a connector, the scheduler expects output from all workers
337
+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
338
+ if has_kv_transfer_group():
339
+ return output
340
+ return output if self.is_driver_worker else None
341
+
342
+ def sample_tokens(self,
343
+ grammar_output: GrammarOutput) -> ModelRunnerOutput:
344
+ return self.model_runner.sample_tokens(grammar_output)
345
+
346
+ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
347
+ return self.model_runner.take_draft_token_ids()
348
+
349
+ def add_lora(
350
+ self,
351
+ lora_request: LoRARequest,
352
+ ) -> bool:
353
+ raise NotImplementedError(
354
+ "LoRA is not supported by the JAX worker yet.")
355
+
356
+ def profile(self, is_start: bool = True):
357
+ if is_start:
358
+ options = jax.profiler.ProfileOptions()
359
+ # default: https://docs.jax.dev/en/latest/profiling.html#general-options
360
+ options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
361
+ options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
362
+ jax.profiler.start_trace(self.profile_dir,
363
+ profiler_options=options)
364
+ else:
365
+ jax.profiler.stop_trace()
366
+
367
+ def load_model(self) -> None:
368
+ self.model_runner.load_model()
369
+
370
+ def compile_or_warm_up_model(self) -> None:
371
+ self.model_runner.capture_model()
372
+ # Reset the seed to ensure that the random state is not affected by
373
+ # the model initialization and profiling.
374
+ self.model_runner._init_random()
375
+
376
+ def reset_mm_cache(self) -> None:
377
+ pass
378
+
379
+ def get_model(self):
380
+ return self.model_runner.get_model()
381
+
382
+ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
383
+ return self.model_runner.get_supported_tasks()
384
+
385
+ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
386
+ # NOTE: This method intentionally returns a concrete vLLM type, which
387
+ # violates the pure abstract contract of the base class. This is a
388
+ # deliberate, temporary compromise.
389
+ #
390
+ # The vLLM executor that calls this method expects the concrete
391
+ # `vllm.KVCacheSpec` object to perform its own internal logic. If we
392
+ # returned an abstract adapter, the vLLM code would break.
393
+ #
394
+ # The ideal long-term solution is for the vLLM DI container to be
395
+ # responsible for this translation. When vLLM can be modified, this
396
+ # method should be changed to return `dict[str, AbstractKVCacheSpec]`,
397
+ # and the vLLM side should be updated to handle the translation.
398
+ kv_cache_specs = self.model_runner.get_kv_cache_spec()
399
+
400
+ if len(kv_cache_specs) == 0:
401
+ return kv_cache_specs
402
+
403
+ # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
404
+ # feature that allows overriding page_size_bytes of KVCacheSpec.
405
+ vllm_page_size_bytes = get_uniform_page_size(
406
+ list(kv_cache_specs.values()))
407
+ rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
408
+ kv_cache_specs)
409
+
410
+ if vllm_page_size_bytes != rpa_page_size_bytes:
411
+ logger.info(
412
+ f"KV cache page size calculated by vLLM "
413
+ f"({vllm_page_size_bytes} Bytes) does not match with actual "
414
+ f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
415
+ f"Recalculating number of KV blocks using actual page size.")
416
+
417
+ available_memory = self.determine_available_memory()
418
+ num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
419
+ available_memory, rpa_page_size_bytes)
420
+
421
+ cache_config = self.vllm_config.cache_config
422
+ cache_config.num_gpu_blocks_override = num_blocks
423
+
424
+ return kv_cache_specs
425
+
426
+ def initialize_from_config(
427
+ self,
428
+ kv_cache_config: KVCacheConfig,
429
+ ) -> None:
430
+ """Allocate GPU KV cache with the specified kv_cache_config."""
431
+ self.model_runner.initialize_kv_cache(kv_cache_config)
432
+
433
+ def get_node_kv_ip_port(self) -> tuple[int, str, int]:
434
+ node_id = get_node_id()
435
+ ip = get_host_ip()
436
+ port = get_kv_transfer_port()
437
+ return (int(node_id), ip, int(port))
438
+
439
+ def check_health(self) -> None:
440
+ # worker will always be healthy as long as it's running.
441
+ return
442
+
443
+ def sync_weights(
444
+ self,
445
+ updated_weights: jaxtyping.PyTree,
446
+ mappings: Dict[str, Tuple[str, Tuple[str]]],
447
+ transpose_keys: Dict[str, Tuple[int]],
448
+ reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
449
+ jaxtyping.PyTree] = None
450
+ ) -> None:
451
+ """Sync the updated weights to the model runner."""
452
+ return self.model_runner._sync_weights(updated_weights=updated_weights,
453
+ mappings=mappings,
454
+ transpose_keys=transpose_keys,
455
+ reshard_fn=reshard_fn)
456
+
457
+ def shutdown(self) -> None:
458
+ return
@@ -0,0 +1,108 @@
1
+ Metadata-Version: 2.4
2
+ Name: tpu_inference
3
+ Version: 0.0.1rc1
4
+ Author: tpu_inference Contributors
5
+ Classifier: Development Status :: 3 - Alpha
6
+ Classifier: Intended Audience :: Developers
7
+ Classifier: Intended Audience :: Education
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: >=3.10
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: tpu-info==0.7.1
18
+ Requires-Dist: yapf==0.43.0
19
+ Requires-Dist: pytest
20
+ Requires-Dist: pytest-mock
21
+ Requires-Dist: absl-py
22
+ Requires-Dist: numpy
23
+ Requires-Dist: google-cloud-storage
24
+ Requires-Dist: jax[tpu]==0.8.0
25
+ Requires-Dist: jaxlib==0.8.0
26
+ Requires-Dist: jaxtyping
27
+ Requires-Dist: flax==0.11.1
28
+ Requires-Dist: torchax==0.0.7
29
+ Requires-Dist: qwix==0.1.1
30
+ Requires-Dist: torchvision==0.24.0
31
+ Requires-Dist: pathwaysutils
32
+ Requires-Dist: parameterized
33
+ Requires-Dist: numba==0.62.1
34
+ Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
35
+ Dynamic: author
36
+ Dynamic: classifier
37
+ Dynamic: description
38
+ Dynamic: description-content-type
39
+ Dynamic: license-file
40
+ Dynamic: requires-dist
41
+ Dynamic: requires-python
42
+
43
+ <p align="center">
44
+ <!-- This image will ONLY show up in GitHub's dark mode -->
45
+ <img src="docs/assets/tpu_inference_dark_mode_short.png#gh-dark-mode-only" alt="vLLM TPU" style="width: 86%;">
46
+ <!-- This image will ONLY show up in GitHub's light mode (and on other platforms) -->
47
+ <img src="docs/assets/tpu_inference_light_mode_short.png#gh-light-mode-only" alt="vLLM TPU" style="width: 86%;">
48
+ </p>
49
+
50
+ <p align="center">
51
+ | <a href="https://docs.vllm.ai/projects/tpu/en/latest/"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> (#sig-tpu) |
52
+ </p>
53
+
54
+ ---
55
+
56
+ _Upcoming Events_ 🔥
57
+
58
+ - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
59
+ - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
+ - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
61
+
62
+ _Latest News_ 🔥
63
+
64
+ - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
65
+
66
+ <details>
67
+ <summary><i>Previous News</i> 🔥</summary>
68
+
69
+ </details>
70
+
71
+ ---
72
+ ## About
73
+
74
+ vLLM TPU is now powered by `tpu-inference`, an expressive and powerful new hardware plugin unifying JAX and PyTorch under a single lowering path within the vLLM project. The new backend now provides a framework for developers to:
75
+
76
+ - Push the limits of TPU hardware performance in open source.
77
+ - Provide more flexibility to JAX and PyTorch users by running PyTorch model definitions performantly on TPU without any additional code changes, while also extending native support to JAX.
78
+ - Retain vLLM standardization: keep the same user experience, telemetry, and interface.
79
+
80
+ ## Recommended models and features
81
+
82
+ Although vLLM TPU’s new unified backend makes out-of-the-box high performance serving possible with any model supported in vLLM, the reality is that we're still in the process of implementing a few core components.
83
+
84
+ For this reason, we’ve provided a **[Recommended Models and Features](https://docs.vllm.ai/projects/tpu/en/latest/recommended_models_features/)** page detailing the models and features that are validated through unit, integration, and performance testing.
85
+
86
+ ## Get started
87
+
88
+ Get started with vLLM on TPUs by following the [quickstart guide](https://docs.vllm.ai/projects/tpu/en/latest/getting_started/quickstart/).
89
+
90
+ Visit our [documentation](https://docs.vllm.ai/projects/tpu/en/latest/) to learn more.
91
+
92
+ **Compatible TPU Generations**
93
+ - Recommended: v5e, v6e
94
+ - Experimental: v3, v4, v5p
95
+
96
+ *Check out a few v6e recipes [here](https://github.com/AI-Hypercomputer/tpu-recipes/tree/main/inference/trillium/vLLM)!*
97
+
98
+ ## Contribute
99
+
100
+ We're always looking for ways to partner with the community to accelerate vLLM TPU development. If you're interested in contributing to this effort, check out the [Contributing guide](https://github.com/vllm-project/tpu-inference/blob/main/CONTRIBUTING.md) and [Issues](https://github.com/vllm-project/tpu-inference/issues) to start. We recommend filtering Issues on the [**good first issue** tag](https://github.com/vllm-project/tpu-inference/issues?q=is%3Aissue+state%3Aopen+label%3A%22good+first+issue%22) if it's your first time contributing.
101
+
102
+ ## Contact us
103
+
104
+ - For technical questions and feature requests, open a GitHub [Issue](https://github.com/vllm-project/tpu-inference/issues)
105
+ - For feature requests, please open one on Github [here](https://github.com/vllm-project/tpu-inference/issues/new/choose)
106
+ - For discussing with fellow users, use the [TPU support topic in the vLLM Forum](https://discuss.vllm.ai/c/hardware-support/google-tpu-support/27)
107
+ - For coordinating contributions and development, use the [Developer Slack](https://join.slack.com/share/enQtOTY2OTUxMDIyNjY1OS00M2MxYWQwZjAyMGZjM2MyZjRjNTA0ZjRkNjkzOTRhMzg0NDM2OTlkZDAxOTAzYmJmNzdkNDc4OGZjYTUwMmRh)
108
+ - For collaborations and partnerships, contact us at [vllm-tpu@google.com](mailto:vllm-tpu@google.com)