tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. tests/core/test_dp_scheduler.py +128 -71
  2. tests/e2e/test_data_parallel.py +176 -280
  3. tests/e2e/test_hybrid_kvcache.py +219 -0
  4. tests/e2e/test_speculative_decoding.py +26 -6
  5. tests/layers/jax/test_qwix.py +1 -1
  6. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
  7. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
  8. tests/layers/vllm/test_mxfp4.py +25 -10
  9. tests/layers/vllm/test_unquantized.py +61 -31
  10. tests/layers/vllm/utils.py +19 -4
  11. tests/models/common/test_model_loader.py +2 -2
  12. tests/models/jax/test_qwen2_5_vl.py +10 -11
  13. tests/runner/test_multimodal_manager.py +3 -3
  14. tests/runner/test_tpu_runner.py +67 -8
  15. tests/runner/test_tpu_runner_dp.py +66 -0
  16. tpu_inference/core/sched/dp_scheduler.py +65 -40
  17. tpu_inference/kernels/mla/v1/kernel.py +7 -26
  18. tpu_inference/layers/common/sharding.py +8 -3
  19. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
  20. tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
  21. tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
  22. tpu_inference/layers/jax/sample/sampling.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +51 -47
  24. tpu_inference/layers/vllm/quantization/common.py +14 -13
  25. tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
  26. tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
  27. tpu_inference/layers/vllm/sharding.py +7 -4
  28. tpu_inference/models/common/model_loader.py +11 -14
  29. tpu_inference/models/jax/llama3.py +13 -10
  30. tpu_inference/models/jax/llama_guard_4.py +1 -1
  31. tpu_inference/models/jax/qwen2.py +3 -2
  32. tpu_inference/models/jax/qwen2_5_vl.py +4 -4
  33. tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
  34. tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
  35. tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
  36. tpu_inference/platforms/tpu_platform.py +7 -7
  37. tpu_inference/runner/compilation_manager.py +43 -33
  38. tpu_inference/runner/kv_cache_manager.py +1 -2
  39. tpu_inference/runner/multimodal_manager.py +1 -1
  40. tpu_inference/runner/tpu_runner.py +12 -9
  41. tpu_inference/utils.py +31 -30
  42. tpu_inference/worker/tpu_worker.py +5 -2
  43. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
  44. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
  45. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
  46. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
  47. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
@@ -118,12 +118,16 @@ def test_loading_model(model, mesh):
118
118
 
119
119
  @pytest.mark.parametrize("model", MODELS)
120
120
  @pytest.mark.parametrize("bias", [False, True])
121
- @pytest.mark.parametrize("mesh", [
122
- test_utils.get_spmd_mesh(1),
123
- test_utils.get_spmd_mesh(jax.local_device_count())
124
- ])
121
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
125
122
  @pytest.mark.parametrize("enable_sp", [False, True])
126
- def test_row_parallel_linear(model, bias, mesh, enable_sp):
123
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
124
+ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
125
+ enable_attn_dp):
126
+ # Skip if enable_attn_dp is True but we don't have enough devices
127
+ if enable_attn_dp and num_devices < 2:
128
+ pytest.skip("enable_attn_dp requires at least 2 devices")
129
+
130
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
127
131
  dtype = torch.bfloat16
128
132
 
129
133
  engine_args = EngineArgs(
@@ -191,12 +195,16 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp):
191
195
 
192
196
  @pytest.mark.parametrize("model", MODELS)
193
197
  @pytest.mark.parametrize("bias", [False, True])
194
- @pytest.mark.parametrize("mesh", [
195
- test_utils.get_spmd_mesh(1),
196
- test_utils.get_spmd_mesh(jax.local_device_count())
197
- ])
198
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
198
199
  @pytest.mark.parametrize("enable_sp", [False, True])
