tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,10 @@
1
- from typing import Callable, Optional, Union
1
+ from typing import Union
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
5
5
  import torch
6
6
  import torch.nn.functional as F
7
+ from compressed_tensors.quantization import QuantizationArgs
7
8
  from jax.experimental.layout import Format, Layout
8
9
  from jax.sharding import Mesh, NamedSharding
9
10
  from jax.sharding import PartitionSpec as P
@@ -12,52 +13,89 @@ from torchax.interop import call_jax, torch_view
12
13
  from torchax.ops.mappings import t2j
13
14
  from vllm.logger import init_logger
14
15
  from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
15
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
16
- CompressedTensorsConfig
17
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import \
18
- CompressedTensorsW8A8Fp8MoEMethod
19
- from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
20
- WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
16
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
17
+ CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
21
18
 
22
19
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
20
+ from tpu_inference.layers.vllm.quantization.unquantized import \
21
+ VllmUnquantizedFusedMoEMethod
23
22
 
24
23
  logger = init_logger(__name__)
25
24
 
26
25
 
26
+ class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
27
+
28
+ @staticmethod
29
+ def get_moe_method(
30
+ quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
31
+ layer: torch.nn.Module,
32
+ layer_name: str,
33
+ ) -> CompressedTensorsMoEMethod:
34
+
35
+ assert isinstance(layer, FusedMoE)
36
+
37
+ # FusedMoE was made by combining multiple Linears so need to
38
+ # make sure quantization config for Linear can target it
39
+ quant_config._add_fused_moe_to_target_scheme_map()
40
+ unfused_names = [
41
+ layer_name + proj_name
42
+ for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
43
+ ]
44
+ # TODO: refactor this to use expert_mapping and check all layer numbers
45
+ all_scheme_dicts = [
46
+ quant_config.get_scheme_dict(layer, name) for name in unfused_names
47
+ ]
48
+ scheme_dict = all_scheme_dicts.pop()
49
+
50
+ # multiple schemes found
51
+ if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
52
+ raise ValueError("All MoE projections need to have same "
53
+ "quantization scheme but found multiple")
54
+
55
+ if scheme_dict is None:
56
+ return VllmUnquantizedFusedMoEMethod(layer.moe_config,
57
+ quant_config.mesh)
58
+
59
+ weight_quant = scheme_dict.get("weights")
60
+ input_quant = scheme_dict.get("input_activations")
61
+
62
+ if quant_config._is_fp8_w8a8(weight_quant, input_quant):
63
+ return VllmCompressedTensorsW8A8Fp8MoEMethod(
64
+ weight_quant, input_quant, layer.moe_config, quant_config.mesh)
65
+ else:
66
+ raise RuntimeError(
67
+ f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
68
+
69
+
27
70
  class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
28
71
  JaxCommonConfig):
29
72
 
30
- def __init__(self, quant_config: "CompressedTensorsConfig",
31
- moe: FusedMoEConfig, mesh: Mesh):
32
- super().__init__(quant_config, moe)
73
+ def __init__(self, weight_quant: QuantizationArgs,
74
+ input_quant: QuantizationArgs, moe: FusedMoEConfig,
75
+ mesh: Mesh):
76
+ super().__init__(weight_quant, input_quant, moe)
33
77
  self.mesh = mesh
34
- self.quant_config = quant_config
35
-
36
- # disable GPU paths
37
- self.use_marlin = False
38
- self.rocm_aiter_moe_enabled = False # is_rocm_aiter_moe_enabled()
39
- self.is_fp8_w8a8_sm100 = False
40
- self.use_cutlass = False
41
- self.disable_expert_map = False
42
78
 
43
79
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
44
80
  assert isinstance(layer, FusedMoE)
45
81
 
