tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import unittest
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
|
|
19
|
+
import chex
|
|
20
|
+
|
|
21
|
+
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
|
|
22
|
+
|
|
23
|
+
import jax
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
from flax import nnx
|
|
26
|
+
from jax.sharding import NamedSharding
|
|
27
|
+
from jax.sharding import PartitionSpec as P
|
|
28
|
+
|
|
29
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
30
|
+
from tpu_inference.layers.common.sharding import build_mesh
|
|
31
|
+
from tpu_inference.layers.jax.attention.llama4_attention import (
|
|
32
|
+
L2Norm, Llama4Attention)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class SimpleVLLMConfig:
|
|
37
|
+
additional_config: dict = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Llama4AttentionTest(unittest.TestCase):
|
|
41
|
+
"""Unit test suite for Llama4-specific attention components."""
|
|
42
|
+
|
|
43
|
+
def setUp(self):
|
|
44
|
+
devices = jax.devices()[:1]
|
|
45
|
+
sharding_strategy = {"tensor_parallelism": len(devices)}
|
|
46
|
+
self.mesh = build_mesh(devices, sharding_strategy)
|
|
47
|
+
|
|
48
|
+
def test_l2norm_forward_pass(self):
|
|
49
|
+
"""Tests the forward pass of the L2Norm module with hardcoded values."""
|
|
50
|
+
eps = 1e-5
|
|
51
|
+
l2_norm = L2Norm(eps=eps)
|
|
52
|
+
x = jnp.array([[1.0, 2.0, 3.0, 4.0]], dtype=jnp.float32)
|
|
53
|
+
|
|
54
|
+
output = l2_norm(x)
|
|
55
|
+
|
|
56
|
+
self.assertEqual(output.shape, x.shape)
|
|
57
|
+
self.assertEqual(output.dtype, x.dtype)
|
|
58
|
+
|
|
59
|
+
# Expected values calculated manually:
|
|
60
|
+
# mean_sq = (1^2 + 2^2 + 3^2 + 4^2) / 4 = (1+4+9+16)/4 = 30/4 = 7.5
|
|
61
|
+
# norm_val = sqrt(7.5 + 1e-5)
|
|
62
|
+
# expected = x / norm_val
|
|
63
|
+
expected_output = jnp.array([[0.365148, 0.730297, 1.095445, 1.460594]],
|
|
64
|
+
dtype=jnp.float32)
|
|
65
|
+
self.assertTrue(jnp.allclose(output, expected_output, atol=1e-6))
|
|
66
|
+
|
|
67
|
+
def test_l2norm_with_zeros(self):
|
|
68
|
+
"""Tests L2Norm with an all-zero input."""
|
|
69
|
+
l2_norm = L2Norm(eps=1e-5)
|
|
70
|
+
x = jnp.zeros((4, 8, 16))
|
|
71
|
+
output = l2_norm(x)
|
|
72
|
+
self.assertEqual(output.shape, x.shape)
|
|
73
|
+
# Output should be all zeros.
|
|
74
|
+
self.assertTrue(jnp.all(output == 0))
|
|
75
|
+
|
|
76
|
+
def test_l2norm_eps_effect(self):
|
|
77
|
+
"""Tests the effect of the epsilon value in L2Norm."""
|
|
78
|
+
eps = 1e-3
|
|
79
|
+
l2_norm = L2Norm(eps=eps)
|
|
80
|
+
x = jax.random.normal(jax.random.PRNGKey(0), (1, 1, 128))
|
|
81
|
+
output = l2_norm(x)
|
|
82
|
+
|
|
83
|
+
mean_sq = jnp.mean(x**2, axis=-1, keepdims=True)
|
|
84
|
+
expected_output = x * jax.lax.rsqrt(mean_sq + eps)
|
|
85
|
+
|
|
86
|
+
self.assertTrue(jnp.allclose(output, expected_output))
|
|
87
|
+
|
|
88
|
+
def test_apply_temperature_tuning(self):
|
|
89
|
+
with jax.set_mesh(self.mesh):
|
|
90
|
+
hidden_size = 64
|
|
91
|
+
num_attention_heads = 4
|
|
92
|
+
head_dim = hidden_size // num_attention_heads
|
|
93
|
+
|
|
94
|
+
# Create dummy sharding objects
|
|
95
|
+
dummy_sharding = NamedSharding(self.mesh, P())
|
|
96
|
+
|
|
97
|
+
llama4_attention = Llama4Attention(
|
|
98
|
+
hidden_size=hidden_size,
|
|
99
|
+
num_attention_heads=num_attention_heads,
|
|
100
|
+
num_key_value_heads=num_attention_heads,
|
|
101
|
+
head_dim=head_dim,
|
|
102
|
+
rope_theta=10000.0,
|
|
103
|
+
rope_scaling={},
|
|
104
|
+
dtype=jnp.bfloat16,
|
|
105
|
+
kv_cache_dtype="auto",
|
|
106
|
+
use_qk_norm=False,
|
|
107
|
+
temperature_tuning=True,
|
|
108
|
+
temperature_tuning_scale=2.0,
|
|
109
|
+
temperature_tuning_floor_scale=2.0,
|
|
110
|
+
mesh=self.mesh,
|
|
111
|
+
random_init=True,
|
|
112
|
+
activation_attention_td=dummy_sharding,
|
|
113
|
+
activation_attention_out_td=dummy_sharding,
|
|
114
|
+
rngs=nnx.Rngs(42),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
seq_len = 8
|
|
118
|
+
input_arr_TNH = jnp.ones((seq_len, num_attention_heads, head_dim),
|
|
119
|
+
dtype=jnp.bfloat16)
|
|
120
|
+
attention_metadata = AttentionMetadata(
|
|
121
|
+
input_positions=jnp.arange(seq_len, dtype=jnp.int32))
|
|
122
|
+
expected_scales = jnp.array(
|
|
123
|
+
[1, 2.375, 2.375, 3.20312, 3.20312, 3.76562, 3.76562, 4.21875],
|
|
124
|
+
dtype=jnp.bfloat16)
|
|
125
|
+
output = llama4_attention.apply_temperature_tuning(
|
|
126
|
+
attention_metadata, input_arr_TNH)
|
|
127
|
+
chex.assert_shape(output, (seq_len, num_attention_heads, head_dim))
|
|
128
|
+
|
|
129
|
+
expected_output = jnp.ones_like(
|
|
130
|
+
input_arr_TNH) * expected_scales[:, None, None]
|
|
131
|
+
chex.assert_trees_all_close(output, expected_output, atol=1e-3)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
if __name__ == "__main__":
|
|
135
|
+
unittest.main()
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
from flax import nnx
|
|
21
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
22
|
+
|
|
23
|
+
from tpu_inference.layers.jax.moe.deepseek_v3_moe import (DeepSeekV3Router,
|
|
24
|
+
SparseMoE)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TestDeepSeekV3Router(unittest.TestCase):
|
|
28
|
+
|
|
29
|
+
def setUp(self):
|
|
30
|
+
self.cpu_mesh = Mesh(jax.devices('cpu'), axis_names=('data', ))
|
|
31
|
+
|
|
32
|
+
def test_get_topk_indices_single_group(self):
|
|
33
|
+
"""Test get_topk_indices with single expert group."""
|
|
34
|
+
with jax.set_mesh(self.cpu_mesh):
|
|
35
|
+
router = DeepSeekV3Router(random_init=True,
|
|
36
|
+
hidden_size=512,
|
|
37
|
+
num_experts=4,
|
|
38
|
+
num_experts_per_tok=2,
|
|
39
|
+
n_groups=1,
|
|
40
|
+
topk_groups=1,
|
|
41
|
+
norm_topk_prob=True,
|
|
42
|
+
routed_scaling_factor=1.0,
|
|
43
|
+
dtype=jnp.bfloat16,
|
|
44
|
+
rngs=nnx.Rngs(42))
|
|
45
|
+
router.bias_E = jnp.zeros((4, ))
|
|
46
|
+
|
|
47
|
+
scores = jnp.array([[0.1, 0.3, 0.2, 0.4]]) # shape: (1, 4)
|
|
48
|
+
indices = router.get_topk_indices(scores)
|
|
49
|
+
|
|
50
|
+
# Should return indices of top 2 experts
|
|
51
|
+
expected_indices = jnp.array([[3,
|
|
52
|
+
1]]) # experts with scores 0.4, 0.3
|
|
53
|
+
self.assertTrue(jnp.array_equal(indices, expected_indices))
|
|
54
|
+
|
|
55
|
+
def test_get_topk_indices_2_groups(self):
|
|
56
|
+
"""Test get_topk_indices with 2 expert groups."""
|
|
57
|
+
with jax.set_mesh(self.cpu_mesh):
|
|
58
|
+
router = DeepSeekV3Router(random_init=True,
|
|
59
|
+
hidden_size=512,
|
|
60
|
+
num_experts=4,
|
|
61
|
+
num_experts_per_tok=2,
|
|
62
|
+
n_groups=2,
|
|
63
|
+
topk_groups=1,
|
|
64
|
+
norm_topk_prob=True,
|
|
65
|
+
routed_scaling_factor=1.0,
|
|
66
|
+
dtype=jnp.bfloat16,
|
|
67
|
+
rngs=nnx.Rngs(42))
|
|
68
|
+
router.bias_E = jnp.zeros((4, ))
|
|
69
|
+
|
|
70
|
+
# 4 experts, 2 groups, 2 experts per group
|
|
71
|
+
scores = jnp.array([[[0.1, 0.3, 0.2, 0.4]]]) # shape: (1, 1, 4)
|
|
72
|
+
indices = router.get_topk_indices(scores)
|
|
73
|
+
|
|
74
|
+
# Should return indices of top 2 experts
|
|
75
|
+
expected_indices = jnp.array([[[3, 2]]])
|
|
76
|
+
self.assertTrue(jnp.array_equal(indices, expected_indices))
|
|
77
|
+
|
|
78
|
+
def test_router_e2e(self):
|
|
79
|
+
with jax.set_mesh(self.cpu_mesh):
|
|
80
|
+
router = DeepSeekV3Router(random_init=True,
|
|
81
|
+
hidden_size=512,
|
|
82
|
+
num_experts=8,
|
|
83
|
+
num_experts_per_tok=2,
|
|
84
|
+
n_groups=2,
|
|
85
|
+
topk_groups=1,
|
|
86
|
+
norm_topk_prob=True,
|
|
87
|
+
routed_scaling_factor=1.0,
|
|
88
|
+
dtype=jnp.bfloat16,
|
|
89
|
+
rngs=nnx.Rngs(42))
|
|
90
|
+
x = jnp.ones((2, 512))
|
|
91
|
+
weights, indices = router(x)
|
|
92
|
+
self.assertEqual(weights.shape, (2, 2))
|
|
93
|
+
self.assertEqual(indices.shape, (2, 2))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class TestSparseMoE(unittest.TestCase):
|
|
97
|
+
|
|
98
|
+
def setUp(self):
|
|
99
|
+
"""Set up a multi-device mesh and a sample MoE layer for testing."""
|
|
100
|
+
devices = jax.devices()
|
|
101
|
+
self.device_count = len(devices)
|
|
102
|
+
if self.device_count < 8:
|
|
103
|
+
self.skipTest("This test requires at least 8 simulated devices.")
|
|
104
|
+
|
|
105
|
+
# This mesh will have a 'model' axis for expert parallelism
|
|
106
|
+
mesh_shape = (self.device_count, 1)
|
|
107
|
+
device_mesh_array = np.array(devices).reshape(mesh_shape)
|
|
108
|
+
|
|
109
|
+
# Define the axis names
|
|
110
|
+
axis_names = ('model', 'data')
|
|
111
|
+
|
|
112
|
+
# Create the 2D mesh
|
|
113
|
+
self.mesh = Mesh(device_mesh_array, axis_names=axis_names)
|
|
114
|
+
|
|
115
|
+
# --- Model Configuration ---
|
|
116
|
+
self.B, self.S, self.D = 2, 4, 16 # Batch, Sequence, Hidden Dim
|
|
117
|
+
self.E, self.K = 16, 8 # Num Experts, Experts per Token
|
|
118
|
+
self.moe_intermediate_size = 32 # FFN Dim
|
|
119
|
+
self.num_expert_parallelism = 8 # Shard experts across 8 devices
|
|
120
|
+
|
|
121
|
+
self.key = jax.random.PRNGKey(42)
|
|
122
|
+
self.x = jax.random.normal(self.key, (self.B * self.S, self.D),
|
|
123
|
+
dtype=jnp.bfloat16)
|
|
124
|
+
|
|
125
|
+
# --- Instantiate MoE Layer ---
|
|
126
|
+
# We need to do this inside the mesh context
|
|
127
|
+
with self.mesh:
|
|
128
|
+
router = DeepSeekV3Router(hidden_size=self.D,
|
|
129
|
+
num_experts=self.E,
|
|
130
|
+
num_experts_per_tok=self.K,
|
|
131
|
+
n_groups=1,
|
|
132
|
+
topk_groups=1,
|
|
133
|
+
norm_topk_prob=False,
|
|
134
|
+
routed_scaling_factor=1.0,
|
|
135
|
+
dtype=jnp.bfloat16,
|
|
136
|
+
rngs=nnx.Rngs(self.key),
|
|
137
|
+
ed_sharding=PartitionSpec(),
|
|
138
|
+
e_sharding=PartitionSpec(),
|
|
139
|
+
activation_ffw_td=PartitionSpec(
|
|
140
|
+
'data', None))
|
|
141
|
+
# Instantiation updated to match user's code snippet
|
|
142
|
+
self.moe = SparseMoE(
|
|
143
|
+
hidden_size=self.D,
|
|
144
|
+
intermediate_size_moe=self.moe_intermediate_size,
|
|
145
|
+
num_local_experts=self.E,
|
|
146
|
+
hidden_act="silu",
|
|
147
|
+
num_experts_per_tok=self.K,
|
|
148
|
+
router=router,
|
|
149
|
+
dtype=jnp.bfloat16,
|
|
150
|
+
rngs=nnx.Rngs(self.key),
|
|
151
|
+
mesh=self.mesh,
|
|
152
|
+
apply_expert_weight_before_computation=False,
|
|
153
|
+
|
|
154
|
+
# Sharding specs updated based on user's snippet
|
|
155
|
+
edf_sharding=PartitionSpec('model', None, None),
|
|
156
|
+
efd_sharding=PartitionSpec('model', None, None),
|
|
157
|
+
activation_ffw_ted=PartitionSpec('data', None),
|
|
158
|
+
activation_ffw_td=PartitionSpec(
|
|
159
|
+
'data', None) # Activations are replicated
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def test_token_replicated_expert_parallel_fwd(self):
|
|
163
|
+
"""
|
|
164
|
+
Validates the MoE forward pass against a simple, dense equivalent.
|
|
165
|
+
This specifically tests the is_batch_sharded_by_expert=False path.
|
|
166
|
+
"""
|
|
167
|
+
# --- 1. Get the ACTUAL output from the complex distributed MoE layer ---
|
|
168
|
+
# The __call__ method will trigger the shard_map, which requires the mesh context.
|
|
169
|
+
with self.mesh:
|
|
170
|
+
actual_output = self.moe(self.x)
|
|
171
|
+
|
|
172
|
+
# --- 2. Calculate the EXPECTED output using a simple, sequential process ---
|
|
173
|
+
# This serves as the "ground truth".
|
|
174
|
+
|
|
175
|
+
# Get router decisions (router params are replicated, so this is fine)
|
|
176
|
+
router_weights, selected_experts = self.moe.router(self.x)
|
|
177
|
+
|
|
178
|
+
# Gather the full, unsharded weights from all devices ---
|
|
179
|
+
# .value on a sharded param gives the *local* shard.
|
|
180
|
+
# jax.device_get() retrieves the *full* GlobalDeviceArray to the host.
|
|
181
|
+
gating_kernel_full = jax.device_get(self.moe.kernel_gating_EDF.value)
|
|
182
|
+
up_proj_kernel_full = jax.device_get(self.moe.kernel_up_proj_EDF.value)
|
|
183
|
+
down_proj_kernel_full = jax.device_get(
|
|
184
|
+
self.moe.kernel_down_proj_EFD.value)
|
|
185
|
+
|
|
186
|
+
# Check that we really got the full weights
|
|
187
|
+
self.assertEqual(gating_kernel_full.shape,
|
|
188
|
+
(self.E, self.D, self.moe_intermediate_size))
|
|
189
|
+
|
|
190
|
+
# Flatten inputs for easier iteration
|
|
191
|
+
flat_x = self.x.reshape(self.B * self.S, self.D)
|
|
192
|
+
flat_weights = router_weights.reshape(self.B * self.S, self.K)
|
|
193
|
+
flat_experts = selected_experts.reshape(self.B * self.S, self.K)
|
|
194
|
+
|
|
195
|
+
expected_output = jnp.zeros_like(flat_x)
|
|
196
|
+
|
|
197
|
+
# Manually apply each expert to each token sequentially
|
|
198
|
+
for i in range(self.B * self.S): # For each token
|
|
199
|
+
token_input = flat_x[i]
|
|
200
|
+
combined_expert_output = jnp.zeros(self.D, dtype=jnp.bfloat16)
|
|
201
|
+
|
|
202
|
+
for k in range(self.K): # For each chosen expert for that token
|
|
203
|
+
expert_idx = flat_experts[i, k]
|
|
204
|
+
weight = flat_weights[i, k]
|
|
205
|
+
|
|
206
|
+
# Get kernels from the *full* gathered arrays ---
|
|
207
|
+
gating_kernel = gating_kernel_full[expert_idx]
|
|
208
|
+
up_proj_kernel = up_proj_kernel_full[expert_idx]
|
|
209
|
+
down_proj_kernel = down_proj_kernel_full[expert_idx]
|
|
210
|
+
|
|
211
|
+
# Perform the expert computation (dense matmuls)
|
|
212
|
+
gating_proj = jnp.dot(token_input, gating_kernel)
|
|
213
|
+
up_proj = jnp.dot(token_input, up_proj_kernel)
|
|
214
|
+
|
|
215
|
+
# Note: Assuming 'silu' activation as specified in MoE init
|
|
216
|
+
fused = nnx.silu(gating_proj) * up_proj
|
|
217
|
+
|
|
218
|
+
expert_output = jnp.dot(fused, down_proj_kernel)
|
|
219
|
+
|
|
220
|
+
# Apply router weight after computation (matches implementation)
|
|
221
|
+
combined_expert_output += weight * expert_output
|
|
222
|
+
|
|
223
|
+
expected_output = expected_output.at[i].set(combined_expert_output)
|
|
224
|
+
|
|
225
|
+
expected_output = expected_output.reshape(self.B * self.S, self.D)
|
|
226
|
+
|
|
227
|
+
# --- 3. Compare the results ---
|
|
228
|
+
self.assertTrue(
|
|
229
|
+
jnp.allclose(actual_output, expected_output, atol=1e-2, rtol=1e-2),
|
|
230
|
+
f"The output of the distributed MoE does not match the dense equivalent.\n"
|
|
231
|
+
f"Actual:\n{actual_output}\n"
|
|
232
|
+
f"Expected:\n{expected_output}")
|
|
233
|
+
print(
|
|
234
|
+
"\n✅ Test Passed: Distributed MoE output matches the dense ground truth."
|
|
235
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|