sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """A scheduler that manages a tensor parallel GPU worker."""
15
15
 
16
+ import faulthandler
16
17
  import logging
17
18
  import os
18
19
  import signal
@@ -21,8 +22,10 @@ import time
21
22
  import warnings
22
23
  from collections import deque
23
24
  from concurrent import futures
25
+ from dataclasses import dataclass
26
+ from http import HTTPStatus
24
27
  from types import SimpleNamespace
25
- from typing import Dict, List, Optional, Tuple
28
+ from typing import Dict, List, Optional, Tuple, Union
26
29
 
27
30
  import psutil
28
31
  import setproctitle
@@ -31,7 +34,9 @@ import zmq
31
34
 
32
35
  from sglang.global_config import global_config
33
36
  from sglang.srt.configs.model_config import ModelConfig
37
+ from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
34
38
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
39
+ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
35
40
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
36
41
  from sglang.srt.managers.io_struct import (
37
42
  AbortReq,
@@ -46,6 +51,10 @@ from sglang.srt.managers.io_struct import (
46
51
  OpenSessionReqInput,
47
52
  OpenSessionReqOutput,
48
53
  ProfileReq,
54
+ ReleaseMemoryOccupationReqInput,
55
+ ReleaseMemoryOccupationReqOutput,
56
+ ResumeMemoryOccupationReqInput,
57
+ ResumeMemoryOccupationReqOutput,
49
58
  TokenizedEmbeddingReqInput,
50
59
  TokenizedGenerateReqInput,
51
60
  UpdateWeightFromDiskReqInput,
@@ -71,12 +80,14 @@ from sglang.srt.managers.schedule_policy import (
71
80
  from sglang.srt.managers.session_controller import Session
72
81
  from sglang.srt.managers.tp_worker import TpModelWorker
73
82
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
83
+ from sglang.srt.managers.utils import validate_input_length
74
84
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
75
85
  from sglang.srt.mem_cache.radix_cache import RadixCache
76
86
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
77
87
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
78
88
  from sglang.srt.server_args import PortArgs, ServerArgs
79
89
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
90
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
80
91
  from sglang.srt.utils import (
81
92
  broadcast_pyobj,
82
93
  configure_logger,
@@ -87,7 +98,7 @@ from sglang.srt.utils import (
87
98
  set_random_seed,
88
99
  suppress_other_loggers,
89
100
  )
90
- from sglang.utils import get_exception_traceback
101
+ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
91
102
 
92
103
  logger = logging.getLogger(__name__)
93
104
 
@@ -95,6 +106,19 @@ logger = logging.getLogger(__name__)
95
106
  test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
96
107
 
97
108
 
109
+ @dataclass
110
+ class GenerationBatchResult:
111
+ logits_output: LogitsProcessorOutput
112
+ next_token_ids: List[int]
113
+ bid: int
114
+
115
+
116
+ @dataclass
117
+ class EmbeddingBatchResult:
118
+ embeddings: torch.Tensor
119
+ bid: int
120
+
121
+
98
122
  class Scheduler:
99
123
  """A scheduler that manages a tensor parallel GPU worker."""
100
124
 
@@ -126,26 +150,36 @@ class Scheduler:
126
150
  else 1
127
151
  )
128
152
 
153
+ # Distributed rank info
154
+ self.dp_size = server_args.dp_size
155
+ self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
156
+ compute_dp_attention_world_info(
157
+ server_args.enable_dp_attention,
158
+ self.tp_rank,
159
+ self.tp_size,
160
+ self.dp_size,
161
+ )
162
+ )
163
+
129
164
  # Init inter-process communication
130
165
  context = zmq.Context(2)
131
-
132
- if self.tp_rank == 0 or self.server_args.enable_dp_attention:
166
+ if self.attn_tp_rank == 0:
133
167
  self.recv_from_tokenizer = get_zmq_socket(
134
- context, zmq.PULL, port_args.scheduler_input_ipc_name
168
+ context, zmq.PULL, port_args.scheduler_input_ipc_name, False
135
169
  )
136
170
  self.send_to_tokenizer = get_zmq_socket(
137
- context, zmq.PUSH, port_args.tokenizer_ipc_name
171
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
138
172
  )
139
173
 
140
174
  if server_args.skip_tokenizer_init:
141
175
  # Directly send to the TokenizerManager
142
176
  self.send_to_detokenizer = get_zmq_socket(
143
- context, zmq.PUSH, port_args.tokenizer_ipc_name
177
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
144
178
  )
145
179
  else:
146
180
  # Send to the DetokenizerManager
147
181
  self.send_to_detokenizer = get_zmq_socket(
148
- context, zmq.PUSH, port_args.detokenizer_ipc_name
182
+ context, zmq.PUSH, port_args.detokenizer_ipc_name, False
149
183
  )
150
184
  else:
151
185
  self.recv_from_tokenizer = None
@@ -173,6 +207,7 @@ class Scheduler:
173
207
  server_args.tokenizer_path,
174
208
  tokenizer_mode=server_args.tokenizer_mode,
175
209
  trust_remote_code=server_args.trust_remote_code,
210
+ revision=server_args.revision,
176
211
  )
177
212
  self.tokenizer = self.processor.tokenizer
178
213
  else:
@@ -180,6 +215,7 @@ class Scheduler:
180
215
  server_args.tokenizer_path,
181
216
  tokenizer_mode=server_args.tokenizer_mode,
182
217
  trust_remote_code=server_args.trust_remote_code,
218
+ revision=server_args.revision,
183
219
  )
