sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +1 -0
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +41 -5
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
  41. sglang/srt/layers/parameter.py +2 -1
  42. sglang/srt/layers/quantization/__init__.py +20 -23
  43. sglang/srt/layers/quantization/fp8.py +6 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  45. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  46. sglang/srt/layers/radix_attention.py +2 -2
  47. sglang/srt/layers/rotary_embedding.py +1179 -31
  48. sglang/srt/layers/sampler.py +39 -1
  49. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  50. sglang/srt/lora/lora.py +1 -9
  51. sglang/srt/managers/configure_logging.py +3 -0
  52. sglang/srt/managers/data_parallel_controller.py +79 -72
  53. sglang/srt/managers/detokenizer_manager.py +23 -6
  54. sglang/srt/managers/image_processor.py +158 -2
  55. sglang/srt/managers/io_struct.py +25 -2
  56. sglang/srt/managers/schedule_batch.py +49 -22
  57. sglang/srt/managers/schedule_policy.py +26 -12
  58. sglang/srt/managers/scheduler.py +277 -178
  59. sglang/srt/managers/session_controller.py +1 -0
  60. sglang/srt/managers/tokenizer_manager.py +206 -121
  61. sglang/srt/managers/tp_worker.py +6 -4
  62. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  63. sglang/srt/managers/utils.py +44 -0
  64. sglang/srt/mem_cache/memory_pool.py +10 -32
  65. sglang/srt/metrics/collector.py +15 -6
  66. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  67. sglang/srt/model_executor/model_runner.py +37 -15
  68. sglang/srt/model_loader/loader.py +8 -6
  69. sglang/srt/model_loader/weight_utils.py +55 -2
  70. sglang/srt/models/baichuan.py +6 -6
  71. sglang/srt/models/chatglm.py +2 -2
  72. sglang/srt/models/commandr.py +3 -3
  73. sglang/srt/models/dbrx.py +4 -4
  74. sglang/srt/models/deepseek.py +3 -3
  75. sglang/srt/models/deepseek_v2.py +8 -8
  76. sglang/srt/models/exaone.py +2 -2
  77. sglang/srt/models/gemma.py +2 -2
  78. sglang/srt/models/gemma2.py +6 -24
  79. sglang/srt/models/gpt2.py +3 -5
  80. sglang/srt/models/gpt_bigcode.py +1 -1
  81. sglang/srt/models/granite.py +2 -2
  82. sglang/srt/models/grok.py +3 -3
  83. sglang/srt/models/internlm2.py +2 -2
  84. sglang/srt/models/llama.py +7 -5
  85. sglang/srt/models/minicpm.py +2 -2
  86. sglang/srt/models/minicpm3.py +6 -6
  87. sglang/srt/models/minicpmv.py +1238 -0
  88. sglang/srt/models/mixtral.py +3 -3
  89. sglang/srt/models/mixtral_quant.py +3 -3
  90. sglang/srt/models/mllama.py +2 -2
  91. sglang/srt/models/olmo.py +3 -3
  92. sglang/srt/models/olmo2.py +4 -4
  93. sglang/srt/models/olmoe.py +7 -13
  94. sglang/srt/models/phi3_small.py +2 -2
  95. sglang/srt/models/qwen.py +2 -2
  96. sglang/srt/models/qwen2.py +41 -4
  97. sglang/srt/models/qwen2_moe.py +3 -3
  98. sglang/srt/models/qwen2_vl.py +22 -122
  99. sglang/srt/models/stablelm.py +2 -2
  100. sglang/srt/models/torch_native_llama.py +3 -3
  101. sglang/srt/models/xverse.py +6 -6
  102. sglang/srt/models/xverse_moe.py +6 -6
  103. sglang/srt/openai_api/protocol.py +2 -0
  104. sglang/srt/sampling/custom_logit_processor.py +38 -0
  105. sglang/srt/sampling/sampling_batch_info.py +139 -4
  106. sglang/srt/sampling/sampling_params.py +3 -1
  107. sglang/srt/server.py +4 -1090
  108. sglang/srt/server_args.py +57 -14
  109. sglang/srt/utils.py +103 -65
  110. sglang/test/runners.py +8 -13
  111. sglang/test/test_programs.py +1 -1
  112. sglang/test/test_utils.py +3 -1
  113. sglang/utils.py +12 -2
  114. sglang/version.py +1 -1
  115. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
  116. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
  117. sglang/launch_server_llavavid.py +0 -25
  118. sglang/srt/constrained/__init__.py +0 -16
  119. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  120. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  121. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -131,6 +131,7 @@ class Session:
131
131
  sampling_params=req.sampling_params,
132
132
  lora_path=req.lora_path,
133
133
  session_id=self.session_id,
134
+ custom_logit_processor=req.custom_logit_processor,
134
135
  )
135
136
  if last_req is not None:
136
137
  new_req.image_inputs = last_req.image_inputs
@@ -21,9 +21,11 @@ import os
21
21
  import pickle
22
22
  import signal
23
23
  import sys
24
+ import threading
24
25
  import time
25
26
  import uuid
26
27
  from datetime import datetime
28
+ from http import HTTPStatus
27
29
  from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
28
30
 
29
31
  import fastapi
@@ -78,6 +80,7 @@ from sglang.srt.utils import (
78
80
  get_zmq_socket,
79
81
  kill_process_tree,
80
82
  )
83
+ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
81
84
 
82
85
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
83
86
 
@@ -110,17 +113,19 @@ class TokenizerManager:
110
113
  port_args: PortArgs,
111
114
  ):
112
115
  # Parse args
116
+
113
117
  self.server_args = server_args
114
118
  self.enable_metrics = server_args.enable_metrics
115
119
  self.log_requests = server_args.log_requests
120
+ self.log_requests_level = 0
116
121
 
117
122
  # Init inter-process communication
118
123
  context = zmq.asyncio.Context(2)
119
124
  self.recv_from_detokenizer = get_zmq_socket(
120
- context, zmq.PULL, port_args.tokenizer_ipc_name
125
+ context, zmq.PULL, port_args.tokenizer_ipc_name, True
121
126
  )
122
127
  self.send_to_scheduler = get_zmq_socket(
123
- context, zmq.PUSH, port_args.scheduler_input_ipc_name
128
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
124
129
  )
125
130
 
126
131
  # Read model args
@@ -153,6 +158,7 @@ class TokenizerManager:
153
158
  server_args.tokenizer_path,
154
159
  tokenizer_mode=server_args.tokenizer_mode,
155
160
  trust_remote_code=server_args.trust_remote_code,
161
+ revision=server_args.revision,
156
162
  )
157
163
  self.tokenizer = self.processor.tokenizer
158
164
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -166,10 +172,11 @@ class TokenizerManager:
166
172
  server_args.tokenizer_path,
167
173
  tokenizer_mode=server_args.tokenizer_mode,
168
174
  trust_remote_code=server_args.trust_remote_code,
175
+ revision=server_args.revision,
169
176
  )
170
177
 
171
178
  # Store states
172
- self.to_create_loop = True
179
+ self.no_create_loop = False
173
180
  self.rid_to_state: Dict[str, ReqState] = {}
174
181
  self.dump_requests_folder = "" # By default do not dump
175
182
  self.dump_requests_threshold = 1000
@@ -205,6 +212,8 @@ class TokenizerManager:
205
212
  self.resume_memory_occupation_communicator = _Communicator(
206
213
  self.send_to_scheduler, server_args.dp_size
207
214
  )
215
+ # Set after scheduler is initialized
216
+ self.max_req_input_len = None
208
217
 
209
218
  # Metrics
210
219
  if self.enable_metrics:
@@ -215,6 +224,44 @@ class TokenizerManager:
215
224
  },
216
225
  )
217
226
 
