sglang 0.4.0.post1__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_offline_throughput.py +18 -6
  2. sglang/bench_one_batch.py +13 -0
  3. sglang/bench_serving.py +8 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/constrained/xgrammar_backend.py +4 -1
  9. sglang/srt/layers/attention/flashinfer_backend.py +2 -0
  10. sglang/srt/layers/attention/triton_backend.py +16 -25
  11. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  12. sglang/srt/layers/ep_moe/layer.py +4 -0
  13. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  14. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/quantization/__init__.py +2 -47
  17. sglang/srt/layers/quantization/fp8.py +58 -10
  18. sglang/srt/layers/radix_attention.py +8 -1
  19. sglang/srt/layers/sampler.py +27 -5
  20. sglang/srt/layers/torchao_utils.py +35 -0
  21. sglang/srt/managers/detokenizer_manager.py +37 -17
  22. sglang/srt/managers/io_struct.py +39 -10
  23. sglang/srt/managers/schedule_batch.py +38 -24
  24. sglang/srt/managers/schedule_policy.py +64 -5
  25. sglang/srt/managers/scheduler.py +169 -134
  26. sglang/srt/managers/tokenizer_manager.py +99 -58
  27. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  28. sglang/srt/mem_cache/chunk_cache.py +2 -2
  29. sglang/srt/mem_cache/radix_cache.py +12 -2
  30. sglang/srt/model_executor/cuda_graph_runner.py +24 -10
  31. sglang/srt/model_executor/model_runner.py +22 -14
  32. sglang/srt/model_parallel.py +66 -5
  33. sglang/srt/models/gemma2.py +34 -0
  34. sglang/srt/models/gemma2_reward.py +0 -1
  35. sglang/srt/models/granite.py +517 -0
  36. sglang/srt/models/grok.py +72 -8
  37. sglang/srt/models/llama.py +22 -0
  38. sglang/srt/models/llama_classification.py +11 -23
  39. sglang/srt/models/llama_reward.py +0 -2
  40. sglang/srt/models/llava.py +37 -14
  41. sglang/srt/models/qwen2.py +20 -0
  42. sglang/srt/openai_api/adapter.py +4 -0
  43. sglang/srt/openai_api/protocol.py +9 -4
  44. sglang/srt/server.py +1 -1
  45. sglang/srt/server_args.py +19 -9
  46. sglang/srt/utils.py +7 -10
  47. sglang/test/test_utils.py +3 -2
  48. sglang/utils.py +10 -3
  49. sglang/version.py +1 -1
  50. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +11 -6
  51. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +54 -52
  52. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  53. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  54. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py CHANGED
@@ -25,9 +25,11 @@ from transformers import PretrainedConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
27
 
28
+ from sglang.srt.layers.activation import GeluAndMul
28
29
  from sglang.srt.layers.fused_moe_triton import FusedMoE
29
30
  from sglang.srt.layers.layernorm import RMSNorm
30
31
  from sglang.srt.layers.linear import (
32
+ MergedColumnParallelLinear,
31
33
  QKVParallelLinear,
32
34
  ReplicatedLinear,
33
35
  RowParallelLinear,
@@ -40,10 +42,43 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
42
  VocabParallelEmbedding,
41
43
  )
42
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
- from sglang.srt.model_loader.loader import DefaultModelLoader
44
45
  from sglang.srt.model_loader.weight_utils import default_weight_loader
45
46
 
46
47
 
48
+ class Grok1MLP(nn.Module):
49
+ def __init__(
50
+ self,
51
+ hidden_size: int,
52
+ intermediate_size: int,
53
+ quant_config: Optional[QuantizationConfig] = None,
54
+ prefix: str = "",
55
+ reduce_results=True,
56
+ ) -> None:
57
+ super().__init__()
58
+ self.gate_up_proj = MergedColumnParallelLinear(
59
+ hidden_size,
60
+ [intermediate_size] * 2,
61
+ bias=False,
62
+ quant_config=quant_config,
63
+ prefix=f"{prefix}.gate_up_proj",
64
+ )
65
+ self.down_proj = RowParallelLinear(
66
+ intermediate_size,
67
+ hidden_size,
68
+ bias=False,
69
+ quant_config=quant_config,
70
+ prefix=f"{prefix}.down_proj",
71
+ reduce_results=reduce_results,
72
+ )
73
+ self.act_fn = GeluAndMul(approximate="tanh")
74
+
75
+ def forward(self, x):
76
+ gate_up, _ = self.gate_up_proj(x)
77
+ x = self.act_fn(gate_up)
78
+ x, _ = self.down_proj(x)
79
+ return x
80
+
81
+
47
82
  class Grok1MoE(nn.Module):
48
83
  """A tensor-parallel MoE implementation for Grok1 that shards each expert
49
84
  across all ranks.
@@ -55,6 +90,7 @@ class Grok1MoE(nn.Module):
55
90
 
56
91
  def __init__(
57
92
  self,
93
+ config: PretrainedConfig,
58
94
  num_experts: int,
59
95
  top_k: int,
60
96
  hidden_size: int,
@@ -62,6 +98,7 @@ class Grok1MoE(nn.Module):
62
98
  params_dtype: Optional[torch.dtype] = None,
63
99
  quant_config: Optional[QuantizationConfig] = None,
64
100
  tp_size: Optional[int] = None,
101
+ reduce_results=True,
65
102
  ):
66
103
  super().__init__()
67
104
  self.hidden_size = hidden_size
@@ -75,13 +112,16 @@ class Grok1MoE(nn.Module):
75
112
  quant_config=None,
76
113
  )
77
114
 
115
+ self.router_logit_softcapping = getattr(
116
+ config, "router_logit_softcapping", 30.0
117
+ )
78
118
  self.experts = FusedMoE(
79
119
  num_experts=num_experts,
80
120
  top_k=top_k,
81
121
  hidden_size=hidden_size,
82
122
  intermediate_size=intermediate_size,
83
123
  params_dtype=params_dtype,
84
- reduce_results=True,
124
+ reduce_results=reduce_results,
85
125
  renormalize=False,
86
126
  quant_config=quant_config,
87
127
  tp_size=tp_size,
@@ -91,9 +131,12 @@ class Grok1MoE(nn.Module):
91
131
  # NOTE: hidden_states can have either 1D or 2D shape.
92
132
  orig_shape = hidden_states.shape
93
133
  hidden_states = hidden_states.view(-1, self.hidden_size)
134
+
94
135
  # router_logits: (num_tokens, n_experts)
95
136
  router_logits, _ = self.gate(hidden_states)
96
137
  router_logits = 30.0 * F.tanh(router_logits / 30.0)
138
+
139
+ # need to assert self.gate.quant_method is unquantized
97
140
  final_hidden_states = self.experts(hidden_states, router_logits)
98
141
  return final_hidden_states.view(orig_shape)
99
142
 
@@ -101,16 +144,18 @@ class Grok1MoE(nn.Module):
101
144
  class Grok1Attention(nn.Module):
102
145
  def __init__(
103
146
  self,
147
+ config: PretrainedConfig,
104
148
  hidden_size: int,
105
149
  num_heads: int,
106
150
  num_kv_heads: int,
107
151
  layer_id: int = 0,
108
152
  max_position: int = 4096 * 32,
109
153
  rope_theta: float = 10000,
110
- logit_cap: float = 30,
111
154
  quant_config: Optional[QuantizationConfig] = None,
112
155
  ) -> None:
113
156
  super().__init__()
157
+ self.config = config
158
+ self.layer_id = layer_id
114
159
  self.hidden_size = hidden_size
115
160
  tp_size = get_tensor_model_parallel_world_size()
116
161
  self.total_num_heads = num_heads
@@ -126,7 +171,7 @@ class Grok1Attention(nn.Module):
126
171
  # the KV heads across multiple tensor parallel GPUs.
127
172
  assert tp_size % self.total_num_kv_heads == 0
128
173
  self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
129
- self.head_dim = 128
174
+ self.head_dim = getattr(config, "head_dim", 128)
130
175
  self.q_size = self.num_heads * self.head_dim
131
176
  self.kv_size = self.num_kv_heads * self.head_dim
132
177
  self.scaling = self.head_dim**-0.5
@@ -140,7 +185,6 @@ class Grok1Attention(nn.Module):
140
185
  bias=False,
141
186
  quant_config=quant_config,
142
187
  )
143
-
144
188
  self.o_proj = RowParallelLinear(
145
189
  self.total_num_heads * self.head_dim,
146
190
  hidden_size,
@@ -154,6 +198,9 @@ class Grok1Attention(nn.Module):
154
198
  base=int(self.rope_theta),
155
199
  is_neox_style=True,
156
200
  )
201
+
202
+ logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
203
+
157
204
  self.attn = RadixAttention(
158
205
  self.num_heads,
159
206
  self.head_dim,
@@ -162,7 +209,6 @@ class Grok1Attention(nn.Module):
162
209
  layer_id=layer_id,
163
210
  logit_cap=logit_cap,
164
211
  )
165
- # TODO(lianmin): load logit cap from config
166
212
 
167
213
  def forward(
168
214
  self,
@@ -186,10 +232,12 @@ class Grok1DecoderLayer(nn.Module):
186
232
  quant_config: Optional[QuantizationConfig] = None,
187
233
  ) -> None:
188
234
  super().__init__()
235
+ self.num_experts = config.num_local_experts
189
236
  self.hidden_size = config.hidden_size
190
237
 
191
238
  rope_theta = getattr(config, "rope_theta", 10000)
192
239
  self.self_attn = Grok1Attention(
240
+ config=config,
193
241
  hidden_size=self.hidden_size,
194
242
  num_heads=config.num_attention_heads,
195
243
  max_position=config.max_position_embeddings,
@@ -199,11 +247,17 @@ class Grok1DecoderLayer(nn.Module):
199
247
  quant_config=quant_config,
200
248
  )
201
249
  self.block_sparse_moe = Grok1MoE(
250
+ config=config,
202
251
  num_experts=config.num_local_experts,
203
252
  top_k=config.num_experts_per_tok,
204
253
  hidden_size=config.hidden_size,
205
- intermediate_size=config.intermediate_size,
254
+ intermediate_size=getattr(
255
+ config,
256
+ "moe_intermediate_size",
257
+ getattr(config, "intermediate_size", None),
258
+ ),
206
259
  quant_config=quant_config,
260
+ reduce_results=True,
207
261
  )
208
262
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
209
263
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -284,6 +338,7 @@ class Grok1ForCausalLM(nn.Module):
284
338
  self,
285
339
  config: PretrainedConfig,
286
340
  quant_config: Optional[QuantizationConfig] = None,
341
+ cache_config=None,
287
342
  ) -> None:
288
343
  super().__init__()
289
344
  self.config = config
@@ -310,6 +365,8 @@ class Grok1ForCausalLM(nn.Module):
310
365
  ("qkv_proj", "q_proj", "q"),
311
366
  ("qkv_proj", "k_proj", "k"),
312
367
  ("qkv_proj", "v_proj", "v"),
368
+ ("gate_up_proj", "gate_proj", 0),
369
+ ("gate_up_proj", "up_proj", 1),
313
370
  ]
314
371
 
315
372
  # Params for weights, fp8 weight scales, fp8 activation scales
@@ -345,6 +402,11 @@ class Grok1ForCausalLM(nn.Module):
345
402
  continue
346
403
  name = name.replace(weight_name, param_name)
347
404
 
405
+ if (
406
+ name.endswith(".bias") or name.endswith("_bias")
407
+ ) and name not in params_dict:
408
+ continue
409
+
348
410
  param = params_dict[name]
349
411
  weight_loader = param.weight_loader
350
412
  weight_loader(
@@ -357,7 +419,9 @@ class Grok1ForCausalLM(nn.Module):
357
419
  break
358
420
  else:
359
421
  # Skip loading extra bias for GPTQ models.
360
- if name.endswith(".bias") and name not in params_dict:
422
+ if (
423
+ name.endswith(".bias") or name.endswith("_bias")
424
+ ) and name not in params_dict:
361
425
  continue
362
426
  # Skip loading kv_scale from ckpts towards new design.
363
427
  if name.endswith(".kv_scale") and name not in params_dict:
@@ -294,6 +294,28 @@ class LlamaModel(nn.Module):
294
294
 
295
295
 
296
296
  class LlamaForCausalLM(nn.Module):
297
+
298
+ # BitandBytes specific attributes
299
+ default_bitsandbytes_target_modules = [
300
+ ".gate_proj.",
301
+ ".down_proj.",
302
+ ".up_proj.",
303
+ ".q_proj.",
304
+ ".k_proj.",
305
+ ".v_proj.",
306
+ ".o_proj.",
307
+ ]
308
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
309
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
310
+ bitsandbytes_stacked_params_mapping = {
311
+ # shard_name, weight_name, index
312
+ "q_proj": ("qkv_proj", 0),
313
+ "k_proj": ("qkv_proj", 1),
314
+ "v_proj": ("qkv_proj", 2),
315
+ "gate_proj": ("gate_up_proj", 0),
316
+ "up_proj": ("gate_up_proj", 1),
317
+ }
318
+
297
319
  def __init__(
298
320
  self,
299
321
  config: LlamaConfig,
@@ -18,7 +18,7 @@ import torch
18
18
  from torch import nn
19
19
  from transformers import LlamaConfig
20
20
 
21
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
21
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
24
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -33,14 +33,13 @@ class LlamaForClassification(nn.Module):
33
33
  ) -> None:
34
34
  super().__init__()
35
35
  self.config = config
36
- self.torchao_config = None
37
36
  self.quant_config = quant_config
38
37
  self.model = LlamaModel(config, quant_config=quant_config)
39
38
 
40
39
  self.classification_head = nn.Linear(
41
40
  config.hidden_size, config.classification_out_size, bias=False
42
41
  )
43
- self.eos_token_id = config.eos_token_id
42
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
44
43
 
45
44
  @torch.no_grad()
46
45
  def forward(
@@ -49,28 +48,17 @@ class LlamaForClassification(nn.Module):
49
48
  positions: torch.Tensor,
50
49
  forward_batch: ForwardBatch,
51
50
  input_embeds: torch.Tensor = None,
52
- ) -> torch.Tensor:
53
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
54
- is_eos_token = input_ids == self.eos_token_id
55
- hidden_states = hidden_states[is_eos_token]
56
- scores = self.classification_head(hidden_states)
57
-
58
- if scores.shape[0] != forward_batch.batch_size:
59
- print("Warning: the EOS tokens are missing in some sentences.")
60
- scores = torch.ones(
61
- (forward_batch.batch_size, self.config.classification_out_size)
62
- ).to(input_ids.device)
51
+ get_embedding: bool = True,
52
+ ) -> EmbeddingPoolerOutput:
53
+ assert (
54
+ get_embedding
55
+ ), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server."
63
56
 
64
- logits_output = LogitsProcessorOutput(
65
- next_token_logits=scores,
66
- next_token_logprobs=scores,
67
- normalized_prompt_logprobs=scores,
68
- input_token_logprobs=torch.ones_like(input_ids),
69
- input_top_logprobs=None,
70
- output_top_logprobs=None,
71
- )
57
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
58
+ last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
59
+ scores = self.classification_head(last_token_hidden)
72
60
 
73
- return logits_output
61
+ return EmbeddingPoolerOutput(scores)
74
62
 
75
63
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
76
64
  params_dict = dict(self.named_parameters())
@@ -21,7 +21,6 @@ from transformers import LlamaConfig
21
21
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
- from sglang.srt.model_loader.weight_utils import default_weight_loader
25
24
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
26
25
 
27
26
 
@@ -33,7 +32,6 @@ class LlamaForSequenceClassification(nn.Module):
33
32
  ) -> None:
34
33
  super().__init__()
35
34
  self.config = config
36
- self.torchao_config = None
37
35
  self.quant_config = quant_config
38
36
  self.num_labels = config.num_labels
39
37
  self.model = LlamaModel(config, quant_config=quant_config)
@@ -57,6 +57,7 @@ class LlavaBaseForCausalLM(nn.Module):
57
57
  else:
58
58
  image_aspect_ratio = "anyres"
59
59
  offset_list = []
60
+ image_inputs.image_pad_len = []
60
61
  for image_idx, image_s in enumerate(image_sizes):
61
62
  if len(image_sizes) > 16:
62
63
  # 2x2 pooling with stride 2
@@ -103,6 +104,7 @@ class LlavaBaseForCausalLM(nn.Module):
103
104
  + input_ids[offset + 1 :]
104
105
  )
105
106
  offset_list.append(offset)
107
+ image_inputs.image_pad_len.append(new_image_feature_len)
106
108
 
107
109
  image_inputs.image_offsets = offset_list
108
110
  return input_ids
@@ -134,6 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
134
136
  image_inputs = forward_batch.image_inputs
135
137
 
136
138
  if forward_batch.forward_mode.is_extend():
139
+ # Clamp input ids. This is because the input_ids for the image tokens are
140
+ # filled with the hash values of the image for the prefix matching in the radix attention.
141
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
142
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
143
+
144
+ # Embed text inputs
145
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
146
+
137
147
  # Got List[List[str]] extend it to List[str]
138
148
  # The length of the List should be equal to batch size
139
149
  modalities_list = []
@@ -142,18 +152,12 @@ class LlavaBaseForCausalLM(nn.Module):
142
152
  if im and im.modalities is not None:
143
153
  modalities_list.extend(im.modalities)
144
154
  if im and im.image_offsets:
145
- max_image_offset.append(max(im.image_offsets))
155
+ max_image_offset.append(
156
+ np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
157
+ )
146
158
  else:
147
159
  max_image_offset.append(-1)
148
160
 
149
- # Clamp input ids. This is because the input_ids for the image tokens are
150
- # filled with the hash values of the image for the prefix matching in the radix attention.
151
- # There values are useless because their embeddings will be replaced by vision embeddings anyway.
152
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
153
-
154
- # Embed text inputs
155
- input_embeds = self.language_model.model.embed_tokens(input_ids)
156
-
157
161
  start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
158
162
  need_vision = start_positions <= np.array(max_image_offset)
159
163
 
@@ -350,6 +354,7 @@ class LlavaBaseForCausalLM(nn.Module):
350
354
 
351
355
  # Fill in the placeholder for the image
352
356
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
357
+ extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
353
358
  prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
354
359
  pt = 0
355
360
  for i in range(bs):
@@ -357,18 +362,36 @@ class LlavaBaseForCausalLM(nn.Module):
357
362
  continue
358
363
 
359
364
  start_idx = extend_start_loc_cpu[i]
365
+ seq_len = extend_seq_lens[i]
360
366
  prefix_len = prefix_lens_cpu[i]
361
367
 
362
368
  # Multiple images
363
- for j, image_offset in enumerate(image_inputs[i].image_offsets):
364
- if image_offset < prefix_len:
369
+ for image_idx, image_offset in enumerate(
370
+ image_inputs[i].image_offsets
371
+ ):
372
+ if (
373
+ image_offset + image_inputs[i].image_pad_len[image_idx]
374
+ <= prefix_len
375
+ ):
365
376
  continue
377
+ if image_offset >= prefix_len + seq_len:
378
+ break
366
379
 
367
- tmp_image_feature = image_features[pt][j]
380
+ tmp_image_feature = image_features[pt][image_idx]
368
381
  pad_len = tmp_image_feature.shape[0]
369
382
 
370
- left_idx = start_idx + (image_offset - prefix_len)
371
- right_idx = start_idx + (image_offset - prefix_len) + pad_len
383
+ input_offset = image_offset - prefix_len
384
+ left_idx = start_idx + input_offset
385
+ right_idx = left_idx + pad_len
386
+ assert right_idx > start_idx
387
+ if input_offset < 0:
388
+ left_idx = start_idx
389
+ tmp_image_feature = tmp_image_feature[-input_offset:]
390
+ if right_idx > start_idx + seq_len:
391
+ tmp_image_feature = tmp_image_feature[
392
+ : start_idx + seq_len - right_idx
393
+ ]
394
+ right_idx = start_idx + seq_len
372
395
  try:
373
396
  input_embeds[left_idx:right_idx] = tmp_image_feature
374
397
  except RuntimeError as e:
@@ -267,6 +267,26 @@ class Qwen2Model(nn.Module):
267
267
 
268
268
 
269
269
  class Qwen2ForCausalLM(nn.Module):
270
+
271
+ # BitandBytes specific attributes
272
+ default_bitsandbytes_target_modules = [
273
+ ".gate_proj.",
274
+ ".down_proj.",
275
+ ".up_proj.",
276
+ ".q_proj.",
277
+ ".k_proj.",
278
+ ".v_proj.",
279
+ ".o_proj.",
280
+ ]
281
+ bitsandbytes_stacked_params_mapping = {
282
+ # shard_name, weight_name, index
283
+ "q_proj": ("qkv_proj", 0),
284
+ "k_proj": ("qkv_proj", 1),
285
+ "v_proj": ("qkv_proj", 2),
286
+ "gate_proj": ("gate_up_proj", 0),
287
+ "up_proj": ("gate_up_proj", 1),
288
+ }
289
+
270
290
  def __init__(
271
291
  self,
272
292
  config: Qwen2Config,
@@ -510,6 +510,8 @@ def v1_generate_request(
510
510
  "stop": request.stop,
511
511
  "stop_token_ids": request.stop_token_ids,
512
512
  "top_p": request.top_p,
513
+ "top_k": request.top_k,
514
+ "min_p": request.min_p,
513
515
  "presence_penalty": request.presence_penalty,
514
516
  "frequency_penalty": request.frequency_penalty,
515
517
  "repetition_penalty": request.repetition_penalty,
@@ -926,6 +928,8 @@ def v1_chat_generate_request(
926
928
  "stop": stop,
927
929
  "stop_token_ids": request.stop_token_ids,
928
930
  "top_p": request.top_p,
931
+ "top_k": request.top_k,
932
+ "min_p": request.min_p,
929
933
  "presence_penalty": request.presence_penalty,
930
934
  "frequency_penalty": request.frequency_penalty,
931
935
  "repetition_penalty": request.repetition_penalty,
@@ -166,17 +166,19 @@ class CompletionRequest(BaseModel):
166
166
  temperature: float = 1.0
167
167
  top_p: float = 1.0
168
168
  user: Optional[str] = None
169
- lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
170
169
 
171
170
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
172
- json_schema: Optional[str] = None
173
- regex: Optional[str] = None
171
+ top_k: int = -1
172
+ min_p: float = 0.0
174
173
  min_tokens: int = 0
174
+ regex: Optional[str] = None
175
+ json_schema: Optional[str] = None
175
176
  repetition_penalty: float = 1.0
176
177
  stop_token_ids: Optional[List[int]] = None
177
178
  no_stop_trim: bool = False
178
179
  ignore_eos: bool = False
179
180
  skip_special_tokens: bool = True
181
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
180
182
 
181
183
 
182
184
  class CompletionResponseChoice(BaseModel):
@@ -276,13 +278,16 @@ class ChatCompletionRequest(BaseModel):
276
278
  user: Optional[str] = None
277
279
 
278
280
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
279
- regex: Optional[str] = None
281
+ top_k: int = -1
282
+ min_p: float = 0.0
280
283
  min_tokens: int = 0
284
+ regex: Optional[str] = None
281
285
  repetition_penalty: float = 1.0
282
286
  stop_token_ids: Optional[List[int]] = None
283
287
  no_stop_trim: bool = False
284
288
  ignore_eos: bool = False
285
289
  skip_special_tokens: bool = True
290
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
286
291
 
287
292
 
288
293
  class ChatMessage(BaseModel):
sglang/srt/server.py CHANGED
@@ -196,7 +196,7 @@ async def stop_profile_async():
196
196
  @app.post("/update_weights_from_disk")
197
197
  @time_func_latency
198
198
  async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
199
- """Update the weights from disk inplace without re-launching the server."""
199
+ """Update the weights from disk in-place without re-launching the server."""
200
200
  success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
201
201
  content = {"success": success, "message": message}
202
202
  if success:
sglang/srt/server_args.py CHANGED
@@ -141,6 +141,7 @@ class ServerArgs:
141
141
  enable_nan_detection: bool = False
142
142
  enable_p2p_check: bool = False
143
143
  triton_attention_reduce_in_fp32: bool = False
144
+ triton_attention_num_kv_splits: int = 8
144
145
  num_continuous_decode_steps: int = 1
145
146
  delete_ckpt_after_loading: bool = False
146
147
 
@@ -220,12 +221,10 @@ class ServerArgs:
220
221
  if self.enable_dp_attention:
221
222
  self.dp_size = self.tp_size
222
223
  self.chunked_prefill_size = self.chunked_prefill_size // 2
223
- self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
224
224
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
225
225
  self.disable_overlap_schedule = True
226
226
  logger.warning(
227
227
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
228
- f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
229
228
  f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
230
229
  "Data parallel size is adjusted to be the same as tensor parallel size. "
231
230
  "Overlap scheduler is disabled."
@@ -282,7 +281,15 @@ class ServerArgs:
282
281
  "--load-format",
283
282
  type=str,
284
283
  default=ServerArgs.load_format,
285
- choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
284
+ choices=[
285
+ "auto",
286
+ "pt",
287
+ "safetensors",
288
+ "npcache",
289
+ "dummy",
290
+ "gguf",
291
+ "bitsandbytes",
292
+ ],
286
293
  help="The format of the model weights to load. "
287
294
  '"auto" will try to load the weights in the safetensors format '
288
295
  "and fall back to the pytorch bin format if safetensors format "
@@ -293,7 +300,9 @@ class ServerArgs:
293
300
  "a numpy cache to speed up the loading. "
294
301
  '"dummy" will initialize the weights with random values, '
295
302
  "which is mainly for profiling."
296
- '"gguf" will load the weights in the gguf format. ',
303
+ '"gguf" will load the weights in the gguf format. '
304
+ '"bitsandbytes" will load the weights using bitsandbytes '
305
+ "quantization.",
297
306
  )
298
307
  parser.add_argument(
299
308
  "--trust-remote-code",
@@ -689,11 +698,6 @@ class ServerArgs:
689
698
  action="store_true",
690
699
  help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
691
700
  )
692
- parser.add_argument(
693
- "--disable-nan-detection",
694
- action="store_true",
695
- help="Disable the NaN detection for better performance.",
696
- )
697
701
  parser.add_argument(
698
702
  "--disable-overlap-schedule",
699
703
  action="store_true",
@@ -753,6 +757,12 @@ class ServerArgs:
753
757
  help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
754
758
  "This only affects Triton attention kernels.",
755
759
  )
760
+ parser.add_argument(
761
+ "--triton-attention-num-kv-splits",
762
+ type=int,
763
+ default=ServerArgs.triton_attention_num_kv_splits,
764
+ help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
765
+ )
756
766
  parser.add_argument(
757
767
  "--num-continuous-decode-steps",
758
768
  type=int,
sglang/srt/utils.py CHANGED
@@ -92,7 +92,7 @@ def is_flashinfer_available():
92
92
  """
93
93
  if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
94
94
  return False
95
- return torch.cuda.is_available() and not is_hip()
95
+ return torch.cuda.is_available() and torch.version.cuda
96
96
 
97
97
 
98
98
  def is_ipv6(address):
@@ -169,7 +169,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
169
169
  return wrapper
170
170
 
171
171
 
172
- def get_available_gpu_memory(device, gpu_id, distributed=False):
172
+ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
173
173
  """
174
174
  Get available memory for cuda:gpu_id device.
175
175
  When distributed is True, the available memory is the minimum available memory of all GPUs.
@@ -184,7 +184,8 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
184
184
  "which may cause useless memory allocation for torch CUDA context.",
185
185
  )
186
186
 
187
- torch.cuda.empty_cache()
187
+ if empty_cache:
188
+ torch.cuda.empty_cache()
188
189
  free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
189
190
 
190
191
  elif device == "xpu":
@@ -196,7 +197,9 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
196
197
  f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
197
198
  "which may cause useless memory allocation for torch XPU context.",
198
199
  )
199
- torch.xpu.empty_cache()
200
+
201
+ if empty_cache:
202
+ torch.xpu.empty_cache()
200
203
  used_memory = torch.xpu.memory_allocated()
201
204
  total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
202
205
  free_gpu_memory = total_gpu_memory - used_memory
@@ -1068,9 +1071,6 @@ def get_device_name(device_id: int = 0) -> str:
1068
1071
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1069
1072
  return torch.cuda.get_device_name(device_id)
1070
1073
 
1071
- if hasattr(torch, "hip") and torch.hip.is_available():
1072
- return torch.hip.get_device_name(device_id)
1073
-
1074
1074
  if hasattr(torch, "xpu") and torch.xpu.is_available():
1075
1075
  return torch.xpu.get_device_name(device_id)
1076
1076
 
@@ -1083,9 +1083,6 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
1083
1083
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1084
1084
  major, minor = torch.cuda.get_device_capability(device_id)
1085
1085
 
1086
- if hasattr(torch, "hip") and torch.hip.is_available():
1087
- major, minor = torch.cuda.get_device_capability(device_id)
1088
-
1089
1086
  if hasattr(torch, "xpu") and torch.xpu.is_available():
1090
1087
  major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
1091
1088
  "."