46
- intermediate_size = layer.w13_weight.shape[1] // 2
47
- w1_weight = layer.w13_weight[:, :intermediate_size]
48
- w3_weight = layer.w13_weight[:, intermediate_size:]
49
- w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
50
- w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
51
-
82
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
83
+ w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
52
84
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
53
- w2_weight_scale = t2j(layer.w2_weight_scale.to(torch.bfloat16),
54
- use_dlpack=False)
55
- w1_weight = t2j(w1_weight, use_dlpack=False)
56
- w1_weight_scale = t2j(w1_weight_scale.to(torch.bfloat16),
57
- use_dlpack=False)
58
- w3_weight = t2j(w3_weight, use_dlpack=False)
59
- w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16),
60
- use_dlpack=False)
85
+ w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
86
+
87
+ w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16)
88
+ w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16)
89
+
90
+ num_experts, hidden_size, intermediate_size = w2_weight.shape
91
+ assert w2_weight_scale.shape == (num_experts, hidden_size, 1)
92
+ assert w13_weight.shape == (num_experts, 2 * intermediate_size,
93
+ hidden_size)
94
+ assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size,
95
+ 1)
96
+
97
+ w1_weight, w3_weight = jnp.split(w13_weight, 2, 1)
98
+ w1_weight_scale, w3_weight_scale = jnp.split(w13_weight_scale, 2, 1)
61
99
 
62
100
  if layer.use_ep:
63
101
  format = Format(Layout((0, 1, 2)),
@@ -69,16 +107,9 @@ class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
69
107
  w2_weight = jax.device_put(w2_weight, format)
70
108
  w2_weight_scale = jax.device_put(w2_weight_scale, format)
71
109
  else:
72
- assert intermediate_size == w2_weight.shape[-1]
73
110
  n_shards = self.mesh.shape["model"]
74
111
  assert intermediate_size % n_shards == 0
75
112
 
76
- # TODO: enable this if using fused weights
77
- # output_sizes = [intermediate_size, intermediate_size]
78
- # w13_weight = reorder_concatenated_tensor_for_sharding(
79
- # w13_weight, output_sizes, n_shards, dim=1
80
- # )
81
-
82
113
  w13_format = Format(
83
114
  Layout((0, 1, 2)),
84
115
  NamedSharding(self.mesh, P(None, "model", None)))
@@ -119,45 +150,23 @@ class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
119
150
  layer: torch.nn.Module,
120
151
  x: torch.Tensor,
121
152
  router_logits: torch.Tensor,
122
- top_k: int,
123
- renormalize: bool,
124
- use_grouped_topk: bool = False,
125
- topk_group: Optional[int] = None,
126
- num_expert_group: Optional[int] = None,
127
- global_num_experts: int = -1,
128
- expert_map: Optional[torch.Tensor] = None,
129
- custom_routing_function: Optional[Callable] = None,
130
- scoring_func: str = "softmax",
131
- routed_scaling_factor: float = 1.0,
132
- e_score_correction_bias: Optional[torch.Tensor] = None,
133
- apply_router_weight_on_input: bool = False,
134
- activation: str = "silu",
135
- enable_eplb: bool = False,
136
- expert_load_view: Optional[torch.Tensor] = None,
137
- logical_to_physical_map: Optional[torch.Tensor] = None,
138
- logical_replica_count: Optional[torch.Tensor] = None,
139
153
  ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
140
154
  assert isinstance(layer, FusedMoE)
141
- if activation != "silu":
155
+ if layer.activation != "silu":
142
156
  raise NotImplementedError(
143
157
  "Only silu is supported for activation function.")
144
- if scoring_func != "softmax":
158
+ if layer.scoring_func != "softmax":
145
159
  raise NotImplementedError(
146
160
  "Only softmax is supported for scoring_func")
147
161
 
148
- # import sys
149
- # sys.stdin = open(0)
150
- # breakpoint()
151
-
152
162
  # TODO: Use MoE kernel when it supports fp8
153
-
154
163
  seqlen = x.shape[0]
155
164
 
156
165
  expert_weights = F.softmax(router_logits, dim=-1)
157
166
  expert_weights, expert_indices = torch.topk(expert_weights,
158
- top_k,
167
+ layer.top_k,
159
168
  dim=-1)
160
- if renormalize:
169
+ if layer.renormalize:
161
170
  expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
162
171
 
163
172
  # cond ffn
@@ -1,4 +1,4 @@
1
- from typing import Callable, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import jax
4
4
  import jax.numpy as jnp
@@ -24,9 +24,11 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
24
24
  from vllm.model_executor.layers.quantization.utils.quant_utils import \
25
25
  is_layer_skipped
26
26
 
27
+ from tpu_inference import envs
28
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
27
29
  from tpu_inference.layers.common.quant_methods import (MXFP4,
28
30
  get_tpu_quant_method)
29
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
31
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func
30
32
  from tpu_inference.layers.vllm.linear_common import \
31
33
  reorder_concatenated_tensor_for_sharding
32
34
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
@@ -85,17 +87,14 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
85
87
  fused_mapping=self.packed_modules_mapping,
86
88
  ):