227
+ self._result_dispatcher = TypeBasedDispatcher(
228
+ [
229
+ (
230
+ (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
231
+ self._handle_batch_output,
232
+ ),
233
+ (OpenSessionReqOutput, self._handle_open_session_req_output),
234
+ (
235
+ UpdateWeightFromDiskReqOutput,
236
+ self._handle_update_weights_from_disk_req_output,
237
+ ),
238
+ (
239
+ InitWeightsUpdateGroupReqOutput,
240
+ self.init_weights_update_group_communicator.handle_recv,
241
+ ),
242
+ (
243
+ UpdateWeightsFromDistributedReqOutput,
244
+ self.update_weights_from_distributed_communicator.handle_recv,
245
+ ),
246
+ (
247
+ UpdateWeightsFromTensorReqOutput,
248
+ self.update_weights_from_tensor_communicator.handle_recv,
249
+ ),
250
+ (
251
+ GetWeightsByNameReqOutput,
252
+ self.get_weights_by_name_communicator.handle_recv,
253
+ ),
254
+ (
255
+ ReleaseMemoryOccupationReqOutput,
256
+ self.release_memory_occupation_communicator.handle_recv,
257
+ ),
258
+ (
259
+ ResumeMemoryOccupationReqOutput,
260
+ self.resume_memory_occupation_communicator.handle_recv,
261
+ ),
262
+ ]
263
+ )
264
+
218
265
  async def generate_request(
219
266
  self,
220
267
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -233,7 +280,10 @@ class TokenizerManager:
233
280
  obj.normalize_batch_and_arguments()
234
281
 
235
282
  if self.log_requests:
236
- logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
283
+ max_length = 2048 if self.log_requests_level == 0 else 1 << 30
284
+ logger.info(
285
+ f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
286
+ )
237
287
 
238
288
  async with self.model_update_lock.reader_lock:
239
289
  is_single = obj.is_single
@@ -265,15 +315,21 @@ class TokenizerManager:
265
315
  )
266
316
  input_embeds = obj.input_embeds
267
317
  input_ids = obj.input_ids
268
- elif obj.input_ids is None:
269
- input_ids = self.tokenizer.encode(input_text)
270
- else:
318
+ elif obj.input_ids is not None:
271
319
  input_ids = obj.input_ids
320
+ else:
321
+ if self.tokenizer is None:
322
+ raise ValueError(
323
+ "The engine initialized with skip_tokenizer_init=True cannot "
324
+ "accept text prompts. Please provide input_ids or re-initialize "
325
+ "the engine with skip_tokenizer_init=False."
326
+ )
327
+ input_ids = self.tokenizer.encode(input_text)
272
328
 
273
329
  if self.is_generation:
274
330
  # TODO: also support getting embeddings for multimodal models
275
331
  image_inputs: Dict = await self.image_processor.process_images_async(
276
- obj.image_data, input_text or input_ids, obj
332
+ obj.image_data, input_text or input_ids, obj, self.max_req_input_len
277
333
  )
278
334
  if image_inputs and "input_ids" in image_inputs:
279
335
  input_ids = image_inputs["input_ids"]
@@ -284,12 +340,28 @@ class TokenizerManager:
284
340
  SessionParams(**obj.session_params) if obj.session_params else None
285
341
  )
286
342
 
287
- if obj.input_ids is not None and len(input_ids) >= self.context_len:
343
+ input_token_num = len(input_ids) if input_ids is not None else 0
344
+ if input_token_num >= self.context_len:
288
345
  raise ValueError(
289
- f"The input ({len(input_ids)} tokens) is longer than the "
346
+ f"The input ({input_token_num} tokens) is longer than the "
290
347
  f"model's context length ({self.context_len} tokens)."
291
348
  )
292
349
 
350
+ if (
351
+ obj.sampling_params.get("max_new_tokens") is not None
352
+ and obj.sampling_params.get("max_new_tokens") + input_token_num
353
+ >= self.context_len
354
+ ):
355
+ raise ValueError(
356
+ f"Requested token count exceeds the model's maximum context length "
357
+ f"of {self.context_len} tokens. You requested a total of "
358
+ f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
359
+ f"tokens: {input_token_num} tokens from the input messages and "
360
+ f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
361
+ f"completion. Please reduce the number of tokens in the input "
362
+ f"messages or the completion to fit within the limit."
363
+ )
364
+
293
365
  # Parse sampling parameters
294
366
  sampling_params = SamplingParams(**obj.sampling_params)
295
367
  sampling_params.normalize(self.tokenizer)
@@ -310,6 +382,7 @@ class TokenizerManager:
310
382
  lora_path=obj.lora_path,
311
383
  input_embeds=input_embeds,
312
384
  session_params=session_params,
385
+ custom_logit_processor=obj.custom_logit_processor,
313
386
  )
314
387
  elif isinstance(obj, EmbeddingReqInput):
