tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,112 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ from vllm.utils.network_utils import get_ip
18
+
19
+ from tpu_inference import envs
20
+ from tpu_inference.logger import init_logger
21
+
22
+ logger = init_logger(__name__)
23
+
24
+ # For multi-host usage only, to collect IP and port for all nodes.
25
+ _NODES_KV_IP_PORT = dict()
26
+
27
+
28
+ def set_node_kv_ip_port(ip_port: tuple[int, str, int]):
29
+ global _NODES_KV_IP_PORT
30
+ node_id, ip, port = ip_port
31
+ _NODES_KV_IP_PORT[node_id] = (ip, port)
32
+
33
+
34
+ def get_kv_ips() -> str:
35
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
36
+ num_nodes = len(_NODES_KV_IP_PORT)
37
+ ips = []
38
+ for node_id in range(num_nodes):
39
+ ips.append(_NODES_KV_IP_PORT[node_id][0])
40
+ return ips
41
+ else:
42
+ return get_host_ip()
43
+
44
+
45
+ def get_kv_ports() -> str:
46
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
47
+ num_nodes = len(_NODES_KV_IP_PORT)
48
+ ports = []
49
+ for node_id in range(num_nodes):
50
+ ports.append(_NODES_KV_IP_PORT[node_id][1])
51
+ return ports
52
+ else:
53
+ return get_kv_transfer_port()
54
+
55
+
56
+ def get_host_ip() -> str:
57
+ """Use `VLLM_HOST_IP` if set, otherwise use default network interface IP."""
58
+ return get_ip()
59
+
60
+
61
+ def get_kv_transfer_port() -> str:
62
+ port = os.getenv("TPU_KV_TRANSFER_PORT", "9100")
63
+ return port
64
+
65
+
66
+ def get_side_channel_port() -> str:
67
+ port = os.getenv("TPU_SIDE_CHANNEL_PORT", "9600")
68
+ return port
69
+
70
+
71
+ def get_device_topology_order_id(local_devices, global_devices) -> int:
72
+ """
73
+ Calculates the topology order ID for the local device set within the global topology.
74
+
75
+ This function determines the rank of the current host/process based on the
76
+ coordinate of its TPU devices relative to all devices in the topology.
77
+
78
+ Args:
79
+ local_devices: A list of TpuDevice objects available to the current process.
80
+ global_devices: A list of all TpuDevice objects in the global topology.
81
+
82
+ Returns:
83
+ The topology order ID (rank) of the local devices.
84
+ """
85
+ if not local_devices:
86
+ raise ValueError("local_devices cannot be empty")
87
+ if not global_devices:
88
+ raise ValueError("global_devices cannot be empty")
89
+
90
+ # 1. Find the 'anchor' (minimum coordinate) for the local devices.
91
+ # This represents the physical top-left corner of the local machine.
92
+ local_anchor = min(d.coords for d in local_devices)
93
+
94
+ # 2. Group global devices by process to find the anchor for EVERY process.
95
+ process_anchors = {}
96
+ for d in global_devices:
97
+ pid = d.process_index
98
+ # Update the minimum coordinate found for this process so far
99
+ if pid not in process_anchors or d.coords < process_anchors[pid]:
100
+ process_anchors[pid] = d.coords
101
+
102
+ # 3. Sort the unique anchors to establish the canonical topology order.
103
+ # Tuples (x, y, z) sort lexicographically (x first, then y, then z).
104
+ sorted_anchors = sorted(process_anchors.values())
105
+
106
+ # 4. Return the index (rank) of the local anchor in the sorted list.
107
+ try:
108
+ return sorted_anchors.index(local_anchor)
109
+ except ValueError:
110
+ raise ValueError(
111
+ f"Local devices: {local_devices} do not exist in the global device: {global_devices} list."
112
+ )
@@ -0,0 +1,9 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3
+
4
+ import os
5
+
6
+ # Disable CUDA-specific shared experts stream for TPU
7
+ # This prevents errors when trying to create CUDA streams on TPU hardware
8
+ # The issue was introduced by vllm-project/vllm#26440
9
+ os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
tpu_inference/envs.py ADDED
@@ -0,0 +1,191 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
3
+
4
+ import functools
5
+ import os
6
+ from collections.abc import Callable
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ if TYPE_CHECKING:
10
+ JAX_PLATFORMS: str = ""
11
+ TPU_ACCELERATOR_TYPE: str | None = None
12
+ TPU_NAME: str | None = None
13
+ TPU_WORKER_ID: str | None = None
14
+ TPU_MULTIHOST_BACKEND: str = ""
15
+ PREFILL_SLICES: str = ""
16
+ DECODE_SLICES: str = ""
17
+ SKIP_JAX_PRECOMPILE: bool = False
18
+ VLLM_XLA_CHECK_RECOMPILATION: bool = False
19
+ MODEL_IMPL_TYPE: str = "auto"
20
+ NEW_MODEL_DESIGN: bool = False
21
+ PHASED_PROFILING_DIR: str = ""
22
+ PYTHON_TRACER_LEVEL: int = 1
23
+ USE_MOE_EP_KERNEL: bool = False
24
+ NUM_SLICES: int = 1
25
+ RAY_USAGE_STATS_ENABLED: str = "0"
26
+ VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
27
+ ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
28
+
29
+
30
+ def env_with_choices(
31
+ env_name: str,
32
+ default: str | None,
33
+ choices: list[str] | Callable[[], list[str]],
34
+ case_sensitive: bool = True,
35
+ ) -> Callable[[], str | None]:
36
+ """
37
+ Create a lambda that validates environment variable against allowed choices
38
+
39
+ Args:
40
+ env_name: Name of the environment variable
41
+ default: Default value if not set (can be None)
42
+ choices: List of valid string options or callable that returns list
43
+ case_sensitive: Whether validation should be case sensitive
44
+
45
+ Returns:
46
+ Lambda function for environment_variables dict
47
+ """
48
+
49
+ def _get_validated_env() -> str | None:
50
+ value = os.getenv(env_name)
51
+ if value is None:
52
+ return default
53
+
54
+ # Resolve choices if it's a callable (for lazy loading)
55
+ actual_choices = choices() if callable(choices) else choices
56
+
57
+ if not case_sensitive:
58
+ check_value = value.lower()
59
+ check_choices = [choice.lower() for choice in actual_choices]
60
+ else:
61
+ check_value = value
62
+ check_choices = actual_choices
63
+
64
+ if check_value not in check_choices:
65
+ raise ValueError(f"Invalid value '{value}' for {env_name}. "
66
+ f"Valid options: {actual_choices}.")
67
+
68
+ return value
69
+
70
+ return _get_validated_env
71
+
72
+
73
+ def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
74
+ """
75
+ Accepts both numeric strings ("0", "1") and boolean strings
76
+ ("true", "false", "True", "False").
77
+
78
+ Args:
79
+ env_name: Name of the environment variable
80
+ default: Default boolean value if not set
81
+ """
82
+
83
+ def _get_bool_env() -> bool:
84
+ value = os.getenv(env_name)
85
+ if value is None or value == "":
86
+ return default
87
+
88
+ value_lower = value.lower()
89
+ if value_lower in ("true", "1"):
90
+ return True
91
+ elif value_lower in ("false", "0"):
92
+ return False
93
+ else:
94
+ raise ValueError(
95
+ f"Invalid boolean value '{value}' for {env_name}. "
96
+ f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
97
+
98
+ return _get_bool_env
99
+
100
+
101
+ environment_variables: dict[str, Callable[[], Any]] = {
102
+ # JAX platform selection (e.g., "tpu", "cpu", "proxy")
103
+ "JAX_PLATFORMS":
104
+ lambda: os.getenv("JAX_PLATFORMS", "").lower(),
105
+ # TPU accelerator type (e.g., "v5litepod-16", "v4-8")
106
+ "TPU_ACCELERATOR_TYPE":
107
+ lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
108
+ # Name of the TPU resource
109
+ "TPU_NAME":
110
+ lambda: os.getenv("TPU_NAME", None),
111
+ # Worker ID for multi-host TPU setups
112
+ "TPU_WORKER_ID":
113
+ lambda: os.getenv("TPU_WORKER_ID", None),
114
+ # Backend for multi-host communication on TPU
115
+ "TPU_MULTIHOST_BACKEND":
116
+ env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
117
+ # Slice configuration for disaggregated prefill workers
118
+ "PREFILL_SLICES":
119
+ lambda: os.getenv("PREFILL_SLICES", ""),
120
+ # Slice configuration for disaggregated decode workers
121
+ "DECODE_SLICES":
122
+ lambda: os.getenv("DECODE_SLICES", ""),
123
+ # Skip JAX precompilation step during initialization
124
+ "SKIP_JAX_PRECOMPILE":
125
+ env_bool("SKIP_JAX_PRECOMPILE", default=False),
126
+ # Check for XLA recompilation during execution
127
+ "VLLM_XLA_CHECK_RECOMPILATION":
128
+ env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
129
+ # Model implementation type (e.g., "flax_nnx")
130
+ "MODEL_IMPL_TYPE":
131
+ env_with_choices("MODEL_IMPL_TYPE", "auto",
132
+ ["auto", "vllm", "flax_nnx", "jetpack"]),
133
+ # Enable new experimental model design
134
+ "NEW_MODEL_DESIGN":
135
+ env_bool("NEW_MODEL_DESIGN", default=False),
136
+ # Directory to store phased profiling output
137
+ "PHASED_PROFILING_DIR":
138
+ lambda: os.getenv("PHASED_PROFILING_DIR", ""),
139
+ # Python tracer level for profiling
140
+ "PYTHON_TRACER_LEVEL":
141
+ lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
142
+ # Use custom expert-parallel kernel for MoE (Mixture of Experts)
143
+ "USE_MOE_EP_KERNEL":
144
+ env_bool("USE_MOE_EP_KERNEL", default=False),
145
+ # Number of TPU slices for multi-slice mesh
146
+ "NUM_SLICES":
147
+ lambda: int(os.getenv("NUM_SLICES") or "1"),
148
+ # Enable/disable Ray usage statistics collection
149
+ "RAY_USAGE_STATS_ENABLED":
150
+ lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
151
+ # Ray compiled DAG channel type for TPU
152
+ "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
153
+ env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
154
+ "ENABLE_QUANTIZED_MATMUL_KERNEL":
155
+ lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
156
+ }
157
+
158
+
159
+ def __getattr__(name: str) -> Any:
160
+ """
161
+ Gets environment variables lazily.
162
+
163
+ NOTE: After enable_envs_cache() invocation (which triggered after service
164
+ initialization), all environment variables will be cached.
165
+ """
166
+ if name in environment_variables:
167
+ return environment_variables[name]()
168
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
169
+
170
+
171
+ def enable_envs_cache() -> None:
172
+ """
173
+ Enables caching of environment variables by wrapping the module's __getattr__
174
+ function with functools.cache(). This improves performance by avoiding
175
+ repeated re-evaluation of environment variables.
176
+
177
+ NOTE: This should be called after service initialization. Once enabled,
178
+ environment variable values are cached and will not reflect changes to
179
+ os.environ until the process is restarted.
180
+ """
181
+ # Tag __getattr__ with functools.cache
182
+ global __getattr__
183
+ __getattr__ = functools.cache(__getattr__)
184
+
185
+ # Cache all environment variables
186
+ for key in environment_variables:
187
+ __getattr__(key)
188
+
189
+
190
+ def __dir__() -> list[str]:
191
+ return list(environment_variables.keys())
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.