tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
tests/__init__.py
ADDED
|
File without changes
|
tests/core/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import unittest
|
|
4
|
+
from unittest.mock import MagicMock, patch
|
|
5
|
+
|
|
6
|
+
from vllm.config import ParallelConfig, VllmConfig
|
|
7
|
+
from vllm.v1.engine import EngineCoreRequest, EngineCoreRequestType
|
|
8
|
+
from vllm.v1.executor.abstract import Executor
|
|
9
|
+
from vllm.v1.request import Request
|
|
10
|
+
|
|
11
|
+
from tpu_inference.core.core_tpu import (DisaggEngineCore,
|
|
12
|
+
DisaggEngineCoreProc,
|
|
13
|
+
_DisaggOrchestrator)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TestDisaggEngineCore(unittest.TestCase):
|
|
17
|
+
|
|
18
|
+
def setUp(self):
|
|
19
|
+
# Patch disagg_utils to control slice configuration.
|
|
20
|
+
self.mock_disagg_utils_patcher = patch(
|
|
21
|
+
'tpu_inference.core.core_tpu.disagg_utils')
|
|
22
|
+
self.mock_disagg_utils = self.mock_disagg_utils_patcher.start()
|
|
23
|
+
self.mock_disagg_utils.get_prefill_slices.return_value = (
|
|
24
|
+
4, ) # One prefill engine
|
|
25
|
+
self.mock_disagg_utils.get_decode_slices.return_value = (
|
|
26
|
+
2, ) # One decode engine
|
|
27
|
+
self.addCleanup(self.mock_disagg_utils_patcher.stop)
|
|
28
|
+
|
|
29
|
+
# Patch the orchestrator to test the adapter in isolation
|
|
30
|
+
self.mock_orchestrator_patcher = patch(
|
|
31
|
+
'tpu_inference.core.core_tpu._DisaggOrchestrator')
|
|
32
|
+
self.mock_orchestrator = self.mock_orchestrator_patcher.start()
|
|
33
|
+
self.addCleanup(self.mock_orchestrator_patcher.stop)
|
|
34
|
+
|
|
35
|
+
# Patch vLLMEngineCore to avoid its complex initialization.
|
|
36
|
+
self.mock_engine_core_patcher = patch(
|
|
37
|
+
'tpu_inference.core.core_tpu.vLLMEngineCore')
|
|
38
|
+
self.mock_vLLMEngineCore = self.mock_engine_core_patcher.start()
|
|
39
|
+
self.addCleanup(self.mock_engine_core_patcher.stop)
|
|
40
|
+
|
|
41
|
+
# Mock jax.devices
|
|
42
|
+
self.mock_jax_devices_patcher = patch('jax.devices',
|
|
43
|
+
return_value=[MagicMock()] * 8)
|
|
44
|
+
self.mock_jax_devices = self.mock_jax_devices_patcher.start()
|
|
45
|
+
self.addCleanup(self.mock_jax_devices_patcher.stop)
|
|
46
|
+
|
|
47
|
+
# VLLM Config
|
|
48
|
+
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
|
49
|
+
self.mock_vllm_config.parallel_config = MagicMock(spec=ParallelConfig)
|
|
50
|
+
self.mock_vllm_config.device_config = MagicMock()
|
|
51
|
+
self.mock_vllm_config.cache_config = MagicMock()
|
|
52
|
+
self.mock_vllm_config.cache_config.prefix_caching_hash_algo = "builtin"
|
|
53
|
+
self.mock_vllm_config.cache_config.block_size = 5
|
|
54
|
+
self.mock_vllm_config.__post_init__ = MagicMock()
|
|
55
|
+
|
|
56
|
+
def test_initialization(self):
|
|
57
|
+
"""Tests that the adapter initializes the orchestrator correctly."""
|
|
58
|
+
engine = DisaggEngineCore(
|
|
59
|
+
vllm_config=self.mock_vllm_config,
|
|
60
|
+
executor_class=MagicMock(spec=Executor),
|
|
61
|
+
log_stats=False,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.mock_orchestrator.assert_called_once()
|
|
65
|
+
args, kwargs = self.mock_orchestrator.call_args
|
|
66
|
+
self.assertIsInstance(kwargs['config'], VllmConfig)
|
|
67
|
+
self.assertEqual(kwargs['config'], self.mock_vllm_config)
|
|
68
|
+
self.assertEqual(kwargs['output_queue'], engine.output_queue)
|
|
69
|
+
self.assertEqual(len(kwargs['prefill_engines']), 1)
|
|
70
|
+
self.assertEqual(len(kwargs['decode_engines']), 1)
|
|
71
|
+
self.assertEqual(kwargs['prefill_slice_sizes'], (4, ))
|
|
72
|
+
self.assertEqual(kwargs['decode_slice_sizes'], (2, ))
|
|
73
|
+
|
|
74
|
+
def test_add_request(self):
|
|
75
|
+
"""Tests that the adapter correctly delegates add_request to the orchestrator."""
|
|
76
|
+
engine = DisaggEngineCore(
|
|
77
|
+
vllm_config=self.mock_vllm_config,
|
|
78
|
+
executor_class=MagicMock(spec=Executor),
|
|
79
|
+
log_stats=False,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
mock_request = MagicMock(spec=Request)
|
|
83
|
+
mock_request.request_id = "test_req"
|
|
84
|
+
mock_request.pooling_params = None
|
|
85
|
+
mock_request.kv_transfer_params = None
|
|
86
|
+
|
|
87
|
+
engine.add_request(mock_request)
|
|
88
|
+
|
|
89
|
+
self.mock_orchestrator.return_value.add_request.assert_called_once()
|
|
90
|
+
# Get the argument passed to add_request
|
|
91
|
+
passed_request = self.mock_orchestrator.return_value.add_request.call_args[
|
|
92
|
+
0][0]
|
|
93
|
+
|
|
94
|
+
# Assert it's the correct type (the Request directly)
|
|
95
|
+
self.assertIsInstance(passed_request, Request)
|
|
96
|
+
self.assertEqual(passed_request.request_id, "test_req")
|
|
97
|
+
|
|
98
|
+
def test_shutdown(self):
|
|
99
|
+
"""Tests that the adapter correctly delegates shutdown to the orchestrator."""
|
|
100
|
+
engine = DisaggEngineCore(
|
|
101
|
+
vllm_config=self.mock_vllm_config,
|
|
102
|
+
executor_class=MagicMock(spec=Executor),
|
|
103
|
+
log_stats=False,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
engine.shutdown()
|
|
107
|
+
|
|
108
|
+
self.mock_orchestrator.return_value.shutdown.assert_called_once()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class TestDisaggEngineCoreProc(unittest.TestCase):
|
|
112
|
+
|
|
113
|
+
def setUp(self):
|
|
114
|
+
# Patch disagg_utils to control slice configuration.
|
|
115
|
+
self.mock_disagg_utils_patcher = patch(
|
|
116
|
+
'tpu_inference.core.core_tpu.disagg_utils')
|
|
117
|
+
self.mock_disagg_utils = self.mock_disagg_utils_patcher.start()
|
|
118
|
+
self.mock_disagg_utils.get_prefill_slices.return_value = (
|
|
119
|
+
4, ) # One prefill engine
|
|
120
|
+
self.mock_disagg_utils.get_decode_slices.return_value = (
|
|
121
|
+
2, ) # One decode engine
|
|
122
|
+
self.addCleanup(self.mock_disagg_utils_patcher.stop)
|
|
123
|
+
|
|
124
|
+
# Patch the orchestrator to test the adapter in isolation
|
|
125
|
+
self.mock_orchestrator_patcher = patch(
|
|
126
|
+
'tpu_inference.core.core_tpu._DisaggOrchestrator')
|
|
127
|
+
self.mock_orchestrator = self.mock_orchestrator_patcher.start()
|
|
128
|
+
self.addCleanup(self.mock_orchestrator_patcher.stop)
|
|
129
|
+
|
|
130
|
+
# Patch vLLMEngineCore to avoid its complex initialization.
|
|
131
|
+
self.mock_engine_core_patcher = patch(
|
|
132
|
+
'tpu_inference.core.core_tpu.vLLMEngineCore')
|
|
133
|
+
self.mock_vLLMEngineCore = self.mock_engine_core_patcher.start()
|
|
134
|
+
self.addCleanup(self.mock_engine_core_patcher.stop)
|
|
135
|
+
|
|
136
|
+
# Patch the ZMQ handshake to isolate the test.
|
|
137
|
+
self.mock_handshake_patcher = patch(
|
|
138
|
+
'tpu_inference.core.core_tpu.DisaggEngineCoreProc._perform_handshake'
|
|
139
|
+
)
|
|
140
|
+
self.mock_handshake = self.mock_handshake_patcher.start()
|
|
141
|
+
self.mock_handshake.return_value.__enter__.return_value = MagicMock(
|
|
142
|
+
outputs=["output_addr"], coordinator_output=None)
|
|
143
|
+
self.addCleanup(self.mock_handshake_patcher.stop)
|
|
144
|
+
|
|
145
|
+
# Patch threads to avoid them running in the background.
|
|
146
|
+
def mock_thread_constructor(*args, **kwargs):
|
|
147
|
+
mock_thread = MagicMock()
|
|
148
|
+
|
|
149
|
+
def mock_start():
|
|
150
|
+
# Check if this is the input thread by looking at target and args
|
|
151
|
+
target = kwargs.get('target')
|
|
152
|
+
thread_args = kwargs.get('args', ())
|
|
153
|
+
|
|
154
|
+
# If this is the input thread (process_input_sockets), set the ready_event
|
|
155
|
+
if (target and hasattr(target, '__name__')
|
|
156
|
+
and target.__name__ == 'process_input_sockets'):
|
|
157
|
+
assert len(
|
|
158
|
+
thread_args
|
|
159
|
+
) == 4, "Expected 4 arguments for vllm process_input_sockets function"
|
|
160
|
+
ready_event = thread_args[
|
|
161
|
+
3] # ready_event is the 4th argument
|
|
162
|
+
ready_event.set()
|
|
163
|
+
|
|
164
|
+
mock_thread.start = mock_start
|
|
165
|
+
mock_thread.is_alive.return_value = True
|
|
166
|
+
return mock_thread
|
|
167
|
+
|
|
168
|
+
self.thread_patcher = patch("threading.Thread",
|
|
169
|
+
side_effect=mock_thread_constructor)
|
|
170
|
+
self.mock_thread = self.thread_patcher.start()
|
|
171
|
+
self.addCleanup(self.thread_patcher.stop)
|
|
172
|
+
|
|
173
|
+
# Mock jax.devices
|
|
174
|
+
self.mock_jax_devices_patcher = patch('jax.devices',
|
|
175
|
+
return_value=[MagicMock()] * 8)
|
|
176
|
+
self.mock_jax_devices = self.mock_jax_devices_patcher.start()
|
|
177
|
+
self.addCleanup(self.mock_jax_devices_patcher.stop)
|
|
178
|
+
|
|
179
|
+
# VLLM Config
|
|
180
|
+
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
|
181
|
+
self.mock_vllm_config.parallel_config = MagicMock(spec=ParallelConfig)
|
|
182
|
+
self.mock_vllm_config.device_config = MagicMock()
|
|
183
|
+
self.mock_vllm_config.cache_config = MagicMock()
|
|
184
|
+
self.mock_vllm_config.cache_config.prefix_caching_hash_algo = "builtin"
|
|
185
|
+
self.mock_vllm_config.cache_config.block_size = 5
|
|
186
|
+
self.mock_vllm_config.__post_init__ = MagicMock()
|
|
187
|
+
|
|
188
|
+
def test_initialization(self):
|
|
189
|
+
"""Tests that the adapter initializes the orchestrator correctly."""
|
|
190
|
+
proc = DisaggEngineCoreProc(
|
|
191
|
+
vllm_config=self.mock_vllm_config,
|
|
192
|
+
local_client=True,
|
|
193
|
+
handshake_address="dummy_addr",
|
|
194
|
+
executor_class=MagicMock(spec=Executor),
|
|
195
|
+
log_stats=False,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
self.mock_orchestrator.assert_called_once()
|
|
199
|
+
args, kwargs = self.mock_orchestrator.call_args
|
|
200
|
+
self.assertIsInstance(kwargs['config'], VllmConfig)
|
|
201
|
+
self.assertEqual(kwargs['config'], self.mock_vllm_config)
|
|
202
|
+
self.assertEqual(kwargs['output_queue'], proc.output_queue)
|
|
203
|
+
self.assertEqual(len(kwargs['prefill_engines']), 1)
|
|
204
|
+
self.assertEqual(len(kwargs['decode_engines']), 1)
|
|
205
|
+
self.assertEqual(kwargs['prefill_slice_sizes'], (4, ))
|
|
206
|
+
self.assertEqual(kwargs['decode_slice_sizes'], (2, ))
|
|
207
|
+
|
|
208
|
+
def test_add_request(self):
|
|
209
|
+
"""Tests that the adapter correctly delegates add_request to the orchestrator."""
|
|
210
|
+
proc = DisaggEngineCoreProc(
|
|
211
|
+
vllm_config=self.mock_vllm_config,
|
|
212
|
+
local_client=True,
|
|
213
|
+
handshake_address="dummy_addr",
|
|
214
|
+
executor_class=MagicMock(spec=Executor),
|
|
215
|
+
log_stats=False,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
mock_request = MagicMock(spec=EngineCoreRequest)
|
|
219
|
+
mock_request.request_id = "test_req"
|
|
220
|
+
mock_request.mm_hashes = None
|
|
221
|
+
mock_request.mm_kwargs = []
|
|
222
|
+
mock_request.use_structured_output = False
|
|
223
|
+
mock_request.pooling_params = None
|
|
224
|
+
mock_request.sampling_params.structured_outputs = None
|
|
225
|
+
mock_request.block_hashes = []
|
|
226
|
+
|
|
227
|
+
mock_engine_request, _ = proc.preprocess_add_request(mock_request)
|
|
228
|
+
|
|
229
|
+
proc.add_request(mock_engine_request)
|
|
230
|
+
|
|
231
|
+
self.mock_orchestrator.return_value.add_request.assert_called_once()
|
|
232
|
+
# Get the argument passed to add_request
|
|
233
|
+
passed_request = self.mock_orchestrator.return_value.add_request.call_args[
|
|
234
|
+
0][0]
|
|
235
|
+
|
|
236
|
+
# Assert it's the correct type (the Request directly)
|
|
237
|
+
self.assertIsInstance(passed_request, Request)
|
|
238
|
+
self.assertEqual(passed_request.request_id, "test_req")
|
|
239
|
+
|
|
240
|
+
def test_shutdown(self):
|
|
241
|
+
"""Tests that the adapter correctly delegates shutdown to the orchestrator."""
|
|
242
|
+
proc = DisaggEngineCoreProc(
|
|
243
|
+
vllm_config=self.mock_vllm_config,
|
|
244
|
+
local_client=True,
|
|
245
|
+
handshake_address="dummy_addr",
|
|
246
|
+
executor_class=MagicMock(spec=Executor),
|
|
247
|
+
log_stats=False,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
proc.shutdown()
|
|
251
|
+
|
|
252
|
+
self.mock_orchestrator.return_value.shutdown.assert_called_once()
|
|
253
|
+
|
|
254
|
+
def test_handle_client_request_add(self):
|
|
255
|
+
"""Tests that the adapter correctly handles an ADD request."""
|
|
256
|
+
proc = DisaggEngineCoreProc(
|
|
257
|
+
vllm_config=self.mock_vllm_config,
|
|
258
|
+
local_client=True,
|
|
259
|
+
handshake_address="dummy_addr",
|
|
260
|
+
executor_class=MagicMock(spec=Executor),
|
|
261
|
+
log_stats=False,
|
|
262
|
+
)
|
|
263
|
+
mock_request = MagicMock(spec=EngineCoreRequest)
|
|
264
|
+
mock_request.request_id = "test_req"
|
|
265
|
+
mock_request.mm_hashes = None
|
|
266
|
+
mock_request.mm_kwargs = []
|
|
267
|
+
mock_request.use_structured_output = False
|
|
268
|
+
mock_request.pooling_params = None
|
|
269
|
+
mock_request.sampling_params.structured_outputs = None
|
|
270
|
+
mock_request.block_hashes = []
|
|
271
|
+
mock_request = proc.preprocess_add_request(mock_request)
|
|
272
|
+
|
|
273
|
+
proc._handle_client_request(EngineCoreRequestType.ADD, mock_request)
|
|
274
|
+
|
|
275
|
+
self.mock_orchestrator.return_value.add_request.assert_called_once()
|
|
276
|
+
|
|
277
|
+
def test_handle_client_request_abort(self):
|
|
278
|
+
"""Tests that the adapter correctly handles an ABORT request."""
|
|
279
|
+
proc = DisaggEngineCoreProc(
|
|
280
|
+
vllm_config=self.mock_vllm_config,
|
|
281
|
+
local_client=True,
|
|
282
|
+
handshake_address="dummy_addr",
|
|
283
|
+
executor_class=MagicMock(spec=Executor),
|
|
284
|
+
log_stats=False,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# This is currently a no-op, so we just check that it doesn't crash
|
|
288
|
+
proc._handle_client_request(EngineCoreRequestType.ABORT, "test_req")
|
|
289
|
+
|
|
290
|
+
def test_handle_client_request_utility(self):
|
|
291
|
+
"""Tests that the adapter correctly handles a UTILITY request."""
|
|
292
|
+
proc = DisaggEngineCoreProc(
|
|
293
|
+
vllm_config=self.mock_vllm_config,
|
|
294
|
+
local_client=True,
|
|
295
|
+
handshake_address="dummy_addr",
|
|
296
|
+
executor_class=MagicMock(spec=Executor),
|
|
297
|
+
log_stats=False,
|
|
298
|
+
)
|
|
299
|
+
# Mock a method on the prefill engine instance
|
|
300
|
+
proc._prefill_engines = [MagicMock()]
|
|
301
|
+
proc._prefill_engines[0].list_loras.return_value = {1, 2, 3}
|
|
302
|
+
|
|
303
|
+
utility_request = (0, "call-id-1", "list_loras", ())
|
|
304
|
+
proc._handle_client_request(EngineCoreRequestType.UTILITY,
|
|
305
|
+
utility_request)
|
|
306
|
+
|
|
307
|
+
proc._prefill_engines[0].list_loras.assert_called_once()
|
|
308
|
+
self.assertTrue(proc.output_queue.qsize() > 0)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class TestDisaggOrchestrator(unittest.TestCase):
|
|
312
|
+
|
|
313
|
+
def setUp(self):
|
|
314
|
+
self.mock_config = MagicMock(spec=VllmConfig)
|
|
315
|
+
self.mock_config.scheduler_config = MagicMock()
|
|
316
|
+
self.mock_config.scheduler_config.max_num_seqs = 16
|
|
317
|
+
self.mock_config.cache_config = MagicMock()
|
|
318
|
+
self.mock_config.cache_config.block_size = 5
|
|
319
|
+
|
|
320
|
+
self.mock_output_queue = MagicMock()
|
|
321
|
+
self.mock_prefill_engine = MagicMock()
|
|
322
|
+
self.mock_decode_engine = MagicMock()
|
|
323
|
+
|
|
324
|
+
# The orchestrator accesses the scheduler on the engine.
|
|
325
|
+
self.mock_prefill_engine.scheduler = MagicMock()
|
|
326
|
+
self.mock_decode_engine.scheduler = MagicMock()
|
|
327
|
+
|
|
328
|
+
# The orchestrator accesses the model_executor on the engine.
|
|
329
|
+
self.mock_prefill_engine.model_executor = MagicMock()
|
|
330
|
+
self.mock_decode_engine.model_executor = MagicMock()
|
|
331
|
+
|
|
332
|
+
# Patch threads to avoid them running in the background.
|
|
333
|
+
self.jet_thread_patcher = patch(
|
|
334
|
+
"tpu_inference.core.core_tpu.JetThread", MagicMock)
|
|
335
|
+
self.mock_jet_thread = self.jet_thread_patcher.start()
|
|
336
|
+
self.addCleanup(self.jet_thread_patcher.stop)
|
|
337
|
+
|
|
338
|
+
def test_initialization(self):
|
|
339
|
+
"""Tests that the orchestrator initializes correctly."""
|
|
340
|
+
orchestrator = _DisaggOrchestrator(
|
|
341
|
+
config=self.mock_config,
|
|
342
|
+
output_queue=self.mock_output_queue,
|
|
343
|
+
prefill_engines=[self.mock_prefill_engine],
|
|
344
|
+
decode_engines=[self.mock_decode_engine],
|
|
345
|
+
prefill_slice_sizes=(4, ),
|
|
346
|
+
decode_slice_sizes=(2, ),
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
self.assertEqual(orchestrator._config, self.mock_config)
|
|
350
|
+
self.assertEqual(orchestrator._output_queue, self.mock_output_queue)
|
|
351
|
+
self.assertEqual(len(orchestrator._prefill_engines), 1)
|
|
352
|
+
self.assertEqual(len(orchestrator._decode_engines), 1)
|
|
353
|
+
self.assertEqual(len(orchestrator._all_threads),
|
|
354
|
+
3) # 1 prefill, 1 transfer, 1 decode
|
|
355
|
+
|
|
356
|
+
def test_add_request(self):
|
|
357
|
+
"""Tests that a new request is added to the prefill engine."""
|
|
358
|
+
orchestrator = _DisaggOrchestrator(
|
|
359
|
+
config=self.mock_config,
|
|
360
|
+
output_queue=self.mock_output_queue,
|
|
361
|
+
prefill_engines=[self.mock_prefill_engine],
|
|
362
|
+
decode_engines=[self.mock_decode_engine],
|
|
363
|
+
prefill_slice_sizes=(4, ),
|
|
364
|
+
decode_slice_sizes=(2, ),
|
|
365
|
+
)
|
|
366
|
+
mock_request = MagicMock()
|
|
367
|
+
mock_request.request_id = "test_req"
|
|
368
|
+
|
|
369
|
+
orchestrator.add_request(mock_request)
|
|
370
|
+
|
|
371
|
+
self.assertIn("test_req", orchestrator._requests)
|
|
372
|
+
self.mock_prefill_engine.scheduler.add_request.assert_called_once_with(
|
|
373
|
+
mock_request)
|
|
374
|
+
|
|
375
|
+
def test_prefill_logic(self):
|
|
376
|
+
"""Tests the prefill logic of the orchestrator."""
|
|
377
|
+
orchestrator = _DisaggOrchestrator(
|
|
378
|
+
config=self.mock_config,
|
|
379
|
+
output_queue=self.mock_output_queue,
|
|
380
|
+
prefill_engines=[self.mock_prefill_engine],
|
|
381
|
+
decode_engines=[self.mock_decode_engine],
|
|
382
|
+
prefill_slice_sizes=(4, ),
|
|
383
|
+
decode_slice_sizes=(2, ),
|
|
384
|
+
)
|
|
385
|
+
orchestrator.live = True
|
|
386
|
+
|
|
387
|
+
# Mock scheduler output
|
|
388
|
+
mock_scheduler_output = MagicMock()
|
|
389
|
+
mock_scheduler_output.total_num_scheduled_tokens = 1
|
|
390
|
+
self.mock_prefill_engine.scheduler.schedule.return_value = mock_scheduler_output
|
|
391
|
+
|
|
392
|
+
# Mock model output
|
|
393
|
+
mock_model_output = MagicMock()
|
|
394
|
+
mock_model_output.req_id_to_index = {"test_req": 0}
|
|
395
|
+
mock_model_output.sampled_token_ids = [[1]]
|
|
396
|
+
self.mock_prefill_engine.model_executor.execute_model.return_value = mock_model_output
|
|
397
|
+
|
|
398
|
+
# Mock request
|
|
399
|
+
mock_request = MagicMock()
|
|
400
|
+
orchestrator._requests["test_req"] = mock_request
|
|
401
|
+
|
|
402
|
+
# Mock the side effect of update_from_output to stop the loop
|
|
403
|
+
def stop_loop(*args, **kwargs):
|
|
404
|
+
orchestrator.live = False
|
|
405
|
+
return {}
|
|
406
|
+
|
|
407
|
+
self.mock_prefill_engine.scheduler.update_from_output.side_effect = stop_loop
|
|
408
|
+
|
|
409
|
+
orchestrator._prefill(0)
|
|
410
|
+
|
|
411
|
+
self.mock_prefill_engine.model_executor.execute_model.assert_called_once(
|
|
412
|
+
)
|
|
413
|
+
self.assertTrue(orchestrator._transfer_backlogs[0].qsize() > 0)
|
|
414
|
+
|
|
415
|
+
def test_transfer_logic(self):
|
|
416
|
+
"""Tests the transfer logic of the orchestrator."""
|
|
417
|
+
orchestrator = _DisaggOrchestrator(
|
|
418
|
+
config=self.mock_config,
|
|
419
|
+
output_queue=self.mock_output_queue,
|
|
420
|
+
prefill_engines=[self.mock_prefill_engine],
|
|
421
|
+
decode_engines=[self.mock_decode_engine],
|
|
422
|
+
prefill_slice_sizes=(4, ),
|
|
423
|
+
decode_slice_sizes=(2, ),
|
|
424
|
+
)
|
|
425
|
+
orchestrator.live = True
|
|
426
|
+
|
|
427
|
+
# Mock kv cache map
|
|
428
|
+
mock_kv_cache_map = {"test_req": ([MagicMock()], [])}
|
|
429
|
+
orchestrator._transfer_backlogs[0].put(mock_kv_cache_map)
|
|
430
|
+
orchestrator._transfer_backlogs[0].put(
|
|
431
|
+
None) # Sentinel to stop the loop
|
|
432
|
+
|
|
433
|
+
orchestrator._transfer(0)
|
|
434
|
+
|
|
435
|
+
self.mock_decode_engine.model_executor.driver_worker.model_runner.transfer_kv_cache.assert_called_once(
|
|
436
|
+
)
|
|
437
|
+
self.assertTrue(orchestrator._decode_backlogs[0].qsize() > 0)
|
|
438
|
+
|
|
439
|
+
def test_decode_logic(self):
|
|
440
|
+
"""Tests the decode logic of the orchestrator."""
|
|
441
|
+
orchestrator = _DisaggOrchestrator(
|
|
442
|
+
config=self.mock_config,
|
|
443
|
+
output_queue=self.mock_output_queue,
|
|
444
|
+
prefill_engines=[self.mock_prefill_engine],
|
|
445
|
+
decode_engines=[self.mock_decode_engine],
|
|
446
|
+
prefill_slice_sizes=(4, ),
|
|
447
|
+
decode_slice_sizes=(2, ),
|
|
448
|
+
)
|
|
449
|
+
orchestrator.live = True
|
|
450
|
+
|
|
451
|
+
# Mock prefill output
|
|
452
|
+
mock_prefill_output = {
|
|
453
|
+
"req_id": "test_req",
|
|
454
|
+
"cache": [MagicMock()],
|
|
455
|
+
"block_hashes": []
|
|
456
|
+
}
|
|
457
|
+
orchestrator._decode_backlogs[0].put(mock_prefill_output)
|
|
458
|
+
orchestrator._decode_backlogs[0].put(None) # Sentinel to stop the loop
|
|
459
|
+
|
|
460
|
+
# Mock request
|
|
461
|
+
mock_request = MagicMock()
|
|
462
|
+
mock_request.num_computed_tokens = 10
|
|
463
|
+
orchestrator._requests["test_req"] = mock_request
|
|
464
|
+
|
|
465
|
+
# Mock scheduler and model runner states for the loop condition
|
|
466
|
+
self.mock_decode_engine.scheduler.has_requests.return_value = False
|
|
467
|
+
self.mock_decode_engine.scheduler.get_request_counts.return_value = (0,
|
|
468
|
+
0)
|
|
469
|
+
self.mock_decode_engine.model_executor.driver_worker.model_runner.input_batch.num_reqs = 0
|
|
470
|
+
self.mock_decode_engine.scheduler.kv_cache_manager.get_block_ids.return_value = (
|
|
471
|
+
[20, 21], )
|
|
472
|
+
|
|
473
|
+
# Mock scheduler output
|
|
474
|
+
mock_scheduler_output = MagicMock()
|
|
475
|
+
mock_scheduler_output.total_num_scheduled_tokens = 1
|
|
476
|
+
self.mock_decode_engine.scheduler.schedule.return_value = mock_scheduler_output
|
|
477
|
+
|
|
478
|
+
# Mock model output
|
|
479
|
+
mock_model_output = MagicMock()
|
|
480
|
+
self.mock_decode_engine.model_executor.execute_model.return_value = mock_model_output
|
|
481
|
+
|
|
482
|
+
# Mock the side effect of update_from_output to stop the loop
|
|
483
|
+
def stop_loop(*args, **kwargs):
|
|
484
|
+
orchestrator.live = False
|
|
485
|
+
return {"test_req": MagicMock()}
|
|
486
|
+
|
|
487
|
+
self.mock_decode_engine.scheduler.update_from_output.side_effect = stop_loop
|
|
488
|
+
|
|
489
|
+
orchestrator._decode(0)
|
|
490
|
+
|
|
491
|
+
self.mock_decode_engine.model_executor.execute_model.assert_called_once(
|
|
492
|
+
)
|
|
493
|
+
self.mock_output_queue.put_nowait.assert_called_once()
|
|
494
|
+
|
|
495
|
+
def test_shutdown(self):
|
|
496
|
+
"""Tests that the orchestrator correctly shuts down its engines."""
|
|
497
|
+
orchestrator = _DisaggOrchestrator(
|
|
498
|
+
config=self.mock_config,
|
|
499
|
+
output_queue=self.mock_output_queue,
|
|
500
|
+
prefill_engines=[self.mock_prefill_engine],
|
|
501
|
+
decode_engines=[self.mock_decode_engine],
|
|
502
|
+
prefill_slice_sizes=(4, ),
|
|
503
|
+
decode_slice_sizes=(2, ),
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
orchestrator.shutdown()
|
|
507
|
+
|
|
508
|
+
self.mock_prefill_engine.shutdown.assert_called_once()
|
|
509
|
+
self.mock_decode_engine.shutdown.assert_called_once()
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
if __name__ == '__main__':
|
|
513
|
+
unittest.main()
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import unittest
|
|
3
|
+
from unittest.mock import MagicMock, patch
|
|
4
|
+
|
|
5
|
+
from vllm.config import ModelConfig, VllmConfig
|
|
6
|
+
|
|
7
|
+
from tpu_inference.core.disagg_executor import DisaggExecutor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DisaggExecutorTest(unittest.TestCase):
|
|
11
|
+
|
|
12
|
+
def setUp(self):
|
|
13
|
+
"""Set up the test environment by mocking dependencies."""
|
|
14
|
+
# Mock configurations
|
|
15
|
+
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
|
16
|
+
self.mock_vllm_config.model_config = ModelConfig(
|
|
17
|
+
tokenizer_mode="auto",
|
|
18
|
+
trust_remote_code=False,
|
|
19
|
+
seed=0,
|
|
20
|
+
dtype='bfloat16')
|
|
21
|
+
self.mock_vllm_config.cache_config = MagicMock()
|
|
22
|
+
self.mock_vllm_config.scheduler_config = MagicMock()
|
|
23
|
+
self.mock_vllm_config.load_config = MagicMock()
|
|
24
|
+
self.mock_vllm_config.lora_config = None
|
|
25
|
+
self.mock_vllm_config.parallel_config = MagicMock()
|
|
26
|
+
self.mock_vllm_config.device_config = MagicMock()
|
|
27
|
+
self.mock_vllm_config.speculative_config = None
|
|
28
|
+
self.mock_vllm_config.prompt_adapter_config = None
|
|
29
|
+
self.mock_vllm_config.observability_config = MagicMock()
|
|
30
|
+
|
|
31
|
+
# Patch the collective_rpc method to avoid actual RPC calls
|
|
32
|
+
self.patcher = patch(
|
|
33
|
+
"tpu_inference.core.disagg_executor.DisaggExecutor.collective_rpc")
|
|
34
|
+
self.mock_collective_rpc = self.patcher.start()
|
|
35
|
+
self.addCleanup(self.patcher.stop)
|
|
36
|
+
|
|
37
|
+
# Create a DisaggExecutor instance with the mock config
|
|
38
|
+
self.executor = DisaggExecutor(vllm_config=self.mock_vllm_config)
|
|
39
|
+
|
|
40
|
+
def test_init_with_devices(self):
|
|
41
|
+
"""Test init_with_devices."""
|
|
42
|
+
self.executor._init_executor()
|
|
43
|
+
|
|
44
|
+
# Check that collective_rpc was called with the expected arguments
|
|
45
|
+
self.mock_collective_rpc.assert_called()
|
|
46
|
+
calls = self.mock_collective_rpc.call_args_list
|
|
47
|
+
|
|
48
|
+
# Asserts for init_worker
|
|
49
|
+
self.assertEqual(calls[0][0][0], "init_worker")
|
|
50
|
+
self.assertEqual(calls[1][0][0], "init_device")
|
|
51
|
+
self.assertEqual(calls[2][0][0], "load_model")
|
|
52
|
+
|
|
53
|
+
def test_check_health(self):
|
|
54
|
+
"""Test check_health."""
|
|
55
|
+
# Call check_health (it should always pass)
|
|
56
|
+
self.executor.check_health()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
if __name__ == '__main__':
|
|
60
|
+
unittest.main()
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
|
|
3
|
+
from tpu_inference.core.disagg_utils import _parse_slices
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DisaggUtilsTest(unittest.TestCase):
|
|
7
|
+
|
|
8
|
+
def test_parse_slices_valid_cases(self):
|
|
9
|
+
"""Tests valid slice strings."""
|
|
10
|
+
# Test with a single slice
|
|
11
|
+
self.assertEqual(_parse_slices("2x2"), ((2, 2), ))
|
|
12
|
+
self.assertEqual(_parse_slices("2"), (2, ))
|
|
13
|
+
|
|
14
|
+
# Test with multiple slices
|
|
15
|
+
self.assertEqual(_parse_slices("2x2,2x1,3,2x4"),
|
|
16
|
+
((2, 2), (2, 1), 3, (2, 4)))
|
|
17
|
+
|
|
18
|
+
# Test with various dimensions
|
|
19
|
+
self.assertEqual(_parse_slices("1x1,10x10,5x3"),
|
|
20
|
+
((1, 1), (10, 10), (5, 3)))
|
|
21
|
+
|
|
22
|
+
# Test with an empty string
|
|
23
|
+
self.assertEqual(_parse_slices(""), ())
|
|
24
|
+
|
|
25
|
+
def test_parse_slices_with_whitespace(self):
|
|
26
|
+
"""Tests valid slice strings with extra whitespace."""
|
|
27
|
+
self.assertEqual(_parse_slices(" 2x2 "), ((2, 2), ))
|
|
28
|
+
self.assertEqual(_parse_slices(" 2x2 , 2x1 , 2x4 "),
|
|
29
|
+
((2, 2), (2, 1), (2, 4)))
|
|
30
|
+
# The current implementation allows spaces inside the slice definition
|
|
31
|
+
self.assertEqual(_parse_slices("2 x 2"), ((2, 2), ))
|
|
32
|
+
self.assertEqual(_parse_slices(" 10 x 10 "), ((10, 10), ))
|
|
33
|
+
|
|
34
|
+
def test_parse_slices_invalid_cases(self):
|
|
35
|
+
"""Tests malformed slice strings that should raise ValueError."""
|
|
36
|
+
invalid_strings = [
|
|
37
|
+
"2*2", # wrong separator
|
|
38
|
+
"2x", # incomplete
|
|
39
|
+
"axb", # not integers
|
|
40
|
+
"2x2x2", # too many dimensions
|
|
41
|
+
"2x2,3*3", # partially malformed
|
|
42
|
+
",2x2", # leading comma
|
|
43
|
+
"2x2,", # trailing comma
|
|
44
|
+
"2x2,,2x1", # empty slice in middle
|
|
45
|
+
]
|
|
46
|
+
for invalid_str in invalid_strings:
|
|
47
|
+
with self.subTest(invalid_str=invalid_str):
|
|
48
|
+
with self.assertRaises(ValueError):
|
|
49
|
+
_parse_slices(invalid_str)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
if __name__ == '__main__':
|
|
53
|
+
unittest.main()
|