sglang 0.5.0rc1__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. sglang/bench_one_batch.py +0 -1
  2. sglang/srt/configs/model_config.py +1 -0
  3. sglang/srt/disaggregation/decode.py +0 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/entrypoints/http_server.py +64 -0
  6. sglang/srt/entrypoints/openai/protocol.py +2 -0
  7. sglang/srt/entrypoints/openai/serving_chat.py +1 -0
  8. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  9. sglang/srt/layers/attention/flashinfer_backend.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  11. sglang/srt/layers/attention/triton_backend.py +24 -27
  12. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  13. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -3
  14. sglang/srt/layers/communicator.py +7 -7
  15. sglang/srt/layers/dp_attention.py +118 -27
  16. sglang/srt/layers/logits_processor.py +12 -18
  17. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/multimodal.py +156 -40
  29. sglang/srt/layers/quantization/__init__.py +5 -32
  30. sglang/srt/layers/quantization/awq.py +15 -16
  31. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  32. sglang/srt/layers/quantization/gptq.py +12 -17
  33. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  34. sglang/srt/layers/quantization/modelopt_quant.py +52 -30
  35. sglang/srt/layers/quantization/mxfp4.py +16 -2
  36. sglang/srt/layers/quantization/utils.py +52 -2
  37. sglang/srt/layers/sampler.py +5 -2
  38. sglang/srt/lora/layers.py +6 -2
  39. sglang/srt/managers/cache_controller.py +4 -1
  40. sglang/srt/managers/io_struct.py +14 -0
  41. sglang/srt/managers/schedule_batch.py +18 -39
  42. sglang/srt/managers/scheduler.py +3 -4
  43. sglang/srt/managers/tokenizer_manager.py +28 -18
  44. sglang/srt/mem_cache/allocator.py +8 -157
  45. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  46. sglang/srt/mem_cache/chunk_cache.py +1 -1
  47. sglang/srt/model_executor/cuda_graph_runner.py +8 -21
  48. sglang/srt/model_executor/forward_batch_info.py +8 -10
  49. sglang/srt/model_executor/model_runner.py +57 -53
  50. sglang/srt/models/deepseek_nextn.py +2 -1
  51. sglang/srt/models/deepseek_v2.py +5 -3
  52. sglang/srt/models/glm4_moe.py +2 -2
  53. sglang/srt/models/glm4_moe_nextn.py +2 -1
  54. sglang/srt/models/gpt_oss.py +7 -2
  55. sglang/srt/models/llama.py +10 -2
  56. sglang/srt/models/llama4.py +18 -5
  57. sglang/srt/models/qwen2.py +2 -2
  58. sglang/srt/models/qwen2_moe.py +20 -5
  59. sglang/srt/models/qwen3_classification.py +78 -0
  60. sglang/srt/models/qwen3_moe.py +18 -5
  61. sglang/srt/models/step3_vl.py +6 -2
  62. sglang/srt/operations.py +17 -2
  63. sglang/srt/sampling/sampling_batch_info.py +7 -4
  64. sglang/srt/server_args.py +33 -7
  65. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  66. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  67. sglang/srt/two_batch_overlap.py +4 -8
  68. sglang/test/test_marlin_moe.py +1 -1
  69. sglang/test/test_marlin_utils.py +1 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +5 -5
  72. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +75 -63
  73. sglang/srt/layers/quantization/scalar_type.py +0 -352
  74. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  75. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  76. {sglang-0.5.0rc1.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -46,6 +46,7 @@ from sglang.srt.layers.dp_attention import (
46
46
  get_attention_tp_rank,
47
47
  get_attention_tp_size,
48
48
  get_local_attention_dp_size,
49
+ is_dp_attention_enabled,
49
50
  )
50
51
  from sglang.srt.layers.layernorm import RMSNorm
51
52
  from sglang.srt.layers.linear import (
@@ -107,10 +108,14 @@ class Qwen2MoeMLP(nn.Module):
107
108
  )
108
109
  self.act_fn = SiluAndMul()
109
110
 
110
- def forward(self, x):
111
+ def forward(
112
+ self,
113
+ x,
114
+ use_reduce_scatter: bool = False,
115
+ ):
111
116
  gate_up, _ = self.gate_up_proj(x)
112
117
  x = self.act_fn(gate_up)
113
- x, _ = self.down_proj(x)
118
+ x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
114
119
  return x
115
120
 
116
121
 
@@ -175,7 +180,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
175
180
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
176
181
 
177
182
  def forward(
178
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ forward_batch: Optional[ForwardBatch] = None,
186
+ use_reduce_scatter: bool = False,
179
187
  ) -> torch.Tensor:
180
188
  num_tokens, hidden_dim = hidden_states.shape
181
189
  hidden_states = hidden_states.view(-1, hidden_dim)
@@ -193,6 +201,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
193
201
  final_hidden_states = self.experts(hidden_states, topk_output)
194
202
  if shared_output is not None:
195
203
  final_hidden_states = final_hidden_states + shared_output
204
+ if self.tp_size > 1 and not use_reduce_scatter:
196
205
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
197
206
 
198
207
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -367,6 +376,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
367
376
  layer_scatter_modes=self.layer_scatter_modes,
368
377
  input_layernorm=self.input_layernorm,
369
378
  post_attention_layernorm=self.post_attention_layernorm,
379
+ allow_reduce_scatter=True,
370
380
  )
