sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,11 @@ from sglang.srt.disaggregation.utils import (
54
54
  TransferBackend,
55
55
  get_kv_class,
56
56
  )
57
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
57
+ from sglang.srt.hf_transformers_utils import (
58
+ get_processor,
59
+ get_tokenizer,
60
+ get_tokenizer_from_processor,
61
+ )
58
62
  from sglang.srt.managers.io_struct import (
59
63
  AbortReq,
60
64
  BatchEmbeddingOut,
@@ -86,6 +90,8 @@ from sglang.srt.managers.io_struct import (
86
90
  ResumeMemoryOccupationReqInput,
87
91
  ResumeMemoryOccupationReqOutput,
88
92
  SessionParams,
93
+ SlowDownReqInput,
94
+ SlowDownReqOutput,
89
95
  TokenizedEmbeddingReqInput,
90
96
  TokenizedGenerateReqInput,
91
97
  UpdateWeightFromDiskReqInput,
@@ -119,10 +125,10 @@ logger = logging.getLogger(__name__)
119
125
  class ReqState:
120
126
  """Store the state a request."""
121
127
 
122
- out_list: List
128
+ out_list: List[Dict[Any, Any]]
123
129
  finished: bool
124
130
  event: asyncio.Event
125
- obj: Any
131
+ obj: Union[GenerateReqInput, EmbeddingReqInput]
126
132
 
127
133
  # For metrics
128
134
  created_time: float
@@ -133,6 +139,21 @@ class ReqState:
133
139
 
134
140
  # For streaming output
135
141
  last_output_offset: int = 0
142
+ # For incremental state update.
143
+ text: str = ""
144
+ output_ids: List[int] = dataclasses.field(default_factory=list)
145
+ input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
146
+ input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
147
+ output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
148
+ output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
149
+ input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
150
+ input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
151
+ output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
152
+ output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
153
+ input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
154
+ input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
155
+ output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
156
+ output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
136
157
 
137
158
 
138
159
  class TokenizerManager:
@@ -161,17 +182,7 @@ class TokenizerManager:
161
182
  # Read model args
162
183
  self.model_path = server_args.model_path
163
184
  self.served_model_name = server_args.served_model_name
164
- self.model_config = ModelConfig(
165
- server_args.model_path,
166
- trust_remote_code=server_args.trust_remote_code,
167
- revision=server_args.revision,
168
- context_length=server_args.context_length,
169
- model_override_args=server_args.json_model_override_args,
170
- is_embedding=server_args.is_embedding,
171
- enable_multimodal=server_args.enable_multimodal,
172
- dtype=server_args.dtype,
173
- quantization=server_args.quantization,
174
- )
185
+ self.model_config = ModelConfig.from_server_args(server_args)
175
186
 
176
187
  self.is_generation = self.model_config.is_generation
177
188
  self.is_image_gen = self.model_config.is_image_gen
@@ -199,7 +210,7 @@ class TokenizerManager:
199
210
  self.tokenizer = self.processor = None
200
211
  else:
201
212
  self.processor = _processor
202
- self.tokenizer = self.processor.tokenizer
213
+ self.tokenizer = get_tokenizer_from_processor(self.processor)
203
214
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
204
215
  else:
205
216
  self.mm_processor = get_dummy_processor()
@@ -265,6 +276,9 @@ class TokenizerManager:
265
276
  self.resume_memory_occupation_communicator = _Communicator(
266
277
  self.send_to_scheduler, server_args.dp_size
267
278
  )
279
+ self.slow_down_communicator = _Communicator(
280
+ self.send_to_scheduler, server_args.dp_size
281
+ )
268
282
  self.flush_cache_communicator = _Communicator(
269
283
  self.send_to_scheduler, server_args.dp_size
270
284
  )
@@ -289,6 +303,7 @@ class TokenizerManager:
289
303
  ),
290
304
  self._handle_batch_output,
291
305
  ),
306
+ (AbortReq, self._handle_abort_req),
292
307
  (OpenSessionReqOutput, self._handle_open_session_req_output),
293
308
  (
294
309
  UpdateWeightFromDiskReqOutput,
@@ -318,6 +333,10 @@ class TokenizerManager:
318
333
  ResumeMemoryOccupationReqOutput,
319
334
  self.resume_memory_occupation_communicator.handle_recv,
320
335
  ),
