sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -131,6 +131,7 @@ class Session:
131
131
  sampling_params=req.sampling_params,
132
132
  lora_path=req.lora_path,
133
133
  session_id=self.session_id,
134
+ custom_logit_processor=req.custom_logit_processor,
134
135
  )
135
136
  if last_req is not None:
136
137
  new_req.image_inputs = last_req.image_inputs
@@ -21,9 +21,11 @@ import os
21
21
  import pickle
22
22
  import signal
23
23
  import sys
24
+ import threading
24
25
  import time
25
26
  import uuid
26
27
  from datetime import datetime
28
+ from http import HTTPStatus
27
29
  from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
28
30
 
29
31
  import fastapi
@@ -78,6 +80,7 @@ from sglang.srt.utils import (
78
80
  get_zmq_socket,
79
81
  kill_process_tree,
80
82
  )
83
+ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
81
84
 
82
85
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
83
86
 
@@ -110,17 +113,19 @@ class TokenizerManager:
110
113
  port_args: PortArgs,
111
114
  ):
112
115
  # Parse args
116
+
113
117
  self.server_args = server_args
114
118
  self.enable_metrics = server_args.enable_metrics
115
119
  self.log_requests = server_args.log_requests
120
+ self.log_requests_level = 0
116
121
 
117
122
  # Init inter-process communication
118
123
  context = zmq.asyncio.Context(2)
119
124
  self.recv_from_detokenizer = get_zmq_socket(
120
- context, zmq.PULL, port_args.tokenizer_ipc_name
125
+ context, zmq.PULL, port_args.tokenizer_ipc_name, True
121
126
  )
122
127
  self.send_to_scheduler = get_zmq_socket(
123
- context, zmq.PUSH, port_args.scheduler_input_ipc_name
128
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
124
129
  )
125
130
 
126
131
  # Read model args
@@ -153,6 +158,7 @@ class TokenizerManager:
153
158
  server_args.tokenizer_path,
154
159
  tokenizer_mode=server_args.tokenizer_mode,
155
160
  trust_remote_code=server_args.trust_remote_code,
161
+ revision=server_args.revision,
156
162
  )
157
163
  self.tokenizer = self.processor.tokenizer
158
164
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -166,10 +172,11 @@ class TokenizerManager:
166
172
  server_args.tokenizer_path,
167
173
  tokenizer_mode=server_args.tokenizer_mode,
168
174
  trust_remote_code=server_args.trust_remote_code,
175
+ revision=server_args.revision,
169
176
  )
170
177
 
171
178
  # Store states
172
- self.to_create_loop = True
179
+ self.no_create_loop = False
173
180
  self.rid_to_state: Dict[str, ReqState] = {}
174
181
  self.dump_requests_folder = "" # By default do not dump
175
182
  self.dump_requests_threshold = 1000
@@ -205,6 +212,8 @@ class TokenizerManager:
205
212
  self.resume_memory_occupation_communicator = _Communicator(
206
213
  self.send_to_scheduler, server_args.dp_size
207
214
  )
215
+ # Set after scheduler is initialized
216
+ self.max_req_input_len = None
208
217
 
209
218
  # Metrics
210
219
  if self.enable_metrics:
@@ -215,6 +224,44 @@ class TokenizerManager:
215
224
  },
216
225
  )
217
226
 
227
+ self._result_dispatcher = TypeBasedDispatcher(
228
+ [
229
+ (
230
+ (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
231
+ self._handle_batch_output,
232
+ ),
233
+ (OpenSessionReqOutput, self._handle_open_session_req_output),
234
+ (
235
+ UpdateWeightFromDiskReqOutput,
236
+ self._handle_update_weights_from_disk_req_output,
237
+ ),
238
+ (
239
+ InitWeightsUpdateGroupReqOutput,
240
+ self.init_weights_update_group_communicator.handle_recv,
241
+ ),
242
+ (
243
+ UpdateWeightsFromDistributedReqOutput,
244
+ self.update_weights_from_distributed_communicator.handle_recv,
245
+ ),
246
+ (
247
+ UpdateWeightsFromTensorReqOutput,
248
+ self.update_weights_from_tensor_communicator.handle_recv,
249
+ ),
250
+ (
251
+ GetWeightsByNameReqOutput,
252
+ self.get_weights_by_name_communicator.handle_recv,
253
+ ),
254
+ (
255
+ ReleaseMemoryOccupationReqOutput,
256
+ self.release_memory_occupation_communicator.handle_recv,
257
+ ),
258
+ (
259
+ ResumeMemoryOccupationReqOutput,
260
+ self.resume_memory_occupation_communicator.handle_recv,
261
+ ),
262
+ ]
263
+ )
264
+
218
265
  async def generate_request(
219
266
  self,
220
267
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -233,7 +280,10 @@ class TokenizerManager:
233
280
  obj.normalize_batch_and_arguments()
234
281
 
235
282
  if self.log_requests:
236
- logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
283
+ max_length = 2048 if self.log_requests_level == 0 else 1 << 30
284
+ logger.info(
285
+ f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
286
+ )
237
287
 
238
288
  async with self.model_update_lock.reader_lock:
239
289
  is_single = obj.is_single
@@ -265,15 +315,21 @@ class TokenizerManager:
265
315
  )
266
316
  input_embeds = obj.input_embeds
267
317
  input_ids = obj.input_ids
268
- elif obj.input_ids is None:
269
- input_ids = self.tokenizer.encode(input_text)
270
- else:
318
+ elif obj.input_ids is not None:
271
319
  input_ids = obj.input_ids
320
+ else:
321
+ if self.tokenizer is None:
322
+ raise ValueError(
323
+ "The engine initialized with skip_tokenizer_init=True cannot "
324
+ "accept text prompts. Please provide input_ids or re-initialize "
325
+ "the engine with skip_tokenizer_init=False."
326
+ )
327
+ input_ids = self.tokenizer.encode(input_text)
272
328
 
273
329
  if self.is_generation:
274
330
  # TODO: also support getting embeddings for multimodal models
275
331
  image_inputs: Dict = await self.image_processor.process_images_async(
276
- obj.image_data, input_text or input_ids, obj
332
+ obj.image_data, input_text or input_ids, obj, self.max_req_input_len
277
333
  )
278
334
  if image_inputs and "input_ids" in image_inputs:
279
335
  input_ids = image_inputs["input_ids"]
@@ -284,12 +340,28 @@ class TokenizerManager:
284
340
  SessionParams(**obj.session_params) if obj.session_params else None
285
341
  )
286
342
 
287
- if obj.input_ids is not None and len(input_ids) >= self.context_len:
343
+ input_token_num = len(input_ids) if input_ids is not None else 0
344
+ if input_token_num >= self.context_len:
288
345
  raise ValueError(
289
- f"The input ({len(input_ids)} tokens) is longer than the "
346
+ f"The input ({input_token_num} tokens) is longer than the "
290
347
  f"model's context length ({self.context_len} tokens)."
291
348
  )
292
349
 
350
+ if (
351
+ obj.sampling_params.get("max_new_tokens") is not None
352
+ and obj.sampling_params.get("max_new_tokens") + input_token_num
353
+ >= self.context_len
354
+ ):
355
+ raise ValueError(
356
+ f"Requested token count exceeds the model's maximum context length "
357
+ f"of {self.context_len} tokens. You requested a total of "
358
+ f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
359
+ f"tokens: {input_token_num} tokens from the input messages and "
360
+ f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
361
+ f"completion. Please reduce the number of tokens in the input "
362
+ f"messages or the completion to fit within the limit."
363
+ )
364
+
293
365
  # Parse sampling parameters
294
366
  sampling_params = SamplingParams(**obj.sampling_params)
295
367
  sampling_params.normalize(self.tokenizer)
@@ -310,6 +382,7 @@ class TokenizerManager:
310
382
  lora_path=obj.lora_path,
311
383
  input_embeds=input_embeds,
312
384
  session_params=session_params,
385
+ custom_logit_processor=obj.custom_logit_processor,
313
386
  )
314
387
  elif isinstance(obj, EmbeddingReqInput):
315
388
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -354,9 +427,20 @@ class TokenizerManager:
354
427
  state.out_list = []
355
428
  if state.finished:
356
429
  if self.log_requests:
357
- msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
430
+ max_length = 2048 if self.log_requests_level == 0 else 1 << 30
431
+ msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
358
432
  logger.info(msg)
359
433
  del self.rid_to_state[obj.rid]
434
+
435
+ # Check if this was an abort/error created by scheduler
436
+ if isinstance(out["meta_info"].get("finish_reason"), dict):
437
+ finish_reason = out["meta_info"]["finish_reason"]
438
+ if (
439
+ finish_reason.get("type") == "abort"
440
+ and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
441
+ ):
442
+ raise ValueError(finish_reason["message"])
443
+
360
444
  yield out
361
445
  break
362
446
 
@@ -601,12 +685,13 @@ class TokenizerManager:
601
685
  async def close_session(
602
686
  self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
603
687
  ):
604
- assert not self.to_create_loop, "close session should not be the first request"
605
688
  await self.send_to_scheduler.send_pyobj(obj)
606
689
 
607
690
  def configure_logging(self, obj: ConfigureLoggingReq):
608
691
  if obj.log_requests is not None:
609
692
  self.log_requests = obj.log_requests
693
+ if obj.log_requests_level is not None:
694
+ self.log_requests_level = obj.log_requests_level
610
695
  if obj.dump_requests_folder is not None:
611
696
  self.dump_requests_folder = obj.dump_requests_folder
612
697
  if obj.dump_requests_threshold is not None:
@@ -628,16 +713,29 @@ class TokenizerManager:
628
713
  return background_tasks
629
714
 
630
715
  def auto_create_handle_loop(self):
631
- if not self.to_create_loop:
716
+ if self.no_create_loop:
632
717
  return
633
718
 
634
- self.to_create_loop = False
719
+ self.no_create_loop = True
635
720
  loop = asyncio.get_event_loop()
636
- self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
721
+ self.asyncio_tasks.add(
722
+ loop.create_task(print_exception_wrapper(self.handle_loop))
723
+ )
637
724
 
638
- signal_handler = SignalHandler(self)
639
- loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
640
- self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
725
+ # We cannot add signal handler when the tokenizer manager is not in
726
+ # the main thread due to the CPython limitation.
727
+ if threading.current_thread() is threading.main_thread():
728
+ signal_handler = SignalHandler(self)
729
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
730
+ else:
731
+ logger.warning(
732
+ "Signal handler is not added because the tokenizer manager is "
733
+ "not in the main thread. This disables graceful shutdown of the "
734
+ "tokenizer manager when SIGTERM is received."
735
+ )
736
+ self.asyncio_tasks.add(
737
+ loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
738
+ )
641
739
 
642
740
  async def sigterm_watchdog(self):
643
741
  while not self.gracefully_exit:
@@ -661,106 +759,68 @@ class TokenizerManager:
661
759
  """The event loop that handles requests"""
662
760
 
663
761
  while True:
664
- recv_obj: Union[
665
- BatchStrOut,
666
- BatchEmbeddingOut,
667
- BatchTokenIDOut,
668
- UpdateWeightFromDiskReqOutput,
669
- UpdateWeightsFromDistributedReqOutput,
670
- GetWeightsByNameReqOutput,
671
- InitWeightsUpdateGroupReqOutput,
672
- ReleaseMemoryOccupationReqOutput,
673
- ResumeMemoryOccupationReqOutput,
674
- ] = await self.recv_from_detokenizer.recv_pyobj()
675
-
676
- if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
677
- for i, rid in enumerate(recv_obj.rids):
678
- state = self.rid_to_state.get(rid, None)
679
- if state is None:
680
- continue
681
-
682
- meta_info = {
683
- "id": rid,
684
- "finish_reason": recv_obj.finished_reasons[i],
685
- "prompt_tokens": recv_obj.prompt_tokens[i],
686
- }
762
+ recv_obj = await self.recv_from_detokenizer.recv_pyobj()
763
+ self._result_dispatcher(recv_obj)
764
+
765
+ def _handle_batch_output(
766
+ self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
767
+ ):
768
+ for i, rid in enumerate(recv_obj.rids):
769
+ state = self.rid_to_state.get(rid, None)
770
+ if state is None:
771
+ continue
687
772
 
688
- if getattr(state.obj, "return_logprob", False):
689
- self.convert_logprob_style(
690
- meta_info,
691
- state.obj.top_logprobs_num,
692
- state.obj.return_text_in_logprobs,
693
- recv_obj,
694
- i,
695
- )
696
-
697
- if not isinstance(recv_obj, BatchEmbeddingOut):
698
- meta_info.update(
699
- {
700
- "completion_tokens": recv_obj.completion_tokens[i],
701
- "cached_tokens": recv_obj.cached_tokens[i],
702
- }
703
- )
704
-
705
- if isinstance(recv_obj, BatchStrOut):
706
- out_dict = {
707
- "text": recv_obj.output_strs[i],
708
- "meta_info": meta_info,
709
- }
710
- elif isinstance(recv_obj, BatchTokenIDOut):
711
- out_dict = {
712
- "token_ids": recv_obj.output_ids[i],
713
- "meta_info": meta_info,
714
- }
715
- else:
716
- assert isinstance(recv_obj, BatchEmbeddingOut)
717
- out_dict = {
718
- "embedding": recv_obj.embeddings[i],
719
- "meta_info": meta_info,
720
- }
721
- state.out_list.append(out_dict)
722
- state.finished = recv_obj.finished_reasons[i] is not None
723
- state.event.set()
724
-
725
- if self.enable_metrics:
726
- self.collect_metrics(state, recv_obj, i)
727
- if self.dump_requests_folder and state.finished:
728
- self.dump_requests(state, out_dict)
729
- elif isinstance(recv_obj, OpenSessionReqOutput):
730
- self.session_futures[recv_obj.session_id].set_result(
731
- recv_obj.session_id if recv_obj.success else None
773
+ meta_info = {
774
+ "id": rid,
775
+ "finish_reason": recv_obj.finished_reasons[i],
776
+ "prompt_tokens": recv_obj.prompt_tokens[i],
777
+ }
778
+
779
+ if getattr(state.obj, "return_logprob", False):
780
+ self.convert_logprob_style(
781
+ meta_info,
782
+ state.obj.top_logprobs_num,
783
+ state.obj.return_text_in_logprobs,
784
+ recv_obj,
785
+ i,
732
786
  )
733
- elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
734
- if self.server_args.dp_size == 1:
735
- self.model_update_result.set_result(recv_obj)
736
- else: # self.server_args.dp_size > 1
737
- self.model_update_tmp.append(recv_obj)
738
- # set future if the all results are recevied
739
- if len(self.model_update_tmp) == self.server_args.dp_size:
740
- self.model_update_result.set_result(self.model_update_tmp)
741
- elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
742
- assert (
743
- self.server_args.dp_size == 1
744
- ), "dp_size must be 1 for init parameter update group"
745
- self.init_weights_update_group_communicator.handle_recv(recv_obj)
746
- elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
747
- assert (
748
- self.server_args.dp_size == 1
749
- ), "dp_size must be 1 for update weights from distributed"
750
- self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
751
- elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
752
- assert (
753
- self.server_args.dp_size == 1
754
- ), "dp_size must be 1 for update weights from distributed"
755
- self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
756
- elif isinstance(recv_obj, GetWeightsByNameReqOutput):
757
- self.get_weights_by_name_communicator.handle_recv(recv_obj)
758
- elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
759
- self.release_memory_occupation_communicator.handle_recv(recv_obj)
760
- elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
761
- self.resume_memory_occupation_communicator.handle_recv(recv_obj)
787
+
788
+ if self.server_args.speculative_algorithm:
789
+ meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
790
+
791
+ if not isinstance(recv_obj, BatchEmbeddingOut):
792
+ meta_info.update(
793
+ {
794
+ "completion_tokens": recv_obj.completion_tokens[i],
795
+ "cached_tokens": recv_obj.cached_tokens[i],
796
+ }
797
+ )
798
+
799
+ if isinstance(recv_obj, BatchStrOut):
800
+ out_dict = {
801
+ "text": recv_obj.output_strs[i],
802
+ "meta_info": meta_info,
803
+ }
804
+ elif isinstance(recv_obj, BatchTokenIDOut):
805
+ out_dict = {
806
+ "token_ids": recv_obj.output_ids[i],
807
+ "meta_info": meta_info,
808
+ }
762
809
  else:
763
- raise ValueError(f"Invalid object: {recv_obj=}")
810
+ assert isinstance(recv_obj, BatchEmbeddingOut)
811
+ out_dict = {
812
+ "embedding": recv_obj.embeddings[i],
813
+ "meta_info": meta_info,
814
+ }
815
+
816
+ state.out_list.append(out_dict)
817
+ state.finished = recv_obj.finished_reasons[i] is not None
818
+ state.event.set()
819
+
820
+ if self.enable_metrics and state.obj.log_metrics:
821
+ self.collect_metrics(state, recv_obj, i)
822
+ if self.dump_requests_folder and state.finished and state.obj.log_metrics:
823
+ self.dump_requests(state, out_dict)
764
824
 
765
825
  def convert_logprob_style(
766
826
  self,
@@ -780,9 +840,6 @@ class TokenizerManager:
780
840
  recv_obj.output_token_logprobs_idx[recv_obj_index],
781
841
  return_text_in_logprobs,
782
842
  )
783
- meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
784
- recv_obj_index
785
- ]
786
843
 
787
844
  if top_logprobs_num > 0:
788
845
  meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
@@ -874,19 +931,51 @@ class TokenizerManager:
874
931
  )
875
932
 
876
933
  if len(self.dump_request_list) >= self.dump_requests_threshold:
934
+ filename = os.path.join(
935
+ self.dump_requests_folder,
936
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
937
+ )
938
+ logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
939
+
877
940
  to_dump = self.dump_request_list
878
941
  self.dump_request_list = []
879
942
 
880
943
  def background_task():
881
944
  os.makedirs(self.dump_requests_folder, exist_ok=True)
882
- current_time = datetime.now()
883
- filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
884
- with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
945
+ with open(filename, "wb") as f:
885
946
  pickle.dump(to_dump, f)
886
947
 
887
948
  # Schedule the task to run in the background without awaiting it
888
949
  asyncio.create_task(asyncio.to_thread(background_task))
889
950
 
951
+ def _handle_open_session_req_output(self, recv_obj):
952
+ self.session_futures[recv_obj.session_id].set_result(
953
+ recv_obj.session_id if recv_obj.success else None
954
+ )
955
+
956
+ def _handle_update_weights_from_disk_req_output(self, recv_obj):
957
+ if self.server_args.dp_size == 1:
958
+ self.model_update_result.set_result(recv_obj)
959
+ else: # self.server_args.dp_size > 1
960
+ self.model_update_tmp.append(recv_obj)
961
+ # set future if the all results are recevied
962
+ if len(self.model_update_tmp) == self.server_args.dp_size:
963
+ self.model_update_result.set_result(self.model_update_tmp)
964
+
965
+
966
+ async def print_exception_wrapper(func):
967
+ """
968
+ Sometimes an asyncio function does not print exception.
969
+ We do another wrapper to handle the exception.
970
+ """
971
+ try:
972
+ await func()
973
+ except Exception:
974
+ traceback = get_exception_traceback()
975
+ logger.error(f"TokenizerManager hit an exception: {traceback}")
976
+ kill_process_tree(os.getpid(), include_parent=True)
977
+ sys.exit(1)
978
+
890
979
 
891
980
  class SignalHandler:
892
981
  def __init__(self, tokenizer_manager):
@@ -83,6 +83,7 @@ class TpModelWorker:
83
83
  server_args.tokenizer_path,
84
84
  tokenizer_mode=server_args.tokenizer_mode,
85
85
  trust_remote_code=server_args.trust_remote_code,
86
+ revision=server_args.revision,
86
87
  )
87
88
  self.tokenizer = self.processor.tokenizer
88
89
  else:
@@ -90,6 +91,7 @@ class TpModelWorker:
90
91
  server_args.tokenizer_path,
91
92
  tokenizer_mode=server_args.tokenizer_mode,
92
93
  trust_remote_code=server_args.trust_remote_code,
94
+ revision=server_args.revision,
93
95
  )
94
96
  self.device = self.model_runner.device
95
97
 
@@ -101,6 +103,7 @@ class TpModelWorker:
101
103
  self.max_total_num_tokens // 2
102
104
  if server_args.max_running_requests is None
103
105
  else server_args.max_running_requests
106
+ // (server_args.dp_size if server_args.enable_dp_attention else 1)
104
107
  ),
105
108
  self.model_runner.req_to_token_pool.size,
106
109
  )
@@ -142,16 +145,15 @@ class TpModelWorker:
142
145
  def get_tp_cpu_group(self):
143
146
  return self.model_runner.tp_group.cpu_group
144
147
 
148
+ def get_attention_tp_cpu_group(self):
149
+ return self.model_runner.attention_tp_group.cpu_group
150
+
145
151
  def get_memory_pool(self):
146
152
  return (
147
153
  self.model_runner.req_to_token_pool,
148
154
  self.model_runner.token_to_kv_pool,
149
155
  )
150
156
 
151
- def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
152
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
153
- self.model_runner.forward(forward_batch)
154
-
155
157
  def forward_batch_generation(
156
158
  self,
157
159
  model_worker_batch: ModelWorkerBatch,
@@ -82,6 +82,8 @@ class TpModelWorkerClient:
82
82
  self.forward_thread.start()
83
83
  self.parent_process = psutil.Process().parent()
84
84
  self.scheduler_stream = torch.get_device_module(self.device).current_stream()
85
+ if self.device == "cpu":
86
+ self.scheduler_stream.synchronize = lambda: None # No-op for CPU
85
87
 
86
88
  def get_worker_info(self):
87
89
  return self.worker.get_worker_info()
@@ -92,6 +94,9 @@ class TpModelWorkerClient:
92
94
  def get_tp_cpu_group(self):
93
95
  return self.worker.get_tp_cpu_group()
94
96
 
97
+ def get_attention_tp_cpu_group(self):
98
+ return self.worker.get_attention_tp_cpu_group()
99
+
95
100
  def get_memory_pool(self):
96
101
  return (
97
102
  self.worker.model_runner.req_to_token_pool,
@@ -151,11 +156,6 @@ class TpModelWorkerClient:
151
156
  logits_output.input_token_logprobs = (
152
157
  logits_output.input_token_logprobs.to("cpu", non_blocking=True)
153
158
  )
154
- logits_output.normalized_prompt_logprobs = (
155
- logits_output.normalized_prompt_logprobs.to(
156
- "cpu", non_blocking=True
157
- )
158
- )
159
159
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
160
160
  copy_done.record()
161
161
 
@@ -174,9 +174,6 @@ class TpModelWorkerClient:
174
174
  logits_output.input_token_logprobs = (
175
175
  logits_output.input_token_logprobs.tolist()
176
176
  )
177
- logits_output.normalized_prompt_logprobs = (
178
- logits_output.normalized_prompt_logprobs.tolist()
179
- )
180
177
  next_token_ids = next_token_ids.tolist()
181
178
  return logits_output, next_token_ids
182
179
 
@@ -0,0 +1,44 @@
1
+ import logging
2
+ from http import HTTPStatus
3
+ from typing import Optional
4
+
5
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def validate_input_length(
11
+ req: Req, max_req_input_len: int, allow_auto_truncate: bool
12
+ ) -> Optional[str]:
13
+ """Validate and potentially truncate input length.
14
+
15
+ Args:
16
+ req: The request containing input_ids to validate
17
+ max_req_input_len: Maximum allowed input length
18
+ allow_auto_truncate: Whether to truncate long inputs
19
+
20
+ Returns:
21
+ Error message if validation fails, None if successful
22
+ """
23
+ if len(req.origin_input_ids) >= max_req_input_len:
24
+ if allow_auto_truncate:
25
+ logger.warning(
26
+ "Request length is longer than the KV cache pool size or "
27
+ "the max context length. Truncated. "
28
+ f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
29
+ )
30
+ req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
31
+ return None
32
+ else:
33
+ error_msg = (
34
+ f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
35
+ f"the maximum allowed length ({max_req_input_len} tokens). "
36
+ f"Use a shorter input or enable --allow-auto-truncate."
37
+ )
38
+ logger.error(error_msg)
39
+ req.finished_reason = FINISH_ABORT(
40
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
41
+ )
42
+ return error_msg
43
+
44
+ return None