sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +48 -20
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +71 -1
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/outlines_backend.py +15 -2
  8. sglang/srt/constrained/xgrammar_backend.py +22 -14
  9. sglang/srt/layers/activation.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  11. sglang/srt/layers/attention/triton_backend.py +9 -7
  12. sglang/srt/layers/custom_op_util.py +26 -0
  13. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  14. sglang/srt/layers/layernorm.py +4 -0
  15. sglang/srt/layers/logits_processor.py +10 -10
  16. sglang/srt/layers/sampler.py +4 -8
  17. sglang/srt/layers/torchao_utils.py +2 -0
  18. sglang/srt/managers/data_parallel_controller.py +74 -9
  19. sglang/srt/managers/detokenizer_manager.py +1 -0
  20. sglang/srt/managers/io_struct.py +27 -0
  21. sglang/srt/managers/schedule_batch.py +104 -38
  22. sglang/srt/managers/schedule_policy.py +5 -1
  23. sglang/srt/managers/scheduler.py +204 -54
  24. sglang/srt/managers/session_controller.py +62 -0
  25. sglang/srt/managers/tokenizer_manager.py +38 -0
  26. sglang/srt/managers/tp_worker.py +12 -1
  27. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  28. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  29. sglang/srt/model_executor/forward_batch_info.py +109 -15
  30. sglang/srt/model_executor/model_runner.py +99 -43
  31. sglang/srt/model_parallel.py +98 -0
  32. sglang/srt/models/deepseek_v2.py +147 -44
  33. sglang/srt/models/gemma2.py +9 -8
  34. sglang/srt/models/llava.py +1 -1
  35. sglang/srt/models/llavavid.py +1 -1
  36. sglang/srt/models/olmo.py +3 -3
  37. sglang/srt/models/phi3_small.py +447 -0
  38. sglang/srt/models/qwen2_vl.py +13 -6
  39. sglang/srt/models/torch_native_llama.py +94 -78
  40. sglang/srt/openai_api/adapter.py +6 -2
  41. sglang/srt/openai_api/protocol.py +1 -1
  42. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  43. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  44. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  45. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  47. sglang/srt/sampling/sampling_batch_info.py +58 -57
  48. sglang/srt/sampling/sampling_params.py +1 -1
  49. sglang/srt/server.py +27 -1
  50. sglang/srt/server_args.py +78 -62
  51. sglang/srt/utils.py +71 -52
  52. sglang/test/runners.py +25 -6
  53. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  54. sglang/test/test_utils.py +30 -19
  55. sglang/version.py +1 -1
  56. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  57. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
  58. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  59. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  60. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
- from typing import TYPE_CHECKING, List, Optional
4
+ import logging
5
+ import threading
6
+ from typing import TYPE_CHECKING, Callable, List, Optional
5
7
 
6
8
  import torch
7
9
 
8
10
  import sglang.srt.sampling.penaltylib as penaltylib
9
11
 
12
+ logger = logging.getLogger(__name__)
13
+
14
+
10
15
  if TYPE_CHECKING:
11
16
  from sglang.srt.managers.schedule_batch import ScheduleBatch
12
17
 
@@ -27,10 +32,11 @@ class SamplingBatchInfo:
27
32
 
28
33
  # Bias Tensors
29
34
  vocab_size: int
35
+ grammars: Optional[List] = None
36
+ sampling_info_done: Optional[threading.Event] = None
30
37
  logit_bias: torch.Tensor = None
31
38
  vocab_mask: Optional[torch.Tensor] = None
32
-
33
- grammars: Optional[List] = None
39
+ apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
34
40
 
35
41
  # Penalizer
36
42
  penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -42,10 +48,7 @@ class SamplingBatchInfo:
42
48
 
43
49
  @classmethod
44
50
  def from_schedule_batch(
45
- cls,
46
- batch: ScheduleBatch,
47
- vocab_size: int,
48
- disable_penalizer: bool,
51
+ cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
49
52
  ):
50
53
  reqs = batch.reqs
51
54
  device = batch.device
