tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. tests/core/test_dp_scheduler.py +128 -71
  2. tests/e2e/test_data_parallel.py +176 -280
  3. tests/e2e/test_hybrid_kvcache.py +219 -0
  4. tests/e2e/test_speculative_decoding.py +26 -6
  5. tests/layers/jax/test_qwix.py +1 -1
  6. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
  7. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
  8. tests/layers/vllm/test_mxfp4.py +25 -10
  9. tests/layers/vllm/test_unquantized.py +61 -31
  10. tests/layers/vllm/utils.py +19 -4
  11. tests/models/common/test_model_loader.py +2 -2
  12. tests/models/jax/test_qwen2_5_vl.py +10 -11
  13. tests/runner/test_multimodal_manager.py +3 -3
  14. tests/runner/test_tpu_runner.py +67 -8
  15. tests/runner/test_tpu_runner_dp.py +66 -0
  16. tpu_inference/core/sched/dp_scheduler.py +65 -40
  17. tpu_inference/kernels/mla/v1/kernel.py +7 -26
  18. tpu_inference/layers/common/sharding.py +8 -3
  19. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
  20. tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
  21. tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
  22. tpu_inference/layers/jax/sample/sampling.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +51 -47
  24. tpu_inference/layers/vllm/quantization/common.py +14 -13
  25. tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
  26. tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
  27. tpu_inference/layers/vllm/sharding.py +7 -4
  28. tpu_inference/models/common/model_loader.py +11 -14
  29. tpu_inference/models/jax/llama3.py +13 -10
  30. tpu_inference/models/jax/llama_guard_4.py +1 -1
  31. tpu_inference/models/jax/qwen2.py +3 -2
  32. tpu_inference/models/jax/qwen2_5_vl.py +4 -4
  33. tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
  34. tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
  35. tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
  36. tpu_inference/platforms/tpu_platform.py +7 -7
  37. tpu_inference/runner/compilation_manager.py +43 -33
  38. tpu_inference/runner/kv_cache_manager.py +1 -2
  39. tpu_inference/runner/multimodal_manager.py +1 -1
  40. tpu_inference/runner/tpu_runner.py +12 -9
  41. tpu_inference/utils.py +31 -30
  42. tpu_inference/worker/tpu_worker.py +5 -2
  43. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
  44. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
  45. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
  46. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
  47. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
@@ -86,7 +86,9 @@ class TestDPScheduler:
86
86
  assert scheduler.dp_size == 2
87
87
  assert len(scheduler.processes) == 2
88
88
  assert len(scheduler.input_queues) == 2
89
- assert len(scheduler.output_queues) == 2
89
+ # output_queues is a dict with (rank, command) tuple keys
90
+ # 2 ranks × 14 commands (SchedulerCommand enum)
91
+ assert len(scheduler.output_queues) == 28
90
92
  assert scheduler.log_stats is True
91
93
  assert len(scheduler.per_rank_kv_cache_configs) == 2
92
94
 
@@ -112,13 +114,18 @@ class TestDPScheduler:
112
114
  block_size=16,
113
115
  )
114
116
 
115
- # Mock the queues
117
+ # Mock the queues - need to mock the .get() method to return the value
116
118
  scheduler.input_queues = [MagicMock(), MagicMock()]
117
- scheduler.output_queues = [MagicMock(), MagicMock()]
118
119
 
119
- # Mock responses from workers
120
- scheduler.output_queues[0].get = MagicMock(return_value=30)
121
- scheduler.output_queues[1].get = MagicMock(return_value=15)
120
+ mock_queue_0 = MagicMock()
121
+ mock_queue_0.get.return_value = 30
122
+ mock_queue_1 = MagicMock()
123
+ mock_queue_1.get.return_value = 15
124
+
125
+ scheduler.output_queues = {
126
+ (0, "get_token_count"): mock_queue_0,
127
+ (1, "get_token_count"): mock_queue_1,
128
+ }
122
129
 
123
130
  rank_tokens = scheduler._get_rank_token_counts()
124
131
 
@@ -148,27 +155,25 @@ class TestDPScheduler:
148
155
 
149
156
  mock_request = MagicMock(spec=Request)
150
157
 
151
- # Mock the queues
158
+ # Mock the queues with tuple keys (rank, command)
152
159
  scheduler.input_queues = [MagicMock(), MagicMock()]