184
220
 
185
221
  # Check whether overlap can be enabled
@@ -208,7 +244,7 @@ class Scheduler:
208
244
  nccl_port=port_args.nccl_port,
209
245
  )
210
246
 
211
- # Launch worker for speculative decoding if need
247
+ # Launch a worker for speculative decoding if needed
212
248
  if self.spec_algorithm.is_eagle():
213
249
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
214
250
 
@@ -238,10 +274,10 @@ class Scheduler:
238
274
  _,
239
275
  ) = self.tp_worker.get_worker_info()
240
276
  self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
277
+ self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
241
278
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
242
279
  global_server_args_dict.update(worker_global_server_args_dict)
243
280
  set_random_seed(self.random_seed)
244
-
245
281
  # Print debug info
246
282
  logger.info(
247
283
  f"max_total_num_tokens={self.max_total_num_tokens}, "
@@ -281,9 +317,13 @@ class Scheduler:
281
317
  self.forward_ct = 0
282
318
  self.forward_ct_decode = 0
283
319
  self.num_generated_tokens = 0
320
+ self.spec_num_total_accepted_tokens = 0
321
+ self.spec_num_total_forward_ct = 0
284
322
  self.last_decode_stats_tic = time.time()
285
323
  self.stream_interval = server_args.stream_interval
286
324
  self.current_stream = torch.get_device_module(self.device).current_stream()
325
+ if self.device == "cpu":
326
+ self.current_stream.synchronize = lambda: None # No-op for CPU
287
327
 
288
328
  # Session info
289
329
  self.sessions: Dict[str, Session] = {}
@@ -300,28 +340,9 @@ class Scheduler:
300
340
  # Init the grammar backend for constrained generation
301
341
  self.grammar_queue: List[Req] = []
302
342
  if not server_args.skip_tokenizer_init:
303
- if server_args.grammar_backend == "outlines":
304
- from sglang.srt.constrained.outlines_backend import (
305
- OutlinesGrammarBackend,
306
- )
307
-
308
- self.grammar_backend = OutlinesGrammarBackend(
309
- self.tokenizer,
310
- whitespace_pattern=server_args.constrained_json_whitespace_pattern,
311
- allow_jump_forward=not server_args.disable_jump_forward,
312
- )
313
- elif server_args.grammar_backend == "xgrammar":
314
- from sglang.srt.constrained.xgrammar_backend import (
315
- XGrammarGrammarBackend,
316
- )
317
-
318
- self.grammar_backend = XGrammarGrammarBackend(
319
- self.tokenizer, vocab_size=self.model_config.vocab_size
320
- )
321
- else:
322
- raise ValueError(
323
- f"Invalid grammar backend: {server_args.grammar_backend}"
324
- )
343
+ self.grammar_backend = create_grammar_backend(
344
+ server_args, self.tokenizer, self.model_config.vocab_size
345
+ )
325
346
  else:
326
347
  self.grammar_backend = None
327
348
 
@@ -356,6 +377,10 @@ class Scheduler:
356
377
  t.start()
357
378
  self.parent_process = psutil.Process().parent()
358
379
 
380
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
381
+ enable=server_args.enable_memory_saver
382
+ )
383
+
359
384
  # Init profiler
360
385
  if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
361
386
  self.profiler = None
@@ -383,22 +408,53 @@ class Scheduler:
383
408
  },
384
409
  )
385
410
 