87
89
  return VllmUnquantizedLinearMethod(linear_config)
88
- # TODO: Add support for MXFP4 Linear Method.
89
- # MXFP4 LinearMethod is available in AMD-Quark, refer to that
90
- # implementation if you are interested in enabling MXFP4 here.
91
90
  logger.warning_once(
92
91
  "MXFP4 linear layer is not implemented - falling back to "
93
92
  "UnquantizedLinearMethod.")
94
93
  return VllmUnquantizedLinearMethod(linear_config)
95
94
  elif isinstance(layer, FusedMoE):
96
- return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
95
+ moe_config = self.get_moe_config(layer)
96
+ return VllmMxfp4MoEMethod(moe_config, self.mesh)
97
97
  elif isinstance(layer, Attention):
98
- # TODO: Add support for MXFP4 Attention.
99
98
  logger.warning_once("MXFP4 attention layer is not implemented. "
100
99
  "Skipping quantization for this layer.")
101
100
  return None
@@ -103,13 +102,30 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
103
102
 
104
103
  class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
105
104
 
106
- def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
105
+ def __init__(self,
106
+ moe: FusedMoEConfig,
107
+ mesh: Mesh,
108
+ ep_axis_name: str = 'model'):
107
109
  FusedMoEMethodBase.__init__(self, moe)
108
110
 
109
111
  # We piggyback on triton implementation as it applies minimal hardware
110
112
  # specific post processing to the weights.
111
113
  self.mxfp4_backend = Mxfp4Backend.TRITON
114
+
112
115
  self.mesh = mesh
116
+ self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
117
+ self.ep_axis_name = ep_axis_name
118
+ # TODO: Use autotune table once we have it.
119
+ self.block_size = {
120
+ "bt": 64,
121
+ "bf": 1024,
122
+ "bd1": 1536,
123
+ "bd2": 1536,
124
+ "btc": 64,
125
+ "bfc": 1024,
126
+ "bd1c": 1536,
127
+ "bd2c": 1536,
128
+ }
113
129
 
