tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Union
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn.functional as F
|
|
7
|
+
from compressed_tensors.quantization import QuantizationArgs
|
|
7
8
|
from jax.experimental.layout import Format, Layout
|
|
8
9
|
from jax.sharding import Mesh, NamedSharding
|
|
9
10
|
from jax.sharding import PartitionSpec as P
|
|
@@ -12,52 +13,89 @@ from torchax.interop import call_jax, torch_view
|
|
|
12
13
|
from torchax.ops.mappings import t2j
|
|
13
14
|
from vllm.logger import init_logger
|
|
14
15
|
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
|
|
15
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.
|
|
16
|
-
|
|
17
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
18
|
-
CompressedTensorsW8A8Fp8MoEMethod
|
|
19
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
|
20
|
-
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
|
16
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
|
17
|
+
CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
|
|
21
18
|
|
|
22
19
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
20
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
21
|
+
VllmUnquantizedFusedMoEMethod
|
|
23
22
|
|
|
24
23
|
logger = init_logger(__name__)
|
|
25
24
|
|
|
26
25
|
|
|
26
|
+
class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def get_moe_method(
|
|
30
|
+
quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
|
|
31
|
+
layer: torch.nn.Module,
|
|
32
|
+
layer_name: str,
|
|
33
|
+
) -> CompressedTensorsMoEMethod:
|
|
34
|
+
|
|
35
|
+
assert isinstance(layer, FusedMoE)
|
|
36
|
+
|
|
37
|
+
# FusedMoE was made by combining multiple Linears so need to
|
|
38
|
+
# make sure quantization config for Linear can target it
|
|
39
|
+
quant_config._add_fused_moe_to_target_scheme_map()
|
|
40
|
+
unfused_names = [
|
|
41
|
+
layer_name + proj_name
|
|
42
|
+
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
|
|
43
|
+
]
|
|
44
|
+
# TODO: refactor this to use expert_mapping and check all layer numbers
|
|
45
|
+
all_scheme_dicts = [
|
|
46
|
+
quant_config.get_scheme_dict(layer, name) for name in unfused_names
|
|
47
|
+
]
|
|
48
|
+
scheme_dict = all_scheme_dicts.pop()
|
|
49
|
+
|
|
50
|
+
# multiple schemes found
|
|
51
|
+
if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
|
|
52
|
+
raise ValueError("All MoE projections need to have same "
|
|
53
|
+
"quantization scheme but found multiple")
|
|
54
|
+
|
|
55
|
+
if scheme_dict is None:
|
|
56
|
+
return VllmUnquantizedFusedMoEMethod(layer.moe_config,
|
|
57
|
+
quant_config.mesh)
|
|
58
|
+
|
|
59
|
+
weight_quant = scheme_dict.get("weights")
|
|
60
|
+
input_quant = scheme_dict.get("input_activations")
|
|
61
|
+
|
|
62
|
+
if quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
|
63
|
+
return VllmCompressedTensorsW8A8Fp8MoEMethod(
|
|
64
|
+
weight_quant, input_quant, layer.moe_config, quant_config.mesh)
|
|
65
|
+
else:
|
|
66
|
+
raise RuntimeError(
|
|
67
|
+
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
|
68
|
+
|
|
69
|
+
|
|
27
70
|
class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
|
|
28
71
|
JaxCommonConfig):
|
|
29
72
|
|
|
30
|
-
def __init__(self,
|
|
31
|
-
|
|
32
|
-
|
|
73
|
+
def __init__(self, weight_quant: QuantizationArgs,
|
|
74
|
+
input_quant: QuantizationArgs, moe: FusedMoEConfig,
|
|
75
|
+
mesh: Mesh):
|
|
76
|
+
super().__init__(weight_quant, input_quant, moe)
|
|
33
77
|
self.mesh = mesh
|
|
34
|
-
self.quant_config = quant_config
|
|
35
|
-
|
|
36
|
-
# disable GPU paths
|
|
37
|
-
self.use_marlin = False
|
|
38
|
-
self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
|
|
39
|
-
self.is_fp8_w8a8_sm100 = False
|
|
40
|
-
self.use_cutlass = False
|
|
41
|
-
self.disable_expert_map = False
|
|
42
78
|
|
|
43
79
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
44
80
|
assert isinstance(layer, FusedMoE)
|
|
45
81
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
w3_weight = layer.w13_weight[:, intermediate_size:]
|
|
49
|
-
w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
|
|
50
|
-
w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
|
|
51
|
-
|
|
82
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
83
|
+
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
|
|
52
84
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
53
|
-
w2_weight_scale = t2j(layer.w2_weight_scale
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
85
|
+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
|
|
86
|
+
|
|
87
|
+
w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16)
|
|
88
|
+
w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16)
|
|
89
|
+
|
|
90
|
+
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
91
|
+
assert w2_weight_scale.shape == (num_experts, hidden_size, 1)
|
|
92
|
+
assert w13_weight.shape == (num_experts, 2 * intermediate_size,
|
|
93
|
+
hidden_size)
|
|
94
|
+
assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size,
|
|
95
|
+
1)
|
|
96
|
+
|
|
97
|
+
w1_weight, w3_weight = jnp.split(w13_weight, 2, 1)
|
|
98
|
+
w1_weight_scale, w3_weight_scale = jnp.split(w13_weight_scale, 2, 1)
|
|
61
99
|
|
|
62
100
|
if layer.use_ep:
|
|
63
101
|
format = Format(Layout((0, 1, 2)),
|
|
@@ -69,16 +107,9 @@ class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
|
|
|
69
107
|
w2_weight = jax.device_put(w2_weight, format)
|
|
70
108
|
w2_weight_scale = jax.device_put(w2_weight_scale, format)
|
|
71
109
|
else:
|
|
72
|
-
assert intermediate_size == w2_weight.shape[-1]
|
|
73
110
|
n_shards = self.mesh.shape["model"]
|
|
74
111
|
assert intermediate_size % n_shards == 0
|
|
75
112
|
|
|
76
|
-
# TODO: enable this if using fused weights
|
|
77
|
-
# output_sizes = [intermediate_size, intermediate_size]
|
|
78
|
-
# w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
79
|
-
# w13_weight, output_sizes, n_shards, dim=1
|
|
80
|
-
# )
|
|
81
|
-
|
|
82
113
|
w13_format = Format(
|
|
83
114
|
Layout((0, 1, 2)),
|
|
84
115
|
NamedSharding(self.mesh, P(None, "model", None)))
|
|
@@ -119,45 +150,23 @@ class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
|
|
|
119
150
|
layer: torch.nn.Module,
|
|
120
151
|
x: torch.Tensor,
|
|
121
152
|
router_logits: torch.Tensor,
|
|
122
|
-
top_k: int,
|
|
123
|
-
renormalize: bool,
|
|
124
|
-
use_grouped_topk: bool = False,
|
|
125
|
-
topk_group: Optional[int] = None,
|
|
126
|
-
num_expert_group: Optional[int] = None,
|
|
127
|
-
global_num_experts: int = -1,
|
|
128
|
-
expert_map: Optional[torch.Tensor] = None,
|
|
129
|
-
custom_routing_function: Optional[Callable] = None,
|
|
130
|
-
scoring_func: str = "softmax",
|
|
131
|
-
routed_scaling_factor: float = 1.0,
|
|
132
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
133
|
-
apply_router_weight_on_input: bool = False,
|
|
134
|
-
activation: str = "silu",
|
|
135
|
-
enable_eplb: bool = False,
|
|
136
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
137
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
138
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
139
153
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
140
154
|
assert isinstance(layer, FusedMoE)
|
|
141
|
-
if activation != "silu":
|
|
155
|
+
if layer.activation != "silu":
|
|
142
156
|
raise NotImplementedError(
|
|
143
157
|
"Only silu is supported for activation function.")
|
|
144
|
-
if scoring_func != "softmax":
|
|
158
|
+
if layer.scoring_func != "softmax":
|
|
145
159
|
raise NotImplementedError(
|
|
146
160
|
"Only softmax is supported for scoring_func")
|
|
147
161
|
|
|
148
|
-
# import sys
|
|
149
|
-
# sys.stdin = open(0)
|
|
150
|
-
# breakpoint()
|
|
151
|
-
|
|
152
162
|
# TODO: Use MoE kernel when it supports fp8
|
|
153
|
-
|
|
154
163
|
seqlen = x.shape[0]
|
|
155
164
|
|
|
156
165
|
expert_weights = F.softmax(router_logits, dim=-1)
|
|
157
166
|
expert_weights, expert_indices = torch.topk(expert_weights,
|
|
158
|
-
top_k,
|
|
167
|
+
layer.top_k,
|
|
159
168
|
dim=-1)
|
|
160
|
-
if renormalize:
|
|
169
|
+
if layer.renormalize:
|
|
161
170
|
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
|
|
162
171
|
|
|
163
172
|
# cond ffn
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
import jax.numpy as jnp
|
|
@@ -24,9 +24,11 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
|
24
24
|
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
25
25
|
is_layer_skipped
|
|
26
26
|
|
|
27
|
+
from tpu_inference import envs
|
|
28
|
+
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
27
29
|
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
28
30
|
get_tpu_quant_method)
|
|
29
|
-
from tpu_inference.layers.vllm.fused_moe import
|
|
31
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
30
32
|
from tpu_inference.layers.vllm.linear_common import \
|
|
31
33
|
reorder_concatenated_tensor_for_sharding
|
|
32
34
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
@@ -85,17 +87,14 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
85
87
|
fused_mapping=self.packed_modules_mapping,
|
|
86
88
|
):
|
|
87
89
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
88
|
-
# TODO: Add support for MXFP4 Linear Method.
|
|
89
|
-
# MXFP4 LinearMethod is available in AMD-Quark, refer to that
|
|
90
|
-
# implementation if you are interested in enabling MXFP4 here.
|
|
91
90
|
logger.warning_once(
|
|
92
91
|
"MXFP4 linear layer is not implemented - falling back to "
|
|
93
92
|
"UnquantizedLinearMethod.")
|
|
94
93
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
95
94
|
elif isinstance(layer, FusedMoE):
|
|
96
|
-
|
|
95
|
+
moe_config = self.get_moe_config(layer)
|
|
96
|
+
return VllmMxfp4MoEMethod(moe_config, self.mesh)
|
|
97
97
|
elif isinstance(layer, Attention):
|
|
98
|
-
# TODO: Add support for MXFP4 Attention.
|
|
99
98
|
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
100
99
|
"Skipping quantization for this layer.")
|
|
101
100
|
return None
|
|
@@ -103,13 +102,30 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
103
102
|
|
|
104
103
|
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
105
104
|
|
|
106
|
-
def __init__(self,
|
|
105
|
+
def __init__(self,
|
|
106
|
+
moe: FusedMoEConfig,
|
|
107
|
+
mesh: Mesh,
|
|
108
|
+
ep_axis_name: str = 'model'):
|
|
107
109
|
FusedMoEMethodBase.__init__(self, moe)
|
|
108
110
|
|
|
109
111
|
# We piggyback on triton implementation as it applies minimal hardware
|
|
110
112
|
# specific post processing to the weights.
|
|
111
113
|
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
114
|
+
|
|
112
115
|
self.mesh = mesh
|
|
116
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
|
|
117
|
+
self.ep_axis_name = ep_axis_name
|
|
118
|
+
# TODO: Use autotune table once we have it.
|
|
119
|
+
self.block_size = {
|
|
120
|
+
"bt": 64,
|
|
121
|
+
"bf": 1024,
|
|
122
|
+
"bd1": 1536,
|
|
123
|
+
"bd2": 1536,
|
|
124
|
+
"btc": 64,
|
|
125
|
+
"bfc": 1024,
|
|
126
|
+
"bd1c": 1536,
|
|
127
|
+
"bd2c": 1536,
|
|
128
|
+
}
|
|
113
129
|
|
|
114
130
|
def get_fused_moe_quant_config(
|
|
115
131
|
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
@@ -122,6 +138,7 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
122
138
|
|
|
123
139
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
124
140
|
assert isinstance(layer, FusedMoE)
|
|
141
|
+
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
125
142
|
|
|
126
143
|
w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
|
|
127
144
|
w13_weight_scale = e8m0_to_fp32(
|
|
@@ -140,6 +157,8 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
140
157
|
w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
|
|
141
158
|
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
142
159
|
|
|
160
|
+
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
161
|
+
|
|
143
162
|
# Because we have dequantized weights, scales are not used anymore.
|
|
144
163
|
delattr(layer, "w13_weight_scale")
|
|
145
164
|
delattr(layer, "w2_weight_scale")
|
|
@@ -157,110 +176,137 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
157
176
|
w3_bias = w13_bias[:, 1::2]
|
|
158
177
|
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
159
178
|
|
|
160
|
-
|
|
161
|
-
|
|
179
|
+
if self.use_kernel:
|
|
180
|
+
# Kernel expects:
|
|
181
|
+
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
182
|
+
# w2: (num_experts, intermediate_size, hidden_size)
|
|
183
|
+
# Current format:
|
|
184
|
+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
185
|
+
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
186
|
+
|
|
187
|
+
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
188
|
+
intermediate_size, hidden_size)
|
|
189
|
+
|
|
190
|
+
# Transpose non-constracting dim to right most dim
|
|
191
|
+
w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
|
|
192
|
+
w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
|
|
193
|
+
|
|
194
|
+
# Apply EP sharding
|
|
195
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
196
|
+
|
|
162
197
|
w13_weight = jax.device_put(
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
Format(Layout((0, 1)),
|
|
174
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
175
|
-
w2_bias = jax.device_put(
|
|
176
|
-
w2_bias,
|
|
177
|
-
Format(Layout((0, 1)),
|
|
178
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
198
|
+
w13_weight_transposed, Format(Layout((0, 1, 2, 3)),
|
|
199
|
+
ep_sharding))
|
|
200
|
+
w2_weight = jax.device_put(w2_weight_transposed,
|
|
201
|
+
Format(Layout((0, 1, 2)), ep_sharding))
|
|
202
|
+
|
|
203
|
+
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
|
|
204
|
+
w13_bias = jax.device_put(w13_bias,
|
|
205
|
+
Format(Layout((0, 1, 2)), ep_sharding))
|
|
206
|
+
w2_bias = jax.device_put(w2_bias,
|
|
207
|
+
Format(Layout((0, 1)), ep_sharding))
|
|
179
208
|
|
|
180
209
|
else:
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
210
|
+
if layer.use_ep:
|
|
211
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
212
|
+
w13_weight = jax.device_put(
|
|
213
|
+
w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
214
|
+
w2_weight = jax.device_put(
|
|
215
|
+
w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
216
|
+
|
|
217
|
+
w13_bias = jax.device_put(w13_bias,
|
|
218
|
+
Format(Layout((0, 1)), ep_sharding))
|
|
219
|
+
w2_bias = jax.device_put(w2_bias,
|
|
220
|
+
Format(Layout((0, 1)), ep_sharding))
|
|
221
|
+
|
|
222
|
+
else:
|
|
223
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
224
|
+
n_shards = self.mesh.shape["model"]
|
|
225
|
+
assert intermediate_size % n_shards == 0
|
|
226
|
+
|
|
227
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
228
|
+
w13_weight,
|
|
229
|
+
output_sizes,
|
|
230
|
+
n_shards,
|
|
231
|
+
dim=1,
|
|
232
|
+
)
|
|
233
|
+
w13_weight = jax.device_put(
|
|
234
|
+
w13_weight,
|
|
235
|
+
Format(Layout((0, 1, 2)),
|
|
236
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
237
|
+
w2_weight = jax.device_put(
|
|
238
|
+
w2_weight,
|
|
239
|
+
Format(Layout((0, 1, 2)),
|
|
240
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
241
|
+
|
|
242
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
243
|
+
w13_bias,
|
|
244
|
+
output_sizes,
|
|
245
|
+
n_shards,
|
|
246
|
+
dim=1,
|
|
247
|
+
)
|
|
248
|
+
w13_bias = jax.device_put(
|
|
249
|
+
w13_bias,
|
|
250
|
+
Format(Layout((0, 1)),
|
|
251
|
+
NamedSharding(self.mesh, P(None, "model"))))
|
|
252
|
+
w2_bias = jax.device_put(
|
|
253
|
+
w2_bias,
|
|
254
|
+
Format(Layout((0, 1)),
|
|
255
|
+
NamedSharding(self.mesh, P(None, None))))
|
|
211
256
|
|
|
212
257
|
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
213
258
|
requires_grad=False)
|
|
214
|
-
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
215
|
-
|
|
216
259
|
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
217
|
-
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
218
260
|
|
|
219
|
-
|
|
261
|
+
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
262
|
+
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
220
263
|
|
|
221
264
|
def apply(
|
|
222
265
|
self,
|
|
223
266
|
layer: torch.nn.Module,
|
|
224
267
|
x: torch.Tensor,
|
|
225
268
|
router_logits: torch.Tensor,
|
|
226
|
-
top_k: int,
|
|
227
|
-
renormalize: bool,
|
|
228
|
-
use_grouped_topk: bool = False,
|
|
229
|
-
topk_group: Optional[int] = None,
|
|
230
|
-
num_expert_group: Optional[int] = None,
|
|
231
|
-
global_num_experts: int = -1,
|
|
232
|
-
expert_map: Optional[torch.Tensor] = None,
|
|
233
|
-
custom_routing_function: Optional[Callable] = None,
|
|
234
|
-
scoring_func: str = "softmax",
|
|
235
|
-
routed_scaling_factor: float = 1.0,
|
|
236
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
237
|
-
apply_router_weight_on_input: bool = False,
|
|
238
|
-
activation: str = "silu",
|
|
239
|
-
enable_eplb: bool = False,
|
|
240
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
241
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
242
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
243
269
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
244
270
|
assert isinstance(layer, FusedMoE)
|
|
245
|
-
if scoring_func != "softmax":
|
|
271
|
+
if layer.scoring_func != "softmax":
|
|
246
272
|
raise NotImplementedError(
|
|
247
273
|
"Only softmax is supported for scoring_func")
|
|
248
274
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
275
|
+
x = jax_view(x)
|
|
276
|
+
w13_weight = jax_view(layer.w13_weight)
|
|
277
|
+
w2_weight = jax_view(layer.w2_weight)
|
|
278
|
+
w13_bias = jax_view(layer.w13_bias)
|
|
279
|
+
w2_bias = jax_view(layer.w2_bias)
|
|
280
|
+
gating_output = jax_view(router_logits)
|
|
281
|
+
|
|
282
|
+
if self.use_kernel:
|
|
283
|
+
output = fused_ep_moe(
|
|
284
|
+
mesh=self.mesh,
|
|
285
|
+
tokens=x,
|
|
286
|
+
w1=w13_weight,
|
|
287
|
+
w2=w2_weight,
|
|
288
|
+
b1=w13_bias,
|
|
289
|
+
b2=w2_bias,
|
|
290
|
+
gating_output=gating_output,
|
|
291
|
+
top_k=layer.top_k,
|
|
292
|
+
ep_axis_name=self.ep_axis_name,
|
|
293
|
+
renormalize_topk_logits=layer.renormalize,
|
|
294
|
+
act_fn=layer.activation,
|
|
295
|
+
**self.block_size,
|
|
296
|
+
)
|
|
297
|
+
else:
|
|
298
|
+
output = fused_moe_func(
|
|
299
|
+
hidden_states=x,
|
|
300
|
+
w1=w13_weight,
|
|
301
|
+
w2=w2_weight,
|
|
302
|
+
w1_bias=w13_bias,
|
|
303
|
+
w2_bias=w2_bias,
|
|
304
|
+
gating_output=gating_output,
|
|
305
|
+
topk=layer.top_k,
|
|
306
|
+
renormalize=layer.renormalize,
|
|
307
|
+
mesh=self.mesh,
|
|
308
|
+
use_ep=layer.use_ep,
|
|
309
|
+
activation=layer.activation,
|
|
310
|
+
)
|
|
265
311
|
|
|
266
312
|
return torch_view(output)
|