371
381
 
372
382
  def forward(
@@ -392,7 +402,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
392
402
  hidden_states, residual, forward_batch
393
403
  )
394
404
 
395
- hidden_states = self.mlp(hidden_states, forward_batch)
405
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
406
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
407
+ forward_batch
408
+ )
409
+
410
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
396
411
 
397
412
  hidden_states, residual = self.layer_communicator.postprocess_layer(
398
413
  hidden_states, residual, forward_batch
@@ -420,7 +435,7 @@ class Qwen2MoeModel(nn.Module):
420
435
  self.embed_tokens = VocabParallelEmbedding(
421
436
  config.vocab_size,
422
437
  config.hidden_size,
423
- enable_tp=not global_server_args_dict["enable_dp_attention"],
438
+ enable_tp=not is_dp_attention_enabled(),
424
439
  prefix=add_prefix("embed_tokens", prefix),
425
440
  )
426
441
  else:
@@ -0,0 +1,78 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from typing import Iterable, Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers import Qwen2Config # Qwen3 uses Qwen2Config
20
+
21
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
25
+ from sglang.srt.utils import add_prefix
26
+
27
+
28
+ class Qwen3ForSequenceClassification(nn.Module):
29
+ def __init__(
30
+ self,
31
+ config: Qwen2Config,
32
+ quant_config: Optional[QuantizationConfig] = None,
33
+ prefix: str = "",
34
+ ) -> None:
35
+ super().__init__()
36
+ self.config = config
37
+ self.quant_config = quant_config
38
+ self.model = Qwen3Model(
39
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
40
+ )
41
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
42
+ # Use normalize=True for qwen3 embedding based on official implementation
43
+ # Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
44
+ # Official code: output = F.normalize(output, p=2, dim=1)
45
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
46
+
47
+ self.eos_token_id = config.eos_token_id
48
+
49
+ @torch.no_grad()
50
+ def forward(
51
+ self,
52
+ input_ids: torch.Tensor,
53
+ positions: torch.Tensor,
54
+ forward_batch: ForwardBatch,
55
+ input_embeds: Optional[torch.Tensor] = None,
56
+ get_embedding: bool = True,
57
+ ) -> EmbeddingPoolerOutput:
58
+ assert (
59
+ get_embedding
60
+ ), "Qwen3ForSequenceClassification is only used for embedding"
61
+
62
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
63
+ logits = self.score(hidden_states)
64
+ pooled_logits = self.pooler(logits, forward_batch).embeddings
65
+
66
+ return EmbeddingPoolerOutput(pooled_logits)
67
+
68
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
69
+ # Filter out lm_head weights of Qwen3ForCausalLM
70
+ filtered_weights = [
71
+ (name, w) for name, w in weights if not name.startswith("lm_head")
72
+ ]
73
+ return Qwen3ForCausalLM.load_weights(self, filtered_weights)
74
+
75
+
76
+ EntryClass = [
77
+ Qwen3ForSequenceClassification,
78
+ ]
@@ -144,11 +144,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
144
144
  self.top_k = config.num_experts_per_tok
145
145
 
146
146
  def forward(
147
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
147
+ self,
148
+ hidden_states: torch.Tensor,
149
+ forward_batch: Optional[ForwardBatch] = None,
150
+ use_reduce_scatter: bool = False,
148
151
  ) -> torch.Tensor:
149
152
 
150
153
  if not global_server_args_dict["moe_a2a_backend"].is_deepep():
151
- return self.forward_normal(hidden_states)
154
+ return self.forward_normal(hidden_states, use_reduce_scatter)
152
155
  else:
153
156
  return self.forward_deepep(hidden_states, forward_batch)
154
157
 
@@ -159,7 +162,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
159
162
  if name not in ["correction_bias"]
160
163
  ]
