sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +234 -74
  4. sglang/check_env.py +25 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -40
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/hf_transformers_utils.py +2 -2
  13. sglang/srt/layers/extend_attention.py +59 -7
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/radix_attention.py +24 -14
  16. sglang/srt/layers/token_attention.py +28 -2
  17. sglang/srt/managers/io_struct.py +9 -4
  18. sglang/srt/managers/schedule_batch.py +98 -323
  19. sglang/srt/managers/tokenizer_manager.py +34 -16
  20. sglang/srt/managers/tp_worker.py +20 -22
  21. sglang/srt/mem_cache/memory_pool.py +74 -38
  22. sglang/srt/model_config.py +11 -0
  23. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  24. sglang/srt/model_executor/forward_batch_info.py +256 -0
  25. sglang/srt/model_executor/model_runner.py +51 -26
  26. sglang/srt/models/chatglm.py +1 -1
  27. sglang/srt/models/commandr.py +1 -1
  28. sglang/srt/models/dbrx.py +1 -1
  29. sglang/srt/models/deepseek.py +1 -1
  30. sglang/srt/models/deepseek_v2.py +199 -17
  31. sglang/srt/models/gemma.py +1 -1
  32. sglang/srt/models/gemma2.py +1 -1
  33. sglang/srt/models/gpt_bigcode.py +1 -1
  34. sglang/srt/models/grok.py +1 -1
  35. sglang/srt/models/internlm2.py +1 -1
  36. sglang/srt/models/llama2.py +1 -1
  37. sglang/srt/models/llama_classification.py +1 -1
  38. sglang/srt/models/llava.py +1 -2
  39. sglang/srt/models/llavavid.py +1 -2
  40. sglang/srt/models/minicpm.py +1 -1
  41. sglang/srt/models/mixtral.py +1 -1
  42. sglang/srt/models/mixtral_quant.py +1 -1
  43. sglang/srt/models/qwen.py +1 -1
  44. sglang/srt/models/qwen2.py +1 -1
  45. sglang/srt/models/qwen2_moe.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/openai_api/adapter.py +151 -29
  48. sglang/srt/openai_api/protocol.py +7 -1
  49. sglang/srt/server.py +111 -84
  50. sglang/srt/server_args.py +12 -2
  51. sglang/srt/utils.py +25 -20
  52. sglang/test/run_eval.py +21 -10
  53. sglang/test/runners.py +237 -0
  54. sglang/test/simple_eval_common.py +12 -12
  55. sglang/test/simple_eval_gpqa.py +92 -0
  56. sglang/test/simple_eval_humaneval.py +5 -5
  57. sglang/test/simple_eval_math.py +72 -0
  58. sglang/test/test_utils.py +95 -14
  59. sglang/utils.py +15 -37
  60. sglang/version.py +1 -1
  61. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
  62. sglang-0.2.11.dist-info/RECORD +102 -0
  63. sglang-0.2.9.post1.dist-info/RECORD +0 -97
  64. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  65. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  66. {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -28,7 +28,7 @@ import sys
28
28
  import threading
29
29
  import time
30
30
  from http import HTTPStatus
31
- from typing import Dict, Optional
31
+ from typing import Dict, List, Optional, Union
32
32
 
33
33
  # Fix a bug of Python threading
34
34
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -59,6 +59,7 @@ from sglang.srt.openai_api.adapter import (
59
59
  v1_batches,
60
60
  v1_chat_completions,
61
61
  v1_completions,
62
+ v1_delete_file,
62
63
  v1_files_create,
63
64
  v1_retrieve_batch,
64
65
  v1_retrieve_file,
@@ -67,13 +68,13 @@ from sglang.srt.openai_api.adapter import (
67
68
  from sglang.srt.openai_api.protocol import ModelCard, ModelList
68
69
  from sglang.srt.server_args import PortArgs, ServerArgs
69
70
  from sglang.srt.utils import (
70
- API_KEY_HEADER_NAME,
71
- APIKeyValidatorMiddleware,
71
+ add_api_key_middleware,
72
72
  allocate_init_ports,
73
73
  assert_pkg_version,
74
74
  enable_show_time_cost,
75
75
  kill_child_process,
76
76
  maybe_set_triton_cache_manager,
77
+ set_torch_compile_config,
77
78
  set_ulimit,
78
79
  )
79
80
  from sglang.utils import get_exception_traceback
@@ -158,6 +159,16 @@ async def openai_v1_chat_completions(raw_request: Request):
158
159
  return await v1_chat_completions(tokenizer_manager, raw_request)
159
160
 
160
161
 
162
+ @app.get("/v1/models")
163
+ def available_models():
164
+ """Show available models."""
165
+ served_model_names = [tokenizer_manager.served_model_name]
166
+ model_cards = []
167
+ for served_model_name in served_model_names:
168
+ model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
169
+ return ModelList(data=model_cards)
170
+
171
+
161
172
  @app.post("/v1/files")
162
173
  async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
163
174
  return await v1_files_create(
@@ -165,6 +176,12 @@ async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("bat
165
176
  )
166
177
 
167
178
 
179
+ @app.delete("/v1/files/{file_id}")
180
+ async def delete_file(file_id: str):
181
+ # https://platform.openai.com/docs/api-reference/files/delete
182
+ return await v1_delete_file(file_id)
183
+
184
+
168
185
  @app.post("/v1/batches")
169
186
  async def openai_v1_batches(raw_request: Request):
170
187
  return await v1_batches(tokenizer_manager, raw_request)
@@ -187,69 +204,11 @@ async def retrieve_file_content(file_id: str):
187
204
  return await v1_retrieve_file_content(file_id)
188
205
 
189
206
 
190
- @app.get("/v1/models")
191
- def available_models():
192
- """Show available models."""
193
- served_model_names = [tokenizer_manager.served_model_name]
194
- model_cards = []
195
- for served_model_name in served_model_names:
196
- model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
197
- return ModelList(data=model_cards)
198
-
199
-
200
- def _set_torch_compile_config():
201
- # The following configurations are for torch compile optimizations
202
- import torch._dynamo.config
203
- import torch._inductor.config
204
-
205
- torch._inductor.config.coordinate_descent_tuning = True
206
- torch._inductor.config.triton.unique_kernel_names = True
207
- torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
208
-
209
- # FIXME: tmp workaround
210
- torch._dynamo.config.accumulated_cache_size_limit = 256
211
-
212
-
213
- def set_envs_and_config(server_args: ServerArgs):
214
- # Set global environments
215
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
216
- os.environ["NCCL_CUMEM_ENABLE"] = "0"
217
- os.environ["NCCL_NVLS_ENABLE"] = "0"
218
- os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
219
-
220
- # Set ulimit
221
- set_ulimit()
222
-
223
- # Enable show time cost for debugging
224
- if server_args.show_time_cost:
225
- enable_show_time_cost()
226
-
227
- # Disable disk cache
228
- if server_args.disable_disk_cache:
229
- disable_cache()
230
-
231
- # Fix triton bugs
232
- if server_args.tp_size * server_args.dp_size > 1:
233
- # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
234
- maybe_set_triton_cache_manager()
235
-
236
- # Set torch compile config
237
- if server_args.enable_torch_compile:
238
- _set_torch_compile_config()
239
-
240
- # Set global chat template
241
- if server_args.chat_template:
242
- # TODO: replace this with huggingface transformers template
243
- load_chat_template_for_openai_api(server_args.chat_template)
244
-
245
-
246
207
  def launch_server(
247
208
  server_args: ServerArgs,
248
209
  model_overide_args: Optional[dict] = None,
249
210
  pipe_finish_writer: Optional[mp.connection.Connection] = None,
250
211
  ):
251
- server_args.check_server_args()
252
-
253
212
  """Launch an HTTP server."""
254
213
  global tokenizer_manager
255
214
 
@@ -258,16 +217,8 @@ def launch_server(
258
217
  format="%(message)s",
259
218
  )
260
219
 
261
- if not server_args.disable_flashinfer:
262
- assert_pkg_version(
263
- "flashinfer",
264
- "0.1.3",
265
- "Please uninstall the old version and "
266
- "reinstall the latest version by following the instructions "
267
- "at https://docs.flashinfer.ai/installation.html.",
268
- )
269
-
270
- set_envs_and_config(server_args)
220
+ server_args.check_server_args()
221
+ _set_envs_and_config(server_args)
271
222
 
272
223
  # Allocate ports
273
224
  server_args.port, server_args.additional_ports = allocate_init_ports(
@@ -284,7 +235,7 @@ def launch_server(
284
235
  )
285
236
  logger.info(f"{server_args=}")
286
237
 
287
- # Handle multi-node tensor parallelism
238
+ # Launch processes for multi-node tensor parallelism
288
239
  if server_args.nnodes > 1:
289
240
  if server_args.node_rank != 0:
290
241
  tp_size_local = server_args.tp_size // server_args.nnodes
@@ -349,8 +300,9 @@ def launch_server(
349
300
  sys.exit(1)
350
301
  assert proc_controller.is_alive() and proc_detoken.is_alive()
351
302
 
352
- if server_args.api_key and server_args.api_key != "":
353
- app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
303
+ # Add api key authorization
304
+ if server_args.api_key:
305
+ add_api_key_middleware(app, server_args.api_key)
354
306
 
355
307
  # Send a warmup request
356
308
  t = threading.Thread(
@@ -372,21 +324,74 @@ def launch_server(
372
324
  t.join()
373
325
 
374
326
 
327
+ def _set_envs_and_config(server_args: ServerArgs):
328
+ # Set global environments
329
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
330
+ os.environ["NCCL_CUMEM_ENABLE"] = "0"
331
+ os.environ["NCCL_NVLS_ENABLE"] = "0"
332
+ os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
333
+
334
+ # Set ulimit
335
+ set_ulimit()
336
+
337
+ # Enable show time cost for debugging
338
+ if server_args.show_time_cost:
339
+ enable_show_time_cost()
340
+
341
+ # Disable disk cache
342
+ if server_args.disable_disk_cache:
343
+ disable_cache()
344
+
345
+ # Fix triton bugs
346
+ if server_args.tp_size * server_args.dp_size > 1:
347
+ # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
348
+ maybe_set_triton_cache_manager()
349
+
350
+ # Set torch compile config
351
+ if server_args.enable_torch_compile:
352
+ set_torch_compile_config()
353
+
354
+ # Set global chat template
355
+ if server_args.chat_template:
356
+ # TODO: replace this with huggingface transformers template
357
+ load_chat_template_for_openai_api(server_args.chat_template)
358
+
359
+ # Check flashinfer version
360
+ if not server_args.disable_flashinfer:
361
+ assert_pkg_version(
362
+ "flashinfer",
363
+ "0.1.3",
364
+ "Please uninstall the old version and "
365
+ "reinstall the latest version by following the instructions "
366
+ "at https://docs.flashinfer.ai/installation.html.",
367
+ )
368
+
369
+
375
370
  def _wait_and_warmup(server_args, pipe_finish_writer):
376
371
  headers = {}
377
372
  url = server_args.url()
378
373
  if server_args.api_key:
379
- headers[API_KEY_HEADER_NAME] = server_args.api_key
374
+ headers["Authorization"] = f"Bearer {server_args.api_key}"
380
375
 
381
376
  # Wait until the server is launched
377
+ success = False
382
378
  for _ in range(120):
383
- time.sleep(0.5)
379
+ time.sleep(1)
384
380
  try:
385
- requests.get(url + "/get_model_info", timeout=5, headers=headers)
381
+ res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
382
+ assert res.status_code == 200, f"{res}"
383
+ success = True
386
384
  break
387
- except requests.exceptions.RequestException:
385
+ except (AssertionError, requests.exceptions.RequestException) as e:
386
+ last_traceback = get_exception_traceback()
388
387
  pass
389
388
 
389
+ if not success:
390
+ if pipe_finish_writer is not None:
391
+ pipe_finish_writer.send(last_traceback)
392
+ print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
393
+ sys.exit(1)
394
+
390
395
  # Send a warmup request
391
396
  try:
392
397
  for _ in range(server_args.dp_size):
@@ -402,12 +407,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
402
407
  headers=headers,
403
408
  timeout=600,
404
409
  )
405
- assert res.status_code == 200
410
+ assert res.status_code == 200, f"{res}"
406
411
  except Exception as e:
412
+ last_traceback = get_exception_traceback()
407
413
  if pipe_finish_writer is not None:
408
- pipe_finish_writer.send(get_exception_traceback())
409
- print(f"Initialization failed. warmup error: {e}", flush=True)
410
- raise e
414
+ pipe_finish_writer.send(last_traceback)
415
+ print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
416
+ sys.exit(1)
411
417
 
412
418
  logger.info("The server is fired up and ready to roll!")
413
419
  if pipe_finish_writer is not None:
@@ -481,10 +487,10 @@ class Runtime:
481
487
  trust_remote_code=self.server_args.trust_remote_code,
482
488
  )
483
489
 
484
- async def add_request(
490
+ async def async_generate(
485
491
  self,
486
492
  prompt: str,
487
- sampling_params: Dict,
493
+ sampling_params: Optional[Dict] = None,
488
494
  ):
489
495
  json_data = {
490
496
  "text": prompt,
@@ -507,5 +513,26 @@ class Runtime:
507
513
  yield cur
508
514
  pos += len(cur)
509
515
 
516
+ add_request = async_generate
517
+
518
+ def generate(
519
+ self,
520
+ prompt: str,
521
+ sampling_params: Optional[Dict] = None,
522
+ return_logprob: Optional[Union[List[bool], bool]] = False,
523
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
524
+ ):
525
+ json_data = {
526
+ "text": prompt,
527
+ "sampling_params": sampling_params,
528
+ "return_logprob": return_logprob,
529
+ "top_logprobs_num": top_logprobs_num,
530
+ }
531
+ response = requests.post(
532
+ self.url + "/generate",
533
+ json=json_data,
534
+ )
535
+ return json.dumps(response.json())
536
+
510
537
  def __del__(self):
511
538
  self.shutdown()
sglang/srt/server_args.py CHANGED
@@ -61,7 +61,7 @@ class ServerArgs:
61
61
  show_time_cost: bool = False
62
62
 
63
63
  # Other
64
- api_key: str = ""
64
+ api_key: Optional[str] = None
65
65
  file_storage_pth: str = "SGlang_storage"
66
66
 
67
67
  # Data parallelism
@@ -80,6 +80,7 @@ class ServerArgs:
80
80
  disable_disk_cache: bool = False
81
81
  enable_torch_compile: bool = False
82
82
  enable_p2p_check: bool = False
83
+ enable_mla: bool = False
83
84
  attention_reduce_in_fp32: bool = False
84
85
  efficient_weight_load: bool = False
85
86
 
@@ -263,6 +264,7 @@ class ServerArgs:
263
264
  help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
264
265
  )
265
266
  parser.add_argument(
267
+ "--tensor-parallel-size",
266
268
  "--tp-size",
267
269
  type=int,
268
270
  default=ServerArgs.tp_size,
@@ -306,7 +308,7 @@ class ServerArgs:
306
308
  "--api-key",
307
309
  type=str,
308
310
  default=ServerArgs.api_key,
309
- help="Set API key of the server.",
311
+ help="Set API key of the server. It is also used in the OpenAI API compatible server.",
310
312
  )
311
313
  parser.add_argument(
312
314
  "--file-storage-pth",
@@ -317,6 +319,7 @@ class ServerArgs:
317
319
 
318
320
  # Data parallelism
319
321
  parser.add_argument(
322
+ "--data-parallel-size",
320
323
  "--dp-size",
321
324
  type=int,
322
325
  default=ServerArgs.dp_size,
@@ -393,6 +396,11 @@ class ServerArgs:
393
396
  action="store_true",
394
397
  help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
395
398
  )
399
+ parser.add_argument(
400
+ "--enable-mla",
401
+ action="store_true",
402
+ help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
403
+ )
396
404
  parser.add_argument(
397
405
  "--attention-reduce-in-fp32",
398
406
  action="store_true",
@@ -407,6 +415,8 @@ class ServerArgs:
407
415
 
408
416
  @classmethod
409
417
  def from_cli_args(cls, args: argparse.Namespace):
418
+ args.tp_size = args.tensor_parallel_size
419
+ args.dp_size = args.data_parallel_size
410
420
  attrs = [attr.name for attr in dataclasses.fields(cls)]
411
421
  return cls(**{attr: getattr(args, attr) for attr in attrs})
412
422
 
sglang/srt/utils.py CHANGED
@@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager):
539
539
  raise RuntimeError("Could not create or locate cache dir")
540
540
 
541
541
 
542
- API_KEY_HEADER_NAME = "X-API-Key"
543
-
544
-
545
- class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
546
- def __init__(self, app, api_key: str):
547
- super().__init__(app)
548
- self.api_key = api_key
549
-
550
- async def dispatch(self, request, call_next):
551
- # extract API key from the request headers
552
- api_key_header = request.headers.get(API_KEY_HEADER_NAME)
553
- if not api_key_header or api_key_header != self.api_key:
554
- return JSONResponse(
555
- status_code=403,
556
- content={"detail": "Invalid API Key"},
557
- )
558
- response = await call_next(request)
559
- return response
560
-
561
-
562
542
  def get_ip_address(ifname):
563
543
  """
564
544
  Get the IP address of a network interface.
@@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args):
642
622
  dist.destroy_process_group()
643
623
 
644
624
 
625
+ def set_torch_compile_config():
626
+ # The following configurations are for torch compile optimizations
627
+ import torch._dynamo.config
628
+ import torch._inductor.config
629
+
630
+ torch._inductor.config.coordinate_descent_tuning = True
631
+ torch._inductor.config.triton.unique_kernel_names = True
632
+ torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
633
+
634
+ # FIXME: tmp workaround
635
+ torch._dynamo.config.accumulated_cache_size_limit = 256
636
+
637
+
645
638
  def set_ulimit(target_soft_limit=65535):
646
639
  resource_type = resource.RLIMIT_NOFILE
647
640
  current_soft, current_hard = resource.getrlimit(resource_type)
@@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader():
700
693
  origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
701
694
 
702
695
  setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
696
+
697
+
698
+ def add_api_key_middleware(app, api_key):
699
+ @app.middleware("http")
700
+ async def authentication(request, call_next):
701
+ if request.method == "OPTIONS":
702
+ return await call_next(request)
703
+ if request.url.path.startswith("/health"):
704
+ return await call_next(request)
705
+ if request.headers.get("Authorization") != "Bearer " + api_key:
706
+ return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
707
+ return await call_next(request)
sglang/test/run_eval.py CHANGED
@@ -10,7 +10,6 @@ import time
10
10
 
11
11
  from sglang.test.simple_eval_common import (
12
12
  ChatCompletionSampler,
13
- download_dataset,
14
13
  make_report,
15
14
  set_ulimit,
16
15
  )
@@ -27,14 +26,26 @@ def run_eval(args):
27
26
  if args.eval_name == "mmlu":
28
27
  from sglang.test.simple_eval_mmlu import MMLUEval
29
28
 
30
- dataset_path = "mmlu.csv"
31
-
32
- if not os.path.exists(dataset_path):
33
- download_dataset(
34
- dataset_path,
35
- "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
36
- )
37
- eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
29
+ filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
30
+ eval_obj = MMLUEval(filename, args.num_examples, args.num_threads)
31
+ elif args.eval_name == "math":
32
+ from sglang.test.simple_eval_math import MathEval
33
+
34
+ equality_checker = ChatCompletionSampler(model="gpt-4-turbo")
35
+
36
+ filename = (
37
+ "https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv"
38
+ )
39
+ eval_obj = MathEval(
40
+ filename, equality_checker, args.num_examples, args.num_threads
41
+ )
42
+ elif args.eval_name == "gpqa":
43
+ from sglang.test.simple_eval_gpqa import GPQAEval
44
+
45
+ filename = (
46
+ "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
47
+ )
48
+ eval_obj = GPQAEval(filename, args.num_examples, args.num_threads)
38
49
  elif args.eval_name == "humaneval":
39
50
  from sglang.test.simple_eval_humaneval import HumanEval
40
51
 
@@ -97,7 +108,7 @@ if __name__ == "__main__":
97
108
  )
98
109
  parser.add_argument("--eval-name", type=str, default="mmlu")
99
110
  parser.add_argument("--num-examples", type=int)
100
- parser.add_argument("--num-threads", type=int, default=64)
111
+ parser.add_argument("--num-threads", type=int, default=512)
101
112
  set_ulimit()
102
113
  args = parser.parse_args()
103
114