153
- scheduler.output_queues = [MagicMock(), MagicMock()]
154
-
155
- # Track call counts for proper sequencing
156
- call_sequence = [100, 50, ([], 10), ([], 25)]
157
-
158
- # Both queues use the same sequence
159
- for q in scheduler.output_queues:
160
- q.get = MagicMock(
161
- side_effect=lambda timeout=None: call_sequence[len([
162
- c for c in scheduler.output_queues if c.get.called
163
- ])])
164
-
165
- # Simpler mock setup
166
- responses_0 = [100, ([], 10)]
167
- responses_1 = [50, ([], 25)]
168
- scheduler.output_queues[0].get = MagicMock(
169
- side_effect=responses_0)
170
- scheduler.output_queues[1].get = MagicMock(
171
- side_effect=responses_1)
160
+
161
+ # Create proper mocks for queue.get() calls
162
+ mock_queue_get_token_0 = MagicMock()
163
+ mock_queue_get_token_0.get.return_value = 100
164
+ mock_queue_get_token_1 = MagicMock()
165
+ mock_queue_get_token_1.get.return_value = 50
166
+ mock_queue_computed_0 = MagicMock()
167
+ mock_queue_computed_0.get.return_value = ([], 10)
168
+ mock_queue_computed_1 = MagicMock()
169
+ mock_queue_computed_1.get.return_value = ([], 25)
170
+
171
+ scheduler.output_queues = {
172
+ (0, "get_token_count"): mock_queue_get_token_0,
173
+ (1, "get_token_count"): mock_queue_get_token_1,
174
+ (0, "get_computed_blocks"): mock_queue_computed_0,
175
+ (1, "get_computed_blocks"): mock_queue_computed_1,
176
+ }
172
177
 
173
178
  rank = scheduler._find_best_rank_for_request(mock_request)
174
179
 
@@ -192,15 +197,25 @@ class TestDPScheduler:
192
197
 
193
198
  mock_request = MagicMock(spec=Request)
194
199
 
195
- # Mock the queues
200
+ # Mock the queues with tuple keys (rank, command)
196
201
  scheduler.input_queues = [MagicMock(), MagicMock()]
197
- scheduler.output_queues = [MagicMock(), MagicMock()]
198
202
 
199
- # No cache hits - both return 0
200
- scheduler.output_queues[0].get = MagicMock(
201
- side_effect=[100, ([], 0)])
202
- scheduler.output_queues[1].get = MagicMock(
203
- side_effect=[50, ([], 0)])
203
+ # Create proper mocks for queue.get() calls
204
+ mock_queue_get_token_0 = MagicMock()
205
+ mock_queue_get_token_0.get.return_value = 100
206
+ mock_queue_get_token_1 = MagicMock()
207
+ mock_queue_get_token_1.get.return_value = 50
208
+ mock_queue_computed_0 = MagicMock()
209
+ mock_queue_computed_0.get.return_value = ([], 0)
210
+ mock_queue_computed_1 = MagicMock()
211
+ mock_queue_computed_1.get.return_value = ([], 0)
212
+
213
+ scheduler.output_queues = {
214
+ (0, "get_token_count"): mock_queue_get_token_0,
215
+ (1, "get_token_count"): mock_queue_get_token_1,
216
+ (0, "get_computed_blocks"): mock_queue_computed_0,
217
+ (1, "get_computed_blocks"): mock_queue_computed_1,
218
+ }
204
219
 
205
220
  rank = scheduler._find_best_rank_for_request(mock_request)
206
221
 
@@ -225,11 +240,12 @@ class TestDPScheduler:
225
240
  mock_request = MagicMock(spec=Request)
226
241
  mock_request.request_id = "req1"
227
242
 
228
- # Mock the queues
243
+ # Mock the queues with tuple keys
229
244
  scheduler.input_queues = [MagicMock(), MagicMock()]
230
- scheduler.output_queues = [MagicMock(), MagicMock()]
231
- scheduler.output_queues[0].get = MagicMock()
232
- scheduler.output_queues[1].get = MagicMock()
245
+ scheduler.output_queues = {
246
+ (0, "add_request"): MagicMock(),
247
+ (1, "add_request"): MagicMock(),
248
+ }
233
249
 
234
250
  # Mock _find_best_rank_for_request to return rank 1
235
251
  scheduler._find_best_rank_for_request = MagicMock(
@@ -245,7 +261,8 @@ class TestDPScheduler:
245
261
  (SchedulerCommand.ADD_REQUEST, mock_request))
246
262
 
247
263
  # Verify we waited for completion
248
- scheduler.output_queues[1].get.assert_called_once()
264
+ scheduler.output_queues[(
265
+ 1, "add_request")].get.assert_called_once()
249
266
 
250
267
  def test_schedule_sends_commands_and_combines_output(
251
268
  self, mock_vllm_config, mock_kv_cache_config,
@@ -262,9 +279,8 @@ class TestDPScheduler:
262
279
  block_size=16,
263
280
  )
264
281
 
265
- # Mock the queues
282
+ # Mock the queues with tuple keys
266
283
  scheduler.input_queues = [MagicMock(), MagicMock()]
267
- scheduler.output_queues = [MagicMock(), MagicMock()]
268
284
 
269
285
  # Create mock scheduler outputs
270
286
  mock_output_0 = MagicMock(spec=SchedulerOutput)
@@ -303,11 +319,16 @@ class TestDPScheduler:
303
319
  mock_output_1.scheduled_encoder_inputs = {}
304
320
  mock_output_1.num_common_prefix_blocks = []
305
321
 
306
- # Setup mock queue responses
307
- scheduler.output_queues[0].get = MagicMock(
308
- return_value=mock_output_0)
309
- scheduler.output_queues[1].get = MagicMock(
310
- return_value=mock_output_1)
322
+ # Setup mock queue responses with tuple keys - need to mock .get()
323
+ mock_queue_0 = MagicMock()
324
+ mock_queue_0.get.return_value = mock_output_0
325
+ mock_queue_1 = MagicMock()
326
+ mock_queue_1.get.return_value = mock_output_1
327
+
328
+ scheduler.output_queues = {
329
+ (0, "schedule"): mock_queue_0,
330
+ (1, "schedule"): mock_queue_1,
331
+ }
311
332
 
312
333
  # Setup assigned ranks
313
334
  scheduler.assigned_dp_rank = {"req1": 0, "req2": 1}
@@ -391,9 +412,10 @@ class TestDPScheduler:
391
412
  )
392
413
 
393
414
  scheduler.input_queues = [MagicMock(), MagicMock()]
394
- scheduler.output_queues = [MagicMock(), MagicMock()]
395
- scheduler.output_queues[0].get = MagicMock()
396
- scheduler.output_queues[1].get = MagicMock()
415
+ scheduler.output_queues = {
416
+ (0, "finish_requests"): MagicMock(),
417
+ (1, "finish_requests"): MagicMock(),
418
+ }
397
419
 
398
420
  scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
399
421
 
@@ -421,10 +443,17 @@ class TestDPScheduler:
421
443
  )
422
444
 
423
445
  scheduler.input_queues = [MagicMock(), MagicMock()]
424
- scheduler.output_queues = [MagicMock(), MagicMock()]
425
446
 
426
- scheduler.output_queues[0].get = MagicMock(return_value=5)
427
- scheduler.output_queues[1].get = MagicMock(return_value=3)
447
+ # Create proper mocks for queue.get() calls
448
+ mock_queue_0 = MagicMock()
449
+ mock_queue_0.get.return_value = 5
450
+ mock_queue_1 = MagicMock()
451
+ mock_queue_1.get.return_value = 3
452
+
453
+ scheduler.output_queues = {
454
+ (0, "get_num_unfinished_requests"): mock_queue_0,
455
+ (1, "get_num_unfinished_requests"): mock_queue_1,
456
+ }
428
457
 
429
458
  total = scheduler.get_num_unfinished_requests()
430
459
 
@@ -452,10 +481,17 @@ class TestDPScheduler:
452
481
  )
453
482
 
454
483
  scheduler.input_queues = [MagicMock(), MagicMock()]
455
- scheduler.output_queues = [MagicMock(), MagicMock()]
456
484
 
457
- scheduler.output_queues[0].get = MagicMock(return_value=False)
458
- scheduler.output_queues[1].get = MagicMock(return_value=True)
485
+ # Create proper mocks for queue.get() calls
486
+ mock_queue_0 = MagicMock()
487
+ mock_queue_0.get.return_value = False
488
+ mock_queue_1 = MagicMock()
489
+ mock_queue_1.get.return_value = True
490
+
491
+ scheduler.output_queues = {
492
+ (0, "has_finished_requests"): mock_queue_0,
493
+ (1, "has_finished_requests"): mock_queue_1,
494
+ }
459
495
 
460
496
  result = scheduler.has_finished_requests()
461
497
 
@@ -482,10 +518,17 @@ class TestDPScheduler:
482
518
  )