199
- def test_column_parallel_linear(model, bias, mesh, enable_sp):
200
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
201
+ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
202
+ enable_attn_dp):
203
+ # Skip if enable_attn_dp is True but we don't have enough devices
204
+ if enable_attn_dp and num_devices < 2:
205
+ pytest.skip("enable_attn_dp requires at least 2 devices")
206
+
207
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
200
208
  dtype = torch.bfloat16
201
209
 
202
210
  engine_args = EngineArgs(
@@ -263,13 +271,17 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp):
263
271
 
264
272
  @pytest.mark.parametrize("model", MODELS)
265
273
  @pytest.mark.parametrize("bias", [False, True])
266
- @pytest.mark.parametrize("mesh", [
267
- test_utils.get_spmd_mesh(1),
268
- test_utils.get_spmd_mesh(jax.local_device_count())
269
- ])
274
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
270
275
  @pytest.mark.parametrize("enable_sp", [False, True])
271
276
  @pytest.mark.parametrize("fuse_matmuls", [False, True])
272
- def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
277
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
278
+ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
279
+ enable_attn_dp):
280
+ # Skip if enable_attn_dp is True but we don't have enough devices
281
+ if enable_attn_dp and num_devices < 2:
282
+ pytest.skip("enable_attn_dp requires at least 2 devices")
283
+
284
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
273
285
  dtype = torch.bfloat16
274
286
 
275
287
  engine_args = EngineArgs(
@@ -341,14 +353,17 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
341
353
 
342
354
  @pytest.mark.parametrize("model", MODELS)
343
355
  @pytest.mark.parametrize("bias", [False, True])
344
- @pytest.mark.parametrize("mesh", [
345
- test_utils.get_spmd_mesh(1),
346
- test_utils.get_spmd_mesh(jax.local_device_count())
347
- ])
356
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
348
357
  @pytest.mark.parametrize("fuse_matmuls", [False, True])
349
358
  @pytest.mark.parametrize("enable_sp", [False, True])
350
- def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
351
- enable_sp):
359
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
360
+ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
361
+ enable_sp, enable_attn_dp):
362
+ # Skip if enable_attn_dp is True but we don't have enough devices
363
+ if enable_attn_dp and num_devices < 2:
364
+ pytest.skip("enable_attn_dp requires at least 2 devices")
365
+
366
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
352
367
  dtype = torch.bfloat16
353
368
 
354
369
  engine_args = EngineArgs(
@@ -418,10 +433,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
418
433
 
419
434
 
420
435
  @pytest.mark.parametrize("use_ep", [True, False])
421
- @pytest.mark.parametrize("mesh", [
422
- test_utils.get_spmd_mesh(1),
423
- test_utils.get_spmd_mesh(jax.local_device_count())
424
- ])
436
+ @pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
425
437
  @pytest.mark.parametrize("num_tokens", [8])
426
438
  @pytest.mark.parametrize("intermediate_size", [1024, 2048])
427
439
  @pytest.mark.parametrize("hidden_size", [128, 512])
@@ -429,8 +441,15 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
429
441
  @pytest.mark.parametrize("topk", [2])
430
442
  @pytest.mark.parametrize("has_bias", [False, True])
431
443
  @pytest.mark.parametrize("activation", ["silu", "swigluoai"])
432
- def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
433
- num_experts, topk, has_bias, activation):
444
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
445
+ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
446
+ hidden_size, num_experts, topk, has_bias, activation,
447
+ enable_attn_dp):
448
+ # Skip if enable_attn_dp is True but we don't have enough devices
449
+ if enable_attn_dp and num_devices < 2:
450
+ pytest.skip("enable_attn_dp requires at least 2 devices")
451
+
452
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
434
453
 
435
454
  torch.manual_seed(42)
436
455
  dtype = torch.bfloat16
@@ -502,16 +521,27 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
502
521
  rtol=1e-1)
503
522
 
504
523
 