114
130
  def get_fused_moe_quant_config(
115
131
  self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
@@ -122,6 +138,7 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
122
138
 
123
139
  def process_weights_after_loading(self, layer: torch.nn.Module):
124
140
  assert isinstance(layer, FusedMoE)
141
+ assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
125
142
 
126
143
  w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
127
144
  w13_weight_scale = e8m0_to_fp32(
@@ -140,6 +157,8 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
140
157
  w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
141
158
  MXFP4_BLOCK_SIZE, jnp.bfloat16)
142
159
 
160
+ num_experts, hidden_size, intermediate_size = w2_weight.shape
161
+
143
162
  # Because we have dequantized weights, scales are not used anymore.
144
163
  delattr(layer, "w13_weight_scale")
145
164
  delattr(layer, "w2_weight_scale")
@@ -157,110 +176,137 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
157
176
  w3_bias = w13_bias[:, 1::2]
158
177
  w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
159
178
 
160
- # TODO(kyuyeunk): Add weight processing logic for the new kernel.
161
- if layer.use_ep:
179
+ if self.use_kernel:
180
+ # Kernel expects:
181
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
182
+ # w2: (num_experts, intermediate_size, hidden_size)
183
+ # Current format:
184
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
185
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
186
+
187
+ w13_reshaped = w13_weight.reshape(num_experts, 2,
188
+ intermediate_size, hidden_size)
189
+
190
+ # Transpose non-constracting dim to right most dim
191
+ w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
192
+ w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
193
+
194
+ # Apply EP sharding
195
+ ep_sharding = NamedSharding(self.mesh, P("model"))
196
+
162
197
  w13_weight = jax.device_put(
163
- w13_weight,
164
- Format(Layout((0, 1, 2)),
165
- NamedSharding(self.mesh, P("model", None, None))))
166
- w2_weight = jax.device_put(
167
- w2_weight,
168
- Format(Layout((0, 1, 2)),
169
- NamedSharding(self.mesh, P("model", None, None))))
170
-
171
- w13_bias = jax.device_put(
172
- w13_bias,
173
- Format(Layout((0, 1)),
174
- NamedSharding(self.mesh, P("model", None))))
175
- w2_bias = jax.device_put(
176
- w2_bias,
177
- Format(Layout((0, 1)),
178
- NamedSharding(self.mesh, P("model", None))))
198
+ w13_weight_transposed, Format(Layout((0, 1, 2, 3)),
199
+ ep_sharding))
200
+ w2_weight = jax.device_put(w2_weight_transposed,
201
+ Format(Layout((0, 1, 2)), ep_sharding))
202
+
203
+ w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
204
+ w13_bias = jax.device_put(w13_bias,
205
+ Format(Layout((0, 1, 2)), ep_sharding))
206
+ w2_bias = jax.device_put(w2_bias,
207
+ Format(Layout((0, 1)), ep_sharding))
179
208
 
180
209
  else:
181
- intermediate_size = w13_weight.shape[1] // 2
182
- assert intermediate_size == w2_weight.shape[-1]
183
- output_sizes = [intermediate_size, intermediate_size]
184
- n_shards = self.mesh.shape["model"]
185
- assert intermediate_size % n_shards == 0
186
- w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
187
- output_sizes,
188
- n_shards,
189
- dim=1)
190
- w13_weight = jax.device_put(
191
- w13_weight,
192
- Format(Layout((0, 1, 2)),
193
- NamedSharding(self.mesh, P(None, "model", None))))
194
- w2_weight = jax.device_put(
195
- w2_weight,
196
- Format(Layout((0, 1, 2)),
197
- NamedSharding(self.mesh, P(None, None, "model"))))
198
-
199
- w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
200
- output_sizes,
201
- n_shards,
202
- dim=1)
203
- w13_bias = jax.device_put(
204
- w13_bias,
205
- Format(Layout((0, 1)),
206
- NamedSharding(self.mesh, P(None, "model"))))
207
- w2_bias = jax.device_put(
208
- w2_bias,
209
- Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
210
- None))))
210
+ if layer.use_ep:
211
+ ep_sharding = NamedSharding(self.mesh, P("model"))
212
+ w13_weight = jax.device_put(
213
+ w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
214
+ w2_weight = jax.device_put(
215
+ w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
216
+
217
+ w13_bias = jax.device_put(w13_bias,
218
+ Format(Layout((0, 1)), ep_sharding))
219
+ w2_bias = jax.device_put(w2_bias,
220
+ Format(Layout((0, 1)), ep_sharding))
221
+
222
+ else:
223
+ output_sizes = [intermediate_size, intermediate_size]
224
+ n_shards = self.mesh.shape["model"]
225
+ assert intermediate_size % n_shards == 0
226
+
227
+ w13_weight = reorder_concatenated_tensor_for_sharding(
228
+ w13_weight,
229
+ output_sizes,
230
+ n_shards,
231
+ dim=1,
232
+ )
233
+ w13_weight = jax.device_put(
234
+ w13_weight,
235
+ Format(Layout((0, 1, 2)),
236
+ NamedSharding(self.mesh, P(None, "model", None))))
237
+ w2_weight = jax.device_put(
238
+ w2_weight,
239
+ Format(Layout((0, 1, 2)),
240
+ NamedSharding(self.mesh, P(None, None, "model"))))
241
+
242
+ w13_bias = reorder_concatenated_tensor_for_sharding(
243
+ w13_bias,
244
+ output_sizes,
245
+ n_shards,
246
+ dim=1,
247
+ )
248
+ w13_bias = jax.device_put(
249
+ w13_bias,
250
+ Format(Layout((0, 1)),
251
+ NamedSharding(self.mesh, P(None, "model"))))
252
+ w2_bias = jax.device_put(
253
+ w2_bias,
254
+ Format(Layout((0, 1)),
255
+ NamedSharding(self.mesh, P(None, None))))
211
256
 
212
257
  layer.w13_weight = Parameter(torch_view(w13_weight),
213
258
  requires_grad=False)
214
- layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
215
-
216
259
  layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
217
- layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
218
260
 
219
- pass
261
+ layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
262
+ layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
220
263
 
221
264
  def apply(
222
265
  self,
223
266
  layer: torch.nn.Module,
224
267
  x: torch.Tensor,
225
268
  router_logits: torch.Tensor,
226
- top_k: int,
227
- renormalize: bool,
228
- use_grouped_topk: bool = False,
229
- topk_group: Optional[int] = None,
230
- num_expert_group: Optional[int] = None,
231
- global_num_experts: int = -1,
232
- expert_map: Optional[torch.Tensor] = None,
233
- custom_routing_function: Optional[Callable] = None,
234
- scoring_func: str = "softmax",
235
- routed_scaling_factor: float = 1.0,
236
- e_score_correction_bias: Optional[torch.Tensor] = None,
237
- apply_router_weight_on_input: bool = False,
238
- activation: str = "silu",
239
- enable_eplb: bool = False,
240
- expert_load_view: Optional[torch.Tensor] = None,
241
- logical_to_physical_map: Optional[torch.Tensor] = None,
242
- logical_replica_count: Optional[torch.Tensor] = None,
243
269
  ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
244
270
  assert isinstance(layer, FusedMoE)
245
- if scoring_func != "softmax":
271
+ if layer.scoring_func != "softmax":
246
272
  raise NotImplementedError(
247
273
  "Only softmax is supported for scoring_func")
248
274
 
249
- # Use the original implementation
250
- output = fused_moe_func_padded(
251
- jax_view(x),
252
- jax_view(layer.w13_weight),
253
- jax_view(layer.w2_weight),
254
- jax_view(layer.w13_bias) if self.moe.has_bias else None,
255
- jax_view(layer.w2_bias) if self.moe.has_bias else None,
256
- jax_view(router_logits),
257
- topk=top_k,
258
- global_num_experts=global_num_experts,
259
- renormalize=renormalize,
260
- reduce_results=layer.reduce_results,
261
- mesh=self.mesh,
262
- use_ep=layer.use_ep,
263
- activation=activation,
264
- )
275
+ x = jax_view(x)
276
+ w13_weight = jax_view(layer.w13_weight)
277
+ w2_weight = jax_view(layer.w2_weight)
278
+ w13_bias = jax_view(layer.w13_bias)
279
+ w2_bias = jax_view(layer.w2_bias)
280
+ gating_output = jax_view(router_logits)
281
+
282
+ if self.use_kernel:
283
+ output = fused_ep_moe(
284
+ mesh=self.mesh,
285
+ tokens=x,
286
+ w1=w13_weight,
287
+ w2=w2_weight,
288
+ b1=w13_bias,
289
+ b2=w2_bias,
290
+ gating_output=gating_output,
291
+ top_k=layer.top_k,
292
+ ep_axis_name=self.ep_axis_name,
293
+ renormalize_topk_logits=layer.renormalize,
294
+ act_fn=layer.activation,
295
+ **self.block_size,
296
+ )
297
+ else:
298
+ output = fused_moe_func(
299
+ hidden_states=x,
300
+ w1=w13_weight,
301
+ w2=w2_weight,
302
+ w1_bias=w13_bias,
303
+ w2_bias=w2_bias,
304
+ gating_output=gating_output,
305
+ topk=layer.top_k,
306
+ renormalize=layer.renormalize,
307
+ mesh=self.mesh,
308
+ use_ep=layer.use_ep,
309
+ activation=layer.activation,
310
+ )
265
311
 
266
312
  return torch_view(output)