411
+ # Init request dispatcher
412
+ self._request_dispatcher = TypeBasedDispatcher(
413
+ [
414
+ (TokenizedGenerateReqInput, self.handle_generate_request),
415
+ (TokenizedEmbeddingReqInput, self.handle_embedding_request),
416
+ (FlushCacheReq, self.flush_cache_wrapped),
417
+ (AbortReq, self.abort_request),
418
+ (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
419
+ (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
420
+ (
421
+ UpdateWeightsFromDistributedReqInput,
422
+ self.update_weights_from_distributed,
423
+ ),
424
+ (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
425
+ (GetWeightsByNameReqInput, self.get_weights_by_name),
426
+ (ProfileReq, self.profile),
427
+ (OpenSessionReqInput, self.open_session),
428
+ (CloseSessionReqInput, self.close_session),
429
+ (
430
+ ReleaseMemoryOccupationReqInput,
431
+ lambda _: self.release_memory_occupation(),
432
+ ),
433
+ (
434
+ ResumeMemoryOccupationReqInput,
435
+ lambda _: self.resume_memory_occupation(),
436
+ ),
437
+ ]
438
+ )
439
+
386
440
  def watchdog_thread(self):
387
441
  """A watch dog thread that will try to kill the server itself if one batch takes too long."""
388
442
  self.watchdog_last_forward_ct = 0
389
443
  self.watchdog_last_time = time.time()
390
444
 
391
445
  while True:
446
+ current = time.time()
392
447
  if self.cur_batch is not None:
393
448
  if self.watchdog_last_forward_ct == self.forward_ct:
394
- if time.time() > self.watchdog_last_time + self.watchdog_timeout:
449
+ if current > self.watchdog_last_time + self.watchdog_timeout:
395
450
  logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
396
451
  break
397
452
  else:
398
453
  self.watchdog_last_forward_ct = self.forward_ct
399
- self.watchdog_last_time = time.time()
400
- time.sleep(self.watchdog_timeout / 2)
401
-
454
+ self.watchdog_last_time = current
455
+ time.sleep(self.watchdog_timeout // 2)
456
+ # Wait sometimes so that the parent process can print the error.
457
+ time.sleep(5)
402
458
  self.parent_process.send_signal(signal.SIGQUIT)
403
459
 
404
460
  @torch.no_grad()
@@ -409,10 +465,6 @@ class Scheduler:
409
465
  self.process_input_requests(recv_reqs)
410
466
 
411
467
  batch = self.get_next_batch_to_run()
412
-
413
- if self.server_args.enable_dp_attention: # TODO: simplify this
414
- batch = self.prepare_dp_attn_batch(batch)
415
-
416
468
  self.cur_batch = batch
417
469
 
418
470
  if batch:
@@ -442,7 +494,7 @@ class Scheduler:
442
494
  result_queue.append((batch.copy(), result))
443
495
 
444
496
  if self.last_batch is None:
445
- # Create a dummy first batch to start the pipeline for overlap scheduler.
497
+ # Create a dummy first batch to start the pipeline for overlap schedule.
446
498
  # It is now used for triggering the sampling_info_done event.
447
499
  tmp_batch = ScheduleBatch(
448
500
  reqs=None,
@@ -467,7 +519,7 @@ class Scheduler:
467
519
 
468
520
  def recv_requests(self) -> List[Req]:
469
521
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
470
- if self.tp_rank == 0 or self.server_args.enable_dp_attention:
522
+ if self.attn_tp_rank == 0:
471
523
  recv_reqs = []
472
524
 
473
525
  while True:
@@ -479,57 +531,48 @@ class Scheduler:
479
531
  else:
480
532
  recv_reqs = None
481
533
 
482
- if self.tp_size != 1 and not self.server_args.enable_dp_attention:
534
+ if self.server_args.enable_dp_attention:
535
+ if self.attn_tp_rank == 0:
536
+ work_reqs = [
537
+ req
538
+ for req in recv_reqs
539
+ if isinstance(
540
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
541
+ )
542
+ ]
543
+ control_reqs = [
544
+ req
545
+ for req in recv_reqs
546
+ if not isinstance(
547
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
548
+ )
549
+ ]
550
+ else:
551
+ work_reqs = None
552
+ control_reqs = None
553
+
554
+ if self.attn_tp_size != 1:
555
+ attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
556
+ work_reqs = broadcast_pyobj(
557
+ work_reqs,
558
+ self.attn_tp_rank,
559
+ self.attn_tp_cpu_group,
560
+ src=attn_tp_rank_0,
561
+ )
562
+ if self.tp_size != 1:
563
+ control_reqs = broadcast_pyobj(
564
+ control_reqs, self.tp_rank, self.tp_cpu_group
565
+ )
566
+ recv_reqs = work_reqs + control_reqs
567
+ elif self.tp_size != 1:
483
568
  recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
484
569
  return recv_reqs
485
570
 
486
571
  def process_input_requests(self, recv_reqs: List):
487
572
  for recv_req in recv_reqs:
488
- if isinstance(recv_req, TokenizedGenerateReqInput):
489
- self.handle_generate_request(recv_req)
490
- elif isinstance(recv_req, TokenizedEmbeddingReqInput):
491
- self.handle_embedding_request(recv_req)
492
- elif isinstance(recv_req, FlushCacheReq):
493
- self.flush_cache()
494
- elif isinstance(recv_req, AbortReq):
495
- self.abort_request(recv_req)
496
- elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
497
- success, message = self.update_weights_from_disk(recv_req)
498
- self.send_to_tokenizer.send_pyobj(
499
- UpdateWeightFromDiskReqOutput(success, message)
500
- )
501
- elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
502
- success, message = self.init_weights_update_group(recv_req)
503
- self.send_to_tokenizer.send_pyobj(
504
- InitWeightsUpdateGroupReqOutput(success, message)
505
- )
506
- elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
507
- success, message = self.update_weights_from_distributed(recv_req)
508
- self.send_to_tokenizer.send_pyobj(
509
- UpdateWeightsFromDistributedReqOutput(success, message)
510
- )
511
- elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
512
- success, message = self.update_weights_from_tensor(recv_req)
513
- self.send_to_tokenizer.send_pyobj(
514
- UpdateWeightsFromTensorReqOutput(success, message)
515
- )
516
- elif isinstance(recv_req, GetWeightsByNameReqInput):
517
- parameter = self.get_weights_by_name(recv_req)
518
- self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
519
- elif isinstance(recv_req, ProfileReq):
520
- if recv_req == ProfileReq.START_PROFILE:
521
- self.start_profile()
522
- else:
523
- self.stop_profile()
524
- elif isinstance(recv_req, OpenSessionReqInput):
525
- session_id, success = self.open_session(recv_req)
526
- self.send_to_tokenizer.send_pyobj(
527
- OpenSessionReqOutput(session_id=session_id, success=success)
528
- )
529
- elif isinstance(recv_req, CloseSessionReqInput):
530
- self.close_session(recv_req)
531
- else:
532
- raise ValueError(f"Invalid request: {recv_req}")
573
+ output = self._request_dispatcher(recv_req)
574
+ if output is not None:
575
+ self.send_to_tokenizer.send_pyobj(output)
533
576
 
534
577
  def handle_generate_request(
535
578
  self,
@@ -548,6 +591,19 @@ class Scheduler:
548
591
  fake_input_ids = [1] * seq_length
549
592
  recv_req.input_ids = fake_input_ids
550
593
 
594
+ # Handle custom logit processor passed to the request
595
+ custom_logit_processor = recv_req.custom_logit_processor
596
+ if (
597
+ not self.server_args.enable_custom_logit_processor
598
+ and custom_logit_processor is not None
599
+ ):
600
+ logger.warning(
601
+ "The SGLang server is not configured to enable custom logit processor."
602
+ "The custom logit processor passed in will be ignored."
603
+ "Please set --enable-custom-logits-processor to enable this feature."
604
+ )
605
+ custom_logit_processor = None
606
+
551
607
  req = Req(
552
608
  recv_req.rid,
553
609
  recv_req.input_text,
@@ -558,6 +614,7 @@ class Scheduler:
558
614
  stream=recv_req.stream,
559
615
  lora_path=recv_req.lora_path,
560
616
  input_embeds=recv_req.input_embeds,
617
+ custom_logit_processor=custom_logit_processor,
561
618
  eos_token_ids=self.model_config.hf_eos_token_id,
562
619
  )
563
620
  req.tokenizer = self.tokenizer
@@ -589,15 +646,16 @@ class Scheduler:
589
646
  req.extend_image_inputs(image_inputs)
590
647
 
591
648
  if len(req.origin_input_ids) >= self.max_req_input_len:
592
- logger.error(
649
+ error_msg = (
593
650
  "Multimodal prompt is too long after expanding multimodal tokens. "
594
- f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
651
+ f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
595
652
  )
653
+ logger.error(error_msg)
596
654
  req.origin_input_ids = [0]
597
655
  req.image_inputs = None
598
656
  req.sampling_params.max_new_tokens = 0
599
657
  req.finished_reason = FINISH_ABORT(
600
- "Multimodal prompt is too long. Check server logs for details."
658
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
601
659
  )
602
660
  self.waiting_queue.append(req)
603
661
  return
@@ -609,13 +667,16 @@ class Scheduler:
609
667
  # By default, only return the logprobs for output tokens
610
668
  req.logprob_start_len = len(req.origin_input_ids) - 1
611
669
 
612
- # Truncate prompts that are too long
613
- if len(req.origin_input_ids) > self.max_req_input_len:
614
- logger.warning(
615
- "Request length is longer than the KV cache pool size or "
616
- "the max context length. Truncated!!!"
617
- )
618
- req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
670
+ # Validate prompts length
671
+ error_msg = validate_input_length(
672
+ req,
673
+ self.max_req_input_len,
674
+ self.server_args.allow_auto_truncate,
675
+ )
676
+
677
+ if error_msg:
678
+ self.waiting_queue.append(req)
679
+ return
619
680
 
620
681
  req.sampling_params.max_new_tokens = min(
621
682
  (
@@ -663,13 +724,12 @@ class Scheduler:
663
724
  )
664
725
  req.tokenizer = self.tokenizer
665
726
 
666
- # Truncate prompts that are too long
667
- if len(req.origin_input_ids) >= self.max_req_input_len:
668
- logger.warning(
669
- "Request length is longer than the KV cache pool size or "
670
- "the max context length. Truncated!!!"
671
- )
672
- req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
727
+ # Validate prompts length
728
+ validate_input_length(
729
+ req,
730
+ self.max_req_input_len,
731
+ self.server_args.allow_auto_truncate,
732
+ )
673
733
 
674
734
  self.waiting_queue.append(req)
675
735
 
@@ -715,21 +775,40 @@ class Scheduler:
715
775
  self.num_generated_tokens = 0
716
776
  self.last_decode_stats_tic = time.time()
717
777
  num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
718
- logger.info(
719
- f"Decode batch. "
720
- f"#running-req: {num_running_reqs}, "
721
- f"#token: {num_used}, "
722
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
723
- f"gen throughput (token/s): {gen_throughput:.2f}, "
724
- f"#queue-req: {len(self.waiting_queue)}"
725
- )
726
778
 
779
+ if self.spec_algorithm.is_none():
780
+ msg = (
781
+ f"Decode batch. "
782
+ f"#running-req: {num_running_reqs}, "
783
+ f"#token: {num_used}, "
784
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
785
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
786
+ f"#queue-req: {len(self.waiting_queue)}"
787
+ )
788
+ spec_accept_length = 0
789
+ else:
790
+ spec_accept_length = (
791
+ self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
792
+ )
793
+ self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
794
+ msg = (
795
+ f"Decode batch. "
796
+ f"#running-req: {num_running_reqs}, "
797
+ f"#token: {num_used}, "
798
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
799
+ f"accept len: {spec_accept_length:.2f}, "
800
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
801
+ f"#queue-req: {len(self.waiting_queue)}"
802
+ )
803
+
804
+ logger.info(msg)
727
805
  if self.enable_metrics:
728
806
  self.stats.num_running_reqs = num_running_reqs
729
807
  self.stats.num_used_tokens = num_used
730
808
  self.stats.token_usage = num_used / self.max_total_num_tokens
731
809
  self.stats.gen_throughput = gen_throughput
732
810
  self.stats.num_queue_reqs = len(self.waiting_queue)
811
+ self.stats.spec_accept_length = spec_accept_length
733
812
  self.metrics_collector.log_stats(self.stats)
734
813
 
735
814
  def check_memory(self):
@@ -772,16 +851,23 @@ class Scheduler:
772
851
  else:
773
852
  self.running_batch.merge_batch(self.last_batch)
774
853
 
775
- # Run prefill first if possible
776
854
  new_batch = self.get_new_batch_prefill()
777
855
  if new_batch is not None:
778
- return new_batch
856
+ # Run prefill first if possible
857
+ ret = new_batch
858
+ else:
859
+ # Run decode
860
+ if self.running_batch is None:
861
+ ret = None
862
+ else:
863
+ self.running_batch = self.update_running_batch(self.running_batch)
864
+ ret = self.running_batch
779
865
 
780
- # Run decode
781
- if self.running_batch is None:
782
- return None
783
- self.running_batch = self.update_running_batch(self.running_batch)
784
- return self.running_batch
866
+ # Handle DP attention
867
+ if self.server_args.enable_dp_attention:
868
+ ret = self.prepare_dp_attn_batch(ret)
869
+
870
+ return ret
785
871
 
786
872
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
787
873
  # Check if the grammar is ready in the grammar queue
@@ -805,9 +891,9 @@ class Scheduler:
805
891
  # Prefill policy
806
892
  adder = PrefillAdder(
807
893
  self.tree_cache,
894
+ self.token_to_kv_pool,
808
895
  self.running_batch,
809
896
  self.new_token_ratio,
810
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
811
897
  self.max_prefill_tokens,
812
898
  self.chunked_prefill_size,
813
899
  running_bs if self.is_mixed_chunk else 0,
@@ -868,7 +954,7 @@ class Scheduler:
868
954
  self.being_chunked_req.is_being_chunked += 1
869
955
 
870
956
  # Print stats
871
- if self.tp_rank == 0:
957
+ if self.attn_tp_rank == 0:
872
958
  self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
873
959
 
874
960
  # Create a new batch
@@ -880,6 +966,7 @@ class Scheduler:
880
966
  self.model_config,
881
967
  self.enable_overlap,
882
968
  self.spec_algorithm,
969
+ self.server_args.enable_custom_logit_processor,
883
970
  )
884
971
  new_batch.prepare_for_extend()
885
972
 
@@ -950,12 +1037,14 @@ class Scheduler:
950
1037
  batch.prepare_for_decode()
951
1038
  return batch
952
1039
 
953
- def run_batch(self, batch: ScheduleBatch):
1040
+ def run_batch(
1041
+ self, batch: ScheduleBatch
1042
+ ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
954
1043
  """Run a batch."""
955
1044
  self.forward_ct += 1
956
1045
 
957
1046
  if self.is_generation:
958
- if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
1047
+ if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
959
1048
  if self.spec_algorithm.is_none():
960
1049
  model_worker_batch = batch.get_model_worker_batch()
961
1050
  logits_output, next_token_ids = (
@@ -968,45 +1057,65 @@ class Scheduler:
968
1057
  model_worker_batch,
969
1058
  num_accepted_tokens,
970
1059
  ) = self.draft_worker.forward_batch_speculative_generation(batch)
1060
+ self.spec_num_total_accepted_tokens += (
1061
+ num_accepted_tokens + batch.batch_size()
1062
+ )
1063
+ self.spec_num_total_forward_ct += batch.batch_size()
971
1064
  self.num_generated_tokens += num_accepted_tokens
972
- elif batch.forward_mode.is_idle():
973
- model_worker_batch = batch.get_model_worker_batch()
974
- self.tp_worker.forward_batch_idle(model_worker_batch)
975
- return
976
1065
  else:
977
- logits_output = None
978
- if self.skip_tokenizer_init:
979
- next_token_ids = torch.full(
980
- (batch.batch_size(),), self.tokenizer.eos_token_id
981
- )
982
- else:
983
- next_token_ids = torch.full((batch.batch_size(),), 0)
1066
+ assert False, "batch.extend_num_tokens == 0, this is unexpected!"
984
1067
  batch.output_ids = next_token_ids
985
- ret = logits_output, next_token_ids, model_worker_batch.bid
1068
+
1069
+ ret = GenerationBatchResult(
1070
+ logits_output=logits_output,
1071
+ next_token_ids=next_token_ids,
1072
+ bid=model_worker_batch.bid,
1073
+ )
986
1074
  else: # embedding or reward model
987
1075
  assert batch.extend_num_tokens != 0
988
1076
  model_worker_batch = batch.get_model_worker_batch()
989
1077
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
990
- ret = embeddings, model_worker_batch.bid
1078
+ ret = EmbeddingBatchResult(
1079
+ embeddings=embeddings, bid=model_worker_batch.bid
1080
+ )
991
1081
  return ret
992
1082
 
993
- def process_batch_result(self, batch: ScheduleBatch, result):
1083
+ def process_batch_result(
1084
+ self,
1085
+ batch: ScheduleBatch,
1086
+ result: Union[GenerationBatchResult, EmbeddingBatchResult],
1087
+ ):
994
1088
  if batch.forward_mode.is_decode():
995
1089
  self.process_batch_result_decode(batch, result)
996
1090
  if batch.is_empty():
997
1091
  self.running_batch = None
998
1092
  elif batch.forward_mode.is_extend():
999
1093
  self.process_batch_result_prefill(batch, result)
1094
+ elif batch.forward_mode.is_idle():
1095
+ if self.enable_overlap:
1096
+ self.tp_worker.resolve_batch_result(result.bid)
1000
1097
  elif batch.forward_mode.is_dummy_first():
1001
1098
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1002
1099
  self.current_stream.synchronize()
1003
1100
  batch.next_batch_sampling_info.sampling_info_done.set()
1004
1101
 
1005
- def process_batch_result_prefill(self, batch: ScheduleBatch, result):
1102
+ def process_batch_result_prefill(
1103
+ self,
1104
+ batch: ScheduleBatch,
1105
+ result: Union[GenerationBatchResult, EmbeddingBatchResult],
1106
+ ):
1006
1107
  skip_stream_req = None
1007
1108
 
1008
1109
  if self.is_generation:
1009
- logits_output, next_token_ids, bid = result
1110
+ (
1111
+ logits_output,
1112
+ next_token_ids,
1113
+ bid,
1114
+ ) = (
1115
+ result.logits_output,
1116
+ result.next_token_ids,
1117
+ result.bid,
1118
+ )
1010
1119
 
1011
1120
  if self.enable_overlap:
1012
1121
  logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
@@ -1020,9 +1129,6 @@ class Scheduler:
1020
1129
  logits_output.input_token_logprobs = (
1021
1130
  logits_output.input_token_logprobs.tolist()
1022
1131
  )
1023
- logits_output.normalized_prompt_logprobs = (
1024
- logits_output.normalized_prompt_logprobs.tolist()
1025
- )
1026
1132
 
1027
1133
  # Check finish conditions
1028
1134
  logprob_pt = 0
@@ -1067,7 +1173,7 @@ class Scheduler:
1067
1173
  batch.next_batch_sampling_info.sampling_info_done.set()
1068
1174
 
1069
1175
  else: # embedding or reward model
1070
- embeddings, bid = result
1176
+ embeddings, bid = result.embeddings, result.bid
1071
1177
  embeddings = embeddings.tolist()
1072
1178
 
1073
1179
  # Check finish conditions
@@ -1091,8 +1197,16 @@ class Scheduler:
1091
1197
 
1092
1198
  self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1093
1199
 
1094
- def process_batch_result_decode(self, batch: ScheduleBatch, result):
1095
- logits_output, next_token_ids, bid = result
1200
+ def process_batch_result_decode(
1201
+ self,
1202
+ batch: ScheduleBatch,
1203
+ result: GenerationBatchResult,
1204
+ ):
1205
+ logits_output, next_token_ids, bid = (
1206
+ result.logits_output,
1207
+ result.next_token_ids,
1208
+ result.bid,
1209
+ )
1096
1210
  self.num_generated_tokens += len(batch.reqs)
1097
1211
 
1098
1212
  if self.enable_overlap:
@@ -1150,7 +1264,7 @@ class Scheduler:
1150
1264
 
1151
1265
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
1152
1266
  if (
1153
- self.tp_rank == 0
1267
+ self.attn_tp_rank == 0
1154
1268
  and self.forward_ct_decode % self.server_args.decode_log_interval == 0
1155
1269
  ):
1156
1270
  self.log_decode_stats()
@@ -1170,9 +1284,6 @@ class Scheduler:
1170
1284
  # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
1171
1285
  num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
1172
1286
 
1173
- if req.normalized_prompt_logprob is None:
1174
- req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
1175
-
1176
1287
  if req.input_token_logprobs_val is None:
1177
1288
  input_token_logprobs_val = output.input_token_logprobs[
1178
1289
  pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
@@ -1253,7 +1364,6 @@ class Scheduler:
1253
1364
  decode_ids_list = []
1254
1365
  read_offsets = []
1255
1366
  output_ids = []
1256
- origin_input_ids = []
1257
1367
 
1258
1368
  skip_special_tokens = []
1259
1369
  spaces_between_special_tokens = []
@@ -1271,15 +1381,12 @@ class Scheduler:
1271
1381
  input_top_logprobs_idx = []
1272
1382
  output_top_logprobs_val = []
1273
1383
  output_top_logprobs_idx = []
1274
- normalized_prompt_logprob = []
1275
1384
  else:
1276
1385
  input_token_logprobs_val = input_token_logprobs_idx = (
1277
1386
  output_token_logprobs_val
1278
1387
  ) = output_token_logprobs_idx = input_top_logprobs_val = (
1279
1388
  input_top_logprobs_idx
1280
- ) = output_top_logprobs_val = output_top_logprobs_idx = (
1281
- normalized_prompt_logprob
1282
- ) = None
1389
+ ) = output_top_logprobs_val = output_top_logprobs_idx = None
1283
1390
 
1284
1391
  for req in reqs:
1285
1392
  if req is skip_req:
@@ -1305,14 +1412,8 @@ class Scheduler:
1305
1412
  decode_ids, read_offset = req.init_incremental_detokenize()
1306
1413
  decode_ids_list.append(decode_ids)
1307
1414
  read_offsets.append(read_offset)
1308
- if self.skip_tokenizer_init or self.server_args.return_token_ids:
1415
+ if self.skip_tokenizer_init:
1309
1416
  output_ids.append(req.output_ids)
1310
- else:
1311
- output_ids = None
1312
- if self.server_args.return_token_ids:
1313
- origin_input_ids.append(req.origin_input_ids)
1314
- else:
1315
- origin_input_ids = None
1316
1417
  skip_special_tokens.append(req.sampling_params.skip_special_tokens)
1317
1418
  spaces_between_special_tokens.append(
1318
1419
  req.sampling_params.spaces_between_special_tokens
@@ -1332,7 +1433,6 @@ class Scheduler:
1332
1433
  input_top_logprobs_idx.append(req.input_top_logprobs_idx)
1333
1434
  output_top_logprobs_val.append(req.output_top_logprobs_val)
1334
1435
  output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1335
- normalized_prompt_logprob.append(req.normalized_prompt_logprob)
1336
1436
 
1337
1437
  # Send to detokenizer
1338
1438
  if rids:
@@ -1344,7 +1444,6 @@ class Scheduler:
1344
1444
  decoded_texts,
1345
1445
  decode_ids_list,
1346
1446
  read_offsets,
1347
- origin_input_ids,
1348
1447
  output_ids,
1349
1448
  skip_special_tokens,
1350
1449
  spaces_between_special_tokens,
@@ -1360,7 +1459,6 @@ class Scheduler:
1360
1459
  input_top_logprobs_idx,
1361
1460
  output_top_logprobs_val,
1362
1461
  output_top_logprobs_idx,
1363
- normalized_prompt_logprob,
1364
1462
  )
1365
1463
  )
1366
1464
  else: # embedding or reward model
@@ -1402,12 +1500,7 @@ class Scheduler:
1402
1500
  # Check forward mode for cuda graph
1403
1501
  if not self.server_args.disable_cuda_graph:
1404
1502
  forward_mode_state = torch.tensor(
1405
- (
1406
- 1
1407
- if local_batch.forward_mode.is_decode()
1408
- or local_batch.forward_mode.is_idle()
1409
- else 0
1410
- ),
1503
+ (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1411
1504
  dtype=torch.int32,
1412
1505
  )
1413
1506
  torch.distributed.all_reduce(
@@ -1428,6 +1521,7 @@ class Scheduler:
1428
1521
  self.model_config,
1429
1522
  self.enable_overlap,
1430
1523
  self.spec_algorithm,
1524
+ self.server_args.enable_custom_logit_processor,
1431
1525
  )
1432
1526
  idle_batch.prepare_for_idle()
1433
1527
  return idle_batch
@@ -1456,6 +1550,9 @@ class Scheduler:
1456
1550
  self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
1457
1551
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1458
1552
 
1553
+ def flush_cache_wrapped(self, recv_req: FlushCacheReq):
1554
+ self.flush_cache()
1555
+
1459
1556
  def flush_cache(self):
1460
1557
  """Flush the memory pool and cache."""
1461
1558
  if len(self.waiting_queue) == 0 and (
@@ -1508,12 +1605,12 @@ class Scheduler:
1508
1605
  assert flash_cache_success, "Cache flush failed after updating weights"
1509
1606
  else:
1510
1607
  logger.error(message)
1511
- return success, message
1608
+ return UpdateWeightFromDiskReqOutput(success, message)
1512
1609
 
1513
1610
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
1514
1611
  """Initialize the online model parameter update group."""
1515
1612
  success, message = self.tp_worker.init_weights_update_group(recv_req)
1516
- return success, message
1613
+ return InitWeightsUpdateGroupReqOutput(success, message)
1517
1614
 
1518
1615
  def update_weights_from_distributed(
1519
1616
  self,
@@ -1526,7 +1623,7 @@ class Scheduler:
1526
1623
  assert flash_cache_success, "Cache flush failed after updating weights"
1527
1624
  else:
1528
1625
  logger.error(message)
1529
- return success, message
1626
+ return UpdateWeightsFromDistributedReqOutput(success, message)
1530
1627
 
1531
1628
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
1532
1629
  """Update the online model parameter from tensors."""
@@ -1537,11 +1634,33 @@ class Scheduler:
1537
1634
  assert flash_cache_success, "Cache flush failed after updating weights"
1538
1635
  else:
1539
1636
  logger.error(message)
1540
- return success, message
1637
+ return UpdateWeightsFromTensorReqOutput(success, message)
1541
1638
 
1542
1639
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
1543
1640
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1544
- return parameter
1641
+ return GetWeightsByNameReqOutput(parameter)
1642
+
1643
+ def release_memory_occupation(self):
1644
+ self.stashed_model_static_state = _export_static_state(
1645
+ self.tp_worker.worker.model_runner.model
1646
+ )
1647
+ self.memory_saver_adapter.pause()
1648
+ self.flush_cache()
1649
+ return ReleaseMemoryOccupationReqOutput()
1650
+
1651
+ def resume_memory_occupation(self):
1652
+ self.memory_saver_adapter.resume()
1653
+ _import_static_state(
1654
+ self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
1655
+ )
1656
+ del self.stashed_model_static_state
1657
+ return ResumeMemoryOccupationReqOutput()
1658
+
1659
+ def profile(self, recv_req: ProfileReq):
1660
+ if recv_req == ProfileReq.START_PROFILE:
1661
+ self.start_profile()
1662
+ else:
1663
+ self.stop_profile()
1545
1664
 
1546
1665
  def start_profile(self) -> None:
1547
1666
  if self.profiler is None:
@@ -1557,20 +1676,20 @@ class Scheduler:
1557
1676
  )
1558
1677
  logger.info("Profiler is done")
1559
1678
 
1560
- def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
1679
+ def open_session(self, recv_req: OpenSessionReqInput):
1561
1680
  # handle error
1562
1681
  session_id = recv_req.session_id
1563
1682
  if session_id in self.sessions:
1564
1683
  logger.warning(f"session id {session_id} already exist, cannot open.")
1565
- return session_id, False
1684
+ return OpenSessionReqOutput(session_id, False)
1566
1685
  elif session_id is None:
1567
1686
  logger.warning(f"session id is None, cannot open.")
1568
- return session_id, False
1687
+ return OpenSessionReqOutput(session_id, False)
1569
1688
  else:
1570
1689
  self.sessions[session_id] = Session(
1571
1690
  recv_req.capacity_of_str_len, session_id
1572
1691
  )
1573
- return session_id, True
1692
+ return OpenSessionReqOutput(session_id, True)
1574
1693
 
1575
1694
  def close_session(self, recv_req: CloseSessionReqInput):
1576
1695
  # handle error
@@ -1581,6 +1700,20 @@ class Scheduler:
1581
1700
  del self.sessions[session_id]
1582
1701
 
1583
1702
 
1703
+ def _export_static_state(model):
1704
+ return dict(
1705
+ buffers=[
1706
+ (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
1707
+ ]
1708
+ )
1709
+
1710
+
1711
+ def _import_static_state(model, static_params):
1712
+ self_named_buffers = dict(model.named_buffers())
1713
+ for name, tensor in static_params["buffers"]:
1714
+ self_named_buffers[name][...] = tensor
1715
+
1716
+
1584
1717
  def run_scheduler_process(
1585
1718
  server_args: ServerArgs,
1586
1719
  port_args: PortArgs,
@@ -1590,6 +1723,7 @@ def run_scheduler_process(
1590
1723
  pipe_writer,
1591
1724
  ):
1592
1725
  setproctitle.setproctitle("sglang::scheduler")
1726
+ faulthandler.enable()
1593
1727
 
1594
1728
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
1595
1729
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
@@ -1612,7 +1746,11 @@ def run_scheduler_process(
1612
1746
  try:
1613
1747
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1614
1748
  pipe_writer.send(
1615
- {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
1749
+ {
1750
+ "status": "ready",
1751
+ "max_total_num_tokens": scheduler.max_total_num_tokens,
1752
+ "max_req_input_len": scheduler.max_req_input_len,
1753
+ }
1616
1754
  )
1617
1755
  if scheduler.enable_overlap:
1618
1756
  scheduler.event_loop_overlap()