336
+ (
337
+ SlowDownReqOutput,
338
+ self.slow_down_communicator.handle_recv,
339
+ ),
321
340
  (
322
341
  FlushCacheReqOutput,
323
342
  self.flush_cache_communicator.handle_recv,
@@ -338,13 +357,14 @@ class TokenizerManager:
338
357
  ]
339
358
  )
340
359
 
360
+ # For pd disaggregtion
341
361
  self.disaggregation_mode = DisaggregationMode(
342
362
  self.server_args.disaggregation_mode
343
363
  )
344
364
  self.transfer_backend = TransferBackend(
345
365
  self.server_args.disaggregation_transfer_backend
346
366
  )
347
- # for disaggregtion, start kv boostrap server on prefill
367
+ # Start kv boostrap server on prefill
348
368
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
349
369
  # only start bootstrap server on prefill tm
350
370
  kv_bootstrap_server_class = get_kv_class(
@@ -479,6 +499,14 @@ class TokenizerManager:
479
499
  session_params = (
480
500
  SessionParams(**obj.session_params) if obj.session_params else None
481
501
  )
502
+ if (
503
+ obj.custom_logit_processor
504
+ and not self.server_args.enable_custom_logit_processor
505
+ ):
506
+ raise ValueError(
507
+ "The server is not configured to enable custom logit processor. "
508
+ "Please set `--enable-custom-logits-processor` to enable this feature."
509
+ )
482
510
 
483
511
  sampling_params = SamplingParams(**obj.sampling_params)
484
512
  sampling_params.normalize(self.tokenizer)
@@ -567,9 +595,9 @@ class TokenizerManager:
567
595
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
568
596
  created_time: Optional[float] = None,
569
597
  ):
598
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
570
599
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
571
600
  self.rid_to_state[obj.rid] = state
572
- self.send_to_scheduler.send_pyobj(tokenized_obj)
573
601
 
574
602
  async def _wait_one_response(
575
603
  self,
@@ -584,10 +612,11 @@ class TokenizerManager:
584
612
  await asyncio.wait_for(state.event.wait(), timeout=4)
585
613
  except asyncio.TimeoutError:
586
614
  if request is not None and await request.is_disconnected():
615
+ # Abort the request for disconnected requests (non-streaming, waiting queue)
587
616
  self.abort_request(obj.rid)
617
+ # Use exception to kill the whole call stack and asyncio task
588
618
  raise ValueError(
589
- "Request is disconnected from the client side. "
590
- f"Abort request {obj.rid}"
619
+ f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
591
620
  )
592
621
  continue
593
622
 
@@ -602,7 +631,6 @@ class TokenizerManager:
602
631
  else:
603
632
  msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
604
633
  logger.info(msg)
605
- del self.rid_to_state[obj.rid]
606
634
 
607
635
  # Check if this was an abort/error created by scheduler
608
636
  if isinstance(out["meta_info"].get("finish_reason"), dict):
@@ -622,10 +650,11 @@ class TokenizerManager:
622
650
  yield out
623
651
  else:
624
652
  if request is not None and await request.is_disconnected():
653
+ # Abort the request for disconnected requests (non-streaming, running)
625
654
  self.abort_request(obj.rid)
655
+ # Use exception to kill the whole call stack and asyncio task
626
656
  raise ValueError(
627
- "Request is disconnected from the client side. "
628
- f"Abort request {obj.rid}"
657
+ f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
629
658
  )
630
659
 
631
660
  async def _handle_batch_request(
@@ -725,7 +754,6 @@ class TokenizerManager:
725
754
  def abort_request(self, rid: str):
726
755
  if rid not in self.rid_to_state:
727
756
  return
728
- del self.rid_to_state[rid]
729
757
  req = AbortReq(rid)
730
758
  self.send_to_scheduler.send_pyobj(req)
731
759
 
@@ -734,12 +762,16 @@ class TokenizerManager:
734
762
  output_dir: Optional[str] = None,
735
763
  num_steps: Optional[int] = None,
736
764
  activities: Optional[List[str]] = None,
765
+ with_stack: Optional[bool] = None,
766
+ record_shapes: Optional[bool] = None,
737
767
  ):
738
768
  req = ProfileReq(
739
769
  type=ProfileReqType.START_PROFILE,
740
770
  output_dir=output_dir,
741
771
  num_steps=num_steps,
742
772
  activities=activities,
773
+ with_stack=with_stack,
774
+ record_shapes=record_shapes,
743
775
  profile_id=str(time.time()),
744
776
  )
745
777
  result = (await self.start_profile_communicator(req))[0]
@@ -876,6 +908,14 @@ class TokenizerManager:
876
908
  self.auto_create_handle_loop()
877
909
  await self.resume_memory_occupation_communicator(obj)
878
910
 
911
+ async def slow_down(
912
+ self,
913
+ obj: SlowDownReqInput,
914
+ request: Optional[fastapi.Request] = None,
915
+ ):
916
+ self.auto_create_handle_loop()
917
+ await self.slow_down_communicator(obj)
918
+
879
919
  async def open_session(
880
920
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
881
921
  ):
@@ -898,12 +938,13 @@ class TokenizerManager:
898
938
  ):