161
164
 
162
- def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
165
+ def forward_normal(
166
+ self,
167
+ hidden_states: torch.Tensor,
168
+ use_reduce_scatter: bool = False,
169
+ ) -> torch.Tensor:
163
170
  num_tokens, hidden_dim = hidden_states.shape
164
171
  hidden_states = hidden_states.view(-1, hidden_dim)
165
172
 
@@ -167,7 +174,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
167
174
  router_logits, _ = self.gate(hidden_states)
168
175
  topk_output = self.topk(hidden_states, router_logits)
169
176
  final_hidden_states = self.experts(hidden_states, topk_output)
170
- if self.tp_size > 1:
177
+ if self.tp_size > 1 and not use_reduce_scatter:
171
178
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
172
179
 
173
180
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -521,6 +528,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
521
528
  layer_scatter_modes=self.layer_scatter_modes,
522
529
  input_layernorm=self.input_layernorm,
523
530
  post_attention_layernorm=self.post_attention_layernorm,
531
+ allow_reduce_scatter=True,
524
532
  )
525
533
 
526
534
  def forward(
@@ -546,7 +554,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
546
554
  hidden_states, residual, forward_batch
547
555
  )
548
556
 
549
- hidden_states = self.mlp(hidden_states, forward_batch)
557
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
558
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
559
+ forward_batch
560
+ )
561
+
562
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
550
563
 
