tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
tpu_inference/utils.py ADDED
@@ -0,0 +1,345 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import time
3
+ from collections import defaultdict
4
+ from collections.abc import Sequence
5
+ from functools import wraps
6
+ from typing import Any, Callable, List, Tuple
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import numpy as np
11
+ import torch
12
+ from jax._src import dtypes
13
+ from jax._src import mesh as mesh_lib
14
+ from jax._src import xla_bridge as xb
15
+ from jax._src.lib import xla_client as xc
16
+ from jax._src.numpy.scalar_types import _ScalarMeta
17
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
18
+ from torchax.ops.mappings import j2t_dtype, t2j_dtype
19
+ from vllm import envs as vllm_envs
20
+ from vllm import utils
21
+
22
+ from tpu_inference import envs
23
+ from tpu_inference.logger import init_logger
24
+
25
+ GBYTES = 1024 * 1024 * 1024
26
+ TPU_HEAD_SIZE_ALIGNMENT = 128
27
+ TPU_SECOND_LAST_MINOR = 8
28
+
29
+ # Map vllm dtype string that doesn't exactly match jax dtype string name.
30
+ _VLLM_DTYPE_STR_TO_JAX_DTYPE = {
31
+ "fp8": jnp.float8_e4m3fn.dtype,
32
+ "fp8_e4m3": jnp.float8_e4m3fn.dtype,
33
+ "fp8_e5m2": jnp.float8_e5m2.dtype,
34
+ }
35
+
36
+
37
+ def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
38
+ if isinstance(dtype, str):
39
+ if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
40
+ return dict_dtype
41
+ return jnp.dtype(dtype)
42
+ elif isinstance(dtype, torch.dtype):
43
+ return t2j_dtype(dtype)
44
+ elif isinstance(dtype, jnp.dtype):
45
+ return dtype
46
+ elif isinstance(dtype, _ScalarMeta):
47
+ return dtype.dtype
48
+ else:
49
+ raise ValueError(f"Argument is unsupported data type {type(dtype)}")
50
+
51
+
52
+ def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
53
+ # Use jax dtype as an intermediate dtype which we'll be used to convert it
54
+ # into torch dtype.
55
+ dtype = to_jax_dtype(dtype)
56
+ return j2t_dtype(dtype)
57
+
58
+
59
+ _megacore = False
60
+ logger = init_logger(__name__)
61
+
62
+
63
+ def align_to(unpadded_dim, pad_multiple):
64
+ return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
65
+
66
+
67
+ def enable_megacore() -> None:
68
+ global _megacore
69
+ _megacore = True
70
+
71
+
72
+ def get_megacore() -> bool:
73
+ return _megacore
74
+
75
+
76
+ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
77
+ if tp_size <= num_kv_heads:
78
+ assert num_kv_heads % tp_size == 0
79
+ return num_kv_heads
80
+ else:
81
+ assert tp_size % num_kv_heads == 0
82
+ return tp_size
83
+
84
+
85
+ def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
86
+ usage = []
87
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS:
88
+ return pathways_hbm_usage_gb(devices)
89
+
90
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
91
+ if multihost_backend == "ray":
92
+ # MemoryStats is only supported for addressable PjRt devices.
93
+ # Assume all the devices have similar memory usage for now.
94
+ # TODO(ranlihao): find a proper way to get the memory usage of each device.
95
+ for device in devices:
96
+ try:
97
+ hbm_used = device.memory_stats()["bytes_in_use"]
98
+ hbm_limit = device.memory_stats()["bytes_limit"]
99
+ logger.info(
100
+ "Get memory stats for device %s. Assuming all devices have the same usage.",
101
+ device)
102
+ usage.extend([(hbm_used, hbm_limit)] * len(devices))
103
+ break
104
+ except Exception as e:
105
+ logger.warning(
106
+ "Failed to get memory stats for device %s: %s. ", device,
107
+ e)
108
+ else:
109
+ for device in devices:
110
+ hbm_used = device.memory_stats()["bytes_in_use"]
111
+ hbm_limit = device.memory_stats()["bytes_limit"]
112
+ usage.append((hbm_used, hbm_limit))
113
+
114
+ return usage
115
+
116
+
117
+ def get_device_name(num_devices: int | None = None):
118
+ kind = jax.devices()[0].device_kind
119
+ if 'TPU' not in kind:
120
+ raise RuntimeError('Expected TPU devices')
121
+ suffix = ''
122
+ if kind.endswith(' lite'):
123
+ kind = kind[:-len(' lite')]
124
+ suffix = 'e'
125
+ elif kind.endswith('e'):
126
+ kind = kind[:-1]
127
+ suffix = 'e'
128
+ elif kind.endswith('p'):
129
+ kind = kind[:-1]
130
+ suffix = 'p'
131
+ elif kind == 'TPU7x':
132
+ kind = 'TPU v7'
133
+ assert kind[:-1] == 'TPU v', kind
134
+ kind += suffix
135
+ if num_devices is not None:
136
+ kind += f'-{num_devices}'
137
+ return kind
138
+
139
+
140
+ def get_device_hbm_limit() -> int:
141
+
142
+ device_kind = get_device_name()
143
+ if device_kind == "TPU v5p" or device_kind == "TPU v5":
144
+ return 95 * GBYTES
145
+ elif device_kind == "TPU v5e":
146
+ return 16 * GBYTES
147
+ elif device_kind == "TPU v6e" or device_kind == "TPU v4":
148
+ return 32 * GBYTES
149
+ elif device_kind == "TPU v7":
150
+ # 192 * GBYTES / 2 because each JAX device (v7x core) has
151
+ # 1/2 of the total chip HBM
152
+ return 96 * GBYTES
153
+ else:
154
+ raise ValueError(f"Unknown device kind: {device_kind}")
155
+
156
+
157
+ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
158
+ live_arrays = jax.live_arrays()
159
+ hbm_used = defaultdict(int)
160
+ hbm_limit = get_device_hbm_limit()
161
+ for array in live_arrays:
162
+ for buffer in array.addressable_shards:
163
+ hbm_used[buffer.data.device] += buffer.data.nbytes
164
+ return [(hbm_used[device], hbm_limit) for device in devices]
165
+
166
+
167
+ def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
168
+ usage = hbm_usage_bytes(devices)
169
+ usage = [(round(used / GBYTES, 2), round(limit / GBYTES, 2))
170
+ for used, limit in usage]
171
+ return usage
172
+
173
+
174
+ def get_padded_head_dim(head_dim: int) -> int:
175
+ """Pads head_dim up to the nearest multiple of 128 for kernel performance."""
176
+ # When head_dim == 64, we use kernel specificly optimized for it which does
177
+ # not require any padding.
178
+ if head_dim == 64:
179
+ return 64
180
+ return (head_dim + 127) // 128 * 128
181
+
182
+
183
+ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
184
+ if num_heads >= sharding_size:
185
+ assert num_heads % sharding_size == 0
186
+ else:
187
+ assert sharding_size % num_heads == 0
188
+ num_heads = sharding_size
189
+ return num_heads
190
+
191
+
192
+ def get_dtype_packing(dtype):
193
+ bits = (dtypes.bit_width(dtype)
194
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
195
+ return 32 // bits
196
+
197
+
198
+ def make_optimized_mesh(axis_shapes: Sequence[int],
199
+ axis_names: Sequence[str],
200
+ *,
201
+ devices: Sequence[xc.Device] | None = None):
202
+ if devices is None:
203
+ devices = xb.devices()
204
+ # Sort the devices in case it's passed in an arbitary order
205
+ devices = sorted(devices, key=lambda x: x.coords)
206
+
207
+ def _is_1D(axis_shapes):
208
+ return sum(x > 1 for x in axis_shapes) == 1
209
+
210
+ if _is_1D(axis_shapes):
211
+ dev_kind = devices[0].device_kind
212
+ device_num = len(devices)
213
+ if dev_kind == "TPU v6 lite":
214
+ ordered_devices = None
215
+ # NOTE(chengjiyao):
216
+ # The coords of v6e-8 are
217
+ # (0,0,0)
218
+ # (1,0,0)
219
+ # (0,1,0)
220
+ # (1,1,0)
221
+ # (0,2,0)
222
+ # (1,2,0)
223
+ # (0,3,0)
224
+ # (1,3,0)
225
+ if device_num == 8:
226
+ ordered_devices = np.array([
227
+ devices[0],
228
+ devices[1],
229
+ devices[2],
230
+ devices[3],
231
+ devices[7],
232
+ devices[6],
233
+ devices[5],
234
+ devices[4],
235
+ ])
236
+ # NOTE(chengjiyao):
237
+ # The coords of v6e-4 are
238
+ # (0,0,0)
239
+ # (1,0,0)
240
+ # (0,1,0)
241
+ # (1,1,0)
242
+ elif device_num == 4:
243
+ ordered_devices = np.array([
244
+ devices[0],
245
+ devices[1],
246
+ devices[3],
247
+ devices[2],
248
+ ])
249
+ if ordered_devices is not None:
250
+ ordered_devices = np.array(ordered_devices)
251
+ ordered_devices = ordered_devices.reshape(axis_shapes)
252
+ mesh = mesh_lib.Mesh(ordered_devices, axis_names)
253
+ logger.info("Use customized mesh: %s", mesh)
254
+ return mesh
255
+
256
+ return jax.make_mesh(axis_shapes, axis_names, devices=devices)
257
+
258
+
259
+ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
260
+ """
261
+ Create a device array with the specified mesh and sharding.
262
+
263
+ Args:
264
+ mesh: The JAX mesh to use for device placement
265
+ *args: Positional arguments to pass to jax.device_put
266
+ sharding: Optional sharding specification. If None, uses PartitionSpec(None)
267
+ **kwargs: Keyword arguments to pass to jax.device_put
268
+
269
+ Returns:
270
+ A JAX array placed on the specified devices
271
+ """
272
+ if sharding is None:
273
+ sharding = NamedSharding(mesh, PartitionSpec(None))
274
+ return jax.device_put(*args, device=sharding, **kwargs)
275
+
276
+
277
+ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
278
+ """
279
+ A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
280
+ """
281
+ if hash_fn_name == "builtin":
282
+ return hash
283
+ return utils.hashing.get_hash_fn_by_name(hash_fn_name)
284
+
285
+
286
+ def quantize_kv(key: jax.Array, value: jax.Array,
287
+ kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
288
+ v_scale: float) -> Tuple[jax.Array, jax.Array]:
289
+ """
290
+ Quantize the key and value tensors.
291
+
292
+ Args:
293
+ key: The key tensor to quantize.
294
+ value: The value tensor to quantize.
295
+ kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
296
+ q_scale: The scale to quantize the key and value tensors by.
297
+ k_scale: The scale to quantize the key tensor by.
298
+ v_scale: The scale to quantize the value tensor by.
299
+
300
+ Returns:
301
+ Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
302
+ """
303
+ dtype_info = jnp.finfo(kv_cache_quantized_dtype)
304
+ minval, maxval = float(dtype_info.min), float(dtype_info.max)
305
+ key = key.astype(jnp.float32) / k_scale
306
+ key = jnp.clip(key, minval, maxval)
307
+ key = key.astype(kv_cache_quantized_dtype)
308
+ value = value.astype(jnp.float32) / v_scale
309
+ value = jnp.clip(value, minval, maxval)
310
+ value = value.astype(kv_cache_quantized_dtype)
311
+
312
+ return key, value
313
+
314
+
315
+ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
316
+ """
317
+ Get the JAX dtype from a string dtype.
318
+
319
+ Args:
320
+ str_dtype: The string dtype to get the JAX dtype from.
321
+
322
+ Returns:
323
+ jnp.dtype: The JAX dtype.
324
+ """
325
+ # TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
326
+ return to_jax_dtype(str_dtype)
327
+
328
+
329
+ def time_function(func):
330
+ """
331
+ A decorator to measure the execution time of a function.
332
+ """
333
+
334
+ @wraps(func)
335
+ def wrapper(*args, **kwargs):
336
+ start_time = time.perf_counter()
337
+ result = func(*args, **kwargs)
338
+ end_time = time.perf_counter()
339
+ execution_time = end_time - start_time
340
+ logger.debug(
341
+ f"Function '{func.__name__}' executed in {execution_time:.4f} seconds."
342
+ )
343
+ return result
344
+
345
+ return wrapper
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.