tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Unit tests for TPUModelRunner mesh initialization."""
|
|
15
|
+
import os
|
|
16
|
+
from unittest.mock import Mock, patch
|
|
17
|
+
|
|
18
|
+
import pytest
|
|
19
|
+
|
|
20
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TestTPUModelRunnerMeshInit:
|
|
24
|
+
"""Test suite for TPUModelRunner._init_mesh and related methods."""
|
|
25
|
+
|
|
26
|
+
@pytest.fixture
|
|
27
|
+
def mock_vllm_config(self):
|
|
28
|
+
"""Create a mock VllmConfig with sharding configuration."""
|
|
29
|
+
config = Mock()
|
|
30
|
+
config.sharding_config = Mock()
|
|
31
|
+
config.sharding_config.model_dp_size = 4
|
|
32
|
+
config.sharding_config.attn_dp_size = 2
|
|
33
|
+
config.sharding_config.expert_size = 1
|
|
34
|
+
config.sharding_config.tp_size = 8
|
|
35
|
+
config.sharding_config.device_indexes = None
|
|
36
|
+
config.sharding_config.total_dp_size = 4
|
|
37
|
+
return config
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def mock_devices(self):
|
|
41
|
+
"""Create mock JAX devices."""
|
|
42
|
+
devices = [Mock(id=i) for i in range(64)]
|
|
43
|
+
return devices
|
|
44
|
+
|
|
45
|
+
@pytest.fixture
|
|
46
|
+
def runner_instance(self, mock_vllm_config, mock_devices):
|
|
47
|
+
"""Create a minimal TPUModelRunner-like object for testing."""
|
|
48
|
+
# Create a minimal object that has the necessary attributes
|
|
49
|
+
runner = Mock(spec=TPUModelRunner)
|
|
50
|
+
runner.vllm_config = mock_vllm_config
|
|
51
|
+
runner.devices = mock_devices
|
|
52
|
+
runner.mesh = None
|
|
53
|
+
|
|
54
|
+
# Bind the actual methods to test (methods don't take sharding_strategy param)
|
|
55
|
+
runner._init_mesh = lambda: TPUModelRunner._init_mesh(runner)
|
|
56
|
+
runner._create_new_model_mesh = lambda: TPUModelRunner._create_new_model_mesh(
|
|
57
|
+
runner)
|
|
58
|
+
runner._create_2d_mesh = lambda: TPUModelRunner._create_2d_mesh(runner)
|
|
59
|
+
runner._create_single_slice_mesh = lambda: TPUModelRunner._create_single_slice_mesh(
|
|
60
|
+
runner)
|
|
61
|
+
runner._create_multi_slice_mesh = lambda ns: TPUModelRunner._create_multi_slice_mesh(
|
|
62
|
+
runner, ns)
|
|
63
|
+
|
|
64
|
+
return runner
|
|
65
|
+
|
|
66
|
+
def test_init_mesh_2d_model_without_device_order(self, runner_instance,
|
|
67
|
+
mock_vllm_config):
|
|
68
|
+
"""Test 2d mesh creation without enforced device order."""
|
|
69
|
+
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
|
|
70
|
+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh') as mock_make_mesh, \
|
|
71
|
+
patch('tpu_inference.runner.tpu_runner.logger'):
|
|
72
|
+
|
|
73
|
+
mock_mesh = Mock()
|
|
74
|
+
mock_make_mesh.return_value = mock_mesh
|
|
75
|
+
|
|
76
|
+
runner_instance._init_mesh()
|
|
77
|
+
|
|
78
|
+
mock_make_mesh.assert_called_once()
|
|
79
|
+
call_args = mock_make_mesh.call_args
|
|
80
|
+
|
|
81
|
+
# Verify mesh_shape
|
|
82
|
+
assert call_args[0][0] == (4, 8) # (model_dp_size, tp_size)
|
|
83
|
+
# Verify axis_names
|
|
84
|
+
assert call_args[0][1] == ("data", "model")
|
|
85
|
+
# Verify devices
|
|
86
|
+
assert call_args[1]['devices'] == runner_instance.devices
|
|
87
|
+
|
|
88
|
+
assert runner_instance.mesh == mock_mesh
|
|
89
|
+
|
|
90
|
+
def test_init_mesh_2d_model_with_device_order(self, runner_instance,
|
|
91
|
+
mock_vllm_config):
|
|
92
|
+
"""Test 2d mesh creation with enforced device order."""
|
|
93
|
+
mock_vllm_config.sharding_config.device_indexes = [0, 1, 2, 3]
|
|
94
|
+
|
|
95
|
+
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': ''}), \
|
|
96
|
+
patch('jax.make_mesh') as mock_jax_mesh, \
|
|
97
|
+
patch('tpu_inference.runner.tpu_runner.logger'):
|
|
98
|
+
|
|
99
|
+
mock_mesh = Mock()
|
|
100
|
+
mock_jax_mesh.return_value = mock_mesh
|
|
101
|
+
|
|
102
|
+
runner_instance._init_mesh()
|
|
103
|
+
|
|
104
|
+
mock_jax_mesh.assert_called_once()
|
|
105
|
+
call_args = mock_jax_mesh.call_args
|
|
106
|
+
|
|
107
|
+
# Verify mesh_shape
|
|
108
|
+
assert call_args[0][0] == (4, 8)
|
|
109
|
+
# Verify axis_names
|
|
110
|
+
assert call_args[0][1] == ("data", "model")
|
|
111
|
+
# Verify devices
|
|
112
|
+
assert call_args[1]['devices'] == runner_instance.devices
|
|
113
|
+
|
|
114
|
+
assert runner_instance.mesh == mock_mesh
|
|
115
|
+
|
|
116
|
+
def test_init_mesh_new_model_single_slice(self, runner_instance,
|
|
117
|
+
mock_vllm_config):
|
|
118
|
+
"""Test new model mesh creation with single slice."""
|
|
119
|
+
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': '1'}), \
|
|
120
|
+
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
|
|
121
|
+
patch('jax.sharding.Mesh') as mock_jax_mesh, \
|
|
122
|
+
patch('tpu_inference.runner.tpu_runner.logger'):
|
|
123
|
+
|
|
124
|
+
mock_devices_array = Mock()
|
|
125
|
+
mock_mesh_utils.create_device_mesh.return_value = mock_devices_array
|
|
126
|
+
mock_mesh = Mock()
|
|
127
|
+
mock_jax_mesh.return_value = mock_mesh
|
|
128
|
+
|
|
129
|
+
runner_instance._init_mesh()
|
|
130
|
+
|
|
131
|
+
# Verify create_device_mesh was called
|
|
132
|
+
mock_mesh_utils.create_device_mesh.assert_called_once()
|
|
133
|
+
call_args = mock_mesh_utils.create_device_mesh.call_args
|
|
134
|
+
|
|
135
|
+
# Verify mesh_shape: (model_dp_size, attn_dp_size, expert_size, tp_size)
|
|
136
|
+
assert call_args[0][0] == (4, 2, 1, 8)
|
|
137
|
+
assert call_args[0][1] == runner_instance.devices
|
|
138
|
+
assert call_args[1]['allow_split_physical_axes'] is True
|
|
139
|
+
|
|
140
|
+
# Verify Mesh was created with correct axis names
|
|
141
|
+
mock_jax_mesh.assert_called_once_with(
|
|
142
|
+
mock_devices_array, ("data", "attn_dp", "expert", "model"))
|
|
143
|
+
|
|
144
|
+
assert runner_instance.mesh == mock_mesh
|
|
145
|
+
|
|
146
|
+
def test_init_mesh_new_model_multi_slice(self, runner_instance,
|
|
147
|
+
mock_vllm_config):
|
|
148
|
+
"""Test new model mesh creation with multiple slices."""
|
|
149
|
+
num_slices = 2
|
|
150
|
+
with patch.dict(os.environ, {'NEW_MODEL_DESIGN': '1', 'NUM_SLICES': str(num_slices)}), \
|
|
151
|
+
patch('tpu_inference.runner.tpu_runner.mesh_utils') as mock_mesh_utils, \
|
|
152
|
+
patch('jax.sharding.Mesh') as mock_jax_mesh, \
|
|
153
|
+
patch('tpu_inference.runner.tpu_runner.logger'):
|
|
154
|
+
|
|
155
|
+
mock_devices_array = Mock()
|
|
156
|
+
mock_mesh_utils.create_hybrid_device_mesh.return_value = mock_devices_array
|
|
157
|
+
mock_mesh = Mock()
|
|
158
|
+
mock_jax_mesh.return_value = mock_mesh
|
|
159
|
+
|
|
160
|
+
runner_instance._init_mesh()
|
|
161
|
+
|
|
162
|
+
# Verify create_hybrid_device_mesh was called
|
|
163
|
+
mock_mesh_utils.create_hybrid_device_mesh.assert_called_once()
|
|
164
|
+
call_args = mock_mesh_utils.create_hybrid_device_mesh.call_args
|
|
165
|
+
|
|
166
|
+
# Verify intra_node_shape: (dp_inner, attn_dp_size, expert_size, tp_size)
|
|
167
|
+
# dp_inner = model_dp_size // num_slices = 4 // 2 = 2
|
|
168
|
+
assert call_args[1]['mesh_shape'] == (2, 2, 1, 8)
|
|
169
|
+
# Verify outer_node_shape: (num_slices, 1, 1, 1)
|
|
170
|
+
assert call_args[1]['dcn_mesh_shape'] == (2, 1, 1, 1)
|
|
171
|
+
assert call_args[1]['devices'] == runner_instance.devices
|
|
172
|
+
assert call_args[1]['allow_split_physical_axes'] is True
|
|
173
|
+
|
|
174
|
+
# Verify Mesh was created with correct axis names
|
|
175
|
+
mock_jax_mesh.assert_called_once_with(
|
|
176
|
+
mock_devices_array, ("data", "attn_dp", "expert", "model"))
|
|
177
|
+
|
|
178
|
+
assert runner_instance.mesh == mock_mesh
|
|
179
|
+
|
|
180
|
+
@pytest.mark.parametrize("num_slices,expected_dp_inner", [
|
|
181
|
+
(1, 4),
|
|
182
|
+
(2, 2),
|
|
183
|
+
(4, 1),
|
|
184
|
+
])
|
|
185
|
+
def test_multi_slice_mesh_dp_inner_calculation(self, runner_instance,
|
|
186
|
+
mock_vllm_config,
|
|
187
|
+
num_slices,
|
|
188
|
+
expected_dp_inner):
|
|
189
|
+
"""Test dp_inner calculation for various num_slices values."""
|
|
190
|
+
with patch('tpu_inference.runner.tpu_runner.mesh_utils'
|
|
191
|
+
) as mock_mesh_utils:
|
|
192
|
+
mock_mesh_utils.create_hybrid_device_mesh.return_value = Mock()
|
|
193
|
+
|
|
194
|
+
runner_instance._create_multi_slice_mesh(num_slices)
|
|
195
|
+
|
|
196
|
+
call_args = mock_mesh_utils.create_hybrid_device_mesh.call_args
|
|
197
|
+
intra_node_shape = call_args[1]['mesh_shape']
|
|
198
|
+
|
|
199
|
+
# First dimension of intra_node_shape should be dp_inner
|
|
200
|
+
assert intra_node_shape[0] == expected_dp_inner
|
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import io
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from unittest.mock import MagicMock, mock_open, patch
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pytest
|
|
11
|
+
from jax._src.interpreters import pxla
|
|
12
|
+
|
|
13
|
+
from tpu_inference.runner.utils import (
|
|
14
|
+
PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR, ForbidCompile, InferencePhase,
|
|
15
|
+
LatencyTracker, PhasedBasedProfiler,
|
|
16
|
+
determine_phase_from_batch_composition_stats, get_batch_composition_stats,
|
|
17
|
+
get_padded_num_reqs_with_upper_limit, get_padded_token_len,
|
|
18
|
+
get_req_paddings, get_token_paddings)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_get_padded_num_reqs_with_upper_limit():
|
|
22
|
+
"""Tests the get_padded_num_reqs_with_upper_limit function."""
|
|
23
|
+
# From utils.py, MIN_NUM_SEQS = 8
|
|
24
|
+
assert get_padded_num_reqs_with_upper_limit(4, 128) == 8
|
|
25
|
+
assert get_padded_num_reqs_with_upper_limit(8, 128) == 8
|
|
26
|
+
assert get_padded_num_reqs_with_upper_limit(9, 128) == 16
|
|
27
|
+
assert get_padded_num_reqs_with_upper_limit(16, 128) == 16
|
|
28
|
+
assert get_padded_num_reqs_with_upper_limit(17, 128) == 32
|
|
29
|
+
assert get_padded_num_reqs_with_upper_limit(100, 64) == 64
|
|
30
|
+
assert get_padded_num_reqs_with_upper_limit(1, 128) == 8
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_get_paddings():
|
|
34
|
+
# Bucketed padding
|
|
35
|
+
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
|
36
|
+
expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]
|
|
37
|
+
actual_paddings = get_token_paddings(min_token_size, max_token_size,
|
|
38
|
+
padding_gap)
|
|
39
|
+
|
|
40
|
+
# Bucketed padding with max_token_size not a power of two.
|
|
41
|
+
max_token_size = 317
|
|
42
|
+
expected_paddings = [16, 32, 64, 128, 192, 256, 320]
|
|
43
|
+
actual_paddings = get_token_paddings(min_token_size, max_token_size,
|
|
44
|
+
padding_gap)
|
|
45
|
+
assert actual_paddings == expected_paddings
|
|
46
|
+
|
|
47
|
+
# Exponential padding.
|
|
48
|
+
max_token_size, padding_gap = 1024, 0
|
|
49
|
+
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
|
|
50
|
+
actual_paddings = get_token_paddings(min_token_size, max_token_size,
|
|
51
|
+
padding_gap)
|
|
52
|
+
assert actual_paddings == expected_paddings
|
|
53
|
+
# Exponential padding with max_token_size not a power of two.
|
|
54
|
+
max_token_size = 317
|
|
55
|
+
expected_paddings = [16, 32, 64, 128, 256, 512]
|
|
56
|
+
actual_paddings = get_token_paddings(min_token_size, max_token_size,
|
|
57
|
+
padding_gap)
|
|
58
|
+
assert actual_paddings == expected_paddings
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_get_padded_token_len():
|
|
62
|
+
min_token_size, max_token_size, padding_gap = 16, 512, 64
|
|
63
|
+
paddings = get_token_paddings(min_token_size, max_token_size, padding_gap)
|
|
64
|
+
assert get_padded_token_len(paddings, 1) == 16
|
|
65
|
+
assert get_padded_token_len(paddings, 16) == 16
|
|
66
|
+
assert get_padded_token_len(paddings, 20) == 32
|
|
67
|
+
assert get_padded_token_len(paddings, 300) == 320
|
|
68
|
+
assert get_padded_token_len(paddings, 512) == 512
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_get_req_paddings():
|
|
72
|
+
assert get_req_paddings(1, 32) == [8, 16, 32]
|
|
73
|
+
assert get_req_paddings(8, 32) == [8, 16, 32]
|
|
74
|
+
assert get_req_paddings(8, 36) == [8, 16, 32, 36]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_latency_tracker(caplog):
|
|
78
|
+
"""Tests the LatencyTracker context manager."""
|
|
79
|
+
logger_name = "vllm.tpu_inference.runner.utils"
|
|
80
|
+
logger = logging.getLogger(logger_name)
|
|
81
|
+
|
|
82
|
+
original_level = logger.level
|
|
83
|
+
original_propagate = logger.propagate
|
|
84
|
+
|
|
85
|
+
# Create an in-memory stream to capture log output
|
|
86
|
+
log_capture_string = io.StringIO()
|
|
87
|
+
# Create a handler that writes to our in-memory stream
|
|
88
|
+
capture_handler = logging.StreamHandler(log_capture_string)
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
logger.setLevel(logging.DEBUG)
|
|
92
|
+
logger.propagate = False
|
|
93
|
+
logger.addHandler(capture_handler)
|
|
94
|
+
|
|
95
|
+
sleep_duration = 0.01
|
|
96
|
+
with LatencyTracker("test_op") as tracker:
|
|
97
|
+
time.sleep(sleep_duration)
|
|
98
|
+
|
|
99
|
+
elapsed = tracker.end_time - tracker.start_time
|
|
100
|
+
assert elapsed >= sleep_duration
|
|
101
|
+
log_contents = log_capture_string.getvalue()
|
|
102
|
+
|
|
103
|
+
assert "Latency for 'test_op'" in log_contents
|
|
104
|
+
assert f"{elapsed:.3f} seconds" in log_contents
|
|
105
|
+
|
|
106
|
+
finally:
|
|
107
|
+
# --- IMPORTANT: Clean up and restore the logger's original state ---
|
|
108
|
+
logger.setLevel(original_level)
|
|
109
|
+
logger.propagate = original_propagate
|
|
110
|
+
logger.removeHandler(capture_handler)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# Define a fixture to clear the JAX cache before each test
|
|
114
|
+
@pytest.fixture(autouse=True)
|
|
115
|
+
def clear_jax_cache():
|
|
116
|
+
jax.clear_caches()
|
|
117
|
+
yield
|
|
118
|
+
jax.clear_caches()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@pytest.fixture
|
|
122
|
+
def jitted_function():
|
|
123
|
+
"""Defines a jitted function for testing."""
|
|
124
|
+
|
|
125
|
+
@jax.jit
|
|
126
|
+
def my_jitted_func(x):
|
|
127
|
+
return x * 2
|
|
128
|
+
|
|
129
|
+
return my_jitted_func
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@pytest.fixture
|
|
133
|
+
def jnp_array_input():
|
|
134
|
+
return jnp.ones((2, 3))
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@pytest.fixture
|
|
138
|
+
def jnp_array_input_same_shape():
|
|
139
|
+
return jnp.zeros((2, 3))
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@pytest.fixture
|
|
143
|
+
def jnp_array_input_new():
|
|
144
|
+
return jnp.ones((3, 3))
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def test_forbid_compile_raises_error_on_first_call(jitted_function,
|
|
148
|
+
jnp_array_input):
|
|
149
|
+
"""Test that ForbidCompile raises an error when a compilation occurs."""
|
|
150
|
+
with pytest.raises(RuntimeError, match="JAX compilation occurred"):
|
|
151
|
+
with ForbidCompile():
|
|
152
|
+
jitted_function(jnp_array_input)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def test_forbid_compile_succeeds_on_cached_call(jitted_function,
|
|
156
|
+
jnp_array_input):
|
|
157
|
+
"""Test that ForbidCompile does not raise an error on a cached call."""
|
|
158
|
+
# Warm up the cache
|
|
159
|
+
jitted_function(jnp_array_input)
|
|
160
|
+
with ForbidCompile():
|
|
161
|
+
jitted_function(jnp_array_input)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_forbid_compile_restores_original_function():
|
|
165
|
+
"""Test that ForbidCompile restores the original JAX function after exit."""
|
|
166
|
+
original_func = pxla._cached_lowering_to_hlo
|
|
167
|
+
with ForbidCompile():
|
|
168
|
+
pass
|
|
169
|
+
assert pxla._cached_lowering_to_hlo is original_func
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def test_forbid_compile_with_exception():
|
|
173
|
+
"""Test that ForbidCompile restores the original function even if an exception occurs."""
|
|
174
|
+
original_func = pxla._cached_lowering_to_hlo
|
|
175
|
+
with pytest.raises(ValueError, match="Test exception"):
|
|
176
|
+
with ForbidCompile():
|
|
177
|
+
raise ValueError("Test exception")
|
|
178
|
+
assert pxla._cached_lowering_to_hlo is original_func
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def test_forbid_compile_raises_on_new_shape(jitted_function, jnp_array_input,
|
|
182
|
+
jnp_array_input_same_shape,
|
|
183
|
+
jnp_array_input_new):
|
|
184
|
+
"""
|
|
185
|
+
Tests that ForbidCompile raises a RuntimeError when a jitted function
|
|
186
|
+
is called with an input shape that triggers a new compilation.
|
|
187
|
+
"""
|
|
188
|
+
# Clear cache for a clean test state.
|
|
189
|
+
pxla._cached_lowering_to_hlo.cache_clear()
|
|
190
|
+
|
|
191
|
+
# Warm up the JIT cache with the SCALAR input.
|
|
192
|
+
# This causes the first compilation and cache miss.
|
|
193
|
+
jitted_function(jnp_array_input)
|
|
194
|
+
misses_after_warmup = pxla._cached_lowering_to_hlo.cache_info().misses
|
|
195
|
+
assert misses_after_warmup == 1
|
|
196
|
+
|
|
197
|
+
# This call uses the same shape/dtype, so it should be a cache HIT.
|
|
198
|
+
# No RuntimeError expected.
|
|
199
|
+
with ForbidCompile():
|
|
200
|
+
jitted_function(jnp_array_input_same_shape)
|
|
201
|
+
assert pxla._cached_lowering_to_hlo.cache_info(
|
|
202
|
+
).misses == misses_after_warmup # No new misses
|
|
203
|
+
|
|
204
|
+
# Now, call with a VECTOR input. This has a different shape,
|
|
205
|
+
# forcing a NEW compilation (cache MISS).
|
|
206
|
+
# This *should* raise a RuntimeError within the ForbidCompile context.
|
|
207
|
+
expected_error_message = "JAX compilation occurred but was forbidden in this context."
|
|
208
|
+
with pytest.raises(RuntimeError, match=expected_error_message):
|
|
209
|
+
with ForbidCompile(message=expected_error_message):
|
|
210
|
+
jitted_function(jnp_array_input_new)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class MockInputBatch:
|
|
214
|
+
|
|
215
|
+
def __init__(self, req_ids, num_computed_tokens_cpu):
|
|
216
|
+
self.req_ids = req_ids
|
|
217
|
+
self.num_computed_tokens_cpu = np.array(num_computed_tokens_cpu)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class MockSchedulerOutput:
|
|
221
|
+
|
|
222
|
+
def __init__(self, num_scheduled_tokens):
|
|
223
|
+
self.num_scheduled_tokens = num_scheduled_tokens
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@pytest.mark.parametrize(
|
|
227
|
+
"scenario, num_reqs, req_ids, computed, scheduled, expected_prefill, expected_decode",
|
|
228
|
+
[
|
|
229
|
+
("prefill_only", 2, [101, 102], [0, 0], {
|
|
230
|
+
101: 50,
|
|
231
|
+
102: 100
|
|
232
|
+
}, 150, 0),
|
|
233
|
+
("decode_only", 3, [201, 202, 203], [10, 20, 5], {
|
|
234
|
+
201: 1,
|
|
235
|
+
202: 1,
|
|
236
|
+
203: 1
|
|
237
|
+
}, 0, 3),
|
|
238
|
+
("mixed_batch", 4, [301, 302, 303, 304], [0, 10, 0, 20], {
|
|
239
|
+
301: 100,
|
|
240
|
+
302: 1,
|
|
241
|
+
303: 50,
|
|
242
|
+
304: 1
|
|
243
|
+
}, 150, 2),
|
|
244
|
+
("chunked_prefill", 2, [401, 402], [50, 10], {
|
|
245
|
+
401: 50,
|
|
246
|
+
402: 1
|
|
247
|
+
}, 50, 1),
|
|
248
|
+
])
|
|
249
|
+
def test_get_batch_composition_stats(scenario, num_reqs, req_ids, computed,
|
|
250
|
+
scheduled, expected_prefill,
|
|
251
|
+
expected_decode):
|
|
252
|
+
"""Tests get_batch_composition_stats for various scenarios."""
|
|
253
|
+
input_batch = MockInputBatch(req_ids, computed)
|
|
254
|
+
scheduler_output = MockSchedulerOutput(scheduled)
|
|
255
|
+
total_tokens = sum(scheduled.values())
|
|
256
|
+
|
|
257
|
+
stats = get_batch_composition_stats(
|
|
258
|
+
input_batch=input_batch,
|
|
259
|
+
total_num_scheduled_tokens=total_tokens,
|
|
260
|
+
num_reqs=num_reqs,
|
|
261
|
+
padded_total_num_scheduled_tokens=total_tokens + 8,
|
|
262
|
+
scheduler_output=scheduler_output)
|
|
263
|
+
|
|
264
|
+
assert stats["num_prefill_tokens"] == expected_prefill
|
|
265
|
+
assert stats["num_decode_tokens"] == expected_decode
|
|
266
|
+
assert stats["num_reqs"] == num_reqs
|
|
267
|
+
assert stats["total_num_scheduled_tokens"] == total_tokens
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@pytest.mark.parametrize("prefill_tokens, total_tokens, expected_phase", [
|
|
271
|
+
(90, 100, InferencePhase.PREFILL_HEAVY),
|
|
272
|
+
(89, 100, InferencePhase.AMBIGUOUS),
|
|
273
|
+
(15, 100, InferencePhase.DECODE_HEAVY),
|
|
274
|
+
(50, 100, InferencePhase.BALANCED),
|
|
275
|
+
(70, 100, InferencePhase.AMBIGUOUS),
|
|
276
|
+
(30, 100, InferencePhase.AMBIGUOUS),
|
|
277
|
+
(40, 100, InferencePhase.BALANCED),
|
|
278
|
+
(50, 100, InferencePhase.BALANCED),
|
|
279
|
+
(60, 100, InferencePhase.BALANCED),
|
|
280
|
+
(100, 100, InferencePhase.PREFILL_HEAVY),
|
|
281
|
+
(20, 100, InferencePhase.DECODE_HEAVY),
|
|
282
|
+
(21, 100, InferencePhase.AMBIGUOUS),
|
|
283
|
+
(0, 100, InferencePhase.DECODE_HEAVY),
|
|
284
|
+
])
|
|
285
|
+
def test_determine_phase_from_batch_composition_stats(prefill_tokens,
|
|
286
|
+
total_tokens,
|
|
287
|
+
expected_phase):
|
|
288
|
+
"""Tests the phase determination logic based on prefill ratios."""
|
|
289
|
+
stats = {
|
|
290
|
+
"num_prefill_tokens": prefill_tokens,
|
|
291
|
+
"total_num_scheduled_tokens": total_tokens
|
|
292
|
+
}
|
|
293
|
+
phase = determine_phase_from_batch_composition_stats(stats)
|
|
294
|
+
assert phase == expected_phase
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
@pytest.fixture
|
|
298
|
+
def profiler_fixture(tmp_path):
|
|
299
|
+
"""Fixture to set up a PhasedBasedProfiler with mocked dependencies."""
|
|
300
|
+
target_module = "tpu_inference.runner.utils"
|
|
301
|
+
with patch(f"{target_module}.jax.profiler.start_trace") as mock_start, \
|
|
302
|
+
patch(f"{target_module}.jax.profiler.stop_trace") as mock_stop, \
|
|
303
|
+
patch("builtins.open", mock_open()) as mock_file, \
|
|
304
|
+
patch(f"{target_module}.datetime") as mock_datetime, \
|
|
305
|
+
patch(f"{target_module}.InferencePhase", InferencePhase), \
|
|
306
|
+
patch(f"{target_module}.determine_phase_from_batch_composition_stats") as mock_determine_phase:
|
|
307
|
+
|
|
308
|
+
mock_now = MagicMock()
|
|
309
|
+
mock_now.strftime.return_value = "2024_01_01_12_00_00"
|
|
310
|
+
mock_datetime.datetime.now.return_value = mock_now
|
|
311
|
+
|
|
312
|
+
profiler = PhasedBasedProfiler(profile_dir=str(tmp_path))
|
|
313
|
+
profiler.num_steps_to_profile_for = PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR
|
|
314
|
+
|
|
315
|
+
yield {
|
|
316
|
+
"profiler": profiler,
|
|
317
|
+
"mock_start": mock_start,
|
|
318
|
+
"mock_stop": mock_stop,
|
|
319
|
+
"mock_file": mock_file,
|
|
320
|
+
"mock_determine_phase": mock_determine_phase,
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def test_phased_profiler_full_cycle(profiler_fixture):
|
|
325
|
+
"""Tests a full start-step-stop profiling cycle for one phase."""
|
|
326
|
+
profiler = profiler_fixture["profiler"]
|
|
327
|
+
mock_start = profiler_fixture["mock_start"]
|
|
328
|
+
mock_stop = profiler_fixture["mock_stop"]
|
|
329
|
+
mock_file = profiler_fixture["mock_file"]
|
|
330
|
+
mock_determine_phase = profiler_fixture["mock_determine_phase"]
|
|
331
|
+
|
|
332
|
+
stats = {"num_reqs": 2, "total_num_scheduled_tokens": 100}
|
|
333
|
+
|
|
334
|
+
# 1. Start profiling on PREFILL_HEAVY phase
|
|
335
|
+
mock_determine_phase.return_value = InferencePhase.PREFILL_HEAVY
|
|
336
|
+
profiler.step(stats)
|
|
337
|
+
mock_start.assert_called_once()
|
|
338
|
+
assert profiler.profiling_n_steps_left == PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR
|
|
339
|
+
assert profiler.current_phase == "prefill_heavy"
|
|
340
|
+
assert profiler.inference_phase_seen[InferencePhase.PREFILL_HEAVY]
|
|
341
|
+
assert mock_file().write.call_count == 1 # Wrote stats on start
|
|
342
|
+
|
|
343
|
+
# 2. Step profiling (N-1 steps)
|
|
344
|
+
for i in range(PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR - 1):
|
|
345
|
+
profiler.step(stats)
|
|
346
|
+
assert profiler.profiling_n_steps_left == PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR - 1 - i
|
|
347
|
+
mock_start.assert_called_once() # Not called again
|
|
348
|
+
mock_stop.assert_not_called()
|
|
349
|
+
|
|
350
|
+
# 3. Final step stops profiling
|
|
351
|
+
profiler.step(stats)
|
|
352
|
+
mock_stop.assert_called_once()
|
|
353
|
+
assert profiler.profiling_n_steps_left == 0
|
|
354
|
+
assert profiler.current_phase == ""
|
|
355
|
+
assert mock_file(
|
|
356
|
+
).write.call_count == PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR + 1
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def test_phased_profiler_ignores_initial_request(profiler_fixture):
|
|
360
|
+
"""Tests that profiling is not triggered for initial small requests."""
|
|
361
|
+
profiler = profiler_fixture["profiler"]
|
|
362
|
+
mock_start = profiler_fixture["mock_start"]
|
|
363
|
+
mock_determine_phase = profiler_fixture["mock_determine_phase"]
|
|
364
|
+
|
|
365
|
+
mock_determine_phase.return_value = InferencePhase.PREFILL_HEAVY
|
|
366
|
+
|
|
367
|
+
profiler.step({"num_reqs": 1, "total_num_scheduled_tokens": 1})
|
|
368
|
+
mock_start.assert_not_called()
|
|
369
|
+
|
|
370
|
+
profiler.step({"num_reqs": 1, "total_num_scheduled_tokens": 100})
|
|
371
|
+
mock_start.assert_not_called()
|
|
372
|
+
|
|
373
|
+
profiler.step({"num_reqs": 2, "total_num_scheduled_tokens": 1})
|
|
374
|
+
mock_start.assert_not_called()
|
|
375
|
+
|
|
376
|
+
profiler.step({"num_reqs": 2, "total_num_scheduled_tokens": 2})
|
|
377
|
+
mock_start.assert_called_once()
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def test_phased_profiler_handles_all_phases(profiler_fixture):
|
|
381
|
+
"""Tests that the profiler can profile all defined phases sequentially."""
|
|
382
|
+
profiler = profiler_fixture["profiler"]
|
|
383
|
+
mock_start = profiler_fixture["mock_start"]
|
|
384
|
+
mock_stop = profiler_fixture["mock_stop"]
|
|
385
|
+
mock_determine_phase = profiler_fixture["mock_determine_phase"]
|
|
386
|
+
|
|
387
|
+
stats = {"num_reqs": 2, "total_num_scheduled_tokens": 100}
|
|
388
|
+
phases_to_profile = [
|
|
389
|
+
InferencePhase.PREFILL_HEAVY, InferencePhase.DECODE_HEAVY,
|
|
390
|
+
InferencePhase.BALANCED
|
|
391
|
+
]
|
|
392
|
+
|
|
393
|
+
for i, phase in enumerate(phases_to_profile):
|
|
394
|
+
# Start profiling for the new phase
|
|
395
|
+
mock_determine_phase.return_value = phase
|
|
396
|
+
profiler.step(stats)
|
|
397
|
+
assert mock_start.call_count == i + 1
|
|
398
|
+
assert profiler.current_phase == phase.name.lower()
|
|
399
|
+
assert profiler.inference_phase_seen[phase]
|
|
400
|
+
|
|
401
|
+
# Step until profiling stops for this phase
|
|
402
|
+
for _ in range(PHASED_PROFILER_NUM_STEPS_TO_PROFILE_FOR):
|
|
403
|
+
profiler.step(stats)
|
|
404
|
+
|
|
405
|
+
assert mock_stop.call_count == i + 1
|
|
406
|
+
assert profiler.current_phase == ""
|
|
407
|
+
|
|
408
|
+
# After all phases seen, should not start again
|
|
409
|
+
mock_determine_phase.return_value = InferencePhase.PREFILL_HEAVY
|
|
410
|
+
profiler.step(stats)
|
|
411
|
+
assert mock_start.call_count == len(phases_to_profile)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|