551
564
  hidden_states, residual = self.layer_communicator.postprocess_layer(
552
565
  hidden_states, residual, forward_batch
@@ -25,7 +25,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
25
25
  from sglang.srt.layers.activation import SiluAndMul
26
26
  from sglang.srt.layers.attention.vision import VisionAttention
27
27
  from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
28
- from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
28
+ from sglang.srt.layers.dp_attention import (
29
+ get_attention_tp_rank,
30
+ get_attention_tp_size,
31
+ is_dp_attention_enabled,
32
+ )
29
33
  from sglang.srt.layers.layernorm import RMSNorm
30
34
  from sglang.srt.layers.linear import (
31
35
  ColumnParallelLinear,
@@ -437,7 +441,7 @@ class Step3TextModel(nn.Module):
437
441
  self.embed_tokens = VocabParallelEmbedding(
438
442
  config.vocab_size,
439
443
  config.hidden_size,
440
- enable_tp=not global_server_args_dict["enable_dp_attention"],
444
+ enable_tp=not is_dp_attention_enabled(),
441
445
  prefix=add_prefix("embed_tokens", prefix),
442
446
  )
443
447
 
sglang/srt/operations.py CHANGED
@@ -1,10 +1,17 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  from contextlib import contextmanager
3
5
  from dataclasses import dataclass
4
- from typing import Any, Callable, Dict, Generator, List, Sequence, Union
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Sequence, Union
5
7
 
6
8
  import torch
7
9
 
10
+ from sglang.srt.layers.dp_attention import set_dp_buffer_len
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
14
+
8
15
  _ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
9
16
 
10
17
  if _ENABLE_PROFILE:
@@ -66,18 +73,26 @@ Stage = List[ExecutionOperation]
66
73
 
67
74
 
68
75
  class _StageExecutor:
69
- def __init__(self, debug_name: str, stages: List[Stage], inputs):
76
+ def __init__(self, debug_name: str, stages: List[Stage], inputs: dict):
70
77
  self._debug_name = debug_name
71
78
  self._stages = stages
72
79
  self._index = 0
73
80
  self._stage_state = _StateDict()
74
81
  self._stage_output = inputs
75
82
 
83
+ # handling DP attention
84
+ forward_batch: ForwardBatch = inputs["forward_batch"]
85
+ self._global_dp_buffer_len = forward_batch.global_dp_buffer_len
86
+ self._local_dp_buffer_len = forward_batch.input_ids.shape[0]
87
+
76
88
  def next(self):
77
89
  assert not self.done
78
90
 
79
91
  stage = self._stages[self._index]
80
92
 
93
+ if self._global_dp_buffer_len is not None:
94
+ set_dp_buffer_len(self._global_dp_buffer_len, self._local_dp_buffer_len)
95
+
81
96
  with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
82
97
  for op in stage:
83
98
  with _annotate_region(debug_name=op.debug_name):
@@ -68,6 +68,8 @@ class SamplingBatchInfo:
68
68
 
69
69
  @classmethod
70
70
  def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
71
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
72
+
71
73
  reqs = batch.reqs
72
74
  device = batch.device
73
75
  temperatures = (
@@ -97,10 +99,11 @@ class SamplingBatchInfo:
97
99
  logit_bias[i, int(key)] = value
98
100
 
99
101
  # Check if any request has custom logit processor
100
- has_custom_logit_processor = (
101
- batch.enable_custom_logit_processor # check the flag first.
102
- and any(r.custom_logit_processor for r in reqs) # then check the requests.
103
- )
102
+ has_custom_logit_processor = global_server_args_dict[
103
+ "enable_custom_logit_processor"
104
+ ] and any( # check the flag first.
105
+ r.custom_logit_processor for r in reqs
106
+ ) # then check the requests.
104
107
 
105
108
  if has_custom_logit_processor:
106
109
  # Merge the same type of custom logit processors together
sglang/srt/server_args.py CHANGED
@@ -24,7 +24,7 @@ import tempfile
24
24
  from typing import List, Literal, Optional, Union
25
25
 
26
26
  from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
27
- from sglang.srt.layers.utils import is_sm100_supported
27
+ from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
28
28
  from sglang.srt.lora.lora_registry import LoRARef
29
29
  from sglang.srt.reasoning_parser import ReasoningParser
30
30
  from sglang.srt.utils import (
@@ -124,6 +124,7 @@ class ServerArgs:
124
124
  # API related
125
125
  api_key: Optional[str] = None
126
126
  served_model_name: Optional[str] = None
127
+ weight_version: str = "default"
127
128
  chat_template: Optional[str] = None
128
129
  completion_template: Optional[str] = None
129
130
  file_storage_path: str = "sglang_storage"
@@ -575,6 +576,7 @@ class ServerArgs:
575
576
  "Pipeline parallelism is incompatible with overlap schedule."
576
577
  )
577
578
 
579
+ # Hicache
578
580
  if self.hicache_storage_backend == "mooncake":
579
581
  # to use mooncake storage backend, the following conditions must be met:
580
582
  self.hicache_io_backend = "kernel"
@@ -1162,6 +1164,12 @@ class ServerArgs:
1162
1164
  default=ServerArgs.served_model_name,
1163
1165
  help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
1164
1166
  )
1167
+ parser.add_argument(
1168
+ "--weight-version",
1169
+ type=str,
1170
+ default=ServerArgs.weight_version,
1171
+ help="Version identifier for the model weights. Defaults to 'default' if not specified.",
1172
+ )
1165
1173
  parser.add_argument(
1166
1174
  "--chat-template",
1167
1175
  type=str,
@@ -1316,19 +1324,23 @@ class ServerArgs:
1316
1324
 
1317
1325
  # Kernel backend
1318
1326
  ATTN_BACKENDS = [
1319
- "aiter",
1327
+ # Common
1328
+ "triton",
1329
+ "torch_native",
1330
+ # NVIDIA specific
1320
1331
  "cutlass_mla",
1321
1332
  "fa3",
1322
1333
  "flashinfer",
1323
1334
  "flashmla",
1324
- "intel_amx",
1325
- "torch_native",
1326
- "ascend",
1327
- "triton",
1328
1335
  "trtllm_mla",
1329
1336
  "trtllm_mha",
1330
1337
  "dual_chunk_flash_attn",
1338
+ # AMD specific
1339
+ "aiter",
1331
1340
  "wave",
1341
+ # Other platforms
1342
+ "intel_amx",
1343
+ "ascend",
1332
1344
  ]
1333
1345
  parser.add_argument(
1334
1346
  "--attention-backend",
@@ -2105,11 +2117,25 @@ class ServerArgs:
2105
2117
  model_arch = hf_config.architectures[0]
2106
2118
  if model_arch in ["GptOssForCausalLM"]:
2107
2119
  if self.attention_backend is None:
2108
- self.attention_backend = "triton"
2120
+ if is_sm100_supported():
2121
+ self.attention_backend = "trtllm_mha"
2122
+ elif is_sm90_supported():
2123
+ self.attention_backend = "fa3"
2124
+ else:
2125
+ self.attention_backend = "triton"
2109
2126
  supported_backends = ["triton", "trtllm_mha", "fa3"]
2127
+ logger.info(
2128
+ f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
2129
+ )
2110
2130
  assert (
2111
2131
  self.attention_backend in supported_backends
2112
2132
  ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
2133
+
2134
+ if is_sm100_supported():
2135
+ self.enable_flashinfer_allreduce_fusion = True
2136
+ logger.info(
2137
+ "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
2138
+ )
2113
2139
  quantization_config = getattr(hf_config, "quantization_config", None)
2114
2140
  is_mxfp4_quant_format = (
2115
2141
  quantization_config is not None
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.layers.dp_attention import DPPaddingMode
8
+ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
@@ -105,30 +105,15 @@ class EAGLEDraftCudaGraphRunner:
105
105
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
106
106
  (self.dp_size,), dtype=torch.int32
107
107
  )
108
- self.gathered_buffer = torch.zeros(
109
- (
110
- self.max_num_token * self.dp_size,
111
- self.model_runner.model_config.hidden_size,
112
- ),
113
- dtype=self.model_runner.dtype,
114
- )
115
108
  else:
116
109
  assert self.require_attn_tp_gather
117
110
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
118
111
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
119
112
  (1,), dtype=torch.int32
120
113
  )
121
- self.gathered_buffer = torch.zeros(
122
- (
123
- self.max_num_token,
124
- self.model_runner.model_config.hidden_size,
125
- ),
126
- dtype=self.model_runner.dtype,
127
- )
128
114
  else:
129
115
  self.global_num_tokens_gpu = None
130
116
  self.global_num_tokens_for_logprob_gpu = None
131
- self.gathered_buffer = None
132
117
 
133
118
  # Capture
134
119
  try:
@@ -193,7 +178,7 @@ class EAGLEDraftCudaGraphRunner:
193
178
  )
194
179
  )
195
180
  global_num_tokens = self.global_num_tokens_gpu
196
- gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
181
+ global_dp_buffer_len = num_tokens * self.dp_size
197
182
  global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
198
183
  elif self.require_attn_tp_gather:
199
184
  self.global_num_tokens_gpu.copy_(
@@ -211,11 +196,11 @@ class EAGLEDraftCudaGraphRunner:
211
196
  )
212
197
  )
213
198
  global_num_tokens = self.global_num_tokens_gpu
214
- gathered_buffer = self.gathered_buffer[:num_tokens]
199
+ global_dp_buffer_len = num_tokens
215
200
  global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
216
201
  else:
217
202
  global_num_tokens = None
218
- gathered_buffer = None
203
+ global_dp_buffer_len = None
219
204
  global_num_tokens_for_logprob = None
220
205
 
221
206
  spec_info = EagleDraftInput(
@@ -239,8 +224,8 @@ class EAGLEDraftCudaGraphRunner:
239
224
  return_logprob=False,
240
225
  positions=positions,
241
226
  global_num_tokens_gpu=global_num_tokens,
242
- dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
243
- gathered_buffer=gathered_buffer,
227
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
228
+ global_dp_buffer_len=global_dp_buffer_len,
244
229
  spec_algorithm=self.model_runner.spec_algorithm,
245
230
  spec_info=spec_info,
246
231
  capture_hidden_mode=(
@@ -258,6 +243,7 @@ class EAGLEDraftCudaGraphRunner:
258
243
  def run_once():
259
244
  # Clean intermediate result cache for DP attention
260
245
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
246
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
261
247
 
262
248
  # Backup two fields, which will be modified in-place in `draft_forward`.
263
249
  output_cache_loc_backup = forward_batch.out_cache_loc
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.layers.dp_attention import DPPaddingMode
8
+ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
117
117
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
118
118
  (self.dp_size,), dtype=torch.int32
119
119
  )
120
- self.gathered_buffer = torch.zeros(
121
- (
122
- self.max_num_token * self.dp_size,
123
- self.model_runner.model_config.hidden_size,
124
- ),
125
- dtype=self.model_runner.dtype,
126
- )
127
120
  else:
128
121
  assert self.require_attn_tp_gather
129
122
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
130
123
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
131
124
  (1,), dtype=torch.int32
132
125
  )
133
- self.gathered_buffer = torch.zeros(
134
- (
135
- self.max_num_token,
136
- self.model_runner.model_config.hidden_size,
137
- ),
138
- dtype=self.model_runner.dtype,
139
- )
140
126
  else:
141
127
  self.global_num_tokens_gpu = None
142
128
  self.global_num_tokens_for_logprob_gpu = None
143
- self.gathered_buffer = None
144
129
 
145
130
  if hasattr(
146
131
  self.model_runner.model_config.hf_config, "draft_vocab_size"
@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
222
207
  device=self.input_ids.device,
223
208
  )
224
209
  )
