sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
4
4
  import torch
5
5
  from torch.nn import Module
6
6
  from vllm import _custom_ops as ops
7
- from vllm.distributed import (
7
+ from vllm.model_executor.custom_op import CustomOp
8
+
9
+ from sglang.srt.distributed import (
8
10
  get_tensor_model_parallel_rank,
9
11
  get_tensor_model_parallel_world_size,
10
12
  )
11
- from vllm.model_executor.custom_op import CustomOp
12
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
13
-
14
13
  from sglang.srt.layers.custom_op_util import register_custom_op
15
14
  from sglang.srt.layers.moe.ep_moe.kernels import (
16
15
  grouped_gemm_triton,
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
25
24
  QuantizationConfig,
26
25
  QuantizeMethodBase,
27
26
  )
27
+ from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
28
28
  from sglang.srt.utils import is_hip, set_weight_attrs
29
29
 
30
30
  logger = logging.getLogger(__name__)
@@ -114,6 +114,8 @@ class EPMoE(torch.nn.Module):
114
114
  tp_size: Optional[int] = None,
115
115
  prefix: str = "",
116
116
  correction_bias: Optional[torch.Tensor] = None,
117
+ custom_routing_function: Optional[Callable] = None,
118
+ activation: str = "silu",
117
119
  ):
118
120
  super().__init__()
119
121
 
@@ -140,6 +142,8 @@ class EPMoE(torch.nn.Module):
140
142
  self.num_expert_group = num_expert_group
141
143
  self.topk_group = topk_group
142
144
  self.correction_bias = correction_bias
145
+ self.custom_routing_function = custom_routing_function
146
+ self.activation = activation
143
147
 
144
148
  if quant_config is None:
145
149
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
@@ -166,6 +170,7 @@ class EPMoE(torch.nn.Module):
166
170
 
167
171
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
168
172
  assert self.quant_method is not None
173
+ assert self.activation == "silu"
169
174
 
170
175
  if self.grouped_gemm_runner is None:
171
176
  self.grouped_gemm_runner = GroupedGemmRunner(
@@ -181,6 +186,7 @@ class EPMoE(torch.nn.Module):
181
186
  topk_group=self.topk_group,
182
187
  num_expert_group=self.num_expert_group,
183
188
  correction_bias=self.correction_bias,
189
+ custom_routing_function=self.custom_routing_function,
184
190
  )
185
191
 
186
192
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -254,16 +260,20 @@ class EPMoE(torch.nn.Module):
254
260
  dtype=torch.float32,
255
261
  device=hidden_states.device,
256
262
  )
257
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
258
- gateup_output,
259
- down_input,
260
- gateup_output.shape[1],
261
- reorder_topk_ids,
262
- self.w2_input_scale,
263
- self.start_expert_id,
264
- self.end_expert_id,
265
- BLOCK_SIZE=512,
266
- )
263
+
264
+ if self.activation == "silu":
265
+ silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
266
+ gateup_output,
267
+ down_input,
268
+ gateup_output.shape[1],
269
+ reorder_topk_ids,
270
+ self.w2_input_scale,
271
+ self.start_expert_id,
272
+ self.end_expert_id,
273
+ BLOCK_SIZE=512,
274
+ )
275
+ else:
276
+ raise ValueError(f"Unsupported activation: {self.activation=}")
267
277
 
268
278
  # GroupGemm-1
269
279
  down_output = torch.empty(
@@ -309,7 +319,6 @@ class EPMoE(torch.nn.Module):
309
319
  ckpt_up_proj_name: str,
310
320
  num_experts: int,
311
321
  ) -> List[Tuple[str, str, int, str]]:
312
-
313
322
  return [
314
323
  # (param_name, weight_name, expert_id, shard_id)
315
324
  (
@@ -354,7 +363,6 @@ class EPMoE(torch.nn.Module):
354
363
  )
355
364
  return
356
365
 
357
- expert_data = param.data[expert_id]
358
366
  if shard_id == "w2":
359
367
  param.data[expert_id] = loaded_weight
360
368
  elif shard_id == "w1":
@@ -8,6 +8,7 @@ from typing import Callable, Optional
8
8
  import torch
9
9
  from torch.nn import functional as F
10
10
 
11
+ from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
11
12
  from sglang.srt.layers.moe.topk import select_experts
12
13
 
13
14
 
@@ -22,6 +23,7 @@ def fused_moe_forward_native(
22
23
  num_expert_group: Optional[int] = None,
23
24
  custom_routing_function: Optional[Callable] = None,
24
25
  correction_bias: Optional[torch.Tensor] = None,
26
+ activation: str = "silu",
25
27
  ) -> torch.Tensor:
26
28
  topk_weights, topk_ids = select_experts(
27
29
  hidden_states=x,
@@ -40,7 +42,88 @@ def fused_moe_forward_native(
40
42
  w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
41
43
  w2_weights = layer.w2_weight[topk_ids]
42
44
  x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
43
- x1 = F.silu(x1)
45
+ if activation == "silu":
46
+ x1 = F.silu(x1)
47
+ elif activation == "gelu":
48
+ x1 = F.gelu(x1)
49
+ else:
50
+ raise ValueError(f"Unsupported activation: {activation=}")
44
51
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
45
52
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
46
53
  return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
54
+
55
+
56
+ def moe_forward_native(
57
+ layer: torch.nn.Module,
58
+ x: torch.Tensor,
59
+ use_grouped_topk: bool,
60
+ top_k: int,
61
+ router_logits: torch.Tensor,
62
+ renormalize: bool,
63
+ topk_group: Optional[int] = None,
64
+ num_expert_group: Optional[int] = None,
65
+ custom_routing_function: Optional[Callable] = None,
66
+ correction_bias: Optional[torch.Tensor] = None,
67
+ activation: str = "silu",
68
+ ) -> torch.Tensor:
69
+
70
+ topk_weights, topk_ids = select_experts(
71
+ hidden_states=x,
72
+ router_logits=router_logits,
73
+ use_grouped_topk=use_grouped_topk,
74
+ top_k=top_k,
75
+ renormalize=renormalize,
76
+ topk_group=topk_group,
77
+ num_expert_group=num_expert_group,
78
+ custom_routing_function=custom_routing_function,
79
+ correction_bias=correction_bias,
80
+ torch_native=True,
81
+ )
82
+
83
+ # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
84
+ len_experts = layer.num_experts
85
+
86
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
87
+ cnts.scatter_(1, topk_ids.to(torch.int64), 1)
88
+ tokens_per_expert = cnts.sum(dim=0)
89
+ idxs = topk_ids.view(-1).argsort()
90
+
91
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
92
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
93
+
94
+ if activation == "silu":
95
+ act = SiluAndMul()
96
+ elif activation == "gelu":
97
+ act = GeluAndMul()
98
+ else:
99
+ raise ValueError(f"Unsupported activation: {activation=}")
100
+
101
+ outputs = []
102
+ start_idx = 0
103
+ for i, num_tokens in enumerate(tokens_per_expert):
104
+ end_idx = start_idx + num_tokens
105
+ if num_tokens == 0:
106
+ continue
107
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
108
+
109
+ layer_w13_weight = layer.w13_weight[i]
110
+ layer_w2_weight = layer.w2_weight[i]
111
+
112
+ gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
113
+ gate_up = act(gate_up)
114
+ expert_out = F.linear(gate_up, layer_w2_weight)
115
+ outputs.append(expert_out)
116
+ start_idx = end_idx
117
+
118
+ outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
119
+ new_x = torch.empty_like(outs)
120
+
121
+ new_x[idxs] = outs
122
+ final_out = (
123
+ new_x.view(*topk_ids.shape, -1)
124
+ .type(topk_weights.dtype)
125
+ .mul_(topk_weights.unsqueeze(dim=-1))
126
+ .sum(dim=1)
127
+ .type(new_x.dtype)
128
+ )
129
+ return final_out
@@ -0,0 +1,164 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 32,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 32,
13
+ "BLOCK_SIZE_N": 64,
14
+ "BLOCK_SIZE_K": 128,
15
+ "GROUP_SIZE_M": 1,
16
+ "num_warps": 4,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 64,
22
+ "BLOCK_SIZE_N": 64,
23
+ "BLOCK_SIZE_K": 128,
24
+ "GROUP_SIZE_M": 16,
25
+ "num_warps": 4,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 32,
31
+ "BLOCK_SIZE_N": 128,
32
+ "BLOCK_SIZE_K": 128,
33
+ "GROUP_SIZE_M": 32,
34
+ "num_warps": 4,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 32,
40
+ "BLOCK_SIZE_N": 128,
41
+ "BLOCK_SIZE_K": 128,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 4,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 32,
49
+ "BLOCK_SIZE_N": 128,
50
+ "BLOCK_SIZE_K": 128,
51
+ "GROUP_SIZE_M": 4,
52
+ "num_warps": 4,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 32,
58
+ "BLOCK_SIZE_N": 128,
59
+ "BLOCK_SIZE_K": 128,
60
+ "GROUP_SIZE_M": 8,
61
+ "num_warps": 4,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 32,
67
+ "BLOCK_SIZE_N": 128,
68
+ "BLOCK_SIZE_K": 128,
69
+ "GROUP_SIZE_M": 4,
70
+ "num_warps": 4,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 256,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 32,
85
+ "BLOCK_SIZE_N": 128,
86
+ "BLOCK_SIZE_K": 128,
87
+ "GROUP_SIZE_M": 8,
88
+ "num_warps": 4,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 32,
94
+ "BLOCK_SIZE_N": 16,
95
+ "BLOCK_SIZE_K": 128,
96
+ "GROUP_SIZE_M": 4,
97
+ "num_warps": 4,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 64,
103
+ "BLOCK_SIZE_N": 16,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 1,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 64,
112
+ "BLOCK_SIZE_N": 64,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 32,
115
+ "num_warps": 4,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 64,
121
+ "BLOCK_SIZE_N": 64,
122
+ "BLOCK_SIZE_K": 128,
123
+ "GROUP_SIZE_M": 4,
124
+ "num_warps": 8,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 64,
130
+ "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 8,
133
+ "num_warps": 4,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 32,
139
+ "BLOCK_SIZE_N": 64,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 1,
142
+ "num_warps": 4,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 32,
148
+ "BLOCK_SIZE_N": 128,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 1,
151
+ "num_warps": 4,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 64,
157
+ "BLOCK_SIZE_N": 128,
158
+ "BLOCK_SIZE_K": 64,
159
+ "GROUP_SIZE_M": 4,
160
+ "num_warps": 4,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }
@@ -15,15 +15,18 @@ from vllm import _custom_ops as ops
15
15
 
16
16
  from sglang.srt.layers.moe.topk import select_experts
17
17
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
- from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
18
+ from sglang.srt.utils import (
19
+ direct_register_custom_op,
20
+ get_device_name,
21
+ is_cuda_available,
22
+ is_hip,
23
+ )
19
24
 
20
- is_hip_flag = False
21
- if not is_hip():
25
+ is_cuda = is_cuda_available()
26
+ is_hip_flag = is_hip()
27
+ if is_cuda:
22
28
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
23
29
 
24
- is_hip_flag = False
25
- else:
26
- is_hip_flag = True
27
30
 
28
31
  logger = logging.getLogger(__name__)
29
32
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@@ -708,6 +711,7 @@ def inplace_fused_experts(
708
711
  w2: torch.Tensor,
709
712
  topk_weights: torch.Tensor,
710
713
  topk_ids: torch.Tensor,
714
+ activation: str = "silu",
711
715
  use_fp8_w8a8: bool = False,
712
716
  use_int8_w8a16: bool = False,
713
717
  w1_scale: Optional[torch.Tensor] = None,
@@ -723,6 +727,7 @@ def inplace_fused_experts(
723
727
  topk_weights,
724
728
  topk_ids,
725
729
  True,
730
+ activation,
726
731
  use_fp8_w8a8,
727
732
  use_int8_w8a16,
728
733
  w1_scale,
@@ -739,6 +744,7 @@ def inplace_fused_experts_fake(
739
744
  w2: torch.Tensor,
740
745
  topk_weights: torch.Tensor,
741
746
  topk_ids: torch.Tensor,
747
+ activation: str = "silu",
742
748
  use_fp8_w8a8: bool = False,
743
749
  use_int8_w8a16: bool = False,
744
750
  w1_scale: Optional[torch.Tensor] = None,
@@ -764,6 +770,7 @@ def outplace_fused_experts(
764
770
  w2: torch.Tensor,
765
771
  topk_weights: torch.Tensor,
766
772
  topk_ids: torch.Tensor,
773
+ activation: str = "silu",
767
774
  use_fp8_w8a8: bool = False,
768
775
  use_int8_w8a16: bool = False,
769
776
  w1_scale: Optional[torch.Tensor] = None,
@@ -779,6 +786,7 @@ def outplace_fused_experts(
779
786
  topk_weights,
780
787
  topk_ids,
781
788
  False,
789
+ activation,
782
790
  use_fp8_w8a8,
783
791
  use_int8_w8a16,
784
792
  w1_scale,
@@ -795,6 +803,7 @@ def outplace_fused_experts_fake(
795
803
  w2: torch.Tensor,
796
804
  topk_weights: torch.Tensor,
797
805
  topk_ids: torch.Tensor,
806
+ activation: str = "silu",
798
807
  use_fp8_w8a8: bool = False,
799
808
  use_int8_w8a16: bool = False,
800
809
  w1_scale: Optional[torch.Tensor] = None,
@@ -821,6 +830,7 @@ def fused_experts(
821
830
  topk_weights: torch.Tensor,
822
831
  topk_ids: torch.Tensor,
823
832
  inplace: bool = False,
833
+ activation: str = "silu",
824
834
  use_fp8_w8a8: bool = False,
825
835
  use_int8_w8a16: bool = False,
826
836
  w1_scale: Optional[torch.Tensor] = None,
@@ -836,6 +846,7 @@ def fused_experts(
836
846
  w2,
837
847
  topk_weights,
838
848
  topk_ids,
849
+ activation,
839
850
  use_fp8_w8a8,
840
851
  use_int8_w8a16,
841
852
  w1_scale,
@@ -852,6 +863,7 @@ def fused_experts(
852
863
  w2,
853
864
  topk_weights,
854
865
  topk_ids,
866
+ activation,
855
867
  use_fp8_w8a8,
856
868
  use_int8_w8a16,
857
869
  w1_scale,
@@ -869,6 +881,7 @@ def fused_experts_impl(
869
881
  topk_weights: torch.Tensor,
870
882
  topk_ids: torch.Tensor,
871
883
  inplace: bool = False,
884
+ activation: str = "silu",
872
885
  use_fp8_w8a8: bool = False,
873
886
  use_int8_w8a16: bool = False,
874
887
  w1_scale: Optional[torch.Tensor] = None,
@@ -983,7 +996,12 @@ def fused_experts_impl(
983
996
  block_shape=block_shape,
984
997
  )
985
998
 
986
- ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
999
+ if activation == "silu":
1000
+ ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1001
+ elif activation == "gelu":
1002
+ ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1003
+ else:
1004
+ raise ValueError(f"Unsupported activation: {activation=}")
987
1005
 
988
1006
  invoke_fused_moe_kernel(
989
1007
  intermediate_cache2,
@@ -1039,6 +1057,7 @@ def fused_moe(
1039
1057
  topk: int,
1040
1058
  renormalize: bool,
1041
1059
  inplace: bool = False,
1060
+ activation: str = "silu",
1042
1061
  use_grouped_topk: bool = False,
1043
1062
  num_expert_group: Optional[int] = None,
1044
1063
  topk_group: Optional[int] = None,
@@ -1108,6 +1127,7 @@ def fused_moe(
1108
1127
  topk_weights,
1109
1128
  topk_ids,
1110
1129
  inplace=inplace,
1130
+ activation=activation,
1111
1131
  use_fp8_w8a8=use_fp8_w8a8,
1112
1132
  use_int8_w8a16=use_int8_w8a16,
1113
1133
  w1_scale=w1_scale,
@@ -5,14 +5,15 @@ from enum import Enum
5
5
  from typing import Callable, List, Optional, Tuple
6
6
 
7
7
  import torch
8
- from vllm.distributed import (
8
+ from vllm.model_executor.custom_op import CustomOp
9
+
10
+ from sglang.srt.distributed import (
9
11
  get_tensor_model_parallel_rank,
10
12
  get_tensor_model_parallel_world_size,
11
13
  tensor_model_parallel_all_reduce,
12
14
  )
13
- from vllm.model_executor.custom_op import CustomOp
14
-
15
15
  from sglang.srt.layers.custom_op_util import register_custom_op
16
+ from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
16
17
  from sglang.srt.layers.moe.topk import select_experts
17
18
  from sglang.srt.layers.quantization.base_config import (
18
19
  QuantizationConfig,
@@ -125,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
125
126
  num_expert_group: Optional[int] = None,
126
127
  custom_routing_function: Optional[Callable] = None,
127
128
  correction_bias: Optional[torch.Tensor] = None,
129
+ activation: str = "silu",
128
130
  ) -> torch.Tensor:
129
131
  return self.forward(
130
132
  x=x,
@@ -137,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
137
139
  num_expert_group=num_expert_group,
138
140
  custom_routing_function=custom_routing_function,
139
141
  correction_bias=correction_bias,
142
+ activation=activation,
140
143
  )
141
144
 
142
145
  def forward_cuda(
@@ -151,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
151
154
  num_expert_group: Optional[int] = None,
152
155
  custom_routing_function: Optional[Callable] = None,
153
156
  correction_bias: Optional[torch.Tensor] = None,
157
+ activation: str = "silu",
154
158
  ) -> torch.Tensor:
155
159
  topk_weights, topk_ids = select_experts(
156
160
  hidden_states=x,
@@ -168,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
168
172
  import ater
169
173
  from ater.fused_moe import fused_experts_ck
170
174
 
175
+ assert activation == "silu", f"{activation=} is not supported."
176
+
171
177
  return fused_experts_ck(
172
178
  hidden_states=x,
173
179
  w1=layer.w13_weight,
@@ -183,10 +189,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
183
189
  topk_weights=topk_weights,
184
190
  topk_ids=topk_ids,
185
191
  inplace=True,
192
+ activation=activation,
186
193
  )
187
194
 
188
- def forward_cpu(self, *args, **kwargs):
189
- raise NotImplementedError("The CPU backend currently does not support MoE.")
195
+ def forward_cpu(
196
+ self,
197
+ layer: torch.nn.Module,
198
+ x: torch.Tensor,
199
+ use_grouped_topk: bool,
200
+ top_k: int,
201
+ router_logits: torch.Tensor,
202
+ renormalize: bool,
203
+ topk_group: Optional[int] = None,
204
+ num_expert_group: Optional[int] = None,
205
+ custom_routing_function: Optional[Callable] = None,
206
+ correction_bias: Optional[torch.Tensor] = None,
207
+ ) -> torch.Tensor:
208
+ return moe_forward_native(
209
+ layer,
210
+ x,
211
+ use_grouped_topk,
212
+ top_k,
213
+ router_logits,
214
+ renormalize,
215
+ topk_group,
216
+ num_expert_group,
217
+ custom_routing_function,
218
+ correction_bias,
219
+ )
190
220
 
191
221
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
192
222
  raise NotImplementedError("The TPU backend currently does not support MoE.")
@@ -232,6 +262,7 @@ class FusedMoE(torch.nn.Module):
232
262
  prefix: str = "",
233
263
  custom_routing_function: Optional[Callable] = None,
234
264
  correction_bias: Optional[torch.Tensor] = None,
265
+ activation: str = "silu",
235
266
  use_presharded_weights: bool = False,
236
267
  ):
237
268
  super().__init__()
@@ -255,6 +286,7 @@ class FusedMoE(torch.nn.Module):
255
286
  self.topk_group = topk_group
256
287
  self.custom_routing_function = custom_routing_function
257
288
  self.correction_bias = correction_bias
289
+ self.activation = activation
258
290
 
259
291
  if quant_config is None:
260
292
  self.quant_method: Optional[QuantizeMethodBase] = (
@@ -565,6 +597,7 @@ class FusedMoE(torch.nn.Module):
565
597
  num_expert_group=self.num_expert_group,
566
598
  custom_routing_function=self.custom_routing_function,
567
599
  correction_bias=self.correction_bias,
600
+ activation=self.activation,
568
601
  )
569
602
 
570
603
  if self.reduce_results and self.tp_size > 1:
@@ -6,7 +6,8 @@ from typing import Callable, Optional, Union
6
6
 
7
7
  import torch
8
8
  from torch.nn import Parameter
9
- from vllm.distributed import get_tensor_model_parallel_rank
9
+
10
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
10
11
 
11
12
  __all__ = [
12
13
  "BasevLLMParameter",
@@ -123,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
123
124
  assert param_data.shape == loaded_weight.shape
124
125
  param_data.copy_(loaded_weight)
125
126
 
126
- def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
127
+ def load_qkv_weight(
128
+ self,
129
+ loaded_weight: torch.Tensor,
130
+ tp_rank: int,
131
+ use_presharded_weights: bool = False,
132
+ **kwargs,
133
+ ):
127
134
 
128
135
  shard_offset = kwargs.get("shard_offset")
129
136
  shard_size = kwargs.get("shard_size")
@@ -141,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
141
148
  param_data = self.data
142
149
  shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
143
150
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
144
- loaded_weight = loaded_weight.narrow(
145
- self.output_dim, shard_id * shard_size, shard_size
146
- )
151
+ if not use_presharded_weights:
152
+ loaded_weight = loaded_weight.narrow(
153
+ self.output_dim, shard_id * shard_size, shard_size
154
+ )
147
155
 
148
- assert param_data.shape == loaded_weight.shape
156
+ assert (
157
+ param_data.shape == loaded_weight.shape
158
+ ), f"{param_data.shape=}, {loaded_weight.shape=}"
149
159
  param_data.copy_(loaded_weight)
150
160
 
151
161
 
@@ -291,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
291
301
  packed_factor: Union[int, Fraction],
292
302
  packed_dim: int,
293
303
  marlin_tile_size: Optional[int] = None,
294
- **kwargs
304
+ **kwargs,
295
305
  ):
296
306
  self._packed_factor = packed_factor
297
307
  self._packed_dim = packed_dim
@@ -335,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
335
345
  packed_factor: Union[int, Fraction],
336
346
  packed_dim: int,
337
347
  marlin_tile_size: Optional[int] = None,
338
- **kwargs
348
+ **kwargs,
339
349
  ):
340
350
  self._packed_factor = packed_factor
341
351
  self._packed_dim = packed_dim