sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,14 @@ class StreamOptions(BaseModel):
82
82
  include_usage: Optional[bool] = False
83
83
 
84
84
 
85
+ class JsonSchemaResponseFormat(BaseModel):
86
+ name: str
87
+ description: Optional[str] = None
88
+ # use alias to workaround pydantic conflict
89
+ schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
90
+ strict: Optional[bool] = False
91
+
92
+
85
93
  class FileRequest(BaseModel):
86
94
  # https://platform.openai.com/docs/api-reference/files/create
87
95
  file: bytes # The File object (not file name) to be uploaded
@@ -213,6 +221,7 @@ class ChatCompletionMessageContentImageURL(BaseModel):
213
221
  class ChatCompletionMessageContentImagePart(BaseModel):
214
222
  type: Literal["image_url"]
215
223
  image_url: ChatCompletionMessageContentImageURL
224
+ modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
216
225
 
217
226
 
218
227
  ChatCompletionMessageContentPart = Union[
@@ -236,8 +245,8 @@ ChatCompletionMessageParam = Union[
236
245
 
237
246
 
238
247
  class ResponseFormat(BaseModel):
239
- # type must be "json_object" or "text"
240
- type: Literal["text", "json_object"]
248
+ type: Literal["text", "json_object", "json_schema"]
249
+ json_schema: Optional[JsonSchemaResponseFormat] = None
241
250
 
242
251
 
243
252
  class ChatCompletionRequest(BaseModel):
@@ -263,7 +272,6 @@ class ChatCompletionRequest(BaseModel):
263
272
 
264
273
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
265
274
  regex: Optional[str] = None
266
- json_schema: Optional[str] = None
267
275
  min_tokens: Optional[int] = 0
268
276
  repetition_penalty: Optional[float] = 1.0
269
277
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
@@ -34,70 +34,76 @@ class SamplingBatchInfo:
34
34
  linear_penalties: torch.Tensor = None
35
35
  scaling_penalties: torch.Tensor = None
36
36
 
37
- def has_bias(self):
37
+ def __len__(self):
38
+ return len(self.temperatures)
39
+
40
+ def can_run_in_cuda_graph(self):
41
+ # Vocab bias and min_ps are not supported in CUDA graph
38
42
  return (
39
- self.logit_bias is not None
40
- or self.vocab_mask is not None
41
- or self.linear_penalties is not None
42
- or self.scaling_penalties is not None
43
+ self.logit_bias is None
44
+ and self.linear_penalties is None
45
+ and self.scaling_penalties is None
46
+ and not self.need_min_p_sampling
43
47
  )
44
48
 
45
49
  @classmethod
46
50
  def dummy_one(cls, max_bs: int, vocab_size: int):
47
51
  ret = cls(vocab_size=vocab_size)
48
- ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
49
- ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
50
- ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
51
- ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
52
+ with torch.device("cuda"):
53
+ ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float)
54
+ ret.top_ps = torch.ones((max_bs,), dtype=torch.float)
55
+ ret.top_ks = torch.ones((max_bs,), dtype=torch.int)
56
+ ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool)
52
57
  return ret
53
58
 
54
59
  def __getitem__(self, key):
55
60
  if isinstance(key, slice):
56
- # NOTE: We do not use cuda graph when there is bias tensors
57
- assert not self.has_bias()
61
+ # NOTE:This method is only used in CUDA graph
62
+ assert self.can_run_in_cuda_graph()
58
63
  return SamplingBatchInfo(
59
64
  vocab_size=self.vocab_size,
60
65
  temperatures=self.temperatures[key],
61
66
  top_ps=self.top_ps[key],
62
67
  top_ks=self.top_ks[key],
63
- min_ps=self.min_ps[key],
64
- need_min_p_sampling=self.need_min_p_sampling,
68
+ vocab_mask=self.vocab_mask[key],
65
69
  )
66
70
  else:
67
71
  raise NotImplementedError
68
72
 
69
73
  def inplace_assign(self, bs: int, other: SamplingBatchInfo):
70
- # NOTE: We do not use cuda graph when there is bias tensors
71
- assert not self.has_bias()
74
+ # NOTE:This method is only used in CUDA graph
75
+ assert self.can_run_in_cuda_graph()
72
76
 
73
77
  self.vocab_size = other.vocab_size
74
- self.need_min_p_sampling = other.need_min_p_sampling
75
-
76
78
  self.temperatures[:bs] = other.temperatures
