sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -41,14 +41,17 @@ from sglang.srt.disaggregation.decode import (
41
41
  DecodeTransferQueue,
42
42
  SchedulerDisaggregationDecodeMixin,
43
43
  )
44
+ from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
44
45
  from sglang.srt.disaggregation.prefill import (
45
46
  PrefillBootstrapQueue,
46
47
  SchedulerDisaggregationPrefillMixin,
47
48
  )
48
49
  from sglang.srt.disaggregation.utils import (
49
50
  DisaggregationMode,
51
+ MetadataBuffers,
50
52
  ReqToMetadataIdxAllocator,
51
53
  TransferBackend,
54
+ prepare_abort,
52
55
  )
53
56
  from sglang.srt.distributed import get_pp_group, get_world_group
54
57
  from sglang.srt.hf_transformers_utils import (
@@ -58,7 +61,10 @@ from sglang.srt.hf_transformers_utils import (
58
61
  )
59
62
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
60
63
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
61
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
64
+ from sglang.srt.managers.expert_distribution import (
65
+ ExpertDistributionRecorder,
66
+ get_global_expert_distribution_recorder,
67
+ )
62
68
  from sglang.srt.managers.io_struct import (
63
69
  AbortReq,
64
70
  CloseSessionReqInput,
@@ -97,6 +103,7 @@ from sglang.srt.managers.io_struct import (
97
103
  UpdateWeightsFromTensorReqInput,
98
104
  UpdateWeightsFromTensorReqOutput,
99
105
  )
106
+ from sglang.srt.managers.mm_utils import init_embedding_cache
100
107
  from sglang.srt.managers.schedule_batch import (
101
108
  FINISH_ABORT,
102
109
  MultimodalInputs,
@@ -129,7 +136,6 @@ from sglang.srt.utils import (
129
136
  DynamicGradMode,
130
137
  broadcast_pyobj,
131
138
  configure_logger,
132
- crash_on_warnings,
133
139
  disable_request_logging,
134
140
  get_bool_env_var,
135
141
  get_zmq_socket,
@@ -142,8 +148,6 @@ from sglang.srt.utils import (
142
148
  )
143
149
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
144
150
 
145
- expert_distribution_recorder = ExpertDistributionRecorder()
146
-
147
151
  logger = logging.getLogger(__name__)
148
152
 
149
153
  # Test retract decode for debugging purposes
@@ -198,6 +202,7 @@ class Scheduler(
198
202
  self.enable_overlap = not server_args.disable_overlap_schedule
199
203
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
200
204
  self.enable_metrics = server_args.enable_metrics
205
+ self.enable_kv_cache_events = server_args.kv_events_config is not None
201
206
  self.stream_interval = server_args.stream_interval
202
207
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
203
208
  server_args.speculative_algorithm
@@ -205,7 +210,6 @@ class Scheduler(
205
210
  self.gpu_id = gpu_id
206
211
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
207
212
  self.page_size = server_args.page_size
208
-
209
213
  # Distributed rank info
210
214
  self.dp_size = server_args.dp_size
211
215
  self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
@@ -349,8 +353,8 @@ class Scheduler(
349
353
  self.forward_ct_decode = 0
350
354
  self.num_generated_tokens = 0
351
355
  self.num_prefill_tokens = 0
352
- self.last_decode_stats_tic = time.time()
353
- self.last_prefill_stats_tic = time.time()
356
+ self.last_decode_stats_tic = time.perf_counter()
357
+ self.last_prefill_stats_tic = time.perf_counter()
354
358
  self.return_health_check_ct = 0
355
359
  self.current_stream = torch.get_device_module(self.device).current_stream()
356
360
  if self.device == "cpu":
@@ -423,6 +427,7 @@ class Scheduler(
423
427
 
424
428
  # Init metrics stats
425
429
  self.init_metrics()
430
+ self.init_kv_events(server_args.kv_events_config)
426
431
 
427
432
  # Init request dispatcher
428
433
  self._request_dispatcher = TypeBasedDispatcher(
@@ -516,6 +521,7 @@ class Scheduler(
516
521
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
517
522
  page_size=self.page_size,
518
523
  disable=server_args.disable_radix_cache,
524
+ enable_kv_cache_events=self.enable_kv_cache_events,
519
525
  )
520
526
 
521
527
  self.decode_mem_cache_buf_multiplier = (
@@ -548,6 +554,10 @@ class Scheduler(
548
554
  },
549
555
  )
550
556
 
557
+ def init_kv_events(self, kv_events_config: Optional[str]):
558
+ if self.enable_kv_cache_events:
559
+ self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)
560
+
551
561
  def init_disaggregation(self):
552
562
  self.transfer_backend = TransferBackend(
553
563
  self.server_args.disaggregation_transfer_backend
@@ -560,29 +570,28 @@ class Scheduler(
560
570
  req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
561
571
  buffer_size
562
572
  )
563
- aux_dtype = torch.int32
564
- # A list of metadata buffers. The shape is (b, metadata_size) where
565
- # b corresponds to a max running requests. The last shape * dtype.itemsize
566
- # should be larger than 64 bytes to work with RDMA, so we pad it.
567
- output_id_buffer = torch.zeros(
568
- (buffer_size, 16), dtype=aux_dtype, device="cpu"
569
- )
570
- metadata_buffers = [output_id_buffer]
573
+ self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
571
574
 
572
575
  # The decode requests polling kv cache
573
576
  self.disagg_decode_transfer_queue = DecodeTransferQueue(
574
577
  gloo_group=self.attn_tp_cpu_group,
575
578
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
576
- metadata_buffers=metadata_buffers,
579
+ metadata_buffers=self.disagg_metadata_buffers,
580
+ scheduler=self,
581
+ tree_cache=self.tree_cache,
577
582
  )
578
583
 
579
584
  # The decode requests pending for pre-allocation
580
585
  self.disagg_decode_prealloc_queue = DecodePreallocQueue(
581
586
  req_to_token_pool=self.req_to_token_pool,
582
587
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
588
+ draft_token_to_kv_pool=(
589
+ None
590
+ if self.draft_worker is None
591
+ else self.draft_worker.model_runner.token_to_kv_pool
592
+ ),
583
593
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
584
- metadata_buffers=metadata_buffers,
585
- aux_dtype=aux_dtype,
594
+ metadata_buffers=self.disagg_metadata_buffers,
586
595
  scheduler=self,
587
596
  transfer_queue=self.disagg_decode_transfer_queue,
588
597
  tree_cache=self.tree_cache,
@@ -602,20 +611,17 @@ class Scheduler(
602
611
  req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
603
612
  buffer_size
604
613
  )
605
- aux_dtype = torch.int32
606
- # A list of metadata buffers. The shape is (b, metadata_size) where
607
- # b corresponds to a max running requests. The last shape * dtype.itemsize
608
- # should be larger than 64 bytes to work with RDMA, so we pad it.
609
- output_id_buffer = torch.zeros(
610
- (buffer_size, 16), dtype=aux_dtype, device="cpu"
611
- )
612
- metadata_buffers = [output_id_buffer]
614
+ self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
613
615
 
614
616
  self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
615
617
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
618
+ draft_token_to_kv_pool=(
619
+ None
620
+ if self.draft_worker is None
621
+ else self.draft_worker.model_runner.token_to_kv_pool
622
+ ),
616
623
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
617
- metadata_buffers=metadata_buffers,
618
- aux_dtype=aux_dtype,
624
+ metadata_buffers=self.disagg_metadata_buffers,
619
625
  tp_rank=self.tp_rank,
620
626
  tp_size=self.tp_size,
621
627
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
@@ -928,6 +934,18 @@ class Scheduler(
928
934
  )
929
935
  req.tokenizer = self.tokenizer
930
936
 
937
+ if self.disaggregation_mode != DisaggregationMode.NULL:
938
+ # Invalid request for disaggregated mode
939
+ if recv_req.bootstrap_room is None:
940
+ error_message = (
941
+ f"Invalid request: Disaggregated request received without "
942
+ f"boostrap room id. {req.rid=}"
943
+ )
944
+ logger.error(error_message)
945
+ prepare_abort(req, error_message)
946
+ self.stream_output([req], req.return_logprob)
947
+ return
948
+
931
949
  if (
932
950
  recv_req.session_params is not None
933
951
  and recv_req.session_params.id is not None
@@ -1033,13 +1051,13 @@ class Scheduler(
1033
1051
  add_to_grammar_queue = True
1034
1052
 
1035
1053
  if add_to_grammar_queue:
1036
- req.queue_time_start = time.time()
1054
+ req.queue_time_start = time.perf_counter()
1037
1055
  self.grammar_queue.append(req)
1038
1056
  else:
1039
1057
  self._add_request_to_queue(req)
1040
1058
 
1041
1059
  def _add_request_to_queue(self, req: Req):
1042
- req.queue_time_start = time.time()
1060
+ req.queue_time_start = time.perf_counter()
1043
1061
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
1044
1062
  self.disagg_prefill_bootstrap_queue.add(req)
1045
1063
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
@@ -1047,8 +1065,11 @@ class Scheduler(
1047
1065
  else:
1048
1066
  self.waiting_queue.append(req)
1049
1067
 
1050
- def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1051
- if self.disaggregation_mode == DisaggregationMode.DECODE:
1068
+ def _extend_requests_to_queue(self, reqs: List[Req]):
1069
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1070
+ self.disagg_prefill_bootstrap_queue.extend(reqs)
1071
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
1072
+ # If this is a decode server, we put the request to the decode pending prealloc queue
1052
1073
  self.disagg_decode_prealloc_queue.extend(reqs)
1053
1074
  else:
1054
1075
  self.waiting_queue.extend(reqs)
@@ -1086,7 +1107,7 @@ class Scheduler(
1086
1107
  req.finished_reason = FINISH_ABORT(
1087
1108
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
1088
1109
  )
1089
- req.queue_time_start = time.time()
1110
+ req.queue_time_start = time.perf_counter()
1090
1111
  self.waiting_queue.append(req)
1091
1112
  return
1092
1113
 
@@ -1110,8 +1131,8 @@ class Scheduler(
1110
1131
  can_run_list: List[Req],
1111
1132
  running_bs: int,
1112
1133
  ):
1113
- gap_latency = time.time() - self.last_prefill_stats_tic
1114
- self.last_prefill_stats_tic = time.time()
1134
+ gap_latency = time.perf_counter() - self.last_prefill_stats_tic
1135
+ self.last_prefill_stats_tic = time.perf_counter()
1115
1136
  self.last_input_throughput = self.num_prefill_tokens / gap_latency
1116
1137
  self.num_prefill_tokens = 0
1117
1138
 
@@ -1155,14 +1176,15 @@ class Scheduler(
1155
1176
  self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
1156
1177
 
1157
1178
  self.metrics_collector.log_stats(self.stats)
1179
+ self._publish_kv_events()
1158
1180
 
1159
1181
  def log_decode_stats(
1160
1182
  self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
1161
1183
  ):
1162
1184
  batch = running_batch or self.running_batch
1163
1185
 
1164
- gap_latency = time.time() - self.last_decode_stats_tic
1165
- self.last_decode_stats_tic = time.time()
1186
+ gap_latency = time.perf_counter() - self.last_decode_stats_tic
1187
+ self.last_decode_stats_tic = time.perf_counter()
1166
1188
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
1167
1189
  self.num_generated_tokens = 0
1168
1190
  num_running_reqs = len(batch.reqs)
@@ -1214,6 +1236,7 @@ class Scheduler(
1214
1236
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1215
1237
  self.stats.spec_accept_length = spec_accept_length
1216
1238
  self.metrics_collector.log_stats(self.stats)
1239
+ self._publish_kv_events()
1217
1240
 
1218
1241
  def check_memory(self):
1219
1242
  available_size = (
@@ -1246,7 +1269,7 @@ class Scheduler(
1246
1269
  if (
1247
1270
  self.enable_metrics
1248
1271
  and self.attn_tp_rank == 0
1249
- and time.time() > self.metrics_collector.last_log_time + 30
1272
+ and time.perf_counter() > self.metrics_collector.last_log_time + 30
1250
1273
  ):
1251
1274
  # During idle time, also collect metrics every 30 seconds.
1252
1275
  num_used = self.max_total_num_tokens - (
@@ -1261,6 +1284,7 @@ class Scheduler(
1261
1284
  self.stats.num_queue_reqs = len(self.waiting_queue)
1262
1285
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1263
1286
  self.metrics_collector.log_stats(self.stats)
1287
+ self._publish_kv_events()
1264
1288
 
1265
1289
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1266
1290
  # Merge the prefill batch into the running batch
@@ -1383,6 +1407,13 @@ class Scheduler(
1383
1407
  self.running_batch.batch_is_full = True
1384
1408
  break
1385
1409
 
1410
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1411
+ # In prefill mode, prealloc queue and transfer queue can also take memory,
1412
+ # so we need to check if the available size for the actual available size.
1413
+ if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
1414
+ self.running_batch.batch_is_full = True
1415
+ break
1416
+
1386
1417
  req.init_next_round_input(
1387
1418
  None if prefix_computed else self.tree_cache,
1388
1419
  self.enable_hierarchical_cache,
@@ -1411,7 +1442,7 @@ class Scheduler(
1411
1442
  if self.enable_metrics:
1412
1443
  # only record queue time when enable_metrics is True to avoid overhead
1413
1444
  for req in can_run_list:
1414
- req.queue_time_end = time.time()
1445
+ req.queue_time_end = time.perf_counter()
1415
1446
 
1416
1447
  self.waiting_queue = [
1417
1448
  x for x in self.waiting_queue if x not in set(can_run_list)
@@ -1513,7 +1544,7 @@ class Scheduler(
1513
1544
  self.profiler_target_forward_ct
1514
1545
  and self.profiler_target_forward_ct <= self.forward_ct
1515
1546
  ):
1516
- self.stop_profile()
1547
+ self.send_to_tokenizer.send_pyobj(self.stop_profile())
1517
1548
 
1518
1549
  if self.forward_sleep_time is not None:
1519
1550
  logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
@@ -1784,10 +1815,10 @@ class Scheduler(
1784
1815
  def watchdog_thread(self):
1785
1816
  """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
1786
1817
  self.watchdog_last_forward_ct = 0
1787
- self.watchdog_last_time = time.time()
1818
+ self.watchdog_last_time = time.perf_counter()
1788
1819
 
1789
1820
  while True:
1790
- current = time.time()
1821
+ current = time.perf_counter()
1791
1822
  if self.cur_batch is not None:
1792
1823
  if self.watchdog_last_forward_ct == self.forward_ct:
1793
1824
  if current > self.watchdog_last_time + self.watchdog_timeout:
@@ -2115,7 +2146,10 @@ class Scheduler(
2115
2146
 
2116
2147
  def stop_profile(self) -> None:
2117
2148
  if self.profiler_activities is None:
2118
- return
2149
+ return ProfileReqOutput(
2150
+ success=False,
2151
+ message="Profiling is not in progress. Call /start_profile first.",
2152
+ )
2119
2153
 
2120
2154
  logger.info("Stop profiling...")
2121
2155
  if self.torch_profiler is not None:
@@ -2146,18 +2180,15 @@ class Scheduler(
2146
2180
  self.torch_profiler_output_dir = None
2147
2181
  self.profiler_activities = None
2148
2182
 
2149
- if self.profiler_target_forward_ct:
2150
- self.send_to_tokenizer.send_pyobj(
2151
- ProfileReqOutput(success=True, message="Succeeded.")
2152
- )
2183
+ return ProfileReqOutput(success=True, message="Succeeded")
2153
2184
 
2154
2185
  def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2155
2186
  if recv_req == ExpertDistributionReq.START_RECORD:
2156
- expert_distribution_recorder.start_record()
2187
+ get_global_expert_distribution_recorder().start_record()
2157
2188
  elif recv_req == ExpertDistributionReq.STOP_RECORD:
2158
- expert_distribution_recorder.stop_record()
2189
+ get_global_expert_distribution_recorder().stop_record()
2159
2190
  elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2160
- expert_distribution_recorder.dump_record()
2191
+ get_global_expert_distribution_recorder().dump_record()
2161
2192
  else:
2162
2193
  raise ValueError("Unrecognized ExpertDistributionReq value")
2163
2194
  return ExpertDistributionReqOutput()
@@ -2195,6 +2226,13 @@ class Scheduler(
2195
2226
  prefix += f" PP{self.pp_rank}"
2196
2227
  return prefix
2197
2228
 
2229
+ def _publish_kv_events(self):
2230
+ if self.enable_kv_cache_events:
2231
+ events = self.tree_cache.take_events()
2232
+ if events:
2233
+ batch = KVEventBatch(ts=time.time(), events=events)
2234
+ self.kv_event_publisher.publish(batch)
2235
+
2198
2236
 
2199
2237
  def is_health_check_generate_req(recv_req):
2200
2238
  return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
@@ -2250,6 +2288,10 @@ def run_scheduler_process(
2250
2288
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2251
2289
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2252
2290
 
2291
+ embedding_cache_size = 100
2292
+ if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
2293
+ embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
2294
+ init_embedding_cache(embedding_cache_size * 1024 * 1024)
2253
2295
  # Create a scheduler and run the event loop
2254
2296
  try:
2255
2297
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
@@ -54,7 +54,7 @@ class SessionReqNode:
54
54
  prefix += " -- " + self.childs[0].req.rid
55
55
  ret = self.childs[0]._str_helper(prefix)
56
56
  for child in self.childs[1:]:
57
- prefix = " " * len(origin_prefix) + " \- " + child.req.rid
57
+ prefix = " " * len(origin_prefix) + r" \- " + child.req.rid
58
58
  ret += child._str_helper(prefix)
59
59
  return ret
60
60
 
@@ -16,6 +16,7 @@
16
16
  import asyncio
17
17
  import copy
18
18
  import dataclasses
19
+ import json
19
20
  import logging
20
21
  import os
21
22
  import pickle
@@ -90,6 +91,8 @@ from sglang.srt.managers.io_struct import (
90
91
  ResumeMemoryOccupationReqInput,
91
92
  ResumeMemoryOccupationReqOutput,
92
93
  SessionParams,
94
+ SetInternalStateReq,
95
+ SetInternalStateReqOutput,
93
96
  SlowDownReqInput,
94
97
  SlowDownReqOutput,
95
98
  TokenizedEmbeddingReqInput,
@@ -169,6 +172,11 @@ class TokenizerManager:
169
172
  self.enable_metrics = server_args.enable_metrics
170
173
  self.log_requests = server_args.log_requests
171
174
  self.log_requests_level = server_args.log_requests_level
175
+ self.preferred_sampling_params = (
176
+ json.loads(server_args.preferred_sampling_params)
177
+ if server_args.preferred_sampling_params
178
+ else None
179
+ )
172
180
 
173
181
  # Init inter-process communication
174
182
  context = zmq.asyncio.Context(2)
@@ -228,6 +236,7 @@ class TokenizerManager:
228
236
  # Store states
229
237
  self.no_create_loop = False
230
238
  self.rid_to_state: Dict[str, ReqState] = {}
239
+ self.health_check_failed = False
231
240
  self.gracefully_exit = False
232
241
  self.last_receive_tstamp = 0
233
242
  self.dump_requests_folder = "" # By default do not dump
@@ -255,6 +264,10 @@ class TokenizerManager:
255
264
  "model_name": self.server_args.served_model_name,
256
265
  # TODO: Add lora name/path in the future,
257
266
  },
267
+ bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
268
+ bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
269
+ bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
270
+ collect_tokens_histogram=self.server_args.collect_tokens_histogram,
258
271
  )
259
272
 
260
273
  # Communicators
@@ -282,12 +295,16 @@ class TokenizerManager:
282
295
  self.flush_cache_communicator = _Communicator(
283
296
  self.send_to_scheduler, server_args.dp_size
284
297
  )
285
- self.start_profile_communicator = _Communicator(
298
+ self.profile_communicator = _Communicator(
286
299
  self.send_to_scheduler, server_args.dp_size
287
300
  )
301
+ self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
288
302
  self.get_internal_state_communicator = _Communicator(
289
303
  self.send_to_scheduler, server_args.dp_size
290
304
  )
305
+ self.set_internal_state_communicator = _Communicator(
306
+ self.send_to_scheduler, server_args.dp_size
307
+ )
291
308
  self.expert_distribution_communicator = _Communicator(
292
309
  self.send_to_scheduler, server_args.dp_size
293
310
  )
@@ -343,12 +360,16 @@ class TokenizerManager:
343
360
  ),
344
361
  (
345
362
  ProfileReqOutput,
346
- self.start_profile_communicator.handle_recv,
363
+ self.profile_communicator.handle_recv,
347
364
  ),
348
365
  (
349
366
  GetInternalStateReqOutput,
350
367
  self.get_internal_state_communicator.handle_recv,
351
368
  ),
369
+ (
370
+ SetInternalStateReqOutput,
371
+ self.set_internal_state_communicator.handle_recv,
372
+ ),
352
373
  (
353
374
  ExpertDistributionReqOutput,
354
375
  self.expert_distribution_communicator.handle_recv,
@@ -438,14 +459,16 @@ class TokenizerManager:
438
459
  )
439
460
  input_ids = self.tokenizer.encode(input_text)
440
461
 
441
- image_inputs: Dict = await self.mm_processor.process_mm_data_async(
442
- image_data=obj.image_data,
443
- input_text=input_text or input_ids,
444
- request_obj=obj,
445
- max_req_input_len=self.max_req_input_len,
446
- )
447
- if image_inputs and "input_ids" in image_inputs:
448
- input_ids = image_inputs["input_ids"]
462
+ image_inputs: Optional[Dict] = None
463
+ if obj.contains_mm_input():
464
+ image_inputs = await self.mm_processor.process_mm_data_async(
465
+ image_data=obj.image_data,
466
+ input_text=input_text or input_ids,
467
+ request_obj=obj,
468
+ max_req_input_len=self.max_req_input_len,
469
+ )
470
+ if image_inputs and "input_ids" in image_inputs:
471
+ input_ids = image_inputs["input_ids"]
449
472
 
450
473
  self._validate_token_len(obj, input_ids)
451
474
  return self._create_tokenized_object(
@@ -508,7 +531,14 @@ class TokenizerManager:
508
531
  "Please set `--enable-custom-logits-processor` to enable this feature."
509
532
  )
510
533
 
511
- sampling_params = SamplingParams(**obj.sampling_params)
534
+ # Parse sampling parameters
535
+ # Note: if there are preferred sampling params, we use them if they are not
536
+ # explicitly passed in sampling_params
537
+ if self.preferred_sampling_params:
538
+ sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
539
+ else:
540
+ sampling_kwargs = obj.sampling_params
541
+ sampling_params = SamplingParams(**sampling_kwargs)
512
542
  sampling_params.normalize(self.tokenizer)
513
543
  sampling_params.verify()
514
544
 
@@ -667,7 +697,6 @@ class TokenizerManager:
667
697
 
668
698
  generators = []
669
699
  rids = []
670
-
671
700
  if getattr(obj, "parallel_sample_num", 1) == 1:
672
701
  if self.server_args.enable_tokenizer_batch_encode:
673
702
  # Validate batch tokenization constraints
@@ -765,6 +794,7 @@ class TokenizerManager:
765
794
  with_stack: Optional[bool] = None,
766
795
  record_shapes: Optional[bool] = None,
767
796
  ):
797
+ self.auto_create_handle_loop()
768
798
  req = ProfileReq(
769
799
  type=ProfileReqType.START_PROFILE,
770
800
  output_dir=output_dir,
@@ -774,22 +804,29 @@ class TokenizerManager:
774
804
  record_shapes=record_shapes,
775
805
  profile_id=str(time.time()),
776
806
  )
777
- result = (await self.start_profile_communicator(req))[0]
807
+ return await self._execute_profile(req)
808
+
809
+ async def stop_profile(self):
810
+ self.auto_create_handle_loop()
811
+ req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
812
+ return await self._execute_profile(req)
813
+
814
+ async def _execute_profile(self, req: ProfileReq):
815
+ result = (await self.profile_communicator(req))[0]
778
816
  if not result.success:
779
817
  raise RuntimeError(result.message)
780
818
  return result
781
819
 
782
- def stop_profile(self):
783
- req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
784
- self.send_to_scheduler.send_pyobj(req)
785
-
786
820
  async def start_expert_distribution_record(self):
821
+ self.auto_create_handle_loop()
787
822
  await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
788
823
 
789
824
  async def stop_expert_distribution_record(self):
825
+ self.auto_create_handle_loop()
790
826
  await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
791
827
 
792
828
  async def dump_expert_distribution_record(self):
829
+ self.auto_create_handle_loop()
793
830
  await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
794
831
 
795
832
  async def update_weights_from_disk(
@@ -856,8 +893,8 @@ class TokenizerManager:
856
893
  ) -> Tuple[bool, str]:
857
894
  self.auto_create_handle_loop()
858
895
  assert (
859
- self.server_args.dp_size == 1
860
- ), "dp_size must be for update weights from distributed"
896
+ self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
897
+ ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
861
898
 
862
899
  # This means that weight sync
863
900
  # cannot run while requests are in progress.
@@ -872,8 +909,8 @@ class TokenizerManager:
872
909
  ) -> Tuple[bool, str]:
873
910
  self.auto_create_handle_loop()
874
911
  assert (
875
- self.server_args.dp_size == 1
876
- ), "dp_size must be 1 for update weights from distributed"
912
+ self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
913
+ ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
877
914
 
878
915
  # This means that weight sync
879
916
  # cannot run while requests are in progress.
@@ -946,6 +983,14 @@ class TokenizerManager:
946
983
  # Many DP ranks
947
984
  return [res.internal_state for res in responses]
948
985
 
986
+ async def set_internal_state(
987
+ self, obj: SetInternalStateReq
988
+ ) -> SetInternalStateReqOutput:
989
+ responses: List[SetInternalStateReqOutput] = (
990
+ await self.set_internal_state_communicator(obj)
991
+ )
992
+ return [res.internal_state for res in responses]
993
+
949
994
  def get_log_request_metadata(self):
950
995
  max_length = None
951
996
  skip_names = None
@@ -1015,11 +1060,17 @@ class TokenizerManager:
1015
1060
  loop.create_task(print_exception_wrapper(self.handle_loop))
1016
1061
  )
1017
1062
 
1063
+ self.event_loop = loop
1064
+
1018
1065
  # We cannot add signal handler when the tokenizer manager is not in
1019
1066
  # the main thread due to the CPython limitation.
1020
1067
  if threading.current_thread() is threading.main_thread():
1021
1068
  signal_handler = SignalHandler(self)
1022
- loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
1069
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
1070
+ # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
1071
+ loop.add_signal_handler(
1072
+ signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
1073
+ )
1023
1074
  else:
1024
1075
  logger.warning(
1025
1076
  "Signal handler is not added because the tokenizer manager is "
@@ -1037,6 +1088,15 @@ class TokenizerManager:
1037
1088
  # Drain requests
1038
1089
  while True:
1039
1090
  remain_num_req = len(self.rid_to_state)
1091
+
1092
+ if self.health_check_failed:
1093
+ # if health check failed, we should exit immediately
1094
+ logger.error(
1095
+ "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1096
+ remain_num_req,
1097
+ )
1098
+ break
1099
+
1040
1100
  logger.info(
1041
1101
  f"Gracefully exiting... remaining number of requests {remain_num_req}"
1042
1102
  )
@@ -1120,7 +1180,16 @@ class TokenizerManager:
1120
1180
  "meta_info": meta_info,
1121
1181
  }
1122
1182
  elif isinstance(recv_obj, BatchMultimodalOut):
1123
- raise NotImplementedError()
1183
+ if isinstance(recv_obj.outputs[i], str):
1184
+ out_dict = {
1185
+ "text": recv_obj.outputs[i],
1186
+ "meta_info": meta_info,
1187
+ }
1188
+ else:
1189
+ out_dict = {
1190
+ "outputs": json.dumps(recv_obj.outputs[i]),
1191
+ "meta_info": meta_info,
1192
+ }
1124
1193
  else:
1125
1194
  assert isinstance(recv_obj, BatchEmbeddingOut)
1126
1195
  out_dict = {
@@ -1366,12 +1435,18 @@ class SignalHandler:
1366
1435
  def __init__(self, tokenizer_manager: TokenizerManager):
1367
1436
  self.tokenizer_manager = tokenizer_manager
1368
1437
 
1369
- def signal_handler(self, signum=None, frame=None):
1438
+ def sigterm_handler(self, signum=None, frame=None):
1370
1439
  logger.warning(
1371
1440
  f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
1372
1441
  )
1373
1442
  self.tokenizer_manager.gracefully_exit = True
1374
1443
 
1444
+ def running_phase_sigquit_handler(self, signum=None, frame=None):
1445
+ logger.error(
1446
+ "Received sigquit from a child process. It usually means the child failed."
1447
+ )
1448
+ kill_process_tree(os.getpid())
1449
+
1375
1450
 
1376
1451
  T = TypeVar("T")
1377
1452
 
@@ -48,3 +48,6 @@ class BasePrefixCache(ABC):
48
48
 
49
49
  def pretty_print(self):
50
50
  raise NotImplementedError()
51
+
52
+ def take_events(self):
53
+ return []