tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
tpu_inference/utils.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import time
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from typing import Any, Callable, List, Tuple
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from jax._src import dtypes
|
|
13
|
+
from jax._src import mesh as mesh_lib
|
|
14
|
+
from jax._src import xla_bridge as xb
|
|
15
|
+
from jax._src.lib import xla_client as xc
|
|
16
|
+
from jax._src.numpy.scalar_types import _ScalarMeta
|
|
17
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
18
|
+
from torchax.ops.mappings import j2t_dtype, t2j_dtype
|
|
19
|
+
from vllm import envs as vllm_envs
|
|
20
|
+
from vllm import utils
|
|
21
|
+
|
|
22
|
+
from tpu_inference import envs
|
|
23
|
+
from tpu_inference.logger import init_logger
|
|
24
|
+
|
|
25
|
+
GBYTES = 1024 * 1024 * 1024
|
|
26
|
+
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
27
|
+
TPU_SECOND_LAST_MINOR = 8
|
|
28
|
+
|
|
29
|
+
# Map vllm dtype string that doesn't exactly match jax dtype string name.
|
|
30
|
+
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
|
|
31
|
+
"fp8": jnp.float8_e4m3fn.dtype,
|
|
32
|
+
"fp8_e4m3": jnp.float8_e4m3fn.dtype,
|
|
33
|
+
"fp8_e5m2": jnp.float8_e5m2.dtype,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
|
|
38
|
+
if isinstance(dtype, str):
|
|
39
|
+
if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
|
|
40
|
+
return dict_dtype
|
|
41
|
+
return jnp.dtype(dtype)
|
|
42
|
+
elif isinstance(dtype, torch.dtype):
|
|
43
|
+
return t2j_dtype(dtype)
|
|
44
|
+
elif isinstance(dtype, jnp.dtype):
|
|
45
|
+
return dtype
|
|
46
|
+
elif isinstance(dtype, _ScalarMeta):
|
|
47
|
+
return dtype.dtype
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Argument is unsupported data type {type(dtype)}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
|
|
53
|
+
# Use jax dtype as an intermediate dtype which we'll be used to convert it
|
|
54
|
+
# into torch dtype.
|
|
55
|
+
dtype = to_jax_dtype(dtype)
|
|
56
|
+
return j2t_dtype(dtype)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
_megacore = False
|
|
60
|
+
logger = init_logger(__name__)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def align_to(unpadded_dim, pad_multiple):
|
|
64
|
+
return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def enable_megacore() -> None:
|
|
68
|
+
global _megacore
|
|
69
|
+
_megacore = True
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_megacore() -> bool:
|
|
73
|
+
return _megacore
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
|
|
77
|
+
if tp_size <= num_kv_heads:
|
|
78
|
+
assert num_kv_heads % tp_size == 0
|
|
79
|
+
return num_kv_heads
|
|
80
|
+
else:
|
|
81
|
+
assert tp_size % num_kv_heads == 0
|
|
82
|
+
return tp_size
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
|
|
86
|
+
usage = []
|
|
87
|
+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
|
|
88
|
+
return pathways_hbm_usage_gb(devices)
|
|
89
|
+
|
|
90
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
91
|
+
if multihost_backend == "ray":
|
|
92
|
+
# MemoryStats is only supported for addressable PjRt devices.
|
|
93
|
+
# Assume all the devices have similar memory usage for now.
|
|
94
|
+
# TODO(ranlihao): find a proper way to get the memory usage of each device.
|
|
95
|
+
for device in devices:
|
|
96
|
+
try:
|
|
97
|
+
hbm_used = device.memory_stats()["bytes_in_use"]
|
|
98
|
+
hbm_limit = device.memory_stats()["bytes_limit"]
|
|
99
|
+
logger.info(
|
|
100
|
+
"Get memory stats for device %s. Assuming all devices have the same usage.",
|
|
101
|
+
device)
|
|
102
|
+
usage.extend([(hbm_used, hbm_limit)] * len(devices))
|
|
103
|
+
break
|
|
104
|
+
except Exception as e:
|
|
105
|
+
logger.warning(
|
|
106
|
+
"Failed to get memory stats for device %s: %s. ", device,
|
|
107
|
+
e)
|
|
108
|
+
else:
|
|
109
|
+
for device in devices:
|
|
110
|
+
hbm_used = device.memory_stats()["bytes_in_use"]
|
|
111
|
+
hbm_limit = device.memory_stats()["bytes_limit"]
|
|
112
|
+
usage.append((hbm_used, hbm_limit))
|
|
113
|
+
|
|
114
|
+
return usage
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def get_device_name(num_devices: int | None = None):
|
|
118
|
+
kind = jax.devices()[0].device_kind
|
|
119
|
+
if 'TPU' not in kind:
|
|
120
|
+
raise RuntimeError('Expected TPU devices')
|
|
121
|
+
suffix = ''
|
|
122
|
+
if kind.endswith(' lite'):
|
|
123
|
+
kind = kind[:-len(' lite')]
|
|
124
|
+
suffix = 'e'
|
|
125
|
+
elif kind.endswith('e'):
|
|
126
|
+
kind = kind[:-1]
|
|
127
|
+
suffix = 'e'
|
|
128
|
+
elif kind.endswith('p'):
|
|
129
|
+
kind = kind[:-1]
|
|
130
|
+
suffix = 'p'
|
|
131
|
+
elif kind == 'TPU7x':
|
|
132
|
+
kind = 'TPU v7'
|
|
133
|
+
assert kind[:-1] == 'TPU v', kind
|
|
134
|
+
kind += suffix
|
|
135
|
+
if num_devices is not None:
|
|
136
|
+
kind += f'-{num_devices}'
|
|
137
|
+
return kind
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def get_device_hbm_limit() -> int:
|
|
141
|
+
|
|
142
|
+
device_kind = get_device_name()
|
|
143
|
+
if device_kind == "TPU v5p" or device_kind == "TPU v5":
|
|
144
|
+
return 95 * GBYTES
|
|
145
|
+
elif device_kind == "TPU v5e":
|
|
146
|
+
return 16 * GBYTES
|
|
147
|
+
elif device_kind == "TPU v6e" or device_kind == "TPU v4":
|
|
148
|
+
return 32 * GBYTES
|
|
149
|
+
elif device_kind == "TPU v7":
|
|
150
|
+
# 192 * GBYTES / 2 because each JAX device (v7x core) has
|
|
151
|
+
# 1/2 of the total chip HBM
|
|
152
|
+
return 96 * GBYTES
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(f"Unknown device kind: {device_kind}")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
|
|
158
|
+
live_arrays = jax.live_arrays()
|
|
159
|
+
hbm_used = defaultdict(int)
|
|
160
|
+
hbm_limit = get_device_hbm_limit()
|
|
161
|
+
for array in live_arrays:
|
|
162
|
+
for buffer in array.addressable_shards:
|
|
163
|
+
hbm_used[buffer.data.device] += buffer.data.nbytes
|
|
164
|
+
return [(hbm_used[device], hbm_limit) for device in devices]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
|
|
168
|
+
usage = hbm_usage_bytes(devices)
|
|
169
|
+
usage = [(round(used / GBYTES, 2), round(limit / GBYTES, 2))
|
|
170
|
+
for used, limit in usage]
|
|
171
|
+
return usage
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def get_padded_head_dim(head_dim: int) -> int:
|
|
175
|
+
"""Pads head_dim up to the nearest multiple of 128 for kernel performance."""
|
|
176
|
+
# When head_dim == 64, we use kernel specificly optimized for it which does
|
|
177
|
+
# not require any padding.
|
|
178
|
+
if head_dim == 64:
|
|
179
|
+
return 64
|
|
180
|
+
return (head_dim + 127) // 128 * 128
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
|
|
184
|
+
if num_heads >= sharding_size:
|
|
185
|
+
assert num_heads % sharding_size == 0
|
|
186
|
+
else:
|
|
187
|
+
assert sharding_size % num_heads == 0
|
|
188
|
+
num_heads = sharding_size
|
|
189
|
+
return num_heads
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def get_dtype_packing(dtype):
|
|
193
|
+
bits = (dtypes.bit_width(dtype)
|
|
194
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
195
|
+
return 32 // bits
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def make_optimized_mesh(axis_shapes: Sequence[int],
|
|
199
|
+
axis_names: Sequence[str],
|
|
200
|
+
*,
|
|
201
|
+
devices: Sequence[xc.Device] | None = None):
|
|
202
|
+
if devices is None:
|
|
203
|
+
devices = xb.devices()
|
|
204
|
+
# Sort the devices in case it's passed in an arbitary order
|
|
205
|
+
devices = sorted(devices, key=lambda x: x.coords)
|
|
206
|
+
|
|
207
|
+
def _is_1D(axis_shapes):
|
|
208
|
+
return sum(x > 1 for x in axis_shapes) == 1
|
|
209
|
+
|
|
210
|
+
if _is_1D(axis_shapes):
|
|
211
|
+
dev_kind = devices[0].device_kind
|
|
212
|
+
device_num = len(devices)
|
|
213
|
+
if dev_kind == "TPU v6 lite":
|
|
214
|
+
ordered_devices = None
|
|
215
|
+
# NOTE(chengjiyao):
|
|
216
|
+
# The coords of v6e-8 are
|
|
217
|
+
# (0,0,0)
|
|
218
|
+
# (1,0,0)
|
|
219
|
+
# (0,1,0)
|
|
220
|
+
# (1,1,0)
|
|
221
|
+
# (0,2,0)
|
|
222
|
+
# (1,2,0)
|
|
223
|
+
# (0,3,0)
|
|
224
|
+
# (1,3,0)
|
|
225
|
+
if device_num == 8:
|
|
226
|
+
ordered_devices = np.array([
|
|
227
|
+
devices[0],
|
|
228
|
+
devices[1],
|
|
229
|
+
devices[2],
|
|
230
|
+
devices[3],
|
|
231
|
+
devices[7],
|
|
232
|
+
devices[6],
|
|
233
|
+
devices[5],
|
|
234
|
+
devices[4],
|
|
235
|
+
])
|
|
236
|
+
# NOTE(chengjiyao):
|
|
237
|
+
# The coords of v6e-4 are
|
|
238
|
+
# (0,0,0)
|
|
239
|
+
# (1,0,0)
|
|
240
|
+
# (0,1,0)
|
|
241
|
+
# (1,1,0)
|
|
242
|
+
elif device_num == 4:
|
|
243
|
+
ordered_devices = np.array([
|
|
244
|
+
devices[0],
|
|
245
|
+
devices[1],
|
|
246
|
+
devices[3],
|
|
247
|
+
devices[2],
|
|
248
|
+
])
|
|
249
|
+
if ordered_devices is not None:
|
|
250
|
+
ordered_devices = np.array(ordered_devices)
|
|
251
|
+
ordered_devices = ordered_devices.reshape(axis_shapes)
|
|
252
|
+
mesh = mesh_lib.Mesh(ordered_devices, axis_names)
|
|
253
|
+
logger.info("Use customized mesh: %s", mesh)
|
|
254
|
+
return mesh
|
|
255
|
+
|
|
256
|
+
return jax.make_mesh(axis_shapes, axis_names, devices=devices)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
|
|
260
|
+
"""
|
|
261
|
+
Create a device array with the specified mesh and sharding.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
mesh: The JAX mesh to use for device placement
|
|
265
|
+
*args: Positional arguments to pass to jax.device_put
|
|
266
|
+
sharding: Optional sharding specification. If None, uses PartitionSpec(None)
|
|
267
|
+
**kwargs: Keyword arguments to pass to jax.device_put
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
A JAX array placed on the specified devices
|
|
271
|
+
"""
|
|
272
|
+
if sharding is None:
|
|
273
|
+
sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
274
|
+
return jax.device_put(*args, device=sharding, **kwargs)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
|
278
|
+
"""
|
|
279
|
+
A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
|
|
280
|
+
"""
|
|
281
|
+
if hash_fn_name == "builtin":
|
|
282
|
+
return hash
|
|
283
|
+
return utils.hashing.get_hash_fn_by_name(hash_fn_name)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def quantize_kv(key: jax.Array, value: jax.Array,
|
|
287
|
+
kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
|
|
288
|
+
v_scale: float) -> Tuple[jax.Array, jax.Array]:
|
|
289
|
+
"""
|
|
290
|
+
Quantize the key and value tensors.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
key: The key tensor to quantize.
|
|
294
|
+
value: The value tensor to quantize.
|
|
295
|
+
kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
|
|
296
|
+
q_scale: The scale to quantize the key and value tensors by.
|
|
297
|
+
k_scale: The scale to quantize the key tensor by.
|
|
298
|
+
v_scale: The scale to quantize the value tensor by.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
|
|
302
|
+
"""
|
|
303
|
+
dtype_info = jnp.finfo(kv_cache_quantized_dtype)
|
|
304
|
+
minval, maxval = float(dtype_info.min), float(dtype_info.max)
|
|
305
|
+
key = key.astype(jnp.float32) / k_scale
|
|
306
|
+
key = jnp.clip(key, minval, maxval)
|
|
307
|
+
key = key.astype(kv_cache_quantized_dtype)
|
|
308
|
+
value = value.astype(jnp.float32) / v_scale
|
|
309
|
+
value = jnp.clip(value, minval, maxval)
|
|
310
|
+
value = value.astype(kv_cache_quantized_dtype)
|
|
311
|
+
|
|
312
|
+
return key, value
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
316
|
+
"""
|
|
317
|
+
Get the JAX dtype from a string dtype.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
str_dtype: The string dtype to get the JAX dtype from.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
jnp.dtype: The JAX dtype.
|
|
324
|
+
"""
|
|
325
|
+
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
|
|
326
|
+
return to_jax_dtype(str_dtype)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def time_function(func):
|
|
330
|
+
"""
|
|
331
|
+
A decorator to measure the execution time of a function.
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
@wraps(func)
|
|
335
|
+
def wrapper(*args, **kwargs):
|
|
336
|
+
start_time = time.perf_counter()
|
|
337
|
+
result = func(*args, **kwargs)
|
|
338
|
+
end_time = time.perf_counter()
|
|
339
|
+
execution_time = end_time - start_time
|
|
340
|
+
logger.debug(
|
|
341
|
+
f"Function '{func.__name__}' executed in {execution_time:.4f} seconds."
|
|
342
|
+
)
|
|
343
|
+
return result
|
|
344
|
+
|
|
345
|
+
return wrapper
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|