tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
from flax import nnx
|
|
22
|
+
from flax.typing import PRNGKey
|
|
23
|
+
from jax.sharding import Mesh
|
|
24
|
+
from vllm.config import ModelConfig
|
|
25
|
+
|
|
26
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
|
|
28
|
+
from tpu_inference.runner.kv_cache import create_kv_caches
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MockVllmConfig:
|
|
32
|
+
|
|
33
|
+
def __init__(self, model: str, kv_cache_dtype: str):
|
|
34
|
+
self.model_config = ModelConfig(model)
|
|
35
|
+
self.model_config.dtype = jnp.bfloat16
|
|
36
|
+
self.load_config = MagicMock()
|
|
37
|
+
self.load_config.download_dir = None
|
|
38
|
+
self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.fixture(scope="module")
|
|
42
|
+
def mesh():
|
|
43
|
+
"""
|
|
44
|
+
Creates a mesh with 1 device.
|
|
45
|
+
"""
|
|
46
|
+
if not jax.devices():
|
|
47
|
+
pytest.skip("No JAX devices available for mesh creation.")
|
|
48
|
+
|
|
49
|
+
devices = np.array(jax.local_devices()[:1])
|
|
50
|
+
num_devices = len(devices)
|
|
51
|
+
assert num_devices == 1
|
|
52
|
+
device_mesh = devices.reshape((num_devices, 1, 1, 1))
|
|
53
|
+
|
|
54
|
+
with Mesh(device_mesh,
|
|
55
|
+
axis_names=('data', 'attn_dp', 'expert', 'model')) as m:
|
|
56
|
+
yield m
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@pytest.fixture
|
|
60
|
+
def mock_model_inputs():
|
|
61
|
+
num_tokens = 8
|
|
62
|
+
num_reqs = 1
|
|
63
|
+
max_num_blocks_per_req = 4
|
|
64
|
+
input_ids = jnp.ones((num_tokens, ), dtype=jnp.int32)
|
|
65
|
+
positions = jnp.ones((num_tokens, ), dtype=jnp.int32)
|
|
66
|
+
block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req),
|
|
67
|
+
dtype=jnp.int32).reshape(-1)
|
|
68
|
+
seq_lens = jnp.ones((num_reqs, ), dtype=jnp.int32)
|
|
69
|
+
query_start_loc = jnp.ones((num_reqs + 1, ), dtype=jnp.int32)
|
|
70
|
+
request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32)
|
|
71
|
+
|
|
72
|
+
attention_metadata = AttentionMetadata(
|
|
73
|
+
input_positions=positions,
|
|
74
|
+
block_tables=block_tables,
|
|
75
|
+
seq_lens=seq_lens,
|
|
76
|
+
query_start_loc=query_start_loc,
|
|
77
|
+
request_distribution=request_distribution,
|
|
78
|
+
)
|
|
79
|
+
indices_do_sample = jnp.ones((num_reqs, ), dtype=jnp.int32)
|
|
80
|
+
|
|
81
|
+
return (input_ids, attention_metadata, indices_do_sample)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@pytest.fixture
|
|
85
|
+
def rng() -> PRNGKey:
|
|
86
|
+
"""Provides a reusable JAX PRNGKey."""
|
|
87
|
+
return jax.random.PRNGKey(42)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class TestQwen3ForCausalLM:
|
|
91
|
+
|
|
92
|
+
@pytest.mark.parametrize("mock_vllm_config", [
|
|
93
|
+
MockVllmConfig("Qwen/Qwen3-0.6B", "auto"),
|
|
94
|
+
MockVllmConfig("Qwen/Qwen3-0.6B", "fp8")
|
|
95
|
+
])
|
|
96
|
+
def test_qwen3_600M(self, mock_vllm_config, rng, mesh, mock_model_inputs):
|
|
97
|
+
"""Tests model init and model forward for the 0.6B model variant."""
|
|
98
|
+
|
|
99
|
+
# Test model init
|
|
100
|
+
model = Qwen3ForCausalLM(mock_vllm_config, rng, mesh)
|
|
101
|
+
|
|
102
|
+
model_config = mock_vllm_config.model_config
|
|
103
|
+
hf_config = model_config.hf_config
|
|
104
|
+
|
|
105
|
+
assert model.mesh.shape == {
|
|
106
|
+
"data": 1,
|
|
107
|
+
"attn_dp": 1,
|
|
108
|
+
"expert": 1,
|
|
109
|
+
"model": 1
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
layers = model.model.layers
|
|
113
|
+
assert len(layers) == hf_config.num_hidden_layers
|
|
114
|
+
assert isinstance(model.rng, nnx.Rngs)
|
|
115
|
+
assert model.model.lm_head == model.model.embed.embedding
|
|
116
|
+
|
|
117
|
+
attn = layers[0].self_attn
|
|
118
|
+
hidden_size = hf_config.hidden_size
|
|
119
|
+
num_heads = hf_config.num_attention_heads
|
|
120
|
+
num_kv_heads = hf_config.num_key_value_heads
|
|
121
|
+
rope_theta = hf_config.rope_theta
|
|
122
|
+
original_head_dim = hf_config.head_dim
|
|
123
|
+
head_dim = 128
|
|
124
|
+
intermediate_size = hf_config.intermediate_size
|
|
125
|
+
|
|
126
|
+
assert attn.hidden_size == hidden_size
|
|
127
|
+
assert attn.num_heads == num_heads
|
|
128
|
+
assert attn.num_kv_heads == num_kv_heads
|
|
129
|
+
assert attn.rope_theta == rope_theta
|
|
130
|
+
assert attn.head_dim_original == original_head_dim
|
|
131
|
+
assert attn.head_dim == head_dim
|
|
132
|
+
assert attn.q_proj.kernel.shape == (hidden_size, num_heads, head_dim)
|
|
133
|
+
assert attn.k_proj.kernel.shape == (hidden_size, num_kv_heads,
|
|
134
|
+
head_dim)
|
|
135
|
+
assert attn.v_proj.kernel.shape == (hidden_size, num_kv_heads,
|
|
136
|
+
head_dim)
|
|
137
|
+
assert attn.o_proj.kernel.shape == (num_heads, head_dim, hidden_size)
|
|
138
|
+
|
|
139
|
+
mlp = layers[0].mlp
|
|
140
|
+
assert mlp.gate_proj.kernel.shape == (hidden_size, intermediate_size)
|
|
141
|
+
assert mlp.up_proj.kernel.shape == (hidden_size, intermediate_size)
|
|
142
|
+
assert mlp.down_proj.kernel.shape == (intermediate_size, hidden_size)
|
|
143
|
+
|
|
144
|
+
# Test model load
|
|
145
|
+
model.load_weights(rng)
|
|
146
|
+
|
|
147
|
+
# Test model forward
|
|
148
|
+
kv_caches = create_kv_caches(
|
|
149
|
+
num_blocks=4,
|
|
150
|
+
block_size=32,
|
|
151
|
+
num_kv_heads=num_kv_heads,
|
|
152
|
+
head_size=head_dim,
|
|
153
|
+
mesh=mesh,
|
|
154
|
+
layer_names=["layer"] * hf_config.num_hidden_layers,
|
|
155
|
+
cache_dtype=jnp.float8_e4m3fn
|
|
156
|
+
if mock_vllm_config.cache_config.cache_dtype == "fp8" else
|
|
157
|
+
jnp.bfloat16)
|
|
158
|
+
# 1 seq with 16 tokens
|
|
159
|
+
input_ids, attention_metadata, indices_do_sample = mock_model_inputs
|
|
160
|
+
kv_caches, hidden_states, aux_hidden_states = model(
|
|
161
|
+
kv_caches, input_ids, attention_metadata)
|
|
162
|
+
assert hidden_states.shape == (8, hidden_size)
|
|
163
|
+
assert len(aux_hidden_states) == 0
|
|
164
|
+
|
|
165
|
+
hidden_states = hidden_states[indices_do_sample]
|
|
166
|
+
assert hidden_states.shape == (1, hidden_size)
|
|
167
|
+
|
|
168
|
+
logits = model.compute_logits(hidden_states)
|
|
169
|
+
assert logits.shape == (1, hf_config.vocab_size)
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# Test for LoRA weight loading API
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import tempfile
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
import numpy as np
|
|
12
|
+
from flax import nnx
|
|
13
|
+
from jax._src import test_util as jtu
|
|
14
|
+
from jax.sharding import Mesh
|
|
15
|
+
from safetensors.numpy import save_file
|
|
16
|
+
|
|
17
|
+
from tpu_inference.models.jax.utils.weight_utils import (
|
|
18
|
+
MetadataMap, load_hf_weights, transfer_state_with_mappings)
|
|
19
|
+
|
|
20
|
+
# ----- nnx.Module Wrappers -----
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SourceLayer(nnx.Module):
|
|
24
|
+
|
|
25
|
+
def __init__(self, rngs):
|
|
26
|
+
self.kernel = nnx.Param(jax.random.normal(rngs(), (4, 4)))
|
|
27
|
+
self.bias = nnx.Param(jax.random.normal(rngs(), (4, )))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SourceModel(nnx.Module):
|
|
31
|
+
|
|
32
|
+
def __init__(self, rngs):
|
|
33
|
+
self.src_lm_head = nnx.Param(jax.random.normal(rngs(), (2, 4)))
|
|
34
|
+
self.layers = {0: SourceLayer(rngs)}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TargetLinear(nnx.Module):
|
|
38
|
+
|
|
39
|
+
def __init__(self, rngs):
|
|
40
|
+
self.kernel = nnx.Param(jnp.zeros((4, 4)))
|
|
41
|
+
self.bias = nnx.Param(jnp.zeros((4, )))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TargetBlock(nnx.Module):
|
|
45
|
+
|
|
46
|
+
def __init__(self, rngs):
|
|
47
|
+
self.mlp = {"up_proj": TargetLinear(rngs)}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TargetModel(nnx.Module):
|
|
51
|
+
|
|
52
|
+
def __init__(self, rngs):
|
|
53
|
+
self.tgt_lm_head = nnx.Param(jnp.zeros((2, 4)))
|
|
54
|
+
self.model = {"layers": {0: TargetBlock(rngs)}}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# ----- Test -----
|
|
58
|
+
class WeightTransfer(jtu.JaxTestCase):
|
|
59
|
+
|
|
60
|
+
def test_transfer_state(self):
|
|
61
|
+
rng = nnx.Rngs(0)
|
|
62
|
+
src_model = SourceModel(rng)
|
|
63
|
+
tgt_model = TargetModel(rng)
|
|
64
|
+
|
|
65
|
+
# Get split states
|
|
66
|
+
_, src_state = nnx.split(src_model)
|
|
67
|
+
_, tgt_state = nnx.split(tgt_model)
|
|
68
|
+
|
|
69
|
+
# Overwrite known values
|
|
70
|
+
src_state["layers"][0]["kernel"].value = jnp.ones((4, 4)) * 42.0
|
|
71
|
+
src_state["layers"][0]["bias"].value = jnp.ones((4, )) * 7.0
|
|
72
|
+
src_state["src_lm_head"].value = jnp.ones((2, 4)) * 6.0
|
|
73
|
+
# Mapping for both kernel and bias
|
|
74
|
+
mappings = {
|
|
75
|
+
"layers.*.kernel": ("model.layers.*.mlp.up_proj.kernel", (None, )),
|
|
76
|
+
"layers.*.bias": ("model.layers.*.mlp.up_proj.bias", (None, )),
|
|
77
|
+
"src_lm_head": ("tgt_lm_head", (None, None)),
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
# Transfer
|
|
81
|
+
new_tgt_state = transfer_state_with_mappings(src_state, tgt_state,
|
|
82
|
+
mappings)
|
|
83
|
+
|
|
84
|
+
# Assert correctness
|
|
85
|
+
assert jnp.allclose(
|
|
86
|
+
new_tgt_state["model"]["layers"][0]["mlp"]["up_proj"]
|
|
87
|
+
["kernel"].value, 42.0)
|
|
88
|
+
assert jnp.allclose(
|
|
89
|
+
new_tgt_state["model"]["layers"][0]["mlp"]["up_proj"]
|
|
90
|
+
["bias"].value, 7.0)
|
|
91
|
+
assert jnp.allclose(new_tgt_state["tgt_lm_head"].value, 6.0)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# ----- Mocks for dtype test -----
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class DtypeTestModel(nnx.Module):
|
|
98
|
+
|
|
99
|
+
def __init__(self, dtype: jnp.dtype, rngs: nnx.Rngs):
|
|
100
|
+
self.weight_to_cast = nnx.Param(jnp.zeros((2, 2), dtype=dtype))
|
|
101
|
+
self.weight_to_keep = nnx.Param(jnp.zeros((2, 2), dtype=dtype))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@dataclass
|
|
105
|
+
class MockModelConfig:
|
|
106
|
+
model: str
|
|
107
|
+
dtype: jnp.dtype
|
|
108
|
+
hf_config: Any = None
|
|
109
|
+
|
|
110
|
+
def get_vocab_size(self):
|
|
111
|
+
return 1
|
|
112
|
+
|
|
113
|
+
def get_hidden_size(self):
|
|
114
|
+
return 1
|
|
115
|
+
|
|
116
|
+
def get_head_size(self):
|
|
117
|
+
return 1
|
|
118
|
+
|
|
119
|
+
is_multimodal_model: bool = False
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class MockLoadConfig:
|
|
124
|
+
download_dir: str
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass
|
|
128
|
+
class MockVllmConfig:
|
|
129
|
+
model_config: MockModelConfig
|
|
130
|
+
load_config: MockLoadConfig
|
|
131
|
+
speculative_config: Any = None
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class WeightLoadingDtypeTest(jtu.JaxTestCase):
|
|
135
|
+
|
|
136
|
+
def setUp(self):
|
|
137
|
+
super().setUp()
|
|
138
|
+
self.tempdir = tempfile.TemporaryDirectory()
|
|
139
|
+
self.addCleanup(self.tempdir.cleanup)
|
|
140
|
+
|
|
141
|
+
# Create dummy safetensors file
|
|
142
|
+
tensors = {
|
|
143
|
+
"weight_to_cast.weight": np.ones((2, 2), dtype=np.float32),
|
|
144
|
+
"weight_to_keep.weight": np.ones((2, 2), dtype=np.float32),
|
|
145
|
+
}
|
|
146
|
+
self.safetensors_path = os.path.join(self.tempdir.name,
|
|
147
|
+
"model.safetensors")
|
|
148
|
+
save_file(tensors, self.safetensors_path)
|
|
149
|
+
|
|
150
|
+
def test_keep_original_dtype(self):
|
|
151
|
+
rng = nnx.Rngs(0)
|
|
152
|
+
model_dtype = jnp.bfloat16
|
|
153
|
+
model = DtypeTestModel(dtype=model_dtype, rngs=rng)
|
|
154
|
+
|
|
155
|
+
mock_model_config = MockModelConfig(model=self.tempdir.name,
|
|
156
|
+
dtype=model_dtype)
|
|
157
|
+
mock_load_config = MockLoadConfig(download_dir=self.tempdir.name)
|
|
158
|
+
vllm_config = MockVllmConfig(model_config=mock_model_config,
|
|
159
|
+
load_config=mock_load_config)
|
|
160
|
+
|
|
161
|
+
mesh = Mesh(jax.devices(), ("model", ))
|
|
162
|
+
|
|
163
|
+
name_map = {
|
|
164
|
+
"weight_to_cast": "weight_to_cast",
|
|
165
|
+
"weight_to_keep": "weight_to_keep",
|
|
166
|
+
}
|
|
167
|
+
metadata_map = MetadataMap(name_map=name_map)
|
|
168
|
+
|
|
169
|
+
keep_original_dtype_keys_regex = [r"weight_to_keep.*"]
|
|
170
|
+
|
|
171
|
+
load_hf_weights(
|
|
172
|
+
vllm_config=vllm_config,
|
|
173
|
+
model=model,
|
|
174
|
+
metadata_map=metadata_map,
|
|
175
|
+
mesh=mesh,
|
|
176
|
+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
self.assertEqual(model.weight_to_cast.value.dtype, model_dtype)
|
|
180
|
+
self.assertEqual(model.weight_to_keep.value.dtype, jnp.float32)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# test_multi_modal_utils.py
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
import numpy as np
|
|
18
|
+
import pytest
|
|
19
|
+
|
|
20
|
+
from tpu_inference.models.jax.utils.multi_modal_utils import (
|
|
21
|
+
MultiModalEmbeddings, NestedTensors, flatten_embeddings,
|
|
22
|
+
merge_multimodal_embeddings, sanity_check_mm_encoder_outputs)
|
|
23
|
+
|
|
24
|
+
# --- Tests for sanity_check_mm_encoder_outputs ---
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_sanity_check_valid_list():
|
|
28
|
+
"""Tests sanity_check with a valid list of 2D embeddings."""
|
|
29
|
+
embeddings: MultiModalEmbeddings = [
|
|
30
|
+
jnp.ones((10, 128)), jnp.ones((15, 128))
|
|
31
|
+
]
|
|
32
|
+
sanity_check_mm_encoder_outputs(embeddings, 2)
|
|
33
|
+
# No assertion error expected
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def test_sanity_check_valid_tuple():
|
|
37
|
+
"""Tests sanity_check with a valid tuple of 2D embeddings."""
|
|
38
|
+
embeddings: MultiModalEmbeddings = (jnp.ones((10, 128)), jnp.ones(
|
|
39
|
+
(15, 128)))
|
|
40
|
+
sanity_check_mm_encoder_outputs(embeddings, 2)
|
|
41
|
+
# No assertion error expected
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def test_sanity_check_valid_3d_jax_array():
|
|
45
|
+
"""Tests sanity_check with a valid 3D jax.Array."""
|
|
46
|
+
embeddings: MultiModalEmbeddings = jnp.ones((2, 10, 128))
|
|
47
|
+
# This is valid because mm_embeddings is iterable, and each item (e)
|
|
48
|
+
# in the first dimension has e.ndim == 2.
|
|
49
|
+
sanity_check_mm_encoder_outputs(embeddings, 2)
|
|
50
|
+
# No assertion error expected
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def test_sanity_check_invalid_type():
|
|
54
|
+
"""Tests sanity_check with an invalid type for embeddings."""
|
|
55
|
+
with pytest.raises(
|
|
56
|
+
AssertionError,
|
|
57
|
+
match=
|
|
58
|
+
"Expected multimodal embeddings to be a list/tuple of 2D tensors"):
|
|
59
|
+
sanity_check_mm_encoder_outputs("not a tensor", 1)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_sanity_check_wrong_num_items():
|
|
63
|
+
"""Tests sanity_check with a mismatch in the number of embeddings."""
|
|
64
|
+
embeddings: MultiModalEmbeddings = [jnp.ones((10, 128))]
|
|
65
|
+
with pytest.raises(
|
|
66
|
+
AssertionError,
|
|
67
|
+
match="Expected number of multimodal embeddings to match number of"
|
|
68
|
+
):
|
|
69
|
+
sanity_check_mm_encoder_outputs(embeddings, 2)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_sanity_check_wrong_dimensions_in_list():
|
|
73
|
+
"""Tests sanity_check with non-2D tensors within the list."""
|
|
74
|
+
embeddings: MultiModalEmbeddings = [jnp.ones((10, 128, 1))]
|
|
75
|
+
with pytest.raises(
|
|
76
|
+
AssertionError,
|
|
77
|
+
match=
|
|
78
|
+
"Expected multimodal embeddings to be a sequence of 2D tensors"):
|
|
79
|
+
sanity_check_mm_encoder_outputs(embeddings, 1)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# --- Tests for flatten_embeddings ---
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def test_flatten_single_array():
|
|
86
|
+
"""Tests flatten_embeddings with a single 2D array."""
|
|
87
|
+
emb: NestedTensors = jnp.arange(12).reshape((3, 4))
|
|
88
|
+
result = flatten_embeddings(emb)
|
|
89
|
+
np.testing.assert_array_equal(result, emb)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_flatten_single_3d_array():
|
|
93
|
+
"""Tests flatten_embeddings with a single 3D array."""
|
|
94
|
+
emb: NestedTensors = jnp.arange(24).reshape((2, 3, 4))
|
|
95
|
+
result = flatten_embeddings(emb)
|
|
96
|
+
expected = jnp.arange(24).reshape((6, 4))
|
|
97
|
+
np.testing.assert_array_equal(result, expected)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def test_flatten_list_of_arrays():
|
|
101
|
+
"""Tests flatten_embeddings with a list of 2D arrays."""
|
|
102
|
+
emb: NestedTensors = [
|
|
103
|
+
jnp.arange(12).reshape((3, 4)),
|
|
104
|
+
jnp.arange(12, 20).reshape((2, 4))
|
|
105
|
+
]
|
|
106
|
+
result = flatten_embeddings(emb)
|
|
107
|
+
expected = jnp.arange(20).reshape((5, 4))
|
|
108
|
+
np.testing.assert_array_equal(result, expected)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def test_flatten_nested_list():
|
|
112
|
+
"""Tests flatten_embeddings with a nested list of arrays."""
|
|
113
|
+
emb: NestedTensors = [
|
|
114
|
+
jnp.arange(6).reshape((2, 3)),
|
|
115
|
+
[
|
|
116
|
+
jnp.arange(6, 12).reshape((2, 3)),
|
|
117
|
+
jnp.arange(12, 15).reshape((1, 3))
|
|
118
|
+
]
|
|
119
|
+
]
|
|
120
|
+
result = flatten_embeddings(emb)
|
|
121
|
+
expected = jnp.arange(15).reshape((5, 3))
|
|
122
|
+
np.testing.assert_array_equal(result, expected)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# --- Tests for merge_multimodal_embeddings ---
|
|
126
|
+
|
|
127
|
+
EMBED_DIM = 4
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@pytest.fixture
|
|
131
|
+
def base_embeds():
|
|
132
|
+
return jnp.zeros((8, EMBED_DIM))
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def test_merge_single_placeholder(base_embeds):
|
|
136
|
+
"""Tests merging with a single integer placeholder ID."""
|
|
137
|
+
input_ids = jnp.array([1, 2, -1, -1, 3, 4, -1, 5])
|
|
138
|
+
inputs_embeds = base_embeds[:len(input_ids)]
|
|
139
|
+
mm_embeds: NestedTensors = jnp.arange(3 * EMBED_DIM).reshape(
|
|
140
|
+
(3, EMBED_DIM))
|
|
141
|
+
result = merge_multimodal_embeddings(input_ids,
|
|
142
|
+
inputs_embeds,
|
|
143
|
+
mm_embeds,
|
|
144
|
+
placeholder_token_id=-1)
|
|
145
|
+
expected = np.array(inputs_embeds)
|
|
146
|
+
expected[input_ids == -1] = mm_embeds
|
|
147
|
+
np.testing.assert_array_equal(result, expected)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def test_merge_no_placeholders(base_embeds):
|
|
151
|
+
"""Tests merging when no placeholder tokens are in input_ids."""
|
|
152
|
+
input_ids = jnp.array([1, 2, 3, 4])
|
|
153
|
+
inputs_embeds = jnp.arange(len(input_ids) * EMBED_DIM).reshape(
|
|
154
|
+
(len(input_ids), EMBED_DIM))
|
|
155
|
+
mm_embeds: NestedTensors = jnp.empty((0, EMBED_DIM))
|
|
156
|
+
|
|
157
|
+
# Based on the provided traceback, this raises a TypeError within JAX's gather.
|
|
158
|
+
with pytest.raises(
|
|
159
|
+
TypeError,
|
|
160
|
+
match="Slice size at index 0 in gather op is out of range"):
|
|
161
|
+
merge_multimodal_embeddings(input_ids,
|
|
162
|
+
inputs_embeds,
|
|
163
|
+
mm_embeds,
|
|
164
|
+
placeholder_token_id=-1)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@pytest.mark.parametrize("placeholder_id", [-1, [-1, -2]])
|
|
168
|
+
def test_merge_mm_embeds_count_too_few(placeholder_id, base_embeds):
|
|
169
|
+
"""
|
|
170
|
+
Tests behavior when fewer embeddings are provided than placeholders.
|
|
171
|
+
Based on the test results provided, this scenario does NOT raise an error
|
|
172
|
+
in the testing environment.
|
|
173
|
+
"""
|
|
174
|
+
input_ids = jnp.array([1, 2, -1, -1, 3]) # 2 placeholders
|
|
175
|
+
inputs_embeds = base_embeds[:len(input_ids)]
|
|
176
|
+
mm_embeds_too_few: NestedTensors = jnp.ones((1, EMBED_DIM))
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
# We are only asserting that this call does not crash.
|
|
180
|
+
# The actual output in this unexpected case is not being tested.
|
|
181
|
+
merge_multimodal_embeddings(input_ids,
|
|
182
|
+
inputs_embeds,
|
|
183
|
+
mm_embeds_too_few,
|
|
184
|
+
placeholder_token_id=placeholder_id)
|
|
185
|
+
except Exception as e:
|
|
186
|
+
pytest.fail(
|
|
187
|
+
f"Did not expect an exception based on test logs, but got {type(e).__name__}: {e}"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@pytest.mark.parametrize("placeholder_id", [-1, [-1, -2]])
|
|
192
|
+
def test_merge_mm_embeds_count_too_many_no_raise(placeholder_id, base_embeds):
|
|
193
|
+
"""Tests that no error is raised if mm_embeds are too many; extras are ignored."""
|
|
194
|
+
input_ids = jnp.array([1, 2, -1, -1, 3]) # 2 placeholders
|
|
195
|
+
inputs_embeds = base_embeds[:len(input_ids)]
|
|
196
|
+
mm_embeds_too_many: NestedTensors = jnp.arange(3 * EMBED_DIM).reshape(
|
|
197
|
+
(3, EMBED_DIM))
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
result = merge_multimodal_embeddings(
|
|
201
|
+
input_ids,
|
|
202
|
+
inputs_embeds,
|
|
203
|
+
mm_embeds_too_many,
|
|
204
|
+
placeholder_token_id=placeholder_id)
|
|
205
|
+
# Check that the first 2 embeddings from mm_embeds_too_many were used.
|
|
206
|
+
expected = np.array(inputs_embeds)
|
|
207
|
+
is_mm = np.isin(input_ids, np.array(placeholder_id))
|
|
208
|
+
expected[is_mm] = flatten_embeddings(mm_embeds_too_many)[:2]
|
|
209
|
+
np.testing.assert_array_equal(result, expected)
|
|
210
|
+
except Exception as e:
|
|
211
|
+
pytest.fail(
|
|
212
|
+
f"Did not expect an exception, but got {type(e).__name__}: {e}")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import pytest
|
|
18
|
+
import torch
|
|
19
|
+
from vllm.config import CacheConfig, VllmConfig
|
|
20
|
+
|
|
21
|
+
from tpu_inference.platforms.tpu_platform import TpuPlatform
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TestTpuPlatform:
|
|
25
|
+
|
|
26
|
+
@pytest.fixture
|
|
27
|
+
def vllm_config(self):
|
|
28
|
+
cache_config = CacheConfig(block_size=16,
|
|
29
|
+
gpu_memory_utilization=0.9,
|
|
30
|
+
swap_space=4,
|
|
31
|
+
cache_dtype="fp8")
|
|
32
|
+
|
|
33
|
+
vllm_config = MagicMock(spec=VllmConfig)
|
|
34
|
+
vllm_config.cache_config = cache_config
|
|
35
|
+
vllm_config.model_config = MagicMock(dtype='bfloat16')
|
|
36
|
+
vllm_config.scheduler_config = MagicMock(is_multimodal_model=False)
|
|
37
|
+
vllm_config.parallel_config = MagicMock()
|
|
38
|
+
vllm_config.compilation_config = MagicMock(mode="dynamo_trace_once",
|
|
39
|
+
backend="openxla")
|
|
40
|
+
vllm_config.kv_transfer_config = None
|
|
41
|
+
return vllm_config
|
|
42
|
+
|
|
43
|
+
@pytest.mark.parametrize("chip_name,expected_dtype", [
|
|
44
|
+
("v6e", torch.float8_e5m2),
|
|
45
|
+
("v5e", torch.float8_e4m3fn),
|
|
46
|
+
])
|
|
47
|
+
def test_fp8_dtype(self, chip_name, expected_dtype):
|
|
48
|
+
mock_chip_type = MagicMock()
|
|
49
|
+
mock_chip_type.name = chip_name
|
|
50
|
+
|
|
51
|
+
with patch('tpu_inference.platforms.tpu_platform.init_logger'), \
|
|
52
|
+
patch('tpu_inference.platforms.tpu_platform.device.get_local_chips', return_value=(mock_chip_type, None)), \
|
|
53
|
+
patch('vllm.envs.VLLM_TPU_USING_PATHWAYS', False):
|
|
54
|
+
assert TpuPlatform.fp8_dtype() == expected_dtype
|
tests/runner/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|