tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import field
|
|
16
|
+
from types import SimpleNamespace
|
|
17
|
+
from typing import Any, Tuple
|
|
18
|
+
from unittest.mock import MagicMock, patch
|
|
19
|
+
|
|
20
|
+
import jax
|
|
21
|
+
import jax.numpy as jnp
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pytest
|
|
24
|
+
from flax.typing import PRNGKey
|
|
25
|
+
from jax.sharding import Mesh
|
|
26
|
+
from vllm.config import ModelConfig
|
|
27
|
+
|
|
28
|
+
from tpu_inference.models.jax.llama_guard_4 import (LlamaGuard4ForCausalLM,
|
|
29
|
+
LlamaGuard4WeightLoader)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MockParamLlamaGuard4:
|
|
33
|
+
"""A mock for a parameter used in the LlamaGuard4 model."""
|
|
34
|
+
shape: Tuple[int, ...]
|
|
35
|
+
dtype: jnp.dtype = jnp.bfloat16
|
|
36
|
+
sharding_spec: Tuple[str | None, ...] | None = None
|
|
37
|
+
value: Any = field(init=False)
|
|
38
|
+
sharding: Any = field(init=False)
|
|
39
|
+
|
|
40
|
+
def __init__(self, shape=(32, 128)):
|
|
41
|
+
self.shape = shape
|
|
42
|
+
self.value = jnp.zeros(self.shape, dtype=self.dtype)
|
|
43
|
+
# The sharding spec is accessed during weight loading
|
|
44
|
+
self.sharding = SimpleNamespace(spec=self.sharding_spec)
|
|
45
|
+
|
|
46
|
+
# Allow the mock parameter's value to be updated
|
|
47
|
+
def __setattr__(self, name, value):
|
|
48
|
+
if name in ['value', 'shape', 'dtype', 'sharding', 'sharding_spec']:
|
|
49
|
+
self.__dict__[name] = value
|
|
50
|
+
else:
|
|
51
|
+
super().__setattr__(name, value)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MockVllmConfig:
|
|
55
|
+
"""A mock VllmConfig sufficient for testing the LlamaGuard4 model."""
|
|
56
|
+
|
|
57
|
+
def __init__(self,
|
|
58
|
+
model_name: str,
|
|
59
|
+
random_weights: bool = False,
|
|
60
|
+
tensor_parallelism: int = 1):
|
|
61
|
+
self.model_config = MagicMock(spec=ModelConfig)
|
|
62
|
+
self.load_config = MagicMock()
|
|
63
|
+
self.load_config.download_dir = None
|
|
64
|
+
|
|
65
|
+
# Downsizing the following to avoid OOM
|
|
66
|
+
self.model_config.get_vocab_size.return_value = 1024
|
|
67
|
+
self.model_config.get_hidden_size.return_value = 128
|
|
68
|
+
self.model_config.model = model_name
|
|
69
|
+
|
|
70
|
+
self.additional_config = {
|
|
71
|
+
"random_weights": random_weights,
|
|
72
|
+
"sharding": {
|
|
73
|
+
"sharding_strategy": {
|
|
74
|
+
"tensor_parallelism": tensor_parallelism
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
self.cache_config = MagicMock(cache_dtype="auto")
|
|
80
|
+
|
|
81
|
+
# Mock the underlying HF config values for parameter detection
|
|
82
|
+
# Downsized to avoid OOM
|
|
83
|
+
text_config_mock = MagicMock()
|
|
84
|
+
text_config_mock.num_attention_heads = 4
|
|
85
|
+
text_config_mock.num_key_value_heads = 2
|
|
86
|
+
text_config_mock.head_dim = 32
|
|
87
|
+
|
|
88
|
+
hf_config_mock = MagicMock()
|
|
89
|
+
hf_config_mock.text_config = text_config_mock
|
|
90
|
+
|
|
91
|
+
self.model_config.hf_config = hf_config_mock
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@pytest.fixture(scope="module")
|
|
95
|
+
def mesh():
|
|
96
|
+
"""
|
|
97
|
+
Creates a mesh with all required axes for testing.
|
|
98
|
+
"""
|
|
99
|
+
if not jax.devices():
|
|
100
|
+
pytest.skip("No JAX devices available for mesh creation.")
|
|
101
|
+
|
|
102
|
+
devices = np.array(jax.local_devices())
|
|
103
|
+
|
|
104
|
+
num_devices = len(devices)
|
|
105
|
+
device_mesh = devices.reshape((num_devices, 1, 1, 1))
|
|
106
|
+
|
|
107
|
+
with Mesh(device_mesh,
|
|
108
|
+
axis_names=('data', 'attn_dp', 'model', 'expert')) as m:
|
|
109
|
+
yield m
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@pytest.fixture
|
|
113
|
+
def rng() -> PRNGKey:
|
|
114
|
+
"""Provides a reusable JAX PRNGKey."""
|
|
115
|
+
return jax.random.PRNGKey(42)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@pytest.fixture
|
|
119
|
+
def mock_vllm_config_llama_guard_4() -> MockVllmConfig:
|
|
120
|
+
return MockVllmConfig(model_name="meta-llama/Llama-Guard-4-12B")
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class TestLlamaGuard4ForCausalLM:
|
|
124
|
+
"""Tests for the main LlamaGuard4ForCausalLM model class."""
|
|
125
|
+
|
|
126
|
+
def test_init_llama_guard_4(self, mock_vllm_config_llama_guard_4, rng,
|
|
127
|
+
mesh):
|
|
128
|
+
"""Tests correct initialization and parameter detection."""
|
|
129
|
+
model = LlamaGuard4ForCausalLM(mock_vllm_config_llama_guard_4, rng,
|
|
130
|
+
mesh)
|
|
131
|
+
|
|
132
|
+
# Check model name is correctly set in the config
|
|
133
|
+
assert "llama-guard-4" in model.vllm_config.model_config.model.lower()
|
|
134
|
+
|
|
135
|
+
assert model.hidden_size == 128
|
|
136
|
+
|
|
137
|
+
def test_create_model_with_random_weights(self,
|
|
138
|
+
mock_vllm_config_llama_guard_4,
|
|
139
|
+
rng, mesh):
|
|
140
|
+
"""
|
|
141
|
+
Tests that random weight initialization creates concrete, non-zero-variance arrays.
|
|
142
|
+
"""
|
|
143
|
+
with jax.set_mesh(mesh):
|
|
144
|
+
model = LlamaGuard4ForCausalLM(
|
|
145
|
+
vllm_config=mock_vllm_config_llama_guard_4,
|
|
146
|
+
rng=rng,
|
|
147
|
+
mesh=mesh,
|
|
148
|
+
force_random_weights=True)
|
|
149
|
+
|
|
150
|
+
embedding_weight = model.embedder.input_embedding_table_VD.value
|
|
151
|
+
attention_q_kernel = model.layers[0].attn.kernel_q_proj_DNH.value
|
|
152
|
+
final_norm_scale = model.final_norm.scale.value
|
|
153
|
+
|
|
154
|
+
assert isinstance(embedding_weight, jax.Array)
|
|
155
|
+
assert isinstance(attention_q_kernel, jax.Array)
|
|
156
|
+
assert isinstance(final_norm_scale, jax.Array)
|
|
157
|
+
|
|
158
|
+
assert jnp.std(embedding_weight) > 0
|
|
159
|
+
assert jnp.std(attention_q_kernel) > 0
|
|
160
|
+
|
|
161
|
+
assert jnp.all(final_norm_scale == 1.0)
|
|
162
|
+
|
|
163
|
+
@patch("tpu_inference.models.jax.llama_guard_4.LlamaGuard4WeightLoader")
|
|
164
|
+
def test_load_weights_called_correctly(self, mock_loader_cls, rng, mesh):
|
|
165
|
+
"""Tests that the weight loader is called correctly for checkpoint loading."""
|
|
166
|
+
vllm_config = MockVllmConfig(model_name="llama-guard-4-test",
|
|
167
|
+
random_weights=False)
|
|
168
|
+
model = LlamaGuard4ForCausalLM(vllm_config, rng, mesh)
|
|
169
|
+
|
|
170
|
+
mock_loader_instance = MagicMock()
|
|
171
|
+
mock_loader_cls.return_value = mock_loader_instance
|
|
172
|
+
model.load_weights(rng)
|
|
173
|
+
|
|
174
|
+
mock_loader_cls.assert_called_once_with(vllm_config=vllm_config,
|
|
175
|
+
hidden_size=128,
|
|
176
|
+
attn_heads=4,
|
|
177
|
+
num_key_value_heads=2,
|
|
178
|
+
attn_head_dim=32)
|
|
179
|
+
mock_loader_instance.load_weights.assert_called_once_with(model)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class TestLlamaGuard4WeightLoader:
|
|
183
|
+
"""Tests for the LlamaGuard4WeightLoader class."""
|
|
184
|
+
|
|
185
|
+
@pytest.fixture
|
|
186
|
+
def weight_loader(self):
|
|
187
|
+
return LlamaGuard4WeightLoader(
|
|
188
|
+
vllm_config=MockVllmConfig("test-model"),
|
|
189
|
+
hidden_size=5120,
|
|
190
|
+
attn_heads=40,
|
|
191
|
+
num_key_value_heads=8,
|
|
192
|
+
attn_head_dim=128)
|
|
193
|
+
|
|
194
|
+
@pytest.mark.parametrize("hf_key, expected", [
|
|
195
|
+
("language_model.model.layers.15.self_attn.q_proj.weight",
|
|
196
|
+
"layers.15.attn.kernel_q_proj_DNH"),
|
|
197
|
+
("language_model.model.layers.0.feed_forward.gate_proj.weight",
|
|
198
|
+
"layers.0.custom_module.kernel_gating_DF"),
|
|
199
|
+
("language_model.model.embed_tokens.weight",
|
|
200
|
+
"embedder.input_embedding_table_VD"),
|
|
201
|
+
("language_model.model.norm.weight", "final_norm.scale"),
|
|
202
|
+
("language_model.lm_head.weight", "lm_head.input_embedding_table_DV"),
|
|
203
|
+
("unmapped.key.name", "unmapped.key.name"),
|
|
204
|
+
])
|
|
205
|
+
def test_map_loaded_to_standardized_name(self, weight_loader, hf_key,
|
|
206
|
+
expected):
|
|
207
|
+
"""Tests the mapping from HuggingFace key names to internal names."""
|
|
208
|
+
assert weight_loader.map_loaded_to_standardized_name(
|
|
209
|
+
hf_key) == expected
|
|
210
|
+
|
|
211
|
+
def test_load_weights_transformation(self, weight_loader, rng, mesh):
|
|
212
|
+
"""Tests that weights are correctly reshaped, transposed, and loaded."""
|
|
213
|
+
vllm_config = MockVllmConfig(model_name="llama-guard-4-small-test",
|
|
214
|
+
random_weights=False)
|
|
215
|
+
|
|
216
|
+
model = LlamaGuard4ForCausalLM(vllm_config, rng, mesh)
|
|
217
|
+
|
|
218
|
+
hidden_size = 5120
|
|
219
|
+
vocab_size = 202048
|
|
220
|
+
|
|
221
|
+
original_weight = jnp.ones((vocab_size, hidden_size))
|
|
222
|
+
dummy_weights = [
|
|
223
|
+
("language_model.model.embed_tokens.weight", original_weight),
|
|
224
|
+
]
|
|
225
|
+
weight_loader.names_and_weights_generator = dummy_weights
|
|
226
|
+
|
|
227
|
+
# Mock get_param to return a mock param with the target shape
|
|
228
|
+
mock_param = MockParamLlamaGuard4(shape=(vocab_size, hidden_size))
|
|
229
|
+
|
|
230
|
+
with patch("tpu_inference.models.jax.llama_guard_4.get_param", return_value=mock_param), \
|
|
231
|
+
patch("tpu_inference.models.jax.llama_guard_4.shard_put", return_value=jnp.ones(mock_param.value.shape)) as mock_shard_put:
|
|
232
|
+
|
|
233
|
+
weight_loader.load_weights(model)
|
|
234
|
+
|
|
235
|
+
# Assert that shard_put was called with the correctly transposed weight
|
|
236
|
+
mock_shard_put.assert_called_once()
|
|
237
|
+
|
|
238
|
+
# Get the actual array passed to shard_put
|
|
239
|
+
called_with_weight = mock_shard_put.call_args[0][0]
|
|
240
|
+
|
|
241
|
+
# Check if the shape of the array passed to shard_put matches the model's expected shape.
|
|
242
|
+
assert called_with_weight.shape == mock_param.value.shape
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
from flax import nnx
|
|
22
|
+
from flax.typing import PRNGKey
|
|
23
|
+
from jax.sharding import Mesh
|
|
24
|
+
from vllm.config import ModelConfig
|
|
25
|
+
|
|
26
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
28
|
+
from tpu_inference.runner.kv_cache import create_kv_caches
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MockVllmConfig:
|
|
32
|
+
"""A mock VllmConfig sufficient for testing the Qwen2 model."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, model: str, kv_cache_dtype: str):
|
|
35
|
+
self.model_config = ModelConfig(model)
|
|
36
|
+
self.model_config.dtype = jnp.bfloat16
|
|
37
|
+
self.load_config = MagicMock()
|
|
38
|
+
self.load_config.download_dir = None
|
|
39
|
+
self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.fixture(scope="module")
|
|
43
|
+
def mesh():
|
|
44
|
+
"""
|
|
45
|
+
Creates a mesh with 1 device.
|
|
46
|
+
"""
|
|
47
|
+
if not jax.devices():
|
|
48
|
+
pytest.skip("No JAX devices available for mesh creation.")
|
|
49
|
+
|
|
50
|
+
devices = np.array(jax.local_devices()[:1])
|
|
51
|
+
num_devices = len(devices)
|
|
52
|
+
assert num_devices == 1
|
|
53
|
+
device_mesh = devices.reshape((num_devices, 1, 1, 1))
|
|
54
|
+
|
|
55
|
+
with Mesh(device_mesh,
|
|
56
|
+
axis_names=('data', 'attn_dp', 'expert', 'model')) as m:
|
|
57
|
+
yield m
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture
|
|
61
|
+
def mock_model_inputs():
|
|
62
|
+
num_tokens = 8
|
|
63
|
+
num_reqs = 1
|
|
64
|
+
max_num_blocks_per_req = 4
|
|
65
|
+
input_ids = jnp.ones((num_tokens, ), dtype=jnp.int32)
|
|
66
|
+
positions = jnp.ones((num_tokens, ), dtype=jnp.int32)
|
|
67
|
+
block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req),
|
|
68
|
+
dtype=jnp.int32).reshape(-1)
|
|
69
|
+
seq_lens = jnp.ones((num_reqs, ), dtype=jnp.int32)
|
|
70
|
+
query_start_loc = jnp.ones((num_reqs + 1, ), dtype=jnp.int32)
|
|
71
|
+
request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32)
|
|
72
|
+
|
|
73
|
+
attention_metadata = AttentionMetadata(
|
|
74
|
+
input_positions=positions,
|
|
75
|
+
block_tables=block_tables,
|
|
76
|
+
seq_lens=seq_lens,
|
|
77
|
+
query_start_loc=query_start_loc,
|
|
78
|
+
request_distribution=request_distribution,
|
|
79
|
+
)
|
|
80
|
+
indices_do_sample = jnp.ones((num_reqs, ), dtype=jnp.int32)
|
|
81
|
+
|
|
82
|
+
return (input_ids, attention_metadata, indices_do_sample)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.fixture
|
|
86
|
+
def rng() -> PRNGKey:
|
|
87
|
+
"""Provides a reusable JAX PRNGKey."""
|
|
88
|
+
return jax.random.PRNGKey(42)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class TestQwen2ForCausalLM:
|
|
92
|
+
"""Tests for the main Qwen2ForCausalLM model class."""
|
|
93
|
+
|
|
94
|
+
@pytest.mark.parametrize("mock_vllm_config", [
|
|
95
|
+
MockVllmConfig("Qwen/Qwen2.5-1.5B", "auto"),
|
|
96
|
+
MockVllmConfig("Qwen/Qwen2.5-1.5B", "fp8")
|
|
97
|
+
])
|
|
98
|
+
def test_qwen25_1_5b(self, mock_vllm_config, rng, mesh, mock_model_inputs):
|
|
99
|
+
"""Tests model init and model forward for the 8B model variant."""
|
|
100
|
+
|
|
101
|
+
# Test model init
|
|
102
|
+
model = Qwen2ForCausalLM(mock_vllm_config, rng, mesh)
|
|
103
|
+
assert "1.5b" in model.vllm_config.model_config.model.lower()
|
|
104
|
+
|
|
105
|
+
model_config = mock_vllm_config.model_config
|
|
106
|
+
hf_config = model_config.hf_config
|
|
107
|
+
|
|
108
|
+
assert model.mesh.shape == {
|
|
109
|
+
"data": 1,
|
|
110
|
+
"attn_dp": 1,
|
|
111
|
+
"expert": 1,
|
|
112
|
+
"model": 1
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
layers = model.model.layers
|
|
116
|
+
assert len(layers) == hf_config.num_hidden_layers
|
|
117
|
+
assert isinstance(model.rng, nnx.Rngs)
|
|
118
|
+
assert model.model.lm_head == model.model.embed.embedding
|
|
119
|
+
|
|
120
|
+
attn = layers[0].self_attn
|
|
121
|
+
hidden_size = hf_config.hidden_size
|
|
122
|
+
num_heads = hf_config.num_attention_heads
|
|
123
|
+
num_kv_heads = hf_config.num_key_value_heads
|
|
124
|
+
rope_theta = hf_config.rope_theta
|
|
125
|
+
original_head_dim = hidden_size // num_heads
|
|
126
|
+
head_dim = 128
|
|
127
|
+
intermediate_size = hf_config.intermediate_size
|
|
128
|
+
|
|
129
|
+
assert attn.hidden_size == hidden_size
|
|
130
|
+
assert attn.num_heads == num_heads
|
|
131
|
+
assert attn.num_kv_heads == num_kv_heads
|
|
132
|
+
assert attn.rope_theta == rope_theta
|
|
133
|
+
assert attn.head_dim_original == original_head_dim
|
|
134
|
+
assert attn.head_dim == head_dim
|
|
135
|
+
assert attn.q_proj.kernel.shape == (hidden_size, num_heads, head_dim)
|
|
136
|
+
assert attn.k_proj.kernel.shape == (hidden_size, num_kv_heads,
|
|
137
|
+
head_dim)
|
|
138
|
+
assert attn.v_proj.kernel.shape == (hidden_size, num_kv_heads,
|
|
139
|
+
head_dim)
|
|
140
|
+
assert attn.o_proj.kernel.shape == (num_heads, head_dim, hidden_size)
|
|
141
|
+
|
|
142
|
+
mlp = layers[0].mlp
|
|
143
|
+
assert mlp.gate_proj.kernel.shape == (hidden_size, intermediate_size)
|
|
144
|
+
assert mlp.up_proj.kernel.shape == (hidden_size, intermediate_size)
|
|
145
|
+
assert mlp.down_proj.kernel.shape == (intermediate_size, hidden_size)
|
|
146
|
+
|
|
147
|
+
# Test model load
|
|
148
|
+
model.load_weights(rng)
|
|
149
|
+
|
|
150
|
+
# Test model forward
|
|
151
|
+
kv_caches = create_kv_caches(
|
|
152
|
+
num_blocks=4,
|
|
153
|
+
block_size=32,
|
|
154
|
+
num_kv_heads=num_kv_heads,
|
|
155
|
+
head_size=head_dim,
|
|
156
|
+
mesh=mesh,
|
|
157
|
+
layer_names=["layer"] * hf_config.num_hidden_layers,
|
|
158
|
+
cache_dtype=jnp.float8_e4m3fn
|
|
159
|
+
if mock_vllm_config.cache_config.cache_dtype == "fp8" else
|
|
160
|
+
jnp.bfloat16)
|
|
161
|
+
# 1 seq with 16 tokens
|
|
162
|
+
input_ids, attention_metadata, indices_do_sample = mock_model_inputs
|
|
163
|
+
kv_caches, hidden_states, aux_hidden_states = model(
|
|
164
|
+
kv_caches, input_ids, attention_metadata)
|
|
165
|
+
assert hidden_states.shape == (8, hidden_size)
|
|
166
|
+
assert len(aux_hidden_states) == 0
|
|
167
|
+
|
|
168
|
+
hidden_states = hidden_states[indices_do_sample]
|
|
169
|
+
assert hidden_states.shape == (1, hidden_size)
|
|
170
|
+
|
|
171
|
+
logits = model.compute_logits(hidden_states)
|
|
172
|
+
assert logits.shape == (1, hf_config.vocab_size)
|