sglang 0.4.1.post7__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  12. sglang/srt/layers/attention/vision.py +243 -40
  13. sglang/srt/layers/dp_attention.py +3 -1
  14. sglang/srt/layers/layernorm.py +5 -5
  15. sglang/srt/layers/linear.py +24 -9
  16. sglang/srt/layers/logits_processor.py +1 -1
  17. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  18. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  22. sglang/srt/layers/parameter.py +16 -7
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/fp8.py +11 -1
  33. sglang/srt/layers/rotary_embedding.py +34 -13
  34. sglang/srt/layers/sampler.py +33 -10
  35. sglang/srt/layers/torchao_utils.py +12 -6
  36. sglang/srt/managers/detokenizer_manager.py +1 -0
  37. sglang/srt/managers/image_processor.py +77 -38
  38. sglang/srt/managers/io_struct.py +36 -5
  39. sglang/srt/managers/schedule_batch.py +31 -25
  40. sglang/srt/managers/scheduler.py +78 -38
  41. sglang/srt/managers/tokenizer_manager.py +4 -0
  42. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  43. sglang/srt/mem_cache/chunk_cache.py +3 -0
  44. sglang/srt/mem_cache/radix_cache.py +30 -1
  45. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  46. sglang/srt/model_executor/forward_batch_info.py +5 -7
  47. sglang/srt/model_executor/model_runner.py +7 -4
  48. sglang/srt/model_loader/loader.py +75 -0
  49. sglang/srt/model_loader/weight_utils.py +91 -5
  50. sglang/srt/models/commandr.py +14 -2
  51. sglang/srt/models/dbrx.py +9 -1
  52. sglang/srt/models/deepseek_v2.py +3 -3
  53. sglang/srt/models/gemma2.py +9 -1
  54. sglang/srt/models/grok.py +1 -0
  55. sglang/srt/models/minicpm3.py +3 -3
  56. sglang/srt/models/minicpmv.py +129 -76
  57. sglang/srt/models/mllama.py +16 -56
  58. sglang/srt/models/qwen2.py +4 -1
  59. sglang/srt/models/qwen2_vl.py +18 -8
  60. sglang/srt/models/torch_native_llama.py +17 -4
  61. sglang/srt/openai_api/adapter.py +139 -37
  62. sglang/srt/openai_api/protocol.py +5 -4
  63. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  64. sglang/srt/sampling/sampling_batch_info.py +4 -14
  65. sglang/srt/server.py +2 -2
  66. sglang/srt/server_args.py +26 -1
  67. sglang/srt/speculative/eagle_utils.py +37 -15
  68. sglang/srt/speculative/eagle_worker.py +11 -13
  69. sglang/srt/utils.py +62 -67
  70. sglang/test/test_programs.py +1 -0
  71. sglang/test/test_utils.py +81 -22
  72. sglang/utils.py +42 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +8 -8
  75. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +78 -67
  76. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ from typing import Optional
4
4
 
5
5
  import torch
6
6
  import torch.nn as nn
7
+ import torch.nn.functional as F
7
8
  from einops import rearrange, repeat
8
9
 
9
10
  from sglang.srt.distributed import parallel_state
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
63
64
 
64
65
 
65
66
  class VisionAttention(nn.Module):
66
- """Multi-headed attention without any cache, mostly used for ViT."""
67
+ r"""
68
+ Multi-headed attention without any cache, mostly used for ViT.
69
+
70
+
71
+ Args:
72
+ use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
73
+ use_context_forward (bool, default to True):
74
+ if ``True``, a flash_attn style attention will be applied
75
+ Otherwise, a full-sequence attention will be applied.
76
+ use_full_precision_softmax (bool, default to False):
77
+ if ``True``, the softmax will be performed in full-precision
78
+ Otherwise, it will be performed in half-precision
79
+
80
+ """
67
81
 
