tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/core/test_dp_scheduler.py +128 -71
- tests/e2e/test_data_parallel.py +176 -280
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_speculative_decoding.py +26 -6
- tests/layers/jax/test_qwix.py +1 -1
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
- tests/layers/vllm/test_mxfp4.py +25 -10
- tests/layers/vllm/test_unquantized.py +61 -31
- tests/layers/vllm/utils.py +19 -4
- tests/models/common/test_model_loader.py +2 -2
- tests/models/jax/test_qwen2_5_vl.py +10 -11
- tests/runner/test_multimodal_manager.py +3 -3
- tests/runner/test_tpu_runner.py +67 -8
- tests/runner/test_tpu_runner_dp.py +66 -0
- tpu_inference/core/sched/dp_scheduler.py +65 -40
- tpu_inference/kernels/mla/v1/kernel.py +7 -26
- tpu_inference/layers/common/sharding.py +8 -3
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
- tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
- tpu_inference/layers/jax/sample/sampling.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +51 -47
- tpu_inference/layers/vllm/quantization/common.py +14 -13
- tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
- tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
- tpu_inference/layers/vllm/sharding.py +7 -4
- tpu_inference/models/common/model_loader.py +11 -14
- tpu_inference/models/jax/llama3.py +13 -10
- tpu_inference/models/jax/llama_guard_4.py +1 -1
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -4
- tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
- tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
- tpu_inference/platforms/tpu_platform.py +7 -7
- tpu_inference/runner/compilation_manager.py +43 -33
- tpu_inference/runner/kv_cache_manager.py +1 -2
- tpu_inference/runner/multimodal_manager.py +1 -1
- tpu_inference/runner/tpu_runner.py +12 -9
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/tpu_worker.py +5 -2
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
|
@@ -118,12 +118,16 @@ def test_loading_model(model, mesh):
|
|
|
118
118
|
|
|
119
119
|
@pytest.mark.parametrize("model", MODELS)
|
|
120
120
|
@pytest.mark.parametrize("bias", [False, True])
|
|
121
|
-
@pytest.mark.parametrize("
|
|
122
|
-
test_utils.get_spmd_mesh(1),
|
|
123
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
124
|
-
])
|
|
121
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
125
122
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
126
|
-
|
|
123
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
124
|
+
def test_row_parallel_linear(model, bias, num_devices, enable_sp,
|
|
125
|
+
enable_attn_dp):
|
|
126
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
127
|
+
if enable_attn_dp and num_devices < 2:
|
|
128
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
129
|
+
|
|
130
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
127
131
|
dtype = torch.bfloat16
|
|
128
132
|
|
|
129
133
|
engine_args = EngineArgs(
|
|
@@ -191,12 +195,16 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp):
|
|
|
191
195
|
|
|
192
196
|
@pytest.mark.parametrize("model", MODELS)
|
|
193
197
|
@pytest.mark.parametrize("bias", [False, True])
|
|
194
|
-
@pytest.mark.parametrize("
|
|
195
|
-
test_utils.get_spmd_mesh(1),
|
|
196
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
197
|
-
])
|
|
198
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
198
199
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
199
|
-
|
|
200
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
201
|
+
def test_column_parallel_linear(model, bias, num_devices, enable_sp,
|
|
202
|
+
enable_attn_dp):
|
|
203
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
204
|
+
if enable_attn_dp and num_devices < 2:
|
|
205
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
206
|
+
|
|
207
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
200
208
|
dtype = torch.bfloat16
|
|
201
209
|
|
|
202
210
|
engine_args = EngineArgs(
|
|
@@ -263,13 +271,17 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp):
|
|
|
263
271
|
|
|
264
272
|
@pytest.mark.parametrize("model", MODELS)
|
|
265
273
|
@pytest.mark.parametrize("bias", [False, True])
|
|
266
|
-
@pytest.mark.parametrize("
|
|
267
|
-
test_utils.get_spmd_mesh(1),
|
|
268
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
269
|
-
])
|
|
274
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
270
275
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
271
276
|
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
272
|
-
|
|
277
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
278
|
+
def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
|
|
279
|
+
enable_attn_dp):
|
|
280
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
281
|
+
if enable_attn_dp and num_devices < 2:
|
|
282
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
283
|
+
|
|
284
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
273
285
|
dtype = torch.bfloat16
|
|
274
286
|
|
|
275
287
|
engine_args = EngineArgs(
|
|
@@ -341,14 +353,17 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
|
|
|
341
353
|
|
|
342
354
|
@pytest.mark.parametrize("model", MODELS)
|
|
343
355
|
@pytest.mark.parametrize("bias", [False, True])
|
|
344
|
-
@pytest.mark.parametrize("
|
|
345
|
-
test_utils.get_spmd_mesh(1),
|
|
346
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
347
|
-
])
|
|
356
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
348
357
|
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
349
358
|
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
350
|
-
|
|
351
|
-
|
|
359
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
360
|
+
def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
|
|
361
|
+
enable_sp, enable_attn_dp):
|
|
362
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
363
|
+
if enable_attn_dp and num_devices < 2:
|
|
364
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
365
|
+
|
|
366
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
352
367
|
dtype = torch.bfloat16
|
|
353
368
|
|
|
354
369
|
engine_args = EngineArgs(
|
|
@@ -418,10 +433,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
|
|
|
418
433
|
|
|
419
434
|
|
|
420
435
|
@pytest.mark.parametrize("use_ep", [True, False])
|
|
421
|
-
@pytest.mark.parametrize("
|
|
422
|
-
test_utils.get_spmd_mesh(1),
|
|
423
|
-
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
424
|
-
])
|
|
436
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
425
437
|
@pytest.mark.parametrize("num_tokens", [8])
|
|
426
438
|
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
|
|
427
439
|
@pytest.mark.parametrize("hidden_size", [128, 512])
|
|
@@ -429,8 +441,15 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
|
|
|
429
441
|
@pytest.mark.parametrize("topk", [2])
|
|
430
442
|
@pytest.mark.parametrize("has_bias", [False, True])
|
|
431
443
|
@pytest.mark.parametrize("activation", ["silu", "swigluoai"])
|
|
432
|
-
|
|
433
|
-
|
|
444
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
445
|
+
def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
|
|
446
|
+
hidden_size, num_experts, topk, has_bias, activation,
|
|
447
|
+
enable_attn_dp):
|
|
448
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
449
|
+
if enable_attn_dp and num_devices < 2:
|
|
450
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
451
|
+
|
|
452
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
434
453
|
|
|
435
454
|
torch.manual_seed(42)
|
|
436
455
|
dtype = torch.bfloat16
|
|
@@ -502,16 +521,27 @@ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
|
|
|
502
521
|
rtol=1e-1)
|
|
503
522
|
|
|
504
523
|
|
|
505
|
-
@pytest.mark.parametrize("
|
|
506
|
-
[test_utils.get_spmd_mesh(jax.local_device_count())])
|
|
524
|
+
@pytest.mark.parametrize("num_devices", [jax.local_device_count()])
|
|
507
525
|
@pytest.mark.parametrize("num_tokens", [128, 512])
|
|
508
526
|
@pytest.mark.parametrize("intermediate_size", [512])
|
|
509
527
|
@pytest.mark.parametrize("hidden_size", [512])
|
|
510
528
|
@pytest.mark.parametrize("num_experts", [32])
|
|
511
529
|
@pytest.mark.parametrize("topk", [8])
|
|
512
530
|
@pytest.mark.parametrize("has_bias", [False, True])
|
|
513
|
-
|
|
514
|
-
|
|
531
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
532
|
+
def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
533
|
+
hidden_size, num_experts, topk, has_bias,
|
|
534
|
+
enable_attn_dp):
|
|
535
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
536
|
+
if enable_attn_dp and num_devices < 2:
|
|
537
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
538
|
+
|
|
539
|
+
# Skip attn_dp tests for fused_moe_use_kernel since the kernel only supports 2D mesh
|
|
540
|
+
if enable_attn_dp:
|
|
541
|
+
pytest.skip(
|
|
542
|
+
"fused_moe kernel does not support attn_dp (requires 2D mesh)")
|
|
543
|
+
|
|
544
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
515
545
|
|
|
516
546
|
# TODO(Qiliang Cui): Remove when issue is resolved.
|
|
517
547
|
if not jtu.is_device_tpu_at_least(version=7):
|
tests/layers/vllm/utils.py
CHANGED
|
@@ -16,12 +16,27 @@ import jax
|
|
|
16
16
|
import torch
|
|
17
17
|
import torch.nn.functional as F
|
|
18
18
|
|
|
19
|
+
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
20
|
+
MESH_AXIS_NAMES_2D)
|
|
19
21
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
+
|
|
23
|
+
def get_spmd_mesh(num_devices: int = 1, enable_attn_dp: bool = False):
|
|
22
24
|
devices = sorted(jax.devices(), key=lambda d: d.id)[0:num_devices]
|
|
23
|
-
|
|
24
|
-
|
|
25
|
+
|
|
26
|
+
if enable_attn_dp:
|
|
27
|
+
if num_devices < 2:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"enable_attn_dp requires at least 2 devices, got {num_devices}"
|
|
30
|
+
)
|
|
31
|
+
axis_names = MESH_AXIS_NAMES
|
|
32
|
+
attn_dp_size = 2
|
|
33
|
+
model_size = num_devices // attn_dp_size
|
|
34
|
+
mesh_shape = (1, attn_dp_size, 1, model_size)
|
|
35
|
+
return jax.make_mesh(mesh_shape, axis_names, devices=devices)
|
|
36
|
+
else:
|
|
37
|
+
axis_names = MESH_AXIS_NAMES_2D
|
|
38
|
+
mesh_shape = (1, len(devices))
|
|
39
|
+
return jax.make_mesh(mesh_shape, axis_names, devices=devices)
|
|
25
40
|
|
|
26
41
|
|
|
27
42
|
def find_all_layer_type(module: torch.nn.Module, layer_type: torch.nn.Module):
|
|
@@ -218,9 +218,9 @@ def test_register_model_vllm_wrapper_methods():
|
|
|
218
218
|
with pytest.raises(NotImplementedError, match="JAX model"):
|
|
219
219
|
instance.forward(input_ids=None, positions=None)
|
|
220
220
|
|
|
221
|
-
# `
|
|
221
|
+
# `embed_input_ids` should be unimplemented.
|
|
222
222
|
with pytest.raises(NotImplementedError, match="JAX model"):
|
|
223
|
-
instance.
|
|
223
|
+
instance.embed_input_ids(input_ids=None, positions=None)
|
|
224
224
|
|
|
225
225
|
# `load_weights` should be a no-op that returns None.
|
|
226
226
|
assert instance.load_weights() is None
|
|
@@ -491,8 +491,7 @@ class TestQwen2_5_VLForConditionalGeneration:
|
|
|
491
491
|
assert embeddings[1].shape == (tokens_per_image, vc.out_hidden_size)
|
|
492
492
|
assert model.visual.call_count == 2
|
|
493
493
|
|
|
494
|
-
def
|
|
495
|
-
self, model: Qwen2_5_VLForConditionalGeneration):
|
|
494
|
+
def test_embed_multimodal(self, model: Qwen2_5_VLForConditionalGeneration):
|
|
496
495
|
grid_thw = ((2, 28, 28), )
|
|
497
496
|
vc = model.config.vision_config
|
|
498
497
|
patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
@@ -503,20 +502,20 @@ class TestQwen2_5_VLForConditionalGeneration:
|
|
|
503
502
|
with patch.object(model,
|
|
504
503
|
'_process_image_input',
|
|
505
504
|
return_value=(mock_vision_output, )) as mock_process:
|
|
506
|
-
mm_embeds = model.
|
|
507
|
-
|
|
505
|
+
mm_embeds = model.embed_multimodal(grid_thw,
|
|
506
|
+
pixel_values=pixel_values)
|
|
508
507
|
mock_process.assert_called_once()
|
|
509
508
|
assert isinstance(mm_embeds, tuple)
|
|
510
509
|
assert len(mm_embeds) == 1
|
|
511
510
|
assert mm_embeds[0].shape == (tokens_per_image, vc.out_hidden_size)
|
|
512
511
|
|
|
513
|
-
mm_embeds_none = model.
|
|
512
|
+
mm_embeds_none = model.embed_multimodal(grid_thw)
|
|
514
513
|
assert len(mm_embeds_none) == 0
|
|
515
514
|
|
|
516
515
|
@patch('tpu_inference.models.jax.qwen2_5_vl.merge_multimodal_embeddings')
|
|
517
|
-
def
|
|
518
|
-
|
|
519
|
-
|
|
516
|
+
def test_embed_input_ids(self, mock_merge_embeddings: MagicMock,
|
|
517
|
+
model: Qwen2_5_VLForConditionalGeneration,
|
|
518
|
+
rng: PRNGKey):
|
|
520
519
|
input_ids = jax.random.randint(rng, (1, 10), 0,
|
|
521
520
|
model.config.vocab_size)
|
|
522
521
|
mock_text_embeds = jnp.ones((1, 10, model.config.hidden_size))
|
|
@@ -524,12 +523,12 @@ class TestQwen2_5_VLForConditionalGeneration:
|
|
|
524
523
|
model.language_model.model.embed = MagicMock(
|
|
525
524
|
return_value=mock_text_embeds)
|
|
526
525
|
|
|
527
|
-
embeds = model.
|
|
526
|
+
embeds = model.embed_input_ids(input_ids, None)
|
|
528
527
|
np.testing.assert_array_equal(embeds, mock_text_embeds)
|
|
529
528
|
mock_merge_embeddings.assert_not_called()
|
|
530
529
|
|
|
531
530
|
empty_mm = jnp.ones((0, model.config.hidden_size), )
|
|
532
|
-
embeds_empty_mm = model.
|
|
531
|
+
embeds_empty_mm = model.embed_input_ids(input_ids, empty_mm)
|
|
533
532
|
np.testing.assert_array_equal(embeds_empty_mm, mock_text_embeds)
|
|
534
533
|
mock_merge_embeddings.assert_not_called()
|
|
535
534
|
|
|
@@ -537,7 +536,7 @@ class TestQwen2_5_VLForConditionalGeneration:
|
|
|
537
536
|
mock_merged = jnp.ones((1, 15, model.config.hidden_size))
|
|
538
537
|
mock_merge_embeddings.return_value = mock_merged
|
|
539
538
|
|
|
540
|
-
embeds_mm = model.
|
|
539
|
+
embeds_mm = model.embed_input_ids(input_ids, mm_embeds)
|
|
541
540
|
np.testing.assert_array_equal(embeds_mm, mock_merged)
|
|
542
541
|
mock_merge_embeddings.assert_called_once_with(
|
|
543
542
|
input_ids, mock_text_embeds, mm_embeds,
|
|
@@ -88,7 +88,7 @@ class TestMultiModalManager:
|
|
|
88
88
|
# 1. ===== Setup =====
|
|
89
89
|
self.runner.is_multimodal_model = True
|
|
90
90
|
self.mock_get_mm_embed_fn = MagicMock()
|
|
91
|
-
self.runner.
|
|
91
|
+
self.runner.embed_multimodal_fn = self.mock_get_mm_embed_fn
|
|
92
92
|
|
|
93
93
|
self.runner.state = MagicMock()
|
|
94
94
|
# Mock scheduler output
|
|
@@ -139,7 +139,7 @@ class TestMultiModalManager:
|
|
|
139
139
|
np.testing.assert_array_equal(np.asarray(cached_embedding),
|
|
140
140
|
np.asarray(dummy_embedding))
|
|
141
141
|
|
|
142
|
-
# Check if
|
|
142
|
+
# Check if embed_multimodal_fn was called with correct args
|
|
143
143
|
self.mock_get_mm_embed_fn.assert_called_once()
|
|
144
144
|
call_args = self.mock_get_mm_embed_fn.call_args
|
|
145
145
|
|
|
@@ -169,7 +169,7 @@ class TestMultiModalManager:
|
|
|
169
169
|
# 1. ===== Setup =====
|
|
170
170
|
self.runner.is_multimodal_model = True
|
|
171
171
|
self.mock_get_mm_embed_fn = MagicMock()
|
|
172
|
-
self.runner.
|
|
172
|
+
self.runner.embed_multimodal_fn = self.mock_get_mm_embed_fn
|
|
173
173
|
|
|
174
174
|
self.runner.state = MagicMock()
|
|
175
175
|
# Mock scheduler output for two requests
|
tests/runner/test_tpu_runner.py
CHANGED
|
@@ -88,7 +88,7 @@ class TestTPUJaxRunner:
|
|
|
88
88
|
|
|
89
89
|
# Mock the embedding function
|
|
90
90
|
self.mock_get_input_embed_fn = MagicMock()
|
|
91
|
-
self.runner.
|
|
91
|
+
self.runner.embed_input_ids_fn = self.mock_get_input_embed_fn
|
|
92
92
|
self.mock_get_input_embed_fn.return_value = dummy_final_embeds
|
|
93
93
|
self.runner.state = MagicMock()
|
|
94
94
|
|
|
@@ -116,6 +116,65 @@ class TestTPUJaxRunner:
|
|
|
116
116
|
np.asarray(dummy_input_ids))
|
|
117
117
|
self.mock_get_input_embed_fn.assert_not_called()
|
|
118
118
|
|
|
119
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
120
|
+
def test_prepare_inputs_hybrid_kvcache(self, mock_sampling_metadata):
|
|
121
|
+
# create hybrid kv cache config
|
|
122
|
+
# 20 layers, 10 full attn + 10 sw attn
|
|
123
|
+
self._create_mock_hybrid_kv_cache_config()
|
|
124
|
+
|
|
125
|
+
# Mock scheduler output.
|
|
126
|
+
scheduler_output = MagicMock()
|
|
127
|
+
scheduler_output.total_num_scheduled_tokens = 10
|
|
128
|
+
scheduler_output.num_scheduled_tokens = {'req1': 10}
|
|
129
|
+
scheduler_output.scheduled_spec_decode_tokens = {}
|
|
130
|
+
scheduler_output.grammar_bitmask = None
|
|
131
|
+
|
|
132
|
+
# Mock input_batch
|
|
133
|
+
self.runner.input_batch = MagicMock()
|
|
134
|
+
self.runner.input_batch.num_reqs = 1
|
|
135
|
+
self.runner.input_batch.req_ids = ['req1']
|
|
136
|
+
self.runner.input_batch.req_id_to_index = {'req1': 0}
|
|
137
|
+
self.runner.input_batch.num_computed_tokens_cpu = np.array([10])
|
|
138
|
+
self.runner.input_batch.token_ids_cpu = np.random.randint(
|
|
139
|
+
0, 1000, (8, 64), dtype=np.int32)
|
|
140
|
+
|
|
141
|
+
# Mock block tables
|
|
142
|
+
# there will be 2 block tables since there are 2 kv cache groups
|
|
143
|
+
mock_block_table = MagicMock()
|
|
144
|
+
mock_block_table.get_cpu_tensor.return_value = np.zeros(
|
|
145
|
+
self.runner.block_tables_cpu[0].shape)
|
|
146
|
+
self.runner.input_batch.block_table = [
|
|
147
|
+
mock_block_table, mock_block_table
|
|
148
|
+
]
|
|
149
|
+
self.runner.block_tables_cpu = [
|
|
150
|
+
np.zeros(self.runner.block_tables_cpu[0].shape, dtype=np.int32),
|
|
151
|
+
np.zeros(self.runner.block_tables_cpu[0].shape, dtype=np.int32)
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
mock_sampling_instance = MagicMock()
|
|
155
|
+
mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
|
|
156
|
+
|
|
157
|
+
output = self.runner._prepare_inputs_non_dp(scheduler_output)
|
|
158
|
+
assert len(output) == 8
|
|
159
|
+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = output
|
|
160
|
+
# assert it will create attention metadata for each layer.
|
|
161
|
+
assert isinstance(attention_metadata, dict)
|
|
162
|
+
assert len(attention_metadata) == 20
|
|
163
|
+
|
|
164
|
+
def _create_mock_hybrid_kv_cache_config(self):
|
|
165
|
+
mock_kv_cache_config = MagicMock()
|
|
166
|
+
mock_kv_cache_group1 = MagicMock()
|
|
167
|
+
mock_kv_cache_group1.layer_names = [f'layer.{i}' for i in range(10)]
|
|
168
|
+
mock_kv_cache_group2 = MagicMock()
|
|
169
|
+
mock_kv_cache_group2.layer_names = [
|
|
170
|
+
f'layer.{i}' for i in range(10, 20)
|
|
171
|
+
]
|
|
172
|
+
mock_kv_cache_config.kv_cache_groups = [
|
|
173
|
+
mock_kv_cache_group1, mock_kv_cache_group2
|
|
174
|
+
]
|
|
175
|
+
self.runner.kv_cache_config = mock_kv_cache_config
|
|
176
|
+
self.runner.use_hybrid_kvcache = True
|
|
177
|
+
|
|
119
178
|
|
|
120
179
|
class TestTPUJaxRunnerMultimodalModelLoadedForTextOnly:
|
|
121
180
|
|
|
@@ -126,7 +185,7 @@ class TestTPUJaxRunnerMultimodalModelLoadedForTextOnly:
|
|
|
126
185
|
device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, -1)
|
|
127
186
|
self.mock_mesh = jax.make_mesh(device_array.shape,
|
|
128
187
|
('data', 'attn_dp', 'expert', 'model'))
|
|
129
|
-
# Setup the runner with the model_config.is_multimodal_model set to True but get_model returning None for
|
|
188
|
+
# Setup the runner with the model_config.is_multimodal_model set to True but get_model returning None for embed_multimodal_fn and embed_input_ids_fn.
|
|
130
189
|
with patch('jax.devices', return_value=self.mock_devices), \
|
|
131
190
|
patch('jax.make_mesh', return_value=self.mock_mesh), \
|
|
132
191
|
patch('jax.random.key', return_value=self.mock_rng_key), \
|
|
@@ -172,8 +231,8 @@ class TestTPUJaxRunnerMultimodalModelLoadedForTextOnly:
|
|
|
172
231
|
def _model_get_model(self):
|
|
173
232
|
mock_multimodal_fns = {
|
|
174
233
|
"precompile_vision_encoder_fn": None,
|
|
175
|
-
"
|
|
176
|
-
"
|
|
234
|
+
"embed_multimodal_fn": None,
|
|
235
|
+
"embed_input_ids_fn": None,
|
|
177
236
|
"get_mrope_input_positions_fn": None
|
|
178
237
|
}
|
|
179
238
|
return (
|
|
@@ -190,13 +249,13 @@ class TestTPUJaxRunnerMultimodalModelLoadedForTextOnly:
|
|
|
190
249
|
# Precondition: make sure the model_config claims the model supports MM.
|
|
191
250
|
assert self.runner.model_config.is_multimodal_model
|
|
192
251
|
|
|
193
|
-
# Precondition: load the model and returns
|
|
194
|
-
assert self.runner.
|
|
252
|
+
# Precondition: load the model and returns embed_multimodal_fn as None.
|
|
253
|
+
assert self.runner.embed_multimodal_fn is None
|
|
195
254
|
|
|
196
255
|
assert not self.runner.is_multimodal_model
|
|
197
256
|
|
|
198
|
-
self.runner.
|
|
257
|
+
self.runner.embed_input_ids_fn = MagicMock()
|
|
199
258
|
dummy_input_ids = jnp.array([1, 2, 3])
|
|
200
259
|
dummy_mm_embeds = jnp.ones((10, 128))
|
|
201
260
|
_ = self.runner._get_input_ids_embeds(dummy_input_ids, dummy_mm_embeds)
|
|
202
|
-
self.runner.
|
|
261
|
+
self.runner.embed_input_ids_fn.assert_not_called()
|
|
@@ -97,6 +97,20 @@ class TestTPUJaxRunnerDPInputsLightweight:
|
|
|
97
97
|
mock_output.grammar_bitmask = None
|
|
98
98
|
return mock_output
|
|
99
99
|
|
|
100
|
+
def _create_mock_hybrid_kv_cache_config(self):
|
|
101
|
+
mock_kv_cache_config = MagicMock()
|
|
102
|
+
mock_kv_cache_group1 = MagicMock()
|
|
103
|
+
mock_kv_cache_group1.layer_names = [f'layer.{i}' for i in range(10)]
|
|
104
|
+
mock_kv_cache_group2 = MagicMock()
|
|
105
|
+
mock_kv_cache_group2.layer_names = [
|
|
106
|
+
f'layer.{i}' for i in range(10, 20)
|
|
107
|
+
]
|
|
108
|
+
mock_kv_cache_config.kv_cache_groups = [
|
|
109
|
+
mock_kv_cache_group1, mock_kv_cache_group2
|
|
110
|
+
]
|
|
111
|
+
self.runner.kv_cache_config = mock_kv_cache_config
|
|
112
|
+
self.runner.use_hybrid_kvcache = True
|
|
113
|
+
|
|
100
114
|
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
101
115
|
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
102
116
|
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
@@ -146,6 +160,58 @@ class TestTPUJaxRunnerDPInputsLightweight:
|
|
|
146
160
|
with pytest.raises(AssertionError):
|
|
147
161
|
self.runner._prepare_inputs_dp(scheduler_output)
|
|
148
162
|
|
|
163
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
164
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
165
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
166
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
167
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
168
|
+
def test_prepare_inputs_dp_hybrid_kvcache(self, mock_sampling_metadata,
|
|
169
|
+
mock_device_array,
|
|
170
|
+
mock_runner_utils,
|
|
171
|
+
mock_named_sharding):
|
|
172
|
+
"""Test basic functionality of _prepare_inputs_dp."""
|
|
173
|
+
# Mock utility functions
|
|
174
|
+
mock_runner_utils.get_padded_token_len.return_value = 16
|
|
175
|
+
mock_sampling_metadata.from_input_batch.return_value = MagicMock()
|
|
176
|
+
mock_named_sharding.return_value = MagicMock()
|
|
177
|
+
|
|
178
|
+
# Create test data - only use req1 and req2 to match num_reqs=2
|
|
179
|
+
num_scheduled_tokens = {"req1": 5, "req2": 3}
|
|
180
|
+
assigned_dp_ranks = {"req1": 0, "req2": 1}
|
|
181
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
182
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
183
|
+
|
|
184
|
+
# Create hybrid kv cache config with 10 full attn layers, 10 sw attn layers
|
|
185
|
+
self._create_mock_hybrid_kv_cache_config()
|
|
186
|
+
|
|
187
|
+
# update input_batch's block_table
|
|
188
|
+
mock_block_table = MagicMock()
|
|
189
|
+
mock_block_table.get_cpu_tensor.return_value = np.arange(32).reshape(
|
|
190
|
+
4, 8)
|
|
191
|
+
self.runner.input_batch.block_table = [
|
|
192
|
+
mock_block_table, mock_block_table
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
# update model runner's block_tables_cpu:
|
|
196
|
+
self.runner.block_tables_cpu = [
|
|
197
|
+
np.zeros((8, 8), dtype=np.int32),
|
|
198
|
+
np.zeros((8, 8), dtype=np.int32)
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
# Execute the method
|
|
202
|
+
result = self.runner._prepare_inputs_dp(scheduler_output)
|
|
203
|
+
|
|
204
|
+
# Basic assertions
|
|
205
|
+
assert len(result) == 8
|
|
206
|
+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
|
|
207
|
+
|
|
208
|
+
# Verify utility functions were called
|
|
209
|
+
mock_runner_utils.get_padded_token_len.assert_called()
|
|
210
|
+
|
|
211
|
+
# Verify there's attention_metadata for each layer
|
|
212
|
+
assert isinstance(attention_metadata, dict)
|
|
213
|
+
assert len(attention_metadata) == 20
|
|
214
|
+
|
|
149
215
|
def test_prepare_dp_input_metadata(self):
|
|
150
216
|
num_scheduled_tokens = {"req1": 10, "req2": 5, "req3": 8, "req4": 3}
|
|
151
217
|
assigned_dp_ranks = {"req1": 0, "req2": 0, "req3": 1, "req4": 1}
|