315
388
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -354,9 +427,20 @@ class TokenizerManager:
354
427
  state.out_list = []
355
428
  if state.finished:
356
429
  if self.log_requests:
357
- msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
430
+ max_length = 2048 if self.log_requests_level == 0 else 1 << 30
431
+ msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
358
432
  logger.info(msg)
359
433
  del self.rid_to_state[obj.rid]
434
+
435
+ # Check if this was an abort/error created by scheduler
436
+ if isinstance(out["meta_info"].get("finish_reason"), dict):
437
+ finish_reason = out["meta_info"]["finish_reason"]
438
+ if (
439
+ finish_reason.get("type") == "abort"
440
+ and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
441
+ ):
442
+ raise ValueError(finish_reason["message"])
443
+
360
444
  yield out
361
445
  break
362
446
 
@@ -601,12 +685,13 @@ class TokenizerManager:
601
685
  async def close_session(
602
686
  self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
603
687
  ):
604
- assert not self.to_create_loop, "close session should not be the first request"
605
688
  await self.send_to_scheduler.send_pyobj(obj)
606
689
 
607
690
  def configure_logging(self, obj: ConfigureLoggingReq):
608
691
  if obj.log_requests is not None:
609
692
  self.log_requests = obj.log_requests
693
+ if obj.log_requests_level is not None:
694
+ self.log_requests_level = obj.log_requests_level
610
695
  if obj.dump_requests_folder is not None:
611
696
  self.dump_requests_folder = obj.dump_requests_folder
612
697
  if obj.dump_requests_threshold is not None:
@@ -628,16 +713,29 @@ class TokenizerManager:
628
713
  return background_tasks
629
714
 
630
715
  def auto_create_handle_loop(self):
631
- if not self.to_create_loop:
716
+ if self.no_create_loop:
632
717
  return
633
718
 
634
- self.to_create_loop = False
719
+ self.no_create_loop = True
635
720
  loop = asyncio.get_event_loop()
636
- self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
721
+ self.asyncio_tasks.add(
722
+ loop.create_task(print_exception_wrapper(self.handle_loop))
723
+ )
637
724
 
638
- signal_handler = SignalHandler(self)
639
- loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
640
- self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
725
+ # We cannot add signal handler when the tokenizer manager is not in
726
+ # the main thread due to the CPython limitation.
727
+ if threading.current_thread() is threading.main_thread():
728
+ signal_handler = SignalHandler(self)
729
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
730
+ else:
731
+ logger.warning(
732
+ "Signal handler is not added because the tokenizer manager is "
733
+ "not in the main thread. This disables graceful shutdown of the "
734
+ "tokenizer manager when SIGTERM is received."
735
+ )
736
+ self.asyncio_tasks.add(
737
+ loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
738
+ )
641
739
 
642
740
  async def sigterm_watchdog(self):
643
741
  while not self.gracefully_exit:
@@ -661,106 +759,64 @@ class TokenizerManager:
661
759
  """The event loop that handles requests"""
662
760
 
663
761
  while True:
664
- recv_obj: Union[
665
- BatchStrOut,
666
- BatchEmbeddingOut,
667
- BatchTokenIDOut,
668
- UpdateWeightFromDiskReqOutput,
669
- UpdateWeightsFromDistributedReqOutput,
670
- GetWeightsByNameReqOutput,
671
- InitWeightsUpdateGroupReqOutput,
672
- ReleaseMemoryOccupationReqOutput,
673
- ResumeMemoryOccupationReqOutput,
674
- ] = await self.recv_from_detokenizer.recv_pyobj()
675
-
676
- if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
677
- for i, rid in enumerate(recv_obj.rids):
678
- state = self.rid_to_state.get(rid, None)
679
- if state is None:
680
- continue
681
-
682
- meta_info = {
683
- "id": rid,
684
- "finish_reason": recv_obj.finished_reasons[i],
685
- "prompt_tokens": recv_obj.prompt_tokens[i],
686
- }
762
+ recv_obj = await self.recv_from_detokenizer.recv_pyobj()
763
+ self._result_dispatcher(recv_obj)
687
764
 