68
82
  def __init__(
69
83
  self,
@@ -72,25 +86,39 @@ class VisionAttention(nn.Module):
72
86
  projection_size: int,
73
87
  use_qkv_parallel: bool,
74
88
  quant_config: Optional[QuantizationConfig] = None,
89
+ dropout: float = 0.0,
90
+ use_context_forward: bool = True,
91
+ use_full_precision_softmax: bool = False,
92
+ flatten_batch: bool = False,
75
93
  prefix: str = "",
76
94
  ):
77
95
  super().__init__()
96
+ self.use_context_forward = use_context_forward
78
97
  world_size = parallel_state.get_tensor_model_parallel_world_size()
79
-
98
+ self.dropout = dropout
99
+ self.head_size = embed_dim // num_heads
80
100
  self.hidden_size_per_attention_head = dist_utils.divide(
81
101
  projection_size, num_heads
82
102
  )
83
103
  self.num_attention_heads_per_partition = dist_utils.divide(
84
104
  num_heads, world_size
85
105
  )
86
- # self.tp_size = get_tensor_model_parallel_world_size()
87
- # num_heads = self.num_heads_per_partition
106
+
107
+ if self.use_context_forward:
108
+ self.qkv_backend = VisionTritonAttention()
109
+ else:
110
+ self.qkv_backend = VisionSdpaAttention(
111
+ head_size=self.head_size,
112
+ dropout=dropout,
113
+ flatten_batch=flatten_batch,
114
+ use_full_precision_softmax=use_full_precision_softmax,
115
+ )
116
+
88
117
  self.use_qkv_parallel = use_qkv_parallel
89
118
  if use_qkv_parallel:
90
- self.head_dim = embed_dim // num_heads
91
119
  self.qkv_proj = QKVParallelLinear(
92
120
  hidden_size=embed_dim,
93
- head_size=self.head_dim,
121
+ head_size=self.head_size,
94
122
  total_num_heads=num_heads,
95
123
  quant_config=quant_config,
96
124
  prefix=f"{prefix}.qkv_proj",
@@ -114,12 +142,15 @@ class VisionAttention(nn.Module):
114
142
  x: torch.Tensor,
115
143
  cu_seqlens: Optional[torch.Tensor] = None,
116
144
  rotary_pos_emb: torch.Tensor = None,
145
+ attention_mask: Optional[torch.Tensor] = None,
117
146
  ) -> torch.Tensor:
147
+ r"""
148
+ Args:
149
+ x: [b, s, embed_dim]
150
+ cu_seqlens: [b]
151
+ Returns:
152
+ [s, b, num_heads * head]
118
153
  """