505
- @pytest.mark.parametrize("mesh",
506
- [test_utils.get_spmd_mesh(jax.local_device_count())])
524
+ @pytest.mark.parametrize("num_devices", [jax.local_device_count()])
507
525
  @pytest.mark.parametrize("num_tokens", [128, 512])
508
526
  @pytest.mark.parametrize("intermediate_size", [512])
509
527
  @pytest.mark.parametrize("hidden_size", [512])
510
528
  @pytest.mark.parametrize("num_experts", [32])
511
529
  @pytest.mark.parametrize("topk", [8])
512
530
  @pytest.mark.parametrize("has_bias", [False, True])
513
- def test_fused_moe_use_kernel(mesh, num_tokens, intermediate_size, hidden_size,
514
- num_experts, topk, has_bias):
531
+ @pytest.mark.parametrize("enable_attn_dp", [False, True])
532
+ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
533
+ hidden_size, num_experts, topk, has_bias,
534
+ enable_attn_dp):
535
+ # Skip if enable_attn_dp is True but we don't have enough devices
536
+ if enable_attn_dp and num_devices < 2:
537
+ pytest.skip("enable_attn_dp requires at least 2 devices")
538
+
539
+ # Skip attn_dp tests for fused_moe_use_kernel since the kernel only supports 2D mesh
540
+ if enable_attn_dp:
541
+ pytest.skip(
542
+ "fused_moe kernel does not support attn_dp (requires 2D mesh)")
543
+
544
+ mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
515
545
 
516
546
  # TODO(Qiliang Cui): Remove when issue is resolved.
517
547
  if not jtu.is_device_tpu_at_least(version=7):
@@ -16,12 +16,27 @@ import jax
16
16
  import torch
17
17
  import torch.nn.functional as F
18
18
 
19
+ from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
20
+ MESH_AXIS_NAMES_2D)
19
21
 
20
- def get_spmd_mesh(num_devices: int = 1):
21
- axis_names = ("data", "model")
22
+
23
+ def get_spmd_mesh(num_devices: int = 1, enable_attn_dp: bool = False):
22
24
  devices = sorted(jax.devices(), key=lambda d: d.id)[0:num_devices]
23
- mesh_shape = (1, len(devices))
24
- return jax.make_mesh(mesh_shape, axis_names, devices=devices)
25
+
26
+ if enable_attn_dp:
27
+ if num_devices < 2:
28
+ raise ValueError(
29
+ f"enable_attn_dp requires at least 2 devices, got {num_devices}"
30
+ )
31
+ axis_names = MESH_AXIS_NAMES
32
+ attn_dp_size = 2
33
+ model_size = num_devices // attn_dp_size
34
+ mesh_shape = (1, attn_dp_size, 1, model_size)
35
+ return jax.make_mesh(mesh_shape, axis_names, devices=devices)
36
+ else:
37
+ axis_names = MESH_AXIS_NAMES_2D
38
+ mesh_shape = (1, len(devices))
39
+ return jax.make_mesh(mesh_shape, axis_names, devices=devices)
25
40
 
26
41
 
27
42
  def find_all_layer_type(module: torch.nn.Module, layer_type: torch.nn.Module):
@@ -218,9 +218,9 @@ def test_register_model_vllm_wrapper_methods():
218
218
  with pytest.raises(NotImplementedError, match="JAX model"):
219
219
  instance.forward(input_ids=None, positions=None)
220
220
 
221
- # `get_input_embeddings` should be unimplemented.
221
+ # `embed_input_ids` should be unimplemented.
222
222
  with pytest.raises(NotImplementedError, match="JAX model"):
223
- instance.get_input_embeddings(input_ids=None, positions=None)
223
+ instance.embed_input_ids(input_ids=None, positions=None)
224
224
 
225
225
  # `load_weights` should be a no-op that returns None.
226
226
  assert instance.load_weights() is None
@@ -491,8 +491,7 @@ class TestQwen2_5_VLForConditionalGeneration:
491
491
  assert embeddings[1].shape == (tokens_per_image, vc.out_hidden_size)
