tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
from flax import nnx
|
|
21
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
22
|
+
|
|
23
|
+
from tpu_inference.layers.jax.moe.deepseek_v3_moe import (DeepSeekV3Router,
|
|
24
|
+
SparseMoE)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TestDeepSeekV3Router(unittest.TestCase):
|
|
28
|
+
|
|
29
|
+
def setUp(self):
|
|
30
|
+
self.cpu_mesh = Mesh(jax.devices('cpu'), axis_names=('data', ))
|
|
31
|
+
|
|
32
|
+
def test_get_topk_indices_single_group(self):
|
|
33
|
+
"""Test get_topk_indices with single expert group."""
|
|
34
|
+
with jax.set_mesh(self.cpu_mesh):
|
|
35
|
+
router = DeepSeekV3Router(random_init=True,
|
|
36
|
+
hidden_size=512,
|
|
37
|
+
num_experts=4,
|
|
38
|
+
num_experts_per_tok=2,
|
|
39
|
+
n_groups=1,
|
|
40
|
+
topk_groups=1,
|
|
41
|
+
norm_topk_prob=True,
|
|
42
|
+
routed_scaling_factor=1.0,
|
|
43
|
+
dtype=jnp.bfloat16,
|
|
44
|
+
rngs=nnx.Rngs(42))
|
|
45
|
+
router.bias_E = jnp.zeros((4, ))
|
|
46
|
+
|
|
47
|
+
scores = jnp.array([[0.1, 0.3, 0.2, 0.4]]) # shape: (1, 4)
|
|
48
|
+
indices = router.get_topk_indices(scores)
|
|
49
|
+
|
|
50
|
+
# Should return indices of top 2 experts
|
|
51
|
+
expected_indices = jnp.array([[3,
|
|
52
|
+
1]]) # experts with scores 0.4, 0.3
|
|
53
|
+
self.assertTrue(jnp.array_equal(indices, expected_indices))
|
|
54
|
+
|
|
55
|
+
def test_get_topk_indices_2_groups(self):
|
|
56
|
+
"""Test get_topk_indices with 2 expert groups."""
|
|
57
|
+
with jax.set_mesh(self.cpu_mesh):
|
|
58
|
+
router = DeepSeekV3Router(random_init=True,
|
|
59
|
+
hidden_size=512,
|
|
60
|
+
num_experts=4,
|
|
61
|
+
num_experts_per_tok=2,
|
|
62
|
+
n_groups=2,
|
|
63
|
+
topk_groups=1,
|
|
64
|
+
norm_topk_prob=True,
|
|
65
|
+
routed_scaling_factor=1.0,
|
|
66
|
+
dtype=jnp.bfloat16,
|
|
67
|
+
rngs=nnx.Rngs(42))
|
|
68
|
+
router.bias_E = jnp.zeros((4, ))
|
|
69
|
+
|
|
70
|
+
# 4 experts, 2 groups, 2 experts per group
|
|
71
|
+
scores = jnp.array([[[0.1, 0.3, 0.2, 0.4]]]) # shape: (1, 1, 4)
|
|
72
|
+
indices = router.get_topk_indices(scores)
|
|
73
|
+
|
|
74
|
+
# Should return indices of top 2 experts
|
|
75
|
+
expected_indices = jnp.array([[[3, 2]]])
|
|
76
|
+
self.assertTrue(jnp.array_equal(indices, expected_indices))
|
|
77
|
+
|
|
78
|
+
def test_router_e2e(self):
|
|
79
|
+
with jax.set_mesh(self.cpu_mesh):
|
|
80
|
+
router = DeepSeekV3Router(random_init=True,
|
|
81
|
+
hidden_size=512,
|
|
82
|
+
num_experts=8,
|
|
83
|
+
num_experts_per_tok=2,
|
|
84
|
+
n_groups=2,
|
|
85
|
+
topk_groups=1,
|
|
86
|
+
norm_topk_prob=True,
|
|
87
|
+
routed_scaling_factor=1.0,
|
|
88
|
+
dtype=jnp.bfloat16,
|
|
89
|
+
rngs=nnx.Rngs(42))
|
|
90
|
+
x = jnp.ones((2, 512))
|
|
91
|
+
weights, indices = router(x)
|
|
92
|
+
self.assertEqual(weights.shape, (2, 2))
|
|
93
|
+
self.assertEqual(indices.shape, (2, 2))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class TestSparseMoE(unittest.TestCase):
|
|
97
|
+
|
|
98
|
+
def setUp(self):
|
|
99
|
+
"""Set up a multi-device mesh and a sample MoE layer for testing."""
|
|
100
|
+
devices = jax.devices()
|
|
101
|
+
self.device_count = len(devices)
|
|
102
|
+
if self.device_count < 8:
|
|
103
|
+
self.skipTest("This test requires at least 8 simulated devices.")
|
|
104
|
+
|
|
105
|
+
# This mesh will have a 'model' axis for expert parallelism
|
|
106
|
+
mesh_shape = (self.device_count, 1)
|
|
107
|
+
device_mesh_array = np.array(devices).reshape(mesh_shape)
|
|
108
|
+
|
|
109
|
+
# Define the axis names
|
|
110
|
+
axis_names = ('model', 'data')
|
|
111
|
+
|
|
112
|
+
# Create the 2D mesh
|
|
113
|
+
self.mesh = Mesh(device_mesh_array, axis_names=axis_names)
|
|
114
|
+
|
|
115
|
+
# --- Model Configuration ---
|
|
116
|
+
self.B, self.S, self.D = 2, 4, 16 # Batch, Sequence, Hidden Dim
|
|
117
|
+
self.E, self.K = 16, 8 # Num Experts, Experts per Token
|
|
118
|
+
self.moe_intermediate_size = 32 # FFN Dim
|
|
119
|
+
self.num_expert_parallelism = 8 # Shard experts across 8 devices
|
|
120
|
+
|
|
121
|
+
self.key = jax.random.PRNGKey(42)
|
|
122
|
+
self.x = jax.random.normal(self.key, (self.B * self.S, self.D),
|
|
123
|
+
dtype=jnp.bfloat16)
|
|
124
|
+
|
|
125
|
+
# --- Instantiate MoE Layer ---
|
|
126
|
+
# We need to do this inside the mesh context
|
|
127
|
+
with self.mesh:
|
|
128
|
+
router = DeepSeekV3Router(hidden_size=self.D,
|
|
129
|
+
num_experts=self.E,
|
|
130
|
+
num_experts_per_tok=self.K,
|
|
131
|
+
n_groups=1,
|
|
132
|
+
topk_groups=1,
|
|
133
|
+
norm_topk_prob=False,
|
|
134
|
+
routed_scaling_factor=1.0,
|
|
135
|
+
dtype=jnp.bfloat16,
|
|
136
|
+
rngs=nnx.Rngs(self.key),
|
|
137
|
+
ed_sharding=PartitionSpec(),
|
|
138
|
+
e_sharding=PartitionSpec(),
|
|
139
|
+
activation_ffw_td=PartitionSpec(
|
|
140
|
+
'data', None))
|
|
141
|
+
# Instantiation updated to match user's code snippet
|
|
142
|
+
self.moe = SparseMoE(
|
|
143
|
+
hidden_size=self.D,
|
|
144
|
+
intermediate_size_moe=self.moe_intermediate_size,
|
|
145
|
+
num_local_experts=self.E,
|
|
146
|
+
hidden_act="silu",
|
|
147
|
+
num_experts_per_tok=self.K,
|
|
148
|
+
router=router,
|
|
149
|
+
dtype=jnp.bfloat16,
|
|
150
|
+
rngs=nnx.Rngs(self.key),
|
|
151
|
+
mesh=self.mesh,
|
|
152
|
+
apply_expert_weight_before_computation=False,
|
|
153
|
+
|
|
154
|
+
# Sharding specs updated based on user's snippet
|
|
155
|
+
edf_sharding=PartitionSpec('model', None, None),
|
|
156
|
+
efd_sharding=PartitionSpec('model', None, None),
|
|
157
|
+
activation_ffw_ted=PartitionSpec('data', None),
|
|
158
|
+
activation_ffw_td=PartitionSpec(
|
|
159
|
+
'data', None) # Activations are replicated
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def test_token_replicated_expert_parallel_fwd(self):
|
|
163
|
+
"""
|
|
164
|
+
Validates the MoE forward pass against a simple, dense equivalent.
|
|
165
|
+
This specifically tests the is_batch_sharded_by_expert=False path.
|
|
166
|
+
"""
|
|
167
|
+
# --- 1. Get the ACTUAL output from the complex distributed MoE layer ---
|
|
168
|
+
# The __call__ method will trigger the shard_map, which requires the mesh context.
|
|
169
|
+
with self.mesh:
|
|
170
|
+
actual_output = self.moe(self.x)
|
|
171
|
+
|
|
172
|
+
# --- 2. Calculate the EXPECTED output using a simple, sequential process ---
|
|
173
|
+
# This serves as the "ground truth".
|
|
174
|
+
|
|
175
|
+
# Get router decisions (router params are replicated, so this is fine)
|
|
176
|
+
router_weights, selected_experts = self.moe.router(self.x)
|
|
177
|
+
|
|
178
|
+
# Gather the full, unsharded weights from all devices ---
|
|
179
|
+
# .value on a sharded param gives the *local* shard.
|
|
180
|
+
# jax.device_get() retrieves the *full* GlobalDeviceArray to the host.
|
|
181
|
+
gating_kernel_full = jax.device_get(self.moe.kernel_gating_EDF.value)
|
|
182
|
+
up_proj_kernel_full = jax.device_get(self.moe.kernel_up_proj_EDF.value)
|
|
183
|
+
down_proj_kernel_full = jax.device_get(
|
|
184
|
+
self.moe.kernel_down_proj_EFD.value)
|
|
185
|
+
|
|
186
|
+
# Check that we really got the full weights
|
|
187
|
+
self.assertEqual(gating_kernel_full.shape,
|
|
188
|
+
(self.E, self.D, self.moe_intermediate_size))
|
|
189
|
+
|
|
190
|
+
# Flatten inputs for easier iteration
|
|
191
|
+
flat_x = self.x.reshape(self.B * self.S, self.D)
|
|
192
|
+
flat_weights = router_weights.reshape(self.B * self.S, self.K)
|
|
193
|
+
flat_experts = selected_experts.reshape(self.B * self.S, self.K)
|
|
194
|
+
|
|
195
|
+
expected_output = jnp.zeros_like(flat_x)
|
|
196
|
+
|
|
197
|
+
# Manually apply each expert to each token sequentially
|
|
198
|
+
for i in range(self.B * self.S): # For each token
|
|
199
|
+
token_input = flat_x[i]
|
|
200
|
+
combined_expert_output = jnp.zeros(self.D, dtype=jnp.bfloat16)
|
|
201
|
+
|
|
202
|
+
for k in range(self.K): # For each chosen expert for that token
|
|
203
|
+
expert_idx = flat_experts[i, k]
|
|
204
|
+
weight = flat_weights[i, k]
|
|
205
|
+
|
|
206
|
+
# Get kernels from the *full* gathered arrays ---
|
|
207
|
+
gating_kernel = gating_kernel_full[expert_idx]
|
|
208
|
+
up_proj_kernel = up_proj_kernel_full[expert_idx]
|
|
209
|
+
down_proj_kernel = down_proj_kernel_full[expert_idx]
|
|
210
|
+
|
|
211
|
+
# Perform the expert computation (dense matmuls)
|
|
212
|
+
gating_proj = jnp.dot(token_input, gating_kernel)
|
|
213
|
+
up_proj = jnp.dot(token_input, up_proj_kernel)
|
|
214
|
+
|
|
215
|
+
# Note: Assuming 'silu' activation as specified in MoE init
|
|
216
|
+
fused = nnx.silu(gating_proj) * up_proj
|
|
217
|
+
|
|
218
|
+
expert_output = jnp.dot(fused, down_proj_kernel)
|
|
219
|
+
|
|
220
|
+
# Apply router weight after computation (matches implementation)
|
|
221
|
+
combined_expert_output += weight * expert_output
|
|
222
|
+
|
|
223
|
+
expected_output = expected_output.at[i].set(combined_expert_output)
|
|
224
|
+
|
|
225
|
+
expected_output = expected_output.reshape(self.B * self.S, self.D)
|
|
226
|
+
|
|
227
|
+
# --- 3. Compare the results ---
|
|
228
|
+
self.assertTrue(
|
|
229
|
+
jnp.allclose(actual_output, expected_output, atol=1e-2, rtol=1e-2),
|
|
230
|
+
f"The output of the distributed MoE does not match the dense equivalent.\n"
|
|
231
|
+
f"Actual:\n{actual_output}\n"
|
|
232
|
+
f"Expected:\n{expected_output}")
|
|
233
|
+
print(
|
|
234
|
+
"\n✅ Test Passed: Distributed MoE output matches the dense ground truth."
|
|
235
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|