tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,399 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from array import array
17
+ from typing import Any, Dict, List, Optional
18
+
19
+ import ray
20
+ import vllm.envs as envs
21
+ from ray.util.placement_group import PlacementGroup
22
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
23
+ from vllm.multimodal.inputs import MultiModalKwargs
24
+ from vllm.platforms import current_platform
25
+ from vllm.ray.ray_env import get_env_vars_to_copy
26
+ from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
27
+ from vllm.utils.network_utils import (get_distributed_init_method, get_ip,
28
+ get_open_port)
29
+ from vllm.v1.core.sched.output import SchedulerOutput
30
+ from vllm.v1.executor.ray_distributed_executor import \
31
+ RayDistributedExecutor as RayDistributedExecutorV1
32
+ from vllm.v1.executor.ray_executor import RayWorkerMetaData
33
+ from vllm.v1.executor.ray_utils import RayWorkerWrapper, _wait_until_pg_ready
34
+
35
+ from tpu_inference.logger import init_logger
36
+
37
+ try:
38
+ from ray._private.state import available_resources_per_node
39
+ except ImportError:
40
+ # Ray 2.9.x doesn't expose `available_resources_per_node`
41
+ from ray._private.state import state as _state
42
+ available_resources_per_node = _state._available_resources_per_node
43
+
44
+ import asyncio
45
+ from collections import defaultdict
46
+
47
+ import msgspec
48
+ from vllm.v1.outputs import SamplerOutput
49
+
50
+ from tpu_inference.distributed.utils import set_node_kv_ip_port
51
+
52
+ logger = init_logger(__name__)
53
+
54
+
55
+ def _encode_hook(obj: Any) -> Any:
56
+ """Custom msgspec enc hook that supports array types and MultiModalKwargs.
57
+
58
+ See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
59
+ """
60
+ if isinstance(obj, array):
61
+ assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
62
+ f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
63
+ f"Given array has a type code of {obj.typecode}.")
64
+ return obj.tobytes()
65
+ if isinstance(obj, MultiModalKwargs):
66
+ return dict(obj)
67
+
68
+
69
+ class RayDistributedExecutor(RayDistributedExecutorV1):
70
+ """Ray-based distributed executor for TPU.
71
+
72
+ The implementation is similar to vllm/executor/ray_distributed_executor.py
73
+ with these major differences:
74
+
75
+ 1. self._init_executor():
76
+ VLLM_USE_RAY_SPMD_WORKER=1, in which the driver worker is the same as other workers.
77
+ 2. self._initialize_ray_cluster():
78
+ This sets placement_group_specs for TPU.
79
+ In vLLM one GPU maps to one placement group.
80
+ While here one TPU node with all chips maps to one placement group.
81
+ 3. self._init_workers_ray():
82
+ This set TPU resources when create each worker.
83
+ And we omit the driver worker related logic.
84
+ """
85
+
86
+ def _init_executor(self) -> None:
87
+ self.forward_dag: Optional[ray.dag.CompiledDAG] = None
88
+
89
+ os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
90
+
91
+ # Currently, this requires USE_RAY_SPMD_WORKER=True.
92
+ self.use_ray_compiled_dag = True
93
+ # If it is true, then we do not distinguish between the
94
+ # "driver worker" vs other workers. Also, the rank 0 worker will
95
+ # be executed in a remote Ray worker. Currently this requires
96
+ # USE_RAY_COMPILED_DAG=True.
97
+ self.use_ray_spmd_worker = True
98
+
99
+ assert self.uses_ray
100
+ self._initialize_ray_cluster()
101
+ placement_group = self.parallel_config.placement_group
102
+
103
+ # Disable Ray usage stats collection.
104
+ ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
105
+ if ray_usage != "1":
106
+ os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
107
+
108
+ # Create the parallel GPU workers.
109
+ self._init_workers_ray(placement_group)
110
+
111
+ self.input_encoder = msgspec.msgpack.Encoder(enc_hook=_encode_hook)
112
+ self.output_decoder = msgspec.msgpack.Decoder(
113
+ Optional[List[SamplerOutput]])
114
+
115
+ self.pp_locks: Optional[List[asyncio.Lock]] = None
116
+
117
+ self.scheduler_output: SchedulerOutput | None = None
118
+
119
+ # KV connector setup
120
+ self.has_connector = self.vllm_config.kv_transfer_config is not None
121
+ if self.has_connector:
122
+ ip_port = self.collective_rpc("get_node_kv_ip_port")
123
+ for item in ip_port:
124
+ set_node_kv_ip_port(item)
125
+ self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
126
+ self.vllm_config.ec_transfer_config is None
127
+ or not self.vllm_config.ec_transfer_config.is_ec_producer)
128
+
129
+ def _initialize_ray_cluster(self) -> None:
130
+ """Initialize the distributed cluster with Ray.
131
+
132
+ it will connect to the Ray cluster and create a placement group
133
+ for the workers, which includes the specification of the resources
134
+ for each distributed worker.
135
+ """
136
+ from vllm.platforms import current_platform
137
+
138
+ if ray.is_initialized():
139
+ logger.info(
140
+ "Ray is already initialized. Skipping Ray initialization.")
141
+ else:
142
+ logger.warning("Ray is not initialized, this is mainly for test.")
143
+ ray.init()
144
+
145
+ device_str = current_platform.ray_device_key
146
+ if not device_str:
147
+ raise ValueError(
148
+ f"current platform {current_platform.device_name} does not "
149
+ "support ray.")
150
+
151
+ pp_size = self.parallel_config.pipeline_parallel_size
152
+ placement_group_specs: List[Dict[str, float]] = []
153
+
154
+ ray_nodes = ray.nodes()
155
+ logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
156
+
157
+ if pp_size == 1:
158
+ placement_group_specs = [{
159
+ device_str: node['Resources'][device_str]
160
+ } for node in ray_nodes]
161
+ else:
162
+ assert pp_size == len(
163
+ ray_nodes
164
+ ), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
165
+ num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
166
+ placement_group_specs = [{
167
+ device_str: num_devices_per_pp_rank
168
+ } for _ in range(pp_size)]
169
+
170
+ # vLLM engine is also a worker to execute model with an accelerator,
171
+ # so it requires to have the device in a current node. Check if
172
+ # the current node has at least one device.
173
+ current_ip = get_ip()
174
+ current_node_id = ray.get_runtime_context().get_node_id()
175
+ current_node_resource = available_resources_per_node()[current_node_id]
176
+ if current_node_resource.get(device_str, 0) < 1:
177
+ raise ValueError(
178
+ f"Current node has no {device_str} available. "
179
+ f"{current_node_resource=}. vLLM engine cannot start without "
180
+ f"{device_str}. Make sure you have at least 1 {device_str} "
181
+ f"available in a node {current_node_id=} {current_ip=}.")
182
+ # This way, at least bundle is required to be created in a current
183
+ # node.
184
+ placement_group_specs[0][f"node:{current_ip}"] = 0.001
185
+ logger.info(
186
+ f"RayDistributedExecutor | placement_group_specs={placement_group_specs}"
187
+ )
188
+
189
+ # By default, Ray packs resources as much as possible.
190
+ current_placement_group = ray.util.placement_group(
191
+ placement_group_specs, strategy="PACK")
192
+ _wait_until_pg_ready(current_placement_group)
193
+
194
+ assert current_placement_group is not None
195
+ # Set the placement group in the parallel config
196
+ self.parallel_config.placement_group = current_placement_group
197
+
198
+ def _init_workers_ray(self, placement_group: "PlacementGroup",
199
+ **ray_remote_kwargs):
200
+ # The workers are the actual ray actors.
201
+ self.workers: List[RayWorkerWrapper] = []
202
+
203
+ # Used in ray compiled DAG: indexed first by PP rank,
204
+ # and then TP rank. In other words, the inner list is
205
+ # the TP group of workers for a PP rank.
206
+ self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
207
+
208
+ if self.parallel_config.ray_workers_use_nsight:
209
+ ray_remote_kwargs = self._configure_ray_workers_use_nsight(
210
+ ray_remote_kwargs)
211
+
212
+ # Create the workers.
213
+ bundle_indices: List[int]
214
+ if envs.VLLM_RAY_BUNDLE_INDICES:
215
+ # Use the bundle indices specified by the user.
216
+ bundle_indices = list(
217
+ map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
218
+ assert len(bundle_indices) == self.parallel_config.world_size, \
219
+ ("VLLM_RAY_BUNDLE_INDICES must have the same size"
220
+ f" as the world size, but got {bundle_indices=} "
221
+ f"and {self.parallel_config.world_size=}")
222
+ assert len(set(bundle_indices)) == len(bundle_indices), \
223
+ ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
224
+ f" but got {bundle_indices=}")
225
+ else:
226
+ bundle_indices = []
227
+ for bundle_id, bundle in enumerate(placement_group.bundle_specs):
228
+ if bundle.get(current_platform.ray_device_key, 0):
229
+ bundle_indices.append(bundle_id)
230
+
231
+ worker_metadata: List[RayWorkerMetaData] = []
232
+ driver_ip = get_ip()
233
+ num_tpu_per_worker = placement_group.bundle_specs[0].get(
234
+ current_platform.ray_device_key, 0)
235
+ for rank, bundle_id in enumerate(bundle_indices):
236
+ scheduling_strategy = PlacementGroupSchedulingStrategy(
237
+ placement_group=placement_group,
238
+ placement_group_capture_child_tasks=True,
239
+ placement_group_bundle_index=bundle_id,
240
+ )
241
+ worker = ray.remote(
242
+ num_cpus=0,
243
+ num_gpus=0,
244
+ resources={
245
+ current_platform.ray_device_key: num_tpu_per_worker
246
+ },
247
+ scheduling_strategy=scheduling_strategy,
248
+ **ray_remote_kwargs,
249
+ )(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
250
+ rpc_rank=rank)
251
+ worker_metadata.append(
252
+ RayWorkerMetaData(worker=worker, created_rank=rank))
253
+
254
+ worker_ips = ray.get([
255
+ each.worker.get_node_ip.remote() # type: ignore[attr-defined]
256
+ for each in worker_metadata
257
+ ])
258
+
259
+ for each, ip in zip(worker_metadata, worker_ips):
260
+ each.ip = ip
261
+
262
+ logger.debug(f"Initialized worker_metadata: {worker_metadata}")
263
+
264
+ ip_counts: Dict[str, int] = {}
265
+ for ip in worker_ips:
266
+ ip_counts[ip] = ip_counts.get(ip, 0) + 1
267
+
268
+ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
269
+ """
270
+ Sort the workers based on 3 properties:
271
+ 1. If the worker is on the same node as the driver (vllm engine),
272
+ it should be placed first.
273
+ 2. Then, if the worker is on a node with fewer workers, it should
274
+ be placed first.
275
+ 3. Finally, if the work is on a node with smaller IP address, it
276
+ should be placed first.
277
+ """
278
+ ip = item.ip
279
+ return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
280
+
281
+ # After sorting, the workers on the same node will be
282
+ # close to each other, and the workers on the driver
283
+ # node will be placed first.
284
+ sorted_worker_metadata = sorted(worker_metadata,
285
+ key=sort_by_driver_then_worker_ip)
286
+ start_rank = 0
287
+ for i, item in enumerate(sorted_worker_metadata):
288
+ item.adjusted_rank = i + start_rank
289
+ logger.info(
290
+ f"Initialized sorted worker_metadata: {sorted_worker_metadata}")
291
+
292
+ self.workers = [item.worker for item in sorted_worker_metadata]
293
+ rerank_mapping = {
294
+ item.created_rank: item.adjusted_rank
295
+ for item in sorted_worker_metadata
296
+ }
297
+ self.collective_rpc("adjust_rank", args=(rerank_mapping, ))
298
+
299
+ # Get the set of TPU IDs used on each node.
300
+ worker_node_and_tpu_ids = []
301
+ for worker in self.workers:
302
+ worker_node_and_tpu_ids.append(
303
+ ray.get(worker.get_node_and_gpu_ids.remote()) \
304
+ ) # type: ignore
305
+
306
+ node_workers = defaultdict(list) # node id -> list of worker ranks
307
+ node_tpus = defaultdict(list) # node id -> list of tpu ids
308
+
309
+ for i, (node_id, tpu_ids) in enumerate(worker_node_and_tpu_ids):
310
+ node_workers[node_id].append(i)
311
+ # `tpu_ids` can be a list of strings or integers.
312
+ # convert them to integers for consistency.
313
+ tpu_ids = [int(x) for x in tpu_ids]
314
+ node_tpus[node_id].extend(tpu_ids)
315
+ for node_id, tpu_ids in node_tpus.items():
316
+ node_tpus[node_id] = sorted(tpu_ids)
317
+ logger.info(
318
+ f"RayDistributedExecutor | node_workers={node_workers} | node_tpus={node_tpus}"
319
+ )
320
+
321
+ all_ips = set(worker_ips + [driver_ip])
322
+ n_ips = len(all_ips)
323
+ n_nodes = len(node_workers)
324
+
325
+ if n_nodes != n_ips:
326
+ logger.warning(
327
+ f"Got {n_nodes} nodes but with {n_ips} IP addresses. "
328
+ "This is not a typical production setup whose "
329
+ "number of nodes and IPs is euqal. This setup may "
330
+ "lead to unexpected behaviors.")
331
+
332
+ # Set environment variables for the driver and workers.
333
+ all_args_to_update_environment_variables = [{
334
+ current_platform.device_control_env_var:
335
+ ",".join(map(str, node_tpus[node_id])),
336
+ } for (node_id, _) in worker_node_and_tpu_ids]
337
+
338
+ # Environment variables to copy from driver to workers
339
+ env_vars_to_copy = get_env_vars_to_copy(
340
+ exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
341
+ additional_vars=set(current_platform.additional_env_vars),
342
+ destination="workers")
343
+
344
+ # Copy existing env vars to each worker's args
345
+ for args in all_args_to_update_environment_variables:
346
+ for name in env_vars_to_copy:
347
+ if name in os.environ:
348
+ args[name] = os.environ[name]
349
+
350
+ self._env_vars_for_all_workers = (
351
+ all_args_to_update_environment_variables)
352
+
353
+ self.collective_rpc("update_environment_variables",
354
+ args=(self._get_env_vars_to_be_updated(), ))
355
+
356
+ distributed_init_method = get_distributed_init_method(
357
+ driver_ip, get_open_port())
358
+
359
+ # Initialize the actual workers inside worker wrapper.
360
+ all_kwargs = []
361
+ for rank, (node_id, _) in enumerate(worker_node_and_tpu_ids):
362
+ local_rank = node_workers[node_id].index(rank)
363
+ ip = sorted_worker_metadata[rank].ip
364
+ prev_ip = sorted_worker_metadata[rank - 1].ip if rank > 0 else ""
365
+ kwargs = dict(
366
+ vllm_config=self.vllm_config,
367
+ local_rank=local_rank,
368
+ rank=rank,
369
+ distributed_init_method=distributed_init_method,
370
+ is_driver_worker=(not self.parallel_config)
371
+ or (rank % self.parallel_config.tensor_parallel_size == 0),
372
+ ip=ip,
373
+ prev_worker_ip=prev_ip,
374
+ )
375
+ all_kwargs.append(kwargs)
376
+ self.collective_rpc("init_worker", args=(all_kwargs, ))
377
+ self.collective_rpc("init_device")
378
+ if self.parallel_config.pipeline_parallel_size > 1:
379
+ self.collective_rpc("initialize_pp_transfer_connect")
380
+ self.collective_rpc("load_model")
381
+
382
+ if self.use_ray_spmd_worker:
383
+ for pp_rank in range(self.parallel_config.pipeline_parallel_size):
384
+ self.pp_tp_workers.append([])
385
+ num_tp_workers = int(
386
+ self.parallel_config.tensor_parallel_size //
387
+ num_tpu_per_worker)
388
+ for tp_rank in range(num_tp_workers):
389
+ # PP=2, TP=4, num_tpu_per_worker=2
390
+ # pp_tp_workers = [[0, 1], [2, 3]]
391
+ rank = (pp_rank * num_tp_workers) + tp_rank
392
+ assert len(self.pp_tp_workers[pp_rank]) == tp_rank
393
+ assert pp_rank < len(self.pp_tp_workers)
394
+ self.pp_tp_workers[pp_rank].append(self.workers[rank])
395
+
396
+ # Ray executor do not need handshake metadata
397
+ # as we pass the kv_parameters through proxy server
398
+ def get_kv_connector_handshake_metadata(self) -> None:
399
+ pass
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,272 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # TODO: Update documentation
16
+
17
+ from typing import List, Optional, Tuple
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from flax import nnx
22
+ from jax.sharding import Mesh
23
+ from jax.sharding import PartitionSpec as P
24
+ from vllm.config import VllmConfig
25
+
26
+ from tpu_inference.layers.jax.attention.attention import (Attention,
27
+ AttentionMetadata)
28
+ from tpu_inference.layers.jax.constants import KVCacheType
29
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
30
+ from tpu_inference.layers.jax.transformer_block import TransformerBlock
31
+ from tpu_inference.logger import init_logger
32
+ from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
33
+ load_hf_weights)
34
+
35
+ logger = init_logger(__name__)
36
+
37
+
38
+ class LlamaForCausalLM(nnx.Module):
39
+
40
+ def __init__(self,
41
+ vllm_config: VllmConfig,
42
+ rng: jax.Array,
43
+ mesh: Mesh,
44
+ force_random_weights: bool = False):
45
+ assert mesh is not None
46
+
47
+ self.vllm_config = vllm_config
48
+ self.rng = nnx.Rngs(rng)
49
+ self.mesh = mesh
50
+
51
+ model_name = self.vllm_config.model_config.model.lower()
52
+ if "70b" in model_name:
53
+ logger.info("Initializing Llama3 70B model variant.")
54
+ self.hidden_size = 8192
55
+ num_layers = 80
56
+ self.num_attention_heads = 64
57
+ self.num_key_value_heads = 8
58
+ intermediate_size = 28672
59
+ elif "8b" in model_name:
60
+ logger.info("Initializing Llama3 8B model variant.")
61
+ self.hidden_size = 4096
62
+ num_layers = 32
63
+ self.num_attention_heads = 32
64
+ self.num_key_value_heads = 8
65
+ intermediate_size = 14336
66
+ else:
67
+ raise ValueError(
68
+ f"Could not determine Llama3 variant (8B or 70B) from model name: '{model_name}'. "
69
+ "Please ensure '8b' or '70b' is in the model path.")
70
+
71
+ dtype = jnp.bfloat16
72
+ self.head_dim = 128
73
+ rope_theta = 500000.0
74
+ vocab_size = 128256
75
+ rms_norm_eps = 1e-5
76
+
77
+ self.embedder = Embedder(vocab_size=vocab_size,
78
+ hidden_size=self.hidden_size,
79
+ dtype=dtype,
80
+ rngs=self.rng,
81
+ random_init=force_random_weights,
82
+ vd_sharding=("model", None))
83
+
84
+ self.layers = []
85
+ kv_cache_dtype = self.vllm_config.cache_config.cache_dtype
86
+ for _ in range(num_layers):
87
+ self.layers.append(
88
+ TransformerBlock(
89
+ pre_attention_norm=RMSNorm(
90
+ dims=self.hidden_size,
91
+ random_init=force_random_weights,
92
+ epsilon=rms_norm_eps,
93
+ rngs=self.rng,
94
+ with_scale=True,
95
+ dtype=dtype,
96
+ ),
97
+ pre_mlp_norm=RMSNorm(
98
+ dims=self.hidden_size,
99
+ rngs=self.rng,
100
+ random_init=force_random_weights,
101
+ epsilon=rms_norm_eps,
102
+ with_scale=True,
103
+ dtype=dtype,
104
+ ),
105
+ attn=Attention(
106
+ hidden_size=self.hidden_size,
107
+ num_attention_heads=self.num_attention_heads,
108
+ num_key_value_heads=self.num_key_value_heads,
109
+ head_dim=self.head_dim,
110
+ rope_theta=rope_theta,
111
+ rope_scaling={},
112
+ rngs=self.rng,
113
+ dtype=dtype,
114
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
115
+ kv_cache_dtype=kv_cache_dtype,
116
+ mesh=self.mesh,
117
+ random_init=force_random_weights,
118
+ dnh_sharding=(None, "model", None),
119
+ dkh_sharding=(None, "model", None),
120
+ nhd_sharding=("model", None, None),
121
+ query_tnh=P(None, "model", None),
122
+ keyvalue_skh=P(None, "model", None),
123
+ attn_o_tnh=P(None, "model", None),
124
+ ),
125
+ custom_module=DenseFFW(dtype=dtype,
126
+ hidden_act="silu",
127
+ hidden_size=self.hidden_size,
128
+ intermediate_size=intermediate_size,
129
+ rngs=self.rng,
130
+ df_sharding=(None, "model"),
131
+ fd_sharding=("model", None),
132
+ random_init=force_random_weights),
133
+ ))
134
+
135
+ self.final_norm = RMSNorm(
136
+ dims=self.hidden_size,
137
+ rngs=self.rng,
138
+ random_init=force_random_weights,
139
+ epsilon=rms_norm_eps,
140
+ with_scale=True,
141
+ dtype=dtype,
142
+ )
143
+
144
+ self.lm_head = LMhead(vocab_size=vocab_size,
145
+ hidden_size=self.hidden_size,
146
+ dtype=dtype,
147
+ rngs=self.rng,
148
+ dv_sharding=(None, 'model'),
149
+ random_init=force_random_weights)
150
+
151
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
152
+ # NOTE: Since we are using nnx.eval_shape to init the model,
153
+ # we have to pass dynamic arrays here for __call__'s usage.
154
+ self.rng = nnx.Rngs(rng)
155
+ weight_loader = Llama3WeightLoader(
156
+ vllm_config=self.vllm_config,
157
+ hidden_size=self.hidden_size,
158
+ attn_heads=self.num_attention_heads,
159
+ num_key_value_heads=self.num_key_value_heads,
160
+ attn_head_dim=self.head_dim)
161
+
162
+ weight_loader.load_weights(self)
163
+
164
+ def __call__(
165
+ self,
166
+ kv_caches: List[jax.Array],
167
+ input_ids: jax.Array,
168
+ attention_metadata: AttentionMetadata,
169
+ *args,
170
+ ) -> Tuple[List[KVCacheType], jax.Array]:
171
+ is_prefill = False
172
+ with jax.named_scope("llama_embed_input"): #Embedding
173
+ x_TD = self.embedder.encode(input_ids)
174
+
175
+ with jax.named_scope("llama_model_transformer_blocks"):
176
+ for (i, layer) in enumerate(self.layers):
177
+ kv_cache = kv_caches[i]
178
+
179
+ # The first layer is unscoped to avoid JAX tracing issues.
180
+ # JAX's profiler may incorrectly apply the scope name from the first
181
+ # layer's kernel compilation to all subsequent layers. Skipping the
182
+ # first layer ensures distinct scope names for the remaining layers.
183
+ if i == 0:
184
+ new_kv_cache, x_TD = layer(x_TD, is_prefill, kv_cache,
185
+ attention_metadata)
186
+ else:
187
+ with jax.named_scope(f'layer_{i}'):
188
+ new_kv_cache, x_TD = layer(x_TD, is_prefill, kv_cache,
189
+ attention_metadata)
190
+
191
+ kv_caches[i] = new_kv_cache
192
+
193
+ with jax.named_scope(
194
+ "llama_final_norm"): #Norm after last transformer block
195
+ final_activation_TD = self.final_norm(x_TD)
196
+
197
+ return kv_caches, final_activation_TD, []
198
+
199
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
200
+ with jax.named_scope("llama_lm_head_projection"
201
+ ): #LM head projection to produce logits
202
+ logits_TV = jnp.dot(hidden_states,
203
+ self.lm_head.input_embedding_table_DV.value)
204
+
205
+ return logits_TV
206
+
207
+
208
+ class Llama3WeightLoader:
209
+
210
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
211
+ num_key_value_heads, attn_head_dim):
212
+ self._transpose_map = {
213
+ "lm_head": (1, 0),
214
+ "gate_proj": (1, 0),
215
+ "up_proj": (1, 0),
216
+ "down_proj": (1, 0),
217
+ "q_proj": (2, 0, 1),
218
+ "k_proj": (2, 0, 1),
219
+ "v_proj": (2, 0, 1),
220
+ "o_proj": (1, 2, 0),
221
+ }
222
+ self._weight_shape_map = {
223
+ "q_proj": (attn_heads, -1, hidden_size),
224
+ "k_proj": (num_key_value_heads, -1, hidden_size),
225
+ "v_proj": (num_key_value_heads, -1, hidden_size),
226
+ "o_proj": (hidden_size, attn_heads, -1),
227
+ }
228
+ self._bias_shape_map = {
229
+ "q_proj.bias": (attn_heads, attn_head_dim),
230
+ "k_proj.bias": (num_key_value_heads, attn_head_dim),
231
+ "v_proj.bias": (num_key_value_heads, attn_head_dim)
232
+ }
233
+
234
+ # Set the mappings from loaded parameter keys to standardized names.
235
+ self._loaded_to_standardized_keys = {
236
+ "model.embed_tokens": "embedder.input_embedding_table_VD",
237
+ "model.layers.*.input_layernorm":
238
+ "layers.*.pre_attention_norm.scale",
239
+ "model.layers.*.mlp.down_proj":
240
+ "layers.*.custom_module.kernel_down_proj_FD",
241
+ "model.layers.*.mlp.gate_proj":
242
+ "layers.*.custom_module.kernel_gating_DF",
243
+ "model.layers.*.mlp.up_proj":
244
+ "layers.*.custom_module.kernel_up_proj_DF",
245
+ "model.layers.*.post_attention_layernorm":
246
+ "layers.*.pre_mlp_norm.scale",
247
+ "model.layers.*.self_attn.k_proj":
248
+ "layers.*.attn.kernel_k_proj_DKH",
249
+ "model.layers.*.self_attn.o_proj":
250
+ "layers.*.attn.kernel_o_proj_NHD",
251
+ "model.layers.*.self_attn.q_proj":
252
+ "layers.*.attn.kernel_q_proj_DNH",
253
+ "model.layers.*.self_attn.v_proj":
254
+ "layers.*.attn.kernel_v_proj_DKH",
255
+ "model.norm": "final_norm.scale",
256
+ "lm_head": "lm_head.input_embedding_table_DV"
257
+ }
258
+ self.vllm_config = vllm_config
259
+
260
+ def load_weights(self, model_for_loading: nnx.Module):
261
+ model_params = nnx.state(model_for_loading)
262
+ metadata_map = MetadataMap(name_map=self._loaded_to_standardized_keys,
263
+ reshape_map=self._weight_shape_map,
264
+ bias_reshape_map=self._bias_shape_map,
265
+ transpose_map=self._transpose_map)
266
+ load_hf_weights(vllm_config=self.vllm_config,
267
+ model=model_for_loading,
268
+ metadata_map=metadata_map,
269
+ mesh=model_for_loading.mesh)
270
+
271
+ # TODO: validate that all of the model_params were accounted for as well.
272
+ nnx.update(model_for_loading, model_params)