sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +16 -7
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +21 -5
  11. sglang/srt/layers/linear.py +89 -47
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
  15. sglang/srt/layers/moe/topk.py +4 -2
  16. sglang/srt/layers/parameter.py +439 -0
  17. sglang/srt/layers/quantization/__init__.py +5 -2
  18. sglang/srt/layers/quantization/fp8.py +107 -53
  19. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  20. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  21. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  22. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  23. sglang/srt/layers/radix_attention.py +2 -0
  24. sglang/srt/layers/vocab_parallel_embedding.py +16 -3
  25. sglang/srt/managers/cache_controller.py +307 -0
  26. sglang/srt/managers/configure_logging.py +43 -0
  27. sglang/srt/managers/data_parallel_controller.py +2 -0
  28. sglang/srt/managers/detokenizer_manager.py +0 -2
  29. sglang/srt/managers/io_struct.py +29 -13
  30. sglang/srt/managers/schedule_batch.py +7 -1
  31. sglang/srt/managers/scheduler.py +58 -15
  32. sglang/srt/managers/session_controller.py +1 -1
  33. sglang/srt/managers/tokenizer_manager.py +109 -45
  34. sglang/srt/mem_cache/memory_pool.py +313 -53
  35. sglang/srt/metrics/collector.py +32 -35
  36. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  37. sglang/srt/model_executor/forward_batch_info.py +20 -15
  38. sglang/srt/model_executor/model_runner.py +53 -10
  39. sglang/srt/models/chatglm.py +1 -1
  40. sglang/srt/models/dbrx.py +1 -1
  41. sglang/srt/models/grok.py +25 -16
  42. sglang/srt/models/llama.py +46 -4
  43. sglang/srt/models/qwen2.py +11 -0
  44. sglang/srt/models/qwen2_eagle.py +131 -0
  45. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  46. sglang/srt/sampling/sampling_batch_info.py +15 -5
  47. sglang/srt/sampling/sampling_params.py +1 -1
  48. sglang/srt/server.py +125 -69
  49. sglang/srt/server_args.py +39 -19
  50. sglang/srt/speculative/eagle_utils.py +93 -85
  51. sglang/srt/speculative/eagle_worker.py +48 -33
  52. sglang/srt/torch_memory_saver_adapter.py +59 -0
  53. sglang/srt/utils.py +61 -5
  54. sglang/test/test_programs.py +23 -1
  55. sglang/test/test_utils.py +36 -7
  56. sglang/version.py +1 -1
  57. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
  58. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
  59. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  61. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -18,10 +18,12 @@ import copy
18
18
  import dataclasses
19
19
  import logging
20
20
  import os
21
+ import pickle
21
22
  import signal
22
23
  import sys
23
24
  import time
24
25
  import uuid
26
+ from datetime import datetime
25
27
  from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
26
28
 
27
29
  import fastapi
