tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/core/test_dp_scheduler.py +128 -71
- tests/e2e/test_data_parallel.py +176 -280
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_speculative_decoding.py +26 -6
- tests/layers/jax/test_qwix.py +1 -1
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
- tests/layers/vllm/test_mxfp4.py +25 -10
- tests/layers/vllm/test_unquantized.py +61 -31
- tests/layers/vllm/utils.py +19 -4
- tests/models/common/test_model_loader.py +2 -2
- tests/models/jax/test_qwen2_5_vl.py +10 -11
- tests/runner/test_multimodal_manager.py +3 -3
- tests/runner/test_tpu_runner.py +67 -8
- tests/runner/test_tpu_runner_dp.py +66 -0
- tpu_inference/core/sched/dp_scheduler.py +65 -40
- tpu_inference/kernels/mla/v1/kernel.py +7 -26
- tpu_inference/layers/common/sharding.py +8 -3
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
- tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
- tpu_inference/layers/jax/sample/sampling.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +51 -47
- tpu_inference/layers/vllm/quantization/common.py +14 -13
- tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
- tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
- tpu_inference/layers/vllm/sharding.py +7 -4
- tpu_inference/models/common/model_loader.py +11 -14
- tpu_inference/models/jax/llama3.py +13 -10
- tpu_inference/models/jax/llama_guard_4.py +1 -1
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -4
- tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
- tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
- tpu_inference/platforms/tpu_platform.py +7 -7
- tpu_inference/runner/compilation_manager.py +43 -33
- tpu_inference/runner/kv_cache_manager.py +1 -2
- tpu_inference/runner/multimodal_manager.py +1 -1
- tpu_inference/runner/tpu_runner.py +12 -9
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/tpu_worker.py +5 -2
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
tests/core/test_dp_scheduler.py
CHANGED
|
@@ -86,7 +86,9 @@ class TestDPScheduler:
|
|
|
86
86
|
assert scheduler.dp_size == 2
|
|
87
87
|
assert len(scheduler.processes) == 2
|
|
88
88
|
assert len(scheduler.input_queues) == 2
|
|
89
|
-
|
|
89
|
+
# output_queues is a dict with (rank, command) tuple keys
|
|
90
|
+
# 2 ranks × 14 commands (SchedulerCommand enum)
|
|
91
|
+
assert len(scheduler.output_queues) == 28
|
|
90
92
|
assert scheduler.log_stats is True
|
|
91
93
|
assert len(scheduler.per_rank_kv_cache_configs) == 2
|
|
92
94
|
|
|
@@ -112,13 +114,18 @@ class TestDPScheduler:
|
|
|
112
114
|
block_size=16,
|
|
113
115
|
)
|
|
114
116
|
|
|
115
|
-
# Mock the queues
|
|
117
|
+
# Mock the queues - need to mock the .get() method to return the value
|
|
116
118
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
117
|
-
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
118
119
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
120
|
+
mock_queue_0 = MagicMock()
|
|
121
|
+
mock_queue_0.get.return_value = 30
|
|
122
|
+
mock_queue_1 = MagicMock()
|
|
123
|
+
mock_queue_1.get.return_value = 15
|
|
124
|
+
|
|
125
|
+
scheduler.output_queues = {
|
|
126
|
+
(0, "get_token_count"): mock_queue_0,
|
|
127
|
+
(1, "get_token_count"): mock_queue_1,
|
|
128
|
+
}
|
|
122
129
|
|
|
123
130
|
rank_tokens = scheduler._get_rank_token_counts()
|
|
124
131
|
|
|
@@ -148,27 +155,25 @@ class TestDPScheduler:
|
|
|
148
155
|
|
|
149
156
|
mock_request = MagicMock(spec=Request)
|
|
150
157
|
|
|
151
|
-
# Mock the queues
|
|
158
|
+
# Mock the queues with tuple keys (rank, command)
|
|
152
159
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
scheduler.output_queues[1].get = MagicMock(
|
|
171
|
-
side_effect=responses_1)
|
|
160
|
+
|
|
161
|
+
# Create proper mocks for queue.get() calls
|
|
162
|
+
mock_queue_get_token_0 = MagicMock()
|
|
163
|
+
mock_queue_get_token_0.get.return_value = 100
|
|
164
|
+
mock_queue_get_token_1 = MagicMock()
|
|
165
|
+
mock_queue_get_token_1.get.return_value = 50
|
|
166
|
+
mock_queue_computed_0 = MagicMock()
|
|
167
|
+
mock_queue_computed_0.get.return_value = ([], 10)
|
|
168
|
+
mock_queue_computed_1 = MagicMock()
|
|
169
|
+
mock_queue_computed_1.get.return_value = ([], 25)
|
|
170
|
+
|
|
171
|
+
scheduler.output_queues = {
|
|
172
|
+
(0, "get_token_count"): mock_queue_get_token_0,
|
|
173
|
+
(1, "get_token_count"): mock_queue_get_token_1,
|
|
174
|
+
(0, "get_computed_blocks"): mock_queue_computed_0,
|
|
175
|
+
(1, "get_computed_blocks"): mock_queue_computed_1,
|
|
176
|
+
}
|
|
172
177
|
|
|
173
178
|
rank = scheduler._find_best_rank_for_request(mock_request)
|
|
174
179
|
|
|
@@ -192,15 +197,25 @@ class TestDPScheduler:
|
|
|
192
197
|
|
|
193
198
|
mock_request = MagicMock(spec=Request)
|
|
194
199
|
|
|
195
|
-
# Mock the queues
|
|
200
|
+
# Mock the queues with tuple keys (rank, command)
|
|
196
201
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
197
|
-
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
198
202
|
|
|
199
|
-
#
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
203
|
+
# Create proper mocks for queue.get() calls
|
|
204
|
+
mock_queue_get_token_0 = MagicMock()
|
|
205
|
+
mock_queue_get_token_0.get.return_value = 100
|
|
206
|
+
mock_queue_get_token_1 = MagicMock()
|
|
207
|
+
mock_queue_get_token_1.get.return_value = 50
|
|
208
|
+
mock_queue_computed_0 = MagicMock()
|
|
209
|
+
mock_queue_computed_0.get.return_value = ([], 0)
|
|
210
|
+
mock_queue_computed_1 = MagicMock()
|
|
211
|
+
mock_queue_computed_1.get.return_value = ([], 0)
|
|
212
|
+
|
|
213
|
+
scheduler.output_queues = {
|
|
214
|
+
(0, "get_token_count"): mock_queue_get_token_0,
|
|
215
|
+
(1, "get_token_count"): mock_queue_get_token_1,
|
|
216
|
+
(0, "get_computed_blocks"): mock_queue_computed_0,
|
|
217
|
+
(1, "get_computed_blocks"): mock_queue_computed_1,
|
|
218
|
+
}
|
|
204
219
|
|
|
205
220
|
rank = scheduler._find_best_rank_for_request(mock_request)
|
|
206
221
|
|
|
@@ -225,11 +240,12 @@ class TestDPScheduler:
|
|
|
225
240
|
mock_request = MagicMock(spec=Request)
|
|
226
241
|
mock_request.request_id = "req1"
|
|
227
242
|
|
|
228
|
-
# Mock the queues
|
|
243
|
+
# Mock the queues with tuple keys
|
|
229
244
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
230
|
-
scheduler.output_queues =
|
|
231
|
-
|
|
232
|
-
|
|
245
|
+
scheduler.output_queues = {
|
|
246
|
+
(0, "add_request"): MagicMock(),
|
|
247
|
+
(1, "add_request"): MagicMock(),
|
|
248
|
+
}
|
|
233
249
|
|
|
234
250
|
# Mock _find_best_rank_for_request to return rank 1
|
|
235
251
|
scheduler._find_best_rank_for_request = MagicMock(
|
|
@@ -245,7 +261,8 @@ class TestDPScheduler:
|
|
|
245
261
|
(SchedulerCommand.ADD_REQUEST, mock_request))
|
|
246
262
|
|
|
247
263
|
# Verify we waited for completion
|
|
248
|
-
scheduler.output_queues[
|
|
264
|
+
scheduler.output_queues[(
|
|
265
|
+
1, "add_request")].get.assert_called_once()
|
|
249
266
|
|
|
250
267
|
def test_schedule_sends_commands_and_combines_output(
|
|
251
268
|
self, mock_vllm_config, mock_kv_cache_config,
|
|
@@ -262,9 +279,8 @@ class TestDPScheduler:
|
|
|
262
279
|
block_size=16,
|
|
263
280
|
)
|
|
264
281
|
|
|
265
|
-
# Mock the queues
|
|
282
|
+
# Mock the queues with tuple keys
|
|
266
283
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
267
|
-
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
268
284
|
|
|
269
285
|
# Create mock scheduler outputs
|
|
270
286
|
mock_output_0 = MagicMock(spec=SchedulerOutput)
|
|
@@ -303,11 +319,16 @@ class TestDPScheduler:
|
|
|
303
319
|
mock_output_1.scheduled_encoder_inputs = {}
|
|
304
320
|
mock_output_1.num_common_prefix_blocks = []
|
|
305
321
|
|
|
306
|
-
# Setup mock queue responses
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
322
|
+
# Setup mock queue responses with tuple keys - need to mock .get()
|
|
323
|
+
mock_queue_0 = MagicMock()
|
|
324
|
+
mock_queue_0.get.return_value = mock_output_0
|
|
325
|
+
mock_queue_1 = MagicMock()
|
|
326
|
+
mock_queue_1.get.return_value = mock_output_1
|
|
327
|
+
|
|
328
|
+
scheduler.output_queues = {
|
|
329
|
+
(0, "schedule"): mock_queue_0,
|
|
330
|
+
(1, "schedule"): mock_queue_1,
|
|
331
|
+
}
|
|
311
332
|
|
|
312
333
|
# Setup assigned ranks
|
|
313
334
|
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
|
|
@@ -391,9 +412,10 @@ class TestDPScheduler:
|
|
|
391
412
|
)
|
|
392
413
|
|
|
393
414
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
394
|
-
scheduler.output_queues =
|
|
395
|
-
|
|
396
|
-
|
|
415
|
+
scheduler.output_queues = {
|
|
416
|
+
(0, "finish_requests"): MagicMock(),
|
|
417
|
+
(1, "finish_requests"): MagicMock(),
|
|
418
|
+
}
|
|
397
419
|
|
|
398
420
|
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
399
421
|
|
|
@@ -421,10 +443,17 @@ class TestDPScheduler:
|
|
|
421
443
|
)
|
|
422
444
|
|
|
423
445
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
424
|
-
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
425
446
|
|
|
426
|
-
|
|
427
|
-
|
|
447
|
+
# Create proper mocks for queue.get() calls
|
|
448
|
+
mock_queue_0 = MagicMock()
|
|
449
|
+
mock_queue_0.get.return_value = 5
|
|
450
|
+
mock_queue_1 = MagicMock()
|
|
451
|
+
mock_queue_1.get.return_value = 3
|
|
452
|
+
|
|
453
|
+
scheduler.output_queues = {
|
|
454
|
+
(0, "get_num_unfinished_requests"): mock_queue_0,
|
|
455
|
+
(1, "get_num_unfinished_requests"): mock_queue_1,
|
|
456
|
+
}
|
|
428
457
|
|
|
429
458
|
total = scheduler.get_num_unfinished_requests()
|
|
430
459
|
|
|
@@ -452,10 +481,17 @@ class TestDPScheduler:
|
|
|
452
481
|
)
|
|
453
482
|
|
|
454
483
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
455
|
-
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
456
484
|
|
|
457
|
-
|
|
458
|
-
|
|
485
|
+
# Create proper mocks for queue.get() calls
|
|
486
|
+
mock_queue_0 = MagicMock()
|
|
487
|
+
mock_queue_0.get.return_value = False
|
|
488
|
+
mock_queue_1 = MagicMock()
|
|
489
|
+
mock_queue_1.get.return_value = True
|
|
490
|
+
|
|
491
|
+
scheduler.output_queues = {
|
|
492
|
+
(0, "has_finished_requests"): mock_queue_0,
|
|
493
|
+
(1, "has_finished_requests"): mock_queue_1,
|
|
494
|
+
}
|
|
459
495
|
|
|
460
496
|
result = scheduler.has_finished_requests()
|
|
461
497
|
|
|
@@ -482,10 +518,17 @@ class TestDPScheduler:
|
|
|
482
518
|
)
|
|
483
519
|
|
|
484
520
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
485
|
-
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
486
521
|
|
|
487
|
-
|
|
488
|
-
|
|
522
|
+
# Create proper mocks for queue.get() calls
|
|
523
|
+
mock_queue_0 = MagicMock()
|
|
524
|
+
mock_queue_0.get.return_value = (2, 1)
|
|
525
|
+
mock_queue_1 = MagicMock()
|
|
526
|
+
mock_queue_1.get.return_value = (1, 3)
|
|
527
|
+
|
|
528
|
+
scheduler.output_queues = {
|
|
529
|
+
(0, "get_request_counts"): mock_queue_0,
|
|
530
|
+
(1, "get_request_counts"): mock_queue_1,
|
|
531
|
+
}
|
|
489
532
|
|
|
490
533
|
running, waiting = scheduler.get_request_counts()
|
|
491
534
|
|
|
@@ -513,10 +556,17 @@ class TestDPScheduler:
|
|
|
513
556
|
)
|
|
514
557
|
|
|
515
558
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
516
|
-
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
517
559
|
|
|
518
|
-
|
|
519
|
-
|
|
560
|
+
# Create proper mocks for queue.get() calls
|
|
561
|
+
mock_queue_0 = MagicMock()
|
|
562
|
+
mock_queue_0.get.return_value = True
|
|
563
|
+
mock_queue_1 = MagicMock()
|
|
564
|
+
mock_queue_1.get.return_value = True
|
|
565
|
+
|
|
566
|
+
scheduler.output_queues = {
|
|
567
|
+
(0, "reset_prefix_cache"): mock_queue_0,
|
|
568
|
+
(1, "reset_prefix_cache"): mock_queue_1,
|
|
569
|
+
}
|
|
520
570
|
|
|
521
571
|
result = scheduler.reset_prefix_cache()
|
|
522
572
|
|
|
@@ -545,7 +595,6 @@ class TestDPScheduler:
|
|
|
545
595
|
)
|
|
546
596
|
|
|
547
597
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
548
|
-
scheduler.output_queues = [MagicMock(), MagicMock()]
|
|
549
598
|
|
|
550
599
|
# Create mock stats
|
|
551
600
|
stats_0 = SchedulerStats(
|
|
@@ -580,10 +629,16 @@ class TestDPScheduler:
|
|
|
580
629
|
kv_connector_stats=None,
|
|
581
630
|
)
|
|
582
631
|
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
632
|
+
# Create proper mocks for queue.get() calls
|
|
633
|
+
mock_queue_0 = MagicMock()
|
|
634
|
+
mock_queue_0.get.return_value = stats_0
|
|
635
|
+
mock_queue_1 = MagicMock()
|
|
636
|
+
mock_queue_1.get.return_value = stats_1
|
|
637
|
+
|
|
638
|
+
scheduler.output_queues = {
|
|
639
|
+
(0, "make_stats"): mock_queue_0,
|
|
640
|
+
(1, "make_stats"): mock_queue_1,
|
|
641
|
+
}
|
|
587
642
|
|
|
588
643
|
combined_stats = scheduler.make_stats()
|
|
589
644
|
|
|
@@ -632,9 +687,10 @@ class TestDPScheduler:
|
|
|
632
687
|
)
|
|
633
688
|
|
|
634
689
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
635
|
-
scheduler.output_queues =
|
|
636
|
-
|
|
637
|
-
|
|
690
|
+
scheduler.output_queues = {
|
|
691
|
+
(0, "update_draft_token_ids"): MagicMock(),
|
|
692
|
+
(1, "update_draft_token_ids"): MagicMock(),
|
|
693
|
+
}
|
|
638
694
|
|
|
639
695
|
scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
|
|
640
696
|
|
|
@@ -667,9 +723,10 @@ class TestDPScheduler:
|
|
|
667
723
|
)
|
|
668
724
|
|
|
669
725
|
scheduler.input_queues = [MagicMock(), MagicMock()]
|
|
670
|
-
scheduler.output_queues =
|
|
671
|
-
|
|
672
|
-
|
|
726
|
+
scheduler.output_queues = {
|
|
727
|
+
(0, "shutdown"): MagicMock(),
|
|
728
|
+
(1, "shutdown"): MagicMock(),
|
|
729
|
+
}
|
|
673
730
|
|
|
674
731
|
mock_process_0 = MagicMock()
|
|
675
732
|
mock_process_1 = MagicMock()
|