tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from types import SimpleNamespace
|
|
16
|
+
from unittest.mock import MagicMock, patch
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pytest
|
|
22
|
+
from flax import nnx
|
|
23
|
+
from flax.typing import PRNGKey
|
|
24
|
+
from jax.sharding import Mesh
|
|
25
|
+
|
|
26
|
+
from tpu_inference.experimental.llama3_jax_stashed import (Llama3WeightLoader,
|
|
27
|
+
LlamaForCausalLM)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MockParam:
|
|
31
|
+
"""A mock for a parameter used in the Llama model."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, shape=(32, 128)):
|
|
34
|
+
self.value = SimpleNamespace(shape=shape)
|
|
35
|
+
# The sharding spec is accessed during weight loading
|
|
36
|
+
self.sharding = SimpleNamespace(spec=None)
|
|
37
|
+
|
|
38
|
+
# Allow the mock parameter's value to be updated
|
|
39
|
+
def __setattr__(self, name, value):
|
|
40
|
+
if name == "value":
|
|
41
|
+
self.__dict__[name] = value
|
|
42
|
+
else:
|
|
43
|
+
super().__setattr__(name, value)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MockVllmConfig:
|
|
47
|
+
"""A mock VllmConfig sufficient for testing the Llama3 model."""
|
|
48
|
+
|
|
49
|
+
def __init__(self,
|
|
50
|
+
model_name: str,
|
|
51
|
+
random_weights: bool = False,
|
|
52
|
+
tensor_parallelism: int = 1):
|
|
53
|
+
self.model_config = SimpleNamespace(model=model_name,
|
|
54
|
+
dtype="bfloat16",
|
|
55
|
+
hf_overrides={},
|
|
56
|
+
override_generation_config={})
|
|
57
|
+
self.load_config = MagicMock()
|
|
58
|
+
self.additional_config = {
|
|
59
|
+
"random_weights": random_weights,
|
|
60
|
+
"sharding": {
|
|
61
|
+
"sharding_strategy": {
|
|
62
|
+
"tensor_parallelism": tensor_parallelism
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
# NOTE (jacobplatin): we could add a quantized KV cache test, but
|
|
68
|
+
# we'll skip it for now.
|
|
69
|
+
self.cache_config = MagicMock(cache_dtype="auto")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.fixture(scope="module")
|
|
73
|
+
def mesh():
|
|
74
|
+
"""
|
|
75
|
+
Creates a mesh with all required axes for testing.
|
|
76
|
+
FIX: The sharding logic expects 'data', 'model', and 'expert' axes.
|
|
77
|
+
This creates a 3D mesh to satisfy the sharding rules, even on a single device.
|
|
78
|
+
"""
|
|
79
|
+
if not jax.devices():
|
|
80
|
+
pytest.skip("No JAX devices available for mesh creation.")
|
|
81
|
+
|
|
82
|
+
devices = np.array(jax.local_devices())
|
|
83
|
+
# Reshape devices into a 3D array to name 3 axes: data, model, and expert.
|
|
84
|
+
# The 'model' and 'expert' axes will have a size of 1.
|
|
85
|
+
num_devices = len(devices)
|
|
86
|
+
device_mesh = devices.reshape((num_devices, 1, 1))
|
|
87
|
+
|
|
88
|
+
with Mesh(device_mesh, axis_names=('data', 'model', 'expert')) as m:
|
|
89
|
+
yield m
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@pytest.fixture
|
|
93
|
+
def rng() -> PRNGKey:
|
|
94
|
+
"""Provides a reusable JAX PRNGKey."""
|
|
95
|
+
return jax.random.PRNGKey(42)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.fixture
|
|
99
|
+
def mock_vllm_config_8b() -> MockVllmConfig:
|
|
100
|
+
return MockVllmConfig(model_name="meta-llama/Llama-3-8B")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@pytest.fixture
|
|
104
|
+
def mock_vllm_config_70b() -> MockVllmConfig:
|
|
105
|
+
return MockVllmConfig(model_name="meta-llama/Llama-3-70B-Instruct")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@pytest.fixture
|
|
109
|
+
def mock_vllm_config_unknown() -> MockVllmConfig:
|
|
110
|
+
return MockVllmConfig(model_name="some-other-model")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# --- Test Cases ---
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class TestLlamaForCausalLM:
|
|
117
|
+
"""Tests for the main LlamaForCausalLM model class."""
|
|
118
|
+
|
|
119
|
+
def test_init_8b_variant(self, mock_vllm_config_8b, rng, mesh):
|
|
120
|
+
"""Tests correct parameter detection for the 8B model variant."""
|
|
121
|
+
model = LlamaForCausalLM(mock_vllm_config_8b, rng, mesh)
|
|
122
|
+
assert model.hidden_size == 4096
|
|
123
|
+
assert "8b" in model.vllm_config.model_config.model.lower()
|
|
124
|
+
|
|
125
|
+
def test_init_70b_variant(self, mock_vllm_config_70b, rng, mesh):
|
|
126
|
+
"""Tests correct parameter detection for the 70B model variant."""
|
|
127
|
+
model = nnx.eval_shape(
|
|
128
|
+
lambda: LlamaForCausalLM(mock_vllm_config_70b, rng, mesh))
|
|
129
|
+
assert model.hidden_size == 8192
|
|
130
|
+
assert "70b" in model.vllm_config.model_config.model.lower()
|
|
131
|
+
|
|
132
|
+
def test_init_unknown_variant_raises_error(self, mock_vllm_config_unknown,
|
|
133
|
+
rng, mesh):
|
|
134
|
+
"""Tests that an unknown model variant raises a ValueError."""
|
|
135
|
+
with pytest.raises(ValueError,
|
|
136
|
+
match="Could not determine Llama3 variant"):
|
|
137
|
+
LlamaForCausalLM(mock_vllm_config_unknown, rng, mesh)
|
|
138
|
+
|
|
139
|
+
def test_create_model_with_random_weights(self, mock_vllm_config_8b, rng,
|
|
140
|
+
mesh):
|
|
141
|
+
"""
|
|
142
|
+
Tests that random weight initialization creates concrete, non-zero-variance arrays.
|
|
143
|
+
"""
|
|
144
|
+
with jax.set_mesh(mesh):
|
|
145
|
+
model = LlamaForCausalLM(vllm_config=mock_vllm_config_8b,
|
|
146
|
+
rng=rng,
|
|
147
|
+
mesh=mesh,
|
|
148
|
+
force_random_weights=True)
|
|
149
|
+
|
|
150
|
+
embedding_weight = model.embedder.input_embedding_table_VD.value
|
|
151
|
+
attention_q_kernel = model.layers[0].attn.kernel_q_proj_DNH.value
|
|
152
|
+
final_norm_scale = model.final_norm.scale.value
|
|
153
|
+
|
|
154
|
+
assert isinstance(embedding_weight, jax.Array)
|
|
155
|
+
assert isinstance(attention_q_kernel, jax.Array)
|
|
156
|
+
assert isinstance(final_norm_scale, jax.Array)
|
|
157
|
+
|
|
158
|
+
assert jnp.std(embedding_weight) > 0
|
|
159
|
+
assert jnp.std(attention_q_kernel) > 0
|
|
160
|
+
|
|
161
|
+
assert jnp.all(final_norm_scale == 1.0)
|
|
162
|
+
|
|
163
|
+
@patch("tpu_inference.experimental.llama3_jax_stashed.Llama3WeightLoader")
|
|
164
|
+
def test_load_weights_called_correctly(self, mock_loader_cls, rng, mesh):
|
|
165
|
+
"""Tests that the weight loader is called correctly for checkpoint loading."""
|
|
166
|
+
vllm_config = MockVllmConfig(model_name="llama3-8b",
|
|
167
|
+
random_weights=False)
|
|
168
|
+
model = LlamaForCausalLM(vllm_config, rng, mesh)
|
|
169
|
+
|
|
170
|
+
mock_loader_instance = MagicMock()
|
|
171
|
+
mock_loader_cls.return_value = mock_loader_instance
|
|
172
|
+
model.load_weights(rng, cache_dir="/tmp/cache")
|
|
173
|
+
mock_loader_cls.assert_called_once_with(vllm_config=vllm_config,
|
|
174
|
+
hidden_size=4096,
|
|
175
|
+
attn_heads=32,
|
|
176
|
+
num_key_value_heads=8,
|
|
177
|
+
attn_head_dim=128)
|
|
178
|
+
mock_loader_instance.load_weights.assert_called_once_with(model)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class TestLlama3WeightLoader:
|
|
182
|
+
"""Tests for the Llama3WeightLoader class."""
|
|
183
|
+
|
|
184
|
+
@pytest.fixture
|
|
185
|
+
def weight_loader(self):
|
|
186
|
+
# Patch the superclass's setup to isolate the Llama3 loader's logic
|
|
187
|
+
return Llama3WeightLoader(vllm_config=MockVllmConfig("test-model"),
|
|
188
|
+
hidden_size=32,
|
|
189
|
+
attn_heads=4,
|
|
190
|
+
num_key_value_heads=2,
|
|
191
|
+
attn_head_dim=8)
|
|
192
|
+
|
|
193
|
+
def test_load_weights_transformation(self, weight_loader, rng, mesh):
|
|
194
|
+
"""Tests that weights are correctly reshaped, transposed, and loaded."""
|
|
195
|
+
vllm_config = MockVllmConfig("llama3-8b-small-test",
|
|
196
|
+
random_weights=False)
|
|
197
|
+
|
|
198
|
+
# Create a model instance but override its config for the test.
|
|
199
|
+
model = LlamaForCausalLM(vllm_config, rng, mesh)
|
|
200
|
+
|
|
201
|
+
with patch(
|
|
202
|
+
"tpu_inference.experimental.llama3_jax_stashed.load_hf_weights"
|
|
203
|
+
) as mock_load:
|
|
204
|
+
# This will now pass after the code fix
|
|
205
|
+
weight_loader.load_weights(model)
|
|
206
|
+
|
|
207
|
+
# Assert that shard_put was called with the correctly transposed weight
|
|
208
|
+
mock_load.assert_called_once()
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from absl.testing import absltest, parameterized
|
|
8
|
+
from jax._src import test_util as jtu
|
|
9
|
+
|
|
10
|
+
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.kernels.collectives import all_gather_matmul
|
|
12
|
+
|
|
13
|
+
jax.config.parse_flags_with_absl()
|
|
14
|
+
|
|
15
|
+
P = jax.sharding.PartitionSpec
|
|
16
|
+
|
|
17
|
+
SpongeDir: str | None = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', None)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@jtu.with_config(jax_numpy_dtype_promotion='standard')
|
|
21
|
+
class AllGatherMatmulTest(jtu.JaxTestCase):
|
|
22
|
+
|
|
23
|
+
@parameterized.product(
|
|
24
|
+
grid_k=[1, 2, 3],
|
|
25
|
+
grid_n=[1, 2, 3],
|
|
26
|
+
rhs_transpose=[True, False],
|
|
27
|
+
)
|
|
28
|
+
def test_all_gather_matmul(self, grid_k, grid_n, rhs_transpose):
|
|
29
|
+
if jax.device_count() != 8:
|
|
30
|
+
self.skipTest('Not enough devices for test')
|
|
31
|
+
|
|
32
|
+
axis_name = 'x'
|
|
33
|
+
num_devices = jax.device_count()
|
|
34
|
+
mesh = utils.make_optimized_mesh((num_devices, ), (axis_name, ))
|
|
35
|
+
bk, bn = 1024, 1024
|
|
36
|
+
m, k, n = 1024, bk * grid_k, bn * grid_n * num_devices
|
|
37
|
+
|
|
38
|
+
# Run the test 10 times to expose race conditions as much as possible.
|
|
39
|
+
for i in range(10):
|
|
40
|
+
# Create input data
|
|
41
|
+
prng_key = jax.random.key(1234 + i)
|
|
42
|
+
k0, k1 = jax.random.split(prng_key, 2)
|
|
43
|
+
x = jax.random.normal(k0, (m, k), dtype=jnp.bfloat16)
|
|
44
|
+
y_shape = (n, k) if rhs_transpose else (k, n)
|
|
45
|
+
y_sharding = P(axis_name, None) if rhs_transpose else P(
|
|
46
|
+
None, axis_name)
|
|
47
|
+
y = jax.random.normal(k1, y_shape, dtype=jnp.bfloat16)
|
|
48
|
+
sharded_x = jax.device_put(
|
|
49
|
+
x, jax.sharding.NamedSharding(mesh, P(axis_name, None)))
|
|
50
|
+
sharded_y = jax.device_put(
|
|
51
|
+
y, jax.sharding.NamedSharding(mesh, y_sharding))
|
|
52
|
+
|
|
53
|
+
# Run the all_gather_matmul function
|
|
54
|
+
output = all_gather_matmul.all_gather_matmul(
|
|
55
|
+
sharded_x,
|
|
56
|
+
sharded_y,
|
|
57
|
+
mesh,
|
|
58
|
+
axis_name,
|
|
59
|
+
bk=bk,
|
|
60
|
+
bn=bn,
|
|
61
|
+
rhs_transpose=rhs_transpose,
|
|
62
|
+
)
|
|
63
|
+
y_for_dot = sharded_y.T if rhs_transpose else sharded_y
|
|
64
|
+
expected_output = jnp.dot(sharded_x, y_for_dot)
|
|
65
|
+
self.assertAllClose(output, expected_output, atol=1e-2, rtol=1e-2)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
if __name__ == "__main__":
|
|
69
|
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
import numpy as np
|
|
18
|
+
from absl.testing import absltest, parameterized
|
|
19
|
+
from jax._src import test_util as jtu
|
|
20
|
+
from jax.sharding import Mesh
|
|
21
|
+
|
|
22
|
+
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe
|
|
23
|
+
|
|
24
|
+
jax.config.parse_flags_with_absl()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def cdiv(a, b):
|
|
28
|
+
assert b != 0
|
|
29
|
+
return (a + b - 1) // b
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def align_to(x, a):
|
|
33
|
+
return cdiv(x, a) * a
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def gen_moe_inputs(
|
|
37
|
+
dtype,
|
|
38
|
+
top_k,
|
|
39
|
+
num_experts,
|
|
40
|
+
hidden_size,
|
|
41
|
+
intermediate_size,
|
|
42
|
+
num_tokens,
|
|
43
|
+
*,
|
|
44
|
+
seed=1234,
|
|
45
|
+
has_bias=False,
|
|
46
|
+
):
|
|
47
|
+
key = jax.random.key(seed)
|
|
48
|
+
k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7)
|
|
49
|
+
|
|
50
|
+
a = jax.random.normal(k0, (num_tokens, hidden_size),
|
|
51
|
+
dtype=jnp.float32).astype(dtype) / 10
|
|
52
|
+
|
|
53
|
+
w1 = (jax.random.normal(
|
|
54
|
+
k1,
|
|
55
|
+
(num_experts, 2, hidden_size, intermediate_size),
|
|
56
|
+
dtype=jnp.float32,
|
|
57
|
+
) / 10).astype(dtype)
|
|
58
|
+
w2 = (jax.random.normal(k2, (num_experts, intermediate_size, hidden_size),
|
|
59
|
+
dtype=jnp.float32) / 10).astype(dtype)
|
|
60
|
+
|
|
61
|
+
if has_bias:
|
|
62
|
+
b1 = (jax.random.normal(k3, (num_experts, 2, intermediate_size),
|
|
63
|
+
dtype=jnp.float32) / 10).astype(dtype)
|
|
64
|
+
b2 = (jax.random.normal(k4, (num_experts, hidden_size),
|
|
65
|
+
dtype=jnp.float32) / 10).astype(dtype)
|
|
66
|
+
else:
|
|
67
|
+
b1 = b2 = None
|
|
68
|
+
|
|
69
|
+
gating_output = (
|
|
70
|
+
jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) +
|
|
71
|
+
jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(
|
|
72
|
+
num_tokens, num_experts) / 100)
|
|
73
|
+
|
|
74
|
+
# To generate unique top-k!
|
|
75
|
+
top_k_indices = jax.random.randint(k6, (num_tokens, top_k),
|
|
76
|
+
minval=0,
|
|
77
|
+
maxval=num_experts - 1,
|
|
78
|
+
dtype=jnp.int32)
|
|
79
|
+
|
|
80
|
+
one_hot = (jnp.sum(
|
|
81
|
+
jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32),
|
|
82
|
+
axis=1,
|
|
83
|
+
) * 30)
|
|
84
|
+
|
|
85
|
+
gating_output = (gating_output + one_hot).astype(dtype)
|
|
86
|
+
|
|
87
|
+
return a, w1, w2, b1, b2, gating_output
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def sub_channel_quantize(x, quant_dtype, wsz=256):
|
|
91
|
+
"""Quantizes x with sub-channel quantization on the 2nd minor."""
|
|
92
|
+
if jnp.issubdtype(quant_dtype, jnp.floating):
|
|
93
|
+
dtype_info = jnp.finfo(quant_dtype)
|
|
94
|
+
else:
|
|
95
|
+
dtype_info = jnp.iinfo(quant_dtype)
|
|
96
|
+
dtype_max = float(dtype_info.max)
|
|
97
|
+
w_lst, scale_lst = [], []
|
|
98
|
+
assert len(x.shape) >= 2
|
|
99
|
+
assert x.shape[-2] % wsz == 0
|
|
100
|
+
for i in range(0, x.shape[-2], wsz):
|
|
101
|
+
y = x[..., i:i + wsz, :]
|
|
102
|
+
abs_max = jnp.abs(y).max(axis=-2, keepdims=True)
|
|
103
|
+
scale = (abs_max / dtype_max).astype(jnp.float32)
|
|
104
|
+
w = (y / scale).astype(quant_dtype)
|
|
105
|
+
w_lst.append(w)
|
|
106
|
+
scale_lst.append(scale)
|
|
107
|
+
return jnp.concat(w_lst, axis=-2), jnp.concat(scale_lst, axis=-2)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
111
|
+
class MoEKernelTest(jtu.JaxTestCase):
|
|
112
|
+
|
|
113
|
+
def setUp(self):
|
|
114
|
+
super().setUp()
|
|
115
|
+
self.mesh_devices = sorted(
|
|
116
|
+
jax.devices(),
|
|
117
|
+
key=lambda x: (
|
|
118
|
+
x.coords[0],
|
|
119
|
+
(-1 if x.coords[0] % 2 else 1) * x.coords[1],
|
|
120
|
+
),
|
|
121
|
+
)
|
|
122
|
+
self.mesh = Mesh(np.array(self.mesh_devices).reshape(1, -1),
|
|
123
|
+
axis_names=("data", "model"))
|
|
124
|
+
|
|
125
|
+
def _test_moe(
|
|
126
|
+
self,
|
|
127
|
+
dtype,
|
|
128
|
+
top_k,
|
|
129
|
+
num_experts,
|
|
130
|
+
hidden_size,
|
|
131
|
+
intermediate_size,
|
|
132
|
+
num_tokens,
|
|
133
|
+
seed,
|
|
134
|
+
renormalize_topk_logits,
|
|
135
|
+
bt,
|
|
136
|
+
bf,
|
|
137
|
+
bd1,
|
|
138
|
+
bd2,
|
|
139
|
+
btc,
|
|
140
|
+
bfc,
|
|
141
|
+
bd1c,
|
|
142
|
+
bd2c,
|
|
143
|
+
act_fn="silu",
|
|
144
|
+
w_dtype=None,
|
|
145
|
+
subc_quant_wsz=None,
|
|
146
|
+
has_bias=False,
|
|
147
|
+
atol=2e-1,
|
|
148
|
+
rtol=2e-1,
|
|
149
|
+
):
|
|
150
|
+
a, w1, w2, b1, b2, gating_output = gen_moe_inputs(
|
|
151
|
+
dtype,
|
|
152
|
+
top_k,
|
|
153
|
+
num_experts,
|
|
154
|
+
hidden_size,
|
|
155
|
+
intermediate_size,
|
|
156
|
+
num_tokens,
|
|
157
|
+
seed=seed,
|
|
158
|
+
has_bias=has_bias,
|
|
159
|
+
)
|
|
160
|
+
w1_scale = None
|
|
161
|
+
w2_scale = None
|
|
162
|
+
if w_dtype is not None:
|
|
163
|
+
if subc_quant_wsz is None:
|
|
164
|
+
subc_quant_wsz = 256
|
|
165
|
+
w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz)
|
|
166
|
+
w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz)
|
|
167
|
+
|
|
168
|
+
actual = fused_ep_moe(
|
|
169
|
+
mesh=self.mesh,
|
|
170
|
+
tokens=a,
|
|
171
|
+
w1=w1,
|
|
172
|
+
w2=w2,
|
|
173
|
+
gating_output=gating_output,
|
|
174
|
+
top_k=top_k,
|
|
175
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
176
|
+
act_fn=act_fn,
|
|
177
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
178
|
+
w1_scale=w1_scale,
|
|
179
|
+
w2_scale=w2_scale,
|
|
180
|
+
b1=b1,
|
|
181
|
+
b2=b2,
|
|
182
|
+
bt=bt,
|
|
183
|
+
bf=bf,
|
|
184
|
+
bd1=bd1,
|
|
185
|
+
bd2=bd2,
|
|
186
|
+
btc=btc,
|
|
187
|
+
bfc=bfc,
|
|
188
|
+
bd1c=bd1c,
|
|
189
|
+
bd2c=bd2c,
|
|
190
|
+
)
|
|
191
|
+
expected = ref_moe(
|
|
192
|
+
a,
|
|
193
|
+
w1,
|
|
194
|
+
w2,
|
|
195
|
+
gating_output,
|
|
196
|
+
top_k,
|
|
197
|
+
b1=b1,
|
|
198
|
+
b2=b2,
|
|
199
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
200
|
+
activation=act_fn,
|
|
201
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
202
|
+
w1_scale=w1_scale,
|
|
203
|
+
w2_scale=w2_scale,
|
|
204
|
+
)
|
|
205
|
+
self.assertAllClose(actual, expected, atol=atol, rtol=rtol)
|
|
206
|
+
|
|
207
|
+
@parameterized.product(renormalize_topk_logits=[True, False], )
|
|
208
|
+
def test_basic(self, renormalize_topk_logits):
|
|
209
|
+
dtype = jnp.bfloat16
|
|
210
|
+
top_k = 8
|
|
211
|
+
num_experts = 128
|
|
212
|
+
hidden_size = 1024
|
|
213
|
+
intermediate_size = 1024
|
|
214
|
+
num_tokens = 8 * 32
|
|
215
|
+
self._test_moe(
|
|
216
|
+
dtype=dtype,
|
|
217
|
+
top_k=top_k,
|
|
218
|
+
num_experts=num_experts,
|
|
219
|
+
hidden_size=hidden_size,
|
|
220
|
+
intermediate_size=intermediate_size,
|
|
221
|
+
num_tokens=num_tokens,
|
|
222
|
+
seed=1234,
|
|
223
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
224
|
+
bt=32,
|
|
225
|
+
bf=1024,
|
|
226
|
+
bd1=1024,
|
|
227
|
+
bd2=1024,
|
|
228
|
+
btc=32,
|
|
229
|
+
bfc=256,
|
|
230
|
+
bd1c=256,
|
|
231
|
+
bd2c=256,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
@parameterized.product(act_fn=["silu", "gelu", "swigluoai"], )
|
|
235
|
+
def test_activation(self, act_fn):
|
|
236
|
+
dtype = jnp.bfloat16
|
|
237
|
+
top_k = 8
|
|
238
|
+
num_experts = 128
|
|
239
|
+
hidden_size = 1024
|
|
240
|
+
intermediate_size = 1024
|
|
241
|
+
num_tokens = 8 * 32
|
|
242
|
+
self._test_moe(
|
|
243
|
+
dtype=dtype,
|
|
244
|
+
top_k=top_k,
|
|
245
|
+
num_experts=num_experts,
|
|
246
|
+
hidden_size=hidden_size,
|
|
247
|
+
intermediate_size=intermediate_size,
|
|
248
|
+
num_tokens=num_tokens,
|
|
249
|
+
seed=1234,
|
|
250
|
+
renormalize_topk_logits=True,
|
|
251
|
+
act_fn=act_fn,
|
|
252
|
+
bt=32,
|
|
253
|
+
bf=512,
|
|
254
|
+
bd1=512,
|
|
255
|
+
bd2=512,
|
|
256
|
+
btc=32,
|
|
257
|
+
bfc=256,
|
|
258
|
+
bd1c=256,
|
|
259
|
+
bd2c=256,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def test_benchmark_qwen_235(self):
|
|
263
|
+
num_experts = 128
|
|
264
|
+
top_k = 8
|
|
265
|
+
hidden_size = 4096
|
|
266
|
+
intermediate_size = 1536
|
|
267
|
+
dtype = jnp.bfloat16
|
|
268
|
+
num_tokens = 8 * 64
|
|
269
|
+
seed = 54321
|
|
270
|
+
renormalize_topk_logits = True
|
|
271
|
+
self._test_moe(
|
|
272
|
+
dtype=dtype,
|
|
273
|
+
top_k=top_k,
|
|
274
|
+
num_experts=num_experts,
|
|
275
|
+
hidden_size=hidden_size,
|
|
276
|
+
intermediate_size=intermediate_size,
|
|
277
|
+
num_tokens=num_tokens,
|
|
278
|
+
seed=seed,
|
|
279
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
280
|
+
bt=64,
|
|
281
|
+
bf=768,
|
|
282
|
+
bd1=2048,
|
|
283
|
+
bd2=2048,
|
|
284
|
+
btc=64,
|
|
285
|
+
bfc=768,
|
|
286
|
+
bd1c=2048,
|
|
287
|
+
bd2c=2048,
|
|
288
|
+
act_fn="silu",
|
|
289
|
+
atol=5e-2,
|
|
290
|
+
rtol=5e-2,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def test_benchmark_qwen_30b_a3b(self):
|
|
294
|
+
num_experts = 128
|
|
295
|
+
top_k = 8
|
|
296
|
+
hidden_size = 2048
|
|
297
|
+
intermediate_size = 768
|
|
298
|
+
dtype = jnp.bfloat16
|
|
299
|
+
num_tokens = 512
|
|
300
|
+
seed = 54321
|
|
301
|
+
renormalize_topk_logits = True
|
|
302
|
+
self._test_moe(
|
|
303
|
+
dtype=dtype,
|
|
304
|
+
top_k=top_k,
|
|
305
|
+
num_experts=num_experts,
|
|
306
|
+
hidden_size=hidden_size,
|
|
307
|
+
intermediate_size=intermediate_size,
|
|
308
|
+
num_tokens=num_tokens,
|
|
309
|
+
seed=seed,
|
|
310
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
311
|
+
bt=16,
|
|
312
|
+
bf=384,
|
|
313
|
+
bd1=512,
|
|
314
|
+
bd2=512,
|
|
315
|
+
btc=16,
|
|
316
|
+
bfc=384,
|
|
317
|
+
bd1c=256,
|
|
318
|
+
bd2c=256,
|
|
319
|
+
act_fn="silu",
|
|
320
|
+
atol=5e-2,
|
|
321
|
+
rtol=5e-2,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
@parameterized.product(
|
|
325
|
+
w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], )
|
|
326
|
+
def test_sub_channel_quantization(self, w_dtype):
|
|
327
|
+
if w_dtype in (
|
|
328
|
+
jnp.float8_e5m2,
|
|
329
|
+
jnp.float4_e2m1fn,
|
|
330
|
+
) and not jtu.is_device_tpu_at_least(version=7):
|
|
331
|
+
self.skipTest("Expect TPUv7+")
|
|
332
|
+
dtype = jnp.bfloat16
|
|
333
|
+
top_k = 8
|
|
334
|
+
num_experts = 128
|
|
335
|
+
hidden_size = 1024
|
|
336
|
+
intermediate_size = 1024
|
|
337
|
+
num_tokens = 8 * 32
|
|
338
|
+
self._test_moe(
|
|
339
|
+
dtype=dtype,
|
|
340
|
+
top_k=top_k,
|
|
341
|
+
num_experts=num_experts,
|
|
342
|
+
hidden_size=hidden_size,
|
|
343
|
+
intermediate_size=intermediate_size,
|
|
344
|
+
num_tokens=num_tokens,
|
|
345
|
+
seed=1234,
|
|
346
|
+
renormalize_topk_logits=False,
|
|
347
|
+
w_dtype=w_dtype,
|
|
348
|
+
subc_quant_wsz=256,
|
|
349
|
+
bt=32,
|
|
350
|
+
bf=1024,
|
|
351
|
+
bd1=1024,
|
|
352
|
+
bd2=1024,
|
|
353
|
+
btc=32,
|
|
354
|
+
bfc=256,
|
|
355
|
+
bd1c=256,
|
|
356
|
+
bd2c=256,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
def test_bias(self):
|
|
360
|
+
dtype = jnp.bfloat16
|
|
361
|
+
top_k = 8
|
|
362
|
+
num_experts = 128
|
|
363
|
+
hidden_size = 1024
|
|
364
|
+
intermediate_size = 1024
|
|
365
|
+
num_tokens = 8 * 32
|
|
366
|
+
self._test_moe(
|
|
367
|
+
dtype=dtype,
|
|
368
|
+
top_k=top_k,
|
|
369
|
+
num_experts=num_experts,
|
|
370
|
+
hidden_size=hidden_size,
|
|
371
|
+
intermediate_size=intermediate_size,
|
|
372
|
+
num_tokens=num_tokens,
|
|
373
|
+
seed=1234,
|
|
374
|
+
renormalize_topk_logits=False,
|
|
375
|
+
has_bias=True,
|
|
376
|
+
bt=32,
|
|
377
|
+
bf=512,
|
|
378
|
+
bd1=512,
|
|
379
|
+
bd2=512,
|
|
380
|
+
btc=32,
|
|
381
|
+
bfc=256,
|
|
382
|
+
bd1c=256,
|
|
383
|
+
bd2c=256,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
if __name__ == "__main__":
|
|
388
|
+
absltest.main(testLoader=jtu.JaxTestLoader())
|