688
- if getattr(state.obj, "return_logprob", False):
689
- self.convert_logprob_style(
690
- meta_info,
691
- state.obj.top_logprobs_num,
692
- state.obj.return_text_in_logprobs,
693
- recv_obj,
694
- i,
695
- )
696
-
697
- if not isinstance(recv_obj, BatchEmbeddingOut):
698
- meta_info.update(
699
- {
700
- "completion_tokens": recv_obj.completion_tokens[i],
701
- "cached_tokens": recv_obj.cached_tokens[i],
702
- }
703
- )
704
-
705
- if isinstance(recv_obj, BatchStrOut):
706
- out_dict = {
707
- "text": recv_obj.output_strs[i],
708
- "meta_info": meta_info,
709
- }
710
- elif isinstance(recv_obj, BatchTokenIDOut):
711
- out_dict = {
712
- "token_ids": recv_obj.output_ids[i],
713
- "meta_info": meta_info,
714
- }
715
- else:
716
- assert isinstance(recv_obj, BatchEmbeddingOut)
717
- out_dict = {
718
- "embedding": recv_obj.embeddings[i],
719
- "meta_info": meta_info,
720
- }
721
- state.out_list.append(out_dict)
722
- state.finished = recv_obj.finished_reasons[i] is not None
723
- state.event.set()
724
-
725
- if self.enable_metrics:
726
- self.collect_metrics(state, recv_obj, i)
727
- if self.dump_requests_folder and state.finished:
728
- self.dump_requests(state, out_dict)
729
- elif isinstance(recv_obj, OpenSessionReqOutput):
730
- self.session_futures[recv_obj.session_id].set_result(
731
- recv_obj.session_id if recv_obj.success else None
765
+ def _handle_batch_output(
766
+ self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
767
+ ):
768
+ for i, rid in enumerate(recv_obj.rids):
769
+ state = self.rid_to_state.get(rid, None)
770
+ if state is None:
771
+ continue
772
+
773
+ meta_info = {
774
+ "id": rid,
775
+ "finish_reason": recv_obj.finished_reasons[i],
776
+ "prompt_tokens": recv_obj.prompt_tokens[i],
777
+ }
778
+
779
+ if getattr(state.obj, "return_logprob", False):
780
+ self.convert_logprob_style(
781
+ meta_info,
782
+ state.obj.top_logprobs_num,
783
+ state.obj.return_text_in_logprobs,
784
+ recv_obj,
785
+ i,
786
+ )
787
+
788
+ if not isinstance(recv_obj, BatchEmbeddingOut):
789
+ meta_info.update(
790
+ {
791
+ "completion_tokens": recv_obj.completion_tokens[i],
792
+ "cached_tokens": recv_obj.cached_tokens[i],
793
+ }
732
794
  )
733
- elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
734
- if self.server_args.dp_size == 1:
735
- self.model_update_result.set_result(recv_obj)
736
- else: # self.server_args.dp_size > 1
737
- self.model_update_tmp.append(recv_obj)
738
- # set future if the all results are recevied
739
- if len(self.model_update_tmp) == self.server_args.dp_size:
740
- self.model_update_result.set_result(self.model_update_tmp)
741
- elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
742
- assert (
743
- self.server_args.dp_size == 1
744
- ), "dp_size must be 1 for init parameter update group"
745
- self.init_weights_update_group_communicator.handle_recv(recv_obj)
746
- elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
747
- assert (
748
- self.server_args.dp_size == 1
749
- ), "dp_size must be 1 for update weights from distributed"
750
- self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
751
- elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
752
- assert (
753
- self.server_args.dp_size == 1
754
- ), "dp_size must be 1 for update weights from distributed"
755
- self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
756
- elif isinstance(recv_obj, GetWeightsByNameReqOutput):
757
- self.get_weights_by_name_communicator.handle_recv(recv_obj)
758
- elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
759
- self.release_memory_occupation_communicator.handle_recv(recv_obj)
760
- elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
761
- self.resume_memory_occupation_communicator.handle_recv(recv_obj)
795
+
796
+ if isinstance(recv_obj, BatchStrOut):
797
+ out_dict = {
798
+ "text": recv_obj.output_strs[i],
799
+ "meta_info": meta_info,
800
+ }
801
+ elif isinstance(recv_obj, BatchTokenIDOut):
802
+ out_dict = {
803
+ "token_ids": recv_obj.output_ids[i],
804
+ "meta_info": meta_info,
805
+ }
762
806
  else:
763
- raise ValueError(f"Invalid object: {recv_obj=}")
807
+ assert isinstance(recv_obj, BatchEmbeddingOut)
808
+ out_dict = {
809
+ "embedding": recv_obj.embeddings[i],
810
+ "meta_info": meta_info,
811
+ }
812
+ state.out_list.append(out_dict)
813
+ state.finished = recv_obj.finished_reasons[i] is not None
814
+ state.event.set()
815
+
816
+ if self.enable_metrics and state.obj.log_metrics:
817
+ self.collect_metrics(state, recv_obj, i)
818
+ if self.dump_requests_folder and state.finished and state.obj.log_metrics:
819
+ self.dump_requests(state, out_dict)
764
820
 
765
821
  def convert_logprob_style(
766
822
  self,
@@ -780,9 +836,6 @@ class TokenizerManager:
780
836
  recv_obj.output_token_logprobs_idx[recv_obj_index],
781
837
  return_text_in_logprobs,
782
838
  )
783
- meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
784
- recv_obj_index
785
- ]
786
839
 
787
840
  if top_logprobs_num > 0:
788
841
  meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
@@ -874,19 +927,51 @@ class TokenizerManager:
874
927
  )
875
928
 
876
929
  if len(self.dump_request_list) >= self.dump_requests_threshold:
930
+ filename = os.path.join(
931
+ self.dump_requests_folder,
932
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
933
+ )
934
+ logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
935
+
877
936
  to_dump = self.dump_request_list
878
937
  self.dump_request_list = []
879
938
 
880
939
  def background_task():
881
940
  os.makedirs(self.dump_requests_folder, exist_ok=True)
882
- current_time = datetime.now()
883
- filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
884
- with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
941
+ with open(filename, "wb") as f:
885
942
  pickle.dump(to_dump, f)
886
943
 
887
944
  # Schedule the task to run in the background without awaiting it
888
945
  asyncio.create_task(asyncio.to_thread(background_task))
889
946
 
947
+ def _handle_open_session_req_output(self, recv_obj):
948
+ self.session_futures[recv_obj.session_id].set_result(
949
+ recv_obj.session_id if recv_obj.success else None
950
+ )
951
+
952
+ def _handle_update_weights_from_disk_req_output(self, recv_obj):
953
+ if self.server_args.dp_size == 1:
954
+ self.model_update_result.set_result(recv_obj)
955
+ else: # self.server_args.dp_size > 1
956
+ self.model_update_tmp.append(recv_obj)
957
+ # set future if the all results are recevied
958
+ if len(self.model_update_tmp) == self.server_args.dp_size:
959
+ self.model_update_result.set_result(self.model_update_tmp)
960
+
961
+
962
+ async def print_exception_wrapper(func):
963
+ """
964
+ Sometimes an asyncio function does not print exception.
965
+ We do another wrapper to handle the exception.
966
+ """
967
+ try:
968
+ await func()
969
+ except Exception:
970
+ traceback = get_exception_traceback()
971
+ logger.error(f"TokenizerManager hit an exception: {traceback}")
972
+ kill_process_tree(os.getpid(), include_parent=True)
973
+ sys.exit(1)
974
+
890
975
 
891
976
  class SignalHandler:
892
977
  def __init__(self, tokenizer_manager):
@@ -83,6 +83,7 @@ class TpModelWorker:
83
83
  server_args.tokenizer_path,
84
84
  tokenizer_mode=server_args.tokenizer_mode,
85
85
  trust_remote_code=server_args.trust_remote_code,
86
+ revision=server_args.revision,
86
87
  )
87
88
  self.tokenizer = self.processor.tokenizer
88
89
  else:
@@ -90,6 +91,7 @@ class TpModelWorker:
90
91
  server_args.tokenizer_path,
91
92
  tokenizer_mode=server_args.tokenizer_mode,
92
93
  trust_remote_code=server_args.trust_remote_code,
94
+ revision=server_args.revision,
93
95
  )
94
96
  self.device = self.model_runner.device
95
97
 
@@ -101,6 +103,7 @@ class TpModelWorker:
101
103
  self.max_total_num_tokens // 2
102
104
  if server_args.max_running_requests is None
103
105
  else server_args.max_running_requests
106
+ // (server_args.dp_size if server_args.enable_dp_attention else 1)
104
107
  ),
105
108
  self.model_runner.req_to_token_pool.size,
106
109
  )
@@ -142,16 +145,15 @@ class TpModelWorker:
142
145
  def get_tp_cpu_group(self):
143
146
  return self.model_runner.tp_group.cpu_group
144
147
 
148
+ def get_attention_tp_cpu_group(self):
149
+ return self.model_runner.attention_tp_group.cpu_group
150
+
145
151
  def get_memory_pool(self):
146
152
  return (
147
153
  self.model_runner.req_to_token_pool,
148
154
  self.model_runner.token_to_kv_pool,
149
155
  )
150
156
 
151
- def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
152
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
153
- self.model_runner.forward(forward_batch)
154
-
155
157
  def forward_batch_generation(
156
158
  self,
157
159
  model_worker_batch: ModelWorkerBatch,
@@ -82,6 +82,8 @@ class TpModelWorkerClient:
82
82
  self.forward_thread.start()
83
83
  self.parent_process = psutil.Process().parent()
84
84
  self.scheduler_stream = torch.get_device_module(self.device).current_stream()
85
+ if self.device == "cpu":
86
+ self.scheduler_stream.synchronize = lambda: None # No-op for CPU
85
87
 
86
88
  def get_worker_info(self):
87
89
  return self.worker.get_worker_info()
@@ -92,6 +94,9 @@ class TpModelWorkerClient:
92
94
  def get_tp_cpu_group(self):
93
95
  return self.worker.get_tp_cpu_group()
94
96
 
97
+ def get_attention_tp_cpu_group(self):
98
+ return self.worker.get_attention_tp_cpu_group()
99
+
95
100
  def get_memory_pool(self):
96
101
  return (
97
102
  self.worker.model_runner.req_to_token_pool,
@@ -151,11 +156,6 @@ class TpModelWorkerClient:
151
156
  logits_output.input_token_logprobs = (
152
157
  logits_output.input_token_logprobs.to("cpu", non_blocking=True)
153
158
  )
154
- logits_output.normalized_prompt_logprobs = (
155
- logits_output.normalized_prompt_logprobs.to(
156
- "cpu", non_blocking=True
157
- )
158
- )
159
159
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
160
160
  copy_done.record()
161
161
 
@@ -174,9 +174,6 @@ class TpModelWorkerClient:
174
174
  logits_output.input_token_logprobs = (
175
175
  logits_output.input_token_logprobs.tolist()
176
176
  )
177
- logits_output.normalized_prompt_logprobs = (
178
- logits_output.normalized_prompt_logprobs.tolist()
179
- )
180
177
  next_token_ids = next_token_ids.tolist()
181
178
  return logits_output, next_token_ids
182
179
 
@@ -0,0 +1,44 @@
1
+ import logging
2
+ from http import HTTPStatus
3
+ from typing import Optional
4
+
5
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def validate_input_length(
11
+ req: Req, max_req_input_len: int, allow_auto_truncate: bool
12
+ ) -> Optional[str]:
13
+ """Validate and potentially truncate input length.
14
+
15
+ Args:
16
+ req: The request containing input_ids to validate
17
+ max_req_input_len: Maximum allowed input length
18
+ allow_auto_truncate: Whether to truncate long inputs
19
+
20
+ Returns:
21
+ Error message if validation fails, None if successful
22
+ """
23
+ if len(req.origin_input_ids) >= max_req_input_len:
24
+ if allow_auto_truncate:
25
+ logger.warning(
26
+ "Request length is longer than the KV cache pool size or "
27
+ "the max context length. Truncated. "
28
+ f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
29
+ )
30
+ req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
31
+ return None
32
+ else:
33
+ error_msg = (
34
+ f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
35
+ f"the maximum allowed length ({max_req_input_len} tokens). "
36
+ f"Use a shorter input or enable --allow-auto-truncate."
37
+ )
38
+ logger.error(error_msg)
39
+ req.finished_reason = FINISH_ABORT(
40
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
41
+ )
42
+ return error_msg
43
+
44
+ return None