sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +1 -0
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +41 -5
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
  41. sglang/srt/layers/parameter.py +2 -1
  42. sglang/srt/layers/quantization/__init__.py +20 -23
  43. sglang/srt/layers/quantization/fp8.py +6 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  45. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  46. sglang/srt/layers/radix_attention.py +2 -2
  47. sglang/srt/layers/rotary_embedding.py +1179 -31
  48. sglang/srt/layers/sampler.py +39 -1
  49. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  50. sglang/srt/lora/lora.py +1 -9
  51. sglang/srt/managers/configure_logging.py +3 -0
  52. sglang/srt/managers/data_parallel_controller.py +79 -72
  53. sglang/srt/managers/detokenizer_manager.py +23 -6
  54. sglang/srt/managers/image_processor.py +158 -2
  55. sglang/srt/managers/io_struct.py +25 -2
  56. sglang/srt/managers/schedule_batch.py +49 -22
  57. sglang/srt/managers/schedule_policy.py +26 -12
  58. sglang/srt/managers/scheduler.py +277 -178
  59. sglang/srt/managers/session_controller.py +1 -0
  60. sglang/srt/managers/tokenizer_manager.py +206 -121
  61. sglang/srt/managers/tp_worker.py +6 -4
  62. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  63. sglang/srt/managers/utils.py +44 -0
  64. sglang/srt/mem_cache/memory_pool.py +10 -32
  65. sglang/srt/metrics/collector.py +15 -6
  66. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  67. sglang/srt/model_executor/model_runner.py +37 -15
  68. sglang/srt/model_loader/loader.py +8 -6
  69. sglang/srt/model_loader/weight_utils.py +55 -2
  70. sglang/srt/models/baichuan.py +6 -6
  71. sglang/srt/models/chatglm.py +2 -2
  72. sglang/srt/models/commandr.py +3 -3
  73. sglang/srt/models/dbrx.py +4 -4
  74. sglang/srt/models/deepseek.py +3 -3
  75. sglang/srt/models/deepseek_v2.py +8 -8
  76. sglang/srt/models/exaone.py +2 -2
  77. sglang/srt/models/gemma.py +2 -2
  78. sglang/srt/models/gemma2.py +6 -24
  79. sglang/srt/models/gpt2.py +3 -5
  80. sglang/srt/models/gpt_bigcode.py +1 -1
  81. sglang/srt/models/granite.py +2 -2
  82. sglang/srt/models/grok.py +3 -3
  83. sglang/srt/models/internlm2.py +2 -2
  84. sglang/srt/models/llama.py +7 -5
  85. sglang/srt/models/minicpm.py +2 -2
  86. sglang/srt/models/minicpm3.py +6 -6
  87. sglang/srt/models/minicpmv.py +1238 -0
  88. sglang/srt/models/mixtral.py +3 -3
  89. sglang/srt/models/mixtral_quant.py +3 -3
  90. sglang/srt/models/mllama.py +2 -2
  91. sglang/srt/models/olmo.py +3 -3
  92. sglang/srt/models/olmo2.py +4 -4
  93. sglang/srt/models/olmoe.py +7 -13
  94. sglang/srt/models/phi3_small.py +2 -2
  95. sglang/srt/models/qwen.py +2 -2
  96. sglang/srt/models/qwen2.py +41 -4
  97. sglang/srt/models/qwen2_moe.py +3 -3
  98. sglang/srt/models/qwen2_vl.py +22 -122
  99. sglang/srt/models/stablelm.py +2 -2
  100. sglang/srt/models/torch_native_llama.py +3 -3
  101. sglang/srt/models/xverse.py +6 -6
  102. sglang/srt/models/xverse_moe.py +6 -6
  103. sglang/srt/openai_api/protocol.py +2 -0
  104. sglang/srt/sampling/custom_logit_processor.py +38 -0
  105. sglang/srt/sampling/sampling_batch_info.py +139 -4
  106. sglang/srt/sampling/sampling_params.py +3 -1
  107. sglang/srt/server.py +4 -1090
  108. sglang/srt/server_args.py +57 -14
  109. sglang/srt/utils.py +103 -65
  110. sglang/test/runners.py +8 -13
  111. sglang/test/test_programs.py +1 -1
  112. sglang/test/test_utils.py +3 -1
  113. sglang/utils.py +12 -2
  114. sglang/version.py +1 -1
  115. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
  116. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
  117. sglang/launch_server_llavavid.py +0 -25
  118. sglang/srt/constrained/__init__.py +0 -16
  119. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  120. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  121. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -22,8 +22,10 @@ import time
22
22
  import warnings
23
23
  from collections import deque
24
24
  from concurrent import futures
25
+ from dataclasses import dataclass
26
+ from http import HTTPStatus
25
27
  from types import SimpleNamespace
26
- from typing import Dict, List, Optional, Tuple
28
+ from typing import Dict, List, Optional, Tuple, Union
27
29
 
28
30
  import psutil
29
31
  import setproctitle
@@ -32,7 +34,9 @@ import zmq
32
34
 
33
35
  from sglang.global_config import global_config
34
36
  from sglang.srt.configs.model_config import ModelConfig
37
+ from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
35
38
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
39
+ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
36
40
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
37
41
  from sglang.srt.managers.io_struct import (
38
42
  AbortReq,
@@ -76,6 +80,7 @@ from sglang.srt.managers.schedule_policy import (
76
80
  from sglang.srt.managers.session_controller import Session
77
81
  from sglang.srt.managers.tp_worker import TpModelWorker
78
82
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
83
+ from sglang.srt.managers.utils import validate_input_length
79
84
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
80
85
  from sglang.srt.mem_cache.radix_cache import RadixCache
81
86
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -93,7 +98,7 @@ from sglang.srt.utils import (
93
98
  set_random_seed,
94
99
  suppress_other_loggers,
95
100
  )
96
- from sglang.utils import get_exception_traceback
101
+ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
97
102
 
98
103
  logger = logging.getLogger(__name__)
99
104
 
@@ -101,6 +106,19 @@ logger = logging.getLogger(__name__)
101
106
  test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
102
107
 
103
108
 
109
+ @dataclass
110
+ class GenerationBatchResult:
111
+ logits_output: LogitsProcessorOutput
112
+ next_token_ids: List[int]
113
+ bid: int
114
+
115
+
116
+ @dataclass
117
+ class EmbeddingBatchResult:
118
+ embeddings: torch.Tensor
119
+ bid: int
120
+
121
+
104
122
  class Scheduler:
105
123
  """A scheduler that manages a tensor parallel GPU worker."""
106
124
 
@@ -132,26 +150,36 @@ class Scheduler:
132
150
  else 1
133
151
  )
134
152
 
153
+ # Distributed rank info
154
+ self.dp_size = server_args.dp_size
155
+ self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
156
+ compute_dp_attention_world_info(
157
+ server_args.enable_dp_attention,
158
+ self.tp_rank,
159
+ self.tp_size,
160
+ self.dp_size,
161
+ )
162
+ )
163
+
135
164
  # Init inter-process communication
136
165
  context = zmq.Context(2)
137
-
138
- if self.tp_rank == 0 or self.server_args.enable_dp_attention:
166
+ if self.attn_tp_rank == 0:
139
167
  self.recv_from_tokenizer = get_zmq_socket(
140
- context, zmq.PULL, port_args.scheduler_input_ipc_name
168
+ context, zmq.PULL, port_args.scheduler_input_ipc_name, False
141
169
  )
142
170
  self.send_to_tokenizer = get_zmq_socket(
143
- context, zmq.PUSH, port_args.tokenizer_ipc_name
171
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
144
172
  )
145
173
 
146
174
  if server_args.skip_tokenizer_init:
147
175
  # Directly send to the TokenizerManager
148
176
  self.send_to_detokenizer = get_zmq_socket(
149
- context, zmq.PUSH, port_args.tokenizer_ipc_name
177
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
150
178
  )
151
179
  else:
152
180
  # Send to the DetokenizerManager
153
181
  self.send_to_detokenizer = get_zmq_socket(
154
- context, zmq.PUSH, port_args.detokenizer_ipc_name
182
+ context, zmq.PUSH, port_args.detokenizer_ipc_name, False
155
183
  )
156
184
  else:
157
185
  self.recv_from_tokenizer = None
@@ -179,6 +207,7 @@ class Scheduler:
179
207
  server_args.tokenizer_path,
180
208
  tokenizer_mode=server_args.tokenizer_mode,
181
209
  trust_remote_code=server_args.trust_remote_code,
210
+ revision=server_args.revision,
182
211
  )
183
212
  self.tokenizer = self.processor.tokenizer
184
213
  else:
@@ -186,6 +215,7 @@ class Scheduler:
186
215
  server_args.tokenizer_path,
187
216
  tokenizer_mode=server_args.tokenizer_mode,
188
217
  trust_remote_code=server_args.trust_remote_code,
218
+ revision=server_args.revision,
189
219
  )
190
220
 
191
221
  # Check whether overlap can be enabled
@@ -214,7 +244,7 @@ class Scheduler:
214
244
  nccl_port=port_args.nccl_port,
215
245
  )
216
246
 
217
- # Launch worker for speculative decoding if need
247
+ # Launch a worker for speculative decoding if needed
218
248
  if self.spec_algorithm.is_eagle():
219
249
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
220
250
 
@@ -244,10 +274,10 @@ class Scheduler:
244
274
  _,
245
275
  ) = self.tp_worker.get_worker_info()
246
276
  self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
277
+ self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
247
278
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
248
279
  global_server_args_dict.update(worker_global_server_args_dict)
249
280
  set_random_seed(self.random_seed)
250
-
251
281
  # Print debug info
252
282
  logger.info(
253
283
  f"max_total_num_tokens={self.max_total_num_tokens}, "
@@ -287,9 +317,13 @@ class Scheduler:
287
317
  self.forward_ct = 0
288
318
  self.forward_ct_decode = 0
289
319
  self.num_generated_tokens = 0
320
+ self.spec_num_total_accepted_tokens = 0
321
+ self.spec_num_total_forward_ct = 0
290
322
  self.last_decode_stats_tic = time.time()
291
323
  self.stream_interval = server_args.stream_interval
292
324
  self.current_stream = torch.get_device_module(self.device).current_stream()
325
+ if self.device == "cpu":
326
+ self.current_stream.synchronize = lambda: None # No-op for CPU
293
327
 
294
328
  # Session info
295
329
  self.sessions: Dict[str, Session] = {}
@@ -306,28 +340,9 @@ class Scheduler:
306
340
  # Init the grammar backend for constrained generation
307
341
  self.grammar_queue: List[Req] = []
308
342
  if not server_args.skip_tokenizer_init:
309
- if server_args.grammar_backend == "outlines":
310
- from sglang.srt.constrained.outlines_backend import (
311
- OutlinesGrammarBackend,
312
- )
313
-
314
- self.grammar_backend = OutlinesGrammarBackend(
315
- self.tokenizer,
316
- whitespace_pattern=server_args.constrained_json_whitespace_pattern,
317
- allow_jump_forward=not server_args.disable_jump_forward,
318
- )
319
- elif server_args.grammar_backend == "xgrammar":
320
- from sglang.srt.constrained.xgrammar_backend import (
321
- XGrammarGrammarBackend,
322
- )
323
-
324
- self.grammar_backend = XGrammarGrammarBackend(
325
- self.tokenizer, vocab_size=self.model_config.vocab_size
326
- )
327
- else:
328
- raise ValueError(
329
- f"Invalid grammar backend: {server_args.grammar_backend}"
330
- )
343
+ self.grammar_backend = create_grammar_backend(
344
+ server_args, self.tokenizer, self.model_config.vocab_size
345
+ )
331
346
  else:
332
347
  self.grammar_backend = None
333
348
 
@@ -393,22 +408,51 @@ class Scheduler:
393
408
  },
394
409
  )
395
410
 
411
+ # Init request dispatcher
412
+ self._request_dispatcher = TypeBasedDispatcher(
413
+ [
414
+ (TokenizedGenerateReqInput, self.handle_generate_request),
415
+ (TokenizedEmbeddingReqInput, self.handle_embedding_request),
416
+ (FlushCacheReq, self.flush_cache_wrapped),
417
+ (AbortReq, self.abort_request),
418
+ (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
419
+ (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
420
+ (
421
+ UpdateWeightsFromDistributedReqInput,
422
+ self.update_weights_from_distributed,
423
+ ),
424
+ (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
425
+ (GetWeightsByNameReqInput, self.get_weights_by_name),
426
+ (ProfileReq, self.profile),
427
+ (OpenSessionReqInput, self.open_session),
428
+ (CloseSessionReqInput, self.close_session),
429
+ (
430
+ ReleaseMemoryOccupationReqInput,
431
+ lambda _: self.release_memory_occupation(),
432
+ ),
433
+ (
434
+ ResumeMemoryOccupationReqInput,
435
+ lambda _: self.resume_memory_occupation(),
436
+ ),
437
+ ]
438
+ )
439
+
396
440
  def watchdog_thread(self):
397
441
  """A watch dog thread that will try to kill the server itself if one batch takes too long."""
398
442
  self.watchdog_last_forward_ct = 0
399
443
  self.watchdog_last_time = time.time()
400
444
 
401
445
  while True:
446
+ current = time.time()
402
447
  if self.cur_batch is not None:
403
448
  if self.watchdog_last_forward_ct == self.forward_ct:
404
- if time.time() > self.watchdog_last_time + self.watchdog_timeout:
449
+ if current > self.watchdog_last_time + self.watchdog_timeout:
405
450
  logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
406
451
  break
407
452
  else:
408
453
  self.watchdog_last_forward_ct = self.forward_ct
409
- self.watchdog_last_time = time.time()
410
- time.sleep(self.watchdog_timeout / 2)
411
-
454
+ self.watchdog_last_time = current
455
+ time.sleep(self.watchdog_timeout // 2)
412
456
  # Wait sometimes so that the parent process can print the error.
413
457
  time.sleep(5)
414
458
  self.parent_process.send_signal(signal.SIGQUIT)
@@ -421,10 +465,6 @@ class Scheduler:
421
465
  self.process_input_requests(recv_reqs)
422
466
 
423
467
  batch = self.get_next_batch_to_run()
424
-
425
- if self.server_args.enable_dp_attention: # TODO: simplify this
426
- batch = self.prepare_dp_attn_batch(batch)
427
-
428
468
  self.cur_batch = batch
429
469
 
430
470
  if batch:
@@ -454,7 +494,7 @@ class Scheduler:
454
494
  result_queue.append((batch.copy(), result))
455
495
 
456
496
  if self.last_batch is None:
457
- # Create a dummy first batch to start the pipeline for overlap scheduler.
497
+ # Create a dummy first batch to start the pipeline for overlap schedule.
458
498
  # It is now used for triggering the sampling_info_done event.
459
499
  tmp_batch = ScheduleBatch(
460
500
  reqs=None,
@@ -479,7 +519,7 @@ class Scheduler:
479
519
 
480
520
  def recv_requests(self) -> List[Req]:
481
521
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
482
- if self.tp_rank == 0 or self.server_args.enable_dp_attention:
522
+ if self.attn_tp_rank == 0:
483
523
  recv_reqs = []
484
524
 
485
525
  while True:
@@ -491,63 +531,48 @@ class Scheduler:
491
531
  else:
492
532
  recv_reqs = None
493
533
 
494
- if self.tp_size != 1 and not self.server_args.enable_dp_attention:
534
+ if self.server_args.enable_dp_attention:
535
+ if self.attn_tp_rank == 0:
536
+ work_reqs = [
537
+ req
538
+ for req in recv_reqs
539
+ if isinstance(
540
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
541
+ )
542
+ ]
543
+ control_reqs = [
544
+ req
545
+ for req in recv_reqs
546
+ if not isinstance(
547
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
548
+ )
549
+ ]
550
+ else:
551
+ work_reqs = None
552
+ control_reqs = None
553
+
554
+ if self.attn_tp_size != 1:
555
+ attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
556
+ work_reqs = broadcast_pyobj(
557
+ work_reqs,
558
+ self.attn_tp_rank,
559
+ self.attn_tp_cpu_group,
560
+ src=attn_tp_rank_0,
561
+ )
562
+ if self.tp_size != 1:
563
+ control_reqs = broadcast_pyobj(
564
+ control_reqs, self.tp_rank, self.tp_cpu_group
565
+ )
566
+ recv_reqs = work_reqs + control_reqs
567
+ elif self.tp_size != 1:
495
568
  recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
496
569
  return recv_reqs
497
570
 
498
571
  def process_input_requests(self, recv_reqs: List):
499
572
  for recv_req in recv_reqs:
500
- if isinstance(recv_req, TokenizedGenerateReqInput):
501
- self.handle_generate_request(recv_req)
502
- elif isinstance(recv_req, TokenizedEmbeddingReqInput):
503
- self.handle_embedding_request(recv_req)
504
- elif isinstance(recv_req, FlushCacheReq):
505
- self.flush_cache()
506
- elif isinstance(recv_req, AbortReq):
507
- self.abort_request(recv_req)
508
- elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
509
- success, message = self.update_weights_from_disk(recv_req)
510
- self.send_to_tokenizer.send_pyobj(
511
- UpdateWeightFromDiskReqOutput(success, message)
512
- )
513
- elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
514
- success, message = self.init_weights_update_group(recv_req)
515
- self.send_to_tokenizer.send_pyobj(
516
- InitWeightsUpdateGroupReqOutput(success, message)
517
- )
518
- elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
519
- success, message = self.update_weights_from_distributed(recv_req)
520
- self.send_to_tokenizer.send_pyobj(
521
- UpdateWeightsFromDistributedReqOutput(success, message)
522
- )
523
- elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
524
- success, message = self.update_weights_from_tensor(recv_req)
525
- self.send_to_tokenizer.send_pyobj(
526
- UpdateWeightsFromTensorReqOutput(success, message)
527
- )
528
- elif isinstance(recv_req, GetWeightsByNameReqInput):
529
- parameter = self.get_weights_by_name(recv_req)
530
- self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
531
- elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
532
- self.release_memory_occupation()
533
- self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
534
- elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
535
- self.resume_memory_occupation()
536
- self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
537
- elif isinstance(recv_req, ProfileReq):
538
- if recv_req == ProfileReq.START_PROFILE:
539
- self.start_profile()
540
- else:
541
- self.stop_profile()
542
- elif isinstance(recv_req, OpenSessionReqInput):
543
- session_id, success = self.open_session(recv_req)
544
- self.send_to_tokenizer.send_pyobj(
545
- OpenSessionReqOutput(session_id=session_id, success=success)
546
- )
547
- elif isinstance(recv_req, CloseSessionReqInput):
548
- self.close_session(recv_req)
549
- else:
550
- raise ValueError(f"Invalid request: {recv_req}")
573
+ output = self._request_dispatcher(recv_req)
574
+ if output is not None:
575
+ self.send_to_tokenizer.send_pyobj(output)
551
576
 
552
577
  def handle_generate_request(
553
578
  self,
@@ -566,6 +591,19 @@ class Scheduler:
566
591
  fake_input_ids = [1] * seq_length
567
592
  recv_req.input_ids = fake_input_ids
568
593
 
594
+ # Handle custom logit processor passed to the request
595
+ custom_logit_processor = recv_req.custom_logit_processor
596
+ if (
597
+ not self.server_args.enable_custom_logit_processor
598
+ and custom_logit_processor is not None
599
+ ):
600
+ logger.warning(
601
+ "The SGLang server is not configured to enable custom logit processor."
602
+ "The custom logit processor passed in will be ignored."
603
+ "Please set --enable-custom-logits-processor to enable this feature."
604
+ )
605
+ custom_logit_processor = None
606
+
569
607
  req = Req(
570
608
  recv_req.rid,
571
609
  recv_req.input_text,
@@ -576,6 +614,7 @@ class Scheduler:
576
614
  stream=recv_req.stream,
577
615
  lora_path=recv_req.lora_path,
578
616
  input_embeds=recv_req.input_embeds,
617
+ custom_logit_processor=custom_logit_processor,
579
618
  eos_token_ids=self.model_config.hf_eos_token_id,
580
619
  )
581
620
  req.tokenizer = self.tokenizer
@@ -607,15 +646,16 @@ class Scheduler:
607
646
  req.extend_image_inputs(image_inputs)
608
647
 
609
648
  if len(req.origin_input_ids) >= self.max_req_input_len:
610
- logger.error(
649
+ error_msg = (
611
650
  "Multimodal prompt is too long after expanding multimodal tokens. "
612
- f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
651
+ f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
613
652
  )
653
+ logger.error(error_msg)
614
654
  req.origin_input_ids = [0]
615
655
  req.image_inputs = None
616
656
  req.sampling_params.max_new_tokens = 0
617
657
  req.finished_reason = FINISH_ABORT(
618
- "Multimodal prompt is too long. Check server logs for details."
658
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
619
659
  )
620
660
  self.waiting_queue.append(req)
621
661
  return
@@ -627,13 +667,16 @@ class Scheduler:
627
667
  # By default, only return the logprobs for output tokens
628
668
  req.logprob_start_len = len(req.origin_input_ids) - 1
629
669
 
630
- # Truncate prompts that are too long
631
- if len(req.origin_input_ids) > self.max_req_input_len:
632
- logger.warning(
633
- "Request length is longer than the KV cache pool size or "
634
- "the max context length. Truncated!!!"
635
- )
636
- req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
670
+ # Validate prompts length
671
+ error_msg = validate_input_length(
672
+ req,
673
+ self.max_req_input_len,
674
+ self.server_args.allow_auto_truncate,
675
+ )
676
+
677
+ if error_msg:
678
+ self.waiting_queue.append(req)
679
+ return
637
680
 
638
681
  req.sampling_params.max_new_tokens = min(
639
682
  (
@@ -681,13 +724,12 @@ class Scheduler:
681
724
  )
682
725
  req.tokenizer = self.tokenizer
683
726
 
684
- # Truncate prompts that are too long
685
- if len(req.origin_input_ids) >= self.max_req_input_len:
686
- logger.warning(
687
- "Request length is longer than the KV cache pool size or "
688
- "the max context length. Truncated!!!"
689
- )
690
- req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
727
+ # Validate prompts length
728
+ validate_input_length(
729
+ req,
730
+ self.max_req_input_len,
731
+ self.server_args.allow_auto_truncate,
732
+ )
691
733
 
692
734
  self.waiting_queue.append(req)
693
735
 
@@ -733,21 +775,40 @@ class Scheduler:
733
775
  self.num_generated_tokens = 0
734
776
  self.last_decode_stats_tic = time.time()
735
777
  num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
736
- logger.info(
737
- f"Decode batch. "
738
- f"#running-req: {num_running_reqs}, "
739
- f"#token: {num_used}, "
740
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
741
- f"gen throughput (token/s): {gen_throughput:.2f}, "
742
- f"#queue-req: {len(self.waiting_queue)}"
743
- )
744
778
 
779
+ if self.spec_algorithm.is_none():
780
+ msg = (
781
+ f"Decode batch. "
782
+ f"#running-req: {num_running_reqs}, "
783
+ f"#token: {num_used}, "
784
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
785
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
786
+ f"#queue-req: {len(self.waiting_queue)}"
787
+ )
788
+ spec_accept_length = 0
789
+ else:
790
+ spec_accept_length = (
791
+ self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
792
+ )
793
+ self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
794
+ msg = (
795
+ f"Decode batch. "
796
+ f"#running-req: {num_running_reqs}, "
797
+ f"#token: {num_used}, "
798
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
799
+ f"accept len: {spec_accept_length:.2f}, "
800
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
801
+ f"#queue-req: {len(self.waiting_queue)}"
802
+ )
803
+
804
+ logger.info(msg)
745
805
  if self.enable_metrics:
746
806
  self.stats.num_running_reqs = num_running_reqs
747
807
  self.stats.num_used_tokens = num_used
748
808
  self.stats.token_usage = num_used / self.max_total_num_tokens
749
809
  self.stats.gen_throughput = gen_throughput
750
810
  self.stats.num_queue_reqs = len(self.waiting_queue)
811
+ self.stats.spec_accept_length = spec_accept_length
751
812
  self.metrics_collector.log_stats(self.stats)
752
813
 
753
814
  def check_memory(self):
@@ -790,16 +851,23 @@ class Scheduler:
790
851
  else:
791
852
  self.running_batch.merge_batch(self.last_batch)
792
853
 
793
- # Run prefill first if possible
794
854
  new_batch = self.get_new_batch_prefill()
795
855
  if new_batch is not None:
796
- return new_batch
856
+ # Run prefill first if possible
857
+ ret = new_batch
858
+ else:
859
+ # Run decode
860
+ if self.running_batch is None:
861
+ ret = None
862
+ else:
863
+ self.running_batch = self.update_running_batch(self.running_batch)
864
+ ret = self.running_batch
797
865
 
798
- # Run decode
799
- if self.running_batch is None:
800
- return None
801
- self.running_batch = self.update_running_batch(self.running_batch)
802
- return self.running_batch
866
+ # Handle DP attention
867
+ if self.server_args.enable_dp_attention:
868
+ ret = self.prepare_dp_attn_batch(ret)
869
+
870
+ return ret
803
871
 
804
872
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
805
873
  # Check if the grammar is ready in the grammar queue
@@ -823,9 +891,9 @@ class Scheduler:
823
891
  # Prefill policy
824
892
  adder = PrefillAdder(
825
893
  self.tree_cache,
894
+ self.token_to_kv_pool,
826
895
  self.running_batch,
827
896
  self.new_token_ratio,
828
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
829
897
  self.max_prefill_tokens,
830
898
  self.chunked_prefill_size,
831
899
  running_bs if self.is_mixed_chunk else 0,
@@ -886,7 +954,7 @@ class Scheduler:
886
954
  self.being_chunked_req.is_being_chunked += 1
887
955
 
888
956
  # Print stats
889
- if self.tp_rank == 0:
957
+ if self.attn_tp_rank == 0:
890
958
  self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
891
959
 
892
960
  # Create a new batch
@@ -898,6 +966,7 @@ class Scheduler:
898
966
  self.model_config,
899
967
  self.enable_overlap,
900
968
  self.spec_algorithm,
969
+ self.server_args.enable_custom_logit_processor,
901
970
  )
902
971
  new_batch.prepare_for_extend()
903
972
 
@@ -968,12 +1037,14 @@ class Scheduler:
968
1037
  batch.prepare_for_decode()
969
1038
  return batch
970
1039
 
971
- def run_batch(self, batch: ScheduleBatch):
1040
+ def run_batch(
1041
+ self, batch: ScheduleBatch
1042
+ ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
972
1043
  """Run a batch."""
973
1044
  self.forward_ct += 1
974
1045
 
975
1046
  if self.is_generation:
976
- if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
1047
+ if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
977
1048
  if self.spec_algorithm.is_none():
978
1049
  model_worker_batch = batch.get_model_worker_batch()
979
1050
  logits_output, next_token_ids = (
@@ -986,45 +1057,65 @@ class Scheduler:
986
1057
  model_worker_batch,
987
1058
  num_accepted_tokens,
988
1059
  ) = self.draft_worker.forward_batch_speculative_generation(batch)
1060
+ self.spec_num_total_accepted_tokens += (
1061
+ num_accepted_tokens + batch.batch_size()
1062
+ )
1063
+ self.spec_num_total_forward_ct += batch.batch_size()
989
1064
  self.num_generated_tokens += num_accepted_tokens
990
- elif batch.forward_mode.is_idle():
991
- model_worker_batch = batch.get_model_worker_batch()
992
- self.tp_worker.forward_batch_idle(model_worker_batch)
993
- return
994
1065
  else:
995
- logits_output = None
996
- if self.skip_tokenizer_init:
997
- next_token_ids = torch.full(
998
- (batch.batch_size(),), self.tokenizer.eos_token_id
999
- )
1000
- else:
1001
- next_token_ids = torch.full((batch.batch_size(),), 0)
1066
+ assert False, "batch.extend_num_tokens == 0, this is unexpected!"
1002
1067
  batch.output_ids = next_token_ids
1003
- ret = logits_output, next_token_ids, model_worker_batch.bid
1068
+
1069
+ ret = GenerationBatchResult(
1070
+ logits_output=logits_output,
1071
+ next_token_ids=next_token_ids,
1072
+ bid=model_worker_batch.bid,
1073
+ )
1004
1074
  else: # embedding or reward model
1005
1075
  assert batch.extend_num_tokens != 0
1006
1076
  model_worker_batch = batch.get_model_worker_batch()
1007
1077
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1008
- ret = embeddings, model_worker_batch.bid
1078
+ ret = EmbeddingBatchResult(
1079
+ embeddings=embeddings, bid=model_worker_batch.bid
1080
+ )
1009
1081
  return ret
1010
1082
 
1011
- def process_batch_result(self, batch: ScheduleBatch, result):
1083
+ def process_batch_result(
1084
+ self,
1085
+ batch: ScheduleBatch,
1086
+ result: Union[GenerationBatchResult, EmbeddingBatchResult],
1087
+ ):
1012
1088
  if batch.forward_mode.is_decode():
1013
1089
  self.process_batch_result_decode(batch, result)
1014
1090
  if batch.is_empty():
1015
1091
  self.running_batch = None
1016
1092
  elif batch.forward_mode.is_extend():
1017
1093
  self.process_batch_result_prefill(batch, result)
1094
+ elif batch.forward_mode.is_idle():
1095
+ if self.enable_overlap:
1096
+ self.tp_worker.resolve_batch_result(result.bid)
1018
1097
  elif batch.forward_mode.is_dummy_first():
1019
1098
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1020
1099
  self.current_stream.synchronize()
1021
1100
  batch.next_batch_sampling_info.sampling_info_done.set()
1022
1101
 
1023
- def process_batch_result_prefill(self, batch: ScheduleBatch, result):
1102
+ def process_batch_result_prefill(
1103
+ self,
1104
+ batch: ScheduleBatch,
1105
+ result: Union[GenerationBatchResult, EmbeddingBatchResult],
1106
+ ):
1024
1107
  skip_stream_req = None
1025
1108
 
1026
1109
  if self.is_generation:
1027
- logits_output, next_token_ids, bid = result
1110
+ (
1111
+ logits_output,
1112
+ next_token_ids,
1113
+ bid,
1114
+ ) = (
1115
+ result.logits_output,
1116
+ result.next_token_ids,
1117
+ result.bid,
1118
+ )
1028
1119
 
1029
1120
  if self.enable_overlap:
1030
1121
  logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
@@ -1038,9 +1129,6 @@ class Scheduler:
1038
1129
  logits_output.input_token_logprobs = (
1039
1130
  logits_output.input_token_logprobs.tolist()
1040
1131
  )
1041
- logits_output.normalized_prompt_logprobs = (
1042
- logits_output.normalized_prompt_logprobs.tolist()
1043
- )
1044
1132
 
1045
1133
  # Check finish conditions
1046
1134
  logprob_pt = 0
@@ -1085,7 +1173,7 @@ class Scheduler:
1085
1173
  batch.next_batch_sampling_info.sampling_info_done.set()
1086
1174
 
1087
1175
  else: # embedding or reward model
1088
- embeddings, bid = result
1176
+ embeddings, bid = result.embeddings, result.bid
1089
1177
  embeddings = embeddings.tolist()
1090
1178
 
1091
1179
  # Check finish conditions
@@ -1109,8 +1197,16 @@ class Scheduler:
1109
1197
 
1110
1198
  self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1111
1199
 
1112
- def process_batch_result_decode(self, batch: ScheduleBatch, result):
1113
- logits_output, next_token_ids, bid = result
1200
+ def process_batch_result_decode(
1201
+ self,
1202
+ batch: ScheduleBatch,
1203
+ result: GenerationBatchResult,
1204
+ ):
1205
+ logits_output, next_token_ids, bid = (
1206
+ result.logits_output,
1207
+ result.next_token_ids,
1208
+ result.bid,
1209
+ )
1114
1210
  self.num_generated_tokens += len(batch.reqs)
1115
1211
 
1116
1212
  if self.enable_overlap:
@@ -1168,7 +1264,7 @@ class Scheduler:
1168
1264
 
1169
1265
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
1170
1266
  if (
1171
- self.tp_rank == 0
1267
+ self.attn_tp_rank == 0
1172
1268
  and self.forward_ct_decode % self.server_args.decode_log_interval == 0
1173
1269
  ):
1174
1270
  self.log_decode_stats()
@@ -1188,9 +1284,6 @@ class Scheduler:
1188
1284
  # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
1189
1285
  num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
1190
1286
 
1191
- if req.normalized_prompt_logprob is None:
1192
- req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
1193
-
1194
1287
  if req.input_token_logprobs_val is None:
1195
1288
  input_token_logprobs_val = output.input_token_logprobs[
1196
1289
  pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
@@ -1288,15 +1381,12 @@ class Scheduler:
1288
1381
  input_top_logprobs_idx = []
1289
1382
  output_top_logprobs_val = []
1290
1383
  output_top_logprobs_idx = []
1291
- normalized_prompt_logprob = []
1292
1384
  else:
1293
1385
  input_token_logprobs_val = input_token_logprobs_idx = (
1294
1386
  output_token_logprobs_val
1295
1387
  ) = output_token_logprobs_idx = input_top_logprobs_val = (
1296
1388
  input_top_logprobs_idx
1297
- ) = output_top_logprobs_val = output_top_logprobs_idx = (
1298
- normalized_prompt_logprob
1299
- ) = None
1389
+ ) = output_top_logprobs_val = output_top_logprobs_idx = None
1300
1390
 
1301
1391
  for req in reqs:
1302
1392
  if req is skip_req:
@@ -1343,7 +1433,6 @@ class Scheduler:
1343
1433
  input_top_logprobs_idx.append(req.input_top_logprobs_idx)
1344
1434
  output_top_logprobs_val.append(req.output_top_logprobs_val)
1345
1435
  output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1346
- normalized_prompt_logprob.append(req.normalized_prompt_logprob)
1347
1436
 
1348
1437
  # Send to detokenizer
1349
1438
  if rids:
@@ -1370,7 +1459,6 @@ class Scheduler:
1370
1459
  input_top_logprobs_idx,
1371
1460
  output_top_logprobs_val,
1372
1461
  output_top_logprobs_idx,
1373
- normalized_prompt_logprob,
1374
1462
  )
1375
1463
  )
1376
1464
  else: # embedding or reward model
@@ -1412,12 +1500,7 @@ class Scheduler:
1412
1500
  # Check forward mode for cuda graph
1413
1501
  if not self.server_args.disable_cuda_graph:
1414
1502
  forward_mode_state = torch.tensor(
1415
- (
1416
- 1
1417
- if local_batch.forward_mode.is_decode()
1418
- or local_batch.forward_mode.is_idle()
1419
- else 0
1420
- ),
1503
+ (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1421
1504
  dtype=torch.int32,
1422
1505
  )
1423
1506
  torch.distributed.all_reduce(
@@ -1438,6 +1521,7 @@ class Scheduler:
1438
1521
  self.model_config,
1439
1522
  self.enable_overlap,
1440
1523
  self.spec_algorithm,
1524
+ self.server_args.enable_custom_logit_processor,
1441
1525
  )
1442
1526
  idle_batch.prepare_for_idle()
1443
1527
  return idle_batch
@@ -1466,6 +1550,9 @@ class Scheduler:
1466
1550
  self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
1467
1551
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1468
1552
 
1553
+ def flush_cache_wrapped(self, recv_req: FlushCacheReq):
1554
+ self.flush_cache()
1555
+
1469
1556
  def flush_cache(self):
1470
1557
  """Flush the memory pool and cache."""
1471
1558
  if len(self.waiting_queue) == 0 and (
@@ -1518,12 +1605,12 @@ class Scheduler:
1518
1605
  assert flash_cache_success, "Cache flush failed after updating weights"
1519
1606
  else:
1520
1607
  logger.error(message)
1521
- return success, message
1608
+ return UpdateWeightFromDiskReqOutput(success, message)
1522
1609
 
1523
1610
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
1524
1611
  """Initialize the online model parameter update group."""
1525
1612
  success, message = self.tp_worker.init_weights_update_group(recv_req)
1526
- return success, message
1613
+ return InitWeightsUpdateGroupReqOutput(success, message)
1527
1614
 
1528
1615
  def update_weights_from_distributed(
1529
1616
  self,
@@ -1536,7 +1623,7 @@ class Scheduler:
1536
1623
  assert flash_cache_success, "Cache flush failed after updating weights"
1537
1624
  else:
1538
1625
  logger.error(message)
1539
- return success, message
1626
+ return UpdateWeightsFromDistributedReqOutput(success, message)
1540
1627
 
1541
1628
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
1542
1629
  """Update the online model parameter from tensors."""
@@ -1547,11 +1634,11 @@ class Scheduler:
1547
1634
  assert flash_cache_success, "Cache flush failed after updating weights"
1548
1635
  else:
1549
1636
  logger.error(message)
1550
- return success, message
1637
+ return UpdateWeightsFromTensorReqOutput(success, message)
1551
1638
 
1552
1639
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
1553
1640
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1554
- return parameter
1641
+ return GetWeightsByNameReqOutput(parameter)
1555
1642
 
1556
1643
  def release_memory_occupation(self):
1557
1644
  self.stashed_model_static_state = _export_static_state(
@@ -1559,6 +1646,7 @@ class Scheduler:
1559
1646
  )
1560
1647
  self.memory_saver_adapter.pause()
1561
1648
  self.flush_cache()
1649
+ return ReleaseMemoryOccupationReqOutput()
1562
1650
 
1563
1651
  def resume_memory_occupation(self):
1564
1652
  self.memory_saver_adapter.resume()
@@ -1566,6 +1654,13 @@ class Scheduler:
1566
1654
  self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
1567
1655
  )
1568
1656
  del self.stashed_model_static_state
1657
+ return ResumeMemoryOccupationReqOutput()
1658
+
1659
+ def profile(self, recv_req: ProfileReq):
1660
+ if recv_req == ProfileReq.START_PROFILE:
1661
+ self.start_profile()
1662
+ else:
1663
+ self.stop_profile()
1569
1664
 
1570
1665
  def start_profile(self) -> None:
1571
1666
  if self.profiler is None:
@@ -1581,20 +1676,20 @@ class Scheduler:
1581
1676
  )
1582
1677
  logger.info("Profiler is done")
1583
1678
 
1584
- def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
1679
+ def open_session(self, recv_req: OpenSessionReqInput):
1585
1680
  # handle error
1586
1681
  session_id = recv_req.session_id
1587
1682
  if session_id in self.sessions:
1588
1683
  logger.warning(f"session id {session_id} already exist, cannot open.")
1589
- return session_id, False
1684
+ return OpenSessionReqOutput(session_id, False)
1590
1685
  elif session_id is None:
1591
1686
  logger.warning(f"session id is None, cannot open.")
1592
- return session_id, False
1687
+ return OpenSessionReqOutput(session_id, False)
1593
1688
  else:
1594
1689
  self.sessions[session_id] = Session(
1595
1690
  recv_req.capacity_of_str_len, session_id
1596
1691
  )
1597
- return session_id, True
1692
+ return OpenSessionReqOutput(session_id, True)
1598
1693
 
1599
1694
  def close_session(self, recv_req: CloseSessionReqInput):
1600
1695
  # handle error
@@ -1651,7 +1746,11 @@ def run_scheduler_process(
1651
1746
  try:
1652
1747
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1653
1748
  pipe_writer.send(
1654
- {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
1749
+ {
1750
+ "status": "ready",
1751
+ "max_total_num_tokens": scheduler.max_total_num_tokens,
1752
+ "max_req_input_len": scheduler.max_req_input_len,
1753
+ }
1655
1754
  )
1656
1755
  if scheduler.enable_overlap:
1657
1756
  scheduler.event_loop_overlap()