tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from types import SimpleNamespace
|
|
16
|
+
from unittest.mock import MagicMock, patch
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pytest
|
|
22
|
+
import torch
|
|
23
|
+
from flax import nnx
|
|
24
|
+
from jax.sharding import Mesh
|
|
25
|
+
from vllm.config import ModelConfig
|
|
26
|
+
|
|
27
|
+
# Assuming the model file is named deepseek_v3.py
|
|
28
|
+
from tpu_inference.models.jax.deepseek_v3 import (DeepSeekV3,
|
|
29
|
+
DeepSeekV3WeightLoader)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MockVariable:
|
|
33
|
+
"""Mocks an nnx.Variable or a QArray structure."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, shape, dtype=jnp.bfloat16, sharding=None):
|
|
36
|
+
self.value = jnp.zeros(shape, dtype=dtype)
|
|
37
|
+
self.sharding = sharding or (None, ) * len(shape)
|
|
38
|
+
self.nbytes = self.value.nbytes
|
|
39
|
+
# Handle the QArray structure used in the loader
|
|
40
|
+
self.array = SimpleNamespace(
|
|
41
|
+
qvalue=self,
|
|
42
|
+
scale=SimpleNamespace(
|
|
43
|
+
value=jnp.ones((1, )),
|
|
44
|
+
nbytes=4,
|
|
45
|
+
sharding=None,
|
|
46
|
+
addressable_shards=[SimpleNamespace(data=jnp.ones((1, )))]))
|
|
47
|
+
self.addressable_shards = [SimpleNamespace(data=self.value)]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class MockVllmConfig:
|
|
51
|
+
"""Mock VllmConfig for DeepSeekV3."""
|
|
52
|
+
|
|
53
|
+
def __init__(self,
|
|
54
|
+
model_name: str = "deepseek-ai/DeepSeek-V3",
|
|
55
|
+
use_mla: bool = False):
|
|
56
|
+
self.model_config = MagicMock(spec=ModelConfig)
|
|
57
|
+
self.model_config.model = model_name
|
|
58
|
+
self.model_config.use_mla = use_mla
|
|
59
|
+
|
|
60
|
+
# DeepSeek V3 specific config
|
|
61
|
+
hf_config = MagicMock()
|
|
62
|
+
hf_config.num_hidden_layers = 1 # Small for testing
|
|
63
|
+
hf_config.num_nextn_predict_layers = 1
|
|
64
|
+
self.model_config.hf_config = hf_config
|
|
65
|
+
|
|
66
|
+
self.load_config = MagicMock()
|
|
67
|
+
self.load_config.download_dir = None
|
|
68
|
+
|
|
69
|
+
self.cache_config = MagicMock()
|
|
70
|
+
self.cache_config.cache_dtype = "auto"
|
|
71
|
+
|
|
72
|
+
self.additional_config = {
|
|
73
|
+
"random_weights": False,
|
|
74
|
+
"sparse_matmul": False,
|
|
75
|
+
"is_verbose": True
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@pytest.fixture(scope="module")
|
|
80
|
+
def mesh():
|
|
81
|
+
if not jax.devices():
|
|
82
|
+
pytest.skip("No JAX devices available.")
|
|
83
|
+
devices = np.array(jax.local_devices())
|
|
84
|
+
num_devices = len(devices)
|
|
85
|
+
device_mesh = devices.reshape((num_devices, 1, 1, 1))
|
|
86
|
+
# Simplify axis names for testing
|
|
87
|
+
with Mesh(device_mesh,
|
|
88
|
+
axis_names=('data', 'attn_dp', 'model', 'expert')) as m:
|
|
89
|
+
yield m
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@pytest.fixture
|
|
93
|
+
def rng():
|
|
94
|
+
return jax.random.PRNGKey(0)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@pytest.fixture
|
|
98
|
+
def mock_config():
|
|
99
|
+
return MockVllmConfig()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class TestDeepSeekV3:
|
|
103
|
+
|
|
104
|
+
def test_init(self, mock_config, rng, mesh):
|
|
105
|
+
"""Tests if the model initializes with the correct hierarchy."""
|
|
106
|
+
model = DeepSeekV3(mock_config, rng, mesh)
|
|
107
|
+
assert len(model.layers) == 3 # num_layers from mock
|
|
108
|
+
assert isinstance(model.embedder, nnx.Module)
|
|
109
|
+
assert model.vllm_config.model_config.hf_config.num_hidden_layers == 1
|
|
110
|
+
|
|
111
|
+
def test_random_weights(self, mock_config, rng, mesh):
|
|
112
|
+
"""Tests that force_random_weights initializes non-zero weights."""
|
|
113
|
+
with jax.set_mesh(mesh):
|
|
114
|
+
model = DeepSeekV3(mock_config,
|
|
115
|
+
rng,
|
|
116
|
+
mesh,
|
|
117
|
+
force_random_weights=True)
|
|
118
|
+
# Check embedding
|
|
119
|
+
weight = model.embedder.input_embedding_table_VD.value
|
|
120
|
+
assert jnp.std(weight) > 0
|
|
121
|
+
# Check a layer norm (should be 1s usually, but check existence)
|
|
122
|
+
assert model.final_norm.scale.value.shape == (7168, )
|
|
123
|
+
|
|
124
|
+
@patch("tpu_inference.models.jax.deepseek_v3.DeepSeekV3WeightLoader")
|
|
125
|
+
def test_load_weights_called(self, mock_loader_cls, mock_config, rng,
|
|
126
|
+
mesh):
|
|
127
|
+
model = DeepSeekV3(mock_config, rng, mesh)
|
|
128
|
+
mock_loader_instance = mock_loader_cls.return_value
|
|
129
|
+
|
|
130
|
+
model.load_weights(rng)
|
|
131
|
+
|
|
132
|
+
mock_loader_instance.load_weights.assert_called_once_with(model)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class TestDeepSeekV3WeightLoader:
|
|
136
|
+
|
|
137
|
+
@pytest.fixture
|
|
138
|
+
def loader(self, mock_config):
|
|
139
|
+
# We need to mock the generator so it doesn't try to download files
|
|
140
|
+
with patch(
|
|
141
|
+
"tpu_inference.models.jax.deepseek_v3.model_weights_generator",
|
|
142
|
+
return_value=[]):
|
|
143
|
+
return DeepSeekV3WeightLoader(vllm_config=mock_config,
|
|
144
|
+
num_layers=2,
|
|
145
|
+
hidden_size=7168,
|
|
146
|
+
q_lora_rank=1536,
|
|
147
|
+
kv_lora_rank=512,
|
|
148
|
+
attn_heads=128,
|
|
149
|
+
qk_nope_head_dim=128,
|
|
150
|
+
qk_rope_head_dim=64,
|
|
151
|
+
v_head_dim=128,
|
|
152
|
+
num_local_experts=256,
|
|
153
|
+
model_dtype=jnp.bfloat16)
|
|
154
|
+
|
|
155
|
+
@pytest.mark.parametrize("loaded_key, expected_mapped", [
|
|
156
|
+
("model.embed_tokens.weight", "embedder.input_embedding_table_VD"),
|
|
157
|
+
("model.layers.0.self_attn.q_a_proj.weight",
|
|
158
|
+
"layers.0.attn.kernel_q_down_proj_DA"),
|
|
159
|
+
("model.layers.5.mlp.experts.10.gate_proj.weight",
|
|
160
|
+
"layers.5.custom_module.kernel_gating_EDF"),
|
|
161
|
+
("model.layers.1.mlp.shared_experts.down_proj.weight",
|
|
162
|
+
"layers.1.shared_experts.kernel_down_proj_FD"),
|
|
163
|
+
("model.norm.weight", "final_norm.scale"),
|
|
164
|
+
])
|
|
165
|
+
def test_key_mapping(self, loader, loaded_key, expected_mapped):
|
|
166
|
+
assert loader.map_loaded_to_standardized_name(
|
|
167
|
+
loaded_key) == expected_mapped
|
|
168
|
+
|
|
169
|
+
def test_transpose_params(self, loader):
|
|
170
|
+
# Test a standard MLP transpose (1, 0)
|
|
171
|
+
dummy_weight = jnp.ones((100, 200))
|
|
172
|
+
transposed = loader._transpose_params("mlp.down_proj", dummy_weight)
|
|
173
|
+
assert transposed.shape == (200, 100)
|
|
174
|
+
|
|
175
|
+
# Test MLA kernel transpose (2, 0, 1)
|
|
176
|
+
dummy_mla = jnp.ones((10, 20, 30))
|
|
177
|
+
transposed_mla = loader._transpose_params("k_b_proj", dummy_mla)
|
|
178
|
+
assert transposed_mla.shape == (30, 10, 20)
|
|
179
|
+
|
|
180
|
+
def test_moe_stacking_logic(self, loader):
|
|
181
|
+
"""Tests that individual expert weights are collected and stacked correctly."""
|
|
182
|
+
weights_dict = {}
|
|
183
|
+
layer_num = "0"
|
|
184
|
+
loader.num_routed_experts = 4 # Small for test
|
|
185
|
+
|
|
186
|
+
# Simulate loading 4 experts
|
|
187
|
+
for i in range(4):
|
|
188
|
+
name = f"model.layers.0.mlp.experts.{i}.gate_proj.weight"
|
|
189
|
+
weight = torch.ones((10, 20)) * i
|
|
190
|
+
result = loader._process_moe_weights(name, weight, weights_dict)
|
|
191
|
+
|
|
192
|
+
if i < 3:
|
|
193
|
+
assert result is None
|
|
194
|
+
assert weights_dict[layer_num][1] == i + 1
|
|
195
|
+
else:
|
|
196
|
+
# On the last expert, it should return stacked tensor
|
|
197
|
+
assert result is not None
|
|
198
|
+
assert result.shape == (4, 10, 20)
|
|
199
|
+
assert layer_num not in weights_dict # Should be cleaned up
|
|
200
|
+
|
|
201
|
+
def test_mla_kernel_weight_splitting(self, loader, mesh):
|
|
202
|
+
"""Tests that kv_b_proj is split into k_b_proj and v_b_proj for MLA kernel."""
|
|
203
|
+
loader.use_mla_kernel = True
|
|
204
|
+
loader.attn_heads = 2
|
|
205
|
+
loader.qk_nope_head_dim = 4
|
|
206
|
+
loader.v_head_dim = 4
|
|
207
|
+
loader.kv_lora_rank = 8
|
|
208
|
+
|
|
209
|
+
# Total rows = heads * (nope_dim + v_dim) = 2 * (4 + 4) = 16
|
|
210
|
+
# Cols = kv_lora_rank = 8
|
|
211
|
+
kv_b_proj_weight = torch.randn((16, 8))
|
|
212
|
+
|
|
213
|
+
# Mocking the load_individual_weight to capture what gets passed
|
|
214
|
+
with patch.object(loader,
|
|
215
|
+
'_load_individual_weight',
|
|
216
|
+
return_value=(0, 0)):
|
|
217
|
+
model_mock = MagicMock()
|
|
218
|
+
model_mock.mesh = mesh
|
|
219
|
+
|
|
220
|
+
# Simulate the splitting logic in the loader
|
|
221
|
+
weight_reshaped = kv_b_proj_weight.view(2, 4 + 4, 8)
|
|
222
|
+
k_weight = weight_reshaped[:, :4, :]
|
|
223
|
+
v_weight = weight_reshaped[:, 4:, :]
|
|
224
|
+
|
|
225
|
+
# Verify shapes of split parts
|
|
226
|
+
assert k_weight.shape == (2, 4, 8)
|
|
227
|
+
assert v_weight.shape == (2, 4, 8)
|
|
228
|
+
|
|
229
|
+
def test_load_individual_weight_with_mxfp4(self, loader, mesh):
|
|
230
|
+
"""Tests the logic for unpacking MXFP4 weights."""
|
|
231
|
+
name = "layers.0.attn.kernel_q_down_proj_DA"
|
|
232
|
+
# Mocking torch tensor as uint8 (packed fp4)
|
|
233
|
+
expected_weight_shape = (128, 128) # Unpacked
|
|
234
|
+
expected_scale_shape = (128, 1)
|
|
235
|
+
|
|
236
|
+
weight = torch.zeros(expected_weight_shape, dtype=torch.uint8)
|
|
237
|
+
scale = torch.ones(expected_scale_shape, dtype=torch.float32)
|
|
238
|
+
|
|
239
|
+
# Mock model parameters
|
|
240
|
+
mock_var = MockVariable(
|
|
241
|
+
(128, 128),
|
|
242
|
+
dtype=jnp.float4_e2m1fn,
|
|
243
|
+
sharding=(None, ('attn_dp', 'model',
|
|
244
|
+
'expert'))) # Unpacked shape (64 * 2)
|
|
245
|
+
mock_params = {
|
|
246
|
+
"layers": {
|
|
247
|
+
"0": {
|
|
248
|
+
"attn": {
|
|
249
|
+
"kernel_q_down_proj_DA": mock_var
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
with patch("tpu_inference.models.jax.deepseek_v3.get_param", return_value=mock_var), \
|
|
256
|
+
patch("tpu_inference.models.jax.deepseek_v3.u8_unpack_e2m1") as mock_unpack, \
|
|
257
|
+
patch("jax.make_array_from_callback") as mock_make_array:
|
|
258
|
+
|
|
259
|
+
def side_effect_router(shape, *args, **kwargs):
|
|
260
|
+
if shape == expected_scale_shape:
|
|
261
|
+
# Return FP32 for the scale call
|
|
262
|
+
return jnp.ones(shape, dtype=jnp.float32)
|
|
263
|
+
elif shape == expected_weight_shape:
|
|
264
|
+
# Return FP4 for the weight call
|
|
265
|
+
return jnp.zeros(shape, dtype=jnp.float4_e2m1fn)
|
|
266
|
+
return jnp.zeros(shape) # Fallback
|
|
267
|
+
|
|
268
|
+
mock_make_array.side_effect = side_effect_router
|
|
269
|
+
mock_unpack.return_value = torch.zeros(expected_weight_shape)
|
|
270
|
+
|
|
271
|
+
loader._load_individual_weight(name,
|
|
272
|
+
weight,
|
|
273
|
+
mock_params,
|
|
274
|
+
mesh,
|
|
275
|
+
scale=scale)
|
|
276
|
+
|
|
277
|
+
mock_unpack.assert_called_once()
|
|
278
|
+
(actual_arg, ), _ = mock_unpack.call_args
|
|
279
|
+
# The implementation converts the torch weight to a JAX array
|
|
280
|
+
expected_arg = jnp.array(weight.cpu().numpy())
|
|
281
|
+
assert jnp.array_equal(actual_arg, expected_arg).item()
|
|
282
|
+
assert mock_make_array.called
|
|
283
|
+
|
|
284
|
+
def test_load_weights_full_flow(self, loader, mesh):
|
|
285
|
+
"""Integrative test for the load_weights loop."""
|
|
286
|
+
model = MagicMock(spec=nnx.Module)
|
|
287
|
+
model.mesh = mesh
|
|
288
|
+
|
|
289
|
+
# Setup generator to return one normal weight
|
|
290
|
+
loader.names_and_weights_generator = [("model.embed_tokens.weight",
|
|
291
|
+
torch.ones((10, 10)))]
|
|
292
|
+
|
|
293
|
+
mock_var = MockVariable((10, 10))
|
|
294
|
+
|
|
295
|
+
with patch("tpu_inference.models.jax.deepseek_v3.nnx.state"), \
|
|
296
|
+
patch("tpu_inference.models.jax.deepseek_v3.get_param", return_value=mock_var), \
|
|
297
|
+
patch("tpu_inference.models.jax.deepseek_v3.nnx.update"), \
|
|
298
|
+
patch.object(loader, '_load_individual_weight', return_value=(1.0, 0.5)):
|
|
299
|
+
|
|
300
|
+
loader.load_weights(model)
|
|
301
|
+
# Verify verbose logging worked if enabled
|
|
302
|
+
assert loader.is_verbose is True
|
|
303
|
+
|
|
304
|
+
def test_load_individual_weight_unpacked(self, loader, mesh):
|
|
305
|
+
"""
|
|
306
|
+
Tests the logic for loading 'unpacked' weights (e.g., standard FP8).
|
|
307
|
+
This verifies the branch that uses DTYPE_VIEW_MAP for raw memory conversion.
|
|
308
|
+
"""
|
|
309
|
+
name = "layers.0.attn.kernel_q_down_proj_DA"
|
|
310
|
+
|
|
311
|
+
# 1. Setup a standard 'unpacked' FP8 torch tensor
|
|
312
|
+
# DeepSeek V3 weights are often float8_e4m3fn
|
|
313
|
+
weight_shape = (128, 128)
|
|
314
|
+
weight = torch.randn(weight_shape).to(torch.float8_e4m3fn)
|
|
315
|
+
|
|
316
|
+
# 2. Mock model parameters to expect jnp.float8_e4m3fn
|
|
317
|
+
# We reuse the MockVariable helper but specify the dtype
|
|
318
|
+
mock_var = MockVariable(weight_shape, dtype=jnp.float8_e4m3fn)
|
|
319
|
+
mock_params = {
|
|
320
|
+
"layers": {
|
|
321
|
+
"0": {
|
|
322
|
+
"attn": {
|
|
323
|
+
"kernel_q_down_proj_DA": mock_var
|
|
324
|
+
}
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
# 3. Patch the necessary JAX/Utility functions
|
|
330
|
+
with patch("tpu_inference.models.jax.deepseek_v3.get_param", return_value=mock_var), \
|
|
331
|
+
patch("tpu_inference.models.jax.deepseek_v3.u8_unpack_e2m1") as mock_unpack, \
|
|
332
|
+
patch("jax.make_array_from_callback") as mock_make_array:
|
|
333
|
+
|
|
334
|
+
# Mock the JAX array creation to return a dummy
|
|
335
|
+
mock_make_array.return_value = jnp.zeros(weight_shape,
|
|
336
|
+
dtype=jnp.float8_e4m3fn)
|
|
337
|
+
|
|
338
|
+
# Execute the loader method
|
|
339
|
+
loader._load_individual_weight(name,
|
|
340
|
+
weight,
|
|
341
|
+
mock_params,
|
|
342
|
+
mesh,
|
|
343
|
+
scale=None)
|
|
344
|
+
|
|
345
|
+
# VERIFICATIONS:
|
|
346
|
+
# - u8_unpack_e2m1 should NOT be called for standard FP8 (only for packed uint8 + scale)
|
|
347
|
+
mock_unpack.assert_not_called()
|
|
348
|
+
|
|
349
|
+
# - make_array_from_callback should be called with the correct shape and sharding
|
|
350
|
+
# The first argument to make_array_from_callback is the shape
|
|
351
|
+
assert mock_make_array.call_args[0][0] == weight_shape
|
|
352
|
+
|
|
353
|
+
# - Verify the model weight value was updated (even if with our dummy)
|
|
354
|
+
assert mock_var.value.dtype == jnp.float8_e4m3fn
|
|
355
|
+
|
|
356
|
+
def test_load_individual_weight_with_scale(self, loader, mesh):
|
|
357
|
+
"""
|
|
358
|
+
Tests loading an unpacked weight that also has a quantization scale.
|
|
359
|
+
"""
|
|
360
|
+
name = "layers.0.custom_module.kernel_gating_DF"
|
|
361
|
+
weight_shape = (64, 128)
|
|
362
|
+
scale_shape = (64, 1)
|
|
363
|
+
|
|
364
|
+
# Use BF16 for this test to verify DTYPE_VIEW_MAP handles multiple types
|
|
365
|
+
weight = torch.randn(weight_shape).to(torch.bfloat16)
|
|
366
|
+
scale = torch.ones(scale_shape, dtype=torch.float32)
|
|
367
|
+
|
|
368
|
+
mock_var = MockVariable(weight_shape, dtype=jnp.bfloat16)
|
|
369
|
+
mock_params = {
|
|
370
|
+
"layers": {
|
|
371
|
+
"0": {
|
|
372
|
+
"custom_module": {
|
|
373
|
+
"kernel_gating_DF": mock_var
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
with patch("tpu_inference.models.jax.deepseek_v3.get_param", return_value=mock_var), \
|
|
380
|
+
patch("jax.make_array_from_callback") as mock_make_array:
|
|
381
|
+
|
|
382
|
+
def side_effect_router(shape, *args, **kwargs):
|
|
383
|
+
if shape == scale_shape:
|
|
384
|
+
# Return FP32 for the scale call
|
|
385
|
+
return jnp.ones(shape, dtype=jnp.float32)
|
|
386
|
+
elif shape == weight_shape:
|
|
387
|
+
# Return FP4 for the weight call
|
|
388
|
+
return jnp.zeros(shape, dtype=jnp.bfloat16)
|
|
389
|
+
return jnp.zeros(shape) # Fallback
|
|
390
|
+
|
|
391
|
+
mock_make_array.side_effect = side_effect_router
|
|
392
|
+
|
|
393
|
+
loader._load_individual_weight(name,
|
|
394
|
+
weight,
|
|
395
|
+
mock_params,
|
|
396
|
+
mesh,
|
|
397
|
+
scale=scale)
|
|
398
|
+
|
|
399
|
+
# Verify the scale was applied to the MockVariable's internal QArray structure
|
|
400
|
+
# (In the model code: base_model_weight.array.scale.value = maybe_sharded_scale)
|
|
401
|
+
assert mock_var.array.scale.value is not None
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
from flax import nnx
|
|
22
|
+
from flax.typing import PRNGKey
|
|
23
|
+
from jax.sharding import Mesh
|
|
24
|
+
from vllm.config import ModelConfig
|
|
25
|
+
|
|
26
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
27
|
+
from tpu_inference.models.jax.llama3 import LlamaForCausalLM
|
|
28
|
+
from tpu_inference.runner.kv_cache import create_kv_caches
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MockVllmConfig:
|
|
32
|
+
|
|
33
|
+
def __init__(self, model: str, kv_cache_dtype: str):
|
|
34
|
+
self.model_config = ModelConfig(model)
|
|
35
|
+
self.model_config.dtype = jnp.bfloat16
|
|
36
|
+
self.load_config = MagicMock()
|
|
37
|
+
self.load_config.download_dir = None
|
|
38
|
+
self.speculative_config = None
|
|
39
|
+
self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.fixture(scope="module")
|
|
43
|
+
def mesh():
|
|
44
|
+
"""
|
|
45
|
+
Creates a mesh with 1 device.
|
|
46
|
+
"""
|
|
47
|
+
if not jax.devices():
|
|
48
|
+
pytest.skip("No JAX devices available for mesh creation.")
|
|
49
|
+
|
|
50
|
+
devices = np.array(jax.local_devices()[:1])
|
|
51
|
+
num_devices = len(devices)
|
|
52
|
+
assert num_devices == 1
|
|
53
|
+
device_mesh = devices.reshape((num_devices, 1, 1, 1))
|
|
54
|
+
|
|
55
|
+
with Mesh(device_mesh,
|
|
56
|
+
axis_names=('data', 'attn_dp', 'expert', 'model')) as m:
|
|
57
|
+
yield m
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture
|
|
61
|
+
def mock_model_inputs():
|
|
62
|
+
num_tokens = 8
|
|
63
|
+
num_reqs = 1
|
|
64
|
+
max_num_blocks_per_req = 4
|
|
65
|
+
input_ids = jnp.ones((num_tokens, ), dtype=jnp.int32)
|
|
66
|
+
positions = jnp.ones((num_tokens, ), dtype=jnp.int32)
|
|
67
|
+
block_tables = jnp.zeros((num_reqs, max_num_blocks_per_req),
|
|
68
|
+
dtype=jnp.int32).reshape(-1)
|
|
69
|
+
seq_lens = jnp.ones((num_reqs, ), dtype=jnp.int32)
|
|
70
|
+
query_start_loc = jnp.ones((num_reqs + 1, ), dtype=jnp.int32)
|
|
71
|
+
request_distribution = jnp.array([0, 0, 0], dtype=jnp.int32)
|
|
72
|
+
|
|
73
|
+
attention_metadata = AttentionMetadata(
|
|
74
|
+
input_positions=positions,
|
|
75
|
+
block_tables=block_tables,
|
|
76
|
+
seq_lens=seq_lens,
|
|
77
|
+
query_start_loc=query_start_loc,
|
|
78
|
+
request_distribution=request_distribution,
|
|
79
|
+
)
|
|
80
|
+
indices_do_sample = jnp.ones((num_reqs, ), dtype=jnp.int32)
|
|
81
|
+
|
|
82
|
+
return (input_ids, attention_metadata, indices_do_sample)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.fixture
|
|
86
|
+
def rng() -> PRNGKey:
|
|
87
|
+
"""Provides a reusable JAX PRNGKey."""
|
|
88
|
+
return jax.random.PRNGKey(42)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@pytest.fixture(autouse=True)
|
|
92
|
+
def mock_get_pp_group():
|
|
93
|
+
mock_pp = MagicMock(is_first_rank=True,
|
|
94
|
+
is_last_rank=True,
|
|
95
|
+
rank_in_group=0,
|
|
96
|
+
world_size=1)
|
|
97
|
+
with patch("tpu_inference.models.jax.llama3.get_pp_group",
|
|
98
|
+
return_value=mock_pp), patch(
|
|
99
|
+
"tpu_inference.layers.jax.pp_utils.get_pp_group",
|
|
100
|
+
return_value=mock_pp):
|
|
101
|
+
yield
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class TestLlamaForCausalLM:
|
|
105
|
+
"""Tests for the main LlamaForCausalLM model class."""
|
|
106
|
+
|
|
107
|
+
@pytest.mark.parametrize("mock_vllm_config", [
|
|
108
|
+
MockVllmConfig("meta-llama/Llama-3.2-1B", "auto"),
|
|
109
|
+
MockVllmConfig("meta-llama/Llama-3.2-1B", "fp8")
|
|
110
|
+
])
|
|
111
|
+
def test_llama32_1b(self, mock_vllm_config, rng, mesh, mock_model_inputs):
|
|
112
|
+
"""Tests model init and model forward for the 8B model variant."""
|
|
113
|
+
|
|
114
|
+
# Test model init
|
|
115
|
+
model = LlamaForCausalLM(mock_vllm_config, rng, mesh)
|
|
116
|
+
|
|
117
|
+
model_config = mock_vllm_config.model_config
|
|
118
|
+
hf_config = model_config.hf_config
|
|
119
|
+
|
|
120
|
+
assert model.mesh.shape == {
|
|
121
|
+
"data": 1,
|
|
122
|
+
"attn_dp": 1,
|
|
123
|
+
"expert": 1,
|
|
124
|
+
"model": 1
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
layers = model.model.layers
|
|
128
|
+
assert len(layers) == hf_config.num_hidden_layers
|
|
129
|
+
assert isinstance(model.rng, nnx.Rngs)
|
|
130
|
+
assert model.model.lm_head == model.model.embed.embedding
|
|
131
|
+
|
|
132
|
+
attn = layers[0].self_attn
|
|
133
|
+
hidden_size = hf_config.hidden_size
|
|
134
|
+
num_heads = hf_config.num_attention_heads
|
|
135
|
+
num_kv_heads = hf_config.num_key_value_heads
|
|
136
|
+
rope_theta = hf_config.rope_theta
|
|
137
|
+
head_dim = hf_config.head_dim
|
|
138
|
+
intermediate_size = hf_config.intermediate_size
|
|
139
|
+
|
|
140
|
+
assert attn.hidden_size == hidden_size
|
|
141
|
+
assert attn.num_heads == num_heads
|
|
142
|
+
assert attn.num_kv_heads == num_kv_heads
|
|
143
|
+
assert attn.rope_theta == rope_theta
|
|
144
|
+
assert attn.head_dim_original == head_dim
|
|
145
|
+
assert attn.head_dim == head_dim
|
|
146
|
+
assert attn.q_proj.kernel.shape == (hidden_size, num_heads, head_dim)
|
|
147
|
+
assert attn.k_proj.kernel.shape == (hidden_size, num_kv_heads,
|
|
148
|
+
head_dim)
|
|
149
|
+
assert attn.v_proj.kernel.shape == (hidden_size, num_kv_heads,
|
|
150
|
+
head_dim)
|
|
151
|
+
assert attn.o_proj.kernel.shape == (num_heads, head_dim, hidden_size)
|
|
152
|
+
|
|
153
|
+
mlp = layers[0].mlp
|
|
154
|
+
assert mlp.gate_proj.kernel.shape == (hidden_size, intermediate_size)
|
|
155
|
+
assert mlp.up_proj.kernel.shape == (hidden_size, intermediate_size)
|
|
156
|
+
assert mlp.down_proj.kernel.shape == (intermediate_size, hidden_size)
|
|
157
|
+
|
|
158
|
+
# Test model load
|
|
159
|
+
model.load_weights(rng)
|
|
160
|
+
|
|
161
|
+
# Test model forward
|
|
162
|
+
kv_caches = create_kv_caches(
|
|
163
|
+
num_blocks=4,
|
|
164
|
+
block_size=32,
|
|
165
|
+
num_kv_heads=num_kv_heads,
|
|
166
|
+
head_size=head_dim,
|
|
167
|
+
mesh=mesh,
|
|
168
|
+
layer_names=["layer"] * hf_config.num_hidden_layers,
|
|
169
|
+
cache_dtype=jnp.float8_e4m3fn
|
|
170
|
+
if mock_vllm_config.cache_config.cache_dtype == "fp8" else
|
|
171
|
+
jnp.bfloat16)
|
|
172
|
+
# 1 seq with 16 tokens
|
|
173
|
+
input_ids, attention_metadata, indices_do_sample = mock_model_inputs
|
|
174
|
+
kv_caches, hidden_states, aux_hidden_states = model(
|
|
175
|
+
kv_caches, input_ids, attention_metadata, None, None, None, None,
|
|
176
|
+
None, True, True)
|
|
177
|
+
assert hidden_states.shape == (8, hidden_size)
|
|
178
|
+
assert len(aux_hidden_states) == 0
|
|
179
|
+
|
|
180
|
+
hidden_states = hidden_states[indices_do_sample]
|
|
181
|
+
assert hidden_states.shape == (1, hidden_size)
|
|
182
|
+
|
|
183
|
+
logits = model.compute_logits(hidden_states)
|
|
184
|
+
assert logits.shape == (1, hf_config.vocab_size)
|