899
939
  await self.send_to_scheduler.send_pyobj(obj)
900
940
 
901
- async def get_internal_state(self) -> Dict[Any, Any]:
941
+ async def get_internal_state(self) -> List[Dict[Any, Any]]:
902
942
  req = GetInternalStateReq()
903
- res: List[GetInternalStateReqOutput] = (
943
+ responses: List[GetInternalStateReqOutput] = (
904
944
  await self.get_internal_state_communicator(req)
905
945
  )
906
- return res[0].internal_state
946
+ # Many DP ranks
947
+ return [res.internal_state for res in responses]
907
948
 
908
949
  def get_log_request_metadata(self):
909
950
  max_length = None
@@ -953,7 +994,7 @@ class TokenizerManager:
953
994
  def create_abort_task(self, obj: GenerateReqInput):
954
995
  # Abort the request if the client is disconnected.
955
996
  async def abort_request():
956
- await asyncio.sleep(1)
997
+ await asyncio.sleep(2)
957
998
  if obj.is_single:
958
999
  self.abort_request(obj.rid)
959
1000
  else:
@@ -1024,6 +1065,9 @@ class TokenizerManager:
1024
1065
  for i, rid in enumerate(recv_obj.rids):
1025
1066
  state = self.rid_to_state.get(rid, None)
1026
1067
  if state is None:
1068
+ logger.error(
1069
+ f"Received output for {rid=} but the state was deleted in TokenizerManager."
1070
+ )
1027
1071
  continue
1028
1072
 
1029
1073
  # Build meta_info and return value
@@ -1036,9 +1080,11 @@ class TokenizerManager:
1036
1080
  if getattr(state.obj, "return_logprob", False):
1037
1081
  self.convert_logprob_style(
1038
1082
  meta_info,
1083
+ state,
1039
1084
  state.obj.top_logprobs_num,
1040
1085
  state.obj.token_ids_logprob,
1041
- state.obj.return_text_in_logprobs,
1086
+ state.obj.return_text_in_logprobs
1087
+ and not self.server_args.skip_tokenizer_init,
1042
1088
  recv_obj,
1043
1089
  i,
1044
1090
  )
@@ -1055,18 +1101,19 @@ class TokenizerManager:
1055
1101
  meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
1056
1102
 
1057
1103
  if isinstance(recv_obj, BatchStrOut):
1104
+ state.text += recv_obj.output_strs[i]
1058
1105
  out_dict = {
1059
- "text": recv_obj.output_strs[i],
1106
+ "text": state.text,
1060
1107
  "meta_info": meta_info,
1061
1108
  }
1062
1109
  elif isinstance(recv_obj, BatchTokenIDOut):
1063
1110
  if self.server_args.stream_output and state.obj.stream:
1064
- output_token_ids = recv_obj.output_ids[i][
1065
- state.last_output_offset :
1066
- ]
1067
- state.last_output_offset = len(recv_obj.output_ids[i])
1111
+ state.output_ids.extend(recv_obj.output_ids[i])
1112
+ output_token_ids = state.output_ids[state.last_output_offset :]
1113
+ state.last_output_offset = len(state.output_ids)
1068
1114
  else:
1069
- output_token_ids = recv_obj.output_ids[i]
1115
+ state.output_ids.extend(recv_obj.output_ids[i])
1116
+ output_token_ids = state.output_ids
1070
1117
 
1071
1118
  out_dict = {
1072
1119
  "output_ids": output_token_ids,
@@ -1087,6 +1134,7 @@ class TokenizerManager:
1087
1134
  meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1088
1135
  state.finished_time = time.time()
1089
1136
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1137
+ del self.rid_to_state[rid]
1090
1138
 
1091
1139
  state.out_list.append(out_dict)
1092
1140
  state.event.set()
@@ -1100,45 +1148,85 @@ class TokenizerManager:
1100
1148
  def convert_logprob_style(
1101
1149
  self,
1102
1150
  meta_info: dict,
1151
+ state: ReqState,
1103
1152
  top_logprobs_num: int,
1104
1153
  token_ids_logprob: List[int],
1105
1154
  return_text_in_logprobs: bool,
1106
1155
  recv_obj: BatchStrOut,
1107
1156
  recv_obj_index: int,
1108
1157
  ):
1158
+ if len(recv_obj.input_token_logprobs_val) > 0:
1159
+ state.input_token_logprobs_val.extend(
1160
+ recv_obj.input_token_logprobs_val[recv_obj_index]
1161
+ )
1162
+ state.input_token_logprobs_idx.extend(
1163
+ recv_obj.input_token_logprobs_idx[recv_obj_index]
1164
+ )
1165
+ state.output_token_logprobs_val.extend(
1166
+ recv_obj.output_token_logprobs_val[recv_obj_index]
1167
+ )
1168
+ state.output_token_logprobs_idx.extend(
1169
+ recv_obj.output_token_logprobs_idx[recv_obj_index]
1170
+ )
1109
1171
  meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1110
- recv_obj.input_token_logprobs_val[recv_obj_index],
1111
- recv_obj.input_token_logprobs_idx[recv_obj_index],
1172
+ state.input_token_logprobs_val,
1173
+ state.input_token_logprobs_idx,
1112
1174
  return_text_in_logprobs,
1113
1175
  )
1114
1176
  meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1115
- recv_obj.output_token_logprobs_val[recv_obj_index],
1116
- recv_obj.output_token_logprobs_idx[recv_obj_index],
1177
+ state.output_token_logprobs_val,
1178
+ state.output_token_logprobs_idx,
1117
1179
  return_text_in_logprobs,
1118
1180
  )
