tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
tests/__init__.py ADDED
File without changes
tests/core/__init__.py ADDED
File without changes
@@ -0,0 +1,83 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import unittest
3
+ from unittest.mock import MagicMock, PropertyMock
4
+
5
+ from tpu_inference.core import adapters
6
+
7
+
8
+ class TestVllmConfigAdapter(unittest.TestCase):
9
+
10
+ def test_config_adapter(self):
11
+ mock_vllm_config = MagicMock()
12
+ type(mock_vllm_config).scheduler_config = PropertyMock(
13
+ return_value="scheduler")
14
+ type(mock_vllm_config).cache_config = PropertyMock(
15
+ return_value="cache")
16
+
17
+ adapter = adapters.VllmConfigAdapter(mock_vllm_config)
18
+
19
+ self.assertEqual(adapter.scheduler_config, "scheduler")
20
+ self.assertEqual(adapter.cache_config, "cache")
21
+
22
+
23
+ class TestVllmSchedulerAdapter(unittest.TestCase):
24
+
25
+ def test_add_request(self):
26
+ mock_scheduler = MagicMock()
27
+ mock_request = MagicMock()
28
+ mock_request.vllm_request = "vllm_request"
29
+ adapter = adapters.VllmSchedulerAdapter(mock_scheduler)
30
+ adapter.add_request(mock_request)
31
+ mock_scheduler.add_request.assert_called_once_with("vllm_request")
32
+
33
+ def test_getattr(self):
34
+ mock_scheduler = MagicMock()
35
+ adapter = adapters.VllmSchedulerAdapter(mock_scheduler)
36
+ adapter.schedule()
37
+ mock_scheduler.schedule.assert_called_once()
38
+
39
+
40
+ class TestVllmEngineAdapter(unittest.TestCase):
41
+
42
+ def test_engine_adapter(self):
43
+ mock_engine_core = MagicMock()
44
+ mock_engine_core.scheduler = "scheduler"
45
+ type(mock_engine_core).model_executor = PropertyMock(
46
+ return_value="executor")
47
+
48
+ adapter = adapters.VllmEngineAdapter(mock_engine_core)
49
+
50
+ self.assertIsInstance(adapter.scheduler, adapters.VllmSchedulerAdapter)
51
+ self.assertEqual(adapter.model_executor, "executor")
52
+
53
+ adapter.execute_model_with_error_logging("arg1", kwarg1="kwarg1")
54
+ mock_engine_core.execute_model_with_error_logging.assert_called_once_with(
55
+ "arg1", kwarg1="kwarg1")
56
+
57
+ adapter.shutdown()
58
+ mock_engine_core.shutdown.assert_called_once()
59
+
60
+
61
+ class TestVllmRequestAdapter(unittest.TestCase):
62
+
63
+ def test_request_adapter(self):
64
+ mock_vllm_request = MagicMock()
65
+ type(mock_vllm_request).request_id = PropertyMock(return_value="123")
66
+
67
+ # Mock properties that can be written to by setting them as attributes
68
+ # on the mock object.
69
+ mock_vllm_request.num_computed_tokens = 10
70
+ mock_vllm_request.status = "COMPLETED"
71
+
72
+ adapter = adapters.VllmRequestAdapter(mock_vllm_request)
73
+
74
+ self.assertEqual(adapter.vllm_request, mock_vllm_request)
75
+ self.assertEqual(adapter.request_id, "123")
76
+ self.assertEqual(adapter.num_computed_tokens, 10)
77
+ self.assertEqual(adapter.status, "COMPLETED")
78
+
79
+ adapter.num_computed_tokens = 20
80
+ self.assertEqual(mock_vllm_request.num_computed_tokens, 20)
81
+
82
+ adapter.status = "RUNNING"
83
+ self.assertEqual(mock_vllm_request.status, "RUNNING")
@@ -0,0 +1,523 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import unittest
4
+ from unittest.mock import MagicMock, patch
5
+
6
+ from vllm.config import ParallelConfig, VllmConfig
7
+ from vllm.v1.engine import EngineCoreRequest, EngineCoreRequestType
8
+ from vllm.v1.executor.abstract import Executor
9
+ from vllm.v1.request import Request
10
+
11
+ from tpu_inference.core.adapters import (VllmConfigAdapter, VllmEngineAdapter,
12
+ VllmRequestAdapter)
13
+ from tpu_inference.core.core_tpu import (DisaggEngineCore,
14
+ DisaggEngineCoreProc,
15
+ _DisaggOrchestrator)
16
+ from tpu_inference.interfaces.config import IConfig
17
+ from tpu_inference.interfaces.engine import IEngineCore
18
+
19
+
20
+ class TestDisaggEngineCore(unittest.TestCase):
21
+
22
+ def setUp(self):
23
+ # Patch disagg_utils to control slice configuration.
24
+ self.mock_disagg_utils_patcher = patch(
25
+ 'tpu_inference.core.core_tpu.disagg_utils')
26
+ self.mock_disagg_utils = self.mock_disagg_utils_patcher.start()
27
+ self.mock_disagg_utils.get_prefill_slices.return_value = (
28
+ 4, ) # One prefill engine
29
+ self.mock_disagg_utils.get_decode_slices.return_value = (
30
+ 2, ) # One decode engine
31
+ self.addCleanup(self.mock_disagg_utils_patcher.stop)
32
+
33
+ # Patch the orchestrator to test the adapter in isolation
34
+ self.mock_orchestrator_patcher = patch(
35
+ 'tpu_inference.core.core_tpu._DisaggOrchestrator')
36
+ self.mock_orchestrator = self.mock_orchestrator_patcher.start()
37
+ self.addCleanup(self.mock_orchestrator_patcher.stop)
38
+
39
+ # Patch vLLMEngineCore to avoid its complex initialization.
40
+ self.mock_engine_core_patcher = patch(
41
+ 'tpu_inference.core.core_tpu.vLLMEngineCore')
42
+ self.mock_vLLMEngineCore = self.mock_engine_core_patcher.start()
43
+ self.addCleanup(self.mock_engine_core_patcher.stop)
44
+
45
+ # Mock jax.devices
46
+ self.mock_jax_devices_patcher = patch('jax.devices',
47
+ return_value=[MagicMock()] * 8)
48
+ self.mock_jax_devices = self.mock_jax_devices_patcher.start()
49
+ self.addCleanup(self.mock_jax_devices_patcher.stop)
50
+
51
+ # VLLM Config
52
+ self.mock_vllm_config = MagicMock(spec=VllmConfig)
53
+ self.mock_vllm_config.parallel_config = MagicMock(spec=ParallelConfig)
54
+ self.mock_vllm_config.device_config = MagicMock()
55
+ self.mock_vllm_config.cache_config = MagicMock()
56
+ self.mock_vllm_config.cache_config.prefix_caching_hash_algo = "builtin"
57
+ self.mock_vllm_config.cache_config.block_size = 5
58
+ self.mock_vllm_config.__post_init__ = MagicMock()
59
+
60
+ def test_initialization(self):
61
+ """Tests that the adapter initializes the orchestrator correctly."""
62
+ engine = DisaggEngineCore(
63
+ vllm_config=self.mock_vllm_config,
64
+ executor_class=MagicMock(spec=Executor),
65
+ log_stats=False,
66
+ )
67
+
68
+ self.mock_orchestrator.assert_called_once()
69
+ args, kwargs = self.mock_orchestrator.call_args
70
+ self.assertIsInstance(kwargs['config'], VllmConfigAdapter)
71
+ self.assertEqual(kwargs['config'].vllm_config, self.mock_vllm_config)
72
+ self.assertEqual(kwargs['output_queue'], engine.output_queue)
73
+ self.assertEqual(len(kwargs['prefill_engines']), 1)
74
+ self.assertIsInstance(kwargs['prefill_engines'][0], VllmEngineAdapter)
75
+ self.assertEqual(len(kwargs['decode_engines']), 1)
76
+ self.assertIsInstance(kwargs['decode_engines'][0], VllmEngineAdapter)
77
+ self.assertEqual(kwargs['prefill_slice_sizes'], (4, ))
78
+ self.assertEqual(kwargs['decode_slice_sizes'], (2, ))
79
+
80
+ def test_add_request(self):
81
+ """Tests that the adapter correctly delegates add_request to the orchestrator."""
82
+ engine = DisaggEngineCore(
83
+ vllm_config=self.mock_vllm_config,
84
+ executor_class=MagicMock(spec=Executor),
85
+ log_stats=False,
86
+ )
87
+
88
+ mock_request = MagicMock(spec=Request)
89
+ mock_request.request_id = "test_req"
90
+ mock_request.pooling_params = None
91
+ mock_request.kv_transfer_params = None
92
+
93
+ engine.add_request(mock_request)
94
+
95
+ self.mock_orchestrator.return_value.add_request.assert_called_once()
96
+ # Get the argument passed to add_request
97
+ passed_request_adapter = self.mock_orchestrator.return_value.add_request.call_args[
98
+ 0][0]
99
+
100
+ # Assert it's the correct type and wraps the correct underlying request
101
+ self.assertIsInstance(passed_request_adapter, VllmRequestAdapter)
102
+ self.assertIsInstance(passed_request_adapter.vllm_request, Request)
103
+ self.assertEqual(passed_request_adapter.request_id, "test_req")
104
+
105
+ def test_shutdown(self):
106
+ """Tests that the adapter correctly delegates shutdown to the orchestrator."""
107
+ engine = DisaggEngineCore(
108
+ vllm_config=self.mock_vllm_config,
109
+ executor_class=MagicMock(spec=Executor),
110
+ log_stats=False,
111
+ )
112
+
113
+ engine.shutdown()
114
+
115
+ self.mock_orchestrator.return_value.shutdown.assert_called_once()
116
+
117
+
118
+ class TestDisaggEngineCoreProc(unittest.TestCase):
119
+
120
+ def setUp(self):
121
+ # Patch disagg_utils to control slice configuration.
122
+ self.mock_disagg_utils_patcher = patch(
123
+ 'tpu_inference.core.core_tpu.disagg_utils')
124
+ self.mock_disagg_utils = self.mock_disagg_utils_patcher.start()
125
+ self.mock_disagg_utils.get_prefill_slices.return_value = (
126
+ 4, ) # One prefill engine
127
+ self.mock_disagg_utils.get_decode_slices.return_value = (
128
+ 2, ) # One decode engine
129
+ self.addCleanup(self.mock_disagg_utils_patcher.stop)
130
+
131
+ # Patch the orchestrator to test the adapter in isolation
132
+ self.mock_orchestrator_patcher = patch(
133
+ 'tpu_inference.core.core_tpu._DisaggOrchestrator')
134
+ self.mock_orchestrator = self.mock_orchestrator_patcher.start()
135
+ self.addCleanup(self.mock_orchestrator_patcher.stop)
136
+
137
+ # Patch vLLMEngineCore to avoid its complex initialization.
138
+ self.mock_engine_core_patcher = patch(
139
+ 'tpu_inference.core.core_tpu.vLLMEngineCore')
140
+ self.mock_vLLMEngineCore = self.mock_engine_core_patcher.start()
141
+ self.addCleanup(self.mock_engine_core_patcher.stop)
142
+
143
+ # Patch the ZMQ handshake to isolate the test.
144
+ self.mock_handshake_patcher = patch(
145
+ 'tpu_inference.core.core_tpu.DisaggEngineCoreProc._perform_handshake'
146
+ )
147
+ self.mock_handshake = self.mock_handshake_patcher.start()
148
+ self.mock_handshake.return_value.__enter__.return_value = MagicMock(
149
+ outputs=["output_addr"], coordinator_output=None)
150
+ self.addCleanup(self.mock_handshake_patcher.stop)
151
+
152
+ # Patch threads to avoid them running in the background.
153
+ def mock_thread_constructor(*args, **kwargs):
154
+ mock_thread = MagicMock()
155
+
156
+ def mock_start():
157
+ # Check if this is the input thread by looking at target and args
158
+ target = kwargs.get('target')
159
+ thread_args = kwargs.get('args', ())
160
+
161
+ # If this is the input thread (process_input_sockets), set the ready_event
162
+ if (target and hasattr(target, '__name__')
163
+ and target.__name__ == 'process_input_sockets'):
164
+ assert len(
165
+ thread_args
166
+ ) == 4, "Expected 4 arguments for vllm process_input_sockets function"
167
+ ready_event = thread_args[
168
+ 3] # ready_event is the 4th argument
169
+ ready_event.set()
170
+
171
+ mock_thread.start = mock_start
172
+ mock_thread.is_alive.return_value = True
173
+ return mock_thread
174
+
175
+ self.thread_patcher = patch("threading.Thread",
176
+ side_effect=mock_thread_constructor)
177
+ self.mock_thread = self.thread_patcher.start()
178
+ self.addCleanup(self.thread_patcher.stop)
179
+
180
+ # Mock jax.devices
181
+ self.mock_jax_devices_patcher = patch('jax.devices',
182
+ return_value=[MagicMock()] * 8)
183
+ self.mock_jax_devices = self.mock_jax_devices_patcher.start()
184
+ self.addCleanup(self.mock_jax_devices_patcher.stop)
185
+
186
+ # VLLM Config
187
+ self.mock_vllm_config = MagicMock(spec=VllmConfig)
188
+ self.mock_vllm_config.parallel_config = MagicMock(spec=ParallelConfig)
189
+ self.mock_vllm_config.device_config = MagicMock()
190
+ self.mock_vllm_config.cache_config = MagicMock()
191
+ self.mock_vllm_config.cache_config.prefix_caching_hash_algo = "builtin"
192
+ self.mock_vllm_config.cache_config.block_size = 5
193
+ self.mock_vllm_config.__post_init__ = MagicMock()
194
+
195
+ def test_initialization(self):
196
+ """Tests that the adapter initializes the orchestrator correctly."""
197
+ proc = DisaggEngineCoreProc(
198
+ vllm_config=self.mock_vllm_config,
199
+ local_client=True,
200
+ handshake_address="dummy_addr",
201
+ executor_class=MagicMock(spec=Executor),
202
+ log_stats=False,
203
+ )
204
+
205
+ self.mock_orchestrator.assert_called_once()
206
+ args, kwargs = self.mock_orchestrator.call_args
207
+ self.assertIsInstance(kwargs['config'], VllmConfigAdapter)
208
+ self.assertEqual(kwargs['config'].vllm_config, self.mock_vllm_config)
209
+ self.assertEqual(kwargs['output_queue'], proc.output_queue)
210
+ self.assertEqual(len(kwargs['prefill_engines']), 1)
211
+ self.assertIsInstance(kwargs['prefill_engines'][0], VllmEngineAdapter)
212
+ self.assertEqual(len(kwargs['decode_engines']), 1)
213
+ self.assertIsInstance(kwargs['decode_engines'][0], VllmEngineAdapter)
214
+ self.assertEqual(kwargs['prefill_slice_sizes'], (4, ))
215
+ self.assertEqual(kwargs['decode_slice_sizes'], (2, ))
216
+
217
+ def test_add_request(self):
218
+ """Tests that the adapter correctly delegates add_request to the orchestrator."""
219
+ proc = DisaggEngineCoreProc(
220
+ vllm_config=self.mock_vllm_config,
221
+ local_client=True,
222
+ handshake_address="dummy_addr",
223
+ executor_class=MagicMock(spec=Executor),
224
+ log_stats=False,
225
+ )
226
+
227
+ mock_request = MagicMock(spec=EngineCoreRequest)
228
+ mock_request.request_id = "test_req"
229
+ mock_request.mm_hashes = None
230
+ mock_request.mm_kwargs = []
231
+ mock_request.use_structured_output = False
232
+ mock_request.pooling_params = None
233
+ mock_request.sampling_params.structured_outputs = None
234
+ mock_request.block_hashes = []
235
+
236
+ mock_engine_request, _ = proc.preprocess_add_request(mock_request)
237
+
238
+ proc.add_request(mock_engine_request)
239
+
240
+ self.mock_orchestrator.return_value.add_request.assert_called_once()
241
+ # Get the argument passed to add_request
242
+ passed_request_adapter = self.mock_orchestrator.return_value.add_request.call_args[
243
+ 0][0]
244
+
245
+ # Assert it's the correct type and wraps the correct underlying request
246
+ self.assertIsInstance(passed_request_adapter, VllmRequestAdapter)
247
+ self.assertIsInstance(passed_request_adapter.vllm_request, Request)
248
+ self.assertEqual(passed_request_adapter.request_id, "test_req")
249
+
250
+ def test_shutdown(self):
251
+ """Tests that the adapter correctly delegates shutdown to the orchestrator."""
252
+ proc = DisaggEngineCoreProc(
253
+ vllm_config=self.mock_vllm_config,
254
+ local_client=True,
255
+ handshake_address="dummy_addr",
256
+ executor_class=MagicMock(spec=Executor),
257
+ log_stats=False,
258
+ )
259
+
260
+ proc.shutdown()
261
+
262
+ self.mock_orchestrator.return_value.shutdown.assert_called_once()
263
+
264
+ def test_handle_client_request_add(self):
265
+ """Tests that the adapter correctly handles an ADD request."""
266
+ proc = DisaggEngineCoreProc(
267
+ vllm_config=self.mock_vllm_config,
268
+ local_client=True,
269
+ handshake_address="dummy_addr",
270
+ executor_class=MagicMock(spec=Executor),
271
+ log_stats=False,
272
+ )
273
+ mock_request = MagicMock(spec=EngineCoreRequest)
274
+ mock_request.request_id = "test_req"
275
+ mock_request.mm_hashes = None
276
+ mock_request.mm_kwargs = []
277
+ mock_request.use_structured_output = False
278
+ mock_request.pooling_params = None
279
+ mock_request.sampling_params.structured_outputs = None
280
+ mock_request.block_hashes = []
281
+ mock_request = proc.preprocess_add_request(mock_request)
282
+
283
+ proc._handle_client_request(EngineCoreRequestType.ADD, mock_request)
284
+
285
+ self.mock_orchestrator.return_value.add_request.assert_called_once()
286
+
287
+ def test_handle_client_request_abort(self):
288
+ """Tests that the adapter correctly handles an ABORT request."""
289
+ proc = DisaggEngineCoreProc(
290
+ vllm_config=self.mock_vllm_config,
291
+ local_client=True,
292
+ handshake_address="dummy_addr",
293
+ executor_class=MagicMock(spec=Executor),
294
+ log_stats=False,
295
+ )
296
+
297
+ # This is currently a no-op, so we just check that it doesn't crash
298
+ proc._handle_client_request(EngineCoreRequestType.ABORT, "test_req")
299
+
300
+ def test_handle_client_request_utility(self):
301
+ """Tests that the adapter correctly handles a UTILITY request."""
302
+ proc = DisaggEngineCoreProc(
303
+ vllm_config=self.mock_vllm_config,
304
+ local_client=True,
305
+ handshake_address="dummy_addr",
306
+ executor_class=MagicMock(spec=Executor),
307
+ log_stats=False,
308
+ )
309
+ # Mock a method on the prefill engine instance
310
+ proc._prefill_engines = [MagicMock()]
311
+ proc._prefill_engines[0].list_loras.return_value = {1, 2, 3}
312
+
313
+ utility_request = (0, "call-id-1", "list_loras", ())
314
+ proc._handle_client_request(EngineCoreRequestType.UTILITY,
315
+ utility_request)
316
+
317
+ proc._prefill_engines[0].list_loras.assert_called_once()
318
+ self.assertTrue(proc.output_queue.qsize() > 0)
319
+
320
+
321
+ class TestDisaggOrchestrator(unittest.TestCase):
322
+
323
+ def setUp(self):
324
+ self.mock_config = MagicMock(spec=IConfig)
325
+ self.mock_config.scheduler_config = MagicMock()
326
+ self.mock_config.scheduler_config.max_num_seqs = 16
327
+ self.mock_config.cache_config = MagicMock()
328
+ self.mock_config.cache_config.block_size = 5
329
+
330
+ self.mock_output_queue = MagicMock()
331
+ self.mock_prefill_engine = MagicMock(spec=IEngineCore)
332
+ self.mock_decode_engine = MagicMock(spec=IEngineCore)
333
+
334
+ # The orchestrator accesses the scheduler on the engine.
335
+ self.mock_prefill_engine.scheduler = MagicMock()
336
+ self.mock_decode_engine.scheduler = MagicMock()
337
+
338
+ # The orchestrator accesses the model_executor on the engine.
339
+ self.mock_prefill_engine.model_executor = MagicMock()
340
+ self.mock_decode_engine.model_executor = MagicMock()
341
+
342
+ # Patch threads to avoid them running in the background.
343
+ self.jet_thread_patcher = patch(
344
+ "tpu_inference.core.core_tpu.JetThread", MagicMock)
345
+ self.mock_jet_thread = self.jet_thread_patcher.start()
346
+ self.addCleanup(self.jet_thread_patcher.stop)
347
+
348
+ def test_initialization(self):
349
+ """Tests that the orchestrator initializes correctly."""
350
+ orchestrator = _DisaggOrchestrator(
351
+ config=self.mock_config,
352
+ output_queue=self.mock_output_queue,
353
+ prefill_engines=[self.mock_prefill_engine],
354
+ decode_engines=[self.mock_decode_engine],
355
+ prefill_slice_sizes=(4, ),
356
+ decode_slice_sizes=(2, ),
357
+ )
358
+
359
+ self.assertEqual(orchestrator._config, self.mock_config)
360
+ self.assertEqual(orchestrator._output_queue, self.mock_output_queue)
361
+ self.assertEqual(len(orchestrator._prefill_engines), 1)
362
+ self.assertEqual(len(orchestrator._decode_engines), 1)
363
+ self.assertEqual(len(orchestrator._all_threads),
364
+ 3) # 1 prefill, 1 transfer, 1 decode
365
+
366
+ def test_add_request(self):
367
+ """Tests that a new request is added to the prefill engine."""
368
+ orchestrator = _DisaggOrchestrator(
369
+ config=self.mock_config,
370
+ output_queue=self.mock_output_queue,
371
+ prefill_engines=[self.mock_prefill_engine],
372
+ decode_engines=[self.mock_decode_engine],
373
+ prefill_slice_sizes=(4, ),
374
+ decode_slice_sizes=(2, ),
375
+ )
376
+ mock_request = MagicMock()
377
+ mock_request.vllm_request.request_id = "test_req"
378
+
379
+ orchestrator.add_request(mock_request)
380
+
381
+ self.assertIn("test_req", orchestrator._requests)
382
+ self.mock_prefill_engine.scheduler.add_request.assert_called_once_with(
383
+ mock_request)
384
+
385
+ def test_prefill_logic(self):
386
+ """Tests the prefill logic of the orchestrator."""
387
+ orchestrator = _DisaggOrchestrator(
388
+ config=self.mock_config,
389
+ output_queue=self.mock_output_queue,
390
+ prefill_engines=[self.mock_prefill_engine],
391
+ decode_engines=[self.mock_decode_engine],
392
+ prefill_slice_sizes=(4, ),
393
+ decode_slice_sizes=(2, ),
394
+ )
395
+ orchestrator.live = True
396
+
397
+ # Mock scheduler output
398
+ mock_scheduler_output = MagicMock()
399
+ mock_scheduler_output.total_num_scheduled_tokens = 1
400
+ self.mock_prefill_engine.scheduler.schedule.return_value = mock_scheduler_output
401
+
402
+ # Mock model output
403
+ mock_model_output = MagicMock()
404
+ mock_model_output.req_id_to_index = {"test_req": 0}
405
+ mock_model_output.sampled_token_ids = [[1]]
406
+ self.mock_prefill_engine.execute_model_with_error_logging.return_value = mock_model_output
407
+
408
+ # Mock request
409
+ mock_request = MagicMock()
410
+ orchestrator._requests["test_req"] = mock_request
411
+
412
+ # Mock the side effect of update_from_output to stop the loop
413
+ def stop_loop(*args, **kwargs):
414
+ orchestrator.live = False
415
+ return {}
416
+
417
+ self.mock_prefill_engine.scheduler.update_from_output.side_effect = stop_loop
418
+
419
+ orchestrator._prefill(0)
420
+
421
+ self.mock_prefill_engine.execute_model_with_error_logging.assert_called_once(
422
+ )
423
+ self.assertTrue(orchestrator._transfer_backlogs[0].qsize() > 0)
424
+
425
+ def test_transfer_logic(self):
426
+ """Tests the transfer logic of the orchestrator."""
427
+ orchestrator = _DisaggOrchestrator(
428
+ config=self.mock_config,
429
+ output_queue=self.mock_output_queue,
430
+ prefill_engines=[self.mock_prefill_engine],
431
+ decode_engines=[self.mock_decode_engine],
432
+ prefill_slice_sizes=(4, ),
433
+ decode_slice_sizes=(2, ),
434
+ )
435
+ orchestrator.live = True
436
+
437
+ # Mock kv cache map
438
+ mock_kv_cache_map = {"test_req": ([MagicMock()], [])}
439
+ orchestrator._transfer_backlogs[0].put(mock_kv_cache_map)
440
+ orchestrator._transfer_backlogs[0].put(
441
+ None) # Sentinel to stop the loop
442
+
443
+ orchestrator._transfer(0)
444
+
445
+ self.mock_decode_engine.model_executor.driver_worker.model_runner.transfer_kv_cache.assert_called_once(
446
+ )
447
+ self.assertTrue(orchestrator._decode_backlogs[0].qsize() > 0)
448
+
449
+ def test_decode_logic(self):
450
+ """Tests the decode logic of the orchestrator."""
451
+ orchestrator = _DisaggOrchestrator(
452
+ config=self.mock_config,
453
+ output_queue=self.mock_output_queue,
454
+ prefill_engines=[self.mock_prefill_engine],
455
+ decode_engines=[self.mock_decode_engine],
456
+ prefill_slice_sizes=(4, ),
457
+ decode_slice_sizes=(2, ),
458
+ )
459
+ orchestrator.live = True
460
+
461
+ # Mock prefill output
462
+ mock_prefill_output = {
463
+ "req_id": "test_req",
464
+ "cache": [MagicMock()],
465
+ "block_hashes": []
466
+ }
467
+ orchestrator._decode_backlogs[0].put(mock_prefill_output)
468
+ orchestrator._decode_backlogs[0].put(None) # Sentinel to stop the loop
469
+
470
+ # Mock request
471
+ mock_request = MagicMock()
472
+ mock_request.vllm_request.num_computed_tokens = 10
473
+ orchestrator._requests["test_req"] = mock_request
474
+
475
+ # Mock scheduler and model runner states for the loop condition
476
+ self.mock_decode_engine.scheduler.has_requests.return_value = False
477
+ self.mock_decode_engine.scheduler.get_request_counts.return_value = (0,
478
+ 0)
479
+ self.mock_decode_engine.model_executor.driver_worker.model_runner.input_batch.num_reqs = 0
480
+ self.mock_decode_engine.scheduler.kv_cache_manager.get_block_ids.return_value = (
481
+ [20, 21], )
482
+
483
+ # Mock scheduler output
484
+ mock_scheduler_output = MagicMock()
485
+ mock_scheduler_output.total_num_scheduled_tokens = 1
486
+ self.mock_decode_engine.scheduler.schedule.return_value = mock_scheduler_output
487
+
488
+ # Mock model output
489
+ mock_model_output = MagicMock()
490
+ self.mock_decode_engine.execute_model_with_error_logging.return_value = mock_model_output
491
+
492
+ # Mock the side effect of update_from_output to stop the loop
493
+ def stop_loop(*args, **kwargs):
494
+ orchestrator.live = False
495
+ return {"test_req": MagicMock()}
496
+
497
+ self.mock_decode_engine.scheduler.update_from_output.side_effect = stop_loop
498
+
499
+ orchestrator._decode(0)
500
+
501
+ self.mock_decode_engine.execute_model_with_error_logging.assert_called_once(
502
+ )
503
+ self.mock_output_queue.put_nowait.assert_called_once()
504
+
505
+ def test_shutdown(self):
506
+ """Tests that the orchestrator correctly shuts down its engines."""
507
+ orchestrator = _DisaggOrchestrator(
508
+ config=self.mock_config,
509
+ output_queue=self.mock_output_queue,
510
+ prefill_engines=[self.mock_prefill_engine],
511
+ decode_engines=[self.mock_decode_engine],
512
+ prefill_slice_sizes=(4, ),
513
+ decode_slice_sizes=(2, ),
514
+ )
515
+
516
+ orchestrator.shutdown()
517
+
518
+ self.mock_prefill_engine.shutdown.assert_called_once()
519
+ self.mock_decode_engine.shutdown.assert_called_once()
520
+
521
+
522
+ if __name__ == '__main__':
523
+ unittest.main()
@@ -0,0 +1,60 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import unittest
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ from vllm.config import ModelConfig, VllmConfig
6
+
7
+ from tpu_inference.core.disagg_executor import DisaggExecutor
8
+
9
+
10
+ class DisaggExecutorTest(unittest.TestCase):
11
+
12
+ def setUp(self):
13
+ """Set up the test environment by mocking dependencies."""
14
+ # Mock configurations
15
+ self.mock_vllm_config = MagicMock(spec=VllmConfig)
16
+ self.mock_vllm_config.model_config = ModelConfig(
17
+ tokenizer_mode="auto",
18
+ trust_remote_code=False,
19
+ seed=0,
20
+ dtype='bfloat16')
21
+ self.mock_vllm_config.cache_config = MagicMock()
22
+ self.mock_vllm_config.scheduler_config = MagicMock()
23
+ self.mock_vllm_config.load_config = MagicMock()
24
+ self.mock_vllm_config.lora_config = None
25
+ self.mock_vllm_config.parallel_config = MagicMock()
26
+ self.mock_vllm_config.device_config = MagicMock()
27
+ self.mock_vllm_config.speculative_config = None
28
+ self.mock_vllm_config.prompt_adapter_config = None
29
+ self.mock_vllm_config.observability_config = MagicMock()
30
+
31
+ # Patch the collective_rpc method to avoid actual RPC calls
32
+ self.patcher = patch(
33
+ "tpu_inference.core.disagg_executor.DisaggExecutor.collective_rpc")
34
+ self.mock_collective_rpc = self.patcher.start()
35
+ self.addCleanup(self.patcher.stop)
36
+
37
+ # Create a DisaggExecutor instance with the mock config
38
+ self.executor = DisaggExecutor(vllm_config=self.mock_vllm_config)
39
+
40
+ def test_init_with_devices(self):
41
+ """Test init_with_devices."""
42
+ self.executor._init_executor()
43
+
44
+ # Check that collective_rpc was called with the expected arguments
45
+ self.mock_collective_rpc.assert_called()
46
+ calls = self.mock_collective_rpc.call_args_list
47
+
48
+ # Asserts for init_worker
49
+ self.assertEqual(calls[0][0][0], "init_worker")
50
+ self.assertEqual(calls[1][0][0], "init_device")
51
+ self.assertEqual(calls[2][0][0], "load_model")
52
+
53
+ def test_check_health(self):
54
+ """Test check_health."""
55
+ # Call check_health (it should always pass)
56
+ self.executor.check_health()
57
+
58
+
59
+ if __name__ == '__main__':
60
+ unittest.main()