119
- Input shape: [b, s, embed_dim]
120
- Output shape: [s, b, num_heads * head_size]
121
- """
122
-
123
154
  bsz, s, _ = x.shape
124
155
  if self.use_qkv_parallel:
125
156
  # [b, s, embed_dim] --> [b, s, embed_dim]
@@ -136,19 +167,19 @@ class VisionAttention(nn.Module):
136
167
  else:
137
168
  # [b, s, embed_dim] --> [s, b, embed_dim]
138
169
  x = rearrange(x, "b s ... -> s b ...")
139
- # [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
170
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
140
171
  qkv, _ = self.qkv_proj(x)
141
- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
172
+ # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
142
173
  new_x_shape = qkv.size()[:-1] + (
143
174
  self.num_attention_heads_per_partition,
144
175
  3 * self.hidden_size_per_attention_head,
145
176
  )
146
177
  qkv = qkv.view(*new_x_shape)
147
178
 
148
- # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
179
+ # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
149
180
  q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
150
181
 
151
- # [s, b, head, head_dim] --> [b, s, head, head_dim]
182
+ # [s, b, head, head_size] --> [b, s, head, head_size]
152
183
  q, k, v = [
153
184
  rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
154
185
  ]
@@ -160,45 +191,217 @@ class VisionAttention(nn.Module):
160
191
  if self.use_qkv_parallel:
161
192
  pass
162
193
  else:
163
- # [b, s, head, head_dim] --> [b * s, head, head_dim]
194
+ # [b, s, head, head_size] --> [b * s, head, head_size]
164
195
  q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
165
196
 
166
- # [b * s, num_heads, head_size]
167
- output = torch.empty_like(q)
168
-
169
- seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
170
- max_seqlen = seq_lens.max().item()
171
-
172
- context_attention_fwd(
173
- q,
174
- k,
175
- v,
176
- output,
177
- cu_seqlens.cuda(),
178
- seq_lens,
179
- max_seqlen,
180
- is_causal=False,
181
- )
197
+ output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
182
198
 
183
199
  if self.use_qkv_parallel:
184
-
185
- # [b * s, head, head_dim] --> [b, s, head * head_dim]
200
+ # [b * s, h, head_size] --> [b, s, h * head_size]
186
201
  output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
187
202
 
188
- # [b, s, head, head_dim] --> [b, s, head, head_dim]
203
+ # [b, s, h * head_size] --> [b, s, h * head_size]
189
204
  output, _ = self.proj(output)
190
205
  else:
191
- # [b * s, head, head_dim] --> [b, s, head, head_dim]
192
- context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
193
-
194
- # [s, b, num_heads * head_size]
206
+ # [b * s, h, head_size] --> [s, b, h * head_size]
195
207
  context_layer = rearrange(
196
- context_layer, "b s h d -> s b (h d)"
208
+ output, "(b s) h d -> s b (h d)", b=bsz, s=s
197
209
  ).contiguous()
198
210
 
199
- # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
211
+ # [s, b, h * head_size] --> [s, b, h * head_size]
200
212
  output, _ = self.proj(context_layer)
201
213
 
214
+ # [s, b, h * head_size] --> [b, s, h * head_size]
202
215
  output = output.view(bsz, s, -1)
203
216
 
204
217
  return output
218
+
219
+
220
+ class VisionSdpaAttention(nn.Module):
221
+ r"""
222
+ Scaled Dot Product Attention inner product
223
+
224
+ """
225
+
226
+ # TODO: Should it be released after used?
227
+ _mask_cache = {}
228
+
229
+ def __init__(
230
+ self,
231
+ head_size: int,
232
+ dropout: float = 0.0,
233
+ flatten_batch: bool = False,
234
+ use_full_precision_softmax: bool = False,
235
+ ):
236
+ super().__init__()
237
+ self.head_size = head_size
238
+ self.flatten_batch = flatten_batch
239
+ self.use_full_precision_softmax = use_full_precision_softmax
240
+ self.dropout = dropout
241
+
242
+ def generate_patch_attention_mask(
243
+ self,
244
+ s: int,
245
+ bsz: int,
246
+ device,
247
+ cu_seqlens: Optional[torch.Tensor],
248
+ flatten_batch: bool = False,
249
+ dtype=torch.bfloat16,
250
+ ) -> torch.Tensor:
251
+ r"""
252
+ Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
253
+
254
+ When `flatten_batch` is True:
255
+ - All sequences in the batch are flattened into a single dimension
256
+ - `s` represents the total number of tokens across all sequences in the batch
257
+ - Returns a unified mask of shape `(1, 1, s, s)`
258
+
259
+ When `flatten_batch` is False:
260
+ - Each sequence has its own attention mask
261
+ - `s` represents the maximum sequence length in the batch
262
+ - Returns separate masks of shape `(b, 1, s, s)`
263
+
264
+ Args:
265
+ flatten_batch: (bool):
266
+ If True, treats all sequences in the batch as a single flattened sequence
267
+ If False, generates separate masks for each sequence
268
+
269
+ Returns:
270
+ Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
271
+ """
272
+
273
+ cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
274
+
275
+ if cache_key in VisionSdpaAttention._mask_cache:
276
+ cached_mask = VisionSdpaAttention._mask_cache[cache_key]
277
+ # print(f"cache hit for key: {cache_key}")
278
+ return cached_mask.to(device=device, dtype=dtype)
279
+
280
+ if cu_seqlens is None:
281
+ raise ValueError("Internal Error: cu_seqlens cannot be None")
282
+
283
+ if flatten_batch:
284
+ mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
285
+ for i in range(1, len(cu_seqlens)):
286
+ start = cu_seqlens[i - 1]
287
+ end = cu_seqlens[i]
288
+ mask[
289
+ ...,
290
+ start:end,
291
+ start:end,
292
+ ] = True
293
+ else:
294
+ # [1, 1, 1, s]
295
+ row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
296
+ # [1, 1, s, 1]
297
+ col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
298
+ # [b, 1, 1, 1]
299
+ seq_lens = (
300
+ (cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
301
+ )
302
+
303
+ mask = (row_indices < seq_lens) & (col_indices < seq_lens)
304
+
305
+ # Convert to attention mask format (False -> 0, True -> -inf)
306
+ mask = (~mask).to(dtype) * torch.finfo(dtype).min
307
+
308
+ VisionSdpaAttention._mask_cache[cache_key] = mask
309
+
310
+ return mask
311
+
312
+ def forward(
313
+ self,
314
+ q: torch.Tensor,
315
+ k: torch.Tensor,
316
+ v: torch.Tensor,
317
+ bsz: int,
318
+ cu_seqlens: Optional[torch.Tensor] = None,
319
+ attention_mask: Optional[torch.Tensor] = None,
320
+ ) -> torch.Tensor:
321
+ r"""
322
+ Args:
323
+ cu_seqlens: [b]
324
+ Returns:
325
+ [b * s, h, head_size]
326
+ """
327
+
328
+ s = q.shape[0] // bsz
329
+
330
+ # [b, 1, s, s]
331
+ if attention_mask is None:
332
+ attention_mask = self.generate_patch_attention_mask(
333
+ s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
334
+ )
335
+ q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
336
+ # [b, 1, s]
337
+ if self.use_full_precision_softmax:
338
+ scale = self.head_size**-0.5
339
+ k_transposed = rearrange(k, "b h s d -> b h d s")
340
+ attn_weights = torch.matmul(q, k_transposed) * scale
341
+ del k, k_transposed
342
+ attn_weights = attn_weights + attention_mask
343
+ del attention_mask
344
+ # full-precision
345
+ attn_weights = nn.functional.softmax(
346
+ attn_weights, dim=-1, dtype=torch.float32
347
+ ).to(q.dtype)
348
+ attn_weights = nn.functional.dropout(
349
+ attn_weights, p=self.dropout, training=False
350
+ )
351
+ output = torch.matmul(attn_weights, v)
352
+ del attn_weights, v
353
+ else:
354
+ # SDPA
355
+ # [b, h, s, head_size]
356
+ output = F.scaled_dot_product_attention(
357
+ q, k, v, attention_mask, dropout_p=self.dropout
358
+ )
359
+
360
+ # [b, h, s, head_size] --> [b * s, h, head_size]
361
+ output = rearrange(output, "b h s d -> (b s) h d")
362
+
363
+ return output
364
+
365
+
366
+ class VisionTritonAttention(nn.Module):
367
+ """
368
+ Triton-implemented attention without a causal mask
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ ):
374
+ super().__init__()
375
+
376
+ def forward(
377
+ self,
378
+ q: torch.Tensor,
379
+ k: torch.Tensor,
380
+ v: torch.Tensor,
381
+ _bsz: int,
382
+ cu_seqlens: Optional[torch.Tensor],
383
+ **kwargs,
384
+ ) -> torch.Tensor:
385
+ r"""
386
+ Args:
387
+ cu_seqlens: [b]
388
+ Returns:
389
+ [b * s, h, head_size]
390
+ """
391
+
392
+ # [b * s, head, head_size]
393
+ output = torch.empty_like(q)
394
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
395
+ max_seqlen = seq_lens.max().item()
396
+ context_attention_fwd(
397
+ q,
398
+ k,
399
+ v,
400
+ output,
401
+ cu_seqlens.cuda(),
402
+ seq_lens.cuda(),
403
+ max_seqlen,
404
+ is_causal=False,
405
+ )
406
+
407
+ return output
@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
22
22
  def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
