tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import field
|
|
16
|
+
from types import SimpleNamespace
|
|
17
|
+
from typing import Any, Tuple
|
|
18
|
+
from unittest.mock import MagicMock, patch
|
|
19
|
+
|
|
20
|
+
import jax
|
|
21
|
+
import jax.numpy as jnp
|
|
22
|
+
import numpy as np
|
|
23
|
+
import pytest
|
|
24
|
+
from flax import nnx
|
|
25
|
+
from flax.typing import PRNGKey
|
|
26
|
+
from jax.sharding import Mesh
|
|
27
|
+
from vllm.config import ModelConfig
|
|
28
|
+
|
|
29
|
+
from tpu_inference.models.jax.llama4 import (Llama4ForCausalLM,
|
|
30
|
+
Llama4WeightLoader)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MockParamLlama4:
|
|
34
|
+
"""A mock for a parameter used in the Llama4 model."""
|
|
35
|
+
shape: Tuple[int, ...]
|
|
36
|
+
dtype: jnp.dtype = jnp.bfloat16
|
|
37
|
+
sharding_spec: Tuple[str | None, ...] | None = None
|
|
38
|
+
value: Any = field(init=False)
|
|
39
|
+
sharding: Any = field(init=False)
|
|
40
|
+
|
|
41
|
+
def __init__(self, shape=(32, 128)):
|
|
42
|
+
self.shape = shape
|
|
43
|
+
self.value = jnp.zeros(self.shape, dtype=self.dtype)
|
|
44
|
+
# The sharding spec is accessed during weight loading
|
|
45
|
+
self.sharding = SimpleNamespace(spec=self.sharding_spec)
|
|
46
|
+
|
|
47
|
+
# Allow the mock parameter's value to be updated
|
|
48
|
+
def __setattr__(self, name, value):
|
|
49
|
+
if name in ['value', 'shape', 'dtype', 'sharding', 'sharding_spec']:
|
|
50
|
+
self.__dict__[name] = value
|
|
51
|
+
else:
|
|
52
|
+
super().__setattr__(name, value)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MockVllmConfig:
|
|
56
|
+
"""A mock VllmConfig sufficient for testing the Llama4 model."""
|
|
57
|
+
|
|
58
|
+
def __init__(self,
|
|
59
|
+
model_name: str,
|
|
60
|
+
random_weights: bool = False,
|
|
61
|
+
tensor_parallelism: int = 1):
|
|
62
|
+
self.model_config = MagicMock(spec=ModelConfig)
|
|
63
|
+
self.load_config = MagicMock()
|
|
64
|
+
self.load_config.download_dir = None
|
|
65
|
+
|
|
66
|
+
# Choose small amount of layers to avoid OOM.
|
|
67
|
+
self.model_config.get_vocab_size.return_value = 202048
|
|
68
|
+
self.model_config.get_hidden_size.return_value = 32
|
|
69
|
+
self.model_config.model = model_name
|
|
70
|
+
|
|
71
|
+
self.additional_config = {
|
|
72
|
+
"random_weights": random_weights,
|
|
73
|
+
"sharding": {
|
|
74
|
+
"sharding_strategy": {
|
|
75
|
+
"tensor_parallelism": tensor_parallelism
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
self.cache_config = MagicMock(cache_dtype="auto")
|
|
81
|
+
|
|
82
|
+
text_config_mock = MagicMock()
|
|
83
|
+
text_config_mock.interleave_moe_layer_step = 1
|
|
84
|
+
text_config_mock.num_attention_heads = 40
|
|
85
|
+
text_config_mock.num_key_value_heads = 8
|
|
86
|
+
text_config_mock.head_dim = 128
|
|
87
|
+
|
|
88
|
+
hf_config_mock = MagicMock()
|
|
89
|
+
hf_config_mock.text_config = text_config_mock
|
|
90
|
+
|
|
91
|
+
self.model_config.hf_config = hf_config_mock
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@pytest.fixture(scope="module")
|
|
95
|
+
def mesh():
|
|
96
|
+
"""
|
|
97
|
+
Creates a mesh with all required axes for testing.
|
|
98
|
+
"""
|
|
99
|
+
if not jax.devices():
|
|
100
|
+
pytest.skip("No JAX devices available for mesh creation.")
|
|
101
|
+
|
|
102
|
+
devices = np.array(jax.local_devices())
|
|
103
|
+
# Reshape devices into a 3D array to name 3 axes: data, model, and expert.
|
|
104
|
+
# The 'model' and 'expert' axes will have a size of 1.
|
|
105
|
+
num_devices = len(devices)
|
|
106
|
+
device_mesh = devices.reshape((num_devices, 1, 1, 1))
|
|
107
|
+
|
|
108
|
+
with Mesh(device_mesh,
|
|
109
|
+
axis_names=('data', 'attn_dp', 'model', 'expert')) as m:
|
|
110
|
+
yield m
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@pytest.fixture
|
|
114
|
+
def rng() -> PRNGKey:
|
|
115
|
+
"""Provides a reusable JAX PRNGKey."""
|
|
116
|
+
return jax.random.PRNGKey(42)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@pytest.fixture
|
|
120
|
+
def mock_vllm_config_llama4() -> MockVllmConfig:
|
|
121
|
+
return MockVllmConfig(model_name="meta-llama/Llama-4-Scout-17B-16E")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class TestLlama4ForCausalLM:
|
|
125
|
+
"""Tests for the main LlamaForCausalLM model class."""
|
|
126
|
+
|
|
127
|
+
def test_init_llama4(self, mock_vllm_config_llama4, rng, mesh):
|
|
128
|
+
"""Tests correct parameter detection for the Llama4 model variant."""
|
|
129
|
+
model = Llama4ForCausalLM(mock_vllm_config_llama4, rng, mesh)
|
|
130
|
+
assert model.hidden_size == 32
|
|
131
|
+
assert "llama-4" in model.vllm_config.model_config.model.lower()
|
|
132
|
+
|
|
133
|
+
def test_create_model_with_random_weights(self, mock_vllm_config_llama4,
|
|
134
|
+
rng, mesh):
|
|
135
|
+
"""
|
|
136
|
+
Tests that random weight initialization creates concrete, non-zero-variance arrays.
|
|
137
|
+
"""
|
|
138
|
+
with jax.set_mesh(mesh):
|
|
139
|
+
model = Llama4ForCausalLM(vllm_config=mock_vllm_config_llama4,
|
|
140
|
+
rng=rng,
|
|
141
|
+
mesh=mesh,
|
|
142
|
+
force_random_weights=True)
|
|
143
|
+
embedding_weight = model.embedder.input_embedding_table_VD.value
|
|
144
|
+
attention_q_kernel = model.layers[0].attn.kernel_q_proj_DNH.value
|
|
145
|
+
final_norm_scale = model.final_norm.scale.value
|
|
146
|
+
|
|
147
|
+
assert isinstance(embedding_weight, jax.Array)
|
|
148
|
+
assert isinstance(attention_q_kernel, jax.Array)
|
|
149
|
+
assert isinstance(final_norm_scale, jax.Array)
|
|
150
|
+
|
|
151
|
+
assert jnp.std(embedding_weight) > 0
|
|
152
|
+
assert jnp.std(attention_q_kernel) > 0
|
|
153
|
+
|
|
154
|
+
assert jnp.all(final_norm_scale == 1.0)
|
|
155
|
+
|
|
156
|
+
@patch("tpu_inference.models.jax.llama4.Llama4WeightLoader")
|
|
157
|
+
def test_load_weights_called_correctly(self, mock_loader_cls, rng, mesh):
|
|
158
|
+
"""Tests that the weight loader is called correctly for checkpoint loading."""
|
|
159
|
+
vllm_config = MockVllmConfig(model_name="llama4-scout",
|
|
160
|
+
random_weights=False)
|
|
161
|
+
model = Llama4ForCausalLM(vllm_config, rng, mesh)
|
|
162
|
+
|
|
163
|
+
mock_loader_instance = MagicMock()
|
|
164
|
+
mock_loader_cls.return_value = mock_loader_instance
|
|
165
|
+
model.load_weights(rng)
|
|
166
|
+
|
|
167
|
+
mock_loader_cls.assert_called_once_with(vllm_config=vllm_config,
|
|
168
|
+
hidden_size=32,
|
|
169
|
+
attn_heads=40,
|
|
170
|
+
num_key_value_heads=8,
|
|
171
|
+
attn_head_dim=128)
|
|
172
|
+
mock_loader_instance.load_weights.assert_called_once_with(model)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class TestLlama4WeightLoader:
|
|
176
|
+
"""Tests for the Llama4WeightLoader class."""
|
|
177
|
+
|
|
178
|
+
@pytest.fixture
|
|
179
|
+
def weight_loader(self):
|
|
180
|
+
# Patch the superclass's setup to isolate the Llama4 loader's logic
|
|
181
|
+
return Llama4WeightLoader(vllm_config=MockVllmConfig("test-model"),
|
|
182
|
+
hidden_size=32,
|
|
183
|
+
attn_heads=40,
|
|
184
|
+
num_key_value_heads=8,
|
|
185
|
+
attn_head_dim=128)
|
|
186
|
+
|
|
187
|
+
@pytest.mark.parametrize("hf_key, expected_num", [
|
|
188
|
+
("language_model.model.layers.15.self_attn.q_proj.weight", 15),
|
|
189
|
+
("layers.0.feed_forward.router.weight", 0),
|
|
190
|
+
("language_model.model.layers.99.norm.weight", 99),
|
|
191
|
+
("language_model.model.norm.weight", None),
|
|
192
|
+
("language_model.model.embed_tokens.weight", None),
|
|
193
|
+
])
|
|
194
|
+
def test_get_layer_num(self, weight_loader, hf_key, expected_num):
|
|
195
|
+
"""Tests the private _get_layer_num utility function."""
|
|
196
|
+
assert weight_loader._get_layer_num(hf_key) == expected_num
|
|
197
|
+
|
|
198
|
+
@pytest.mark.parametrize("hf_key, expected_num", [
|
|
199
|
+
("language_model.model.layers.10.feed_forward.experts.4.down_proj.weight",
|
|
200
|
+
4),
|
|
201
|
+
("language_model.model.layers.0.feed_forward.experts.0.gate_proj.weight_scale",
|
|
202
|
+
0),
|
|
203
|
+
("language_model.model.layers.5.feed_forward.experts.128.up_proj.weight",
|
|
204
|
+
128),
|
|
205
|
+
("language_model.model.norm.weight", None),
|
|
206
|
+
("language_model.model.layers.15.self_attn.q_proj.weight", None),
|
|
207
|
+
])
|
|
208
|
+
def test_get_expert_num(self, weight_loader, hf_key, expected_num):
|
|
209
|
+
"""Tests the private _get_expert_num utility function to extract the expert index."""
|
|
210
|
+
assert weight_loader._get_expert_num(hf_key) == expected_num
|
|
211
|
+
|
|
212
|
+
@pytest.mark.parametrize("hf_key, expected", [
|
|
213
|
+
("language_model.model.layers.15.self_attn.q_proj.weight",
|
|
214
|
+
"layers.15.attn.kernel_q_proj_DNH"),
|
|
215
|
+
("language_model.model.layers.0.feed_forward.shared_expert.down_proj.weight",
|
|
216
|
+
"layers.0.shared_experts.kernel_down_proj_FD"),
|
|
217
|
+
("language_model.model.embed_tokens.weight",
|
|
218
|
+
"embedder.input_embedding_table_VD"),
|
|
219
|
+
("language_model.model.norm.weight", "final_norm.scale"),
|
|
220
|
+
("language_model.lm_head.weight", "lm_head.input_embedding_table_DV"),
|
|
221
|
+
("unmapped.key.name", "unmapped.key.name"),
|
|
222
|
+
])
|
|
223
|
+
def test_map_loaded_to_standardized_name(self, weight_loader, hf_key,
|
|
224
|
+
expected):
|
|
225
|
+
"""Tests the mapping from HuggingFace key names to internal names."""
|
|
226
|
+
assert weight_loader.map_loaded_to_standardized_name(
|
|
227
|
+
hf_key) == expected
|
|
228
|
+
|
|
229
|
+
def test_load_weights_transformation(self, weight_loader, rng, mesh):
|
|
230
|
+
"""Tests that weights are correctly reshaped, transposed, and loaded."""
|
|
231
|
+
vllm_config = MockVllmConfig(model_name="llama4-small-test",
|
|
232
|
+
random_weights=False)
|
|
233
|
+
|
|
234
|
+
model = Llama4ForCausalLM(vllm_config, rng, mesh)
|
|
235
|
+
|
|
236
|
+
# Original weight shape is (vocab_size, hidden_size)
|
|
237
|
+
original_weight = jnp.ones((128, 32))
|
|
238
|
+
dummy_weights = [
|
|
239
|
+
("language_model.model.embed_tokens.weight", original_weight),
|
|
240
|
+
]
|
|
241
|
+
weight_loader.names_and_weights_generator = dummy_weights
|
|
242
|
+
|
|
243
|
+
# Mock get_param to return a mock param with the target shape (vocab_size, hidden_size)
|
|
244
|
+
mock_param = MockParamLlama4(shape=(128, 32))
|
|
245
|
+
|
|
246
|
+
with patch("tpu_inference.models.jax.llama4.get_param", return_value=mock_param), \
|
|
247
|
+
patch("tpu_inference.models.jax.llama4.shard_put", return_value=jnp.ones(mock_param.value.shape)) as mock_shard_put:
|
|
248
|
+
|
|
249
|
+
# This will now pass after the code fix
|
|
250
|
+
weight_loader.load_weights(model)
|
|
251
|
+
|
|
252
|
+
# Assert that shard_put was called with the correctly transposed weight
|
|
253
|
+
mock_shard_put.assert_called_once()
|
|
254
|
+
|
|
255
|
+
# Get the actual array passed to shard_put
|
|
256
|
+
called_with_weight = mock_shard_put.call_args[0][0]
|
|
257
|
+
|
|
258
|
+
# Check if the shape of the array passed to shard_put matches the model's expected shape.
|
|
259
|
+
assert called_with_weight.shape == mock_param.value.shape
|
|
260
|
+
|
|
261
|
+
def test_map_llama4_gate_up_proj(self, weight_loader, rng, mesh):
|
|
262
|
+
"""Tests that gate_up_proj weights are correctly split, reshaped, transposed, and loaded."""
|
|
263
|
+
# Set up a dummy model and its config
|
|
264
|
+
model = Llama4ForCausalLM(MockVllmConfig("test-model"), rng, mesh)
|
|
265
|
+
|
|
266
|
+
# Create a dummy fused gate_up_proj weight tensor
|
|
267
|
+
hidden_size = 32
|
|
268
|
+
intermediate_size_moe = 8192
|
|
269
|
+
num_local_experts = 2
|
|
270
|
+
dummy_weight = jnp.ones(
|
|
271
|
+
(num_local_experts, hidden_size, 2 * intermediate_size_moe))
|
|
272
|
+
|
|
273
|
+
# Set up mocks and patches
|
|
274
|
+
mock_model_params = nnx.state(model)
|
|
275
|
+
mock_param = MockParamLlama4(shape=(2, hidden_size,
|
|
276
|
+
intermediate_size_moe))
|
|
277
|
+
|
|
278
|
+
# Create a dummy WeightLoader and set up the necessary attributes
|
|
279
|
+
weight_loader.is_verbose = False
|
|
280
|
+
layer_num = 0
|
|
281
|
+
weight_loader.names_and_weights_generator = [
|
|
282
|
+
(f"language_model.model.layers.{layer_num}.feed_forward.experts.gate_up_proj.weight",
|
|
283
|
+
dummy_weight),
|
|
284
|
+
]
|
|
285
|
+
|
|
286
|
+
with patch("tpu_inference.models.jax.llama4.get_param", return_value=mock_param), \
|
|
287
|
+
patch("tpu_inference.models.jax.llama4.shard_put", return_value=jnp.ones(mock_param.value.shape)) as mock_shard_put:
|
|
288
|
+
|
|
289
|
+
# Call _map_llama4_gate_up_proj directly
|
|
290
|
+
weight_loader._map_llama4_gate_up_proj(
|
|
291
|
+
model, mock_model_params,
|
|
292
|
+
f"language_model.model.layers.{layer_num}.feed_forward.experts.gate_up_proj.weight",
|
|
293
|
+
dummy_weight)
|
|
294
|
+
# Check if shard_put was called the correct number of times and with the correct weight shapes
|
|
295
|
+
assert mock_shard_put.call_count == 2
|
|
296
|
+
# call_args_list gives you a list of all the calls with their arguments.
|
|
297
|
+
for call in mock_shard_put.call_args_list:
|
|
298
|
+
assert call[0][0].shape == (num_local_experts, 32, 8192)
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
from flax import nnx
|
|
22
|
+
from flax.typing import PRNGKey
|
|
23
|
+
from jax.sharding import Mesh
|
|
24
|
+
from vllm.config import ModelConfig
|
|
25
|
+
|
|
26
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.models.jax.llama_eagle3 import (Eagle3LlamaDecoderLayer,
|
|
28
|
+
EagleLlama3ForCausalLM)
|
|
29
|
+
from tpu_inference.runner.kv_cache import create_kv_caches
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MockSpeculativeConfig:
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
self.num_speculative_tokens = 3
|
|
36
|
+
self.method = "eagle3"
|
|
37
|
+
self.draft_model_config = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MockVllmConfig:
|
|
41
|
+
|
|
42
|
+
def __init__(self, model: str, draft_model: str, kv_cache_dtype):
|
|
43
|
+
self.model_config = ModelConfig(model)
|
|
44
|
+
self.model_config.dtype = jnp.bfloat16
|
|
45
|
+
self.load_config = MagicMock()
|
|
46
|
+
self.load_config.download_dir = None
|
|
47
|
+
self.speculative_config = MockSpeculativeConfig()
|
|
48
|
+
self.speculative_config.draft_model_config = ModelConfig(
|
|
49
|
+
draft_model,
|
|
50
|
+
dtype="bfloat16",
|
|
51
|
+
max_model_len=2048,
|
|
52
|
+
skip_tokenizer_init=True,
|
|
53
|
+
trust_remote_code=True)
|
|
54
|
+
self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@pytest.fixture
|
|
58
|
+
def mock_vllm_config() -> MockVllmConfig:
|
|
59
|
+
return MockVllmConfig(model="meta-llama/Meta-Llama-3-8B-Instruct",
|
|
60
|
+
draft_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
|
61
|
+
kv_cache_dtype="auto")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@pytest.fixture(scope="module")
|
|
65
|
+
def mesh():
|
|
66
|
+
"""Creates a mesh with 1 device."""
|
|
67
|
+
if not jax.devices():
|
|
68
|
+
pytest.skip("No JAX devices available for mesh creation.")
|
|
69
|
+
|
|
70
|
+
devices = np.array(jax.local_devices()[:1])
|
|
71
|
+
device_mesh = devices.reshape((1, 1, -1))
|
|
72
|
+
|
|
73
|
+
with Mesh(device_mesh, axis_names=('data', 'attn_dp', 'model')) as m:
|
|
74
|
+
yield m
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@pytest.fixture
|
|
78
|
+
def mock_model_inputs(mock_vllm_config: MockVllmConfig):
|
|
79
|
+
"""Provides mock inputs for the EagleLlama3 model."""
|
|
80
|
+
batch_size = 2
|
|
81
|
+
seq_len = 16
|
|
82
|
+
target_hidden_size = mock_vllm_config.model_config.get_hidden_size()
|
|
83
|
+
|
|
84
|
+
input_ids = jnp.ones((batch_size * seq_len, ), dtype=jnp.int32)
|
|
85
|
+
hidden_states = jnp.ones((batch_size * seq_len, target_hidden_size),
|
|
86
|
+
dtype=jnp.bfloat16)
|
|
87
|
+
attention_metadata = AttentionMetadata(
|
|
88
|
+
input_positions=jnp.arange(batch_size * seq_len, dtype=jnp.int32),
|
|
89
|
+
block_tables=jnp.zeros((batch_size, 1), dtype=jnp.int32).reshape(-1),
|
|
90
|
+
seq_lens=jnp.full((batch_size, ), seq_len, dtype=jnp.int32),
|
|
91
|
+
query_start_loc=jnp.arange(0, (batch_size + 1) * seq_len,
|
|
92
|
+
seq_len,
|
|
93
|
+
dtype=jnp.int32),
|
|
94
|
+
request_distribution=jnp.array([0, 0, batch_size], dtype=jnp.int32),
|
|
95
|
+
)
|
|
96
|
+
return input_ids, hidden_states, attention_metadata
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@pytest.fixture
|
|
100
|
+
def rng() -> PRNGKey:
|
|
101
|
+
"""Provides a reusable JAX PRNGKey."""
|
|
102
|
+
return jax.random.PRNGKey(42)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class TestEagleLlama3ForCausalLM:
|
|
106
|
+
"""Tests for the EagleLlama3ForCausalLM model."""
|
|
107
|
+
|
|
108
|
+
def test_eagle3_decoder_layer_init(self, mock_vllm_config: MockVllmConfig,
|
|
109
|
+
rng: PRNGKey, mesh: Mesh):
|
|
110
|
+
"""Tests the initialization of the Eagle3LlamaDecoderLayer."""
|
|
111
|
+
hf_config = mock_vllm_config.speculative_config.draft_model_config.hf_config
|
|
112
|
+
dtype = jnp.bfloat16
|
|
113
|
+
rngs = nnx.Rngs(rng)
|
|
114
|
+
|
|
115
|
+
layer = Eagle3LlamaDecoderLayer(
|
|
116
|
+
hf_config,
|
|
117
|
+
dtype,
|
|
118
|
+
rngs,
|
|
119
|
+
mesh,
|
|
120
|
+
kv_cache_dtype=mock_vllm_config.cache_config.cache_dtype)
|
|
121
|
+
|
|
122
|
+
# Check if projection layers are overridden with correct input size
|
|
123
|
+
original_hidden_size = hf_config.hidden_size
|
|
124
|
+
expected_input_size = 2 * original_hidden_size
|
|
125
|
+
|
|
126
|
+
assert layer.self_attn.q_proj.kernel.value.shape[
|
|
127
|
+
0] == expected_input_size
|
|
128
|
+
assert layer.self_attn.k_proj.kernel.value.shape[
|
|
129
|
+
0] == expected_input_size
|
|
130
|
+
assert layer.self_attn.v_proj.kernel.value.shape[
|
|
131
|
+
0] == expected_input_size
|
|
132
|
+
assert isinstance(layer.hidden_norm, nnx.RMSNorm)
|
|
133
|
+
|
|
134
|
+
@pytest.mark.parametrize("mock_vllm_config", [
|
|
135
|
+
MockVllmConfig("meta-llama/Meta-Llama-3-8B-Instruct",
|
|
136
|
+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "auto"),
|
|
137
|
+
MockVllmConfig("meta-llama/Meta-Llama-3-8B-Instruct",
|
|
138
|
+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "fp8"),
|
|
139
|
+
])
|
|
140
|
+
def test_forward_pass(self, mock_vllm_config: MockVllmConfig, rng: PRNGKey,
|
|
141
|
+
mesh: Mesh, mock_model_inputs):
|
|
142
|
+
"""Tests the forward pass of the EagleLlama3ForCausalLM model."""
|
|
143
|
+
|
|
144
|
+
draft_model_config = mock_vllm_config.speculative_config.draft_model_config
|
|
145
|
+
hf_config = draft_model_config.hf_config
|
|
146
|
+
model = EagleLlama3ForCausalLM(mock_vllm_config, rng, mesh)
|
|
147
|
+
|
|
148
|
+
input_ids, hidden_states, attention_metadata = mock_model_inputs
|
|
149
|
+
|
|
150
|
+
kv_caches = create_kv_caches(
|
|
151
|
+
num_blocks=4,
|
|
152
|
+
block_size=16,
|
|
153
|
+
num_kv_heads=hf_config.num_key_value_heads,
|
|
154
|
+
head_size=hf_config.hidden_size // hf_config.num_attention_heads,
|
|
155
|
+
mesh=mesh,
|
|
156
|
+
layer_names=["layer"] * hf_config.num_hidden_layers,
|
|
157
|
+
cache_dtype=jnp.float8_e4m3fn
|
|
158
|
+
if mock_vllm_config.cache_config.cache_dtype == "fp8" else
|
|
159
|
+
jnp.bfloat16)
|
|
160
|
+
|
|
161
|
+
_, output_hidden_states, aux_hidden_states = model(
|
|
162
|
+
kv_caches, input_ids, hidden_states, attention_metadata)
|
|
163
|
+
|
|
164
|
+
logits = model.compute_logits(output_hidden_states)
|
|
165
|
+
|
|
166
|
+
target_model_config = mock_vllm_config.model_config
|
|
167
|
+
|
|
168
|
+
assert output_hidden_states.shape == (
|
|
169
|
+
input_ids.shape[0], draft_model_config.get_hidden_size())
|
|
170
|
+
assert logits.shape == (input_ids.shape[0],
|
|
171
|
+
target_model_config.get_vocab_size())
|
|
172
|
+
assert len(aux_hidden_states) == 1
|
|
173
|
+
assert aux_hidden_states[0].shape == output_hidden_states.shape
|
|
174
|
+
|
|
175
|
+
@patch("tpu_inference.models.jax.llama_eagle3.load_hf_weights")
|
|
176
|
+
def test_load_weights(self, mock_load_hf_weights: MagicMock,
|
|
177
|
+
mock_vllm_config: MockVllmConfig, rng: PRNGKey,
|
|
178
|
+
mesh: Mesh):
|
|
179
|
+
"""Tests that the load_weights function is called correctly."""
|
|
180
|
+
model = EagleLlama3ForCausalLM(mock_vllm_config, rng, mesh)
|
|
181
|
+
model.load_weights(rng)
|
|
182
|
+
|
|
183
|
+
mock_load_hf_weights.assert_called_once()
|
|
184
|
+
call_args = mock_load_hf_weights.call_args.kwargs
|
|
185
|
+
|
|
186
|
+
assert call_args["vllm_config"] is mock_vllm_config
|
|
187
|
+
assert call_args["model"] is model
|
|
188
|
+
assert call_args["mesh"] is mesh
|
|
189
|
+
assert call_args["is_draft_model"] is True
|
|
190
|
+
|
|
191
|
+
metadata_map = call_args["metadata_map"]
|
|
192
|
+
assert "midlayer.hidden_norm" in metadata_map.name_map
|
|
193
|
+
assert "lm_head" in metadata_map.name_map
|
|
194
|
+
assert "d2t" in metadata_map.name_map
|
|
195
|
+
assert "q_proj" in metadata_map.reshape_map
|
|
196
|
+
assert metadata_map.reshape_map["q_proj"][-1] == (
|
|
197
|
+
2 * mock_vllm_config.model_config.get_hidden_size())
|