@@ -43,6 +45,7 @@ from sglang.srt.managers.io_struct import (
43
45
  BatchStrOut,
44
46
  BatchTokenIDOut,
45
47
  CloseSessionReqInput,
48
+ ConfigureLoggingReq,
46
49
  EmbeddingReqInput,
47
50
  FlushCacheReq,
48
51
  GenerateReqInput,
@@ -53,6 +56,10 @@ from sglang.srt.managers.io_struct import (
53
56
  OpenSessionReqInput,
54
57
  OpenSessionReqOutput,
55
58
  ProfileReq,
59
+ ReleaseMemoryOccupationReqInput,
60
+ ReleaseMemoryOccupationReqOutput,
61
+ ResumeMemoryOccupationReqInput,
62
+ ResumeMemoryOccupationReqOutput,
56
63
  SessionParams,
57
64
  TokenizedEmbeddingReqInput,
58
65
  TokenizedGenerateReqInput,
@@ -105,6 +112,7 @@ class TokenizerManager:
105
112
  # Parse args
106
113
  self.server_args = server_args
107
114
  self.enable_metrics = server_args.enable_metrics
115
+ self.log_requests = server_args.log_requests
108
116
 
109
117
  # Init inter-process communication
110
118
  context = zmq.asyncio.Context(2)
@@ -163,6 +171,9 @@ class TokenizerManager:
163
171
  # Store states
164
172
  self.to_create_loop = True
165
173
  self.rid_to_state: Dict[str, ReqState] = {}
174
+ self.dump_requests_folder = "" # By default do not dump
175
+ self.dump_requests_threshold = 1000
176
+ self.dump_request_list: List[Tuple] = []
166
177
 
167
178
  # The event to notify the weight sync is finished.
168
179
  self.model_update_lock = RWLock()
@@ -188,6 +199,12 @@ class TokenizerManager:
188
199
  self.get_weights_by_name_communicator = _Communicator(
189
200
  self.send_to_scheduler, server_args.dp_size
190
201
  )
202
+ self.release_memory_occupation_communicator = _Communicator(
203
+ self.send_to_scheduler, server_args.dp_size
204
+ )
205
+ self.resume_memory_occupation_communicator = _Communicator(
206
+ self.send_to_scheduler, server_args.dp_size
207
+ )
191
208
 
192
209
  # Metrics
193
210
  if self.enable_metrics:
@@ -215,7 +232,7 @@ class TokenizerManager:
215
232
 
216
233
  obj.normalize_batch_and_arguments()
217
234
 
218
- if self.server_args.log_requests:
235
+ if self.log_requests:
219
236
  logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
220
237
 
221
238
  async with self.model_update_lock.reader_lock:
@@ -336,7 +353,7 @@ class TokenizerManager:
336
353
 
337
354
  state.out_list = []
338
355
  if state.finished:
339
- if self.server_args.log_requests:
356
+ if self.log_requests:
340
357
  msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
341
358
  logger.info(msg)
342
359
  del self.rid_to_state[obj.rid]
@@ -548,6 +565,22 @@ class TokenizerManager:
548
565
  else:
549
566
  return all_parameters
550
567
 
568
+ async def release_memory_occupation(
569
+ self,
570
+ obj: ReleaseMemoryOccupationReqInput,
571
+ request: Optional[fastapi.Request] = None,
572
+ ):
573
+ self.auto_create_handle_loop()
574
+ await self.release_memory_occupation_communicator(obj)
575
+
576
+ async def resume_memory_occupation(
577
+ self,
578
+ obj: ResumeMemoryOccupationReqInput,
579
+ request: Optional[fastapi.Request] = None,
580
+ ):
581
+ self.auto_create_handle_loop()
582
+ await self.resume_memory_occupation_communicator(obj)
583
+
551
584
  async def open_session(
552
585
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
553
586
  ):
@@ -571,6 +604,15 @@ class TokenizerManager:
571
604
  assert not self.to_create_loop, "close session should not be the first request"
572
605
  await self.send_to_scheduler.send_pyobj(obj)
573
606
 
607
+ def configure_logging(self, obj: ConfigureLoggingReq):
608
+ if obj.log_requests is not None:
609
+ self.log_requests = obj.log_requests
610
+ if obj.dump_requests_folder is not None:
611
+ self.dump_requests_folder = obj.dump_requests_folder
612
+ if obj.dump_requests_threshold is not None:
613
+ self.dump_requests_threshold = obj.dump_requests_threshold
614
+ logging.info(f"Config logging: {obj=}")
615
+
574
616
  def create_abort_task(self, obj: GenerateReqInput):
575
617
  # Abort the request if the client is disconnected.
576
618
  async def abort_request():
@@ -601,7 +643,7 @@ class TokenizerManager:
601
643
  while not self.gracefully_exit:
602
644
  await asyncio.sleep(5)
603
645
 
604
- # drain requests
646
+ # Drain requests
605
647
  while True:
606
648
  remain_num_req = len(self.rid_to_state)
607
649
  logger.info(
@@ -627,6 +669,8 @@ class TokenizerManager:
627
669
  UpdateWeightsFromDistributedReqOutput,
628
670
  GetWeightsByNameReqOutput,
629
671
  InitWeightsUpdateGroupReqOutput,
672
+ ReleaseMemoryOccupationReqOutput,
673
+ ResumeMemoryOccupationReqOutput,
630
674
  ] = await self.recv_from_detokenizer.recv_pyobj()
631
675
 
632
676
  if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
@@ -663,13 +707,6 @@ class TokenizerManager:
663
707
  "text": recv_obj.output_strs[i],
664
708
  "meta_info": meta_info,
665
709
  }
666
- if self.server_args.return_token_ids:
667
- out_dict.update(
668
- {
669
- "input_ids": recv_obj.origin_input_ids[i],
670
- "output_ids": recv_obj.output_ids[i],
671
- }
672
- )
673
710
  elif isinstance(recv_obj, BatchTokenIDOut):
674
711
  out_dict = {
675
712
  "token_ids": recv_obj.output_ids[i],
@@ -686,41 +723,9 @@ class TokenizerManager:
686
723
  state.event.set()
687
724
 
688
725
  if self.enable_metrics:
689
- completion_tokens = (
690
- recv_obj.completion_tokens[i]
691
- if recv_obj.completion_tokens
692
- else 0
693
- )
694
-
695
- if state.first_token_time is None:
696
- state.first_token_time = time.time()
697
- self.metrics_collector.observe_time_to_first_token(
698
- state.first_token_time - state.created_time
699
- )
700
- else:
701
- if completion_tokens >= 2:
702
- # Compute time_per_output_token for the streaming case
703
- self.metrics_collector.observe_time_per_output_token(
704
- (time.time() - state.first_token_time)
705
- / (completion_tokens - 1)
706
- )
707
-
708
- if state.finished:
709
- self.metrics_collector.inc_prompt_tokens(
710
- recv_obj.prompt_tokens[i]
711
- )
712
- self.metrics_collector.inc_generation_tokens(
713
- completion_tokens
714
- )
715
- self.metrics_collector.observe_e2e_request_latency(
716
- time.time() - state.created_time
717
- )
718
- # Compute time_per_output_token for the non-streaming case
719
- if not state.obj.stream and completion_tokens >= 1:
720
- self.metrics_collector.observe_time_per_output_token(
721
- (time.time() - state.created_time)
722
- / completion_tokens
723
- )
726
+ self.collect_metrics(state, recv_obj, i)
727
+ if self.dump_requests_folder and state.finished:
728
+ self.dump_requests(state, out_dict)
724
729
  elif isinstance(recv_obj, OpenSessionReqOutput):
725
730
  self.session_futures[recv_obj.session_id].set_result(
726
731
  recv_obj.session_id if recv_obj.success else None
@@ -750,6 +755,10 @@ class TokenizerManager:
750
755
  self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
751
756
  elif isinstance(recv_obj, GetWeightsByNameReqOutput):
752
757
  self.get_weights_by_name_communicator.handle_recv(recv_obj)
758
+ elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
759
+ self.release_memory_occupation_communicator.handle_recv(recv_obj)
760
+ elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
761
+ self.resume_memory_occupation_communicator.handle_recv(recv_obj)
753
762
  else:
754
763
  raise ValueError(f"Invalid object: {recv_obj=}")
755
764
 
@@ -823,6 +832,61 @@ class TokenizerManager:
823
832
  ret.append(None)
824
833
  return ret
825
834
 
835
+ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
836
+ completion_tokens = (
837
+ recv_obj.completion_tokens[i]
838
+ if getattr(recv_obj, "completion_tokens", None)
839
+ else 0
840
+ )
841
+
842
+ if state.first_token_time is None:
843
+ state.first_token_time = time.time()
844
+ self.metrics_collector.observe_time_to_first_token(
845
+ state.first_token_time - state.created_time
846
+ )
847
+ else:
848
+ if completion_tokens >= 2:
849
+ # Compute time_per_output_token for the streaming case
850
+ self.metrics_collector.observe_time_per_output_token(
851
+ (time.time() - state.first_token_time) / (completion_tokens - 1)
852
+ )
853
+
854
+ if state.finished:
855
+ self.metrics_collector.observe_one_finished_request(
856
+ recv_obj.prompt_tokens[i], completion_tokens
857
+ )
858
+ self.metrics_collector.observe_e2e_request_latency(
859
+ time.time() - state.created_time
860
+ )
861
+ # Compute time_per_output_token for the non-streaming case
862
+ if (
863
+ hasattr(state.obj, "stream")
864
+ and not state.obj.stream
865
+ and completion_tokens >= 1
866
+ ):
867
+ self.metrics_collector.observe_time_per_output_token(
868
+ (time.time() - state.created_time) / completion_tokens
869
+ )
870
+
871
+ def dump_requests(self, state: ReqState, out_dict: dict):
872
+ self.dump_request_list.append(
873
+ (state.obj, out_dict, state.created_time, time.time())
874
+ )
875
+
876
+ if len(self.dump_request_list) >= self.dump_requests_threshold:
877
+ to_dump = self.dump_request_list
878
+ self.dump_request_list = []
879
+
880
+ def background_task():
881
+ os.makedirs(self.dump_requests_folder, exist_ok=True)
882
+ current_time = datetime.now()
883
+ filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
884
+ with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
885
+ pickle.dump(to_dump, f)
886
+
887
+ # Schedule the task to run in the background without awaiting it
888
+ asyncio.create_task(asyncio.to_thread(background_task))
889
+
826
890
 
827
891
  class SignalHandler:
828
892
  def __init__(self, tokenizer_manager):