1119
1181
 
1120
1182
  if top_logprobs_num > 0:
1183
+ if len(recv_obj.input_top_logprobs_val) > 0:
1184
+ state.input_top_logprobs_val.extend(
1185
+ recv_obj.input_top_logprobs_val[recv_obj_index]
1186
+ )
1187
+ state.input_top_logprobs_idx.extend(
1188
+ recv_obj.input_top_logprobs_idx[recv_obj_index]
1189
+ )
1190
+ state.output_top_logprobs_val.extend(
1191
+ recv_obj.output_top_logprobs_val[recv_obj_index]
1192
+ )
1193
+ state.output_top_logprobs_idx.extend(
1194
+ recv_obj.output_top_logprobs_idx[recv_obj_index]
1195
+ )
1121
1196
  meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1122
- recv_obj.input_top_logprobs_val[recv_obj_index],
1123
- recv_obj.input_top_logprobs_idx[recv_obj_index],
1197
+ state.input_top_logprobs_val,
1198
+ state.input_top_logprobs_idx,
1124
1199
  return_text_in_logprobs,
1125
1200
  )
1126
1201
  meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1127
- recv_obj.output_top_logprobs_val[recv_obj_index],
1128
- recv_obj.output_top_logprobs_idx[recv_obj_index],
1202
+ state.output_top_logprobs_val,
1203
+ state.output_top_logprobs_idx,
1129
1204
  return_text_in_logprobs,
1130
1205
  )
1131
1206
 
1132
1207
  if token_ids_logprob is not None:
1208
+ if len(recv_obj.input_token_ids_logprobs_val) > 0:
1209
+ state.input_token_ids_logprobs_val.extend(
1210
+ recv_obj.input_token_ids_logprobs_val[recv_obj_index]
1211
+ )
1212
+ state.input_token_ids_logprobs_idx.extend(
1213
+ recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
1214
+ )
1215
+ state.output_token_ids_logprobs_val.extend(
1216
+ recv_obj.output_token_ids_logprobs_val[recv_obj_index]
1217
+ )
1218
+ state.output_token_ids_logprobs_idx.extend(
1219
+ recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
1220
+ )
1133
1221
  meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1134
