sglang 0.4.1.post7__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/dp_attention.py +3 -1
  12. sglang/srt/layers/layernorm.py +5 -5
  13. sglang/srt/layers/linear.py +24 -9
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  16. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  20. sglang/srt/layers/parameter.py +16 -7
  21. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  22. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  23. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/fp8.py +4 -1
  31. sglang/srt/layers/rotary_embedding.py +6 -1
  32. sglang/srt/layers/sampler.py +28 -8
  33. sglang/srt/layers/torchao_utils.py +12 -6
  34. sglang/srt/managers/detokenizer_manager.py +1 -0
  35. sglang/srt/managers/io_struct.py +36 -5
  36. sglang/srt/managers/schedule_batch.py +31 -25
  37. sglang/srt/managers/scheduler.py +61 -35
  38. sglang/srt/managers/tokenizer_manager.py +4 -0
  39. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  40. sglang/srt/model_executor/forward_batch_info.py +5 -7
  41. sglang/srt/model_executor/model_runner.py +7 -4
  42. sglang/srt/model_loader/loader.py +75 -0
  43. sglang/srt/model_loader/weight_utils.py +91 -5
  44. sglang/srt/models/commandr.py +14 -2
  45. sglang/srt/models/dbrx.py +9 -1
  46. sglang/srt/models/deepseek_v2.py +3 -3
  47. sglang/srt/models/gemma2.py +9 -1
  48. sglang/srt/models/grok.py +1 -0
  49. sglang/srt/models/minicpm3.py +3 -3
  50. sglang/srt/models/torch_native_llama.py +17 -4
  51. sglang/srt/openai_api/adapter.py +139 -37
  52. sglang/srt/openai_api/protocol.py +5 -4
  53. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  54. sglang/srt/sampling/sampling_batch_info.py +4 -14
  55. sglang/srt/server.py +2 -2
  56. sglang/srt/server_args.py +20 -1
  57. sglang/srt/speculative/eagle_utils.py +37 -15
  58. sglang/srt/speculative/eagle_worker.py +11 -13
  59. sglang/srt/utils.py +62 -65
  60. sglang/test/test_programs.py +1 -0
  61. sglang/test/test_utils.py +81 -22
  62. sglang/version.py +1 -1
  63. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/METADATA +7 -7
  64. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/RECORD +67 -56
  65. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  66. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  67. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -114,6 +114,8 @@ class EPMoE(torch.nn.Module):
114
114
  tp_size: Optional[int] = None,
115
115
  prefix: str = "",
116
116
  correction_bias: Optional[torch.Tensor] = None,
117
+ custom_routing_function: Optional[Callable] = None,
118
+ activation: str = "silu",
117
119
  ):
118
120
  super().__init__()
119
121
 
@@ -140,6 +142,8 @@ class EPMoE(torch.nn.Module):
140
142
  self.num_expert_group = num_expert_group
141
143
  self.topk_group = topk_group
142
144
  self.correction_bias = correction_bias
145
+ self.custom_routing_function = custom_routing_function
146
+ self.activation = activation
143
147
 
144
148
  if quant_config is None:
145
149
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
@@ -166,6 +170,7 @@ class EPMoE(torch.nn.Module):
166
170
 
167
171
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
168
172
  assert self.quant_method is not None
173
+ assert self.activation == "silu"
169
174
 
170
175
  if self.grouped_gemm_runner is None:
171
176
  self.grouped_gemm_runner = GroupedGemmRunner(
@@ -181,6 +186,7 @@ class EPMoE(torch.nn.Module):
181
186
  topk_group=self.topk_group,
182
187
  num_expert_group=self.num_expert_group,
183
188
  correction_bias=self.correction_bias,
189
+ custom_routing_function=self.custom_routing_function,
184
190
  )
185
191
 
186
192
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -254,16 +260,20 @@ class EPMoE(torch.nn.Module):
254
260
  dtype=torch.float32,
255
261
  device=hidden_states.device,
256
262
  )
257
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
258
- gateup_output,
259
- down_input,
260
- gateup_output.shape[1],
261
- reorder_topk_ids,
262
- self.w2_input_scale,
263
- self.start_expert_id,
264
- self.end_expert_id,
265
- BLOCK_SIZE=512,
266
- )
263
+
264
+ if self.activation == "silu":
265
+ silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
266
+ gateup_output,
267
+ down_input,
268
+ gateup_output.shape[1],
269
+ reorder_topk_ids,
270
+ self.w2_input_scale,
271
+ self.start_expert_id,
272
+ self.end_expert_id,
273
+ BLOCK_SIZE=512,
274
+ )
275
+ else:
276
+ raise ValueError(f"Unsupported activation: {self.activation=}")
267
277
 