225
- gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
210
+ global_dp_buffer_len = num_tokens * self.dp_size
226
211
  elif self.require_attn_tp_gather:
227
212
  self.global_num_tokens_gpu.copy_(
228
213
  torch.tensor(
@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
238
223
  device=self.input_ids.device,
239
224
  )
240
225
  )
241
- gathered_buffer = self.gathered_buffer[:num_tokens]
226
+ global_dp_buffer_len = num_tokens
242
227
  else:
243
- gathered_buffer = None
228
+ global_dp_buffer_len = None
244
229
 
245
230
  spec_info = EagleDraftInput(
246
231
  hidden_states=hidden_states,
@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
264
249
  positions=positions,
265
250
  global_num_tokens_gpu=self.global_num_tokens_gpu,
266
251
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
267
- dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
268
- gathered_buffer=gathered_buffer,
252
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
253
+ global_dp_buffer_len=global_dp_buffer_len,
269
254
  spec_algorithm=self.model_runner.spec_algorithm,
270
255
  spec_info=spec_info,
271
256
  capture_hidden_mode=CaptureHiddenMode.LAST,
@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
288
273
  def run_once():
289
274
  # Clean intermediate result cache for DP attention
290
275
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
276
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
291
277
 
292
278
  # Backup two fields, which will be modified in-place in `draft_forward`.
293
279
  output_cache_loc_backup = forward_batch.out_cache_loc
@@ -678,16 +678,12 @@ class TboForwardBatchPreparer:
678
678
  # TODO improve, e.g. unify w/ `init_raw`
679
679
  if (
680
680
  global_server_args_dict["moe_dense_tp_size"] == 1
681
- and batch.gathered_buffer is not None
681
+ and batch.global_dp_buffer_len is not None
682
682
  ):
683
683
  sum_len = end_token_index - start_token_index
684
- gathered_buffer = torch.zeros(
685
- (sum_len, batch.gathered_buffer.shape[1]),
686
- dtype=batch.gathered_buffer.dtype,
687
- device=batch.gathered_buffer.device,
688
- )
684
+ global_dp_buffer_len = sum_len
689
685
  else:
690
- gathered_buffer = None
686
+ global_dp_buffer_len = None
691
687
 
692
688
  output_dict.update(
693
689
  dict(
@@ -706,7 +702,7 @@ class TboForwardBatchPreparer:
706
702
  global_num_tokens_gpu=None,
707
703
  global_num_tokens_cpu=None,
708
704
  dp_padding_mode=None,
709
- gathered_buffer=gathered_buffer,
705
+ global_dp_buffer_len=global_dp_buffer_len,
710
706
  global_num_tokens_for_logprob_gpu=None,
711
707
  global_num_tokens_for_logprob_cpu=None,
712
708
  sampling_info=None,
@@ -4,9 +4,9 @@ from typing import Optional
4
4
  import pytest
5
5
  import torch
6
6
  from sgl_kernel import fused_marlin_moe
7
+ from sgl_kernel.scalar_type import ScalarType, scalar_types
7
8
 
8
9
  from sglang.srt.layers.activation import SiluAndMul
9
- from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
10
10
  from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
11
11
 
12
12
 
@@ -10,13 +10,13 @@ from typing import Optional
10
10
 
11
11
  import numpy as np
12
12
  import torch
13
+ from sgl_kernel.scalar_type import ScalarType
13
14
 
14
15
  from sglang.srt.layers.quantization.marlin_utils import (
15
16
  GPTQ_MARLIN_TILE,
16
17
  marlin_permute_scales,
17
18
  marlin_zero_points,
18
19
  )
19
- from sglang.srt.layers.quantization.scalar_type import ScalarType
20
20
  from sglang.srt.layers.quantization.utils import (
21
21
  get_pack_factor,
22
22
  gptq_quantize_weights,
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.5.0rc1"
1
+ __version__ = "0.5.0rc2"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sglang
3
- Version: 0.5.0rc1
3
+ Version: 0.5.0rc2
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -251,18 +251,18 @@ Requires-Dist: scipy; extra == "runtime-common"
251
251
  Requires-Dist: timm==1.0.16; extra == "runtime-common"
252
252
  Requires-Dist: tiktoken; extra == "runtime-common"
253
253
  Requires-Dist: torchao==0.9.0; extra == "runtime-common"
254
- Requires-Dist: transformers==4.55.0; extra == "runtime-common"
254
+ Requires-Dist: transformers==4.55.2; extra == "runtime-common"
255
255
  Requires-Dist: uvicorn; extra == "runtime-common"
256
256
  Requires-Dist: uvloop; extra == "runtime-common"
257
257
  Requires-Dist: xgrammar==0.1.22; extra == "runtime-common"
258
258
  Provides-Extra: srt
259
259
  Requires-Dist: sglang[runtime_common]; extra == "srt"
260
- Requires-Dist: sgl-kernel==0.3.4.post1; extra == "srt"
260
+ Requires-Dist: sgl-kernel==0.3.5; extra == "srt"
261
261
  Requires-Dist: torch==2.8.0; extra == "srt"
262
262
  Requires-Dist: torchaudio==2.8.0; extra == "srt"
263
263
  Requires-Dist: torchvision; extra == "srt"
264
264
  Requires-Dist: cuda-python; extra == "srt"
265
- Requires-Dist: flashinfer_python==0.2.11.post1; extra == "srt"
265
+ Requires-Dist: flashinfer_python==0.2.11.post3; extra == "srt"
266
266
  Provides-Extra: blackwell
267
267
  Requires-Dist: sglang[runtime_common]; extra == "blackwell"
268
268
  Requires-Dist: sgl-kernel; extra == "blackwell"
@@ -270,7 +270,7 @@ Requires-Dist: torch==2.8.0; extra == "blackwell"
270
270
  Requires-Dist: torchaudio==2.8.0; extra == "blackwell"
271
271
  Requires-Dist: torchvision; extra == "blackwell"
272
272
  Requires-Dist: cuda-python; extra == "blackwell"
273
- Requires-Dist: flashinfer_python==0.2.11.post1; extra == "blackwell"
273
+ Requires-Dist: flashinfer_python==0.2.11.post3; extra == "blackwell"
274
274
  Provides-Extra: srt-hip
275
275
  Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
276
276
  Requires-Dist: torch; extra == "srt-hip"