- recv_obj.input_token_ids_logprobs_val[recv_obj_index],
1135
- recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
1222
+ state.input_token_ids_logprobs_val,
1223
+ state.input_token_ids_logprobs_idx,
1136
1224
  return_text_in_logprobs,
1137
1225
  )
1138
1226
  meta_info["output_token_ids_logprobs"] = (
1139
1227
  self.detokenize_top_logprobs_tokens(
1140
- recv_obj.output_token_ids_logprobs_val[recv_obj_index],
1141
- recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
1228
+ state.output_token_ids_logprobs_val,
1229
+ state.output_token_ids_logprobs_idx,
1142
1230
  return_text_in_logprobs,
1143
1231
  )
1144
1232
  )
@@ -1205,11 +1293,18 @@ class TokenizerManager:
1205
1293
  state.last_completion_tokens = completion_tokens
1206
1294
 
1207
1295
  if state.finished:
1296
+ has_grammar = (
1297
+ state.obj.sampling_params.get("json_schema", None)
1298
+ or state.obj.sampling_params.get("regex", None)
1299
+ or state.obj.sampling_params.get("ebnf", None)
1300
+ or state.obj.sampling_params.get("structural_tag", None)
1301
+ )
1208
1302
  self.metrics_collector.observe_one_finished_request(
1209
1303
  recv_obj.prompt_tokens[i],
1210
1304
  completion_tokens,
1211
1305
  recv_obj.cached_tokens[i],
1212
1306
  state.finished_time - state.created_time,
1307
+ has_grammar,
1213
1308
  )
1214
1309
 
1215
1310
  def dump_requests(self, state: ReqState, out_dict: dict):
@@ -1235,6 +1330,9 @@ class TokenizerManager:
1235
1330
  # Schedule the task to run in the background without awaiting it
1236
1331
  asyncio.create_task(asyncio.to_thread(background_task))
1237
1332
 
1333
+ def _handle_abort_req(self, recv_obj):
1334
+ self.rid_to_state.pop(recv_obj.rid)
1335
+
1238
1336
  def _handle_open_session_req_output(self, recv_obj):