@@ -73,12 +76,39 @@ class SamplingBatchInfo:
73
76
  top_ks=top_ks,
74
77
  min_ps=min_ps,
75
78
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
76
- is_all_greedy=top_ks.max().item() <= 1,
79
+ is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
77
80
  vocab_size=vocab_size,
78
81
  device=device,
79
82
  )
80
83
  # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
81
84
 
85
+ if enable_overlap_schedule:
86
+ # TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
87
+ # so it is kind of tricky to make it work with overlap scheduler.
88
+ # It requires correcly updating the penalty logits before the sampling and syncing the events.
89
+ # We will support them later.
90
+ penalizers = {
91
+ penaltylib.BatchedMinNewTokensPenalizer,
92
+ }
93
+ if (
94
+ any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
95
+ or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
96
+ or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
97
+ ):
98
+ logger.warning(
99
+ "frequency_penalty, presence_penalty, and repetition_penalty are not supported "
100
+ "when using the default overlap scheduler. They will be ignored. "
101
+ "Please add `--disable-overlap` when launching the server if you need these features. "
102
+ "The speed will be slower in that case."
103
+ )
104
+ else:
105
+ penalizers = {
106
+ penaltylib.BatchedFrequencyPenalizer,
107
+ penaltylib.BatchedMinNewTokensPenalizer,
108
+ penaltylib.BatchedPresencePenalizer,
109
+ penaltylib.BatchedRepetitionPenalizer,
110
+ }
111
+
82
112
  # Each penalizers will do nothing if they evaluate themselves as not required by looking at
83
113
  # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
84
114
  # should not add hefty computation overhead other than simple checks.
@@ -86,20 +116,12 @@ class SamplingBatchInfo:
86
116
  # While we choose not to even create the class instances if they are not required, this
87
117
  # could add additional complexity to the {ScheduleBatch} class, especially we need to
88
118
  # handle {filter_batch()} and {merge_batch()} cases as well.
89
- if disable_penalizer:
90
- ret.penalizer_orchestrator = None
91
- else:
92
- ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
93
- vocab_size=vocab_size,
94
- batch=batch,
95
- device=batch.device,
96
- Penalizers={
97
- penaltylib.BatchedFrequencyPenalizer,
98
- penaltylib.BatchedMinNewTokensPenalizer,
99
- penaltylib.BatchedPresencePenalizer,
100
- penaltylib.BatchedRepetitionPenalizer,
101
- },
102
- )
119
+ ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
120
+ vocab_size=vocab_size,
121
+ batch=batch,
122
+ device=batch.device,
123
+ Penalizers=penalizers,
124
+ )
103
125
 
104
126
  # Handle logit bias but only allocate when needed
105
127
  ret.logit_bias = None
@@ -110,9 +132,6 @@ class SamplingBatchInfo:
110
132
  return len(self.temperatures)
111
133
 
112
134
  def update_penalties(self):
113
- if not self.penalizer_orchestrator:
114
- return
115
-
116
135
  self.scaling_penalties = None
117
136
  self.linear_penalties = None
118
137
 
@@ -133,23 +152,28 @@ class SamplingBatchInfo:
133
152
  self.linear_penalties = penalizer.apply(self.linear_penalties)
134
153
 
135
154
  def update_regex_vocab_mask(self):
136
- if not self.grammars or not any(grammar for grammar in self.grammars):
155
+ if not self.grammars:
137
156
  self.vocab_mask = None
157
+ self.apply_mask = None
138
158
  return
139
159
 
