sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +1 -1
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +4 -2
  11. sglang/srt/layers/linear.py +159 -55
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -3
  15. sglang/srt/layers/parameter.py +431 -0
  16. sglang/srt/layers/quantization/__init__.py +3 -2
  17. sglang/srt/layers/quantization/fp8.py +1 -1
  18. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  20. sglang/srt/managers/cache_controller.py +307 -0
  21. sglang/srt/managers/data_parallel_controller.py +2 -0
  22. sglang/srt/managers/schedule_batch.py +7 -1
  23. sglang/srt/managers/scheduler.py +10 -6
  24. sglang/srt/managers/session_controller.py +1 -1
  25. sglang/srt/managers/tokenizer_manager.py +6 -2
  26. sglang/srt/mem_cache/memory_pool.py +206 -1
  27. sglang/srt/metrics/collector.py +22 -30
  28. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  29. sglang/srt/model_executor/forward_batch_info.py +20 -15
  30. sglang/srt/model_executor/model_runner.py +10 -4
  31. sglang/srt/models/chatglm.py +1 -1
  32. sglang/srt/models/dbrx.py +1 -1
  33. sglang/srt/models/grok.py +25 -16
  34. sglang/srt/models/llama.py +9 -2
  35. sglang/srt/sampling/sampling_batch_info.py +1 -0
  36. sglang/srt/server.py +11 -8
  37. sglang/srt/server_args.py +12 -1
  38. sglang/srt/speculative/eagle_utils.py +93 -85
  39. sglang/srt/speculative/eagle_worker.py +47 -33
  40. sglang/srt/utils.py +32 -5
  41. sglang/test/test_programs.py +23 -1
  42. sglang/test/test_utils.py +36 -7
  43. sglang/version.py +1 -1
  44. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +6 -7
  45. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +48 -43
  46. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  47. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  48. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -114,26 +114,20 @@ class TokenizerMetricsCollector:
114
114
  documentation="Histogram of time to first token in seconds.",
115
115
  labelnames=labels.keys(),
116
116
  buckets=[
117
- 0.001,
118
- 0.005,
119
- 0.01,
120
- 0.02,
121
- 0.04,
122
- 0.06,
123
- 0.08,
124
117
  0.1,
125
118
  0.25,
126
119
  0.5,
127
120
  0.75,
128
- 1.0,
129
- 2.5,
130
- 5.0,
131
- 7.5,
132
- 10.0,
133
- 15.0,
134
- 20.0,
135
- 25.0,
136
- 30.0,
121
+ 1,
122
+ 2,
123
+ 5,
124
+ 10,
125
+ 20,
126
+ 40,
127
+ 60,
128
+ 80,
129
+ 120,
130
+ 160,
137
131
  ],
138
132
  )
139
133
 
@@ -168,21 +162,19 @@ class TokenizerMetricsCollector:
168
162
  documentation="Histogram of End-to-end request latency in seconds",
169
163
  labelnames=labels.keys(),
170
164
  buckets=[
171
- 0.3,
165
+ 0.1,
166
+ 0.25,
172
167
  0.5,
173
- 0.8,
174
- 1.0,
175
- 1.5,
176
- 2.0,
177
- 2.5,
178
- 5.0,
179
- 10.0,
180
- 15.0,
181
- 20.0,
182
- 30.0,
183
- 40.0,
184
- 50.0,
185
- 60.0,
168
+ 1,
169
+ 2,
170
+ 5,
171
+ 10,
172
+ 20,
173
+ 40,
174
+ 60,
175
+ 80,
176
+ 120,
177
+ 160,
186
178
  ],
187
179
  )
188
180
 
@@ -124,10 +124,12 @@ class CudaGraphRunner:
124
124
  self.tp_size = self.model_runner.tp_size
125
125
 
126
126
  # Batch sizes to capture
127
- if model_runner.server_args.disable_cuda_graph_padding:
128
- self.capture_bs = list(range(1, 33)) + [64, 128]
129
- else:
130
- self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
127
+ self.capture_bs = self.model_runner.server_args.cuda_graph_bs
128
+ if self.capture_bs is None:
129
+ if model_runner.server_args.disable_cuda_graph_padding:
130
+ self.capture_bs = list(range(1, 33)) + [64, 128]
131
+ else:
132
+ self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
131
133
 
132
134
  if max(self.capture_bs) > model_runner.req_to_token_pool.size:
133
135
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
@@ -322,6 +324,8 @@ class CudaGraphRunner:
322
324
  global_num_tokens = None
323
325
  gathered_buffer = None
324
326
 
327
+ spec_info = self.get_spec_info(num_tokens, positions)
328
+
325
329
  forward_batch = ForwardBatch(
326
330
  forward_mode=self.capture_forward_mode,
327
331
  batch_size=bs,
@@ -338,10 +342,13 @@ class CudaGraphRunner:
338
342
  top_logprobs_nums=[0] * bs,
339
343
  positions=positions,
340
344
  global_num_tokens=global_num_tokens,
341
- mrope_positions=mrope_positions,
342
345
  gathered_buffer=gathered_buffer,
346
+ mrope_positions=mrope_positions,
343
347
  spec_algorithm=self.model_runner.spec_algorithm,
344
- spec_info=self.get_spec_info(num_tokens, positions),
348
+ spec_info=spec_info,
349
+ capture_hidden_mode=(
350
+ spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
351
+ ),
345
352
  )
346
353
 
347
354
  # Attention backend
@@ -446,10 +453,10 @@ class CudaGraphRunner:
446
453
 
447
454
  if self.model_runner.is_draft_worker:
448
455
  spec_info = EAGLEDraftInput()
456
+ spec_info.load_server_args(self.model_runner.server_args)
449
457
  spec_info.hidden_states = self.hidden_states[:num_tokens]
450
458
  spec_info.positions = positions
451
459
  spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
452
- spec_info.init(self.model_runner.server_args)
453
460
  else:
454
461
  spec_info = EagleVerifyInput(
455
462
  None,
@@ -106,6 +106,24 @@ class ForwardMode(IntEnum):
106
106
  def is_dummy_first(self):
107
107
  return self == ForwardMode.DUMMY_FIRST
108
108
 
109
+ def is_decode_or_idle(self):
110
+ return self == ForwardMode.DECODE or self == ForwardMode.IDLE
111
+
112
+
113
+ class CaptureHiddenMode(IntEnum):
114
+ NULL = auto()
115
+ FULL = auto()
116
+ LAST = auto()
117
+
118
+ def need_capture(self):
119
+ return self != CaptureHiddenMode.NULL
120
+
121
+ def is_full(self):
122
+ return self == CaptureHiddenMode.FULL
123
+
124
+ def is_last(self):
125
+ return self == CaptureHiddenMode.LAST
126
+
109
127
 
110
128
  @dataclass
111
129
  class ForwardBatch:
@@ -174,6 +192,7 @@ class ForwardBatch:
174
192
  # Speculative decoding
175
193
  spec_info: SpecInfo = None
176
194
  spec_algorithm: SpeculativeAlgorithm = None
195
+ capture_hidden_mode: CaptureHiddenMode = None
177
196
 
178
197
  # For Qwen2-VL
179
198
  mrope_positions: torch.Tensor = None
@@ -265,6 +284,7 @@ class ForwardBatch:
265
284
  sampling_info=batch.sampling_info,
266
285
  spec_algorithm=batch.spec_algorithm,
267
286
  spec_info=batch.spec_info,
287
+ capture_hidden_mode=batch.capture_hidden_mode,
268
288
  input_embeds=batch.input_embeds,
269
289
  )
270
290
 
@@ -400,18 +420,3 @@ def compute_position_torch(
400
420
  @maybe_torch_compile(dynamic=True)
401
421
  def clamp_position(seq_lens):
402
422
  return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
403
-
404
-
405
- class CaptureHiddenMode(IntEnum):
406
- NULL = auto()
407
- FULL = auto()
408
- LAST = auto()
409
-
410
- def need_capture(self):
411
- return self != CaptureHiddenMode.NULL
412
-
413
- def is_full(self):
414
- return self == CaptureHiddenMode.FULL
415
-
416
- def is_last(self):
417
- return self == CaptureHiddenMode.LAST
@@ -89,6 +89,7 @@ class ModelRunner:
89
89
  self.is_draft_worker = is_draft_worker
90
90
  self.is_generation = model_config.is_generation
91
91
  self.is_multimodal = model_config.is_multimodal
92
+ self.should_log = tp_rank == 0
92
93
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
93
94
  server_args.speculative_algorithm
94
95
  )
@@ -117,15 +118,21 @@ class ModelRunner:
117
118
 
118
119
  if self.is_multimodal:
119
120
  self.mem_fraction_static *= 0.95
121
+ logger.info(
122
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
123
+ f"because this is a multimodal model."
124
+ )
125
+
120
126
  if self.model_config.hf_config.architectures == [
121
127
  "MllamaForConditionalGeneration"
122
128
  ]:
123
129
  logger.info("Automatically turn off --chunked-prefill-size for mllama.")
124
130
  server_args.chunked_prefill_size = -1
125
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
131
+
126
132
  if self.model_config.hf_config.architectures == [
127
133
  "Qwen2VLForConditionalGeneration"
128
134
  ]:
135
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
129
136
  logger.info(
130
137
  "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
131
138
  )
@@ -198,7 +205,7 @@ class ModelRunner:
198
205
  if self.device == "cuda":
199
206
  backend = "nccl"
200
207
  elif self.device == "xpu":
201
- # TODO(liangan1):Just use gloo to bypass the initilization fail
208
+ # TODO(liangan1): Just use gloo to bypass the initilization fail
202
209
  # Need to use xccl for xpu backend in the future
203
210
  backend = "gloo"
204
211
  elif self.device == "hpu":
@@ -627,7 +634,6 @@ class ModelRunner:
627
634
  )
628
635
 
629
636
  def init_double_sparsity_channel_config(self, selected_channel):
630
-
631
637
  selected_channel = "." + selected_channel + "_proj"
632
638
  self.sorted_channels = []
633
639
  # load channel config
@@ -718,7 +724,7 @@ class ModelRunner:
718
724
  elif forward_batch.forward_mode.is_idle():
719
725
  return self.forward_idle(forward_batch)
720
726
  else:
721
- raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
727
+ raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
722
728
 
723
729
  def sample(
724
730
  self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
@@ -23,8 +23,8 @@ from torch import nn
23
23
  from torch.nn import LayerNorm
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.transformers_utils.configs import ChatGLMConfig
27
26
 
27
+ from sglang.srt.configs import ChatGLMConfig
28
28
  from sglang.srt.layers.activation import SiluAndMul
29
29
  from sglang.srt.layers.layernorm import RMSNorm
30
30
  from sglang.srt.layers.linear import (
sglang/srt/models/dbrx.py CHANGED
@@ -25,8 +25,8 @@ from vllm.distributed import (
25
25
  tensor_model_parallel_all_reduce,
26
26
  )
27
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
- from vllm.transformers_utils.configs.dbrx import DbrxConfig
29
28
 
29
+ from sglang.srt.configs import DbrxConfig
30
30
  from sglang.srt.layers.linear import (
31
31
  QKVParallelLinear,
32
32
  ReplicatedLinear,
sglang/srt/models/grok.py CHANGED
@@ -57,6 +57,7 @@ class Grok1MLP(nn.Module):
57
57
  quant_config: Optional[QuantizationConfig] = None,
58
58
  prefix: str = "",
59
59
  reduce_results=True,
60
+ use_presharded_weights: bool = False,
60
61
  ) -> None:
61
62
  super().__init__()
62
63
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -65,6 +66,7 @@ class Grok1MLP(nn.Module):
65
66
  bias=False,
66
67
  quant_config=quant_config,
67
68
  prefix=f"{prefix}.gate_up_proj",
69
+ use_presharded_weights=use_presharded_weights,
68
70
  )
69
71
  self.down_proj = RowParallelLinear(
70
72
  intermediate_size,
@@ -73,6 +75,7 @@ class Grok1MLP(nn.Module):
73
75
  quant_config=quant_config,
74
76
  prefix=f"{prefix}.down_proj",
75
77
  reduce_results=reduce_results,
78
+ use_presharded_weights=use_presharded_weights,
76
79
  )
77
80
  self.act_fn = GeluAndMul(approximate="tanh")
78
81
 
@@ -103,6 +106,7 @@ class Grok1MoE(nn.Module):
103
106
  quant_config: Optional[QuantizationConfig] = None,
104
107
  tp_size: Optional[int] = None,
105
108
  reduce_results=True,
109
+ use_presharded_weights: bool = False,
106
110
  ):
107
111
  super().__init__()
108
112
  self.hidden_size = hidden_size
@@ -129,6 +133,7 @@ class Grok1MoE(nn.Module):
129
133
  renormalize=False,
130
134
  quant_config=quant_config,
131
135
  tp_size=tp_size,
136
+ use_presharded_weights=use_presharded_weights,
132
137
  )
133
138
 
134
139
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -156,6 +161,7 @@ class Grok1Attention(nn.Module):
156
161
  max_position: int = 4096 * 32,
157
162
  rope_theta: float = 10000,
158
163
  quant_config: Optional[QuantizationConfig] = None,
164
+ reduce_results: bool = True,
159
165
  ) -> None:
160
166
  super().__init__()
161
167
  self.config = config
@@ -194,6 +200,7 @@ class Grok1Attention(nn.Module):
194
200
  hidden_size,
195
201
  bias=False,
196
202
  quant_config=quant_config,
203
+ reduce_results=reduce_results,
197
204
  )
198
205
  self.rotary_emb = get_rope(
199
206
  self.head_dim,
@@ -234,10 +241,12 @@ class Grok1DecoderLayer(nn.Module):
234
241
  config: PretrainedConfig,
235
242
  layer_id: int = 0,
236
243
  quant_config: Optional[QuantizationConfig] = None,
244
+ use_presharded_weights: bool = False,
237
245
  ) -> None:
238
246
  super().__init__()
239
247
  self.num_experts = config.num_local_experts
240
248
  self.hidden_size = config.hidden_size
249
+ self.layer_id = layer_id
241
250
 
242
251
  rope_theta = getattr(config, "rope_theta", 10000)
243
252
  self.self_attn = Grok1Attention(
@@ -262,6 +271,7 @@ class Grok1DecoderLayer(nn.Module):
262
271
  ),
263
272
  quant_config=quant_config,
264
273
  reduce_results=True,
274
+ use_presharded_weights=use_presharded_weights,
265
275
  )
266
276
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
267
277
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -299,6 +309,7 @@ class Grok1Model(nn.Module):
299
309
  self,
300
310
  config: PretrainedConfig,
301
311
  quant_config: Optional[QuantizationConfig] = None,
312
+ use_presharded_weights: bool = False,
302
313
  ) -> None:
303
314
  super().__init__()
304
315
  self.config = config
@@ -311,7 +322,12 @@ class Grok1Model(nn.Module):
311
322
  )
312
323
  self.layers = nn.ModuleList(
313
324
  [
314
- Grok1DecoderLayer(config, i, quant_config=quant_config)
325
+ Grok1DecoderLayer(
326
+ config,
327
+ i,
328
+ quant_config=quant_config,
329
+ use_presharded_weights=use_presharded_weights,
330
+ )
315
331
  for i in range(config.num_hidden_layers)
316
332
  ]
317
333
  )
@@ -347,11 +363,7 @@ class Grok1ForCausalLM(nn.Module):
347
363
  super().__init__()
348
364
  self.config = config
349
365
  self.quant_config = quant_config
350
- self.model = Grok1Model(config, quant_config=quant_config)
351
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
352
- self.logits_processor = LogitsProcessor(config)
353
366
 
354
- # Monkey patch _prepare_weights to load pre-sharded weights
355
367
  if (
356
368
  self.config.num_local_experts > 0
357
369
  and get_tensor_model_parallel_world_size() > 1
@@ -361,6 +373,14 @@ class Grok1ForCausalLM(nn.Module):
361
373
  else:
362
374
  self.use_presharded_weights = False
363
375
 
376
+ self.model = Grok1Model(
377
+ config,
378
+ quant_config=quant_config,
379
+ use_presharded_weights=self.use_presharded_weights,
380
+ )
381
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
382
+ self.logits_processor = LogitsProcessor(config)
383
+
364
384
  def forward(
365
385
  self,
366
386
  input_ids: torch.Tensor,
@@ -376,10 +396,7 @@ class Grok1ForCausalLM(nn.Module):
376
396
  def load_weights(
377
397
  self,
378
398
  weights: Iterable[Tuple[str, torch.Tensor]],
379
- use_presharded_weights: bool | None = None,
380
399
  ):
381
- if use_presharded_weights is None:
382
- use_presharded_weights = self.use_presharded_weights
383
400
  num_experts = self.config.num_local_experts
384
401
 
385
402
  stacked_params_mapping = [
@@ -435,20 +452,12 @@ class Grok1ForCausalLM(nn.Module):
435
452
  continue
436
453
  name = name.replace(weight_name, param_name)
437
454
 
438
- if use_presharded_weights:
439
- extra_kwargs = {
440
- "use_presharded_weights": use_presharded_weights
441
- }
442
- else:
443
- extra_kwargs = {}
444
-
445
455
  load_weight_wrapper(
446
456
  name,
447
457
  loaded_weight,
448
458
  name,
449
459
  shard_id=shard_id,
450
460
  expert_id=expert_id,
451
- **extra_kwargs,
452
461
  )
453
462
  break
454
463
  else:
@@ -100,6 +100,7 @@ class LlamaAttention(nn.Module):
100
100
  max_position_embeddings: int = 8192,
101
101
  quant_config: Optional[QuantizationConfig] = None,
102
102
  prefix: str = "",
103
+ bias: bool = False,
103
104
  ) -> None:
104
105
  super().__init__()
105
106
  self.hidden_size = hidden_size
@@ -132,14 +133,14 @@ class LlamaAttention(nn.Module):
132
133
  self.head_dim,
133
134
  self.total_num_heads,
134
135
  self.total_num_kv_heads,
135
- bias=False,
136
+ bias=bias,
136
137
  quant_config=quant_config,
137
138
  prefix=f"{prefix}.qkv_proj",
138
139
  )
139
140
  self.o_proj = RowParallelLinear(
140
141
  self.total_num_heads * self.head_dim,
141
142
  hidden_size,
142
- bias=False,
143
+ bias=bias,
143
144
  quant_config=quant_config,
144
145
  prefix=f"{prefix}.o_proj",
145
146
  )
@@ -194,6 +195,11 @@ class LlamaDecoderLayer(nn.Module):
194
195
  )
195
196
  rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
196
197
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
198
+ # Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
199
+ # Support internlm/internlm-7b with bias
200
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
201
+ config, "bias", False
202
+ )
197
203
  self.self_attn = LlamaAttention(
198
204
  config=config,
199
205
  hidden_size=self.hidden_size,
@@ -206,6 +212,7 @@ class LlamaDecoderLayer(nn.Module):
206
212
  max_position_embeddings=max_position_embeddings,
207
213
  quant_config=quant_config,
208
214
  prefix=f"{prefix}.self_attn",
215
+ bias=attention_bias,
209
216
  )
210
217
  self.mlp = LlamaMLP(
211
218
  hidden_size=self.hidden_size,
@@ -232,6 +232,7 @@ class SamplingBatchInfo:
232
232
  self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
233
233
  self.logit_bias, other.logit_bias, len(self), len(other), self.device
234
234
  )
235
+ self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
235
236
 
236
237
  def apply_logits_bias(self, logits: torch.Tensor):
237
238
  # Apply logit_bias
sglang/srt/server.py CHANGED
@@ -127,14 +127,12 @@ async def health() -> Response:
127
127
  async def health_generate(request: Request) -> Response:
128
128
  """Check the health of the inference server by generating one token."""
129
129
 
130
+ sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
131
+
130
132
  if tokenizer_manager.is_generation:
131
- gri = GenerateReqInput(
132
- input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
133
- )
133
+ gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
134
134
  else:
135
- gri = EmbeddingReqInput(
136
- input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
137
- )
135
+ gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
138
136
 
139
137
  try:
140
138
  async for _ in tokenizer_manager.generate_request(gri, request):
@@ -546,7 +544,12 @@ def launch_server(
546
544
 
547
545
  # Send a warmup request
548
546
  t = threading.Thread(
549
- target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
547
+ target=_wait_and_warmup,
548
+ args=(
549
+ server_args,
550
+ pipe_finish_writer,
551
+ tokenizer_manager.image_token_id,
552
+ ),
550
553
  )
551
554
  t.start()
552
555
 
@@ -616,7 +619,7 @@ def _set_envs_and_config(server_args: ServerArgs):
616
619
  mp.set_start_method("spawn", force=True)
617
620
 
618
621
 
619
- def _wait_and_warmup(server_args, pipe_finish_writer):
622
+ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
620
623
  headers = {}
621
624
  url = server_args.url()
622
625
  if server_args.api_key:
sglang/srt/server_args.py CHANGED
@@ -148,6 +148,7 @@ class ServerArgs:
148
148
  enable_torch_compile: bool = False
149
149
  torch_compile_max_bs: int = 32
150
150
  cuda_graph_max_bs: Optional[int] = None
151
+ cuda_graph_bs: Optional[List[int]] = None
151
152
  torchao_config: str = ""
152
153
  enable_nan_detection: bool = False
153
154
  enable_p2p_check: bool = False
@@ -361,6 +362,7 @@ class ServerArgs:
361
362
  "awq_marlin",
362
363
  "bitsandbytes",
363
364
  "gguf",
365
+ "modelopt",
364
366
  ],
365
367
  help="The quantization method.",
366
368
  )
@@ -802,6 +804,12 @@ class ServerArgs:
802
804
  default=ServerArgs.cuda_graph_max_bs,
803
805
  help="Set the maximum batch size for cuda graph.",
804
806
  )
807
+ parser.add_argument(
808
+ "--cuda-graph-bs",
809
+ type=int,
810
+ nargs="+",
811
+ help="Set the list of batch sizes for cuda graph.",
812
+ )
805
813
  parser.add_argument(
806
814
  "--torchao-config",
807
815
  type=str,
@@ -920,7 +928,10 @@ class PortArgs:
920
928
  while True:
921
929
  if is_port_available(port):
922
930
  break
923
- port += 42
931
+ if port < 60000:
932
+ port += 42
933
+ else:
934
+ port -= 43
924
935
 
925
936
  return PortArgs(
926
937
  tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,