sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -82,6 +82,14 @@ class StreamOptions(BaseModel):
82
82
  include_usage: Optional[bool] = False
83
83
 
84
84
 
85
+ class JsonSchemaResponseFormat(BaseModel):
86
+ name: str
87
+ description: Optional[str] = None
88
+ # use alias to workaround pydantic conflict
89
+ schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
90
+ strict: Optional[bool] = False
91
+
92
+
85
93
  class FileRequest(BaseModel):
86
94
  # https://platform.openai.com/docs/api-reference/files/create
87
95
  file: bytes # The File object (not file name) to be uploaded
@@ -213,6 +221,7 @@ class ChatCompletionMessageContentImageURL(BaseModel):
213
221
  class ChatCompletionMessageContentImagePart(BaseModel):
214
222
  type: Literal["image_url"]
215
223
  image_url: ChatCompletionMessageContentImageURL
224
+ modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
216
225
 
217
226
 
218
227
  ChatCompletionMessageContentPart = Union[
@@ -236,8 +245,8 @@ ChatCompletionMessageParam = Union[
236
245
 
237
246
 
238
247
  class ResponseFormat(BaseModel):
239
- # type must be "json_object" or "text"
240
- type: Literal["text", "json_object"]
248
+ type: Literal["text", "json_object", "json_schema"]
249
+ json_schema: Optional[JsonSchemaResponseFormat] = None
241
250
 
242
251
 
243
252
  class ChatCompletionRequest(BaseModel):
@@ -263,7 +272,6 @@ class ChatCompletionRequest(BaseModel):
263
272
 
264
273
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
265
274
  regex: Optional[str] = None
266
- json_schema: Optional[str] = None
267
275
  min_tokens: Optional[int] = 0
268
276
  repetition_penalty: Optional[float] = 1.0
269
277
  stop_token_ids: Optional[List[int]] = Field(default_factory=list)
@@ -34,66 +34,26 @@ class SamplingBatchInfo:
34
34
  linear_penalties: torch.Tensor = None
35
35
  scaling_penalties: torch.Tensor = None
36
36
 
37
- def can_run_in_cuda_graph(self):
38
- # Vocab bias and min_ps are not supported in CUDA graph
39
- return (
40
- self.logit_bias is None
41
- and self.vocab_mask is None
42
- and self.linear_penalties is None
43
- and self.scaling_penalties is None
44
- and not self.need_min_p_sampling
45
- )
46
-
47
- @classmethod
48
- def dummy_one(cls, max_bs: int, vocab_size: int):
49
- ret = cls(vocab_size=vocab_size)
50
- ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
51
- ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
52
- ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
53
- return ret
54
-
55
- def __getitem__(self, key):
56
- if isinstance(key, slice):
57
- # NOTE:This method is only used in CUDA graph
58
- assert self.can_run_in_cuda_graph()
59
- return SamplingBatchInfo(
60
- vocab_size=self.vocab_size,
61
- temperatures=self.temperatures[key],
62
- top_ps=self.top_ps[key],
63
- top_ks=self.top_ks[key],
64
- )
65
- else:
66
- raise NotImplementedError
67
-
68
- def inplace_assign(self, bs: int, other: SamplingBatchInfo):
69
- # NOTE:This method is only used in CUDA graph
70
- assert self.can_run_in_cuda_graph()
71
-
72
- self.vocab_size = other.vocab_size
73
- self.temperatures[:bs] = other.temperatures
74
- self.top_ps[:bs] = other.top_ps
75
- self.top_ks[:bs] = other.top_ks
76
-
77
37
  @classmethod
78
38
  def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
79
- device = "cuda"
80
39
  reqs = batch.reqs
81
40
  ret = cls(vocab_size=vocab_size)
82
41
 
83
- ret.temperatures = torch.tensor(
84
- [r.sampling_params.temperature for r in reqs],
85
- dtype=torch.float,
86
- device=device,
87
- ).view(-1, 1)
88
- ret.top_ps = torch.tensor(
89
- [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
90
- )
91
- ret.top_ks = torch.tensor(
92
- [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
93
- )
94
- ret.min_ps = torch.tensor(
95
- [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
96
- )
42
+ with torch.device("cuda"):
43
+ ret.temperatures = torch.tensor(
44
+ [r.sampling_params.temperature for r in reqs],
45
+ dtype=torch.float,
46
+ ).view(-1, 1)
47
+ ret.top_ps = torch.tensor(
48
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float
49
+ )
50
+ ret.top_ks = torch.tensor(
51
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int
52
+ )
53
+ ret.min_ps = torch.tensor(
54
+ [r.sampling_params.min_p for r in reqs], dtype=torch.float
55
+ )
56
+
97
57
  ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
98
58
 
99
59
  # Each penalizers will do nothing if they evaluate themselves as not required by looking at
@@ -106,7 +66,7 @@ class SamplingBatchInfo:
106
66
  ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
107
67
  vocab_size=vocab_size,
108
68
  batch=batch,
109
- device=device,
69
+ device="cuda",
110
70
  Penalizers={
111
71
  penaltylib.BatchedFrequencyPenalizer,
112
72
  penaltylib.BatchedMinNewTokensPenalizer,
@@ -118,11 +78,12 @@ class SamplingBatchInfo:
118
78
  # Handle logit bias but only allocate when needed
119
79
  ret.logit_bias = None
120
80
 
121
- ret.update_regex_vocab_mask(batch)
122
-
123
81
  return ret
124
82
 
125
- def prepare_penalties(self):
83
+ def __len__(self):
84
+ return len(self.temperatures)
85
+
86
+ def update_penalties(self):
126
87
  self.scaling_penalties = None
127
88
  self.linear_penalties = None
128
89
 
@@ -142,18 +103,16 @@ class SamplingBatchInfo:
142
103
  self.linear_penalties = penalizer.apply(self.linear_penalties)
143
104
 
144
105
  def update_regex_vocab_mask(self, batch: ScheduleBatch):
145
- bs, reqs = batch.batch_size(), batch.reqs
146
- device = "cuda"
147
- has_regex = any(req.regex_fsm is not None for req in reqs)
106
+ has_regex = any(req.regex_fsm is not None for req in batch.reqs)
148
107
 
149
108
  # Reset the vocab mask
150
109
  self.vocab_mask = None
151
110
 
152
111
  if has_regex:
153
112
  self.vocab_mask = torch.zeros(
154
- bs, self.vocab_size, dtype=torch.bool, device=device
113
+ batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
155
114
  )
156
- for i, req in enumerate(reqs):
115
+ for i, req in enumerate(batch.reqs):
157
116
  if req.regex_fsm is not None:
158
117
  self.vocab_mask[i].fill_(1)
159
118
  self.vocab_mask[i][
@@ -174,6 +133,26 @@ class SamplingBatchInfo:
174
133
  if self_val is not None: # logit_bias can be None
175
134
  setattr(self, item, self_val[new_indices])
176
135
 
136
+ @staticmethod
137
+ def merge_bias_tensor(
138
+ lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
139
+ ):
140
+ # bias tensor can be None
141
+ if lhs is not None or rhs is not None:
142
+ shape, dtype = None, None
143
+ if lhs is not None:
144
+ shape, dtype = lhs.shape[1:], lhs.dtype
145
+ else:
146
+ shape, dtype = rhs.shape[1:], rhs.dtype
147
+ with torch.dtype(dtype):
148
+ if lhs is None:
149
+ lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
150
+ if rhs is None:
151
+ rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
152
+ return torch.cat([lhs, rhs])
153
+
154
+ return None
155
+
177
156
  def merge(self, other: "SamplingBatchInfo"):
178
157
  self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
179
158
 
@@ -187,19 +166,6 @@ class SamplingBatchInfo:
187
166
  other_val = getattr(other, item, None)
188
167
  setattr(self, item, torch.concat([self_val, other_val]))
189
168
 
190
- # logit_bias can be None
191
- if self.logit_bias is not None or other.logit_bias is not None:
192
- vocab_size = (
193
- self.logit_bias.shape[1]
194
- if self.logit_bias is not None
195
- else other.logit_bias.shape[1]
196
- )
197
- if self.logit_bias is None:
198
- self.logit_bias = torch.zeros(
199
- (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
200
- )
201
- if other.logit_bias is None:
202
- other.logit_bias = torch.zeros(
203
- (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
204
- )
205
- self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
169
+ self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
170
+ self.logit_bias, other.logit_bias, len(self), len(other)
171
+ )
sglang/srt/server.py CHANGED
@@ -37,6 +37,7 @@ import requests
37
37
  import uvicorn
38
38
  import uvloop
39
39
  from fastapi import FastAPI, File, Form, Request, UploadFile
40
+ from fastapi.middleware.cors import CORSMiddleware
40
41
  from fastapi.responses import JSONResponse, Response, StreamingResponse
41
42
 
42
43
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
@@ -77,6 +78,7 @@ from sglang.srt.utils import (
77
78
  assert_pkg_version,
78
79
  configure_logger,
79
80
  enable_show_time_cost,
81
+ is_hip,
80
82
  kill_child_process,
81
83
  maybe_set_triton_cache_manager,
82
84
  prepare_model,
@@ -93,6 +95,14 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
93
95
  app = FastAPI()
94
96
  tokenizer_manager = None
95
97
 
98
+ app.add_middleware(
99
+ CORSMiddleware,
100
+ allow_origins=["*"],
101
+ allow_credentials=True,
102
+ allow_methods=["*"],
103
+ allow_headers=["*"],
104
+ )
105
+
96
106
 
97
107
  @app.get("/health")
98
108
  async def health() -> Response:
@@ -143,7 +153,7 @@ async def flush_cache():
143
153
  async def update_weights(obj: UpdateWeightReqInput, request: Request):
144
154
 
145
155
  success, message = await tokenizer_manager.update_weights(obj, request)
146
- content = {"message": message, "success": str(success)}
156
+ content = {"success": success, "message": message}
147
157
  if success:
148
158
  return JSONResponse(
149
159
  content,
@@ -272,7 +282,6 @@ async def retrieve_file_content(file_id: str):
272
282
 
273
283
  def launch_server(
274
284
  server_args: ServerArgs,
275
- model_override_args: Optional[dict] = None,
276
285
  pipe_finish_writer: Optional[mp.connection.Connection] = None,
277
286
  ):
278
287
  """Launch an HTTP server."""
@@ -317,7 +326,6 @@ def launch_server(
317
326
  tp_rank_range,
318
327
  server_args,
319
328
  ports[3],
320
- model_override_args,
321
329
  )
322
330
 
323
331
  try:
@@ -328,23 +336,19 @@ def launch_server(
328
336
  return
329
337
 
330
338
  # Launch processes
331
- tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args)
332
- if server_args.chat_template:
333
- load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
334
339
  pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
335
- pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
336
340
 
337
341
  if server_args.dp_size == 1:
338
342
  start_controller_process = start_controller_process_single
339
343
  else:
340
344
  start_controller_process = start_controller_process_multi
341
-
342
345
  proc_controller = mp.Process(
343
346
  target=start_controller_process,
344
- args=(server_args, port_args, pipe_controller_writer, model_override_args),
347
+ args=(server_args, port_args, pipe_controller_writer),
345
348
  )
346
349
  proc_controller.start()
347
350
 
351
+ pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
348
352
  proc_detoken = mp.Process(
349
353
  target=start_detokenizer_process,
350
354
  args=(
@@ -355,6 +359,10 @@ def launch_server(
355
359
  )
356
360
  proc_detoken.start()
357
361
 
362
+ tokenizer_manager = TokenizerManager(server_args, port_args)
363
+ if server_args.chat_template:
364
+ load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
365
+
358
366
  # Wait for the model to finish loading
359
367
  controller_init_state = pipe_controller_reader.recv()
360
368
  detoken_init_state = pipe_detoken_reader.recv()
@@ -418,7 +426,7 @@ def _set_envs_and_config(server_args: ServerArgs):
418
426
  maybe_set_triton_cache_manager()
419
427
 
420
428
  # Check flashinfer version
421
- if not server_args.disable_flashinfer:
429
+ if server_args.attention_backend == "flashinfer":
422
430
  assert_pkg_version(
423
431
  "flashinfer",
424
432
  "0.1.6",
@@ -427,6 +435,10 @@ def _set_envs_and_config(server_args: ServerArgs):
427
435
  "at https://docs.flashinfer.ai/installation.html.",
428
436
  )
429
437
 
438
+ if is_hip():
439
+ # to figure out a better method of not using fork later
440
+ mp.set_start_method("spawn", force=True)
441
+
430
442
 
431
443
  def _wait_and_warmup(server_args, pipe_finish_writer, pid):
432
444
  headers = {}
@@ -440,13 +452,12 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
440
452
  time.sleep(1)
441
453
  try:
442
454
  res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
443
- assert res.status_code == 200, f"{res}"
455
+ assert res.status_code == 200, f"{res=}, {res.text=}"
444
456
  success = True
445
457
  break
446
- except (AssertionError, requests.exceptions.RequestException) as e:
458
+ except (AssertionError, requests.exceptions.RequestException):
447
459
  last_traceback = get_exception_traceback()
448
460
  pass
449
- model_info = res.json()
450
461
 
451
462
  if not success:
452
463
  if pipe_finish_writer is not None:
@@ -455,6 +466,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
455
466
  kill_child_process(pid, including_parent=False)
456
467
  return
457
468
 
469
+ model_info = res.json()
470
+
458
471
  # Send a warmup request
459
472
  request_name = "/generate" if model_info["is_generation"] else "/encode"
460
473
  max_new_tokens = 8 if model_info["is_generation"] else 1
@@ -501,7 +514,6 @@ class Runtime:
501
514
  def __init__(
502
515
  self,
503
516
  log_level: str = "error",
504
- model_override_args: Optional[dict] = None,
505
517
  *args,
506
518
  **kwargs,
507
519
  ):
@@ -525,7 +537,7 @@ class Runtime:
525
537
 
526
538
  proc = mp.Process(
527
539
  target=launch_server,
528
- args=(self.server_args, model_override_args, pipe_writer),
540
+ args=(self.server_args, pipe_writer),
529
541
  )
530
542
  proc.start()
531
543
  pipe_writer.close()
@@ -604,6 +616,7 @@ class Runtime:
604
616
  return_logprob: Optional[Union[List[bool], bool]] = False,
605
617
  logprob_start_len: Optional[Union[List[int], int]] = None,
606
618
  top_logprobs_num: Optional[Union[List[int], int]] = None,
619
+ lora_path: Optional[List[Optional[str]]] = None,
607
620
  ):
608
621
  json_data = {
609
622
  "text": prompt,
@@ -611,7 +624,9 @@ class Runtime:
611
624
  "return_logprob": return_logprob,
612
625
  "logprob_start_len": logprob_start_len,
613
626
  "top_logprobs_num": top_logprobs_num,
627
+ "lora_path": lora_path,
614
628
  }
629
+ assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
615
630
  response = requests.post(
616
631
  self.url + "/generate",
617
632
  json=json_data,