sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_latency.py +28 -10
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/layers/attention/__init__.py +27 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  7. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  8. sglang/srt/layers/attention/triton_backend.py +6 -4
  9. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  12. sglang/srt/layers/sampler.py +6 -2
  13. sglang/srt/managers/detokenizer_manager.py +31 -10
  14. sglang/srt/managers/io_struct.py +4 -0
  15. sglang/srt/managers/schedule_batch.py +120 -43
  16. sglang/srt/managers/schedule_policy.py +2 -1
  17. sglang/srt/managers/scheduler.py +202 -140
  18. sglang/srt/managers/tokenizer_manager.py +5 -1
  19. sglang/srt/managers/tp_worker.py +111 -1
  20. sglang/srt/mem_cache/chunk_cache.py +8 -4
  21. sglang/srt/mem_cache/memory_pool.py +77 -4
  22. sglang/srt/mem_cache/radix_cache.py +15 -7
  23. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  24. sglang/srt/model_executor/forward_batch_info.py +16 -21
  25. sglang/srt/model_executor/model_runner.py +60 -1
  26. sglang/srt/models/baichuan.py +2 -3
  27. sglang/srt/models/chatglm.py +5 -6
  28. sglang/srt/models/commandr.py +1 -2
  29. sglang/srt/models/dbrx.py +1 -2
  30. sglang/srt/models/deepseek.py +4 -5
  31. sglang/srt/models/deepseek_v2.py +5 -6
  32. sglang/srt/models/exaone.py +1 -2
  33. sglang/srt/models/gemma.py +2 -2
  34. sglang/srt/models/gemma2.py +5 -5
  35. sglang/srt/models/gpt_bigcode.py +5 -5
  36. sglang/srt/models/grok.py +1 -2
  37. sglang/srt/models/internlm2.py +1 -2
  38. sglang/srt/models/llama.py +1 -2
  39. sglang/srt/models/llama_classification.py +1 -2
  40. sglang/srt/models/llama_reward.py +2 -3
  41. sglang/srt/models/llava.py +4 -8
  42. sglang/srt/models/llavavid.py +1 -2
  43. sglang/srt/models/minicpm.py +1 -2
  44. sglang/srt/models/minicpm3.py +5 -6
  45. sglang/srt/models/mixtral.py +1 -2
  46. sglang/srt/models/mixtral_quant.py +1 -2
  47. sglang/srt/models/olmo.py +352 -0
  48. sglang/srt/models/olmoe.py +1 -2
  49. sglang/srt/models/qwen.py +1 -2
  50. sglang/srt/models/qwen2.py +1 -2
  51. sglang/srt/models/qwen2_moe.py +4 -5
  52. sglang/srt/models/stablelm.py +1 -2
  53. sglang/srt/models/torch_native_llama.py +1 -2
  54. sglang/srt/models/xverse.py +1 -2
  55. sglang/srt/models/xverse_moe.py +4 -5
  56. sglang/srt/models/yivl.py +1 -2
  57. sglang/srt/openai_api/adapter.py +92 -49
  58. sglang/srt/openai_api/protocol.py +10 -2
  59. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  60. sglang/srt/sampling/sampling_batch_info.py +92 -58
  61. sglang/srt/sampling/sampling_params.py +2 -0
  62. sglang/srt/server.py +116 -17
  63. sglang/srt/server_args.py +121 -45
  64. sglang/srt/utils.py +11 -3
  65. sglang/test/few_shot_gsm8k.py +4 -1
  66. sglang/test/few_shot_gsm8k_engine.py +144 -0
  67. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  68. sglang/version.py +1 -1
  69. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
  70. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
  71. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  72. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  73. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  74. {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -25,11 +25,12 @@ import json
25
25
  import logging
26
26
  import multiprocessing as mp
27
27
  import os
28
- import random
29
28
  import threading
30
29
  import time
31
30
  from http import HTTPStatus
32
- from typing import Dict, List, Optional, Union
31
+ from typing import AsyncIterator, Dict, List, Optional, Union
32
+
33
+ import orjson
33
34
 
34
35
  # Fix a bug of Python threading
35
36
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -40,7 +41,8 @@ import uvicorn
40
41
  import uvloop
41
42
  from fastapi import FastAPI, File, Form, Request, UploadFile
42
43
  from fastapi.middleware.cors import CORSMiddleware
43
- from fastapi.responses import JSONResponse, Response, StreamingResponse
44
+ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
45
+ from uvicorn.config import LOGGING_CONFIG
44
46
 
45
47
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
46
48
  from sglang.srt.hf_transformers_utils import get_tokenizer
@@ -176,12 +178,12 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
176
178
  success, message = await tokenizer_manager.update_weights(obj, request)
177
179
  content = {"success": success, "message": message}
178
180
  if success:
179
- return JSONResponse(
181
+ return ORJSONResponse(
180
182
  content,
181
183
  status_code=HTTPStatus.OK,
182
184
  )
183
185
  else:
184
- return JSONResponse(
186
+ return ORJSONResponse(
185
187
  content,
186
188
  status_code=HTTPStatus.BAD_REQUEST,
187
189
  )
@@ -192,14 +194,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
192
194
  """Handle a generate request."""
193
195
  if obj.stream:
194
196
 
195
- async def stream_results():
197
+ async def stream_results() -> AsyncIterator[bytes]:
196
198
  try:
197
199
  async for out in tokenizer_manager.generate_request(obj, request):
198
- yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
200
+ yield b"data: " + orjson.dumps(
201
+ out, option=orjson.OPT_NON_STR_KEYS
202
+ ) + b"\n\n"
199
203
  except ValueError as e:
200
204
  out = {"error": {"message": str(e)}}
201
- yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
202
- yield "data: [DONE]\n\n"
205
+ yield b"data: " + orjson.dumps(
206
+ out, option=orjson.OPT_NON_STR_KEYS
207
+ ) + b"\n\n"
208
+ yield b"data: [DONE]\n\n"
203
209
 
204
210
  return StreamingResponse(
205
211
  stream_results(),
@@ -211,7 +217,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
211
217
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
212
218
  return ret
213
219
  except ValueError as e:
214
- return JSONResponse(
220
+ return ORJSONResponse(
215
221
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
216
222
  )
217
223
 
@@ -226,7 +232,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
226
232
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
227
233
  return ret
228
234
  except ValueError as e:
229
- return JSONResponse(
235
+ return ORJSONResponse(
230
236
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
231
237
  )
232
238
 
@@ -241,7 +247,7 @@ async def judge_request(obj: RewardReqInput, request: Request):
241
247
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
242
248
  return ret
243
249
  except ValueError as e:
244
- return JSONResponse(
250
+ return ORJSONResponse(
245
251
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
246
252
  )
247
253
 
@@ -260,13 +266,13 @@ async def openai_v1_chat_completions(raw_request: Request):
260
266
  return await v1_chat_completions(tokenizer_manager, raw_request)
261
267
 
262
268
 
263
- @app.post("/v1/embeddings")
269
+ @app.post("/v1/embeddings", response_class=ORJSONResponse)
264
270
  async def openai_v1_embeddings(raw_request: Request):
265
271
  response = await v1_embeddings(tokenizer_manager, raw_request)
266
272
  return response
267
273
 
268
274
 
269
- @app.get("/v1/models")
275
+ @app.get("/v1/models", response_class=ORJSONResponse)
270
276
  def available_models():
271
277
  """Show available models."""
272
278
  served_model_names = [tokenizer_manager.served_model_name]
@@ -429,6 +435,14 @@ def launch_server(
429
435
 
430
436
  try:
431
437
  # Listen for HTTP requests
438
+ LOGGING_CONFIG["formatters"]["default"][
439
+ "fmt"
440
+ ] = "[%(asctime)s] %(levelprefix)s %(message)s"
441
+ LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
442
+ LOGGING_CONFIG["formatters"]["access"][
443
+ "fmt"
444
+ ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
445
+ LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
432
446
  uvicorn.run(
433
447
  app,
434
448
  host=server_args.host,
@@ -447,7 +461,7 @@ def _set_envs_and_config(server_args: ServerArgs):
447
461
  os.environ["NCCL_CUMEM_ENABLE"] = "0"
448
462
  os.environ["NCCL_NVLS_ENABLE"] = "0"
449
463
  os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
450
- os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
464
+ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
451
465
 
452
466
  # Set ulimit
453
467
  set_ulimit()
@@ -528,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
528
542
  kill_child_process(pid, including_parent=False)
529
543
  return
530
544
 
545
+ # logger.info(f"{res.json()=}")
546
+
531
547
  logger.info("The server is fired up and ready to roll!")
532
548
  if pipe_finish_writer is not None:
533
549
  pipe_finish_writer.send("ready")
@@ -692,6 +708,10 @@ class Runtime:
692
708
  self.shutdown()
693
709
 
694
710
 
711
+ STREAM_END_SYMBOL = b"data: [DONE]"
712
+ STREAM_CHUNK_START_SYMBOL = b"data:"
713
+
714
+
695
715
  class Engine:
696
716
  """
697
717
  SRT Engine without an HTTP server layer.
@@ -716,7 +736,10 @@ class Engine:
716
736
  logprob_start_len: Optional[Union[List[int], int]] = None,
717
737
  top_logprobs_num: Optional[Union[List[int], int]] = None,
718
738
  lora_path: Optional[List[Optional[str]]] = None,
739
+ stream: bool = False,
719
740
  ):
741
+ # TODO (ByronHsu): refactor to reduce the duplicated code
742
+
720
743
  obj = GenerateReqInput(
721
744
  text=prompt,
722
745
  sampling_params=sampling_params,
@@ -724,13 +747,89 @@ class Engine:
724
747
  logprob_start_len=logprob_start_len,
725
748
  top_logprobs_num=top_logprobs_num,
726
749
  lora_path=lora_path,
750
+ stream=stream,
727
751
  )
728
752
 
729
753
  # get the current event loop
730
754
  loop = asyncio.get_event_loop()
731
- return loop.run_until_complete(generate_request(obj, None))
755
+ ret = loop.run_until_complete(generate_request(obj, None))
756
+
757
+ if stream is True:
758
+
759
+ def generator_wrapper():
760
+ offset = 0
761
+ loop = asyncio.get_event_loop()
762
+ generator = ret.body_iterator
763
+ while True:
764
+ chunk = loop.run_until_complete(generator.__anext__())
765
+
766
+ if chunk.startswith(STREAM_END_SYMBOL):
767
+ break
768
+ else:
769
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
770
+ data["text"] = data["text"][offset:]
771
+ offset += len(data["text"])
772
+ yield data
773
+
774
+ # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
775
+ # however, it allows to wrap the generator as a subfunction and return
776
+ return generator_wrapper()
777
+ else:
778
+ return ret
779
+
780
+ async def async_generate(
781
+ self,
782
+ prompt: Union[str, List[str]],
783
+ sampling_params: Optional[Dict] = None,
784
+ return_logprob: Optional[Union[List[bool], bool]] = False,
785
+ logprob_start_len: Optional[Union[List[int], int]] = None,
786
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
787
+ lora_path: Optional[List[Optional[str]]] = None,
788
+ stream: bool = False,
789
+ ):
790
+ obj = GenerateReqInput(
791
+ text=prompt,
792
+ sampling_params=sampling_params,
793
+ return_logprob=return_logprob,
794
+ logprob_start_len=logprob_start_len,
795
+ top_logprobs_num=top_logprobs_num,
796
+ lora_path=lora_path,
797
+ stream=stream,
798
+ )
799
+
800
+ ret = await generate_request(obj, None)
801
+
802
+ if stream is True:
803
+ generator = ret.body_iterator
804
+
805
+ async def generator_wrapper():
806
+
807
+ offset = 0
808
+
809
+ while True:
810
+ chunk = await generator.__anext__()
811
+
812
+ if chunk.startswith(STREAM_END_SYMBOL):
813
+ break
814
+ else:
815
+ data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
816
+ data["text"] = data["text"][offset:]
817
+ offset += len(data["text"])
818
+ yield data
819
+
820
+ return generator_wrapper()
821
+ else:
822
+ return ret
732
823
 
733
824
  def shutdown(self):
734
825
  kill_child_process(os.getpid(), including_parent=False)
735
826
 
736
- # TODO (ByronHsu): encode and async generate
827
+ def get_tokenizer(self):
828
+ global tokenizer_manager
829
+
830
+ if tokenizer_manager is None:
831
+ raise ReferenceError("Tokenizer Manager is not initialized.")
832
+ else:
833
+ return tokenizer_manager.tokenizer
834
+
835
+ # TODO (ByronHsu): encode
sglang/srt/server_args.py CHANGED
@@ -35,12 +35,12 @@ class ServerArgs:
35
35
  tokenizer_mode: str = "auto"
36
36
  skip_tokenizer_init: bool = False
37
37
  load_format: str = "auto"
38
+ trust_remote_code: bool = True
38
39
  dtype: str = "auto"
39
- device: str = "cuda"
40
40
  kv_cache_dtype: str = "auto"
41
- trust_remote_code: bool = True
42
- context_length: Optional[int] = None
43
41
  quantization: Optional[str] = None
42
+ context_length: Optional[int] = None
43
+ device: str = "cuda"
44
44
  served_model_name: Optional[str] = None
45
45
  chat_template: Optional[str] = None
46
46
  is_embedding: bool = False
@@ -73,6 +73,7 @@ class ServerArgs:
73
73
  # Other
74
74
  api_key: Optional[str] = None
75
75
  file_storage_pth: str = "SGLang_storage"
76
+ enable_cache_report: bool = False
76
77
 
77
78
  # Data parallelism
78
79
  dp_size: int = 1
@@ -86,10 +87,23 @@ class ServerArgs:
86
87
  # Model override args in JSON
87
88
  json_model_override_args: str = "{}"
88
89
 
89
- # Optimization/debug options
90
+ # Double Sparsity
91
+ enable_double_sparsity: bool = False
92
+ ds_channel_config_path: str = None
93
+ ds_heavy_channel_num: int = 32
94
+ ds_heavy_token_num: int = 256
95
+ ds_heavy_channel_type: str = "qk"
96
+ ds_sparse_decode_threshold: int = 4096
97
+
98
+ # LoRA
99
+ lora_paths: Optional[List[str]] = None
100
+ max_loras_per_batch: int = 8
101
+
102
+ # Kernel backend
90
103
  attention_backend: Optional[str] = None
91
104
  sampling_backend: Optional[str] = None
92
105
 
106
+ # Optimization/debug options
93
107
  disable_flashinfer: bool = False
94
108
  disable_flashinfer_sampling: bool = False
95
109
  disable_radix_cache: bool = False
@@ -99,16 +113,16 @@ class ServerArgs:
99
113
  disable_disk_cache: bool = False
100
114
  disable_custom_all_reduce: bool = False
101
115
  disable_mla: bool = False
116
+ disable_penalizer: bool = False
117
+ disable_nan_detection: bool = False
118
+ enable_overlap_schedule: bool = False
102
119
  enable_mixed_chunk: bool = False
103
120
  enable_torch_compile: bool = False
104
121
  max_torch_compile_bs: int = 32
105
122
  torchao_config: str = ""
106
123
  enable_p2p_check: bool = False
107
124
  triton_attention_reduce_in_fp32: bool = False
108
-
109
- # LoRA
110
- lora_paths: Optional[List[str]] = None
111
- max_loras_per_batch: int = 8
125
+ num_continuous_decode_steps: int = 1
112
126
 
113
127
  def __post_init__(self):
114
128
  # Set missing default values
@@ -224,6 +238,11 @@ class ServerArgs:
224
238
  '"dummy" will initialize the weights with random values, '
225
239
  "which is mainly for profiling.",
226
240
  )
241
+ parser.add_argument(
242
+ "--trust-remote-code",
243
+ action="store_true",
244
+ help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
245
+ )
227
246
  parser.add_argument(
228
247
  "--dtype",
229
248
  type=str,
@@ -238,13 +257,6 @@ class ServerArgs:
238
257
  '* "float" is shorthand for FP32 precision.\n'
239
258
  '* "float32" for FP32 precision.',
240
259
  )
241
- parser.add_argument(
242
- "--device",
243
- type=str,
244
- default="cuda",
245
- choices=["cuda"],
246
- help="The device type.",
247
- )
248
260
  parser.add_argument(
249
261
  "--kv-cache-dtype",
250
262
  type=str,
@@ -252,17 +264,6 @@ class ServerArgs:
252
264
  choices=["auto", "fp8_e5m2"],
253
265
  help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
254
266
  )
255
- parser.add_argument(
256
- "--trust-remote-code",
257
- action="store_true",
258
- help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
259
- )
260
- parser.add_argument(
261
- "--context-length",
262
- type=int,
263
- default=ServerArgs.context_length,
264
- help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
265
- )
266
267
  parser.add_argument(
267
268
  "--quantization",
268
269
  type=str,
@@ -278,6 +279,19 @@ class ServerArgs:
278
279
  ],
279
280
  help="The quantization method.",
280
281
  )
282
+ parser.add_argument(
283
+ "--context-length",
284
+ type=int,
285
+ default=ServerArgs.context_length,
286
+ help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
287
+ )
288
+ parser.add_argument(
289
+ "--device",
290
+ type=str,
291
+ default="cuda",
292
+ choices=["cuda", "xpu"],
293
+ help="The device type.",
294
+ )
281
295
  parser.add_argument(
282
296
  "--served-model-name",
283
297
  type=str,
@@ -398,6 +412,11 @@ class ServerArgs:
398
412
  default=ServerArgs.file_storage_pth,
399
413
  help="The path of the file storage in backend.",
400
414
  )
415
+ parser.add_argument(
416
+ "--enable-cache-report",
417
+ action="store_true",
418
+ help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
419
+ )
401
420
 
402
421
  # Data parallelism
403
422
  parser.add_argument(
@@ -440,7 +459,60 @@ class ServerArgs:
440
459
  default=ServerArgs.json_model_override_args,
441
460
  )
442
461
 
443
- # Optimization/debug options
462
+ # Double Sparsity
463
+ parser.add_argument(
464
+ "--enable-double-sparsity",
465
+ action="store_true",
466
+ help="Enable double sparsity attention",
467
+ )
468
+ parser.add_argument(
469
+ "--ds-channel-config-path",
470
+ type=str,
471
+ default=ServerArgs.ds_channel_config_path,
472
+ help="The path of the double sparsity channel config",
473
+ )
474
+ parser.add_argument(
475
+ "--ds-heavy-channel-num",
476
+ type=int,
477
+ default=ServerArgs.ds_heavy_channel_num,
478
+ help="The number of heavy channels in double sparsity attention",
479
+ )
480
+ parser.add_argument(
481
+ "--ds-heavy-token-num",
482
+ type=int,
483
+ default=ServerArgs.ds_heavy_token_num,
484
+ help="The number of heavy tokens in double sparsity attention",
485
+ )
486
+ parser.add_argument(
487
+ "--ds-heavy-channel-type",
488
+ type=str,
489
+ default=ServerArgs.ds_heavy_channel_type,
490
+ help="The type of heavy channels in double sparsity attention",
491
+ )
492
+ parser.add_argument(
493
+ "--ds-sparse-decode-threshold",
494
+ type=int,
495
+ default=ServerArgs.ds_sparse_decode_threshold,
496
+ help="The type of heavy channels in double sparsity attention",
497
+ )
498
+
499
+ # LoRA
500
+ parser.add_argument(
501
+ "--lora-paths",
502
+ type=str,
503
+ nargs="*",
504
+ default=None,
505
+ action=LoRAPathAction,
506
+ help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
507
+ )
508
+ parser.add_argument(
509
+ "--max-loras-per-batch",
510
+ type=int,
511
+ default=8,
512
+ help="Maximum number of adapters for a running batch, include base-only request",
513
+ )
514
+
515
+ # Kernel backend
444
516
  parser.add_argument(
445
517
  "--attention-backend",
446
518
  type=str,
@@ -455,6 +527,8 @@ class ServerArgs:
455
527
  default=ServerArgs.sampling_backend,
456
528
  help="Choose the kernels for sampling layers.",
457
529
  )
530
+
531
+ # Optimization/debug options
458
532
  parser.add_argument(
459
533
  "--disable-flashinfer",
460
534
  action="store_true",
@@ -501,6 +575,21 @@ class ServerArgs:
501
575
  action="store_true",
502
576
  help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
503
577
  )
578
+ parser.add_argument(
579
+ "--disable-penalizer",
580
+ action="store_true",
581
+ help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
582
+ )
583
+ parser.add_argument(
584
+ "--disable-nan-detection",
585
+ action="store_true",
586
+ help="Disable the NaN detection for better performance.",
587
+ )
588
+ parser.add_argument(
589
+ "--enable-overlap-schedule",
590
+ action="store_true",
591
+ help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
592
+ )
504
593
  parser.add_argument(
505
594
  "--enable-mixed-chunk",
506
595
  action="store_true",
@@ -535,25 +624,12 @@ class ServerArgs:
535
624
  "This only affects Triton attention kernels.",
536
625
  )
537
626
  parser.add_argument(
538
- "--efficient-weight-load",
539
- action="store_true",
540
- help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
541
- )
542
-
543
- # LoRA options
544
- parser.add_argument(
545
- "--lora-paths",
546
- type=str,
547
- nargs="*",
548
- default=None,
549
- action=LoRAPathAction,
550
- help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
551
- )
552
- parser.add_argument(
553
- "--max-loras-per-batch",
627
+ "--num-continuous-decode-steps",
554
628
  type=int,
555
- default=8,
556
- help="Maximum number of adapters for a running batch, include base-only request",
629
+ default=ServerArgs.num_continuous_decode_steps,
630
+ help="Run multiple continuous decoding steps to reduce scheduling overhead. "
631
+ "This can potentially increase throughput but may also increase time-to-first-token latency. "
632
+ "The default value is 1, meaning only run one decoding step at a time.",
557
633
  )
558
634
 
559
635
  @classmethod
sglang/srt/utils.py CHANGED
@@ -35,7 +35,7 @@ import psutil
35
35
  import requests
36
36
  import torch
37
37
  import torch.distributed as dist
38
- from fastapi.responses import JSONResponse
38
+ from fastapi.responses import ORJSONResponse
39
39
  from packaging import version as pkg_version
40
40
  from torch import nn
41
41
  from torch.profiler import ProfilerActivity, profile, record_function
@@ -566,7 +566,7 @@ def add_api_key_middleware(app, api_key: str):
566
566
  if request.url.path.startswith("/health"):
567
567
  return await call_next(request)
568
568
  if request.headers.get("Authorization") != "Bearer " + api_key:
569
- return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
569
+ return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
570
570
  return await call_next(request)
571
571
 
572
572
 
@@ -584,10 +584,11 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
584
584
 
585
585
  def configure_logger(server_args, prefix: str = ""):
586
586
  format = f"[%(asctime)s{prefix}] %(message)s"
587
+ # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
587
588
  logging.basicConfig(
588
589
  level=getattr(logging, server_args.log_level.upper()),
589
590
  format=format,
590
- datefmt="%H:%M:%S",
591
+ datefmt="%Y-%m-%d %H:%M:%S",
591
592
  force=True,
592
593
  )
593
594
 
@@ -690,3 +691,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
690
691
  prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
691
692
  step_counter += 1
692
693
  return result
694
+
695
+
696
+ def first_rank_print(*args, **kwargs):
697
+ if torch.cuda.current_device() == 0:
698
+ print(*args, **kwargs)
699
+ else:
700
+ pass
@@ -76,7 +76,9 @@ def run_eval(args):
76
76
  def few_shot_gsm8k(s, question):
77
77
  s += few_shot_examples + question
78
78
  s += sgl.gen(
79
- "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
79
+ "answer",
80
+ max_tokens=args.max_new_tokens,
81
+ stop=["Question", "Assistant:", "<|separator|>"],
80
82
  )
81
83
 
82
84
  #####################################
@@ -131,6 +133,7 @@ if __name__ == "__main__":
131
133
  parser.add_argument("--num-shots", type=int, default=5)
132
134
  parser.add_argument("--data-path", type=str, default="test.jsonl")
133
135
  parser.add_argument("--num-questions", type=int, default=200)
136
+ parser.add_argument("--max-new-tokens", type=int, default=512)
134
137
  parser.add_argument("--parallel", type=int, default=128)
135
138
  parser.add_argument("--host", type=str, default="http://127.0.0.1")
136
139
  parser.add_argument("--port", type=int, default=30000)