23
23
  global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
24
24
 
25
+ from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
26
+
25
27
  _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
26
28
  enable_dp_attention, tp_rank, tp_size, dp_size
27
29
  )
@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
35
37
  ],
36
38
  tp_rank,
37
39
  torch.distributed.get_backend(tp_group.device_group),
38
- False,
40
+ SYNC_TOKEN_IDS_ACROSS_TP,
39
41
  False,
40
42
  False,
41
43
  False,
@@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union
19
19
  import torch
20
20
  import torch.nn as nn
21
21
 
22
- from sglang.srt.utils import is_flashinfer_available
22
+ from sglang.srt.utils import is_cuda_available
23
23
 
24
- if is_flashinfer_available():
25
- from flashinfer.norm import (
24
+ if is_cuda_available():
25
+ from sgl_kernel import (
26
26
  fused_add_rmsnorm,
27
27
  gemma_fused_add_rmsnorm,
28
28
  gemma_rmsnorm,
@@ -121,8 +121,8 @@ class GemmaRMSNorm(CustomOp):
121
121
  return out
122
122
 
123
123
 
124
- if not is_flashinfer_available():
124
+ if not is_cuda_available():
125
125
  logger.info(
126
- "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
126
+ "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
127
127
  )
128
128
  from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@@ -329,12 +329,14 @@ class ColumnParallelLinear(LinearBase):
329
329
  prefix: str = "",
330
330
  tp_rank: Optional[int] = None,
331
331
  tp_size: Optional[int] = None,
332
+ use_presharded_weights: bool = False,
332
333
  ):
333
334
  super().__init__(
334
335
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
335
336
  )
336
337
 
337
338
  self.gather_output = gather_output
339
+ self.use_presharded_weights = use_presharded_weights
338
340
 
339
341
  # Divide the weight matrix along the last dimension.
340
342
  if tp_rank is None:
@@ -402,7 +404,8 @@ class ColumnParallelLinear(LinearBase):
402
404
  if output_dim is not None and not use_bitsandbytes_4bit:
403
405
  shard_size = param_data.shape[output_dim]
404
406
  start_idx = self.tp_rank * shard_size
405
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
407
+ if not self.use_presharded_weights:
408
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
406
409
 
407
410
  # Special case for loading scales off disk, which often do not
408
411
  # have a shape (such as in the case of AutoFP8).
@@ -418,7 +421,11 @@ class ColumnParallelLinear(LinearBase):
418
421
  if len(loaded_weight.shape) == 0:
419
422
  assert loaded_weight.numel() == 1
420
423
  loaded_weight = loaded_weight.reshape(1)
421
- param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
424
+ param.load_column_parallel_weight(
425
+ loaded_weight,
426
+ tp_rank=self.tp_rank,
427
+ use_presharded_weights=self.use_presharded_weights,
428
+ )
422
429
 
423
430
  def forward(self, input_):
424
431
  bias = self.bias if not self.skip_bias_add else None
@@ -499,7 +506,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
499
506
  prefix=prefix,
500
507
  tp_rank=tp_rank,
501
508
  tp_size=tp_size,
509
+ use_presharded_weights=use_presharded_weights,
502
510
  )