77
79
  self.top_ps[:bs] = other.top_ps
78
80
  self.top_ks[:bs] = other.top_ks
79
- self.min_ps[:bs] = other.min_ps
81
+
82
+ if other.vocab_mask is None:
83
+ self.vocab_mask[:bs].fill_(False)
84
+ else:
85
+ self.vocab_mask[:bs] = other.vocab_mask
80
86
 
81
87
  @classmethod
82
88
  def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
83
- device = "cuda"
84
89
  reqs = batch.reqs
85
90
  ret = cls(vocab_size=vocab_size)
86
91
 
87
- ret.temperatures = torch.tensor(
88
- [r.sampling_params.temperature for r in reqs],
89
- dtype=torch.float,
90
- device=device,
91
- ).view(-1, 1)
92
- ret.top_ps = torch.tensor(
93
- [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
94
- )
95
- ret.top_ks = torch.tensor(
96
- [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
97
- )
98
- ret.min_ps = torch.tensor(
99
- [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
100
- )
92
+ with torch.device("cuda"):
93
+ ret.temperatures = torch.tensor(
94
+ [r.sampling_params.temperature for r in reqs],
95
+ dtype=torch.float,
96
+ ).view(-1, 1)
97
+ ret.top_ps = torch.tensor(
98
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float
99
+ )
100
+ ret.top_ks = torch.tensor(
101
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int
102
+ )
103
+ ret.min_ps = torch.tensor(
104
+ [r.sampling_params.min_p for r in reqs], dtype=torch.float
105
+ )
106
+
101
107
  ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
102
108
 
103
109
  # Each penalizers will do nothing if they evaluate themselves as not required by looking at
@@ -110,7 +116,7 @@ class SamplingBatchInfo:
110
116
  ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
111
117
  vocab_size=vocab_size,
112
118
  batch=batch,
113
- device=device,
119
+ device="cuda",
114
120
  Penalizers={
115
121
  penaltylib.BatchedFrequencyPenalizer,
116
122
  penaltylib.BatchedMinNewTokensPenalizer,
@@ -122,11 +128,9 @@ class SamplingBatchInfo:
122
128
  # Handle logit bias but only allocate when needed
123
129
  ret.logit_bias = None
124
130
 
125
- ret.update_regex_vocab_mask(batch)
126
-
127
131
  return ret
128
132
 
129
- def prepare_penalties(self):
133
+ def update_penalties(self):
130
134
  self.scaling_penalties = None
131
135
  self.linear_penalties = None
132
136
 
@@ -146,18 +150,16 @@ class SamplingBatchInfo:
146
150
  self.linear_penalties = penalizer.apply(self.linear_penalties)
147
151
 
148
152
  def update_regex_vocab_mask(self, batch: ScheduleBatch):
149
- bs, reqs = batch.batch_size(), batch.reqs
150
- device = "cuda"
151
- has_regex = any(req.regex_fsm is not None for req in reqs)
153
+ has_regex = any(req.regex_fsm is not None for req in batch.reqs)
152
154
 
153
155
  # Reset the vocab mask
154
156
  self.vocab_mask = None
155
157
 
156
158
  if has_regex:
157
159
  self.vocab_mask = torch.zeros(
158
- bs, self.vocab_size, dtype=torch.bool, device=device
160
+ batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
159
161
  )
160
- for i, req in enumerate(reqs):
162
+ for i, req in enumerate(batch.reqs):
161
163
  if req.regex_fsm is not None:
162
164
  self.vocab_mask[i].fill_(1)
163
165
  self.vocab_mask[i][
@@ -178,6 +180,26 @@ class SamplingBatchInfo:
178
180
  if self_val is not None: # logit_bias can be None
179
181
  setattr(self, item, self_val[new_indices])
180
182
 
183
+ @staticmethod
184
+ def merge_bias_tensor(
185
+ lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
186
+ ):
187
+ # bias tensor can be None
188
+ if lhs is not None or rhs is not None:
189
+ shape, dtype = None, None
190
+ if lhs is not None:
191
+ shape, dtype = lhs.shape[1:], lhs.dtype
192
+ else:
193
+ shape, dtype = rhs.shape[1:], rhs.dtype
194
+ with torch.dtype(dtype):
195
+ if lhs is None:
196
+ lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
197
+ if rhs is None:
198
+ rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
199
+ return torch.cat([lhs, rhs])
200
+
201
+ return None
202
+
181
203
  def merge(self, other: "SamplingBatchInfo"):
182
204
  self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
183
205
 
@@ -191,19 +213,6 @@ class SamplingBatchInfo:
191
213
  other_val = getattr(other, item, None)
192
214
  setattr(self, item, torch.concat([self_val, other_val]))
193
215
 
194
- # logit_bias can be None
195
- if self.logit_bias is not None or other.logit_bias is not None:
196
- vocab_size = (
197
- self.logit_bias.shape[1]
198
- if self.logit_bias is not None
199
- else other.logit_bias.shape[1]
200
- )
201
- if self.logit_bias is None:
202
- self.logit_bias = torch.zeros(
203
- (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
204
- )
205
- if other.logit_bias is None:
206
- other.logit_bias = torch.zeros(
207
- (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
208
- )
209
- self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
216
+ self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
217
+ self.logit_bias, other.logit_bias, len(self), len(other)
218
+ )
sglang/srt/server.py CHANGED
@@ -37,6 +37,7 @@ import requests
37
37
  import uvicorn
38
38
  import uvloop
39
39
  from fastapi import FastAPI, File, Form, Request, UploadFile
40
+ from fastapi.middleware.cors import CORSMiddleware
40
41
  from fastapi.responses import JSONResponse, Response, StreamingResponse
41
42
 
42
43
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
@@ -93,6 +94,14 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
93
94
  app = FastAPI()
94
95
  tokenizer_manager = None
95
96
 
97
+ app.add_middleware(
98
+ CORSMiddleware,
99
+ allow_origins=["*"],
100
+ allow_credentials=True,
101
+ allow_methods=["*"],
102
+ allow_headers=["*"],
103
+ )
104
+
96
105
 
97
106
  @app.get("/health")
98
107
  async def health() -> Response:
@@ -272,7 +281,6 @@ async def retrieve_file_content(file_id: str):
272
281
 
273
282
  def launch_server(
274
283
  server_args: ServerArgs,
275
- model_override_args: Optional[dict] = None,
276
284
  pipe_finish_writer: Optional[mp.connection.Connection] = None,
277
285
  ):
278
286
  """Launch an HTTP server."""
@@ -317,7 +325,6 @@ def launch_server(
317
325
  tp_rank_range,
318
326
  server_args,
319
327
  ports[3],
320
- model_override_args,
321
328
  )
322
329
 
323
330
  try:
@@ -328,23 +335,19 @@ def launch_server(
328
335
  return
329
336
 
330
337
  # Launch processes
331
- tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args)
332
- if server_args.chat_template:
333
- load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
334
338
  pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
335
- pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
336
339
 
337
340
  if server_args.dp_size == 1:
338
341
  start_controller_process = start_controller_process_single
339
342
  else:
340
343
  start_controller_process = start_controller_process_multi
341
-
342
344
  proc_controller = mp.Process(
343
345
  target=start_controller_process,
344
- args=(server_args, port_args, pipe_controller_writer, model_override_args),
346
+ args=(server_args, port_args, pipe_controller_writer),
345
347
  )
346
348
  proc_controller.start()
347
349
 
350
+ pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
348
351
  proc_detoken = mp.Process(
349
352
  target=start_detokenizer_process,
350
353
  args=(
@@ -355,6 +358,10 @@ def launch_server(
355
358
  )
356
359
  proc_detoken.start()
357
360
 
361
+ tokenizer_manager = TokenizerManager(server_args, port_args)
362
+ if server_args.chat_template:
363
+ load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
364
+
358
365
  # Wait for the model to finish loading
359
366
  controller_init_state = pipe_controller_reader.recv()
360
367
  detoken_init_state = pipe_detoken_reader.recv()
@@ -418,7 +425,7 @@ def _set_envs_and_config(server_args: ServerArgs):
418
425
  maybe_set_triton_cache_manager()
419
426
 
420
427
  # Check flashinfer version
421
- if not server_args.disable_flashinfer:
428
+ if server_args.attention_backend == "flashinfer":
422
429
  assert_pkg_version(
423
430
  "flashinfer",
424
431
  "0.1.6",
@@ -440,13 +447,12 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
440
447
  time.sleep(1)
441
448
  try:
442
449
  res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
443
- assert res.status_code == 200, f"{res}"
450
+ assert res.status_code == 200, f"{res=}, {res.text=}"
444
451
  success = True
445
452
  break
446
- except (AssertionError, requests.exceptions.RequestException) as e:
453
+ except (AssertionError, requests.exceptions.RequestException):
447
454
  last_traceback = get_exception_traceback()
448
455
  pass
449
- model_info = res.json()
450
456
 
451
457
  if not success:
452
458
  if pipe_finish_writer is not None:
@@ -455,6 +461,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
455
461
  kill_child_process(pid, including_parent=False)
456
462
  return
457
463
 
464
+ model_info = res.json()
465
+
458
466
  # Send a warmup request
459
467
  request_name = "/generate" if model_info["is_generation"] else "/encode"
460
468
  max_new_tokens = 8 if model_info["is_generation"] else 1
@@ -501,7 +509,6 @@ class Runtime:
501
509
  def __init__(
502
510
  self,
503
511
  log_level: str = "error",
504
- model_override_args: Optional[dict] = None,
505
512
  *args,
506
513
  **kwargs,
507
514
  ):
@@ -525,7 +532,7 @@ class Runtime:
525
532
 
526
533
  proc = mp.Process(
527
534
  target=launch_server,
528
- args=(self.server_args, model_override_args, pipe_writer),
535
+ args=(self.server_args, pipe_writer),
529
536
  )
530
537
  proc.start()
531
538
  pipe_writer.close()
@@ -604,6 +611,7 @@ class Runtime:
604
611
  return_logprob: Optional[Union[List[bool], bool]] = False,
605
612
  logprob_start_len: Optional[Union[List[int], int]] = None,
606
613
  top_logprobs_num: Optional[Union[List[int], int]] = None,
614
+ lora_path: Optional[List[Optional[str]]] = None,
607
615
  ):
608
616
  json_data = {
609
617
  "text": prompt,
@@ -611,7 +619,9 @@ class Runtime:
611
619
  "return_logprob": return_logprob,
612
620
  "logprob_start_len": logprob_start_len,
613
621
  "top_logprobs_num": top_logprobs_num,
622
+ "lora_path": lora_path,
614
623
  }
624
+ assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
615
625
  response = requests.post(
616
626
  self.url + "/generate",
617
627
  json=json_data,
sglang/srt/server_args.py CHANGED
@@ -49,7 +49,6 @@ class ServerArgs:
49
49
  # Memory and scheduling
50
50
  mem_fraction_static: Optional[float] = None
51
51
  max_running_requests: Optional[int] = None
52
- max_num_reqs: Optional[int] = None
53
52
  max_total_tokens: Optional[int] = None
54
53
  chunked_prefill_size: int = 8192
55
54
  max_prefill_tokens: int = 16384
@@ -75,7 +74,18 @@ class ServerArgs:
75
74
  dp_size: int = 1
76
75
  load_balance_method: str = "round_robin"
77
76
 
77
+ # Distributed args
78
+ nccl_init_addr: Optional[str] = None
79
+ nnodes: int = 1
80
+ node_rank: Optional[int] = None
81
+
82
+ # Model override args in JSON
83
+ json_model_override_args: str = "{}"
84
+
78
85
  # Optimization/debug options
86
+ attention_backend: Optional[str] = None
87
+ sampling_backend: Optional[str] = None
88
+
79
89
  disable_flashinfer: bool = False
80
90
  disable_flashinfer_sampling: bool = False
81
91
  disable_radix_cache: bool = False
@@ -86,16 +96,17 @@ class ServerArgs:
86
96
  disable_custom_all_reduce: bool = False
87
97
  enable_mixed_chunk: bool = False
88
98
  enable_torch_compile: bool = False
99
+ torchao_config: str = ""
89
100
  enable_p2p_check: bool = False
90
101
  enable_mla: bool = False
91
102
  triton_attention_reduce_in_fp32: bool = False
92
103
 
93
- # Distributed args
94
- nccl_init_addr: Optional[str] = None
95
- nnodes: int = 1
96
- node_rank: Optional[int] = None
104
+ # LoRA
105
+ lora_paths: Optional[List[str]] = None
106
+ max_loras_per_batch: int = 8
97
107
 
98
108
  def __post_init__(self):
109
+ # Set missing default values
99
110
  if self.tokenizer_path is None:
100
111
  self.tokenizer_path = self.model_path
101
112
 
@@ -106,6 +117,7 @@ class ServerArgs:
106
117
  # Disable chunked prefill
107
118
  self.chunked_prefill_size = None
108
119
 
120
+ # Mem fraction depends on the tensor parallelism size
109
121
  if self.mem_fraction_static is None:
110
122
  if self.tp_size >= 16:
111
123
  self.mem_fraction_static = 0.79
@@ -126,6 +138,42 @@ class ServerArgs:
126
138
  if self.random_seed is None:
127
139
  self.random_seed = random.randint(0, 1 << 30)
128
140
 
141
+ # Deprecation warnings
142
+ if self.disable_flashinfer:
143
+ logger.warning(
144
+ "The option '--disable-flashinfer' will be deprecated in the next release. "
145
+ "Please use '--attention-backend triton' instead."
146
+ )
147
+ self.attention_backend = "triton"
148
+ if self.disable_flashinfer_sampling:
149
+ logger.warning(
150
+ "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
151
+ "Please use '--sampling-backend pytorch' instead. "
152
+ )
153
+ self.sampling_backend = "pytorch"
154
+
155
+ # Default kernel backends
156
+ if self.enable_mla:
157
+ logger.info("MLA optimization is tunred on. Use triton backend.")
158
+ self.attention_backend = "triton"
159
+
160
+ if self.attention_backend is None:
161
+ self.attention_backend = "flashinfer"
162
+
163
+ if self.sampling_backend is None:
164
+ self.sampling_backend = "flashinfer"
165
+
166
+ # Model-specific patches
167
+ if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
168
+ logger.info(
169
+ "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
170
+ )
171
+ self.trust_remote_code = False
172
+
173
+ if "gemma-2" in self.model_path.lower():
174
+ logger.info("When using sliding window in gemma-2, turn on flashinfer.")
175
+ self.attention_backend = "flashinfer"
176
+
129
177
  @staticmethod
130
178
  def add_cli_args(parser: argparse.ArgumentParser):
131
179
  parser.add_argument(
@@ -209,11 +257,6 @@ class ServerArgs:
209
257
  action="store_true",
210
258
  help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
211
259
  )
212
- parser.add_argument(
213
- "--is-embedding",
214
- action="store_true",
215
- help="Whether to use a CausalLM as an embedding model.",
216
- )
217
260
  parser.add_argument(
218
261
  "--context-length",
219
262
  type=int,
@@ -248,6 +291,11 @@ class ServerArgs:
248
291
  default=ServerArgs.chat_template,
249
292
  help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
250
293
  )
294
+ parser.add_argument(
295
+ "--is-embedding",
296
+ action="store_true",
297
+ help="Whether to use a CausalLM as an embedding model.",
298
+ )
251
299
  parser.add_argument(
252
300
  "--mem-fraction-static",
253
301
  type=float,
@@ -260,17 +308,12 @@ class ServerArgs:
260
308
  default=ServerArgs.max_running_requests,
261
309
  help="The maximum number of running requests.",
262
310
  )
263
- parser.add_argument(
264
- "--max-num-reqs",
265
- type=int,
266
- default=ServerArgs.max_num_reqs,
267
- help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
268
- )
269
311
  parser.add_argument(
270
312
  "--max-total-tokens",
271
313
  type=int,
272
314
  default=ServerArgs.max_total_tokens,
273
- help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
315
+ help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
316
+ "This option is typically used for development and debugging purposes.",
274
317
  )
275
318
  parser.add_argument(
276
319
  "--chunked-prefill-size",
@@ -381,16 +424,38 @@ class ServerArgs:
381
424
  )
382
425
  parser.add_argument("--node-rank", type=int, help="The node rank.")
383
426
 
427
+ # Model override args
428
+ parser.add_argument(
429
+ "--json-model-override-args",
430
+ type=str,
431
+ help="A dictionary in JSON string format used to override default model configurations.",
432
+ default=ServerArgs.json_model_override_args,
433
+ )
434
+
384
435
  # Optimization/debug options
436
+ parser.add_argument(
437
+ "--attention-backend",
438
+ type=str,
439
+ choices=["flashinfer", "triton"],
440
+ default=ServerArgs.attention_backend,
441
+ help="Choose the kernels for attention layers.",
442
+ )
443
+ parser.add_argument(
444
+ "--sampling-backend",
445
+ type=str,
446
+ choices=["flashinfer", "pytorch"],
447
+ default=ServerArgs.sampling_backend,
448
+ help="Choose the kernels for sampling layers.",
449
+ )
385
450
  parser.add_argument(
386
451
  "--disable-flashinfer",
387
452
  action="store_true",
388
- help="Disable flashinfer attention kernels.",
453
+ help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
389
454
  )
390
455
  parser.add_argument(
391
456
  "--disable-flashinfer-sampling",
392
457
  action="store_true",
393
- help="Disable flashinfer sampling kernels.",
458
+ help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
394
459
  )
395
460
  parser.add_argument(
396
461
  "--disable-radix-cache",
@@ -431,7 +496,13 @@ class ServerArgs:
431
496
  parser.add_argument(
432
497
  "--enable-torch-compile",
433
498
  action="store_true",
434
- help="Optimize the model with torch.compile, experimental feature.",
499
+ help="Optimize the model with torch.compile. Experimental feature.",
500
+ )
501
+ parser.add_argument(
502
+ "--torchao-config",
503
+ type=str,
504
+ default=ServerArgs.torchao_config,
505
+ help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
435
506
  )
436
507
  parser.add_argument(
437
508
  "--enable-p2p-check",
@@ -455,6 +526,21 @@ class ServerArgs:
455
526
  help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
456
527
  )
457
528
 
529
+ # LoRA options
530
+ parser.add_argument(
531
+ "--lora-paths",
532
+ type=str,
533
+ nargs="*",
534
+ default=None,
535
+ help="The list of LoRA adapters.",
536
+ )
537
+ parser.add_argument(
538
+ "--max-loras-per-batch",
539
+ type=int,
540
+ default=8,
541
+ help="Maximum number of adapters for a running batch, include base-only request",
542
+ )
543
+
458
544
  @classmethod
459
545
  def from_cli_args(cls, args: argparse.Namespace):
460
546
  args.tp_size = args.tensor_parallel_size
@@ -472,14 +558,30 @@ class ServerArgs:
472
558
  assert not (
473
559
  self.dp_size > 1 and self.node_rank is not None
474
560
  ), "multi-node data parallel is not supported"
475
- if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
476
- logger.info(
477
- "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
478
- )
479
- self.trust_remote_code = False
480
- if "gemma-2" in self.model_path.lower():
481
- logger.info("When using sliding window in gemma-2, turn on flashinfer.")
482
- self.disable_flashinfer = False
561
+ assert (
562
+ self.max_loras_per_batch > 0
563
+ # FIXME
564
+ and (self.lora_paths is None or self.disable_cuda_graph)
565
+ and (self.lora_paths is None or self.disable_radix_cache)
566
+ ), "compatibility of lora and cuda graph and radix attention is in progress"
567
+
568
+
569
+ def prepare_server_args(argv: List[str]) -> ServerArgs:
570
+ """
571
+ Prepare the server arguments from the command line arguments.
572
+
573
+ Args:
574
+ args: The command line arguments. Typically, it should be `sys.argv[1:]`
575
+ to ensure compatibility with `parse_args` when no arguments are passed.
576
+
577
+ Returns:
578
+ The server arguments.
579
+ """
580
+ parser = argparse.ArgumentParser()
581
+ ServerArgs.add_cli_args(parser)
582
+ raw_args = parser.parse_args(argv)
583
+ server_args = ServerArgs.from_cli_args(raw_args)
584
+ return server_args
483
585
 
484
586
 
485
587
  @dataclasses.dataclass
sglang/srt/utils.py CHANGED
@@ -35,6 +35,7 @@ import torch
35
35
  import torch.distributed as dist
36
36
  from fastapi.responses import JSONResponse
37
37
  from packaging import version as pkg_version
38
+ from torch import nn
38
39
  from torch.nn.parameter import Parameter
39
40
  from triton.runtime.cache import (
40
41
  FileCacheManager,
@@ -714,3 +715,14 @@ def configure_logger(server_args, prefix: str = ""):
714
715
  datefmt="%H:%M:%S",
715
716
  force=True,
716
717
  )
718
+
719
+
720
+ # source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
721
+ def replace_submodule(
722
+ model: nn.Module, module_name: str, new_module: nn.Module
723
+ ) -> nn.Module:
724
+ """Replace a submodule in a model with a new module."""
725
+ parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
726
+ target_name = module_name.split(".")[-1]
727
+ setattr(parent, target_name, new_module)
728
+ return new_module