sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +16 -7
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +21 -5
  11. sglang/srt/layers/linear.py +89 -47
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
  15. sglang/srt/layers/moe/topk.py +4 -2
  16. sglang/srt/layers/parameter.py +439 -0
  17. sglang/srt/layers/quantization/__init__.py +5 -2
  18. sglang/srt/layers/quantization/fp8.py +107 -53
  19. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  20. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  21. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  22. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  23. sglang/srt/layers/radix_attention.py +2 -0
  24. sglang/srt/layers/vocab_parallel_embedding.py +16 -3
  25. sglang/srt/managers/cache_controller.py +307 -0
  26. sglang/srt/managers/configure_logging.py +43 -0
  27. sglang/srt/managers/data_parallel_controller.py +2 -0
  28. sglang/srt/managers/detokenizer_manager.py +0 -2
  29. sglang/srt/managers/io_struct.py +29 -13
  30. sglang/srt/managers/schedule_batch.py +7 -1
  31. sglang/srt/managers/scheduler.py +58 -15
  32. sglang/srt/managers/session_controller.py +1 -1
  33. sglang/srt/managers/tokenizer_manager.py +109 -45
  34. sglang/srt/mem_cache/memory_pool.py +313 -53
  35. sglang/srt/metrics/collector.py +32 -35
  36. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  37. sglang/srt/model_executor/forward_batch_info.py +20 -15
  38. sglang/srt/model_executor/model_runner.py +53 -10
  39. sglang/srt/models/chatglm.py +1 -1
  40. sglang/srt/models/dbrx.py +1 -1
  41. sglang/srt/models/grok.py +25 -16
  42. sglang/srt/models/llama.py +46 -4
  43. sglang/srt/models/qwen2.py +11 -0
  44. sglang/srt/models/qwen2_eagle.py +131 -0
  45. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  46. sglang/srt/sampling/sampling_batch_info.py +15 -5
  47. sglang/srt/sampling/sampling_params.py +1 -1
  48. sglang/srt/server.py +125 -69
  49. sglang/srt/server_args.py +39 -19
  50. sglang/srt/speculative/eagle_utils.py +93 -85
  51. sglang/srt/speculative/eagle_worker.py +48 -33
  52. sglang/srt/torch_memory_saver_adapter.py +59 -0
  53. sglang/srt/utils.py +61 -5
  54. sglang/test/test_programs.py +23 -1
  55. sglang/test/test_utils.py +36 -7
  56. sglang/version.py +1 -1
  57. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
  58. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
  59. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  61. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
31
31
 
32
32
  import torch
33
33
 
34
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
35
+
34
36
  # Fix a bug of Python threading
35
37
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
36
38
 