511
+ self.prefix = prefix
503
512
 
504
513
  def weight_loader(
505
514
  self,
@@ -743,6 +752,7 @@ class QKVParallelLinear(ColumnParallelLinear):
743
752
  prefix: str = "",
744
753
  tp_rank: Optional[int] = None,
745
754
  tp_size: Optional[int] = None,
755
+ load_presharded_attn: bool = False,
746
756
  ):
747
757
  self.hidden_size = hidden_size
748
758
  self.head_size = head_size
@@ -772,6 +782,7 @@ class QKVParallelLinear(ColumnParallelLinear):
772
782
  self.num_kv_heads * self.head_size * tp_size, # k_proj
773
783
  self.num_kv_heads * self.head_size * tp_size, # v_proj
774
784
  ]
785
+ self.use_presharded_weights = load_presharded_attn
775
786
 
776
787
  super().__init__(
777
788
  input_size=input_size,
@@ -784,6 +795,7 @@ class QKVParallelLinear(ColumnParallelLinear):
784
795
  prefix=prefix,
785
796
  tp_rank=tp_rank,
786
797
  tp_size=tp_size,
798
+ use_presharded_weights=self.use_presharded_weights,
787
799
  )
788
800
 
789
801
  def _get_shard_offset_mapping(self, loaded_shard_id: str):