140
- self.vocab_mask = torch.zeros(
141
- len(self.temperatures),
142
- self.vocab_size,
143
- dtype=torch.bool,
160
+ # find a grammar from the list
161
+ grammar = next(grammar for grammar in self.grammars if grammar)
162
+
163
+ # maybe we can reuse the existing mask?
164
+ self.vocab_mask = grammar.allocate_vocab_mask(
165
+ vocab_size=self.vocab_size,
166
+ batch_size=len(self.temperatures),
144
167
  device=self.device,
145
168
  )
169
+ self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
170
+
146
171
  for i, grammar in enumerate(self.grammars):
147
172
  if grammar is not None:
148
- grammar.fill_vocab_mask(self.vocab_mask[i])
173
+ grammar.fill_vocab_mask(self.vocab_mask, i)
149
174
 
150
175
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
151
- if self.penalizer_orchestrator:
152
- self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
176
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
153
177
 
154
178
  for item in [
155
179
  "temperatures",
@@ -188,8 +212,7 @@ class SamplingBatchInfo:
188
212
  return None
189
213
 
190
214
  def merge_batch(self, other: "SamplingBatchInfo"):
191
- if self.penalizer_orchestrator:
192
- self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
215
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
193
216
 
194
217
  for item in [
195
218
  "temperatures",
@@ -205,25 +228,3 @@ class SamplingBatchInfo:
205
228
  self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
206
229
  self.logit_bias, other.logit_bias, len(self), len(other), self.device
207
230
  )
208
-
209
- def copy(self):
210
- return SamplingBatchInfo(
211
- temperatures=self.temperatures,
212
- top_ps=self.top_ps,
213
- top_ks=self.top_ks,
214
- min_ps=self.min_ps,
215
- is_all_greedy=self.is_all_greedy,
216
- need_min_p_sampling=self.need_min_p_sampling,
217
- vocab_size=self.vocab_size,
218
- device=self.device,
219
- )
220
-
221
- def to(self, device: str):
222
- for item in [
223
- "temperatures",
224
- "top_ps",
225
- "top_ks",
226
- "min_ps",
227
- ]:
228
- value = getattr(self, item)
229
- setattr(self, item, value.to(device, non_blocking=True))
@@ -24,7 +24,6 @@ class SamplingParams:
24
24
  def __init__(
25
25
  self,
26
26
  max_new_tokens: int = 128,
27
- min_new_tokens: int = 0,
28
27
  stop: Optional[Union[str, List[str]]] = None,
29
28
  stop_token_ids: Optional[List[int]] = None,
30
29
  temperature: float = 1.0,
@@ -34,6 +33,7 @@ class SamplingParams:
34
33
  frequency_penalty: float = 0.0,
35
34
  presence_penalty: float = 0.0,
36
35
  repetition_penalty: float = 1.0,
36
+ min_new_tokens: int = 0,
37
37
  spaces_between_special_tokens: bool = True,
38
38
  regex: Optional[str] = None,
39
39
  n: int = 1,
sglang/srt/server.py CHANGED
@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
50
50
  )
51
51
  from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
52
52
  from sglang.srt.managers.io_struct import (
53
+ CloseSessionReqInput,
53
54
  EmbeddingReqInput,
54
55
  GenerateReqInput,
56
+ OpenSessionReqInput,
55
57
  UpdateWeightReqInput,
56
58
  )
57
59
  from sglang.srt.managers.scheduler import run_scheduler_process
@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
215
217
  )
216
218
 
217
219
 
220
+ @app.api_route("/open_session", methods=["GET", "POST"])
221
+ async def open_session(obj: OpenSessionReqInput, request: Request):
222
+ """Open a session, and return its unique session id."""
223
+ try:
224
+ session_id = await tokenizer_manager.open_session(obj, request)
225
+ return session_id
226
+ except Exception as e:
227
+ return ORJSONResponse(
228
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
229
+ )
230
+
231
+
232
+ @app.api_route("/close_session", methods=["GET", "POST"])
233
+ async def close_session(obj: CloseSessionReqInput, request: Request):
234
+ """Close the session"""
235
+ try:
236
+ await tokenizer_manager.close_session(obj, request)
237
+ return Response(status_code=200)
238
+ except Exception as e:
239
+ return ORJSONResponse(
240
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
241
+ )
242
+
243
+
218
244
  @time_func_latency
219
245
  async def generate_request(obj: GenerateReqInput, request: Request):
220
246
  """Handle a generate request."""
@@ -392,7 +418,7 @@ def launch_engine(
392
418
  )
393
419
  for tp_rank in tp_rank_range:
394
420
  reader, writer = mp.Pipe(duplex=False)
395
- gpu_id = tp_rank % tp_size_per_node
421
+ gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
396
422
  proc = mp.Process(
397
423
  target=run_scheduler_process,
398
424
  args=(server_args, port_args, gpu_id, tp_rank, None, writer),
sglang/srt/server_args.py CHANGED
@@ -23,8 +23,10 @@ import tempfile
23
23
  from typing import List, Optional
24
24
 
25
25
  from sglang.srt.utils import (
26
- get_gpu_memory_capacity,
26
+ get_amdgpu_memory_capacity,
27
+ get_nvgpu_memory_capacity,
27
28
  is_flashinfer_available,
29
+ is_hip,
28
30
  is_ipv6,
29
31
  is_port_available,
30
32
  )
@@ -70,6 +72,7 @@ class ServerArgs:
70
72
  constrained_json_whitespace_pattern: Optional[str] = None
71
73
  watchdog_timeout: float = 300
72
74
  download_dir: Optional[str] = None
75
+ base_gpu_id: int = 0
73
76
 
74
77
  # Logging
75
78
  log_level: str = "info"
@@ -114,8 +117,6 @@ class ServerArgs:
114
117
  grammar_backend: Optional[str] = "outlines"
115
118
 
116
119
  # Optimization/debug options
117
- disable_flashinfer: bool = False
118
- disable_flashinfer_sampling: bool = False
119
120
  disable_radix_cache: bool = False
120
121
  disable_jump_forward: bool = False
121
122
  disable_cuda_graph: bool = False
@@ -123,14 +124,14 @@ class ServerArgs:
123
124
  disable_disk_cache: bool = False
124
125
  disable_custom_all_reduce: bool = False
125
126
  disable_mla: bool = False
126
- disable_penalizer: bool = False
127
- disable_nan_detection: bool = False
128
- enable_overlap_schedule: bool = False
127
+ disable_overlap_schedule: bool = False
129
128
  enable_mixed_chunk: bool = False
129
+ enable_dp_attention: bool = False
130
130
  enable_torch_compile: bool = False
131
131
  torch_compile_max_bs: int = 32
132
132
  cuda_graph_max_bs: int = 160
133
133
  torchao_config: str = ""
134
+ enable_nan_detection: bool = False
134
135
  enable_p2p_check: bool = False
135
136
  triton_attention_reduce_in_fp32: bool = False
136
137
  num_continuous_decode_steps: int = 1
@@ -156,7 +157,7 @@ class ServerArgs:
156
157
  if self.tp_size >= 16:
157
158
  self.mem_fraction_static = 0.79
158
159
  elif self.tp_size >= 8:
159
- self.mem_fraction_static = 0.83
160
+ self.mem_fraction_static = 0.82
160
161
  elif self.tp_size >= 4:
161
162
  self.mem_fraction_static = 0.85
162
163
  elif self.tp_size >= 2:
@@ -165,59 +166,45 @@ class ServerArgs:
165
166
  self.mem_fraction_static = 0.88
166
167
 
167
168
  # Adjust for GPUs with small memory capacities
168
- gpu_mem = get_gpu_memory_capacity()
169
+ if is_hip():
170
+ gpu_mem = get_amdgpu_memory_capacity()
171
+ else:
172
+ gpu_mem = get_nvgpu_memory_capacity()
169
173
  if gpu_mem < 25000:
170
- logger.warning(
171
- "Automatically adjust --chunked-prefill-size for small GPUs."
172
- )
173
174
  self.chunked_prefill_size //= 4 # make it 2048
174
175
  self.cuda_graph_max_bs = 4
176
+ logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
175
177
 
176
- # Deprecation warnings
177
- if self.disable_flashinfer:
178
- logger.warning(
179
- "The option '--disable-flashinfer' will be deprecated in the next release. "
180
- "Please use '--attention-backend triton' instead."
181
- )
182
- self.attention_backend = "triton"
183
- if self.disable_flashinfer_sampling:
184
- logger.warning(
185
- "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
186
- "Please use '--sampling-backend pytorch' instead. "
187
- )
188
- self.sampling_backend = "pytorch"
189
-
178
+ # Choose kernel backends
190
179
  if not is_flashinfer_available():
191
180
  self.attention_backend = "triton"
192
181
  self.sampling_backend = "pytorch"
193
182
 
194
- # Default kernel backends
195
183
  if self.attention_backend is None:
196
184
  self.attention_backend = "flashinfer"
197
-
198
185
  if self.sampling_backend is None:
199
186
  self.sampling_backend = "flashinfer"
200
187
 
201
- if self.enable_overlap_schedule:
202
- logger.warning(
203
- "Overlap scheduler mode is enabled. This is an experimental feature. "
204
- "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
205
- "and embedding APIs are not supported and will lead to wrong results. "
206
- "The NaN detection is also disabled."
188
+ # Others
189
+ if self.enable_dp_attention:
190
+ self.dp_size = self.tp_size
191
+ self.chunked_prefill_size = self.chunked_prefill_size // 2
192
+ self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
193
+ self.schedule_conservativeness = self.schedule_conservativeness * 0.3
194
+ self.disable_overlap_schedule = True
195
+ logger.info(
196
+ f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
197
+ f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
198
+ f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
199
+ "Data parallel size is adjusted to be the same as tensor parallel size. "
200
+ "Overlap schedule is disabled."
207
201
  )
208
- self.disable_penalizer = True
209
- self.disable_nan_detection = True
210
202
 
211
- # Model-specific patches
212
- if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
203
+ if self.enable_mixed_chunk:
213
204
  logger.info(
214
- "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
205
+ "Overlap schedule is disabled because mixed-style chunked prefill is enabled."
215
206
  )
216
- self.trust_remote_code = False
217
-
218
- if "gemma-2" in self.model_path.lower():
219
- logger.info("When using sliding window in gemma-2, turn on flashinfer.")
220
- self.attention_backend = "flashinfer"
207
+ self.disable_overlap_schedule = True
221
208
 
222
209
  @staticmethod
223
210
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -426,6 +413,12 @@ class ServerArgs:
426
413
  default=ServerArgs.download_dir,
427
414
  help="Model download directory.",
428
415
  )
416
+ parser.add_argument(
417
+ "--base-gpu-id",
418
+ type=int,
419
+ default=ServerArgs.base_gpu_id,
420
+ help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
421
+ )
429
422
 
430
423
  # Logging
431
424
  parser.add_argument(
@@ -599,16 +592,6 @@ class ServerArgs:
599
592
  )
600
593
 
601
594
  # Optimization/debug options
602
- parser.add_argument(
603
- "--disable-flashinfer",
604
- action="store_true",
605
- help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
606
- )
607
- parser.add_argument(
608
- "--disable-flashinfer-sampling",
609
- action="store_true",
610
- help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
611
- )
612
595
  parser.add_argument(
613
596
  "--disable-radix-cache",
614
597
  action="store_true",
@@ -644,26 +627,26 @@ class ServerArgs:
644
627
  action="store_true",
645
628
  help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
646
629
  )
647
- parser.add_argument(
648
- "--disable-penalizer",
649
- action="store_true",
650
- help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
651
- )
652
630
  parser.add_argument(
653
631
  "--disable-nan-detection",
654
632
  action="store_true",
655
633
  help="Disable the NaN detection for better performance.",
656
634
  )
657
635
  parser.add_argument(
658
- "--enable-overlap-schedule",
636
+ "--disable-overlap-schedule",
659
637
  action="store_true",
660
- help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
638
+ help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
661
639
  )
662
640
  parser.add_argument(
663
641
  "--enable-mixed-chunk",
664
642
  action="store_true",
665
643
  help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
666
644
  )
645
+ parser.add_argument(
646
+ "--enable-dp-attention",
647
+ action="store_true",
648
+ help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
649
+ )
667
650
  parser.add_argument(
668
651
  "--enable-torch-compile",
669
652
  action="store_true",
@@ -685,7 +668,12 @@ class ServerArgs:
685
668
  "--torchao-config",
686
669
  type=str,
687
670
  default=ServerArgs.torchao_config,
688
- help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
671
+ help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
672
+ )
673
+ parser.add_argument(
674
+ "--enable-nan-detection",
675
+ action="store_true",
676
+ help="Enable the NaN detection for debugging purposes.",
689
677
  )
690
678
  parser.add_argument(
691
679
  "--enable-p2p-check",
@@ -712,6 +700,23 @@ class ServerArgs:
712
700
  help="Delete the model checkpoint after loading the model.",
713
701
  )
714
702
 
703
+ # Deprecated arguments
704
+ parser.add_argument(
705
+ "--enable-overlap-schedule",
706
+ action=DeprecatedAction,
707
+ help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
708
+ )
709
+ parser.add_argument(
710
+ "--disable-flashinfer",
711
+ action=DeprecatedAction,
712
+ help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
713
+ )
714
+ parser.add_argument(
715
+ "--disable-flashinfer-sampling",
716
+ action=DeprecatedAction,
717
+ help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
718
+ )
719
+
715
720
  @classmethod
716
721
  def from_cli_args(cls, args: argparse.Namespace):
717
722
  args.tp_size = args.tensor_parallel_size
@@ -738,6 +743,7 @@ class ServerArgs:
738
743
  and (self.lora_paths is None or self.disable_cuda_graph)
739
744
  and (self.lora_paths is None or self.disable_radix_cache)
740
745
  ), "compatibility of lora and cuda graph and radix attention is in progress"
746
+ assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
741
747
 
742
748
  if isinstance(self.lora_paths, list):
743
749
  lora_paths = self.lora_paths
@@ -782,7 +788,7 @@ class PortArgs:
782
788
 
783
789
  @staticmethod
784
790
  def init_new(server_args) -> "PortArgs":
785
- port = server_args.port + 42
791
+ port = server_args.port + random.randint(100, 1000)
786
792
  while True:
787
793
  if is_port_available(port):
788
794
  break
@@ -805,3 +811,13 @@ class LoRAPathAction(argparse.Action):
805
811
  getattr(namespace, self.dest)[name] = path
806
812
  else:
807
813
  getattr(namespace, self.dest)[lora_path] = lora_path
814
+
815
+
816
+ class DeprecatedAction(argparse.Action):
817
+ def __init__(self, option_strings, dest, nargs=0, **kwargs):
818
+ super(DeprecatedAction, self).__init__(
819
+ option_strings, dest, nargs=nargs, **kwargs
820
+ )
821
+
822
+ def __call__(self, parser, namespace, values, option_string=None):
823
+ raise ValueError(self.help)
sglang/srt/utils.py CHANGED
@@ -71,6 +71,8 @@ def is_flashinfer_available():
71
71
  Check whether flashinfer is available.
72
72
  As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
73
73
  """
74
+ if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false":
75
+ return False
74
76
  return torch.cuda.is_available() and not is_hip()
75
77
 
76
78
 
@@ -330,6 +332,7 @@ def suppress_other_loggers():
330
332
  )
331
333
  logging.getLogger("vllm.selector").setLevel(logging.WARN)
332
334
  logging.getLogger("vllm.utils").setLevel(logging.ERROR)
335
+ logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
333
336
 
334
337
  warnings.filterwarnings(
335
338
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -394,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
394
397
  pass
395
398
 
396
399
 
400
+ def monkey_patch_vllm_model_config():
401
+ from vllm.config import ModelConfig
402
+
403
+ if not hasattr(ModelConfig, "_resolve_task"):
404
+ return
405
+
406
+ def _resolve_task(
407
+ self,
408
+ task_option,
409
+ hf_config,
410
+ ):
411
+ supported_tasks = {
412
+ "generate": True,
413
+ "embedding": False,
414
+ }
415
+ selected_task = "generate"
416
+ return supported_tasks, selected_task
417
+
418
+ setattr(ModelConfig, "_resolve_task", _resolve_task)
419
+
420
+
397
421
  def monkey_patch_vllm_p2p_access_check(gpu_id: int):
398
422
  """
399
423
  Monkey patch the slow p2p access check in vllm.
@@ -405,57 +429,6 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
405
429
  setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
406
430
 
407
431
 
408
- def monkey_patch_vllm_dummy_weight_loader():
409
- """
410
- Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
411
- """
412
-
413
- from vllm.model_executor.model_loader.loader import (
414
- CacheConfig,
415
- DeviceConfig,
416
- DummyModelLoader,
417
- LoRAConfig,
418
- ModelConfig,
419
- ParallelConfig,
420
- SchedulerConfig,
421
- _initialize_model,
422
- initialize_dummy_weights,
423
- nn,
424
- set_default_torch_dtype,
425
- )
426
-
427
- def load_model(
428
- self,
429
- *,
430
- model_config: ModelConfig,
431
- device_config: DeviceConfig,
432
- lora_config: Optional[LoRAConfig],
433
- parallel_config: ParallelConfig,
434
- scheduler_config: SchedulerConfig,
435
- cache_config: CacheConfig,
436
- ) -> nn.Module:
437
- with set_default_torch_dtype(model_config.dtype):
438
- with torch.device(device_config.device):
439
- model = _initialize_model(
440
- model_config,
441
- self.load_config,
442
- lora_config,
443
- cache_config,
444
- )
445
-
446
- for _, module in model.named_modules():
447
- quant_method = getattr(module, "quant_method", None)
448
- if quant_method is not None:
449
- quant_method.process_weights_after_loading(module)
450
-
451
- # NOTE(woosuk): For accurate performance evaluation, we assign
452
- # random values to the weights.
453
- initialize_dummy_weights(model)
454
- return model.eval()
455
-
456
- setattr(DummyModelLoader, "load_model", load_model)
457
-
458
-
459
432
  vllm_all_gather_backup = None
460
433
 
461
434
 
@@ -794,7 +767,48 @@ def add_prometheus_middleware(app):
794
767
  app.routes.append(metrics_route)
795
768
 
796
769
 
797
- def get_gpu_memory_capacity():
770
+ def bind_port(port):
771
+ """Bind to a specific port, assuming it's available."""
772
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
773
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse
774
+ sock.bind(("", port))
775
+ sock.listen(1)
776
+ return sock
777
+
778
+
779
+ def get_amdgpu_memory_capacity():
780
+ try:
781
+ # Run rocm-smi and capture the output
782
+ result = subprocess.run(
783
+ ["rocm-smi --showmeminfo vram | grep 'Total Memory' | awk '{print $NF}'"],
784
+ stdout=subprocess.PIPE,
785
+ stderr=subprocess.PIPE,
786
+ shell=True,
787
+ text=True,
788
+ )
789
+ if result.returncode != 0:
790
+ raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}")
791
+
792
+ # Parse the output to extract memory values in MiB
793
+ memory_values = [
794
+ float(mem) / 1024 / 1024
795
+ for mem in result.stdout.strip().split("\n")
796
+ if re.match(r"^\d+(\.\d+)?$", mem.strip())
797
+ ]
798
+
799
+ if not memory_values:
800
+ raise ValueError("No GPU memory values found.")
801
+
802
+ # Return the minimum memory value
803
+ return min(memory_values)
804
+
805
+ except FileNotFoundError:
806
+ raise RuntimeError(
807
+ "rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible."
808
+ )
809
+
810
+
811
+ def get_nvgpu_memory_capacity():
798
812
  try:
799
813
  # Run nvidia-smi and capture the output
800
814
  result = subprocess.run(
@@ -824,3 +838,8 @@ def get_gpu_memory_capacity():
824
838
  raise RuntimeError(
825
839
  "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
826
840
  )
841
+
842
+
843
+ def crash_on_warnings():
844
+ # Crash on warning if we are running CI tests
845
+ return os.getenv("SGLANG_IS_IN_CI", "false") == "true"