tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
tests/test_envs.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the tpu-inference project
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
import tpu_inference.envs as envs
|
|
7
|
+
from tpu_inference.envs import enable_envs_cache, environment_variables
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_getattr_without_cache(monkeypatch: pytest.MonkeyPatch):
|
|
11
|
+
assert envs.JAX_PLATFORMS == ""
|
|
12
|
+
assert envs.PHASED_PROFILING_DIR == ""
|
|
13
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
14
|
+
monkeypatch.setenv("PHASED_PROFILING_DIR", "/tmp/profiling")
|
|
15
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
16
|
+
assert envs.PHASED_PROFILING_DIR == "/tmp/profiling"
|
|
17
|
+
|
|
18
|
+
assert envs.TPU_NAME is None
|
|
19
|
+
assert envs.TPU_ACCELERATOR_TYPE is None
|
|
20
|
+
monkeypatch.setenv("TPU_NAME", "my-tpu")
|
|
21
|
+
monkeypatch.setenv("TPU_ACCELERATOR_TYPE", "v5litepod-16")
|
|
22
|
+
assert envs.TPU_NAME == "my-tpu"
|
|
23
|
+
assert envs.TPU_ACCELERATOR_TYPE == "v5litepod-16"
|
|
24
|
+
|
|
25
|
+
# __getattr__ is not decorated with functools.cache
|
|
26
|
+
assert not hasattr(envs.__getattr__, "cache_info")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
30
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
31
|
+
monkeypatch.setenv("TPU_NAME", "my-tpu")
|
|
32
|
+
|
|
33
|
+
# __getattr__ is not decorated with functools.cache
|
|
34
|
+
assert not hasattr(envs.__getattr__, "cache_info")
|
|
35
|
+
|
|
36
|
+
enable_envs_cache()
|
|
37
|
+
|
|
38
|
+
# __getattr__ is decorated with functools.cache
|
|
39
|
+
assert hasattr(envs.__getattr__, "cache_info")
|
|
40
|
+
start_hits = envs.__getattr__.cache_info().hits
|
|
41
|
+
|
|
42
|
+
# 2 more hits due to JAX_PLATFORMS and TPU_NAME accesses
|
|
43
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
44
|
+
assert envs.TPU_NAME == "my-tpu"
|
|
45
|
+
assert envs.__getattr__.cache_info().hits == start_hits + 2
|
|
46
|
+
|
|
47
|
+
# All environment variables are cached
|
|
48
|
+
for environment_variable in environment_variables:
|
|
49
|
+
envs.__getattr__(environment_variable)
|
|
50
|
+
assert envs.__getattr__.cache_info(
|
|
51
|
+
).hits == start_hits + 2 + len(environment_variables)
|
|
52
|
+
|
|
53
|
+
# Reset envs.__getattr__ back to non-cached version to
|
|
54
|
+
# avoid affecting other tests
|
|
55
|
+
envs.__getattr__ = envs.__getattr__.__wrapped__
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
+
# Ensure clean environment for boolean vars by setting to default "0"
|
|
60
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
61
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
62
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
|
|
63
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
|
|
64
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
|
|
65
|
+
|
|
66
|
+
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
67
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
68
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
69
|
+
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
70
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
71
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
72
|
+
|
|
73
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
|
|
74
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
75
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
|
|
76
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
77
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
78
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
79
|
+
|
|
80
|
+
# Test NEW_MODEL_DESIGN (default False)
|
|
81
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
82
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
83
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
84
|
+
|
|
85
|
+
# Test USE_MOE_EP_KERNEL (default False)
|
|
86
|
+
assert envs.USE_MOE_EP_KERNEL is False
|
|
87
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
|
|
88
|
+
assert envs.USE_MOE_EP_KERNEL is True
|
|
89
|
+
|
|
90
|
+
# Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
|
|
91
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
|
|
92
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
|
|
93
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
|
|
97
|
+
"""Test that boolean env vars accept string values like 'True' and 'False'"""
|
|
98
|
+
|
|
99
|
+
# Test NEW_MODEL_DESIGN with string "True"
|
|
100
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
|
|
101
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
102
|
+
|
|
103
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
|
|
104
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
105
|
+
|
|
106
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
|
|
107
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
108
|
+
|
|
109
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
|
|
110
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
111
|
+
|
|
112
|
+
# Test SKIP_JAX_PRECOMPILE with string values
|
|
113
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
|
|
114
|
+
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
115
|
+
|
|
116
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
|
|
117
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
118
|
+
|
|
119
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION with string values
|
|
120
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
|
|
121
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
122
|
+
|
|
123
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
|
|
124
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
125
|
+
|
|
126
|
+
# Test USE_MOE_EP_KERNEL with string values
|
|
127
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
|
|
128
|
+
assert envs.USE_MOE_EP_KERNEL is True
|
|
129
|
+
|
|
130
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
|
|
131
|
+
assert envs.USE_MOE_EP_KERNEL is False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
|
|
135
|
+
"""Test that boolean env vars raise errors for invalid values"""
|
|
136
|
+
|
|
137
|
+
# Test invalid value for NEW_MODEL_DESIGN
|
|
138
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
|
|
139
|
+
with pytest.raises(
|
|
140
|
+
ValueError,
|
|
141
|
+
match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
|
|
142
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
143
|
+
|
|
144
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
|
|
145
|
+
with pytest.raises(ValueError,
|
|
146
|
+
match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
|
|
147
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
148
|
+
|
|
149
|
+
# Test invalid value for SKIP_JAX_PRECOMPILE
|
|
150
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
|
|
151
|
+
with pytest.raises(
|
|
152
|
+
ValueError,
|
|
153
|
+
match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
|
|
154
|
+
_ = envs.SKIP_JAX_PRECOMPILE
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
|
|
158
|
+
"""Test that empty string returns default value"""
|
|
159
|
+
|
|
160
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "")
|
|
161
|
+
assert envs.NEW_MODEL_DESIGN is False # Should return default
|
|
162
|
+
|
|
163
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
|
|
164
|
+
assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
168
|
+
# Ensure clean environment for integer vars by setting to defaults
|
|
169
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
|
|
170
|
+
monkeypatch.setenv("NUM_SLICES", "1")
|
|
171
|
+
|
|
172
|
+
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
173
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
174
|
+
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
175
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
176
|
+
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
177
|
+
|
|
178
|
+
# Test NUM_SLICES (default 1)
|
|
179
|
+
assert envs.NUM_SLICES == 1
|
|
180
|
+
monkeypatch.setenv("NUM_SLICES", "2")
|
|
181
|
+
assert envs.NUM_SLICES == 2
|
|
182
|
+
monkeypatch.setenv("NUM_SLICES", "4")
|
|
183
|
+
assert envs.NUM_SLICES == 4
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def test_model_impl_type_choices(monkeypatch: pytest.MonkeyPatch):
|
|
187
|
+
# Test case sensitive choices
|
|
188
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
|
|
189
|
+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
190
|
+
|
|
191
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
|
|
192
|
+
assert envs.MODEL_IMPL_TYPE == "vllm"
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
196
|
+
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
197
|
+
monkeypatch.delenv("PREFILL_SLICES", raising=False)
|
|
198
|
+
monkeypatch.delenv("DECODE_SLICES", raising=False)
|
|
199
|
+
|
|
200
|
+
assert envs.JAX_PLATFORMS == ""
|
|
201
|
+
assert envs.PREFILL_SLICES == ""
|
|
202
|
+
assert envs.DECODE_SLICES == ""
|
|
203
|
+
assert envs.PHASED_PROFILING_DIR == ""
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def test_none_default_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
207
|
+
monkeypatch.delenv("TPU_ACCELERATOR_TYPE", raising=False)
|
|
208
|
+
monkeypatch.delenv("TPU_NAME", raising=False)
|
|
209
|
+
monkeypatch.delenv("TPU_WORKER_ID", raising=False)
|
|
210
|
+
|
|
211
|
+
assert envs.TPU_ACCELERATOR_TYPE is None
|
|
212
|
+
assert envs.TPU_NAME is None
|
|
213
|
+
assert envs.TPU_WORKER_ID is None
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
217
|
+
assert envs.RAY_USAGE_STATS_ENABLED == "0"
|
|
218
|
+
monkeypatch.setenv("RAY_USAGE_STATS_ENABLED", "1")
|
|
219
|
+
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
220
|
+
|
|
221
|
+
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def test_invalid_attribute_raises_error():
|
|
225
|
+
with pytest.raises(AttributeError,
|
|
226
|
+
match="has no attribute 'NONEXISTENT_VAR'"):
|
|
227
|
+
_ = envs.NONEXISTENT_VAR
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def test_dir_returns_all_env_vars():
|
|
231
|
+
env_vars = envs.__dir__()
|
|
232
|
+
assert isinstance(env_vars, list)
|
|
233
|
+
assert len(env_vars) == len(environment_variables)
|
|
234
|
+
assert "JAX_PLATFORMS" in env_vars
|
|
235
|
+
assert "TPU_NAME" in env_vars
|
|
236
|
+
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
237
|
+
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
|
|
238
|
+
assert "MODEL_IMPL_TYPE" in env_vars
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
242
|
+
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
243
|
+
assert envs.TPU_WORKER_ID == "0"
|
|
244
|
+
|
|
245
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "ray")
|
|
246
|
+
assert envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
250
|
+
monkeypatch.setenv("PREFILL_SLICES", "0,1,2,3")
|
|
251
|
+
assert envs.PREFILL_SLICES == "0,1,2,3"
|
|
252
|
+
|
|
253
|
+
monkeypatch.setenv("DECODE_SLICES", "4,5,6,7")
|
|
254
|
+
assert envs.DECODE_SLICES == "4,5,6,7"
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
|
|
258
|
+
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
|
|
259
|
+
assert envs.MODEL_IMPL_TYPE == "auto"
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def test_cache_preserves_values_across_env_changes(
|
|
263
|
+
monkeypatch: pytest.MonkeyPatch):
|
|
264
|
+
monkeypatch.setenv("JAX_PLATFORMS", "tpu")
|
|
265
|
+
|
|
266
|
+
enable_envs_cache()
|
|
267
|
+
|
|
268
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
269
|
+
|
|
270
|
+
# Change environment variable
|
|
271
|
+
monkeypatch.setenv("JAX_PLATFORMS", "cpu")
|
|
272
|
+
|
|
273
|
+
# Cached value should still be "tpu"
|
|
274
|
+
assert envs.JAX_PLATFORMS == "tpu"
|
|
275
|
+
|
|
276
|
+
# Reset envs.__getattr__ back to non-cached version
|
|
277
|
+
envs.__getattr__ = envs.__getattr__.__wrapped__
|
|
278
|
+
|
|
279
|
+
# Now it should reflect the new value
|
|
280
|
+
assert envs.JAX_PLATFORMS == "cpu"
|
tests/test_tpu_info.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
from unittest.mock import MagicMock, patch
|
|
17
|
+
|
|
18
|
+
import pytest
|
|
19
|
+
import requests
|
|
20
|
+
|
|
21
|
+
from tpu_inference.tpu_info import (get_node_name, get_node_worker_id,
|
|
22
|
+
get_num_chips, get_num_cores_per_chip,
|
|
23
|
+
get_tpu_metadata, get_tpu_type)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Mock requests.get for get_tpu_metadata tests
|
|
27
|
+
@patch("tpu_inference.tpu_info.requests.get")
|
|
28
|
+
def test_get_tpu_metadata_success(mock_get):
|
|
29
|
+
"""Test get_tpu_metadata when the request is successful."""
|
|
30
|
+
mock_response = MagicMock()
|
|
31
|
+
mock_response.status_code = 200
|
|
32
|
+
mock_response.text = "test_metadata_value"
|
|
33
|
+
mock_get.return_value = mock_response
|
|
34
|
+
assert get_tpu_metadata(key="test-key") == "test_metadata_value"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@patch("tpu_inference.tpu_info.requests.get")
|
|
38
|
+
def test_get_tpu_metadata_request_error(mock_get):
|
|
39
|
+
"""Test get_tpu_metadata when a RequestException is raised."""
|
|
40
|
+
mock_get.side_effect = requests.RequestException("Test RequestException")
|
|
41
|
+
assert get_tpu_metadata(key="test-key") is None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Test get_tpu_type
|
|
45
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata")
|
|
46
|
+
@patch.dict(os.environ, {"TPU_ACCELERATOR_TYPE": "env_tpu_type"})
|
|
47
|
+
def test_get_tpu_type_from_env(mock_get_tpu_metadata):
|
|
48
|
+
"""Test get_tpu_type when TPU_ACCELERATOR_TYPE is set in environment."""
|
|
49
|
+
# The function should return the env var value and not call get_tpu_metadata
|
|
50
|
+
assert get_tpu_type() == "env_tpu_type"
|
|
51
|
+
mock_get_tpu_metadata.assert_not_called()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@patch.dict(os.environ, {}, clear=True)
|
|
55
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata",
|
|
56
|
+
return_value="metadata_tpu_type")
|
|
57
|
+
def test_get_tpu_type_from_metadata(mock_get_tpu_metadata):
|
|
58
|
+
"""Test get_tpu_type when environment variable is not set."""
|
|
59
|
+
assert get_tpu_type() == "metadata_tpu_type"
|
|
60
|
+
mock_get_tpu_metadata.assert_called_once_with(key="accelerator-type")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Test get_node_name
|
|
64
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata")
|
|
65
|
+
@patch.dict(os.environ, {"TPU_NAME": "env_tpu_name"})
|
|
66
|
+
def test_get_node_name_from_env(mock_get_tpu_metadata):
|
|
67
|
+
"""Test get_node_name when TPU_NAME is set in environment."""
|
|
68
|
+
assert get_node_name() == "env_tpu_name"
|
|
69
|
+
mock_get_tpu_metadata.assert_not_called()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@patch.dict(os.environ, {}, clear=True)
|
|
73
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata",
|
|
74
|
+
return_value="metadata_tpu_name")
|
|
75
|
+
def test_get_node_name_from_metadata(mock_get_tpu_metadata):
|
|
76
|
+
"""Test get_node_name when environment variable is not set."""
|
|
77
|
+
assert get_node_name() == "metadata_tpu_name"
|
|
78
|
+
mock_get_tpu_metadata.assert_called_once_with(key="instance-id")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Test get_node_worker_id
|
|
82
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata")
|
|
83
|
+
@patch.dict(os.environ, {"TPU_WORKER_ID": "5"})
|
|
84
|
+
def test_get_node_worker_id_from_env(mock_get_tpu_metadata):
|
|
85
|
+
"""Test get_node_worker_id when TPU_WORKER_ID is set in environment."""
|
|
86
|
+
assert get_node_worker_id() == 5
|
|
87
|
+
mock_get_tpu_metadata.assert_not_called()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@patch.dict(os.environ, {}, clear=True)
|
|
91
|
+
@patch("tpu_inference.tpu_info.get_tpu_metadata", return_value="10")
|
|
92
|
+
def test_get_node_worker_id_from_metadata(mock_get_tpu_metadata):
|
|
93
|
+
"""Test get_node_worker_id when environment variable is not set."""
|
|
94
|
+
assert get_node_worker_id() == 10
|
|
95
|
+
mock_get_tpu_metadata.assert_called_once_with(key="agent-worker-number")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Test get_num_cores_per_chip
|
|
99
|
+
@pytest.mark.parametrize(
|
|
100
|
+
"tpu_type, expected",
|
|
101
|
+
[
|
|
102
|
+
("v5litepod-4", 1),
|
|
103
|
+
("v6e-8", 1),
|
|
104
|
+
("v4-8", 2),
|
|
105
|
+
("v5p-16", 2),
|
|
106
|
+
("unknown-type", 2) # Default case
|
|
107
|
+
])
|
|
108
|
+
@patch("tpu_inference.tpu_info.get_tpu_type")
|
|
109
|
+
def test_get_num_cores_per_chip(mock_get_tpu_type, tpu_type, expected):
|
|
110
|
+
"""Test get_num_cores_per_chip with different TPU types."""
|
|
111
|
+
mock_get_tpu_type.return_value = tpu_type
|
|
112
|
+
assert get_num_cores_per_chip() == expected
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# Test get_num_chips
|
|
116
|
+
@patch("tpu_inference.tpu_info.glob.glob",
|
|
117
|
+
return_value=["/dev/accel0", "/dev/accel1"])
|
|
118
|
+
def test_get_num_chips_from_accel(mock_glob):
|
|
119
|
+
"""Test get_num_chips when /dev/accel* files exist."""
|
|
120
|
+
assert get_num_chips() == 2
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
|
|
124
|
+
@patch("tpu_inference.tpu_info.os.listdir", return_value=["0", "1", "2"])
|
|
125
|
+
def test_get_num_chips_from_vfio(mock_listdir, mock_glob):
|
|
126
|
+
"""Test get_num_chips when /dev/accel* files don't exist but /dev/vfio entries do."""
|
|
127
|
+
assert get_num_chips() == 3
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@patch("tpu_inference.tpu_info.glob.glob", return_value=[])
|
|
131
|
+
@patch("tpu_inference.tpu_info.os.listdir", side_effect=FileNotFoundError)
|
|
132
|
+
def test_get_num_chips_not_found(mock_listdir, mock_glob, caplog):
|
|
133
|
+
"""Test get_num_chips when neither files nor directory are found."""
|
|
134
|
+
assert get_num_chips() == 0
|
tests/test_utils.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import os
|
|
3
|
+
from unittest.mock import MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
# Import the functions to be tested
|
|
9
|
+
from tpu_inference.utils import (GBYTES, enable_megacore,
|
|
10
|
+
get_jax_dtype_from_str_dtype, get_megacore,
|
|
11
|
+
get_padded_head_dim, hbm_usage_bytes,
|
|
12
|
+
hbm_usage_gb)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_enable_and_get_megacore():
|
|
16
|
+
"""Tests the enable_megacore and get_megacore functions."""
|
|
17
|
+
assert not get_megacore()
|
|
18
|
+
enable_megacore()
|
|
19
|
+
assert get_megacore()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@patch.dict(os.environ, {"TPU_MULTIHOST_BACKEND": "ray"})
|
|
23
|
+
def test_hbm_usage_bytes_ray_backend():
|
|
24
|
+
"""Tests hbm_usage_bytes when TPU_MULTIHOST_BACKEND is ray."""
|
|
25
|
+
mock_device1 = MagicMock()
|
|
26
|
+
mock_device1.memory_stats.return_value = {
|
|
27
|
+
"bytes_in_use": 100 * GBYTES,
|
|
28
|
+
"bytes_limit": 128 * GBYTES
|
|
29
|
+
}
|
|
30
|
+
mock_device2 = MagicMock()
|
|
31
|
+
mock_device2.memory_stats.side_effect = Exception("Memory stats failed")
|
|
32
|
+
|
|
33
|
+
devices = [mock_device1, mock_device2]
|
|
34
|
+
usage = hbm_usage_bytes(devices)
|
|
35
|
+
|
|
36
|
+
expected_usage = [(100 * GBYTES, 128 * GBYTES),
|
|
37
|
+
(100 * GBYTES, 128 * GBYTES)]
|
|
38
|
+
assert usage == expected_usage
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
|
|
42
|
+
def test_hbm_usage_bytes_pathways_disabled():
|
|
43
|
+
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is False."""
|
|
44
|
+
mock_device1 = MagicMock()
|
|
45
|
+
mock_device1.memory_stats.return_value = {
|
|
46
|
+
"bytes_in_use": 100 * GBYTES,
|
|
47
|
+
"bytes_limit": 128 * GBYTES
|
|
48
|
+
}
|
|
49
|
+
mock_device2 = MagicMock()
|
|
50
|
+
mock_device2.memory_stats.return_value = {
|
|
51
|
+
"bytes_in_use": 50 * GBYTES,
|
|
52
|
+
"bytes_limit": 128 * GBYTES
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
devices = [mock_device1, mock_device2]
|
|
56
|
+
usage = hbm_usage_bytes(devices)
|
|
57
|
+
|
|
58
|
+
expected_usage = [(100 * GBYTES, 128 * GBYTES),
|
|
59
|
+
(50 * GBYTES, 128 * GBYTES)]
|
|
60
|
+
assert usage == expected_usage
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
|
|
64
|
+
@patch("jax.live_arrays")
|
|
65
|
+
@patch("jax.devices")
|
|
66
|
+
def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
|
|
67
|
+
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True."""
|
|
68
|
+
# Mock TPU v5p devices
|
|
69
|
+
mock_jax_device = MagicMock()
|
|
70
|
+
mock_jax_device.device_kind = "TPU v5p"
|
|
71
|
+
mock_devices.return_value = [mock_jax_device]
|
|
72
|
+
|
|
73
|
+
# Create mock devices
|
|
74
|
+
mock_device1 = MagicMock()
|
|
75
|
+
mock_device2 = MagicMock()
|
|
76
|
+
devices = [mock_device1, mock_device2]
|
|
77
|
+
|
|
78
|
+
# Create mock addressable shards with data property
|
|
79
|
+
mock_data1_dev1 = MagicMock()
|
|
80
|
+
mock_data1_dev1.device = mock_device1
|
|
81
|
+
mock_data1_dev1.nbytes = 2000 # 2000 bytes on device1
|
|
82
|
+
|
|
83
|
+
mock_data1_dev2 = MagicMock()
|
|
84
|
+
mock_data1_dev2.device = mock_device2
|
|
85
|
+
mock_data1_dev2.nbytes = 2000 # 2000 bytes on device2
|
|
86
|
+
|
|
87
|
+
mock_data2_dev1 = MagicMock()
|
|
88
|
+
mock_data2_dev1.device = mock_device1
|
|
89
|
+
mock_data2_dev1.nbytes = 1000 # 1000 bytes on device1
|
|
90
|
+
|
|
91
|
+
mock_shard1_dev1 = MagicMock()
|
|
92
|
+
mock_shard1_dev1.data = mock_data1_dev1
|
|
93
|
+
|
|
94
|
+
mock_shard1_dev2 = MagicMock()
|
|
95
|
+
mock_shard1_dev2.data = mock_data1_dev2
|
|
96
|
+
|
|
97
|
+
mock_shard2_dev1 = MagicMock()
|
|
98
|
+
mock_shard2_dev1.data = mock_data2_dev1
|
|
99
|
+
|
|
100
|
+
# Create mock arrays with addressable_shards
|
|
101
|
+
mock_array1 = MagicMock()
|
|
102
|
+
mock_array1.addressable_shards = [mock_shard1_dev1, mock_shard1_dev2]
|
|
103
|
+
|
|
104
|
+
mock_array2 = MagicMock()
|
|
105
|
+
mock_array2.addressable_shards = [mock_shard2_dev1]
|
|
106
|
+
|
|
107
|
+
mock_live_arrays.return_value = [mock_array1, mock_array2]
|
|
108
|
+
|
|
109
|
+
usage = hbm_usage_bytes(devices)
|
|
110
|
+
|
|
111
|
+
# Expected calculations:
|
|
112
|
+
# Array1: 2000 bytes on device1, 2000 bytes on device2
|
|
113
|
+
# Array2: 1000 bytes on device1
|
|
114
|
+
# Device1 total: 2000 + 1000 = 3000 bytes
|
|
115
|
+
# Device2 total: 2000 + 0 = 2000 bytes
|
|
116
|
+
# hbm_limit = 95 * GBYTES for TPU v5p
|
|
117
|
+
expected_usage = [(3000, 95 * GBYTES), (2000, 95 * GBYTES)]
|
|
118
|
+
assert usage == expected_usage
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", False)
|
|
122
|
+
def test_hbm_usage_gb_pathways_disabled():
|
|
123
|
+
"""Tests hbm_usage_gb when VLLM_TPU_USING_PATHWAYS is False."""
|
|
124
|
+
mock_device1 = MagicMock()
|
|
125
|
+
mock_device1.memory_stats.return_value = {
|
|
126
|
+
"bytes_in_use": 100 * GBYTES,
|
|
127
|
+
"bytes_limit": 128 * GBYTES
|
|
128
|
+
}
|
|
129
|
+
mock_device2 = MagicMock()
|
|
130
|
+
mock_device2.memory_stats.return_value = {
|
|
131
|
+
"bytes_in_use": 50.5 * GBYTES,
|
|
132
|
+
"bytes_limit": 128.0 * GBYTES
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
devices = [mock_device1, mock_device2]
|
|
136
|
+
usage = hbm_usage_gb(devices)
|
|
137
|
+
|
|
138
|
+
expected_usage = [(100.0, 128.0), (50.5, 128.0)]
|
|
139
|
+
assert usage == expected_usage
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@patch("vllm.envs.VLLM_TPU_USING_PATHWAYS", True)
|
|
143
|
+
@patch("jax.live_arrays")
|
|
144
|
+
@patch("jax.devices")
|
|
145
|
+
def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
|
|
146
|
+
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True but no live arrays."""
|
|
147
|
+
# Mock TPU v6e devices
|
|
148
|
+
mock_jax_device = MagicMock()
|
|
149
|
+
mock_jax_device.device_kind = "TPU v6e"
|
|
150
|
+
mock_devices.return_value = [mock_jax_device]
|
|
151
|
+
|
|
152
|
+
mock_device1 = MagicMock()
|
|
153
|
+
mock_device2 = MagicMock()
|
|
154
|
+
devices = [mock_device1, mock_device2]
|
|
155
|
+
|
|
156
|
+
# No live arrays
|
|
157
|
+
mock_live_arrays.return_value = []
|
|
158
|
+
|
|
159
|
+
usage = hbm_usage_bytes(devices)
|
|
160
|
+
|
|
161
|
+
# No arrays means no memory usage, defaultdict returns 0 for missing keys
|
|
162
|
+
# HBM limit for TPU v6e is 32 GB
|
|
163
|
+
expected_usage = [(0, 32 * GBYTES), (0, 32 * GBYTES)]
|
|
164
|
+
assert usage == expected_usage
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@pytest.mark.parametrize(
|
|
168
|
+
"head_dim, expected_padded_head_dim",
|
|
169
|
+
[
|
|
170
|
+
(1, 128),
|
|
171
|
+
(64, 64),
|
|
172
|
+
(127, 128),
|
|
173
|
+
(128, 128),
|
|
174
|
+
(129, 256),
|
|
175
|
+
(255, 256),
|
|
176
|
+
(256, 256),
|
|
177
|
+
(0, 0), # Although head_dim is usually positive, testing boundary
|
|
178
|
+
],
|
|
179
|
+
)
|
|
180
|
+
def test_get_padded_head_dim(head_dim, expected_padded_head_dim):
|
|
181
|
+
"""Tests the get_padded_head_dim function."""
|
|
182
|
+
assert get_padded_head_dim(head_dim) == expected_padded_head_dim
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def test_get_jax_dtype_from_str_dtype():
|
|
186
|
+
"""
|
|
187
|
+
Test the get_jax_dtype_from_str_dtype function
|
|
188
|
+
"""
|
|
189
|
+
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
190
|
+
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
191
|
+
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
192
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
|
|
193
|
+
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
tests/worker/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|