1239
1337
  self.session_futures[recv_obj.session_id].set_result(
1240
1338
  recv_obj.session_id if recv_obj.success else None
@@ -1245,7 +1343,7 @@ class TokenizerManager:
1245
1343
  self.model_update_result.set_result(recv_obj)
1246
1344
  else: # self.server_args.dp_size > 1
1247
1345
  self.model_update_tmp.append(recv_obj)
1248
- # set future if the all results are recevied
1346
+ # set future if the all results are received
1249
1347
  if len(self.model_update_tmp) == self.server_args.dp_size:
1250
1348
  self.model_update_result.set_result(self.model_update_tmp)
1251
1349
 
@@ -1314,3 +1412,15 @@ class _Communicator(Generic[T]):
1314
1412
  self._result_values.append(recv_obj)
1315
1413
  if len(self._result_values) == self._fan_out:
1316
1414
  self._result_event.set()
1415
+
1416
+
1417
+ # Note: request abort handling logic
1418
+ # We should handle all of the following cases correctly.
1419
+ #
1420
+ # | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
1421
+ # | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
1422
+ # | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
1423
+ # | http | yes | running | background task | fast api | del in _handle_batch_output |
1424
+ # | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
1425
+ # | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
1426
+ #
@@ -20,8 +20,12 @@ from typing import Optional, Tuple, Union
20
20
  import torch
21
21
 
22
22
  from sglang.srt.configs.model_config import ModelConfig
23
- from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
24
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
23
+ from sglang.srt.distributed import get_pp_group, get_world_group
24
+ from sglang.srt.hf_transformers_utils import (
25
+ get_processor,
26
+ get_tokenizer,
27
+ get_tokenizer_from_processor,
28
+ )
25
29
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
26
30
  from sglang.srt.managers.io_struct import (
27
31
  GetWeightsByNameReqInput,
@@ -61,20 +65,13 @@ class TpModelWorker:
61
65
  self.pp_rank = pp_rank
62
66
 
63
67
  # Init model and tokenizer
64
- self.model_config = ModelConfig(
65
- (
68
+ self.model_config = ModelConfig.from_server_args(
69
+ server_args,
70
+ model_path=(
66
71
  server_args.model_path
67
72
  if not is_draft_worker
68
73
  else server_args.speculative_draft_model_path
69
74
  ),
70
- trust_remote_code=server_args.trust_remote_code,
71
- revision=server_args.revision,
72
- context_length=server_args.context_length,
73
- model_override_args=server_args.json_model_override_args,
74
- is_embedding=server_args.is_embedding,
75
- enable_multimodal=server_args.enable_multimodal,
76
- dtype=server_args.dtype,
77
- quantization=server_args.quantization,
78
75
  is_draft_model=is_draft_worker,
79
76
  )
80
77
 
@@ -102,7 +99,7 @@ class TpModelWorker:
102
99
  trust_remote_code=server_args.trust_remote_code,
103
100
  revision=server_args.revision,
104
101
  )
105
- self.tokenizer = self.processor.tokenizer
102
+ self.tokenizer = get_tokenizer_from_processor(self.processor)
106
103
  else:
107
104
  self.tokenizer = get_tokenizer(
108
105
  server_args.tokenizer_path,
@@ -186,8 +183,11 @@ class TpModelWorker:
186
183
  def forward_batch_generation(
187
184
  self,
188
185
  model_worker_batch: ModelWorkerBatch,
186
+ launch_done: Optional[threading.Event] = None,
189
187
  skip_sample: bool = False,
190
- ) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
188
+ ) -> Tuple[
189
+ Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
190
+ ]:
191
191
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
192
192
 
193
193
  pp_proxy_tensors = None
@@ -199,11 +199,11 @@ class TpModelWorker:
199
199
  )
200
200
 
201
201
  if self.pp_group.is_last_rank:
202
- logits_output = self.model_runner.forward(
202
+ logits_output, can_run_cuda_graph = self.model_runner.forward(
203
203
  forward_batch, pp_proxy_tensors=pp_proxy_tensors
204
204
  )
205
- if model_worker_batch.launch_done is not None:
206
- model_worker_batch.launch_done.set()
205
+ if launch_done is not None:
206
+ launch_done.set()
207
207
 
208
208
  if skip_sample:
209
209
  next_token_ids = None
@@ -212,17 +212,17 @@ class TpModelWorker:
212
212
  logits_output, model_worker_batch
213
213
  )
214
214
 
215
- return logits_output, next_token_ids
215
+ return logits_output, next_token_ids, can_run_cuda_graph
216
216
  else:
217
- pp_proxy_tensors = self.model_runner.forward(
217
+ pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
218
218
  forward_batch,
219
219
  pp_proxy_tensors=pp_proxy_tensors,
220
220
  )
221
- return pp_proxy_tensors.tensors, None
221
+ return pp_proxy_tensors.tensors, None, can_run_cuda_graph
222
222
 
223
223
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
224
224
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
225
- logits_output = self.model_runner.forward(forward_batch)
225
+ logits_output, _ = self.model_runner.forward(forward_batch)
226
226
  embeddings = logits_output.embeddings
227
227
  return embeddings
228
228
 
@@ -18,7 +18,7 @@ import logging
18
18
  import signal
19
19
  import threading
20
20
  from queue import Queue
21
- from typing import Optional
21
+ from typing import Optional, Tuple
22
22
 
23
23
  import psutil
24
24
  import torch
@@ -127,10 +127,12 @@ class TpModelWorkerClient:
127
127
  batch_lists = [None] * 2
128
128
 
129
129
  while True:
130
- model_worker_batch, future_token_ids_ct = self.input_queue.get()
130
+ model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
131
131
  if not model_worker_batch:
132
132
  break
133
133
 
134
+ sync_event.wait()
135
+
134
136
  # Keep a reference of model_worker_batch by storing it into a list.
135
137
  # Otherwise, the tensor members of model_worker_batch will be released
136
138
  # by pytorch and cause CUDA illegal memory access errors.
@@ -145,8 +147,10 @@ class TpModelWorkerClient:
145
147
  resolve_future_token_ids(input_ids, self.future_token_ids_map)
146
148
 
147
149
  # Run forward
148
- logits_output, next_token_ids = self.worker.forward_batch_generation(
149
- model_worker_batch
150
+ logits_output, next_token_ids, can_run_cuda_graph = (
151
+ self.worker.forward_batch_generation(
152
+ model_worker_batch, model_worker_batch.launch_done
153
+ )
150
154
  )
151
155
 
152
156
  # Update the future token ids map
@@ -171,14 +175,18 @@ class TpModelWorkerClient:
171
175
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
172
176
  copy_done.record()
173
177
 
174
- self.output_queue.put((copy_done, logits_output, next_token_ids))
178
+ self.output_queue.put(
179
+ (copy_done, logits_output, next_token_ids, can_run_cuda_graph)
180
+ )
175
181
 
176
182
  def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
177
183
  """
178
184
  This function is called to resolve the last batch result and
179
185
  wait for the current batch to be launched. Used in overlap mode.
180
186
  """
181
- copy_done, logits_output, next_token_ids = self.output_queue.get()
187
+ copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
188
+ self.output_queue.get()
189
+ )
182
190
 
183
191
  if launch_done is not None:
184
192
  launch_done.wait()
@@ -193,9 +201,11 @@ class TpModelWorkerClient:
193
201
  logits_output.input_token_logprobs.tolist()
194
202
  )
195
203
  next_token_ids = next_token_ids.tolist()
196
- return logits_output, next_token_ids
204
+ return logits_output, next_token_ids, can_run_cuda_graph
197
205
 
198
- def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
206
+ def forward_batch_generation(
207
+ self, model_worker_batch: ModelWorkerBatch
208
+ ) -> Tuple[None, torch.Tensor, bool]:
199
209
  # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
200
210
  sampling_info = model_worker_batch.sampling_info
201
211
  sampling_info.update_penalties()
@@ -206,10 +216,11 @@ class TpModelWorkerClient:
206
216
  )