@@ -842,9 +854,10 @@ class QKVParallelLinear(ColumnParallelLinear):
842
854
  shard_size=shard_size, shard_offset=shard_offset
843
855
  )
844
856
 
845
- loaded_weight_shard = loaded_weight.narrow(
846
- param.output_dim, shard_offset, shard_size
847
- )
857
+ if not self.use_presharded_weights:
858
+ loaded_weight_shard = loaded_weight.narrow(
859
+ param.output_dim, shard_offset, shard_size
860
+ )
848
861
  self.weight_loader_v2(param, loaded_weight_shard, shard_id)
849
862
 
850
863
  def weight_loader_v2(
@@ -882,6 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
882
895
  shard_offset=shard_offset,
883
896
  shard_size=shard_size,
884
897
  tp_rank=self.tp_rank,
898
+ use_presharded_weights=self.use_presharded_weights,
885
899
  )
886
900
 
887
901
  def weight_loader(
@@ -987,9 +1001,10 @@ class QKVParallelLinear(ColumnParallelLinear):
987
1001
  param, orig_qkv_offsets, shard_id
988
1002
  )
989
1003
 
990
- loaded_weight_shard = loaded_weight.narrow(
991
- output_dim, shard_offset, shard_size
992
- )
1004
+ if not self.use_presharded_weights:
1005
+ loaded_weight_shard = loaded_weight.narrow(
1006
+ output_dim, shard_offset, shard_size
1007
+ )
993
1008
  self.weight_loader(param, loaded_weight_shard, shard_id)
994
1009
  return
995
1010
 
@@ -1049,7 +1064,7 @@ class QKVParallelLinear(ColumnParallelLinear):
1049
1064
 
1050
1065
  # bitsandbytes loads the weights of the specific portion
1051
1066
  # no need to narrow here
1052
- if not use_bitsandbytes_4bit:
1067
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
1053
1068
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1054
1069
 
1055
1070
  # Special case for for AQLM codebooks.
@@ -296,7 +296,7 @@ def fused_softcap_kernel(
296
296
  n_elements,
297
297
  BLOCK_SIZE: tl.constexpr,
298
298
  ):
299
- pid = tl.program_id(0)
299
+ pid = tl.program_id(0).to(tl.int64)
300
300
  block_start = pid * BLOCK_SIZE
301
301
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
302
302
  mask = offsets < n_elements
@@ -114,6 +114,8 @@ class EPMoE(torch.nn.Module):
114
114
  tp_size: Optional[int] = None,
115
115
  prefix: str = "",
116
116
  correction_bias: Optional[torch.Tensor] = None,
117
+ custom_routing_function: Optional[Callable] = None,
118
+ activation: str = "silu",
117
119
  ):
118
120
  super().__init__()
119
121
 
@@ -140,6 +142,8 @@ class EPMoE(torch.nn.Module):
140
142
  self.num_expert_group = num_expert_group
141
143
  self.topk_group = topk_group
142
144
  self.correction_bias = correction_bias
145
+ self.custom_routing_function = custom_routing_function
146
+ self.activation = activation
143
147
 
144
148
  if quant_config is None:
145
149
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
@@ -166,6 +170,7 @@ class EPMoE(torch.nn.Module):
166
170
 
167
171
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
168
172
  assert self.quant_method is not None
173
+ assert self.activation == "silu"
169
174
 
170
175
  if self.grouped_gemm_runner is None:
171
176
  self.grouped_gemm_runner = GroupedGemmRunner(
@@ -181,6 +186,7 @@ class EPMoE(torch.nn.Module):
181
186
  topk_group=self.topk_group,
182
187
  num_expert_group=self.num_expert_group,
183
188
  correction_bias=self.correction_bias,
189
+ custom_routing_function=self.custom_routing_function,
184
190
  )
185
191
 
186
192
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -254,16 +260,20 @@ class EPMoE(torch.nn.Module):
254
260
  dtype=torch.float32,
255
261
  device=hidden_states.device,
256
262
  )