492
492
  assert model.visual.call_count == 2
493
493
 
494
- def test_get_multimodal_embeddings(
495
- self, model: Qwen2_5_VLForConditionalGeneration):
494
+ def test_embed_multimodal(self, model: Qwen2_5_VLForConditionalGeneration):
496
495
  grid_thw = ((2, 28, 28), )
497
496
  vc = model.config.vision_config
498
497
  patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
@@ -503,20 +502,20 @@ class TestQwen2_5_VLForConditionalGeneration:
503
502
  with patch.object(model,
504
503
  '_process_image_input',
505
504
  return_value=(mock_vision_output, )) as mock_process:
506
- mm_embeds = model.get_multimodal_embeddings(
507
- grid_thw, pixel_values=pixel_values)
505
+ mm_embeds = model.embed_multimodal(grid_thw,
506
+ pixel_values=pixel_values)
508
507
  mock_process.assert_called_once()
509
508
  assert isinstance(mm_embeds, tuple)
510
509
  assert len(mm_embeds) == 1
511
510
  assert mm_embeds[0].shape == (tokens_per_image, vc.out_hidden_size)
512
511
 
513
- mm_embeds_none = model.get_multimodal_embeddings(grid_thw)
512
+ mm_embeds_none = model.embed_multimodal(grid_thw)
514
513
  assert len(mm_embeds_none) == 0
515
514
 
516
515
  @patch('tpu_inference.models.jax.qwen2_5_vl.merge_multimodal_embeddings')
517
- def test_get_input_embeddings(self, mock_merge_embeddings: MagicMock,
518
- model: Qwen2_5_VLForConditionalGeneration,
519
- rng: PRNGKey):
516
+ def test_embed_input_ids(self, mock_merge_embeddings: MagicMock,
517
+ model: Qwen2_5_VLForConditionalGeneration,
518
+ rng: PRNGKey):
520
519
  input_ids = jax.random.randint(rng, (1, 10), 0,
521
520
  model.config.vocab_size)
522
521
  mock_text_embeds = jnp.ones((1, 10, model.config.hidden_size))
@@ -524,12 +523,12 @@ class TestQwen2_5_VLForConditionalGeneration:
524
523
  model.language_model.model.embed = MagicMock(
525
524
  return_value=mock_text_embeds)
526
525
 
527
- embeds = model.get_input_embeddings(input_ids, None)
526
+ embeds = model.embed_input_ids(input_ids, None)
528
527
  np.testing.assert_array_equal(embeds, mock_text_embeds)
529
528
  mock_merge_embeddings.assert_not_called()
530
529
 
531
530
  empty_mm = jnp.ones((0, model.config.hidden_size), )
532
- embeds_empty_mm = model.get_input_embeddings(input_ids, empty_mm)
531
+ embeds_empty_mm = model.embed_input_ids(input_ids, empty_mm)
533
532
  np.testing.assert_array_equal(embeds_empty_mm, mock_text_embeds)
534
533
  mock_merge_embeddings.assert_not_called()
535
534
 
@@ -537,7 +536,7 @@ class TestQwen2_5_VLForConditionalGeneration:
537
536
  mock_merged = jnp.ones((1, 15, model.config.hidden_size))
538
537
  mock_merge_embeddings.return_value = mock_merged
539
538
 
540
- embeds_mm = model.get_input_embeddings(input_ids, mm_embeds)
539
+ embeds_mm = model.embed_input_ids(input_ids, mm_embeds)
541
540
  np.testing.assert_array_equal(embeds_mm, mock_merged)
542
541
  mock_merge_embeddings.assert_called_once_with(
543
542
  input_ids, mock_text_embeds, mm_embeds,
@@ -88,7 +88,7 @@ class TestMultiModalManager:
88
88
  # 1. ===== Setup =====
89
89
  self.runner.is_multimodal_model = True
90
90
  self.mock_get_mm_embed_fn = MagicMock()
91
- self.runner.get_multimodal_embeddings_fn = self.mock_get_mm_embed_fn
91
+ self.runner.embed_multimodal_fn = self.mock_get_mm_embed_fn
92
92
 
93
93
  self.runner.state = MagicMock()
94
94
  # Mock scheduler output
@@ -139,7 +139,7 @@ class TestMultiModalManager:
139
139
  np.testing.assert_array_equal(np.asarray(cached_embedding),
140
140
  np.asarray(dummy_embedding))
141
141
 
142
- # Check if get_multimodal_embeddings_fn was called with correct args
142
+ # Check if embed_multimodal_fn was called with correct args
143
143
  self.mock_get_mm_embed_fn.assert_called_once()
144
144
  call_args = self.mock_get_mm_embed_fn.call_args
145
145
 
@@ -169,7 +169,7 @@ class TestMultiModalManager:
169
169
  # 1. ===== Setup =====
170
170
  self.runner.is_multimodal_model = True
171
171
  self.mock_get_mm_embed_fn = MagicMock()
172
- self.runner.get_multimodal_embeddings_fn = self.mock_get_mm_embed_fn
172
+ self.runner.embed_multimodal_fn = self.mock_get_mm_embed_fn
173
173
 
174
174
  self.runner.state = MagicMock()
175
175
  # Mock scheduler output for two requests
@@ -88,7 +88,7 @@ class TestTPUJaxRunner:
88
88
 
89
89
  # Mock the embedding function
90
90
  self.mock_get_input_embed_fn = MagicMock()
91
- self.runner.get_input_embeddings_fn = self.mock_get_input_embed_fn
91
+ self.runner.embed_input_ids_fn = self.mock_get_input_embed_fn
92
92
  self.mock_get_input_embed_fn.return_value = dummy_final_embeds
93
93
  self.runner.state = MagicMock()
94
94
 
@@ -116,6 +116,65 @@ class TestTPUJaxRunner:
116
116
  np.asarray(dummy_input_ids))
117
117
  self.mock_get_input_embed_fn.assert_not_called()
118
118
 
119
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
120
+ def test_prepare_inputs_hybrid_kvcache(self, mock_sampling_metadata):
121
+ # create hybrid kv cache config
122
+ # 20 layers, 10 full attn + 10 sw attn
123
+ self._create_mock_hybrid_kv_cache_config()
124
+
125
+ # Mock scheduler output.
126
+ scheduler_output = MagicMock()
127
+ scheduler_output.total_num_scheduled_tokens = 10
128
+ scheduler_output.num_scheduled_tokens = {'req1': 10}
129
+ scheduler_output.scheduled_spec_decode_tokens = {}
130
+ scheduler_output.grammar_bitmask = None
131
+
132
+ # Mock input_batch
133
+ self.runner.input_batch = MagicMock()
134
+ self.runner.input_batch.num_reqs = 1
135
+ self.runner.input_batch.req_ids = ['req1']
136
+ self.runner.input_batch.req_id_to_index = {'req1': 0}
137
+ self.runner.input_batch.num_computed_tokens_cpu = np.array([10])
138
+ self.runner.input_batch.token_ids_cpu = np.random.randint(
139
+ 0, 1000, (8, 64), dtype=np.int32)
140
+
141
+ # Mock block tables
142
+ # there will be 2 block tables since there are 2 kv cache groups
143
+ mock_block_table = MagicMock()
144
+ mock_block_table.get_cpu_tensor.return_value = np.zeros(
145
+ self.runner.block_tables_cpu[0].shape)
146
+ self.runner.input_batch.block_table = [
147
+ mock_block_table, mock_block_table
148
+ ]
149
+ self.runner.block_tables_cpu = [
150
+ np.zeros(self.runner.block_tables_cpu[0].shape, dtype=np.int32),
151
+ np.zeros(self.runner.block_tables_cpu[0].shape, dtype=np.int32)
152
+ ]
153
+
154
+ mock_sampling_instance = MagicMock()
155
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
156
+
157
+ output = self.runner._prepare_inputs_non_dp(scheduler_output)
158
+ assert len(output) == 8
159
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = output
160
+ # assert it will create attention metadata for each layer.
161
+ assert isinstance(attention_metadata, dict)
162
+ assert len(attention_metadata) == 20
163
+
164
+ def _create_mock_hybrid_kv_cache_config(self):
165
+ mock_kv_cache_config = MagicMock()
166
+ mock_kv_cache_group1 = MagicMock()
167
+ mock_kv_cache_group1.layer_names = [f'layer.{i}' for i in range(10)]
168
+ mock_kv_cache_group2 = MagicMock()
169
+ mock_kv_cache_group2.layer_names = [
170
+ f'layer.{i}' for i in range(10, 20)
171
+ ]
172
+ mock_kv_cache_config.kv_cache_groups = [
173
+ mock_kv_cache_group1, mock_kv_cache_group2
174
+ ]
175
+ self.runner.kv_cache_config = mock_kv_cache_config
176
+ self.runner.use_hybrid_kvcache = True
177
+
119
178
 