483
519
 
484
520
  scheduler.input_queues = [MagicMock(), MagicMock()]
485
- scheduler.output_queues = [MagicMock(), MagicMock()]
486
521
 
487
- scheduler.output_queues[0].get = MagicMock(return_value=(2, 1))
488
- scheduler.output_queues[1].get = MagicMock(return_value=(1, 3))
522
+ # Create proper mocks for queue.get() calls
523
+ mock_queue_0 = MagicMock()
524
+ mock_queue_0.get.return_value = (2, 1)
525
+ mock_queue_1 = MagicMock()
526
+ mock_queue_1.get.return_value = (1, 3)
527
+
528
+ scheduler.output_queues = {
529
+ (0, "get_request_counts"): mock_queue_0,
530
+ (1, "get_request_counts"): mock_queue_1,
531
+ }
489
532
 
490
533
  running, waiting = scheduler.get_request_counts()
491
534
 
@@ -513,10 +556,17 @@ class TestDPScheduler:
513
556
  )
514
557
 
515
558
  scheduler.input_queues = [MagicMock(), MagicMock()]
516
- scheduler.output_queues = [MagicMock(), MagicMock()]
517
559
 
518
- scheduler.output_queues[0].get = MagicMock(return_value=True)
519
- scheduler.output_queues[1].get = MagicMock(return_value=True)
560
+ # Create proper mocks for queue.get() calls
561
+ mock_queue_0 = MagicMock()
562
+ mock_queue_0.get.return_value = True
563
+ mock_queue_1 = MagicMock()
564
+ mock_queue_1.get.return_value = True
565
+
566
+ scheduler.output_queues = {
567
+ (0, "reset_prefix_cache"): mock_queue_0,
568
+ (1, "reset_prefix_cache"): mock_queue_1,
569
+ }
520
570
 
521
571
  result = scheduler.reset_prefix_cache()
522
572
 
@@ -545,7 +595,6 @@ class TestDPScheduler:
545
595
  )
546
596
 
547
597
  scheduler.input_queues = [MagicMock(), MagicMock()]
548
- scheduler.output_queues = [MagicMock(), MagicMock()]
549
598
 
550
599
  # Create mock stats
551
600
  stats_0 = SchedulerStats(
@@ -580,10 +629,16 @@ class TestDPScheduler:
580
629
  kv_connector_stats=None,
581
630
  )
582
631
 
583
- scheduler.output_queues[0].get = MagicMock(
584
- return_value=stats_0)
585
- scheduler.output_queues[1].get = MagicMock(
586
- return_value=stats_1)
632
+ # Create proper mocks for queue.get() calls
633
+ mock_queue_0 = MagicMock()
634
+ mock_queue_0.get.return_value = stats_0
635
+ mock_queue_1 = MagicMock()
636
+ mock_queue_1.get.return_value = stats_1
637
+
638
+ scheduler.output_queues = {
639
+ (0, "make_stats"): mock_queue_0,
640
+ (1, "make_stats"): mock_queue_1,
641
+ }
587
642
 
588
643
  combined_stats = scheduler.make_stats()
589
644
 
@@ -632,9 +687,10 @@ class TestDPScheduler:
632
687
  )
633
688
 
634
689
  scheduler.input_queues = [MagicMock(), MagicMock()]
635
- scheduler.output_queues = [MagicMock(), MagicMock()]
636
- scheduler.output_queues[0].get = MagicMock()
637
- scheduler.output_queues[1].get = MagicMock()
690
+ scheduler.output_queues = {
691
+ (0, "update_draft_token_ids"): MagicMock(),
692
+ (1, "update_draft_token_ids"): MagicMock(),
693
+ }
638
694
 
639
695
  scheduler.assigned_dp_rank = {"req1": 0, "req2": 1, "req3": 0}
640
696
 
@@ -667,9 +723,10 @@ class TestDPScheduler:
667
723
  )
668
724
 
669
725
  scheduler.input_queues = [MagicMock(), MagicMock()]
670
- scheduler.output_queues = [MagicMock(), MagicMock()]
671
- scheduler.output_queues[0].get = MagicMock()
672
- scheduler.output_queues[1].get = MagicMock()
726
+ scheduler.output_queues = {
727
+ (0, "shutdown"): MagicMock(),
728
+ (1, "shutdown"): MagicMock(),
729
+ }
673
730
 
674
731
  mock_process_0 = MagicMock()
675
732
  mock_process_1 = MagicMock()