tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import InitVar, dataclass
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
from flax import nnx
|
|
20
|
+
from flax.typing import Sharding
|
|
21
|
+
from jaxtyping import Float
|
|
22
|
+
|
|
23
|
+
from tpu_inference.layers.jax.base import create_param
|
|
24
|
+
from tpu_inference.layers.jax.layers import FlaxUtils
|
|
25
|
+
from tpu_inference.layers.jax.moe.moe import Router
|
|
26
|
+
|
|
27
|
+
modeling_flax_utils = FlaxUtils()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(kw_only=True)
|
|
31
|
+
class GptOssRouter(Router):
|
|
32
|
+
"""Router module for Mixture-of-Experts (MoE) layers.
|
|
33
|
+
|
|
34
|
+
This module determines which experts each token should be routed.
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
e_sharding: Sharding = ()
|
|
38
|
+
|
|
39
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
40
|
+
"""
|
|
41
|
+
Initializes the parent's kernel and adds the new bias parameter.
|
|
42
|
+
"""
|
|
43
|
+
super().__post_init__(rngs)
|
|
44
|
+
|
|
45
|
+
self.bias_E = create_param(rngs,
|
|
46
|
+
shape=(self.num_experts, ),
|
|
47
|
+
dtype=self.dtype,
|
|
48
|
+
sharding=self.e_sharding,
|
|
49
|
+
random_init=self.random_init)
|
|
50
|
+
|
|
51
|
+
def __call__(self, x_TD: Float):
|
|
52
|
+
"""
|
|
53
|
+
Overrides the parent's forward pass to include the bias.
|
|
54
|
+
"""
|
|
55
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
56
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
57
|
+
|
|
58
|
+
router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
|
|
59
|
+
self.kernel_DE.value)
|
|
60
|
+
|
|
61
|
+
router_logits_TE += self.bias_E.value
|
|
62
|
+
|
|
63
|
+
weights_TX, selected_experts_TX = jax.lax.top_k(
|
|
64
|
+
router_logits_TE, self.num_experts_per_tok)
|
|
65
|
+
|
|
66
|
+
normalized_weights_TX = jax.nn.softmax(weights_TX.astype(self.dtype),
|
|
67
|
+
axis=-1)
|
|
68
|
+
|
|
69
|
+
return normalized_weights_TX, selected_experts_TX
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _swiglu(x: Float, alpha: Float, limit: Float) -> Float:
|
|
73
|
+
"""Implements the specific SwiGLU from the golden implementation."""
|
|
74
|
+
x_glu, x_linear = x[..., ::2], x[..., 1::2]
|
|
75
|
+
|
|
76
|
+
x_glu = jnp.clip(x_glu, a_max=limit)
|
|
77
|
+
x_linear = jnp.clip(x_linear, a_min=-limit, a_max=limit)
|
|
78
|
+
|
|
79
|
+
gated_activation = x_glu * jax.nn.sigmoid(alpha * x_glu)
|
|
80
|
+
|
|
81
|
+
return gated_activation * (x_linear + 1)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(kw_only=True)
|
|
85
|
+
class CombineExperts(nnx.Module):
|
|
86
|
+
"""Module for combining expert outputs with weighted sum."""
|
|
87
|
+
dtype: jnp.dtype
|
|
88
|
+
|
|
89
|
+
def __call__(self, down_proj_TED: Float, weights_TX: Float,
|
|
90
|
+
indices_TX: jax.Array) -> Float:
|
|
91
|
+
"""Combines expert outputs using weighted sum.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
down_proj_TED: Expert outputs, shape (tokens, experts, hidden_dim)
|
|
95
|
+
weights_TX: Router weights, shape (tokens, experts_per_token)
|
|
96
|
+
indices_TX: Selected expert indices, shape (tokens, experts_per_token)
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Combined output, shape (tokens, hidden_dim)
|
|
100
|
+
"""
|
|
101
|
+
with jax.named_scope("combine_experts"):
|
|
102
|
+
indices_for_gather = indices_TX[..., None]
|
|
103
|
+
gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED,
|
|
104
|
+
indices_for_gather,
|
|
105
|
+
axis=1)
|
|
106
|
+
output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED,
|
|
107
|
+
weights_TX)
|
|
108
|
+
|
|
109
|
+
return output_TD.astype(self.dtype)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass(kw_only=True)
|
|
113
|
+
class GptOssMoE(nnx.Module):
|
|
114
|
+
"""
|
|
115
|
+
JAX implementation of the GPT-OSS Mixture-of-Experts MLP block.
|
|
116
|
+
"""
|
|
117
|
+
dtype: jnp.dtype
|
|
118
|
+
hidden_size: int
|
|
119
|
+
intermediate_size_moe: int
|
|
120
|
+
num_local_experts: int
|
|
121
|
+
router: GptOssRouter
|
|
122
|
+
rngs: InitVar[nnx.Rngs]
|
|
123
|
+
|
|
124
|
+
swiglu_limit: float = 7.0
|
|
125
|
+
swiglu_alpha: float = 1.702
|
|
126
|
+
|
|
127
|
+
# Sharding specifications
|
|
128
|
+
activation_ffw_td: Sharding
|
|
129
|
+
edf_sharding: Sharding
|
|
130
|
+
efd_sharding: Sharding
|
|
131
|
+
ed_sharding: Sharding
|
|
132
|
+
|
|
133
|
+
random_init: bool = False
|
|
134
|
+
|
|
135
|
+
def __call__(self, x_TD: Float) -> Float:
|
|
136
|
+
"""Performs the forward pass for the GPT-OSS MoE layer."""
|
|
137
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
138
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
139
|
+
|
|
140
|
+
weights_TX, indices_TX = self.router(x_TD)
|
|
141
|
+
|
|
142
|
+
# First MLP layer (up-projection)
|
|
143
|
+
with jax.named_scope("MLP #1"):
|
|
144
|
+
up_proj_TEF2 = jnp.einsum('TD,EDF -> TEF', x_TD,
|
|
145
|
+
self.mlp1_weight_EDF2.value)
|
|
146
|
+
up_proj_TEF2 += self.mlp1_bias_EF2.value
|
|
147
|
+
|
|
148
|
+
fuse_TEF = _swiglu(up_proj_TEF2,
|
|
149
|
+
alpha=self.swiglu_alpha,
|
|
150
|
+
limit=self.swiglu_limit)
|
|
151
|
+
|
|
152
|
+
# Second MLP layer (down-projection)
|
|
153
|
+
with jax.named_scope("MLP #2"):
|
|
154
|
+
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
|
|
155
|
+
self.mlp2_weight_EFD.value)
|
|
156
|
+
down_proj_TED += self.mlp2_bias_ED.value
|
|
157
|
+
|
|
158
|
+
# Weighted sum of expert outputs
|
|
159
|
+
output_TD = self.combine_experts(down_proj_TED, weights_TX, indices_TX)
|
|
160
|
+
|
|
161
|
+
return output_TD
|
|
162
|
+
|
|
163
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
164
|
+
"""Initializes all weights and biases for the MoE block."""
|
|
165
|
+
D, F, E = self.hidden_size, self.intermediate_size_moe, self.num_local_experts
|
|
166
|
+
|
|
167
|
+
self.combine_experts = CombineExperts(dtype=self.dtype)
|
|
168
|
+
|
|
169
|
+
# MLP #1 Weights (Combined Gate and Up-projection) and Bias
|
|
170
|
+
self.mlp1_weight_EDF2 = create_param(
|
|
171
|
+
rngs,
|
|
172
|
+
shape=(E, D, F * 2),
|
|
173
|
+
dtype=self.dtype,
|
|
174
|
+
sharding=self.edf_sharding,
|
|
175
|
+
random_init=self.random_init,
|
|
176
|
+
)
|
|
177
|
+
self.mlp1_bias_EF2 = create_param(
|
|
178
|
+
rngs,
|
|
179
|
+
shape=(E, F * 2),
|
|
180
|
+
dtype=self.dtype,
|
|
181
|
+
sharding=self.ed_sharding,
|
|
182
|
+
random_init=self.random_init,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# MLP #2 Weights (Down-projection) and Bias
|
|
186
|
+
self.mlp2_weight_EFD = create_param(
|
|
187
|
+
rngs,
|
|
188
|
+
shape=(E, F, D),
|
|
189
|
+
dtype=self.dtype,
|
|
190
|
+
sharding=self.efd_sharding,
|
|
191
|
+
random_init=self.random_init,
|
|
192
|
+
)
|
|
193
|
+
self.mlp2_bias_ED = create_param(
|
|
194
|
+
rngs,
|
|
195
|
+
shape=(E, D),
|
|
196
|
+
dtype=self.dtype,
|
|
197
|
+
sharding=self.ed_sharding,
|
|
198
|
+
random_init=self.random_init,
|
|
199
|
+
)
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import InitVar, dataclass
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
from flax import nnx
|
|
20
|
+
from flax.typing import Sharding
|
|
21
|
+
from jaxtyping import Float
|
|
22
|
+
|
|
23
|
+
from tpu_inference.layers.jax.base import create_param
|
|
24
|
+
from tpu_inference.layers.jax.layers import FlaxUtils
|
|
25
|
+
|
|
26
|
+
modeling_flax_utils = FlaxUtils()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(kw_only=True)
|
|
30
|
+
class CombineExperts(nnx.Module):
|
|
31
|
+
"""Combines expert outputs with router weights.
|
|
32
|
+
|
|
33
|
+
Supports `TED,TE -> TD` when passed expert outputs, using float32
|
|
34
|
+
accumulation for numerical stability, then casting back to the target
|
|
35
|
+
dtype.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
dtype: jnp.dtype
|
|
39
|
+
|
|
40
|
+
def __call__(self, expert_outputs_TED: Float, weights_TE: Float) -> Float:
|
|
41
|
+
with jax.named_scope("combine_experts"):
|
|
42
|
+
output_TD = jnp.einsum(
|
|
43
|
+
"TED,TE -> TD",
|
|
44
|
+
expert_outputs_TED.astype(jnp.float32),
|
|
45
|
+
weights_TE.astype(jnp.float32),
|
|
46
|
+
precision="float32",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
return output_TD.astype(self.dtype)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(kw_only=True)
|
|
53
|
+
class Router(nnx.Module):
|
|
54
|
+
"""Router module for Mixture-of-Experts (MoE) layers.
|
|
55
|
+
|
|
56
|
+
This module determines which experts each token should be routed to based on the input.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
"""
|
|
60
|
+
dtype: jnp.dtype
|
|
61
|
+
hidden_size: int
|
|
62
|
+
num_experts: int
|
|
63
|
+
num_experts_per_tok: int
|
|
64
|
+
router_act: str
|
|
65
|
+
rngs: InitVar[nnx.Rngs]
|
|
66
|
+
activation_ffw_td: Sharding
|
|
67
|
+
ed_sharding: Sharding
|
|
68
|
+
random_init: bool = False
|
|
69
|
+
|
|
70
|
+
def __call__(self, x_TD: Float):
|
|
71
|
+
"""Routes tokens to experts.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
x_TD: Input array of shape (sequence_length, d_model).
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
A tuple containing:
|
|
78
|
+
- normalized_weights_TX: Normalized weights for selected experts, shape (sequence_length, num_experts_per_tok).
|
|
79
|
+
- selected_experts_TX: Indices of selected experts, shape (sequence_length, num_experts_per_tok).
|
|
80
|
+
"""
|
|
81
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
82
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
83
|
+
router_act = modeling_flax_utils.ACT2FN[self.router_act]
|
|
84
|
+
router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
|
|
85
|
+
self.kernel_DE.value)
|
|
86
|
+
weights_TX, selected_experts_TX = jax.lax.top_k(
|
|
87
|
+
router_logits_TE, self.num_experts_per_tok)
|
|
88
|
+
if self.router_act != "sigmoid": # sigmoid does not accept axis argument.
|
|
89
|
+
normalized_weights_TX = router_act(weights_TX.astype(self.dtype),
|
|
90
|
+
axis=-1)
|
|
91
|
+
else:
|
|
92
|
+
normalized_weights_TX = router_act(weights_TX.astype(self.dtype))
|
|
93
|
+
return normalized_weights_TX, selected_experts_TX
|
|
94
|
+
|
|
95
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
96
|
+
"""Generates the router kernel (weights) for routing."""
|
|
97
|
+
shape = (self.hidden_size, self.num_experts)
|
|
98
|
+
self.kernel_DE = create_param(rngs,
|
|
99
|
+
shape=shape,
|
|
100
|
+
dtype=self.dtype,
|
|
101
|
+
sharding=self.ed_sharding,
|
|
102
|
+
random_init=self.random_init)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass(kw_only=True)
|
|
106
|
+
class MoE(nnx.Module):
|
|
107
|
+
"""Mixture-of-Experts (MoE) Routed MLP Layer.
|
|
108
|
+
|
|
109
|
+
This module implements a MoE layer with a router and multiple expert MLPs.
|
|
110
|
+
|
|
111
|
+
Attributes:
|
|
112
|
+
router: The Router module.
|
|
113
|
+
"""
|
|
114
|
+
dtype: jnp.dtype
|
|
115
|
+
num_local_experts: int
|
|
116
|
+
apply_expert_weight_before_computation: bool
|
|
117
|
+
hidden_size: int
|
|
118
|
+
intermediate_size_moe: int
|
|
119
|
+
hidden_act: str
|
|
120
|
+
rngs: InitVar[nnx.Rngs]
|
|
121
|
+
router: nnx.Module
|
|
122
|
+
activation_ffw_td: Sharding
|
|
123
|
+
activation_ffw_ted: Sharding
|
|
124
|
+
edf_sharding: Sharding
|
|
125
|
+
efd_sharding: Sharding
|
|
126
|
+
random_init: bool = False
|
|
127
|
+
|
|
128
|
+
def __call__(self, x_TD: Float):
|
|
129
|
+
"""Performs the forward pass of the MoE layer.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
x_TD: Input array of shape (sequence_length, d_model).
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Output array of shape (sequence_length, d_model) after passing through MoE.
|
|
136
|
+
"""
|
|
137
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
138
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
139
|
+
weights_TX, indices_TX = self.router(x_TD)
|
|
140
|
+
one_hot_indices_TXE = jax.nn.one_hot(
|
|
141
|
+
indices_TX, num_classes=self.num_local_experts, dtype=self.dtype)
|
|
142
|
+
full_weights_TE = jnp.sum(one_hot_indices_TXE * weights_TX[..., None],
|
|
143
|
+
axis=1)
|
|
144
|
+
|
|
145
|
+
# Some models use the routing scores to weight the data instead of
|
|
146
|
+
# weighting the expert outputs.
|
|
147
|
+
if self.apply_expert_weight_before_computation:
|
|
148
|
+
with jax.named_scope("pre_computing_weight"):
|
|
149
|
+
return self._moe_fwd_preapply_router_weights(
|
|
150
|
+
x_TD, full_weights_TE)
|
|
151
|
+
else:
|
|
152
|
+
return self._moe_fwd(x_TD, full_weights_TE)
|
|
153
|
+
|
|
154
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
155
|
+
"""Generates the kernels (weights) for the router and experts (gating, up-projection, and down-projection layers)."""
|
|
156
|
+
|
|
157
|
+
D = self.hidden_size
|
|
158
|
+
F = self.intermediate_size_moe
|
|
159
|
+
shape_gating = (self.num_local_experts, D, F)
|
|
160
|
+
shape_up = (self.num_local_experts, D, F)
|
|
161
|
+
shape_down = (self.num_local_experts, F, D)
|
|
162
|
+
|
|
163
|
+
self.kernel_gating_EDF = create_param(rngs,
|
|
164
|
+
shape=shape_gating,
|
|
165
|
+
dtype=self.dtype,
|
|
166
|
+
sharding=self.edf_sharding,
|
|
167
|
+
random_init=self.random_init)
|
|
168
|
+
self.kernel_up_proj_EDF = create_param(rngs,
|
|
169
|
+
shape=shape_up,
|
|
170
|
+
dtype=self.dtype,
|
|
171
|
+
sharding=self.edf_sharding,
|
|
172
|
+
random_init=self.random_init)
|
|
173
|
+
self.kernel_down_proj_EFD = create_param(rngs,
|
|
174
|
+
shape=shape_down,
|
|
175
|
+
dtype=self.dtype,
|
|
176
|
+
sharding=self.efd_sharding,
|
|
177
|
+
random_init=self.random_init)
|
|
178
|
+
|
|
179
|
+
# Shared combine module for combine path
|
|
180
|
+
self.combine_experts = CombineExperts(dtype=self.dtype)
|
|
181
|
+
|
|
182
|
+
def _moe_fwd_preapply_router_weights(self, x_TD: jax.Array, weights_TE):
|
|
183
|
+
"""Performs the forward pass of the MoE experts with router weights pre-applied to the inputs.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
x_TD: Input array for the experts, shape (sequence_length, hidden_size).
|
|
187
|
+
weights_TE: Router weights, shape (sequence_length, num_experts).
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Output array of shape (sequence_length, d_model).
|
|
191
|
+
"""
|
|
192
|
+
# Data needs to be replicated since it will be weighted by the router
|
|
193
|
+
# scores before being passed to each expert.
|
|
194
|
+
num_experts = weights_TE.shape[-1]
|
|
195
|
+
x_TED = jnp.repeat(x_TD[:, None, :], num_experts, 1)
|
|
196
|
+
weights_TED = weights_TE[..., None]
|
|
197
|
+
x_TED = jnp.asarray(x_TED, self.dtype)
|
|
198
|
+
|
|
199
|
+
with jax.named_scope("activation_expert_weighting"):
|
|
200
|
+
x_TED = x_TED * weights_TED
|
|
201
|
+
|
|
202
|
+
x_TED = nnx.with_sharding_constraint(x_TED, self.activation_ffw_ted)
|
|
203
|
+
with jax.named_scope("gating"):
|
|
204
|
+
gating_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
|
|
205
|
+
self.kernel_gating_EDF.value)
|
|
206
|
+
activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
|
|
207
|
+
gating_TEF)
|
|
208
|
+
with jax.named_scope("up_projection"):
|
|
209
|
+
up_proj_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
|
|
210
|
+
self.kernel_up_proj_EDF.value)
|
|
211
|
+
|
|
212
|
+
fuse_TEF = activated_gating_TEF * up_proj_TEF
|
|
213
|
+
|
|
214
|
+
with jax.named_scope("down_projection"):
|
|
215
|
+
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
|
|
216
|
+
self.kernel_down_proj_EFD.value)
|
|
217
|
+
with jax.named_scope("sum"):
|
|
218
|
+
output_TD = down_proj_TED.sum(axis=1)
|
|
219
|
+
return output_TD.astype(self.dtype)
|
|
220
|
+
|
|
221
|
+
def _moe_fwd(self, x_TD: Float, weights):
|
|
222
|
+
"""Performs the basic forward pass of the MoE experts without dropping or megablocks.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
x_TD: Input array for the experts, shape (sequence_length, d_model).
|
|
226
|
+
weights: Weights for combining expert outputs, shape (sequence_length, num_experts).
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Output array of shape (sequence_length, d_model).
|
|
230
|
+
"""
|
|
231
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
232
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
233
|
+
with jax.named_scope("gating"):
|
|
234
|
+
gating_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
|
|
235
|
+
self.kernel_gating_EDF.value)
|
|
236
|
+
activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
|
|
237
|
+
gating_TEF)
|
|
238
|
+
with jax.named_scope("up_projection"):
|
|
239
|
+
up_proj_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
|
|
240
|
+
self.kernel_up_proj_EDF.value)
|
|
241
|
+
|
|
242
|
+
fuse_TEF = activated_gating_TEF * up_proj_TEF
|
|
243
|
+
|
|
244
|
+
with jax.named_scope("down_projection"):
|
|
245
|
+
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
|
|
246
|
+
self.kernel_down_proj_EFD.value)
|
|
247
|
+
# Combine across experts
|
|
248
|
+
output_TD = self.combine_experts(down_proj_TED, weights)
|
|
249
|
+
return output_TD
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List, Protocol
|
|
16
|
+
|
|
17
|
+
from flax import nnx
|
|
18
|
+
from vllm.distributed import get_pp_group
|
|
19
|
+
from vllm.distributed.utils import get_pp_indices
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PPMissingLayer(nnx.Module):
|
|
23
|
+
"""
|
|
24
|
+
A placeholder layer for missing layers in a pipeline parallel model.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, *args, **kwargs):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
def __call__(self, *args, **kwargs):
|
|
31
|
+
"""Return the first arg from args or the first value from kwargs."""
|
|
32
|
+
return args[0] if args else next(iter(kwargs.values()))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class LayerFn(Protocol):
|
|
36
|
+
|
|
37
|
+
def __call__(self) -> nnx.Module:
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def make_layers(
|
|
42
|
+
num_hidden_layers: int,
|
|
43
|
+
layer_fn: LayerFn,
|
|
44
|
+
) -> tuple[int, int, List[nnx.Module]]:
|
|
45
|
+
start_layer, end_layer = get_pp_indices(num_hidden_layers,
|
|
46
|
+
get_pp_group().rank_in_group,
|
|
47
|
+
get_pp_group().world_size)
|
|
48
|
+
|
|
49
|
+
layers = [PPMissingLayer() for _ in range(start_layer)] \
|
|
50
|
+
+ [layer_fn() for _ in range(start_layer, end_layer)] \
|
|
51
|
+
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
|
|
52
|
+
|
|
53
|
+
return start_layer, end_layer, layers
|