tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
tests/__init__.py ADDED
File without changes
tests/core/__init__.py ADDED
File without changes
@@ -0,0 +1,513 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import unittest
4
+ from unittest.mock import MagicMock, patch
5
+
6
+ from vllm.config import ParallelConfig, VllmConfig
7
+ from vllm.v1.engine import EngineCoreRequest, EngineCoreRequestType
8
+ from vllm.v1.executor.abstract import Executor
9
+ from vllm.v1.request import Request
10
+
11
+ from tpu_inference.core.core_tpu import (DisaggEngineCore,
12
+ DisaggEngineCoreProc,
13
+ _DisaggOrchestrator)
14
+
15
+
16
+ class TestDisaggEngineCore(unittest.TestCase):
17
+
18
+ def setUp(self):
19
+ # Patch disagg_utils to control slice configuration.
20
+ self.mock_disagg_utils_patcher = patch(
21
+ 'tpu_inference.core.core_tpu.disagg_utils')
22
+ self.mock_disagg_utils = self.mock_disagg_utils_patcher.start()
23
+ self.mock_disagg_utils.get_prefill_slices.return_value = (
24
+ 4, ) # One prefill engine
25
+ self.mock_disagg_utils.get_decode_slices.return_value = (
26
+ 2, ) # One decode engine
27
+ self.addCleanup(self.mock_disagg_utils_patcher.stop)
28
+
29
+ # Patch the orchestrator to test the adapter in isolation
30
+ self.mock_orchestrator_patcher = patch(
31
+ 'tpu_inference.core.core_tpu._DisaggOrchestrator')
32
+ self.mock_orchestrator = self.mock_orchestrator_patcher.start()
33
+ self.addCleanup(self.mock_orchestrator_patcher.stop)
34
+
35
+ # Patch vLLMEngineCore to avoid its complex initialization.
36
+ self.mock_engine_core_patcher = patch(
37
+ 'tpu_inference.core.core_tpu.vLLMEngineCore')
38
+ self.mock_vLLMEngineCore = self.mock_engine_core_patcher.start()
39
+ self.addCleanup(self.mock_engine_core_patcher.stop)
40
+
41
+ # Mock jax.devices
42
+ self.mock_jax_devices_patcher = patch('jax.devices',
43
+ return_value=[MagicMock()] * 8)
44
+ self.mock_jax_devices = self.mock_jax_devices_patcher.start()
45
+ self.addCleanup(self.mock_jax_devices_patcher.stop)
46
+
47
+ # VLLM Config
48
+ self.mock_vllm_config = MagicMock(spec=VllmConfig)
49
+ self.mock_vllm_config.parallel_config = MagicMock(spec=ParallelConfig)
50
+ self.mock_vllm_config.device_config = MagicMock()
51
+ self.mock_vllm_config.cache_config = MagicMock()
52
+ self.mock_vllm_config.cache_config.prefix_caching_hash_algo = "builtin"
53
+ self.mock_vllm_config.cache_config.block_size = 5
54
+ self.mock_vllm_config.__post_init__ = MagicMock()
55
+
56
+ def test_initialization(self):
57
+ """Tests that the adapter initializes the orchestrator correctly."""
58
+ engine = DisaggEngineCore(
59
+ vllm_config=self.mock_vllm_config,
60
+ executor_class=MagicMock(spec=Executor),
61
+ log_stats=False,
62
+ )
63
+
64
+ self.mock_orchestrator.assert_called_once()
65
+ args, kwargs = self.mock_orchestrator.call_args
66
+ self.assertIsInstance(kwargs['config'], VllmConfig)
67
+ self.assertEqual(kwargs['config'], self.mock_vllm_config)
68
+ self.assertEqual(kwargs['output_queue'], engine.output_queue)
69
+ self.assertEqual(len(kwargs['prefill_engines']), 1)
70
+ self.assertEqual(len(kwargs['decode_engines']), 1)
71
+ self.assertEqual(kwargs['prefill_slice_sizes'], (4, ))
72
+ self.assertEqual(kwargs['decode_slice_sizes'], (2, ))
73
+
74
+ def test_add_request(self):
75
+ """Tests that the adapter correctly delegates add_request to the orchestrator."""
76
+ engine = DisaggEngineCore(
77
+ vllm_config=self.mock_vllm_config,
78
+ executor_class=MagicMock(spec=Executor),
79
+ log_stats=False,
80
+ )
81
+
82
+ mock_request = MagicMock(spec=Request)
83
+ mock_request.request_id = "test_req"
84
+ mock_request.pooling_params = None
85
+ mock_request.kv_transfer_params = None
86
+
87
+ engine.add_request(mock_request)
88
+
89
+ self.mock_orchestrator.return_value.add_request.assert_called_once()
90
+ # Get the argument passed to add_request
91
+ passed_request = self.mock_orchestrator.return_value.add_request.call_args[
92
+ 0][0]
93
+
94
+ # Assert it's the correct type (the Request directly)
95
+ self.assertIsInstance(passed_request, Request)
96
+ self.assertEqual(passed_request.request_id, "test_req")
97
+
98
+ def test_shutdown(self):
99
+ """Tests that the adapter correctly delegates shutdown to the orchestrator."""
100
+ engine = DisaggEngineCore(
101
+ vllm_config=self.mock_vllm_config,
102
+ executor_class=MagicMock(spec=Executor),
103
+ log_stats=False,
104
+ )
105
+
106
+ engine.shutdown()
107
+
108
+ self.mock_orchestrator.return_value.shutdown.assert_called_once()
109
+
110
+
111
+ class TestDisaggEngineCoreProc(unittest.TestCase):
112
+
113
+ def setUp(self):
114
+ # Patch disagg_utils to control slice configuration.
115
+ self.mock_disagg_utils_patcher = patch(
116
+ 'tpu_inference.core.core_tpu.disagg_utils')
117
+ self.mock_disagg_utils = self.mock_disagg_utils_patcher.start()
118
+ self.mock_disagg_utils.get_prefill_slices.return_value = (
119
+ 4, ) # One prefill engine
120
+ self.mock_disagg_utils.get_decode_slices.return_value = (
121
+ 2, ) # One decode engine
122
+ self.addCleanup(self.mock_disagg_utils_patcher.stop)
123
+
124
+ # Patch the orchestrator to test the adapter in isolation
125
+ self.mock_orchestrator_patcher = patch(
126
+ 'tpu_inference.core.core_tpu._DisaggOrchestrator')
127
+ self.mock_orchestrator = self.mock_orchestrator_patcher.start()
128
+ self.addCleanup(self.mock_orchestrator_patcher.stop)
129
+
130
+ # Patch vLLMEngineCore to avoid its complex initialization.
131
+ self.mock_engine_core_patcher = patch(
132
+ 'tpu_inference.core.core_tpu.vLLMEngineCore')
133
+ self.mock_vLLMEngineCore = self.mock_engine_core_patcher.start()
134
+ self.addCleanup(self.mock_engine_core_patcher.stop)
135
+
136
+ # Patch the ZMQ handshake to isolate the test.
137
+ self.mock_handshake_patcher = patch(
138
+ 'tpu_inference.core.core_tpu.DisaggEngineCoreProc._perform_handshake'
139
+ )
140
+ self.mock_handshake = self.mock_handshake_patcher.start()
141
+ self.mock_handshake.return_value.__enter__.return_value = MagicMock(
142
+ outputs=["output_addr"], coordinator_output=None)
143
+ self.addCleanup(self.mock_handshake_patcher.stop)
144
+
145
+ # Patch threads to avoid them running in the background.
146
+ def mock_thread_constructor(*args, **kwargs):
147
+ mock_thread = MagicMock()
148
+
149
+ def mock_start():
150
+ # Check if this is the input thread by looking at target and args
151
+ target = kwargs.get('target')
152
+ thread_args = kwargs.get('args', ())
153
+
154
+ # If this is the input thread (process_input_sockets), set the ready_event
155
+ if (target and hasattr(target, '__name__')
156
+ and target.__name__ == 'process_input_sockets'):
157
+ assert len(
158
+ thread_args
159
+ ) == 4, "Expected 4 arguments for vllm process_input_sockets function"
160
+ ready_event = thread_args[
161
+ 3] # ready_event is the 4th argument
162
+ ready_event.set()
163
+
164
+ mock_thread.start = mock_start
165
+ mock_thread.is_alive.return_value = True
166
+ return mock_thread
167
+
168
+ self.thread_patcher = patch("threading.Thread",
169
+ side_effect=mock_thread_constructor)
170
+ self.mock_thread = self.thread_patcher.start()
171
+ self.addCleanup(self.thread_patcher.stop)
172
+
173
+ # Mock jax.devices
174
+ self.mock_jax_devices_patcher = patch('jax.devices',
175
+ return_value=[MagicMock()] * 8)
176
+ self.mock_jax_devices = self.mock_jax_devices_patcher.start()
177
+ self.addCleanup(self.mock_jax_devices_patcher.stop)
178
+
179
+ # VLLM Config
180
+ self.mock_vllm_config = MagicMock(spec=VllmConfig)
181
+ self.mock_vllm_config.parallel_config = MagicMock(spec=ParallelConfig)
182
+ self.mock_vllm_config.device_config = MagicMock()
183
+ self.mock_vllm_config.cache_config = MagicMock()
184
+ self.mock_vllm_config.cache_config.prefix_caching_hash_algo = "builtin"
185
+ self.mock_vllm_config.cache_config.block_size = 5
186
+ self.mock_vllm_config.__post_init__ = MagicMock()
187
+
188
+ def test_initialization(self):
189
+ """Tests that the adapter initializes the orchestrator correctly."""
190
+ proc = DisaggEngineCoreProc(
191
+ vllm_config=self.mock_vllm_config,
192
+ local_client=True,
193
+ handshake_address="dummy_addr",
194
+ executor_class=MagicMock(spec=Executor),
195
+ log_stats=False,
196
+ )
197
+
198
+ self.mock_orchestrator.assert_called_once()
199
+ args, kwargs = self.mock_orchestrator.call_args
200
+ self.assertIsInstance(kwargs['config'], VllmConfig)
201
+ self.assertEqual(kwargs['config'], self.mock_vllm_config)
202
+ self.assertEqual(kwargs['output_queue'], proc.output_queue)
203
+ self.assertEqual(len(kwargs['prefill_engines']), 1)
204
+ self.assertEqual(len(kwargs['decode_engines']), 1)
205
+ self.assertEqual(kwargs['prefill_slice_sizes'], (4, ))
206
+ self.assertEqual(kwargs['decode_slice_sizes'], (2, ))
207
+
208
+ def test_add_request(self):
209
+ """Tests that the adapter correctly delegates add_request to the orchestrator."""
210
+ proc = DisaggEngineCoreProc(
211
+ vllm_config=self.mock_vllm_config,
212
+ local_client=True,
213
+ handshake_address="dummy_addr",
214
+ executor_class=MagicMock(spec=Executor),
215
+ log_stats=False,
216
+ )
217
+
218
+ mock_request = MagicMock(spec=EngineCoreRequest)
219
+ mock_request.request_id = "test_req"
220
+ mock_request.mm_hashes = None
221
+ mock_request.mm_kwargs = []
222
+ mock_request.use_structured_output = False
223
+ mock_request.pooling_params = None
224
+ mock_request.sampling_params.structured_outputs = None
225
+ mock_request.block_hashes = []
226
+
227
+ mock_engine_request, _ = proc.preprocess_add_request(mock_request)
228
+
229
+ proc.add_request(mock_engine_request)
230
+
231
+ self.mock_orchestrator.return_value.add_request.assert_called_once()
232
+ # Get the argument passed to add_request
233
+ passed_request = self.mock_orchestrator.return_value.add_request.call_args[
234
+ 0][0]
235
+
236
+ # Assert it's the correct type (the Request directly)
237
+ self.assertIsInstance(passed_request, Request)
238
+ self.assertEqual(passed_request.request_id, "test_req")
239
+
240
+ def test_shutdown(self):
241
+ """Tests that the adapter correctly delegates shutdown to the orchestrator."""
242
+ proc = DisaggEngineCoreProc(
243
+ vllm_config=self.mock_vllm_config,
244
+ local_client=True,
245
+ handshake_address="dummy_addr",
246
+ executor_class=MagicMock(spec=Executor),
247
+ log_stats=False,
248
+ )
249
+
250
+ proc.shutdown()
251
+
252
+ self.mock_orchestrator.return_value.shutdown.assert_called_once()
253
+
254
+ def test_handle_client_request_add(self):
255
+ """Tests that the adapter correctly handles an ADD request."""
256
+ proc = DisaggEngineCoreProc(
257
+ vllm_config=self.mock_vllm_config,
258
+ local_client=True,
259
+ handshake_address="dummy_addr",
260
+ executor_class=MagicMock(spec=Executor),
261
+ log_stats=False,
262
+ )
263
+ mock_request = MagicMock(spec=EngineCoreRequest)
264
+ mock_request.request_id = "test_req"
265
+ mock_request.mm_hashes = None
266
+ mock_request.mm_kwargs = []
267
+ mock_request.use_structured_output = False
268
+ mock_request.pooling_params = None
269
+ mock_request.sampling_params.structured_outputs = None
270
+ mock_request.block_hashes = []
271
+ mock_request = proc.preprocess_add_request(mock_request)
272
+
273
+ proc._handle_client_request(EngineCoreRequestType.ADD, mock_request)
274
+
275
+ self.mock_orchestrator.return_value.add_request.assert_called_once()
276
+
277
+ def test_handle_client_request_abort(self):
278
+ """Tests that the adapter correctly handles an ABORT request."""
279
+ proc = DisaggEngineCoreProc(
280
+ vllm_config=self.mock_vllm_config,
281
+ local_client=True,
282
+ handshake_address="dummy_addr",
283
+ executor_class=MagicMock(spec=Executor),
284
+ log_stats=False,
285
+ )
286
+
287
+ # This is currently a no-op, so we just check that it doesn't crash
288
+ proc._handle_client_request(EngineCoreRequestType.ABORT, "test_req")
289
+
290
+ def test_handle_client_request_utility(self):
291
+ """Tests that the adapter correctly handles a UTILITY request."""
292
+ proc = DisaggEngineCoreProc(
293
+ vllm_config=self.mock_vllm_config,
294
+ local_client=True,
295
+ handshake_address="dummy_addr",
296
+ executor_class=MagicMock(spec=Executor),
297
+ log_stats=False,
298
+ )
299
+ # Mock a method on the prefill engine instance
300
+ proc._prefill_engines = [MagicMock()]
301
+ proc._prefill_engines[0].list_loras.return_value = {1, 2, 3}
302
+
303
+ utility_request = (0, "call-id-1", "list_loras", ())
304
+ proc._handle_client_request(EngineCoreRequestType.UTILITY,
305
+ utility_request)
306
+
307
+ proc._prefill_engines[0].list_loras.assert_called_once()
308
+ self.assertTrue(proc.output_queue.qsize() > 0)
309
+
310
+
311
+ class TestDisaggOrchestrator(unittest.TestCase):
312
+
313
+ def setUp(self):
314
+ self.mock_config = MagicMock(spec=VllmConfig)
315
+ self.mock_config.scheduler_config = MagicMock()
316
+ self.mock_config.scheduler_config.max_num_seqs = 16
317
+ self.mock_config.cache_config = MagicMock()
318
+ self.mock_config.cache_config.block_size = 5
319
+
320
+ self.mock_output_queue = MagicMock()
321
+ self.mock_prefill_engine = MagicMock()
322
+ self.mock_decode_engine = MagicMock()
323
+
324
+ # The orchestrator accesses the scheduler on the engine.
325
+ self.mock_prefill_engine.scheduler = MagicMock()
326
+ self.mock_decode_engine.scheduler = MagicMock()
327
+
328
+ # The orchestrator accesses the model_executor on the engine.
329
+ self.mock_prefill_engine.model_executor = MagicMock()
330
+ self.mock_decode_engine.model_executor = MagicMock()
331
+
332
+ # Patch threads to avoid them running in the background.
333
+ self.jet_thread_patcher = patch(
334
+ "tpu_inference.core.core_tpu.JetThread", MagicMock)
335
+ self.mock_jet_thread = self.jet_thread_patcher.start()
336
+ self.addCleanup(self.jet_thread_patcher.stop)
337
+
338
+ def test_initialization(self):
339
+ """Tests that the orchestrator initializes correctly."""
340
+ orchestrator = _DisaggOrchestrator(
341
+ config=self.mock_config,
342
+ output_queue=self.mock_output_queue,
343
+ prefill_engines=[self.mock_prefill_engine],
344
+ decode_engines=[self.mock_decode_engine],
345
+ prefill_slice_sizes=(4, ),
346
+ decode_slice_sizes=(2, ),
347
+ )
348
+
349
+ self.assertEqual(orchestrator._config, self.mock_config)
350
+ self.assertEqual(orchestrator._output_queue, self.mock_output_queue)
351
+ self.assertEqual(len(orchestrator._prefill_engines), 1)
352
+ self.assertEqual(len(orchestrator._decode_engines), 1)
353
+ self.assertEqual(len(orchestrator._all_threads),
354
+ 3) # 1 prefill, 1 transfer, 1 decode
355
+
356
+ def test_add_request(self):
357
+ """Tests that a new request is added to the prefill engine."""
358
+ orchestrator = _DisaggOrchestrator(
359
+ config=self.mock_config,
360
+ output_queue=self.mock_output_queue,
361
+ prefill_engines=[self.mock_prefill_engine],
362
+ decode_engines=[self.mock_decode_engine],
363
+ prefill_slice_sizes=(4, ),
364
+ decode_slice_sizes=(2, ),
365
+ )
366
+ mock_request = MagicMock()
367
+ mock_request.request_id = "test_req"
368
+
369
+ orchestrator.add_request(mock_request)
370
+
371
+ self.assertIn("test_req", orchestrator._requests)
372
+ self.mock_prefill_engine.scheduler.add_request.assert_called_once_with(
373
+ mock_request)
374
+
375
+ def test_prefill_logic(self):
376
+ """Tests the prefill logic of the orchestrator."""
377
+ orchestrator = _DisaggOrchestrator(
378
+ config=self.mock_config,
379
+ output_queue=self.mock_output_queue,
380
+ prefill_engines=[self.mock_prefill_engine],
381
+ decode_engines=[self.mock_decode_engine],
382
+ prefill_slice_sizes=(4, ),
383
+ decode_slice_sizes=(2, ),
384
+ )
385
+ orchestrator.live = True
386
+
387
+ # Mock scheduler output
388
+ mock_scheduler_output = MagicMock()
389
+ mock_scheduler_output.total_num_scheduled_tokens = 1
390
+ self.mock_prefill_engine.scheduler.schedule.return_value = mock_scheduler_output
391
+
392
+ # Mock model output
393
+ mock_model_output = MagicMock()
394
+ mock_model_output.req_id_to_index = {"test_req": 0}
395
+ mock_model_output.sampled_token_ids = [[1]]
396
+ self.mock_prefill_engine.model_executor.execute_model.return_value = mock_model_output
397
+
398
+ # Mock request
399
+ mock_request = MagicMock()
400
+ orchestrator._requests["test_req"] = mock_request
401
+
402
+ # Mock the side effect of update_from_output to stop the loop
403
+ def stop_loop(*args, **kwargs):
404
+ orchestrator.live = False
405
+ return {}
406
+
407
+ self.mock_prefill_engine.scheduler.update_from_output.side_effect = stop_loop
408
+
409
+ orchestrator._prefill(0)
410
+
411
+ self.mock_prefill_engine.model_executor.execute_model.assert_called_once(
412
+ )
413
+ self.assertTrue(orchestrator._transfer_backlogs[0].qsize() > 0)
414
+
415
+ def test_transfer_logic(self):
416
+ """Tests the transfer logic of the orchestrator."""
417
+ orchestrator = _DisaggOrchestrator(
418
+ config=self.mock_config,
419
+ output_queue=self.mock_output_queue,
420
+ prefill_engines=[self.mock_prefill_engine],
421
+ decode_engines=[self.mock_decode_engine],
422
+ prefill_slice_sizes=(4, ),
423
+ decode_slice_sizes=(2, ),
424
+ )
425
+ orchestrator.live = True
426
+
427
+ # Mock kv cache map
428
+ mock_kv_cache_map = {"test_req": ([MagicMock()], [])}
429
+ orchestrator._transfer_backlogs[0].put(mock_kv_cache_map)
430
+ orchestrator._transfer_backlogs[0].put(
431
+ None) # Sentinel to stop the loop
432
+
433
+ orchestrator._transfer(0)
434
+
435
+ self.mock_decode_engine.model_executor.driver_worker.model_runner.transfer_kv_cache.assert_called_once(
436
+ )
437
+ self.assertTrue(orchestrator._decode_backlogs[0].qsize() > 0)
438
+
439
+ def test_decode_logic(self):
440
+ """Tests the decode logic of the orchestrator."""
441
+ orchestrator = _DisaggOrchestrator(
442
+ config=self.mock_config,
443
+ output_queue=self.mock_output_queue,
444
+ prefill_engines=[self.mock_prefill_engine],
445
+ decode_engines=[self.mock_decode_engine],
446
+ prefill_slice_sizes=(4, ),
447
+ decode_slice_sizes=(2, ),
448
+ )
449
+ orchestrator.live = True
450
+
451
+ # Mock prefill output
452
+ mock_prefill_output = {
453
+ "req_id": "test_req",
454
+ "cache": [MagicMock()],
455
+ "block_hashes": []
456
+ }
457
+ orchestrator._decode_backlogs[0].put(mock_prefill_output)
458
+ orchestrator._decode_backlogs[0].put(None) # Sentinel to stop the loop
459
+
460
+ # Mock request
461
+ mock_request = MagicMock()
462
+ mock_request.num_computed_tokens = 10
463
+ orchestrator._requests["test_req"] = mock_request
464
+
465
+ # Mock scheduler and model runner states for the loop condition
466
+ self.mock_decode_engine.scheduler.has_requests.return_value = False
467
+ self.mock_decode_engine.scheduler.get_request_counts.return_value = (0,
468
+ 0)
469
+ self.mock_decode_engine.model_executor.driver_worker.model_runner.input_batch.num_reqs = 0
470
+ self.mock_decode_engine.scheduler.kv_cache_manager.get_block_ids.return_value = (
471
+ [20, 21], )
472
+
473
+ # Mock scheduler output
474
+ mock_scheduler_output = MagicMock()
475
+ mock_scheduler_output.total_num_scheduled_tokens = 1
476
+ self.mock_decode_engine.scheduler.schedule.return_value = mock_scheduler_output
477
+
478
+ # Mock model output
479
+ mock_model_output = MagicMock()
480
+ self.mock_decode_engine.model_executor.execute_model.return_value = mock_model_output
481
+
482
+ # Mock the side effect of update_from_output to stop the loop
483
+ def stop_loop(*args, **kwargs):
484
+ orchestrator.live = False
485
+ return {"test_req": MagicMock()}
486
+
487
+ self.mock_decode_engine.scheduler.update_from_output.side_effect = stop_loop
488
+
489
+ orchestrator._decode(0)
490
+
491
+ self.mock_decode_engine.model_executor.execute_model.assert_called_once(
492
+ )
493
+ self.mock_output_queue.put_nowait.assert_called_once()
494
+
495
+ def test_shutdown(self):
496
+ """Tests that the orchestrator correctly shuts down its engines."""
497
+ orchestrator = _DisaggOrchestrator(
498
+ config=self.mock_config,
499
+ output_queue=self.mock_output_queue,
500
+ prefill_engines=[self.mock_prefill_engine],
501
+ decode_engines=[self.mock_decode_engine],
502
+ prefill_slice_sizes=(4, ),
503
+ decode_slice_sizes=(2, ),
504
+ )
505
+
506
+ orchestrator.shutdown()
507
+
508
+ self.mock_prefill_engine.shutdown.assert_called_once()
509
+ self.mock_decode_engine.shutdown.assert_called_once()
510
+
511
+
512
+ if __name__ == '__main__':
513
+ unittest.main()
@@ -0,0 +1,60 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import unittest
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ from vllm.config import ModelConfig, VllmConfig
6
+
7
+ from tpu_inference.core.disagg_executor import DisaggExecutor
8
+
9
+
10
+ class DisaggExecutorTest(unittest.TestCase):
11
+
12
+ def setUp(self):
13
+ """Set up the test environment by mocking dependencies."""
14
+ # Mock configurations
15
+ self.mock_vllm_config = MagicMock(spec=VllmConfig)
16
+ self.mock_vllm_config.model_config = ModelConfig(
17
+ tokenizer_mode="auto",
18
+ trust_remote_code=False,
19
+ seed=0,
20
+ dtype='bfloat16')
21
+ self.mock_vllm_config.cache_config = MagicMock()
22
+ self.mock_vllm_config.scheduler_config = MagicMock()
23
+ self.mock_vllm_config.load_config = MagicMock()
24
+ self.mock_vllm_config.lora_config = None
25
+ self.mock_vllm_config.parallel_config = MagicMock()
26
+ self.mock_vllm_config.device_config = MagicMock()
27
+ self.mock_vllm_config.speculative_config = None
28
+ self.mock_vllm_config.prompt_adapter_config = None
29
+ self.mock_vllm_config.observability_config = MagicMock()
30
+
31
+ # Patch the collective_rpc method to avoid actual RPC calls
32
+ self.patcher = patch(
33
+ "tpu_inference.core.disagg_executor.DisaggExecutor.collective_rpc")
34
+ self.mock_collective_rpc = self.patcher.start()
35
+ self.addCleanup(self.patcher.stop)
36
+
37
+ # Create a DisaggExecutor instance with the mock config
38
+ self.executor = DisaggExecutor(vllm_config=self.mock_vllm_config)
39
+
40
+ def test_init_with_devices(self):
41
+ """Test init_with_devices."""
42
+ self.executor._init_executor()
43
+
44
+ # Check that collective_rpc was called with the expected arguments
45
+ self.mock_collective_rpc.assert_called()
46
+ calls = self.mock_collective_rpc.call_args_list
47
+
48
+ # Asserts for init_worker
49
+ self.assertEqual(calls[0][0][0], "init_worker")
50
+ self.assertEqual(calls[1][0][0], "init_device")
51
+ self.assertEqual(calls[2][0][0], "load_model")
52
+
53
+ def test_check_health(self):
54
+ """Test check_health."""
55
+ # Call check_health (it should always pass)
56
+ self.executor.check_health()
57
+
58
+
59
+ if __name__ == '__main__':
60
+ unittest.main()
@@ -0,0 +1,53 @@
1
+ import unittest
2
+
3
+ from tpu_inference.core.disagg_utils import _parse_slices
4
+
5
+
6
+ class DisaggUtilsTest(unittest.TestCase):
7
+
8
+ def test_parse_slices_valid_cases(self):
9
+ """Tests valid slice strings."""
10
+ # Test with a single slice
11
+ self.assertEqual(_parse_slices("2x2"), ((2, 2), ))
12
+ self.assertEqual(_parse_slices("2"), (2, ))
13
+
14
+ # Test with multiple slices
15
+ self.assertEqual(_parse_slices("2x2,2x1,3,2x4"),
16
+ ((2, 2), (2, 1), 3, (2, 4)))
17
+
18
+ # Test with various dimensions
19
+ self.assertEqual(_parse_slices("1x1,10x10,5x3"),
20
+ ((1, 1), (10, 10), (5, 3)))
21
+
22
+ # Test with an empty string
23
+ self.assertEqual(_parse_slices(""), ())
24
+
25
+ def test_parse_slices_with_whitespace(self):
26
+ """Tests valid slice strings with extra whitespace."""
27
+ self.assertEqual(_parse_slices(" 2x2 "), ((2, 2), ))
28
+ self.assertEqual(_parse_slices(" 2x2 , 2x1 , 2x4 "),
29
+ ((2, 2), (2, 1), (2, 4)))
30
+ # The current implementation allows spaces inside the slice definition
31
+ self.assertEqual(_parse_slices("2 x 2"), ((2, 2), ))
32
+ self.assertEqual(_parse_slices(" 10 x 10 "), ((10, 10), ))
33
+
34
+ def test_parse_slices_invalid_cases(self):
35
+ """Tests malformed slice strings that should raise ValueError."""
36
+ invalid_strings = [
37
+ "2*2", # wrong separator
38
+ "2x", # incomplete
39
+ "axb", # not integers
40
+ "2x2x2", # too many dimensions
41
+ "2x2,3*3", # partially malformed
42
+ ",2x2", # leading comma
43
+ "2x2,", # trailing comma
44
+ "2x2,,2x1", # empty slice in middle
45
+ ]
46
+ for invalid_str in invalid_strings:
47
+ with self.subTest(invalid_str=invalid_str):
48
+ with self.assertRaises(ValueError):
49
+ _parse_slices(invalid_str)
50
+
51
+
52
+ if __name__ == '__main__':
53
+ unittest.main()