120
179
  class TestTPUJaxRunnerMultimodalModelLoadedForTextOnly:
121
180
 
@@ -126,7 +185,7 @@ class TestTPUJaxRunnerMultimodalModelLoadedForTextOnly:
126
185
  device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, -1)
127
186
  self.mock_mesh = jax.make_mesh(device_array.shape,
128
187
  ('data', 'attn_dp', 'expert', 'model'))
129
- # Setup the runner with the model_config.is_multimodal_model set to True but get_model returning None for get_multimodal_embeddings_fn and get_input_embeddings_fn.
188
+ # Setup the runner with the model_config.is_multimodal_model set to True but get_model returning None for embed_multimodal_fn and embed_input_ids_fn.
130
189
  with patch('jax.devices', return_value=self.mock_devices), \
131
190
  patch('jax.make_mesh', return_value=self.mock_mesh), \
132
191
  patch('jax.random.key', return_value=self.mock_rng_key), \
@@ -172,8 +231,8 @@ class TestTPUJaxRunnerMultimodalModelLoadedForTextOnly:
172
231
  def _model_get_model(self):
173
232
  mock_multimodal_fns = {
174
233
  "precompile_vision_encoder_fn": None,
175
- "get_multimodal_embeddings_fn": None,
176
- "get_input_embeddings_fn": None,
234
+ "embed_multimodal_fn": None,
235
+ "embed_input_ids_fn": None,
177
236
  "get_mrope_input_positions_fn": None
178
237
  }