@@ -52,11 +54,14 @@ from sglang.srt.managers.data_parallel_controller import (
52
54
  from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
53
55
  from sglang.srt.managers.io_struct import (
54
56
  CloseSessionReqInput,
57
+ ConfigureLoggingReq,
55
58
  EmbeddingReqInput,
56
59
  GenerateReqInput,
57
60
  GetWeightsByNameReqInput,
58
61
  InitWeightsUpdateGroupReqInput,
59
62
  OpenSessionReqInput,
63
+ ReleaseMemoryOccupationReqInput,
64
+ ResumeMemoryOccupationReqInput,
60
65
  UpdateWeightFromDiskReqInput,
61
66
  UpdateWeightsFromDistributedReqInput,
62
67
  UpdateWeightsFromTensorReqInput,
@@ -127,14 +132,12 @@ async def health() -> Response:
127
132
  async def health_generate(request: Request) -> Response:
128
133
  """Check the health of the inference server by generating one token."""
129
134
 
135
+ sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
136
+
130
137
  if tokenizer_manager.is_generation:
131
- gri = GenerateReqInput(
132
- input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
133
- )
138
+ gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
134
139
  else:
135
- gri = EmbeddingReqInput(
136
- input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
137
- )
140
+ gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
138
141
 
139
142
  try:
140
143
  async for _ in tokenizer_manager.generate_request(gri, request):
@@ -159,12 +162,68 @@ async def get_model_info():
159
162
  @app.get("/get_server_info")
160
163
  async def get_server_info():
161
164
  return {
162
- **dataclasses.asdict(tokenizer_manager.server_args), # server args
165
+ **dataclasses.asdict(tokenizer_manager.server_args),
163
166
  **scheduler_info,
164
167
  "version": __version__,
165
168
  }
166
169
 
167
170
 
171
+ # fastapi implicitly converts json in the request to obj (dataclass)
172
+ @app.api_route("/generate", methods=["POST", "PUT"])
173
+ @time_func_latency
174
+ async def generate_request(obj: GenerateReqInput, request: Request):
175
+ """Handle a generate request."""
176
+ if obj.stream:
177
+
178
+ async def stream_results() -> AsyncIterator[bytes]:
179
+ try:
180
+ async for out in tokenizer_manager.generate_request(obj, request):
181
+ yield b"data: " + orjson.dumps(
182
+ out, option=orjson.OPT_NON_STR_KEYS
183
+ ) + b"\n\n"
184
+ except ValueError as e:
185
+ out = {"error": {"message": str(e)}}
186
+ yield b"data: " + orjson.dumps(
187
+ out, option=orjson.OPT_NON_STR_KEYS
188
+ ) + b"\n\n"
189
+ yield b"data: [DONE]\n\n"
190
+
191
+ return StreamingResponse(
192
+ stream_results(),
193
+ media_type="text/event-stream",
194
+ background=tokenizer_manager.create_abort_task(obj),
195
+ )
196
+ else:
197
+ try:
198
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
199
+ return ret
200
+ except ValueError as e:
201
+ logger.error(f"Error: {e}")
202
+ return _create_error_response(e)
203
+
204
+
205
+ @app.api_route("/encode", methods=["POST", "PUT"])
206
+ @time_func_latency
207
+ async def encode_request(obj: EmbeddingReqInput, request: Request):
208
+ """Handle an embedding request."""
209
+ try:
210
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
211
+ return ret
212
+ except ValueError as e:
213
+ return _create_error_response(e)
214
+
215
+
216
+ @app.api_route("/classify", methods=["POST", "PUT"])
217
+ @time_func_latency
218
+ async def classify_request(obj: EmbeddingReqInput, request: Request):
219
+ """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
220
+ try:
221
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
222
+ return ret
223
+ except ValueError as e:
224
+ return _create_error_response(e)
225
+
226
+
168
227
  @app.post("/flush_cache")
169
228
  async def flush_cache():
170
229
  """Flush the radix cache."""
@@ -176,8 +235,7 @@ async def flush_cache():
176
235
  )
177
236
 
178
237
 
179
- @app.get("/start_profile")
180
- @app.post("/start_profile")
238
+ @app.api_route("/start_profile", methods=["GET", "POST"])
181
239
  async def start_profile_async():
182
240
  """Start profiling."""
183
241
  tokenizer_manager.start_profile()
@@ -187,8 +245,7 @@ async def start_profile_async():
187
245
  )
188
246
 
189
247
 
190
- @app.get("/stop_profile")
191
- @app.post("/stop_profile")
248
+ @app.api_route("/stop_profile", methods=["GET", "POST"])
192
249
  async def stop_profile_async():
193
250
  """Stop profiling."""
194
251
  tokenizer_manager.stop_profile()
@@ -257,6 +314,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
257
314
  return _create_error_response(e)
258
315
 
259
316
 
317
+ @app.api_route("/release_memory_occupation", methods=["GET", "POST"])
318
+ async def release_memory_occupation(
319
+ obj: ReleaseMemoryOccupationReqInput, request: Request
320
+ ):
321
+ """Release GPU occupation temporarily"""
322
+ try:
323
+ await tokenizer_manager.release_memory_occupation(obj, request)
324
+ except Exception as e:
325
+ return _create_error_response(e)
326
+
327
+
328
+ @app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
329
+ async def resume_memory_occupation(
330
+ obj: ResumeMemoryOccupationReqInput, request: Request
331
+ ):
332
+ """Resume GPU occupation"""
333
+ try:
334
+ await tokenizer_manager.resume_memory_occupation(obj, request)
335
+ except Exception as e:
336
+ return _create_error_response(e)
337
+
338
+
260
339
  @app.api_route("/open_session", methods=["GET", "POST"])
261
340
  async def open_session(obj: OpenSessionReqInput, request: Request):
262
341
  """Open a session, and return its unique session id."""
@@ -281,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
281
360
  return _create_error_response(e)
282
361
 
283
362
 
284
- # fastapi implicitly converts json in the request to obj (dataclass)
285
- @app.api_route("/generate", methods=["POST", "PUT"])
286
- @time_func_latency
287
- async def generate_request(obj: GenerateReqInput, request: Request):
288
- """Handle a generate request."""
289
- if obj.stream:
290
-
291
- async def stream_results() -> AsyncIterator[bytes]:
292
- try:
293
- async for out in tokenizer_manager.generate_request(obj, request):
294
- yield b"data: " + orjson.dumps(
295
- out, option=orjson.OPT_NON_STR_KEYS
296
- ) + b"\n\n"
297
- except ValueError as e:
298
- out = {"error": {"message": str(e)}}
299
- yield b"data: " + orjson.dumps(
300
- out, option=orjson.OPT_NON_STR_KEYS
301
- ) + b"\n\n"
302
- yield b"data: [DONE]\n\n"
303
-
304
- return StreamingResponse(
305
- stream_results(),
306
- media_type="text/event-stream",
307
- background=tokenizer_manager.create_abort_task(obj),
308
- )
309
- else:
310
- try:
311
- ret = await tokenizer_manager.generate_request(obj, request).__anext__()
312
- return ret
313
- except ValueError as e:
314
- logger.error(f"Error: {e}")
315
- return _create_error_response(e)
316
-
317
-
318
- @app.api_route("/encode", methods=["POST", "PUT"])
319
- @time_func_latency
320
- async def encode_request(obj: EmbeddingReqInput, request: Request):
321
- """Handle an embedding request."""
322
- try:
323
- ret = await tokenizer_manager.generate_request(obj, request).__anext__()
324
- return ret
325
- except ValueError as e:
326
- return _create_error_response(e)
327
-
328
-
329
- @app.api_route("/classify", methods=["POST", "PUT"])
330
- @time_func_latency
331
- async def classify_request(obj: EmbeddingReqInput, request: Request):
332
- """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
333
- try:
334
- ret = await tokenizer_manager.generate_request(obj, request).__anext__()
335
- return ret
336
- except ValueError as e:
337
- return _create_error_response(e)
363
+ @app.api_route("/configure_logging", methods=["GET", "POST"])
364
+ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
365
+ """Close the session"""
366
+ tokenizer_manager.configure_logging(obj)
367
+ return Response(status_code=200)
338
368
 
339
369
 
340
370
  ##### OpenAI-compatible API endpoints #####
@@ -440,6 +470,10 @@ def launch_engine(
440
470
  server_args.model_path, server_args.tokenizer_path
441
471
  )
442
472
 
473
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
474
+ enable=server_args.enable_memory_saver
475
+ )
476
+
443
477
  if server_args.dp_size == 1:
444
478
  # Launch tensor parallel scheduler processes
445
479
  scheduler_procs = []
@@ -456,7 +490,8 @@ def launch_engine(
456
490
  target=run_scheduler_process,
457
491
  args=(server_args, port_args, gpu_id, tp_rank, None, writer),
458
492
  )
459
- proc.start()
493
+ with memory_saver_adapter.configure_subprocess():
494
+ proc.start()
460
495
  scheduler_procs.append(proc)
461
496
  scheduler_pipe_readers.append(reader)
462
497
 
@@ -473,7 +508,8 @@ def launch_engine(
473
508
  target=run_data_parallel_controller_process,
474
509
  args=(server_args, port_args, writer),
475
510
  )
476
- proc.start()
511
+ with memory_saver_adapter.configure_subprocess():
512
+ proc.start()
477
513
 
478
514
  # Launch detokenizer process
479
515
  detoken_proc = mp.Process(
@@ -546,7 +582,12 @@ def launch_server(
546
582
 
547
583
  # Send a warmup request
548
584
  t = threading.Thread(
549
- target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
585
+ target=_wait_and_warmup,
586
+ args=(
587
+ server_args,
588
+ pipe_finish_writer,
589
+ tokenizer_manager.image_token_id,
590
+ ),
550
591
  )
551
592
  t.start()
552
593
 
@@ -608,6 +649,9 @@ def _set_envs_and_config(server_args: ServerArgs):
608
649
  # The child processes will send SIGQUIT to this process when any error happens
609
650
  # This process then clean up the whole process tree
610
651
  def sigquit_handler(signum, frame):
652
+ logger.error(
653
+ "Received sigquit from a child proces. It usually means the child failed."
654
+ )
611
655
  kill_process_tree(os.getpid())
612
656
 
613
657
  signal.signal(signal.SIGQUIT, sigquit_handler)
@@ -616,7 +660,7 @@ def _set_envs_and_config(server_args: ServerArgs):
616
660
  mp.set_start_method("spawn", force=True)
617
661
 
618
662
 
619
- def _wait_and_warmup(server_args, pipe_finish_writer):
663
+ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
620
664
  headers = {}
621
665
  url = server_args.url()
622
666
  if server_args.api_key:
@@ -891,6 +935,18 @@ class Engine:
891
935
  loop = asyncio.get_event_loop()
892
936
  return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
893
937
 
938
+ def release_memory_occupation(self):
939
+ """Release GPU occupation temporarily"""
940
+ obj = ReleaseMemoryOccupationReqInput()
941
+ loop = asyncio.get_event_loop()
942
+ loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
943
+
944
+ def resume_memory_occupation(self):
945
+ """Resume GPU occupation"""
946
+ obj = ResumeMemoryOccupationReqInput()
947
+ loop = asyncio.get_event_loop()
948
+ loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
949
+
894
950
 
895
951
  class Runtime:
896
952
  """
sglang/srt/server_args.py CHANGED
@@ -23,7 +23,6 @@ from typing import List, Optional
23
23
  import torch
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import check_gguf_file
26
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
27
26
  from sglang.srt.utils import (
28
27
  get_amdgpu_memory_capacity,
29
28
  get_hpu_memory_capacity,
@@ -32,6 +31,7 @@ from sglang.srt.utils import (
32
31
  is_hip,
33
32
  is_ipv6,
34
33
  is_port_available,
34
+ nullable_str,
35
35
  )
36
36
 
37
37
  logger = logging.getLogger(__name__)
@@ -47,6 +47,7 @@ class ServerArgs:
47
47
  trust_remote_code: bool = True
48
48
  dtype: str = "auto"
49
49
  kv_cache_dtype: str = "auto"
50
+ quantization_param_path: nullable_str = None
50
51
  quantization: Optional[str] = None
51
52
  context_length: Optional[int] = None
52
53
  device: str = "cuda"
@@ -55,7 +56,6 @@ class ServerArgs:
55
56
  is_embedding: bool = False
56
57
  revision: Optional[str] = None
57
58
  skip_tokenizer_init: bool = False
58
- return_token_ids: bool = False
59
59
 
60
60
  # Port for the HTTP server
61
61
  host: str = "127.0.0.1"
@@ -91,7 +91,7 @@ class ServerArgs:
91
91
 
92
92
  # API related
93
93
  api_key: Optional[str] = None
94
- file_storage_pth: str = "SGLang_storage"
94
+ file_storage_pth: str = "sglang_storage"
95
95
  enable_cache_report: bool = False
96
96
 
97
97
  # Data parallelism
@@ -148,6 +148,7 @@ class ServerArgs:
148
148
  enable_torch_compile: bool = False
149
149
  torch_compile_max_bs: int = 32
150
150
  cuda_graph_max_bs: Optional[int] = None
151
+ cuda_graph_bs: Optional[List[int]] = None
151
152
  torchao_config: str = ""
152
153
  enable_nan_detection: bool = False
153
154
  enable_p2p_check: bool = False
@@ -155,6 +156,7 @@ class ServerArgs:
155
156
  triton_attention_num_kv_splits: int = 8
156
157
  num_continuous_decode_steps: int = 1
157
158
  delete_ckpt_after_loading: bool = False
159
+ enable_memory_saver: bool = False
158
160
 
159
161
  def __post_init__(self):
160
162
  # Set missing default values
@@ -295,6 +297,11 @@ class ServerArgs:
295
297
  "tokenizer if available, and 'slow' will "
296
298
  "always use the slow tokenizer.",
297
299
  )
300
+ parser.add_argument(
301
+ "--skip-tokenizer-init",
302
+ action="store_true",
303
+ help="If set, skip init tokenizer and pass input_ids in generate request",
304
+ )
298
305
  parser.add_argument(
299
306
  "--load-format",
300
307
  type=str,
@@ -345,8 +352,17 @@ class ServerArgs:
345
352
  "--kv-cache-dtype",
346
353
  type=str,
347
354
  default=ServerArgs.kv_cache_dtype,
348
- choices=["auto", "fp8_e5m2"],
349
- help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
355
+ choices=["auto", "fp8_e5m2", "fp8_e4m3"],
356
+ help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
357
+ )
358
+ parser.add_argument(
359
+ "--quantization-param-path",
360
+ type=nullable_str,
361
+ default=None,
362
+ help="Path to the JSON file containing the KV cache "
363
+ "scaling factors. This should generally be supplied, when "
364
+ "KV cache dtype is FP8. Otherwise, KV cache scaling factors "
365
+ "default to 1.0, which may cause accuracy issues. ",
350
366
  )
351
367
  parser.add_argument(
352
368
  "--quantization",
@@ -361,6 +377,8 @@ class ServerArgs:
361
377
  "awq_marlin",
362
378
  "bitsandbytes",
363
379
  "gguf",
380
+ "modelopt",
381
+ "w8a8_int8",
364
382
  ],
365
383
  help="The quantization method.",
366
384
  )
@@ -402,18 +420,6 @@ class ServerArgs:
402
420
  "name, a tag name, or a commit id. If unspecified, will use "
403
421
  "the default version.",
404
422
  )
405
- parser.add_argument(
406
- "--skip-tokenizer-init",
407
- action="store_true",
408
- help="If set, skip init tokenizer and pass input_ids in generate request",
409
- )
410
- parser.add_argument(
411
- "--return-token-ids",
412
- action="store_true",
413
- default=ServerArgs.return_token_ids,
414
- help="Whether to return token IDs in the output, this may introduce additional overhead.",
415
- )
416
-
417
423
  # Memory and scheduling
418
424
  parser.add_argument(
419
425
  "--mem-fraction-static",
@@ -549,7 +555,7 @@ class ServerArgs:
549
555
  "--decode-log-interval",
550
556
  type=int,
551
557
  default=ServerArgs.decode_log_interval,
552
- help="The log interval of decode batch",
558
+ help="The log interval of decode batch.",
553
559
  )
554
560
 
555
561
  # API related
@@ -802,6 +808,12 @@ class ServerArgs:
802
808
  default=ServerArgs.cuda_graph_max_bs,
803
809
  help="Set the maximum batch size for cuda graph.",
804
810
  )
811
+ parser.add_argument(
812
+ "--cuda-graph-bs",
813
+ type=int,
814
+ nargs="+",
815
+ help="Set the list of batch sizes for cuda graph.",
816
+ )
805
817
  parser.add_argument(
806
818
  "--torchao-config",
807
819
  type=str,
@@ -843,6 +855,11 @@ class ServerArgs:
843
855
  action="store_true",
844
856
  help="Delete the model checkpoint after loading the model.",
845
857
  )
858
+ parser.add_argument(
859
+ "--enable-memory-saver",
860
+ action="store_true",
861
+ help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
862
+ )
846
863
 
847
864
  @classmethod
848
865
  def from_cli_args(cls, args: argparse.Namespace):
@@ -920,7 +937,10 @@ class PortArgs:
920
937
  while True:
921
938
  if is_port_available(port):
922
939
  break
923
- port += 42
940
+ if port < 60000:
941
+ port += 42
942
+ else:
943
+ port -= 43
924
944
 
925
945
  return PortArgs(
926
946
  tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,