207
217
 
208
218
  # A cuda stream sync here to avoid the cuda illegal memory access error.
209
- self.scheduler_stream.synchronize()
219
+ sync_event = torch.get_device_module(self.device).Event()
220
+ sync_event.record(self.scheduler_stream)
210
221
 
211
222
  # Push a new batch to the queue
212
- self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
223
+ self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
213
224
 
214
225
  # Allocate output future objects
215
226
  bs = len(model_worker_batch.seq_lens)
@@ -223,7 +234,7 @@ class TpModelWorkerClient:
223
234
  self.future_token_ids_ct = (
224
235
  self.future_token_ids_ct + bs
225
236
  ) % self.future_token_ids_limit
226
- return None, future_next_token_ids
237
+ return None, future_next_token_ids, False
227
238
 
228
239
  def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
229
240
  success, message = self.worker.update_weights_from_disk(recv_req)
@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache):
24
24
  self,
25
25
  req_to_token_pool: ReqToTokenPool,
26
26
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27
+ page_size: int,
27
28
  ):
28
29
  self.req_to_token_pool = req_to_token_pool
29
30
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
31
+ self.page_size = page_size
30
32
 
31
33
  def reset(self):
32
34
  pass
@@ -374,9 +374,9 @@ class MHATokenToKVPool(KVCache):
374
374
  # Overlap the copy of K and V cache for small batch size
375
375
  current_stream = self.device_module.current_stream()
376
376
  self.alt_stream.wait_stream(current_stream)
377
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
377
378
  with self.device_module.stream(self.alt_stream):
378
- self.k_buffer[layer_id - self.start_layer][loc] = cache_k
379
- self.v_buffer[layer_id - self.start_layer][loc] = cache_v
379
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
380
380
  current_stream.wait_stream(self.alt_stream)
381
381
  else:
382
382
  self.k_buffer[layer_id - self.start_layer][loc] = cache_k
@@ -762,6 +762,8 @@ class HostKVCache(abc.ABC):
762
762
  self.size = int(device_pool.size * host_to_device_ratio)
763
763
  # Align the host memory pool size to the page size
764
764
  self.size = self.size - (self.size % self.page_size)
765
+ self.start_layer = device_pool.start_layer
766
+ self.end_layer = device_pool.end_layer
765
767
 
766
768
  assert (
767
769
  self.size > device_pool.size