tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
|
|
4
|
+
from unittest import mock
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pytest
|
|
10
|
+
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
|
11
|
+
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
|
12
|
+
VllmConfig)
|
|
13
|
+
from vllm.config.load import LoadConfig
|
|
14
|
+
|
|
15
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
16
|
+
from tpu_inference.runner import utils as runner_utils
|
|
17
|
+
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
18
|
+
|
|
19
|
+
# Use a real model dir for config, but we will mock model loading/execution
|
|
20
|
+
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
|
21
|
+
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _create_proposer(
|
|
25
|
+
method: str,
|
|
26
|
+
num_speculative_tokens: int,
|
|
27
|
+
) -> Eagle3Proposer:
|
|
28
|
+
model_config = ModelConfig(model=model_dir,
|
|
29
|
+
runner="generate",
|
|
30
|
+
max_model_len=8192,
|
|
31
|
+
seed=42)
|
|
32
|
+
|
|
33
|
+
speculative_config = SpeculativeConfig(
|
|
34
|
+
target_model_config=model_config,
|
|
35
|
+
target_parallel_config=ParallelConfig(),
|
|
36
|
+
model=eagle3_dir,
|
|
37
|
+
method=method,
|
|
38
|
+
num_speculative_tokens=num_speculative_tokens,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
vllm_config = VllmConfig(model_config=model_config,
|
|
42
|
+
cache_config=CacheConfig(block_size=16),
|
|
43
|
+
speculative_config=speculative_config,
|
|
44
|
+
device_config=DeviceConfig(device="tpu"),
|
|
45
|
+
parallel_config=ParallelConfig(
|
|
46
|
+
pipeline_parallel_size=1,
|
|
47
|
+
tensor_parallel_size=1),
|
|
48
|
+
load_config=LoadConfig(),
|
|
49
|
+
scheduler_config=SchedulerConfig(
|
|
50
|
+
max_num_batched_tokens=8192,
|
|
51
|
+
max_num_seqs=128,
|
|
52
|
+
max_model_len=model_config.max_model_len,
|
|
53
|
+
is_encoder_decoder=False))
|
|
54
|
+
|
|
55
|
+
# Mock the runner, as the proposer needs it for initialization
|
|
56
|
+
mock_runner = mock.MagicMock()
|
|
57
|
+
# Create a real mesh for testing sharding-related logic
|
|
58
|
+
devices = np.array(jax.devices())
|
|
59
|
+
mock_runner.mesh = jax.sharding.Mesh(devices, axis_names=('model', ))
|
|
60
|
+
mock_runner.max_num_tokens = 8192
|
|
61
|
+
mock_runner.max_model_len = 8192
|
|
62
|
+
mock_runner.kv_cache_config.kv_cache_groups = [mock.MagicMock()]
|
|
63
|
+
mock_runner.input_batch = mock.MagicMock()
|
|
64
|
+
|
|
65
|
+
return Eagle3Proposer(vllm_config=vllm_config, runner=mock_runner)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_prepare_inputs():
|
|
69
|
+
"""
|
|
70
|
+
Mirrors the GPU test for prepare_inputs, adapted for JAX.
|
|
71
|
+
- cu_target_query_lens: [0, a, a + b, a + b + c]
|
|
72
|
+
- num_rejected_tokens: [n1, n2, n3]
|
|
73
|
+
- num_tokens_per_req: [a - n1, b - n2, c - n3]
|
|
74
|
+
- cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
|
75
|
+
- token_indices: [0, ..., a - n1 - 1, a, ..., a + b - n2 - 1, ...]
|
|
76
|
+
"""
|
|
77
|
+
proposer = _create_proposer("eagle3", 1)
|
|
78
|
+
num_reqs = 3
|
|
79
|
+
max_num_seqs = 128
|
|
80
|
+
max_num_blocks_per_req = 10 # Mock value
|
|
81
|
+
|
|
82
|
+
# Mock runner attributes
|
|
83
|
+
proposer.runner.input_batch.num_reqs = num_reqs
|
|
84
|
+
proposer.runner.num_tokens_paddings = runner_utils.get_token_paddings(
|
|
85
|
+
min_token_size=16, max_token_size=1024, padding_gap=0)
|
|
86
|
+
|
|
87
|
+
# Mocks required by _prepare_draft_inputs helper
|
|
88
|
+
proposer.combine_hidden_states_fn = lambda state, h: h # Mock passthrough
|
|
89
|
+
proposer.state = None # Mock state
|
|
90
|
+
proposer.runner.input_batch.block_table = [mock.MagicMock()]
|
|
91
|
+
# Mock the block table return value (2D array)
|
|
92
|
+
(proposer.runner.input_batch.block_table[0].get_cpu_tensor.return_value
|
|
93
|
+
) = jnp.zeros((num_reqs, max_num_blocks_per_req), dtype=jnp.int32)
|
|
94
|
+
|
|
95
|
+
# --- Setup sequence data ---
|
|
96
|
+
qsl_cpu = np.zeros(max_num_seqs + 1, dtype=np.int32)
|
|
97
|
+
query_lens = np.zeros(max_num_seqs, dtype=np.int32)
|
|
98
|
+
query_lens[:num_reqs] = [4, 7, 5]
|
|
99
|
+
qsl_cpu[1:] = np.cumsum(query_lens)
|
|
100
|
+
|
|
101
|
+
sl_cpu = np.zeros(max_num_seqs, dtype=np.int32)
|
|
102
|
+
sl_cpu[:num_reqs] = [4, 7, 5]
|
|
103
|
+
|
|
104
|
+
# Inputs
|
|
105
|
+
total_tokens = 16
|
|
106
|
+
hidden_size = 128
|
|
107
|
+
# The input_ids should be large enough to be indexed by token_indices,
|
|
108
|
+
# which can access up to total_tokens for padded requests.
|
|
109
|
+
input_ids = jnp.arange(total_tokens + 1)
|
|
110
|
+
aux_hidden_states = (jnp.ones((total_tokens + 1, hidden_size)),
|
|
111
|
+
jnp.ones((total_tokens + 1, hidden_size)),
|
|
112
|
+
jnp.ones((total_tokens + 1, hidden_size)))
|
|
113
|
+
|
|
114
|
+
num_rejected_tokens_cpu = np.zeros(max_num_seqs, dtype=np.int32)
|
|
115
|
+
num_rejected_tokens_cpu[:num_reqs] = [1, 3, 2]
|
|
116
|
+
num_rejected_tokens = jnp.array(num_rejected_tokens_cpu)
|
|
117
|
+
# This is only used in the _prepare_input_ids helper
|
|
118
|
+
# It must be padded to max_num_seqs (128) to match the mask in jnp.where
|
|
119
|
+
next_token_ids_cpu = np.zeros(max_num_seqs, dtype=np.int32)
|
|
120
|
+
next_token_ids_cpu[:num_reqs] = [1, 2, 3] # Valid tokens for active reqs
|
|
121
|
+
next_token_ids = jnp.array(next_token_ids_cpu)
|
|
122
|
+
|
|
123
|
+
attn_metadata = AttentionMetadata(
|
|
124
|
+
seq_lens=jnp.array(sl_cpu),
|
|
125
|
+
input_positions=jnp.arange(total_tokens),
|
|
126
|
+
query_start_loc=jnp.array(qsl_cpu),
|
|
127
|
+
block_tables=jnp.array([]), # This will be replaced by the mock
|
|
128
|
+
request_distribution=None,
|
|
129
|
+
)
|
|
130
|
+
attn_metadata.query_start_loc_cpu = qsl_cpu
|
|
131
|
+
attn_metadata.seq_lens_cpu = sl_cpu
|
|
132
|
+
|
|
133
|
+
# Expected results
|
|
134
|
+
expected_new_qsl = np.zeros(max_num_seqs + 1, dtype=np.int32)
|
|
135
|
+
num_tokens_per_req = np.zeros(max_num_seqs, dtype=np.int32)
|
|
136
|
+
num_tokens_per_req[:num_reqs] = [3, 4, 3]
|
|
137
|
+
# The implementation sets padded query lengths to 1, and rejected tokens
|
|
138
|
+
# are 0 for padded requests.
|
|
139
|
+
num_tokens_per_req[num_reqs:] = 1
|
|
140
|
+
expected_new_qsl[1:] = np.cumsum(num_tokens_per_req)
|
|
141
|
+
|
|
142
|
+
expected_new_seq_lens = np.zeros(max_num_seqs, dtype=np.int32)
|
|
143
|
+
expected_new_seq_lens[:num_reqs] = [3, 4, 3]
|
|
144
|
+
|
|
145
|
+
expected_total_tokens = int(expected_new_qsl[-1])
|
|
146
|
+
expected_total_tokens = runner_utils.get_padded_token_len(
|
|
147
|
+
proposer.runner.num_tokens_paddings, expected_total_tokens)
|
|
148
|
+
|
|
149
|
+
expected_last_token_indices = jnp.array(expected_new_qsl[1:] - 1)
|
|
150
|
+
|
|
151
|
+
# Execute
|
|
152
|
+
target_hidden_states, input_ids, last_token_indices, updated_metadata = (
|
|
153
|
+
proposer.prepare_inputs(attn_metadata, input_ids, aux_hidden_states,
|
|
154
|
+
next_token_ids, num_rejected_tokens))
|
|
155
|
+
|
|
156
|
+
# Assertions
|
|
157
|
+
assert jnp.array_equal(updated_metadata.query_start_loc,
|
|
158
|
+
jnp.array(expected_new_qsl))
|
|
159
|
+
assert jnp.array_equal(updated_metadata.seq_lens,
|
|
160
|
+
jnp.array(expected_new_seq_lens))
|
|
161
|
+
|
|
162
|
+
assert jnp.array_equal(last_token_indices, expected_last_token_indices)
|
|
163
|
+
|
|
164
|
+
assert input_ids.shape == (expected_total_tokens, )
|
|
165
|
+
# NOTE: We don't check the content of target_token_ids for padded requests
|
|
166
|
+
# as it's complicated to construct the expected tensor. The shape check
|
|
167
|
+
# and the qsl/seq_len checks are sufficient to validate the logic.
|
|
168
|
+
# The concatenated hidden state shape should be (..., hidden_size * 3)
|
|
169
|
+
assert target_hidden_states.shape == (expected_total_tokens,
|
|
170
|
+
hidden_size * 3)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@pytest.mark.parametrize("method", ["eagle3"])
|
|
174
|
+
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
|
175
|
+
def test_propose(method, num_speculative_tokens):
|
|
176
|
+
proposer = _create_proposer(method, num_speculative_tokens)
|
|
177
|
+
|
|
178
|
+
# Mock the JAX model functions
|
|
179
|
+
hidden_size = 128
|
|
180
|
+
vocab_size = 100
|
|
181
|
+
batch_size = 2
|
|
182
|
+
seq_len_1 = 5
|
|
183
|
+
seq_len_2 = 3
|
|
184
|
+
total_tokens = seq_len_1 + seq_len_2
|
|
185
|
+
base_token_ids = [42, 60]
|
|
186
|
+
|
|
187
|
+
def mock_model_fn(state, kv_caches, input_ids, target_hidden_states,
|
|
188
|
+
attn_metadata):
|
|
189
|
+
"""
|
|
190
|
+
Mock model_fn.
|
|
191
|
+
Returns: (kv_caches, hidden_states_for_logits, residual_tuple)
|
|
192
|
+
|
|
193
|
+
- On first call (num_tokens == total_tokens):
|
|
194
|
+
Populate hidden_states_for_logits[last_token_indices] with base_token_ids.
|
|
195
|
+
Populate residual_tuple[0][last_token_indices] with base_token_ids.
|
|
196
|
+
- On loop calls (num_tokens == batch_size):
|
|
197
|
+
Use input_ids (previous draft token) to generate new token (input_ids + 1).
|
|
198
|
+
Populate hidden_states_for_logits with (input_ids + 1).
|
|
199
|
+
Populate residual_tuple[0] with (input_ids + 1).
|
|
200
|
+
"""
|
|
201
|
+
num_tokens = input_ids.shape[0]
|
|
202
|
+
|
|
203
|
+
# This will be used for logits (output 2)
|
|
204
|
+
hidden_states_for_logits = jnp.zeros((num_tokens, hidden_size))
|
|
205
|
+
# This will be fed into the next step (output 3, item 0)
|
|
206
|
+
residual_hidden_states = jnp.zeros((num_tokens, hidden_size))
|
|
207
|
+
|
|
208
|
+
if num_tokens == total_tokens:
|
|
209
|
+
# First call in propose.
|
|
210
|
+
# `propose` will select from last_token_indices.
|
|
211
|
+
last_token_indices = attn_metadata.query_start_loc[1:] - 1
|
|
212
|
+
|
|
213
|
+
# Set logits output
|
|
214
|
+
hidden_states_for_logits = hidden_states_for_logits.at[
|
|
215
|
+
last_token_indices, 0].set(jnp.array(base_token_ids))
|
|
216
|
+
|
|
217
|
+
# Set residual for next step
|
|
218
|
+
residual_hidden_states = residual_hidden_states.at[
|
|
219
|
+
last_token_indices, 0].set(jnp.array(base_token_ids))
|
|
220
|
+
else:
|
|
221
|
+
# Subsequent calls in the loop
|
|
222
|
+
# input_ids is the previous draft token (shape `batch_size`)
|
|
223
|
+
# Mock logic: next token = previous token + 1
|
|
224
|
+
next_token_ids_encoded = input_ids + 1
|
|
225
|
+
|
|
226
|
+
# Set logits output
|
|
227
|
+
hidden_states_for_logits = hidden_states_for_logits.at[:, 0].set(
|
|
228
|
+
next_token_ids_encoded)
|
|
229
|
+
|
|
230
|
+
# Set residual for next step
|
|
231
|
+
residual_hidden_states = residual_hidden_states.at[:, 0].set(
|
|
232
|
+
next_token_ids_encoded)
|
|
233
|
+
|
|
234
|
+
# Return (kv_caches, hidden_states, residual_tuple)
|
|
235
|
+
return kv_caches, hidden_states_for_logits, (residual_hidden_states, )
|
|
236
|
+
|
|
237
|
+
def mock_compute_logits_fn(state, hidden_states, lora_metadata):
|
|
238
|
+
# Create deterministic logits from hidden_states.
|
|
239
|
+
# Takes the value from hidden_states[:, 0]
|
|
240
|
+
token_ids = hidden_states[:, 0].astype(jnp.int32)
|
|
241
|
+
return jax.nn.one_hot(token_ids, vocab_size)
|
|
242
|
+
|
|
243
|
+
def mock_combine_hidden_states_fn(state, hidden_states):
|
|
244
|
+
# Passthrough, as the mock doesn't need combination.
|
|
245
|
+
return hidden_states
|
|
246
|
+
|
|
247
|
+
proposer.model_fn = mock_model_fn
|
|
248
|
+
proposer.compute_logits_fn = mock_compute_logits_fn
|
|
249
|
+
proposer.combine_hidden_states_fn = mock_combine_hidden_states_fn
|
|
250
|
+
proposer.state = None # Mock state
|
|
251
|
+
|
|
252
|
+
# Inputs
|
|
253
|
+
kv_caches = [None] * 1 # Mock kv_caches
|
|
254
|
+
|
|
255
|
+
# Create the 2D table first, as this is what the (unused) mock expects
|
|
256
|
+
block_tables_2d = jnp.zeros((batch_size, 10), dtype=jnp.int32)
|
|
257
|
+
|
|
258
|
+
attn_metadata = AttentionMetadata(
|
|
259
|
+
seq_lens=jnp.array([seq_len_1, seq_len_2]),
|
|
260
|
+
input_positions=jnp.concatenate(
|
|
261
|
+
[jnp.arange(seq_len_1),
|
|
262
|
+
jnp.arange(seq_len_2)]),
|
|
263
|
+
query_start_loc=jnp.array([0, seq_len_1, total_tokens]),
|
|
264
|
+
# Pass the FLATTENED table to simulate output of prepare_inputs
|
|
265
|
+
block_tables=block_tables_2d.reshape(-1),
|
|
266
|
+
request_distribution=None,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# These are the inputs to `propose`
|
|
270
|
+
# input_ids (from prepare_inputs)
|
|
271
|
+
target_token_ids = jnp.zeros(total_tokens, dtype=jnp.int32)
|
|
272
|
+
# target_hidden_states (from prepare_inputs)
|
|
273
|
+
target_hidden_states = jnp.zeros((total_tokens, hidden_size))
|
|
274
|
+
# last_token_indices (from prepare_inputs)
|
|
275
|
+
last_token_indices = attn_metadata.query_start_loc[1:] - 1
|
|
276
|
+
|
|
277
|
+
# Mock runner for block tables
|
|
278
|
+
# This mock isn't actually used by propose(), but we'll set it
|
|
279
|
+
# to the 2D table for correctness, as that's what
|
|
280
|
+
# _prepare_draft_inputs (called by prepare_inputs) would expect.
|
|
281
|
+
proposer.runner.input_batch.num_reqs = batch_size
|
|
282
|
+
proposer.runner.input_batch.block_table = [mock.MagicMock()]
|
|
283
|
+
(proposer.runner.input_batch.block_table[0].get_device_tensor.return_value
|
|
284
|
+
) = block_tables_2d
|
|
285
|
+
|
|
286
|
+
# Execute
|
|
287
|
+
_, draft_token_ids = proposer.propose(
|
|
288
|
+
kv_caches,
|
|
289
|
+
target_token_ids,
|
|
290
|
+
attn_metadata,
|
|
291
|
+
last_token_indices,
|
|
292
|
+
target_hidden_states,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if draft_token_ids.ndim == 1:
|
|
296
|
+
draft_token_ids = jnp.expand_dims(draft_token_ids, axis=-1)
|
|
297
|
+
# Assertions
|
|
298
|
+
assert draft_token_ids.shape == (batch_size, num_speculative_tokens)
|
|
299
|
+
|
|
300
|
+
# Check the generated tokens
|
|
301
|
+
# Step 0: base_token_ids [42, 60]
|
|
302
|
+
# Step 1: [43, 61]
|
|
303
|
+
# Step 2: [44, 62]
|
|
304
|
+
# ...
|
|
305
|
+
expected_tokens = np.zeros((batch_size, num_speculative_tokens),
|
|
306
|
+
dtype=np.int64)
|
|
307
|
+
for i in range(batch_size):
|
|
308
|
+
for j in range(num_speculative_tokens):
|
|
309
|
+
expected_tokens[i, j] = base_token_ids[i] + j
|
|
310
|
+
|
|
311
|
+
assert jnp.array_equal(draft_token_ids, jnp.array(expected_tokens))
|
tests/test_base.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import logging
|
|
2
16
|
import unittest
|
|
3
17
|
import warnings
|
tests/test_envs.py
CHANGED
|
@@ -56,6 +56,13 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
59
|
+
# Ensure clean environment for boolean vars by setting to default "0"
|
|
60
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
61
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
62
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
|
|
63
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
|
|
64
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
|
|
65
|
+
|
|
59
66
|
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
60
67
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
61
68
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
|
|
@@ -63,6 +70,13 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
63
70
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
64
71
|
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
65
72
|
|
|
73
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION (default False)
|
|
74
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
75
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "1")
|
|
76
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
77
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
78
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
79
|
+
|
|
66
80
|
# Test NEW_MODEL_DESIGN (default False)
|
|
67
81
|
assert envs.NEW_MODEL_DESIGN is False
|
|
68
82
|
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
|
|
@@ -73,22 +87,110 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
73
87
|
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
|
|
74
88
|
assert envs.USE_MOE_EP_KERNEL is True
|
|
75
89
|
|
|
90
|
+
# Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
|
|
91
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
|
|
92
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
|
|
93
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
|
|
97
|
+
"""Test that boolean env vars accept string values like 'True' and 'False'"""
|
|
98
|
+
|
|
99
|
+
# Test NEW_MODEL_DESIGN with string "True"
|
|
100
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
|
|
101
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
102
|
+
|
|
103
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
|
|
104
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
105
|
+
|
|
106
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
|
|
107
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
108
|
+
|
|
109
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
|
|
110
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
111
|
+
|
|
112
|
+
# Test SKIP_JAX_PRECOMPILE with string values
|
|
113
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
|
|
114
|
+
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
115
|
+
|
|
116
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
|
|
117
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
118
|
+
|
|
119
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION with string values
|
|
120
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
|
|
121
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
122
|
+
|
|
123
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
|
|
124
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
125
|
+
|
|
126
|
+
# Test USE_MOE_EP_KERNEL with string values
|
|
127
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
|
|
128
|
+
assert envs.USE_MOE_EP_KERNEL is True
|
|
129
|
+
|
|
130
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
|
|
131
|
+
assert envs.USE_MOE_EP_KERNEL is False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
|
|
135
|
+
"""Test that boolean env vars raise errors for invalid values"""
|
|
136
|
+
|
|
137
|
+
# Test invalid value for NEW_MODEL_DESIGN
|
|
138
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
|
|
139
|
+
with pytest.raises(
|
|
140
|
+
ValueError,
|
|
141
|
+
match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
|
|
142
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
143
|
+
|
|
144
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
|
|
145
|
+
with pytest.raises(ValueError,
|
|
146
|
+
match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
|
|
147
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
148
|
+
|
|
149
|
+
# Test invalid value for SKIP_JAX_PRECOMPILE
|
|
150
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
|
|
151
|
+
with pytest.raises(
|
|
152
|
+
ValueError,
|
|
153
|
+
match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
|
|
154
|
+
_ = envs.SKIP_JAX_PRECOMPILE
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
|
|
158
|
+
"""Test that empty string returns default value"""
|
|
159
|
+
|
|
160
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "")
|
|
161
|
+
assert envs.NEW_MODEL_DESIGN is False # Should return default
|
|
162
|
+
|
|
163
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
|
|
164
|
+
assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
|
|
165
|
+
|
|
76
166
|
|
|
77
167
|
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
168
|
+
# Ensure clean environment for integer vars by setting to defaults
|
|
169
|
+
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")
|
|
170
|
+
monkeypatch.setenv("NUM_SLICES", "1")
|
|
171
|
+
|
|
78
172
|
assert envs.PYTHON_TRACER_LEVEL == 1
|
|
79
173
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "3")
|
|
80
174
|
assert envs.PYTHON_TRACER_LEVEL == 3
|
|
81
175
|
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "0")
|
|
82
176
|
assert envs.PYTHON_TRACER_LEVEL == 0
|
|
83
177
|
|
|
178
|
+
# Test NUM_SLICES (default 1)
|
|
179
|
+
assert envs.NUM_SLICES == 1
|
|
180
|
+
monkeypatch.setenv("NUM_SLICES", "2")
|
|
181
|
+
assert envs.NUM_SLICES == 2
|
|
182
|
+
monkeypatch.setenv("NUM_SLICES", "4")
|
|
183
|
+
assert envs.NUM_SLICES == 4
|
|
84
184
|
|
|
85
|
-
def test_lowercase_conversion(monkeypatch: pytest.MonkeyPatch):
|
|
86
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "GRPC")
|
|
87
|
-
assert envs.TPU_MULTIHOST_BACKEND == "grpc"
|
|
88
185
|
|
|
89
|
-
|
|
186
|
+
def test_model_impl_type_choices(monkeypatch: pytest.MonkeyPatch):
|
|
187
|
+
# Test case sensitive choices
|
|
188
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "flax_nnx")
|
|
90
189
|
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
91
190
|
|
|
191
|
+
monkeypatch.setenv("MODEL_IMPL_TYPE", "vllm")
|
|
192
|
+
assert envs.MODEL_IMPL_TYPE == "vllm"
|
|
193
|
+
|
|
92
194
|
|
|
93
195
|
def test_string_env_vars_defaults(monkeypatch: pytest.MonkeyPatch):
|
|
94
196
|
monkeypatch.delenv("JAX_PLATFORMS", raising=False)
|
|
@@ -117,8 +219,6 @@ def test_ray_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
117
219
|
assert envs.RAY_USAGE_STATS_ENABLED == "1"
|
|
118
220
|
|
|
119
221
|
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "shm"
|
|
120
|
-
monkeypatch.setenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "nccl")
|
|
121
|
-
assert envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"
|
|
122
222
|
|
|
123
223
|
|
|
124
224
|
def test_invalid_attribute_raises_error():
|
|
@@ -134,6 +234,7 @@ def test_dir_returns_all_env_vars():
|
|
|
134
234
|
assert "JAX_PLATFORMS" in env_vars
|
|
135
235
|
assert "TPU_NAME" in env_vars
|
|
136
236
|
assert "SKIP_JAX_PRECOMPILE" in env_vars
|
|
237
|
+
assert "VLLM_XLA_CHECK_RECOMPILATION" in env_vars
|
|
137
238
|
assert "MODEL_IMPL_TYPE" in env_vars
|
|
138
239
|
|
|
139
240
|
|
|
@@ -141,11 +242,8 @@ def test_tpu_multihost_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
141
242
|
monkeypatch.setenv("TPU_WORKER_ID", "0")
|
|
142
243
|
assert envs.TPU_WORKER_ID == "0"
|
|
143
244
|
|
|
144
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "
|
|
145
|
-
assert envs.TPU_MULTIHOST_BACKEND == "
|
|
146
|
-
|
|
147
|
-
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "xla")
|
|
148
|
-
assert envs.TPU_MULTIHOST_BACKEND == "xla"
|
|
245
|
+
monkeypatch.setenv("TPU_MULTIHOST_BACKEND", "ray")
|
|
246
|
+
assert envs.TPU_MULTIHOST_BACKEND == "ray"
|
|
149
247
|
|
|
150
248
|
|
|
151
249
|
def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
@@ -158,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
158
256
|
|
|
159
257
|
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
|
|
160
258
|
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
|
|
161
|
-
assert envs.MODEL_IMPL_TYPE == "
|
|
259
|
+
assert envs.MODEL_IMPL_TYPE == "auto"
|
|
162
260
|
|
|
163
261
|
|
|
164
262
|
def test_cache_preserves_values_across_env_changes(
|
tests/test_tpu_info.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import os
|
|
2
16
|
from unittest.mock import MagicMock, patch
|
|
3
17
|
|
tests/test_utils.py
CHANGED
|
@@ -9,7 +9,7 @@ import pytest
|
|
|
9
9
|
from tpu_inference.utils import (GBYTES, enable_megacore,
|
|
10
10
|
get_jax_dtype_from_str_dtype, get_megacore,
|
|
11
11
|
get_padded_head_dim, hbm_usage_bytes,
|
|
12
|
-
hbm_usage_gb
|
|
12
|
+
hbm_usage_gb)
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def test_enable_and_get_megacore():
|
|
@@ -182,48 +182,6 @@ def test_get_padded_head_dim(head_dim, expected_padded_head_dim):
|
|
|
182
182
|
assert get_padded_head_dim(head_dim) == expected_padded_head_dim
|
|
183
183
|
|
|
184
184
|
|
|
185
|
-
def test_quantize_kv_float8_e4m3fn():
|
|
186
|
-
"""Tests the quantize_kv function with float8_e4m3fn dtype."""
|
|
187
|
-
key = jnp.array([-1.0, 0.5, 1.0, 1.5])
|
|
188
|
-
value = jnp.array([2.0, 0.0, -2.0, -3.0])
|
|
189
|
-
kv_cache_quantized_dtype = jnp.float8_e4m3fn
|
|
190
|
-
k_scale = 0.1
|
|
191
|
-
v_scale = 0.2
|
|
192
|
-
|
|
193
|
-
quantized_key, quantized_value = quantize_kv(key, value,
|
|
194
|
-
kv_cache_quantized_dtype,
|
|
195
|
-
k_scale, v_scale)
|
|
196
|
-
|
|
197
|
-
# Expected key: key / k_scale -> clip -> astype
|
|
198
|
-
# [-10., 5., 10., 15.] are within float8_e4m3fn range
|
|
199
|
-
expected_key = jnp.array([-10.0, 5.0, 10.0, 15.0], dtype=jnp.float8_e4m3fn)
|
|
200
|
-
|
|
201
|
-
# Expected value: value / v_scale -> clip -> astype
|
|
202
|
-
# [10., 0., -10., -15.] are within float8_e4m3fn range
|
|
203
|
-
expected_value = jnp.array([10.0, 0.0, -10.0, -15.0],
|
|
204
|
-
dtype=jnp.float8_e4m3fn)
|
|
205
|
-
|
|
206
|
-
assert jnp.array_equal(quantized_key, expected_key)
|
|
207
|
-
assert jnp.array_equal(quantized_value, expected_value)
|
|
208
|
-
|
|
209
|
-
# Test clipping
|
|
210
|
-
dtype_info = jnp.finfo(kv_cache_quantized_dtype)
|
|
211
|
-
minval, maxval = float(dtype_info.min), float(dtype_info.max)
|
|
212
|
-
|
|
213
|
-
# Values that will be outside the range after scaling
|
|
214
|
-
key_clip = jnp.array([minval * k_scale * 2, maxval * k_scale * 2])
|
|
215
|
-
value_clip = jnp.array([maxval * v_scale * 2, minval * v_scale * 2])
|
|
216
|
-
quantized_key_clip, quantized_value_clip = quantize_kv(
|
|
217
|
-
key_clip, value_clip, kv_cache_quantized_dtype, k_scale, v_scale)
|
|
218
|
-
|
|
219
|
-
# Values should be clipped to the min/max of the float8 dtype
|
|
220
|
-
expected_key_clip = jnp.array([minval, maxval], dtype=jnp.float8_e4m3fn)
|
|
221
|
-
expected_value_clip = jnp.array([maxval, minval], dtype=jnp.float8_e4m3fn)
|
|
222
|
-
|
|
223
|
-
assert jnp.array_equal(quantized_key_clip, expected_key_clip)
|
|
224
|
-
assert jnp.array_equal(quantized_value_clip, expected_value_clip)
|
|
225
|
-
|
|
226
|
-
|
|
227
185
|
def test_get_jax_dtype_from_str_dtype():
|
|
228
186
|
"""
|
|
229
187
|
Test the get_jax_dtype_from_str_dtype function
|
|
@@ -231,6 +189,5 @@ def test_get_jax_dtype_from_str_dtype():
|
|
|
231
189
|
assert get_jax_dtype_from_str_dtype("int8") == jnp.int8
|
|
232
190
|
assert get_jax_dtype_from_str_dtype("bfloat16") == jnp.bfloat16
|
|
233
191
|
assert get_jax_dtype_from_str_dtype("fp8") == jnp.float8_e4m3fn
|
|
234
|
-
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.
|
|
192
|
+
assert get_jax_dtype_from_str_dtype("fp8_e4m3") == jnp.float8_e4m3fn
|
|
235
193
|
assert get_jax_dtype_from_str_dtype("fp8_e5m2") == jnp.float8_e5m2
|
|
236
|
-
assert get_jax_dtype_from_str_dtype("auto") is None
|
tests/worker/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|