268
278
  # GroupGemm-1
269
279
  down_output = torch.empty(
@@ -309,7 +319,6 @@ class EPMoE(torch.nn.Module):
309
319
  ckpt_up_proj_name: str,
310
320
  num_experts: int,
311
321
  ) -> List[Tuple[str, str, int, str]]:
312
-
313
322
  return [
314
323
  # (param_name, weight_name, expert_id, shard_id)
315
324
  (
@@ -354,7 +363,6 @@ class EPMoE(torch.nn.Module):
354
363
  )
355
364
  return
356
365
 
357
- expert_data = param.data[expert_id]
358
366
  if shard_id == "w2":
359
367
  param.data[expert_id] = loaded_weight
360
368
  elif shard_id == "w1":
@@ -8,7 +8,7 @@ from typing import Callable, Optional
8
8
  import torch
9
9
  from torch.nn import functional as F
10
10
 
11
- from sglang.srt.layers.activation import SiluAndMul
11
+ from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
12
12
  from sglang.srt.layers.moe.topk import select_experts
13
13
 
14
14
 
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
23
23
  num_expert_group: Optional[int] = None,
24
24
  custom_routing_function: Optional[Callable] = None,
25
25
  correction_bias: Optional[torch.Tensor] = None,
26
+ activation: str = "silu",
26
27
  ) -> torch.Tensor:
27
28
  topk_weights, topk_ids = select_experts(
28
29
  hidden_states=x,
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
41
42
  w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
42
43
  w2_weights = layer.w2_weight[topk_ids]
43
44
  x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
44
- x1 = F.silu(x1)
45
+ if activation == "silu":
46
+ x1 = F.silu(x1)
47
+ elif activation == "gelu":
48
+ x1 = F.gelu(x1)
49
+ else:
50
+ raise ValueError(f"Unsupported activation: {activation=}")
45
51
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
46
52
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
47
53
  return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -58,6 +64,7 @@ def moe_forward_native(
58
64
  num_expert_group: Optional[int] = None,
59
65
  custom_routing_function: Optional[Callable] = None,
60
66
  correction_bias: Optional[torch.Tensor] = None,
67
+ activation: str = "silu",
61
68
  ) -> torch.Tensor:
62
69
 
63
70
  topk_weights, topk_ids = select_experts(
@@ -84,6 +91,13 @@ def moe_forward_native(
84
91
  sorted_tokens = x[idxs // topk_ids.shape[1]]
85
92
  tokens_per_expert = tokens_per_expert.cpu().numpy()
86
93
 
94
+ if activation == "silu":
95
+ act = SiluAndMul()
96
+ elif activation == "gelu":
97
+ act = GeluAndMul()
98
+ else:
99
+ raise ValueError(f"Unsupported activation: {activation=}")
100
+
87
101
  outputs = []
88
102
  start_idx = 0
89
103
  for i, num_tokens in enumerate(tokens_per_expert):
@@ -96,7 +110,7 @@ def moe_forward_native(
96
110
  layer_w2_weight = layer.w2_weight[i]
97
111
 
98
112
  gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
99
- gate_up = SiluAndMul()(gate_up)
113
+ gate_up = act(gate_up)
100
114
  expert_out = F.linear(gate_up, layer_w2_weight)
101
115
  outputs.append(expert_out)
102
116
  start_idx = end_idx
@@ -0,0 +1,164 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 32,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 32,
13
+ "BLOCK_SIZE_N": 64,
14
+ "BLOCK_SIZE_K": 128,
15
+ "GROUP_SIZE_M": 1,
16
+ "num_warps": 4,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 64,
22
+ "BLOCK_SIZE_N": 64,
23
+ "BLOCK_SIZE_K": 128,
24
+ "GROUP_SIZE_M": 16,
25
+ "num_warps": 4,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 32,
31
+ "BLOCK_SIZE_N": 128,
32
+ "BLOCK_SIZE_K": 128,
33
+ "GROUP_SIZE_M": 32,
34
+ "num_warps": 4,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 32,
40
+ "BLOCK_SIZE_N": 128,
41
+ "BLOCK_SIZE_K": 128,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 4,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 32,
49
+ "BLOCK_SIZE_N": 128,
50
+ "BLOCK_SIZE_K": 128,
51
+ "GROUP_SIZE_M": 4,
52
+ "num_warps": 4,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 32,
58
+ "BLOCK_SIZE_N": 128,
59
+ "BLOCK_SIZE_K": 128,
60
+ "GROUP_SIZE_M": 8,
61
+ "num_warps": 4,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 32,
67
+ "BLOCK_SIZE_N": 128,
68
+ "BLOCK_SIZE_K": 128,
69
+ "GROUP_SIZE_M": 4,
70
+ "num_warps": 4,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 256,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 32,
85
+ "BLOCK_SIZE_N": 128,
86
+ "BLOCK_SIZE_K": 128,
87
+ "GROUP_SIZE_M": 8,
88
+ "num_warps": 4,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 32,
94
+ "BLOCK_SIZE_N": 16,
95
+ "BLOCK_SIZE_K": 128,
96
+ "GROUP_SIZE_M": 4,
97
+ "num_warps": 4,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 64,
103
+ "BLOCK_SIZE_N": 16,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 1,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 64,
112
+ "BLOCK_SIZE_N": 64,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 32,
115
+ "num_warps": 4,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 64,
121
+ "BLOCK_SIZE_N": 64,
122
+ "BLOCK_SIZE_K": 128,
123
+ "GROUP_SIZE_M": 4,
124
+ "num_warps": 8,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 64,
130
+ "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 8,
133
+ "num_warps": 4,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 32,
139
+ "BLOCK_SIZE_N": 64,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 1,
142
+ "num_warps": 4,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 32,
148
+ "BLOCK_SIZE_N": 128,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 1,
151
+ "num_warps": 4,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 64,
157
+ "BLOCK_SIZE_N": 128,
158
+ "BLOCK_SIZE_K": 64,
159
+ "GROUP_SIZE_M": 4,
160
+ "num_warps": 4,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }
@@ -711,6 +711,7 @@ def inplace_fused_experts(
711
711
  w2: torch.Tensor,
712
712
  topk_weights: torch.Tensor,
713
713
  topk_ids: torch.Tensor,
714
+ activation: str = "silu",
714
715
  use_fp8_w8a8: bool = False,
715
716
  use_int8_w8a16: bool = False,
716
717
  w1_scale: Optional[torch.Tensor] = None,
@@ -726,6 +727,7 @@ def inplace_fused_experts(
726
727
  topk_weights,
727
728
  topk_ids,
728
729
  True,
730
+ activation,
729
731
  use_fp8_w8a8,
730
732
  use_int8_w8a16,
731
733
  w1_scale,
@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
742
744
  w2: torch.Tensor,
743
745
  topk_weights: torch.Tensor,
744
746
  topk_ids: torch.Tensor,
747
+ activation: str = "silu",
745
748
  use_fp8_w8a8: bool = False,
746
749
  use_int8_w8a16: bool = False,
747
750
  w1_scale: Optional[torch.Tensor] = None,
@@ -767,6 +770,7 @@ def outplace_fused_experts(
767
770
  w2: torch.Tensor,
768
771
  topk_weights: torch.Tensor,
769
772
  topk_ids: torch.Tensor,
773
+ activation: str = "silu",
770
774
  use_fp8_w8a8: bool = False,
771
775
  use_int8_w8a16: bool = False,
772
776
  w1_scale: Optional[torch.Tensor] = None,
@@ -782,6 +786,7 @@ def outplace_fused_experts(
782
786
  topk_weights,
783
787
  topk_ids,
784
788
  False,
789
+ activation,
785
790
  use_fp8_w8a8,
786
791
  use_int8_w8a16,
787
792
  w1_scale,
@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
798
803
  w2: torch.Tensor,
799
804
  topk_weights: torch.Tensor,
800
805
  topk_ids: torch.Tensor,
806
+ activation: str = "silu",
801
807
  use_fp8_w8a8: bool = False,
802
808
  use_int8_w8a16: bool = False,
803
809
  w1_scale: Optional[torch.Tensor] = None,
@@ -824,6 +830,7 @@ def fused_experts(
824
830
  topk_weights: torch.Tensor,
825
831
  topk_ids: torch.Tensor,
826
832
  inplace: bool = False,
833
+ activation: str = "silu",
827
834
  use_fp8_w8a8: bool = False,
828
835
  use_int8_w8a16: bool = False,
829
836
  w1_scale: Optional[torch.Tensor] = None,
@@ -839,6 +846,7 @@ def fused_experts(
839
846
  w2,
840
847
  topk_weights,
841
848
  topk_ids,
849
+ activation,
842
850
  use_fp8_w8a8,
843
851
  use_int8_w8a16,
844
852
  w1_scale,
@@ -855,6 +863,7 @@ def fused_experts(
855
863
  w2,
856
864
  topk_weights,
857
865
  topk_ids,
866
+ activation,
858
867
  use_fp8_w8a8,
859
868
  use_int8_w8a16,
860
869
  w1_scale,
@@ -872,6 +881,7 @@ def fused_experts_impl(
872
881
  topk_weights: torch.Tensor,
873
882
  topk_ids: torch.Tensor,
874
883
  inplace: bool = False,
884
+ activation: str = "silu",
875
885
  use_fp8_w8a8: bool = False,
876
886
  use_int8_w8a16: bool = False,
877
887
  w1_scale: Optional[torch.Tensor] = None,
@@ -986,7 +996,12 @@ def fused_experts_impl(
986
996
  block_shape=block_shape,
987
997
  )
988
998
 
989
- ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
999
+ if activation == "silu":
1000
+ ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1001
+ elif activation == "gelu":
1002
+ ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1003
+ else:
1004
+ raise ValueError(f"Unsupported activation: {activation=}")
990
1005
 
991
1006
  invoke_fused_moe_kernel(
992
1007
  intermediate_cache2,
@@ -1042,6 +1057,7 @@ def fused_moe(
1042
1057
  topk: int,
1043
1058
  renormalize: bool,
1044
1059
  inplace: bool = False,
1060
+ activation: str = "silu",
1045
1061
  use_grouped_topk: bool = False,
1046
1062
  num_expert_group: Optional[int] = None,
1047
1063
  topk_group: Optional[int] = None,
@@ -1111,6 +1127,7 @@ def fused_moe(
1111
1127
  topk_weights,
1112
1128
  topk_ids,
1113
1129
  inplace=inplace,
1130
+ activation=activation,
1114
1131
  use_fp8_w8a8=use_fp8_w8a8,
1115
1132
  use_int8_w8a16=use_int8_w8a16,
1116
1133
  w1_scale=w1_scale,
@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
126
126
  num_expert_group: Optional[int] = None,
127
127
  custom_routing_function: Optional[Callable] = None,
128
128
  correction_bias: Optional[torch.Tensor] = None,
129
+ activation: str = "silu",
129
130
  ) -> torch.Tensor:
130
131
  return self.forward(
131
132
  x=x,
@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
138
139
  num_expert_group=num_expert_group,
139
140
  custom_routing_function=custom_routing_function,
140
141
  correction_bias=correction_bias,
142
+ activation=activation,
141
143
  )
142
144
 
143
145
  def forward_cuda(
@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
152
154
  num_expert_group: Optional[int] = None,
153
155
  custom_routing_function: Optional[Callable] = None,
154
156
  correction_bias: Optional[torch.Tensor] = None,
157
+ activation: str = "silu",
155
158
  ) -> torch.Tensor:
156
159
  topk_weights, topk_ids = select_experts(
157
160
  hidden_states=x,
@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
169
172
  import ater
170
173
  from ater.fused_moe import fused_experts_ck
171
174
 
175
+ assert activation == "silu", f"{activation=} is not supported."
176
+
172
177
  return fused_experts_ck(
173
178
  hidden_states=x,
174
179
  w1=layer.w13_weight,
@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
184
189
  topk_weights=topk_weights,
185
190
  topk_ids=topk_ids,
186
191
  inplace=True,
192
+ activation=activation,
187
193
  )
188
194
 
189
195
  def forward_cpu(
@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
256
262
  prefix: str = "",
257
263
  custom_routing_function: Optional[Callable] = None,
258
264
  correction_bias: Optional[torch.Tensor] = None,
265
+ activation: str = "silu",
259
266
  use_presharded_weights: bool = False,
260
267
  ):
261
268
  super().__init__()
@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
279
286
  self.topk_group = topk_group
280
287
  self.custom_routing_function = custom_routing_function
281
288
  self.correction_bias = correction_bias
289
+ self.activation = activation
282
290
 
283
291
  if quant_config is None:
284
292
  self.quant_method: Optional[QuantizeMethodBase] = (
@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
589
597
  num_expert_group=self.num_expert_group,
590
598
  custom_routing_function=self.custom_routing_function,
591
599
  correction_bias=self.correction_bias,
600
+ activation=self.activation,
592
601
  )
593
602
 
594
603
  if self.reduce_results and self.tp_size > 1:
@@ -124,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
124
124
  assert param_data.shape == loaded_weight.shape
125
125
  param_data.copy_(loaded_weight)
126
126
 
127
- def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
127
+ def load_qkv_weight(
128
+ self,
129
+ loaded_weight: torch.Tensor,
130
+ tp_rank: int,
131
+ use_presharded_weights: bool = False,
132
+ **kwargs,
133
+ ):
128
134
 
129
135
  shard_offset = kwargs.get("shard_offset")
130
136
  shard_size = kwargs.get("shard_size")
@@ -142,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
142
148
  param_data = self.data
143
149
  shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
144
150
  param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
145
- loaded_weight = loaded_weight.narrow(
146
- self.output_dim, shard_id * shard_size, shard_size
147
- )
151
+ if not use_presharded_weights:
152
+ loaded_weight = loaded_weight.narrow(
153
+ self.output_dim, shard_id * shard_size, shard_size
154
+ )
148
155
 
149
- assert param_data.shape == loaded_weight.shape
156
+ assert (
157
+ param_data.shape == loaded_weight.shape
158
+ ), f"{param_data.shape=}, {loaded_weight.shape=}"
150
159
  param_data.copy_(loaded_weight)
151
160
 
152
161
 
@@ -292,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
292
301
  packed_factor: Union[int, Fraction],
293
302
  packed_dim: int,
294
303
  marlin_tile_size: Optional[int] = None,
295
- **kwargs
304
+ **kwargs,
296
305
  ):
297
306
  self._packed_factor = packed_factor
298
307
  self._packed_dim = packed_dim
@@ -336,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
336
345
  packed_factor: Union[int, Fraction],
337
346
  packed_dim: int,
338
347
  marlin_tile_size: Optional[int] = None,
339
- **kwargs
348
+ **kwargs,
340
349
  ):
341
350
  self._packed_factor = packed_factor
342
351
  self._packed_dim = packed_dim
@@ -0,0 +1,164 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 16,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 4,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 64,
13
+ "BLOCK_SIZE_N": 16,
14
+ "BLOCK_SIZE_K": 128,
15
+ "GROUP_SIZE_M": 8,
16
+ "num_warps": 4,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 64,
22
+ "BLOCK_SIZE_N": 16,
23
+ "BLOCK_SIZE_K": 128,
24
+ "GROUP_SIZE_M": 1,
25
+ "num_warps": 4,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 64,
31
+ "BLOCK_SIZE_N": 16,
32
+ "BLOCK_SIZE_K": 128,
33
+ "GROUP_SIZE_M": 1,
34
+ "num_warps": 4,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 64,
40
+ "BLOCK_SIZE_N": 16,
41
+ "BLOCK_SIZE_K": 128,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 4,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 64,
49
+ "BLOCK_SIZE_N": 16,
50
+ "BLOCK_SIZE_K": 128,
51
+ "GROUP_SIZE_M": 32,
52
+ "num_warps": 4,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 64,
58
+ "BLOCK_SIZE_N": 16,
59
+ "BLOCK_SIZE_K": 128,
60
+ "GROUP_SIZE_M": 1,
61
+ "num_warps": 4,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 64,
67
+ "BLOCK_SIZE_N": 16,
68
+ "BLOCK_SIZE_K": 128,
69
+ "GROUP_SIZE_M": 1,
70
+ "num_warps": 4,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 16,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 64,
85
+ "BLOCK_SIZE_N": 16,
86
+ "BLOCK_SIZE_K": 128,
87
+ "GROUP_SIZE_M": 1,
88
+ "num_warps": 4,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 64,
94
+ "BLOCK_SIZE_N": 16,
95
+ "BLOCK_SIZE_K": 128,
96
+ "GROUP_SIZE_M": 1,
97
+ "num_warps": 4,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 64,
103
+ "BLOCK_SIZE_N": 16,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 4,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 64,
112
+ "BLOCK_SIZE_N": 16,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 4,
115
+ "num_warps": 4,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 64,
121
+ "BLOCK_SIZE_N": 16,
122
+ "BLOCK_SIZE_K": 128,
123
+ "GROUP_SIZE_M": 4,
124
+ "num_warps": 4,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 64,
130
+ "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 8,
133
+ "num_warps": 4,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 128,
139
+ "BLOCK_SIZE_N": 32,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 8,
142
+ "num_warps": 4,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 64,
148
+ "BLOCK_SIZE_N": 128,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 16,
151
+ "num_warps": 4,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 64,
157
+ "BLOCK_SIZE_N": 64,
158
+ "BLOCK_SIZE_K": 128,
159
+ "GROUP_SIZE_M": 16,
160
+ "num_warps": 4,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }