sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -22,8 +22,10 @@ import time
22
22
  import warnings
23
23
  from collections import deque
24
24
  from concurrent import futures
25
+ from dataclasses import dataclass
26
+ from http import HTTPStatus
25
27
  from types import SimpleNamespace
26
- from typing import Dict, List, Optional, Tuple
28
+ from typing import Dict, List, Optional, Tuple, Union
27
29
 
28
30
  import psutil
29
31
  import setproctitle
@@ -32,7 +34,9 @@ import zmq
32
34
 
33
35
  from sglang.global_config import global_config
34
36
  from sglang.srt.configs.model_config import ModelConfig
37
+ from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
35
38
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
39
+ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
36
40
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
37
41
  from sglang.srt.managers.io_struct import (
38
42
  AbortReq,
@@ -76,6 +80,7 @@ from sglang.srt.managers.schedule_policy import (
76
80
  from sglang.srt.managers.session_controller import Session
77
81
  from sglang.srt.managers.tp_worker import TpModelWorker
78
82
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
83
+ from sglang.srt.managers.utils import validate_input_length
79
84
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
80
85
  from sglang.srt.mem_cache.radix_cache import RadixCache
81
86
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -93,7 +98,7 @@ from sglang.srt.utils import (
93
98
  set_random_seed,
94
99
  suppress_other_loggers,
95
100
  )
96
- from sglang.utils import get_exception_traceback
101
+ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
97
102
 
98
103
  logger = logging.getLogger(__name__)
99
104
 
@@ -101,6 +106,19 @@ logger = logging.getLogger(__name__)
101
106
  test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
102
107
 
103
108
 
109
+ @dataclass
110
+ class GenerationBatchResult:
111
+ logits_output: LogitsProcessorOutput
112
+ next_token_ids: List[int]
113
+ bid: int
114
+
115
+
116
+ @dataclass
117
+ class EmbeddingBatchResult:
118
+ embeddings: torch.Tensor
119
+ bid: int
120
+
121
+
104
122
  class Scheduler:
105
123
  """A scheduler that manages a tensor parallel GPU worker."""
106
124
 
@@ -132,26 +150,36 @@ class Scheduler:
132
150
  else 1
133
151
  )
134
152
 
153
+ # Distributed rank info
154
+ self.dp_size = server_args.dp_size
155
+ self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
156
+ compute_dp_attention_world_info(
157
+ server_args.enable_dp_attention,
158
+ self.tp_rank,
159
+ self.tp_size,
160
+ self.dp_size,
161
+ )
162
+ )
163
+
135
164
  # Init inter-process communication
136
165
  context = zmq.Context(2)
137
-
138
- if self.tp_rank == 0 or self.server_args.enable_dp_attention:
166
+ if self.attn_tp_rank == 0:
139
167
  self.recv_from_tokenizer = get_zmq_socket(
140
- context, zmq.PULL, port_args.scheduler_input_ipc_name
168
+ context, zmq.PULL, port_args.scheduler_input_ipc_name, False
141
169
  )
142
170
  self.send_to_tokenizer = get_zmq_socket(
143
- context, zmq.PUSH, port_args.tokenizer_ipc_name
171
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
144
172
  )
145
173
 
146
174
  if server_args.skip_tokenizer_init:
147
175
  # Directly send to the TokenizerManager
148
176
  self.send_to_detokenizer = get_zmq_socket(
149
- context, zmq.PUSH, port_args.tokenizer_ipc_name
177
+ context, zmq.PUSH, port_args.tokenizer_ipc_name, False
150
178
  )
151
179
  else:
152
180
  # Send to the DetokenizerManager
153
181
  self.send_to_detokenizer = get_zmq_socket(
154
- context, zmq.PUSH, port_args.detokenizer_ipc_name
182
+ context, zmq.PUSH, port_args.detokenizer_ipc_name, False
155
183
  )
156
184
  else:
157
185
  self.recv_from_tokenizer = None
@@ -179,6 +207,7 @@ class Scheduler:
179
207
  server_args.tokenizer_path,
180
208
  tokenizer_mode=server_args.tokenizer_mode,
181
209
  trust_remote_code=server_args.trust_remote_code,
210
+ revision=server_args.revision,
182
211
  )
183
212
  self.tokenizer = self.processor.tokenizer
184
213
  else:
@@ -186,6 +215,7 @@ class Scheduler:
186
215
  server_args.tokenizer_path,
187
216
  tokenizer_mode=server_args.tokenizer_mode,
188
217
  trust_remote_code=server_args.trust_remote_code,
218
+ revision=server_args.revision,
189
219
  )
190
220
 
191
221
  # Check whether overlap can be enabled
@@ -214,7 +244,7 @@ class Scheduler:
214
244
  nccl_port=port_args.nccl_port,
215
245
  )
216
246
 
217
- # Launch worker for speculative decoding if need
247
+ # Launch a worker for speculative decoding if needed
218
248
  if self.spec_algorithm.is_eagle():
219
249
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
220
250
 
@@ -244,13 +274,14 @@ class Scheduler:
244
274
  _,
245
275
  ) = self.tp_worker.get_worker_info()
246
276
  self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
277
+ self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
247
278
  self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
248
279
  global_server_args_dict.update(worker_global_server_args_dict)
249
280
  set_random_seed(self.random_seed)
250
-
251
281
  # Print debug info
252
282
  logger.info(
253
283
  f"max_total_num_tokens={self.max_total_num_tokens}, "
284
+ f"chunked_prefill_size={server_args.chunked_prefill_size}, "
254
285
  f"max_prefill_tokens={self.max_prefill_tokens}, "
255
286
  f"max_running_requests={self.max_running_requests}, "
256
287
  f"context_len={self.model_config.context_len}"
@@ -287,9 +318,13 @@ class Scheduler:
287
318
  self.forward_ct = 0
288
319
  self.forward_ct_decode = 0
289
320
  self.num_generated_tokens = 0
321
+ self.spec_num_total_accepted_tokens = 0
322
+ self.spec_num_total_forward_ct = 0
290
323
  self.last_decode_stats_tic = time.time()
291
324
  self.stream_interval = server_args.stream_interval
292
325
  self.current_stream = torch.get_device_module(self.device).current_stream()
326
+ if self.device == "cpu":
327
+ self.current_stream.synchronize = lambda: None # No-op for CPU
293
328
 
294
329
  # Session info
295
330
  self.sessions: Dict[str, Session] = {}
@@ -306,28 +341,9 @@ class Scheduler:
306
341
  # Init the grammar backend for constrained generation
307
342
  self.grammar_queue: List[Req] = []
308
343
  if not server_args.skip_tokenizer_init:
309
- if server_args.grammar_backend == "outlines":
310
- from sglang.srt.constrained.outlines_backend import (
311
- OutlinesGrammarBackend,
312
- )
313
-
314
- self.grammar_backend = OutlinesGrammarBackend(
315
- self.tokenizer,
316
- whitespace_pattern=server_args.constrained_json_whitespace_pattern,
317
- allow_jump_forward=not server_args.disable_jump_forward,
318
- )
319
- elif server_args.grammar_backend == "xgrammar":
320
- from sglang.srt.constrained.xgrammar_backend import (
321
- XGrammarGrammarBackend,
322
- )
323
-
324
- self.grammar_backend = XGrammarGrammarBackend(
325
- self.tokenizer, vocab_size=self.model_config.vocab_size
326
- )
327
- else:
328
- raise ValueError(
329
- f"Invalid grammar backend: {server_args.grammar_backend}"
330
- )
344
+ self.grammar_backend = create_grammar_backend(
345
+ server_args, self.tokenizer, self.model_config.vocab_size
346
+ )
331
347
  else:
332
348
  self.grammar_backend = None
333
349
 
@@ -393,22 +409,56 @@ class Scheduler:
393
409
  },
394
410
  )
395
411
 
412
+ # The largest prefill length of a single request
413
+ self._largest_prefill_len: int = 0
414
+ # The largest context length (prefill + generation) of a single request
415
+ self._largest_prefill_decode_len: int = 0
416
+
417
+ # Init request dispatcher
418
+ self._request_dispatcher = TypeBasedDispatcher(
419
+ [
420
+ (TokenizedGenerateReqInput, self.handle_generate_request),
421
+ (TokenizedEmbeddingReqInput, self.handle_embedding_request),
422
+ (FlushCacheReq, self.flush_cache_wrapped),
423
+ (AbortReq, self.abort_request),
424
+ (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
425
+ (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
426
+ (
427
+ UpdateWeightsFromDistributedReqInput,
428
+ self.update_weights_from_distributed,
429
+ ),
430
+ (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
431
+ (GetWeightsByNameReqInput, self.get_weights_by_name),
432
+ (ProfileReq, self.profile),
433
+ (OpenSessionReqInput, self.open_session),
434
+ (CloseSessionReqInput, self.close_session),
435
+ (
436
+ ReleaseMemoryOccupationReqInput,
437
+ lambda _: self.release_memory_occupation(),
438
+ ),
439
+ (
440
+ ResumeMemoryOccupationReqInput,
441
+ lambda _: self.resume_memory_occupation(),
442
+ ),
443
+ ]
444
+ )
445
+
396
446
  def watchdog_thread(self):
397
447
  """A watch dog thread that will try to kill the server itself if one batch takes too long."""
398
448
  self.watchdog_last_forward_ct = 0
399
449
  self.watchdog_last_time = time.time()
400
450
 
401
451
  while True:
452
+ current = time.time()
402
453
  if self.cur_batch is not None:
403
454
  if self.watchdog_last_forward_ct == self.forward_ct:
404
- if time.time() > self.watchdog_last_time + self.watchdog_timeout:
455
+ if current > self.watchdog_last_time + self.watchdog_timeout:
405
456
  logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
406
457
  break
407
458
  else:
408
459
  self.watchdog_last_forward_ct = self.forward_ct
409
- self.watchdog_last_time = time.time()
410
- time.sleep(self.watchdog_timeout / 2)
411
-
460
+ self.watchdog_last_time = current
461
+ time.sleep(self.watchdog_timeout // 2)
412
462
  # Wait sometimes so that the parent process can print the error.
413
463
  time.sleep(5)
414
464
  self.parent_process.send_signal(signal.SIGQUIT)
@@ -421,10 +471,6 @@ class Scheduler:
421
471
  self.process_input_requests(recv_reqs)
422
472
 
423
473
  batch = self.get_next_batch_to_run()
424
-
425
- if self.server_args.enable_dp_attention: # TODO: simplify this
426
- batch = self.prepare_dp_attn_batch(batch)
427
-
428
474
  self.cur_batch = batch
429
475
 
430
476
  if batch:
@@ -440,7 +486,7 @@ class Scheduler:
440
486
  @torch.no_grad()
441
487
  def event_loop_overlap(self):
442
488
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
443
- result_queue = deque()
489
+ self.result_queue = deque()
444
490
 
445
491
  while True:
446
492
  recv_reqs = self.recv_requests()
@@ -451,10 +497,10 @@ class Scheduler:
451
497
 
452
498
  if batch:
453
499
  result = self.run_batch(batch)
454
- result_queue.append((batch.copy(), result))
500
+ self.result_queue.append((batch.copy(), result))
455
501
 
456
502
  if self.last_batch is None:
457
- # Create a dummy first batch to start the pipeline for overlap scheduler.
503
+ # Create a dummy first batch to start the pipeline for overlap schedule.
458
504
  # It is now used for triggering the sampling_info_done event.
459
505
  tmp_batch = ScheduleBatch(
460
506
  reqs=None,
@@ -465,7 +511,7 @@ class Scheduler:
465
511
 
466
512
  if self.last_batch:
467
513
  # Process the results of the last batch
468
- tmp_batch, tmp_result = result_queue.popleft()
514
+ tmp_batch, tmp_result = self.result_queue.popleft()
469
515
  tmp_batch.next_batch_sampling_info = (
470
516
  self.tp_worker.cur_sampling_info if batch else None
471
517
  )
@@ -479,7 +525,7 @@ class Scheduler:
479
525
 
480
526
  def recv_requests(self) -> List[Req]:
481
527
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
482
- if self.tp_rank == 0 or self.server_args.enable_dp_attention:
528
+ if self.attn_tp_rank == 0:
483
529
  recv_reqs = []
484
530
 
485
531
  while True:
@@ -491,63 +537,48 @@ class Scheduler:
491
537
  else:
492
538
  recv_reqs = None
493
539
 
494
- if self.tp_size != 1 and not self.server_args.enable_dp_attention:
540
+ if self.server_args.enable_dp_attention:
541
+ if self.attn_tp_rank == 0:
542
+ work_reqs = [
543
+ req
544
+ for req in recv_reqs
545
+ if isinstance(
546
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
547
+ )
548
+ ]
549
+ control_reqs = [
550
+ req
551
+ for req in recv_reqs
552
+ if not isinstance(
553
+ req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
554
+ )
555
+ ]
556
+ else:
557
+ work_reqs = None
558
+ control_reqs = None
559
+
560
+ if self.attn_tp_size != 1:
561
+ attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
562
+ work_reqs = broadcast_pyobj(
563
+ work_reqs,
564
+ self.attn_tp_rank,
565
+ self.attn_tp_cpu_group,
566
+ src=attn_tp_rank_0,
567
+ )
568
+ if self.tp_size != 1:
569
+ control_reqs = broadcast_pyobj(
570
+ control_reqs, self.tp_rank, self.tp_cpu_group
571
+ )
572
+ recv_reqs = work_reqs + control_reqs
573
+ elif self.tp_size != 1:
495
574
  recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
496
575
  return recv_reqs
497
576
 
498
577
  def process_input_requests(self, recv_reqs: List):
499
578
  for recv_req in recv_reqs:
500
- if isinstance(recv_req, TokenizedGenerateReqInput):
501
- self.handle_generate_request(recv_req)
502
- elif isinstance(recv_req, TokenizedEmbeddingReqInput):
503
- self.handle_embedding_request(recv_req)
504
- elif isinstance(recv_req, FlushCacheReq):
505
- self.flush_cache()
506
- elif isinstance(recv_req, AbortReq):
507
- self.abort_request(recv_req)
508
- elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
509
- success, message = self.update_weights_from_disk(recv_req)
510
- self.send_to_tokenizer.send_pyobj(
511
- UpdateWeightFromDiskReqOutput(success, message)
512
- )
513
- elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
514
- success, message = self.init_weights_update_group(recv_req)
515
- self.send_to_tokenizer.send_pyobj(
516
- InitWeightsUpdateGroupReqOutput(success, message)
517
- )
518
- elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
519
- success, message = self.update_weights_from_distributed(recv_req)
520
- self.send_to_tokenizer.send_pyobj(
521
- UpdateWeightsFromDistributedReqOutput(success, message)
522
- )
523
- elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
524
- success, message = self.update_weights_from_tensor(recv_req)
525
- self.send_to_tokenizer.send_pyobj(
526
- UpdateWeightsFromTensorReqOutput(success, message)
527
- )
528
- elif isinstance(recv_req, GetWeightsByNameReqInput):
529
- parameter = self.get_weights_by_name(recv_req)
530
- self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
531
- elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
532
- self.release_memory_occupation()
533
- self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
534
- elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
535
- self.resume_memory_occupation()
536
- self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
537
- elif isinstance(recv_req, ProfileReq):
538
- if recv_req == ProfileReq.START_PROFILE:
539
- self.start_profile()
540
- else:
541
- self.stop_profile()
542
- elif isinstance(recv_req, OpenSessionReqInput):
543
- session_id, success = self.open_session(recv_req)
544
- self.send_to_tokenizer.send_pyobj(
545
- OpenSessionReqOutput(session_id=session_id, success=success)
546
- )
547
- elif isinstance(recv_req, CloseSessionReqInput):
548
- self.close_session(recv_req)
549
- else:
550
- raise ValueError(f"Invalid request: {recv_req}")
579
+ output = self._request_dispatcher(recv_req)
580
+ if output is not None:
581
+ self.send_to_tokenizer.send_pyobj(output)
551
582
 
552
583
  def handle_generate_request(
553
584
  self,
@@ -566,6 +597,19 @@ class Scheduler:
566
597
  fake_input_ids = [1] * seq_length
567
598
  recv_req.input_ids = fake_input_ids
568
599
 
600
+ # Handle custom logit processor passed to the request
601
+ custom_logit_processor = recv_req.custom_logit_processor
602
+ if (
603
+ not self.server_args.enable_custom_logit_processor
604
+ and custom_logit_processor is not None
605
+ ):
606
+ logger.warning(
607
+ "The SGLang server is not configured to enable custom logit processor."
608
+ "The custom logit processor passed in will be ignored."
609
+ "Please set --enable-custom-logits-processor to enable this feature."
610
+ )
611
+ custom_logit_processor = None
612
+
569
613
  req = Req(
570
614
  recv_req.rid,
571
615
  recv_req.input_text,
@@ -576,6 +620,7 @@ class Scheduler:
576
620
  stream=recv_req.stream,
577
621
  lora_path=recv_req.lora_path,
578
622
  input_embeds=recv_req.input_embeds,
623
+ custom_logit_processor=custom_logit_processor,
579
624
  eos_token_ids=self.model_config.hf_eos_token_id,
580
625
  )
581
626
  req.tokenizer = self.tokenizer
@@ -597,7 +642,7 @@ class Scheduler:
597
642
  self.waiting_queue.append(req)
598
643
  return
599
644
 
600
- # Handle image inputs
645
+ # Handle multimodal inputs
601
646
  if recv_req.image_inputs is not None:
602
647
  image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
603
648
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
@@ -607,33 +652,36 @@ class Scheduler:
607
652
  req.extend_image_inputs(image_inputs)
608
653
 
609
654
  if len(req.origin_input_ids) >= self.max_req_input_len:
610
- logger.error(
655
+ error_msg = (
611
656
  "Multimodal prompt is too long after expanding multimodal tokens. "
612
- f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
657
+ f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
613
658
  )
659
+ logger.error(error_msg)
614
660
  req.origin_input_ids = [0]
615
661
  req.image_inputs = None
616
662
  req.sampling_params.max_new_tokens = 0
617
663
  req.finished_reason = FINISH_ABORT(
618
- "Multimodal prompt is too long. Check server logs for details."
664
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
619
665
  )
620
666
  self.waiting_queue.append(req)
621
667
  return
622
668
 
623
- # Copy more attributes
624
- req.logprob_start_len = recv_req.logprob_start_len
669
+ # Validate prompts length
670
+ error_msg = validate_input_length(
671
+ req,
672
+ self.max_req_input_len,
673
+ self.server_args.allow_auto_truncate,
674
+ )
675
+ if error_msg:
676
+ self.waiting_queue.append(req)
677
+ return
625
678
 
626
- if req.logprob_start_len == -1:
679
+ # Copy more attributes
680
+ if recv_req.logprob_start_len == -1:
627
681
  # By default, only return the logprobs for output tokens
628
682
  req.logprob_start_len = len(req.origin_input_ids) - 1
629
-
630
- # Truncate prompts that are too long
631
- if len(req.origin_input_ids) > self.max_req_input_len:
632
- logger.warning(
633
- "Request length is longer than the KV cache pool size or "
634
- "the max context length. Truncated!!!"
635
- )
636
- req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
683
+ else:
684
+ req.logprob_start_len = recv_req.logprob_start_len
637
685
 
638
686
  req.sampling_params.max_new_tokens = min(
639
687
  (
@@ -681,17 +729,27 @@ class Scheduler:
681
729
  )
682
730
  req.tokenizer = self.tokenizer
683
731
 
684
- # Truncate prompts that are too long
685
- if len(req.origin_input_ids) >= self.max_req_input_len:
686
- logger.warning(
687
- "Request length is longer than the KV cache pool size or "
688
- "the max context length. Truncated!!!"
689
- )
690
- req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
732
+ # Validate prompts length
733
+ error_msg = validate_input_length(
734
+ req,
735
+ self.max_req_input_len,
736
+ self.server_args.allow_auto_truncate,
737
+ )
738
+ if error_msg:
739
+ self.waiting_queue.append(req)
740
+ return
691
741
 
742
+ # Copy more attributes
743
+ req.logprob_start_len = len(req.origin_input_ids) - 1
692
744
  self.waiting_queue.append(req)
693
745
 
694
- def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
746
+ def log_prefill_stats(
747
+ self,
748
+ adder: PrefillAdder,
749
+ can_run_list: List[Req],
750
+ running_bs: ScheduleBatch,
751
+ has_being_chunked: bool,
752
+ ):
695
753
  self.tree_cache_metrics["total"] += (
696
754
  adder.log_input_tokens + adder.log_hit_tokens
697
755
  ) / 10**9
@@ -733,21 +791,40 @@ class Scheduler:
733
791
  self.num_generated_tokens = 0
734
792
  self.last_decode_stats_tic = time.time()
735
793
  num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
736
- logger.info(
737
- f"Decode batch. "
738
- f"#running-req: {num_running_reqs}, "
739
- f"#token: {num_used}, "
740
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
741
- f"gen throughput (token/s): {gen_throughput:.2f}, "
742
- f"#queue-req: {len(self.waiting_queue)}"
743
- )
744
794
 
795
+ if self.spec_algorithm.is_none():
796
+ msg = (
797
+ f"Decode batch. "
798
+ f"#running-req: {num_running_reqs}, "
799
+ f"#token: {num_used}, "
800
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
801
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
802
+ f"#queue-req: {len(self.waiting_queue)}"
803
+ )
804
+ spec_accept_length = 0
805
+ else:
806
+ spec_accept_length = (
807
+ self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
808
+ )
809
+ self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
810
+ msg = (
811
+ f"Decode batch. "
812
+ f"#running-req: {num_running_reqs}, "
813
+ f"#token: {num_used}, "
814
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
815
+ f"accept len: {spec_accept_length:.2f}, "
816
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
817
+ f"#queue-req: {len(self.waiting_queue)}"
818
+ )
819
+
820
+ logger.info(msg)
745
821
  if self.enable_metrics:
746
822
  self.stats.num_running_reqs = num_running_reqs
747
823
  self.stats.num_used_tokens = num_used
748
824
  self.stats.token_usage = num_used / self.max_total_num_tokens
749
825
  self.stats.gen_throughput = gen_throughput
750
826
  self.stats.num_queue_reqs = len(self.waiting_queue)
827
+ self.stats.spec_accept_length = spec_accept_length
751
828
  self.metrics_collector.log_stats(self.stats)
752
829
 
753
830
  def check_memory(self):
@@ -790,16 +867,23 @@ class Scheduler:
790
867
  else:
791
868
  self.running_batch.merge_batch(self.last_batch)
792
869
 
793
- # Run prefill first if possible
794
870
  new_batch = self.get_new_batch_prefill()
795
871
  if new_batch is not None:
796
- return new_batch
872
+ # Run prefill first if possible
873
+ ret = new_batch
874
+ else:
875
+ # Run decode
876
+ if self.running_batch is None:
877
+ ret = None
878
+ else:
879
+ self.running_batch = self.update_running_batch(self.running_batch)
880
+ ret = self.running_batch
797
881
 
798
- # Run decode
799
- if self.running_batch is None:
800
- return None
801
- self.running_batch = self.update_running_batch(self.running_batch)
802
- return self.running_batch
882
+ # Handle DP attention
883
+ if self.server_args.enable_dp_attention:
884
+ ret = self.prepare_dp_attn_batch(ret)
885
+
886
+ return ret
803
887
 
804
888
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
805
889
  # Check if the grammar is ready in the grammar queue
@@ -823,9 +907,9 @@ class Scheduler:
823
907
  # Prefill policy
824
908
  adder = PrefillAdder(
825
909
  self.tree_cache,
910
+ self.token_to_kv_pool,
826
911
  self.running_batch,
827
912
  self.new_token_ratio,
828
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
829
913
  self.max_prefill_tokens,
830
914
  self.chunked_prefill_size,
831
915
  running_bs if self.is_mixed_chunk else 0,
@@ -886,7 +970,7 @@ class Scheduler:
886
970
  self.being_chunked_req.is_being_chunked += 1
887
971
 
888
972
  # Print stats
889
- if self.tp_rank == 0:
973
+ if self.attn_tp_rank == 0:
890
974
  self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
891
975
 
892
976
  # Create a new batch
@@ -898,6 +982,7 @@ class Scheduler:
898
982
  self.model_config,
899
983
  self.enable_overlap,
900
984
  self.spec_algorithm,
985
+ self.server_args.enable_custom_logit_processor,
901
986
  )
902
987
  new_batch.prepare_for_extend()
903
988
 
@@ -954,7 +1039,7 @@ class Scheduler:
954
1039
  )
955
1040
 
956
1041
  # Check for jump-forward
957
- if not self.disable_jump_forward:
1042
+ if not self.disable_jump_forward and batch.has_grammar:
958
1043
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
959
1044
  self.waiting_queue.extend(jump_forward_reqs)
960
1045
  if batch.is_empty():
@@ -968,63 +1053,81 @@ class Scheduler:
968
1053
  batch.prepare_for_decode()
969
1054
  return batch
970
1055
 
971
- def run_batch(self, batch: ScheduleBatch):
1056
+ def run_batch(
1057
+ self, batch: ScheduleBatch
1058
+ ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
972
1059
  """Run a batch."""
973
1060
  self.forward_ct += 1
974
1061
 
975
1062
  if self.is_generation:
976
- if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
977
- if self.spec_algorithm.is_none():
978
- model_worker_batch = batch.get_model_worker_batch()
979
- logits_output, next_token_ids = (
980
- self.tp_worker.forward_batch_generation(model_worker_batch)
981
- )
982
- else:
983
- (
984
- logits_output,
985
- next_token_ids,
986
- model_worker_batch,
987
- num_accepted_tokens,
988
- ) = self.draft_worker.forward_batch_speculative_generation(batch)
989
- self.num_generated_tokens += num_accepted_tokens
990
- elif batch.forward_mode.is_idle():
1063
+ if self.spec_algorithm.is_none():
991
1064
  model_worker_batch = batch.get_model_worker_batch()
992
- self.tp_worker.forward_batch_idle(model_worker_batch)
993
- return
1065
+ logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
1066
+ model_worker_batch
1067
+ )
994
1068
  else:
995
- logits_output = None
996
- if self.skip_tokenizer_init:
997
- next_token_ids = torch.full(
998
- (batch.batch_size(),), self.tokenizer.eos_token_id
999
- )
1000
- else:
1001
- next_token_ids = torch.full((batch.batch_size(),), 0)
1069
+ (
1070
+ logits_output,
1071
+ next_token_ids,
1072
+ model_worker_batch,
1073
+ num_accepted_tokens,
1074
+ ) = self.draft_worker.forward_batch_speculative_generation(batch)
1075
+ self.spec_num_total_accepted_tokens += (
1076
+ num_accepted_tokens + batch.batch_size()
1077
+ )
1078
+ self.spec_num_total_forward_ct += batch.batch_size()
1079
+ self.num_generated_tokens += num_accepted_tokens
1002
1080
  batch.output_ids = next_token_ids
1003
- ret = logits_output, next_token_ids, model_worker_batch.bid
1081
+
1082
+ ret = GenerationBatchResult(
1083
+ logits_output=logits_output,
1084
+ next_token_ids=next_token_ids,
1085
+ bid=model_worker_batch.bid,
1086
+ )
1004
1087
  else: # embedding or reward model
1005
- assert batch.extend_num_tokens != 0
1006
1088
  model_worker_batch = batch.get_model_worker_batch()
1007
1089
  embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1008
- ret = embeddings, model_worker_batch.bid
1090
+ ret = EmbeddingBatchResult(
1091
+ embeddings=embeddings, bid=model_worker_batch.bid
1092
+ )
1009
1093
  return ret
1010
1094
 
1011
- def process_batch_result(self, batch: ScheduleBatch, result):
1095
+ def process_batch_result(
1096
+ self,
1097
+ batch: ScheduleBatch,
1098
+ result: Union[GenerationBatchResult, EmbeddingBatchResult],
1099
+ ):
1012
1100
  if batch.forward_mode.is_decode():
1013
1101
  self.process_batch_result_decode(batch, result)
1014
1102
  if batch.is_empty():
1015
1103
  self.running_batch = None
1016
1104
  elif batch.forward_mode.is_extend():
1017
1105
  self.process_batch_result_prefill(batch, result)
1106
+ elif batch.forward_mode.is_idle():
1107
+ if self.enable_overlap:
1108
+ self.tp_worker.resolve_batch_result(result.bid)
1018
1109
  elif batch.forward_mode.is_dummy_first():
1019
1110
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1020
1111
  self.current_stream.synchronize()
1021
1112
  batch.next_batch_sampling_info.sampling_info_done.set()
1022
1113
 
1023
- def process_batch_result_prefill(self, batch: ScheduleBatch, result):
1114
+ def process_batch_result_prefill(
1115
+ self,
1116
+ batch: ScheduleBatch,
1117
+ result: Union[GenerationBatchResult, EmbeddingBatchResult],
1118
+ ):
1024
1119
  skip_stream_req = None
1025
1120
 
1026
1121
  if self.is_generation:
1027
- logits_output, next_token_ids, bid = result
1122
+ (
1123
+ logits_output,
1124
+ next_token_ids,
1125
+ bid,
1126
+ ) = (
1127
+ result.logits_output,
1128
+ result.next_token_ids,
1129
+ result.bid,
1130
+ )
1028
1131
 
1029
1132
  if self.enable_overlap:
1030
1133
  logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
@@ -1038,9 +1141,6 @@ class Scheduler:
1038
1141
  logits_output.input_token_logprobs = (
1039
1142
  logits_output.input_token_logprobs.tolist()
1040
1143
  )
1041
- logits_output.normalized_prompt_logprobs = (
1042
- logits_output.normalized_prompt_logprobs.tolist()
1043
- )
1044
1144
 
1045
1145
  # Check finish conditions
1046
1146
  logprob_pt = 0
@@ -1085,7 +1185,7 @@ class Scheduler:
1085
1185
  batch.next_batch_sampling_info.sampling_info_done.set()
1086
1186
 
1087
1187
  else: # embedding or reward model
1088
- embeddings, bid = result
1188
+ embeddings, bid = result.embeddings, result.bid
1089
1189
  embeddings = embeddings.tolist()
1090
1190
 
1091
1191
  # Check finish conditions
@@ -1109,8 +1209,16 @@ class Scheduler:
1109
1209
 
1110
1210
  self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1111
1211
 
1112
- def process_batch_result_decode(self, batch: ScheduleBatch, result):
1113
- logits_output, next_token_ids, bid = result
1212
+ def process_batch_result_decode(
1213
+ self,
1214
+ batch: ScheduleBatch,
1215
+ result: GenerationBatchResult,
1216
+ ):
1217
+ logits_output, next_token_ids, bid = (
1218
+ result.logits_output,
1219
+ result.next_token_ids,
1220
+ result.bid,
1221
+ )
1114
1222
  self.num_generated_tokens += len(batch.reqs)
1115
1223
 
1116
1224
  if self.enable_overlap:
@@ -1168,7 +1276,7 @@ class Scheduler:
1168
1276
 
1169
1277
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
1170
1278
  if (
1171
- self.tp_rank == 0
1279
+ self.attn_tp_rank == 0
1172
1280
  and self.forward_ct_decode % self.server_args.decode_log_interval == 0
1173
1281
  ):
1174
1282
  self.log_decode_stats()
@@ -1188,9 +1296,6 @@ class Scheduler:
1188
1296
  # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
1189
1297
  num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
1190
1298
 
1191
- if req.normalized_prompt_logprob is None:
1192
- req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
1193
-
1194
1299
  if req.input_token_logprobs_val is None:
1195
1300
  input_token_logprobs_val = output.input_token_logprobs[
1196
1301
  pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
@@ -1278,6 +1383,7 @@ class Scheduler:
1278
1383
  prompt_tokens = []
1279
1384
  completion_tokens = []
1280
1385
  cached_tokens = []
1386
+ spec_verify_ct = []
1281
1387
 
1282
1388
  if return_logprob:
1283
1389
  input_token_logprobs_val = []
@@ -1288,15 +1394,12 @@ class Scheduler:
1288
1394
  input_top_logprobs_idx = []
1289
1395
  output_top_logprobs_val = []
1290
1396
  output_top_logprobs_idx = []
1291
- normalized_prompt_logprob = []
1292
1397
  else:
1293
1398
  input_token_logprobs_val = input_token_logprobs_idx = (
1294
1399
  output_token_logprobs_val
1295
1400
  ) = output_token_logprobs_idx = input_top_logprobs_val = (
1296
1401
  input_top_logprobs_idx
1297
- ) = output_top_logprobs_val = output_top_logprobs_idx = (
1298
- normalized_prompt_logprob
1299
- ) = None
1402
+ ) = output_top_logprobs_val = output_top_logprobs_idx = None
1300
1403
 
1301
1404
  for req in reqs:
1302
1405
  if req is skip_req:
@@ -1334,6 +1437,9 @@ class Scheduler:
1334
1437
  completion_tokens.append(len(req.output_ids))
1335
1438
  cached_tokens.append(req.cached_tokens)
1336
1439
 
1440
+ if not self.spec_algorithm.is_none():
1441
+ spec_verify_ct.append(req.spec_verify_ct)
1442
+
1337
1443
  if return_logprob:
1338
1444
  input_token_logprobs_val.append(req.input_token_logprobs_val)
1339
1445
  input_token_logprobs_idx.append(req.input_token_logprobs_idx)
@@ -1343,7 +1449,6 @@ class Scheduler:
1343
1449
  input_top_logprobs_idx.append(req.input_top_logprobs_idx)
1344
1450
  output_top_logprobs_val.append(req.output_top_logprobs_val)
1345
1451
  output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1346
- normalized_prompt_logprob.append(req.normalized_prompt_logprob)
1347
1452
 
1348
1453
  # Send to detokenizer
1349
1454
  if rids:
@@ -1362,6 +1467,7 @@ class Scheduler:
1362
1467
  prompt_tokens,
1363
1468
  completion_tokens,
1364
1469
  cached_tokens,
1470
+ spec_verify_ct,
1365
1471
  input_token_logprobs_val,
1366
1472
  input_token_logprobs_idx,
1367
1473
  output_token_logprobs_val,
@@ -1370,7 +1476,6 @@ class Scheduler:
1370
1476
  input_top_logprobs_idx,
1371
1477
  output_top_logprobs_val,
1372
1478
  output_top_logprobs_idx,
1373
- normalized_prompt_logprob,
1374
1479
  )
1375
1480
  )
1376
1481
  else: # embedding or reward model
@@ -1412,12 +1517,7 @@ class Scheduler:
1412
1517
  # Check forward mode for cuda graph
1413
1518
  if not self.server_args.disable_cuda_graph:
1414
1519
  forward_mode_state = torch.tensor(
1415
- (
1416
- 1
1417
- if local_batch.forward_mode.is_decode()
1418
- or local_batch.forward_mode.is_idle()
1419
- else 0
1420
- ),
1520
+ (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1421
1521
  dtype=torch.int32,
1422
1522
  )
1423
1523
  torch.distributed.all_reduce(
@@ -1438,6 +1538,7 @@ class Scheduler:
1438
1538
  self.model_config,
1439
1539
  self.enable_overlap,
1440
1540
  self.spec_algorithm,
1541
+ self.server_args.enable_custom_logit_processor,
1441
1542
  )
1442
1543
  idle_batch.prepare_for_idle()
1443
1544
  return idle_batch
@@ -1466,6 +1567,9 @@ class Scheduler:
1466
1567
  self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
1467
1568
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1468
1569
 
1570
+ def flush_cache_wrapped(self, recv_req: FlushCacheReq):
1571
+ self.flush_cache()
1572
+
1469
1573
  def flush_cache(self):
1470
1574
  """Flush the memory pool and cache."""
1471
1575
  if len(self.waiting_queue) == 0 and (
@@ -1477,6 +1581,15 @@ class Scheduler:
1477
1581
  self.grammar_backend.reset()
1478
1582
  self.req_to_token_pool.clear()
1479
1583
  self.token_to_kv_pool.clear()
1584
+
1585
+ if not self.spec_algorithm.is_none():
1586
+ self.draft_worker.model_runner.req_to_token_pool.clear()
1587
+ self.draft_worker.model_runner.token_to_kv_pool.clear()
1588
+
1589
+ self.num_generated_tokens = 0
1590
+ self.forward_ct_decode = 0
1591
+ self.spec_num_total_accepted_tokens = 0
1592
+ self.spec_num_total_forward_ct = 0
1480
1593
  torch.cuda.empty_cache()
1481
1594
  logger.info("Cache flushed successfully!")
1482
1595
  if_success = True
@@ -1518,12 +1631,12 @@ class Scheduler:
1518
1631
  assert flash_cache_success, "Cache flush failed after updating weights"
1519
1632
  else:
1520
1633
  logger.error(message)
1521
- return success, message
1634
+ return UpdateWeightFromDiskReqOutput(success, message)
1522
1635
 
1523
1636
  def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
1524
1637
  """Initialize the online model parameter update group."""
1525
1638
  success, message = self.tp_worker.init_weights_update_group(recv_req)
1526
- return success, message
1639
+ return InitWeightsUpdateGroupReqOutput(success, message)
1527
1640
 
1528
1641
  def update_weights_from_distributed(
1529
1642
  self,
@@ -1536,7 +1649,7 @@ class Scheduler:
1536
1649
  assert flash_cache_success, "Cache flush failed after updating weights"
1537
1650
  else:
1538
1651
  logger.error(message)
1539
- return success, message
1652
+ return UpdateWeightsFromDistributedReqOutput(success, message)
1540
1653
 
1541
1654
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
1542
1655
  """Update the online model parameter from tensors."""
@@ -1547,11 +1660,11 @@ class Scheduler:
1547
1660
  assert flash_cache_success, "Cache flush failed after updating weights"
1548
1661
  else:
1549
1662
  logger.error(message)
1550
- return success, message
1663
+ return UpdateWeightsFromTensorReqOutput(success, message)
1551
1664
 
1552
1665
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
1553
1666
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1554
- return parameter
1667
+ return GetWeightsByNameReqOutput(parameter)
1555
1668
 
1556
1669
  def release_memory_occupation(self):
1557
1670
  self.stashed_model_static_state = _export_static_state(
@@ -1559,6 +1672,7 @@ class Scheduler:
1559
1672
  )
1560
1673
  self.memory_saver_adapter.pause()
1561
1674
  self.flush_cache()
1675
+ return ReleaseMemoryOccupationReqOutput()
1562
1676
 
1563
1677
  def resume_memory_occupation(self):
1564
1678
  self.memory_saver_adapter.resume()
@@ -1566,6 +1680,13 @@ class Scheduler:
1566
1680
  self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
1567
1681
  )
1568
1682
  del self.stashed_model_static_state
1683
+ return ResumeMemoryOccupationReqOutput()
1684
+
1685
+ def profile(self, recv_req: ProfileReq):
1686
+ if recv_req == ProfileReq.START_PROFILE:
1687
+ self.start_profile()
1688
+ else:
1689
+ self.stop_profile()
1569
1690
 
1570
1691
  def start_profile(self) -> None:
1571
1692
  if self.profiler is None:
@@ -1581,20 +1702,20 @@ class Scheduler:
1581
1702
  )
1582
1703
  logger.info("Profiler is done")
1583
1704
 
1584
- def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
1705
+ def open_session(self, recv_req: OpenSessionReqInput):
1585
1706
  # handle error
1586
1707
  session_id = recv_req.session_id
1587
1708
  if session_id in self.sessions:
1588
1709
  logger.warning(f"session id {session_id} already exist, cannot open.")
1589
- return session_id, False
1710
+ return OpenSessionReqOutput(session_id, False)
1590
1711
  elif session_id is None:
1591
1712
  logger.warning(f"session id is None, cannot open.")
1592
- return session_id, False
1713
+ return OpenSessionReqOutput(session_id, False)
1593
1714
  else:
1594
1715
  self.sessions[session_id] = Session(
1595
1716
  recv_req.capacity_of_str_len, session_id
1596
1717
  )
1597
- return session_id, True
1718
+ return OpenSessionReqOutput(session_id, True)
1598
1719
 
1599
1720
  def close_session(self, recv_req: CloseSessionReqInput):
1600
1721
  # handle error
@@ -1651,7 +1772,11 @@ def run_scheduler_process(
1651
1772
  try:
1652
1773
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1653
1774
  pipe_writer.send(
1654
- {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
1775
+ {
1776
+ "status": "ready",
1777
+ "max_total_num_tokens": scheduler.max_total_num_tokens,
1778
+ "max_req_input_len": scheduler.max_req_input_len,
1779
+ }
1655
1780
  )
1656
1781
  if scheduler.enable_overlap:
1657
1782
  scheduler.event_loop_overlap()