179
238
  return (
@@ -190,13 +249,13 @@ class TestTPUJaxRunnerMultimodalModelLoadedForTextOnly:
190
249
  # Precondition: make sure the model_config claims the model supports MM.
191
250
  assert self.runner.model_config.is_multimodal_model
192
251
 
193
- # Precondition: load the model and returns get_multimodal_embeddings_fn as None.
194
- assert self.runner.get_multimodal_embeddings_fn is None
252
+ # Precondition: load the model and returns embed_multimodal_fn as None.
253
+ assert self.runner.embed_multimodal_fn is None
195
254
 
196
255
  assert not self.runner.is_multimodal_model
197
256
 
198
- self.runner.get_input_embeddings_fn = MagicMock()
257
+ self.runner.embed_input_ids_fn = MagicMock()
199
258
  dummy_input_ids = jnp.array([1, 2, 3])
200
259
  dummy_mm_embeds = jnp.ones((10, 128))
201
260
  _ = self.runner._get_input_ids_embeds(dummy_input_ids, dummy_mm_embeds)
202
- self.runner.get_input_embeddings_fn.assert_not_called()
261
+ self.runner.embed_input_ids_fn.assert_not_called()
@@ -97,6 +97,20 @@ class TestTPUJaxRunnerDPInputsLightweight:
97
97
  mock_output.grammar_bitmask = None
98
98
  return mock_output
99
99
 
100
+ def _create_mock_hybrid_kv_cache_config(self):
101
+ mock_kv_cache_config = MagicMock()
102
+ mock_kv_cache_group1 = MagicMock()
103
+ mock_kv_cache_group1.layer_names = [f'layer.{i}' for i in range(10)]
104
+ mock_kv_cache_group2 = MagicMock()
105
+ mock_kv_cache_group2.layer_names = [
106
+ f'layer.{i}' for i in range(10, 20)
107
+ ]
108
+ mock_kv_cache_config.kv_cache_groups = [
109
+ mock_kv_cache_group1, mock_kv_cache_group2
110
+ ]
111
+ self.runner.kv_cache_config = mock_kv_cache_config
112
+ self.runner.use_hybrid_kvcache = True
113
+
100
114
  @patch('tpu_inference.runner.tpu_runner.NamedSharding')
101
115
  @patch('tpu_inference.runner.tpu_runner.runner_utils')
102
116
  @patch('tpu_inference.runner.tpu_runner.device_array',
@@ -146,6 +160,58 @@ class TestTPUJaxRunnerDPInputsLightweight:
146
160
  with pytest.raises(AssertionError):
147
161
  self.runner._prepare_inputs_dp(scheduler_output)
148
162
 
163
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
164
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
165
+ @patch('tpu_inference.runner.tpu_runner.device_array',
166
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
167
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
168
+ def test_prepare_inputs_dp_hybrid_kvcache(self, mock_sampling_metadata,
169
+ mock_device_array,
170
+ mock_runner_utils,
171
+ mock_named_sharding):
172
+ """Test basic functionality of _prepare_inputs_dp."""
173
+ # Mock utility functions
174
+ mock_runner_utils.get_padded_token_len.return_value = 16
175
+ mock_sampling_metadata.from_input_batch.return_value = MagicMock()
176
+ mock_named_sharding.return_value = MagicMock()
177
+
178
+ # Create test data - only use req1 and req2 to match num_reqs=2
179
+ num_scheduled_tokens = {"req1": 5, "req2": 3}
180
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
181
+ scheduler_output = self._create_mock_scheduler_output(
182
+ num_scheduled_tokens, assigned_dp_ranks)
183
+
184
+ # Create hybrid kv cache config with 10 full attn layers, 10 sw attn layers
185
+ self._create_mock_hybrid_kv_cache_config()
186
+
187
+ # update input_batch's block_table
188
+ mock_block_table = MagicMock()
189
+ mock_block_table.get_cpu_tensor.return_value = np.arange(32).reshape(
190
+ 4, 8)
191
+ self.runner.input_batch.block_table = [
192
+ mock_block_table, mock_block_table
193
+ ]
194
+
195
+ # update model runner's block_tables_cpu:
196
+ self.runner.block_tables_cpu = [
197
+ np.zeros((8, 8), dtype=np.int32),
198
+ np.zeros((8, 8), dtype=np.int32)
199
+ ]
200
+
201
+ # Execute the method
202
+ result = self.runner._prepare_inputs_dp(scheduler_output)
203
+
204
+ # Basic assertions
205
+ assert len(result) == 8
206
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
207
+
208
+ # Verify utility functions were called
209
+ mock_runner_utils.get_padded_token_len.assert_called()
210
+
211
+ # Verify there's attention_metadata for each layer
212
+ assert isinstance(attention_metadata, dict)
213
+ assert len(attention_metadata) == 20
214
+
149
215
  def test_prepare_dp_input_metadata(self):
150
216
  num_scheduled_tokens = {"req1": 10, "req2": 5, "req3": 8, "req4": 3}
151
217
  assigned_dp_ranks = {"req1": 0, "req2": 0, "req3": 1, "req4": 1}