257
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
258
- gateup_output,
259
- down_input,
260
- gateup_output.shape[1],
261
- reorder_topk_ids,
262
- self.w2_input_scale,
263
- self.start_expert_id,
264
- self.end_expert_id,
265
- BLOCK_SIZE=512,
266
- )
263
+
264
+ if self.activation == "silu":
265
+ silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
266
+ gateup_output,
267
+ down_input,
268
+ gateup_output.shape[1],
269
+ reorder_topk_ids,
270
+ self.w2_input_scale,
271
+ self.start_expert_id,
272
+ self.end_expert_id,
273
+ BLOCK_SIZE=512,
274
+ )
275
+ else:
276
+ raise ValueError(f"Unsupported activation: {self.activation=}")
267
277
 
268
278
  # GroupGemm-1
269
279
  down_output = torch.empty(
@@ -309,7 +319,6 @@ class EPMoE(torch.nn.Module):
309
319
  ckpt_up_proj_name: str,
310
320
  num_experts: int,
311
321
  ) -> List[Tuple[str, str, int, str]]:
312
-
313
322
  return [
314
323
  # (param_name, weight_name, expert_id, shard_id)
315
324
  (
@@ -354,7 +363,6 @@ class EPMoE(torch.nn.Module):
354
363
  )
355
364
  return
356
365
 
357
- expert_data = param.data[expert_id]
358
366
  if shard_id == "w2":
359
367
  param.data[expert_id] = loaded_weight
360
368
  elif shard_id == "w1":
@@ -8,7 +8,7 @@ from typing import Callable, Optional
8
8
  import torch
9
9
  from torch.nn import functional as F
10
10
 
11
- from sglang.srt.layers.activation import SiluAndMul
11
+ from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
12
12
  from sglang.srt.layers.moe.topk import select_experts
13
13
 
14
14
 
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
23
23
  num_expert_group: Optional[int] = None,
24
24
  custom_routing_function: Optional[Callable] = None,
25
25
  correction_bias: Optional[torch.Tensor] = None,
26
+ activation: str = "silu",
26
27
  ) -> torch.Tensor:
27
28
  topk_weights, topk_ids = select_experts(
28
29
  hidden_states=x,
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
41
42
  w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
42
43
  w2_weights = layer.w2_weight[topk_ids]
43
44
  x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
44
- x1 = F.silu(x1)
45
+ if activation == "silu":
46
+ x1 = F.silu(x1)
47
+ elif activation == "gelu":
48
+ x1 = F.gelu(x1)
49
+ else:
50
+ raise ValueError(f"Unsupported activation: {activation=}")
45
51
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
46
52
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
47
53
  return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -58,6 +64,7 @@ def moe_forward_native(
58
64
  num_expert_group: Optional[int] = None,
59
65
  custom_routing_function: Optional[Callable] = None,
60
66
  correction_bias: Optional[torch.Tensor] = None,
67
+ activation: str = "silu",
61
68
  ) -> torch.Tensor:
62
69
 
63
70
  topk_weights, topk_ids = select_experts(
@@ -84,6 +91,13 @@ def moe_forward_native(
84
91
  sorted_tokens = x[idxs // topk_ids.shape[1]]
85
92
  tokens_per_expert = tokens_per_expert.cpu().numpy()
86
93
 
94
+ if activation == "silu":
95
+ act = SiluAndMul()
96
+ elif activation == "gelu":
97
+ act = GeluAndMul()
98
+ else:
99
+ raise ValueError(f"Unsupported activation: {activation=}")
100
+
87
101
  outputs = []
88
102
  start_idx = 0
89
103
  for i, num_tokens in enumerate(tokens_per_expert):
@@ -96,7 +110,7 @@ def moe_forward_native(
96
110
  layer_w2_weight = layer.w2_weight[i]
97
111
 
98
112
  gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
99
- gate_up = SiluAndMul()(gate_up)
113
+ gate_up = act(gate_up)
100
114
  expert_out = F.linear(gate_up, layer_w2_weight)
101
115
  outputs.append(expert_out)
102
116
  start_idx = end_idx