sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  import asyncio
17
17
  import copy
18
18
  import dataclasses
19
+ import json
19
20
  import logging
20
21
  import os
21
22
  import pickle
@@ -90,6 +91,8 @@ from sglang.srt.managers.io_struct import (
90
91
  ResumeMemoryOccupationReqInput,
91
92
  ResumeMemoryOccupationReqOutput,
92
93
  SessionParams,
94
+ SetInternalStateReq,
95
+ SetInternalStateReqOutput,
93
96
  SlowDownReqInput,
94
97
  SlowDownReqOutput,
95
98
  TokenizedEmbeddingReqInput,
@@ -125,10 +128,10 @@ logger = logging.getLogger(__name__)
125
128
  class ReqState:
126
129
  """Store the state a request."""
127
130
 
128
- out_list: List
131
+ out_list: List[Dict[Any, Any]]
129
132
  finished: bool
130
133
  event: asyncio.Event
131
- obj: Any
134
+ obj: Union[GenerateReqInput, EmbeddingReqInput]
132
135
 
133
136
  # For metrics
134
137
  created_time: float
@@ -139,6 +142,21 @@ class ReqState:
139
142
 
140
143
  # For streaming output
141
144
  last_output_offset: int = 0
145
+ # For incremental state update.
146
+ text: str = ""
147
+ output_ids: List[int] = dataclasses.field(default_factory=list)
148
+ input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
149
+ input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
150
+ output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
151
+ output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
152
+ input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
153
+ input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
154
+ output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
155
+ output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
156
+ input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
157
+ input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
158
+ output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
159
+ output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
142
160
 
143
161
 
144
162
  class TokenizerManager:
@@ -154,6 +172,11 @@ class TokenizerManager:
154
172
  self.enable_metrics = server_args.enable_metrics
155
173
  self.log_requests = server_args.log_requests
156
174
  self.log_requests_level = server_args.log_requests_level
175
+ self.preferred_sampling_params = (
176
+ json.loads(server_args.preferred_sampling_params)
177
+ if server_args.preferred_sampling_params
178
+ else None
179
+ )
157
180
 
158
181
  # Init inter-process communication
159
182
  context = zmq.asyncio.Context(2)
@@ -213,6 +236,7 @@ class TokenizerManager:
213
236
  # Store states
214
237
  self.no_create_loop = False
215
238
  self.rid_to_state: Dict[str, ReqState] = {}
239
+ self.health_check_failed = False
216
240
  self.gracefully_exit = False
217
241
  self.last_receive_tstamp = 0
218
242
  self.dump_requests_folder = "" # By default do not dump
@@ -240,6 +264,10 @@ class TokenizerManager:
240
264
  "model_name": self.server_args.served_model_name,
241
265
  # TODO: Add lora name/path in the future,
242
266
  },
267
+ bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
268
+ bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
269
+ bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
270
+ collect_tokens_histogram=self.server_args.collect_tokens_histogram,
243
271
  )
244
272
 
245
273
  # Communicators
@@ -267,12 +295,16 @@ class TokenizerManager:
267
295
  self.flush_cache_communicator = _Communicator(
268
296
  self.send_to_scheduler, server_args.dp_size
269
297
  )
270
- self.start_profile_communicator = _Communicator(
298
+ self.profile_communicator = _Communicator(
271
299
  self.send_to_scheduler, server_args.dp_size
272
300
  )
301
+ self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
273
302
  self.get_internal_state_communicator = _Communicator(
274
303
  self.send_to_scheduler, server_args.dp_size
275
304
  )
305
+ self.set_internal_state_communicator = _Communicator(
306
+ self.send_to_scheduler, server_args.dp_size
307
+ )
276
308
  self.expert_distribution_communicator = _Communicator(
277
309
  self.send_to_scheduler, server_args.dp_size
278
310
  )
@@ -288,6 +320,7 @@ class TokenizerManager:
288
320
  ),
289
321
  self._handle_batch_output,
290
322
  ),
323
+ (AbortReq, self._handle_abort_req),
291
324
  (OpenSessionReqOutput, self._handle_open_session_req_output),
292
325
  (
293
326
  UpdateWeightFromDiskReqOutput,
@@ -327,12 +360,16 @@ class TokenizerManager:
327
360
  ),
328
361
  (
329
362
  ProfileReqOutput,
330
- self.start_profile_communicator.handle_recv,
363
+ self.profile_communicator.handle_recv,
331
364
  ),
332
365
  (
333
366
  GetInternalStateReqOutput,
334
367
  self.get_internal_state_communicator.handle_recv,
335
368
  ),
369
+ (
370
+ SetInternalStateReqOutput,
371
+ self.set_internal_state_communicator.handle_recv,
372
+ ),
336
373
  (
337
374
  ExpertDistributionReqOutput,
338
375
  self.expert_distribution_communicator.handle_recv,
@@ -341,13 +378,14 @@ class TokenizerManager:
341
378
  ]
342
379
  )
343
380
 
381
+ # For pd disaggregtion
344
382
  self.disaggregation_mode = DisaggregationMode(
345
383
  self.server_args.disaggregation_mode
346
384
  )
347
385
  self.transfer_backend = TransferBackend(
348
386
  self.server_args.disaggregation_transfer_backend
349
387
  )
350
- # for disaggregtion, start kv boostrap server on prefill
388
+ # Start kv boostrap server on prefill
351
389
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
352
390
  # only start bootstrap server on prefill tm
353
391
  kv_bootstrap_server_class = get_kv_class(
@@ -421,14 +459,16 @@ class TokenizerManager:
421
459
  )
422
460
  input_ids = self.tokenizer.encode(input_text)
423
461
 
424
- image_inputs: Dict = await self.mm_processor.process_mm_data_async(
425
- image_data=obj.image_data,
426
- input_text=input_text or input_ids,
427
- request_obj=obj,
428
- max_req_input_len=self.max_req_input_len,
429
- )
430
- if image_inputs and "input_ids" in image_inputs:
431
- input_ids = image_inputs["input_ids"]
462
+ image_inputs: Optional[Dict] = None
463
+ if obj.contains_mm_input():
464
+ image_inputs = await self.mm_processor.process_mm_data_async(
465
+ image_data=obj.image_data,
466
+ input_text=input_text or input_ids,
467
+ request_obj=obj,
468
+ max_req_input_len=self.max_req_input_len,
469
+ )
470
+ if image_inputs and "input_ids" in image_inputs:
471
+ input_ids = image_inputs["input_ids"]
432
472
 
433
473
  self._validate_token_len(obj, input_ids)
434
474
  return self._create_tokenized_object(
@@ -482,8 +522,23 @@ class TokenizerManager:
482
522
  session_params = (
483
523
  SessionParams(**obj.session_params) if obj.session_params else None
484
524
  )
525
+ if (
526
+ obj.custom_logit_processor
527
+ and not self.server_args.enable_custom_logit_processor
528
+ ):
529
+ raise ValueError(
530
+ "The server is not configured to enable custom logit processor. "
531
+ "Please set `--enable-custom-logits-processor` to enable this feature."
532
+ )
485
533
 
486
- sampling_params = SamplingParams(**obj.sampling_params)
534
+ # Parse sampling parameters
535
+ # Note: if there are preferred sampling params, we use them if they are not
536
+ # explicitly passed in sampling_params
537
+ if self.preferred_sampling_params:
538
+ sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
539
+ else:
540
+ sampling_kwargs = obj.sampling_params
541
+ sampling_params = SamplingParams(**sampling_kwargs)
487
542
  sampling_params.normalize(self.tokenizer)
488
543
  sampling_params.verify()
489
544
 
@@ -570,9 +625,9 @@ class TokenizerManager:
570
625
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
571
626
  created_time: Optional[float] = None,
572
627
  ):
628
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
573
629
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
574
630
  self.rid_to_state[obj.rid] = state
575
- self.send_to_scheduler.send_pyobj(tokenized_obj)
576
631
 
577
632
  async def _wait_one_response(
578
633
  self,
@@ -587,10 +642,11 @@ class TokenizerManager:
587
642
  await asyncio.wait_for(state.event.wait(), timeout=4)
588
643
  except asyncio.TimeoutError:
589
644
  if request is not None and await request.is_disconnected():
645
+ # Abort the request for disconnected requests (non-streaming, waiting queue)
590
646
  self.abort_request(obj.rid)
647
+ # Use exception to kill the whole call stack and asyncio task
591
648
  raise ValueError(
592
- "Request is disconnected from the client side. "
593
- f"Abort request {obj.rid}"
649
+ f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
594
650
  )
595
651
  continue
596
652
 
@@ -605,7 +661,6 @@ class TokenizerManager:
605
661
  else:
606
662
  msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
607
663
  logger.info(msg)
608
- del self.rid_to_state[obj.rid]
609
664
 
610
665
  # Check if this was an abort/error created by scheduler
611
666
  if isinstance(out["meta_info"].get("finish_reason"), dict):
@@ -625,10 +680,11 @@ class TokenizerManager:
625
680
  yield out
626
681
  else:
627
682
  if request is not None and await request.is_disconnected():
683
+ # Abort the request for disconnected requests (non-streaming, running)
628
684
  self.abort_request(obj.rid)
685
+ # Use exception to kill the whole call stack and asyncio task
629
686
  raise ValueError(
630
- "Request is disconnected from the client side. "
631
- f"Abort request {obj.rid}"
687
+ f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
632
688
  )
633
689
 
634
690
  async def _handle_batch_request(
@@ -641,7 +697,6 @@ class TokenizerManager:
641
697
 
642
698
  generators = []
643
699
  rids = []
644
-
645
700
  if getattr(obj, "parallel_sample_num", 1) == 1:
646
701
  if self.server_args.enable_tokenizer_batch_encode:
647
702
  # Validate batch tokenization constraints
@@ -728,7 +783,6 @@ class TokenizerManager:
728
783
  def abort_request(self, rid: str):
729
784
  if rid not in self.rid_to_state:
730
785
  return
731
- del self.rid_to_state[rid]
732
786
  req = AbortReq(rid)
733
787
  self.send_to_scheduler.send_pyobj(req)
734
788
 
@@ -737,30 +791,42 @@ class TokenizerManager:
737
791
  output_dir: Optional[str] = None,
738
792
  num_steps: Optional[int] = None,
739
793
  activities: Optional[List[str]] = None,
794
+ with_stack: Optional[bool] = None,
795
+ record_shapes: Optional[bool] = None,
740
796
  ):
797
+ self.auto_create_handle_loop()
741
798
  req = ProfileReq(
742
799
  type=ProfileReqType.START_PROFILE,
743
800
  output_dir=output_dir,
744
801
  num_steps=num_steps,
745
802
  activities=activities,
803
+ with_stack=with_stack,
804
+ record_shapes=record_shapes,
746
805
  profile_id=str(time.time()),
747
806
  )
748
- result = (await self.start_profile_communicator(req))[0]
807
+ return await self._execute_profile(req)
808
+
809
+ async def stop_profile(self):
810
+ self.auto_create_handle_loop()
811
+ req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
812
+ return await self._execute_profile(req)
813
+
814
+ async def _execute_profile(self, req: ProfileReq):
815
+ result = (await self.profile_communicator(req))[0]
749
816
  if not result.success:
750
817
  raise RuntimeError(result.message)
751
818
  return result
752
819
 
753
- def stop_profile(self):
754
- req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
755
- self.send_to_scheduler.send_pyobj(req)
756
-
757
820
  async def start_expert_distribution_record(self):
821
+ self.auto_create_handle_loop()
758
822
  await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
759
823
 
760
824
  async def stop_expert_distribution_record(self):
825
+ self.auto_create_handle_loop()
761
826
  await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
762
827
 
763
828
  async def dump_expert_distribution_record(self):
829
+ self.auto_create_handle_loop()
764
830
  await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
765
831
 
766
832
  async def update_weights_from_disk(
@@ -827,8 +893,8 @@ class TokenizerManager:
827
893
  ) -> Tuple[bool, str]:
828
894
  self.auto_create_handle_loop()
829
895
  assert (
830
- self.server_args.dp_size == 1
831
- ), "dp_size must be for update weights from distributed"
896
+ self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
897
+ ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
832
898
 
833
899
  # This means that weight sync
834
900
  # cannot run while requests are in progress.
@@ -843,8 +909,8 @@ class TokenizerManager:
843
909
  ) -> Tuple[bool, str]:
844
910
  self.auto_create_handle_loop()
845
911
  assert (
846
- self.server_args.dp_size == 1
847
- ), "dp_size must be 1 for update weights from distributed"
912
+ self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
913
+ ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
848
914
 
849
915
  # This means that weight sync
850
916
  # cannot run while requests are in progress.
@@ -909,12 +975,21 @@ class TokenizerManager:
909
975
  ):
910
976
  await self.send_to_scheduler.send_pyobj(obj)
911
977
 
912
- async def get_internal_state(self) -> Dict[Any, Any]:
978
+ async def get_internal_state(self) -> List[Dict[Any, Any]]:
913
979
  req = GetInternalStateReq()
914
- res: List[GetInternalStateReqOutput] = (
980
+ responses: List[GetInternalStateReqOutput] = (
915
981
  await self.get_internal_state_communicator(req)
916
982
  )
917
- return res[0].internal_state
983
+ # Many DP ranks
984
+ return [res.internal_state for res in responses]
985
+
986
+ async def set_internal_state(
987
+ self, obj: SetInternalStateReq
988
+ ) -> SetInternalStateReqOutput:
989
+ responses: List[SetInternalStateReqOutput] = (
990
+ await self.set_internal_state_communicator(obj)
991
+ )
992
+ return [res.internal_state for res in responses]
918
993
 
919
994
  def get_log_request_metadata(self):
920
995
  max_length = None
@@ -964,7 +1039,7 @@ class TokenizerManager:
964
1039
  def create_abort_task(self, obj: GenerateReqInput):
965
1040
  # Abort the request if the client is disconnected.
966
1041
  async def abort_request():
967
- await asyncio.sleep(1)
1042
+ await asyncio.sleep(2)
968
1043
  if obj.is_single:
969
1044
  self.abort_request(obj.rid)
970
1045
  else:
@@ -985,11 +1060,17 @@ class TokenizerManager:
985
1060
  loop.create_task(print_exception_wrapper(self.handle_loop))
986
1061
  )
987
1062
 
1063
+ self.event_loop = loop
1064
+
988
1065
  # We cannot add signal handler when the tokenizer manager is not in
989
1066
  # the main thread due to the CPython limitation.
990
1067
  if threading.current_thread() is threading.main_thread():
991
1068
  signal_handler = SignalHandler(self)
992
- loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
1069
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
1070
+ # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
1071
+ loop.add_signal_handler(
1072
+ signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
1073
+ )
993
1074
  else:
994
1075
  logger.warning(
995
1076
  "Signal handler is not added because the tokenizer manager is "
@@ -1007,6 +1088,15 @@ class TokenizerManager:
1007
1088
  # Drain requests
1008
1089
  while True:
1009
1090
  remain_num_req = len(self.rid_to_state)
1091
+
1092
+ if self.health_check_failed:
1093
+ # if health check failed, we should exit immediately
1094
+ logger.error(
1095
+ "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1096
+ remain_num_req,
1097
+ )
1098
+ break
1099
+
1010
1100
  logger.info(
1011
1101
  f"Gracefully exiting... remaining number of requests {remain_num_req}"
1012
1102
  )
@@ -1035,6 +1125,9 @@ class TokenizerManager:
1035
1125
  for i, rid in enumerate(recv_obj.rids):
1036
1126
  state = self.rid_to_state.get(rid, None)
1037
1127
  if state is None:
1128
+ logger.error(
1129
+ f"Received output for {rid=} but the state was deleted in TokenizerManager."
1130
+ )
1038
1131
  continue
1039
1132
 
1040
1133
  # Build meta_info and return value
@@ -1047,9 +1140,11 @@ class TokenizerManager:
1047
1140
  if getattr(state.obj, "return_logprob", False):
1048
1141
  self.convert_logprob_style(
1049
1142
  meta_info,
1143
+ state,
1050
1144
  state.obj.top_logprobs_num,
1051
1145
  state.obj.token_ids_logprob,
1052
- state.obj.return_text_in_logprobs,
1146
+ state.obj.return_text_in_logprobs
1147
+ and not self.server_args.skip_tokenizer_init,
1053
1148
  recv_obj,
1054
1149
  i,
1055
1150
  )
@@ -1066,25 +1161,35 @@ class TokenizerManager:
1066
1161
  meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
1067
1162
 
1068
1163
  if isinstance(recv_obj, BatchStrOut):
1164
+ state.text += recv_obj.output_strs[i]
1069
1165
  out_dict = {
1070
- "text": recv_obj.output_strs[i],
1166
+ "text": state.text,
1071
1167
  "meta_info": meta_info,
1072
1168
  }
1073
1169
  elif isinstance(recv_obj, BatchTokenIDOut):
1074
1170
  if self.server_args.stream_output and state.obj.stream:
1075
- output_token_ids = recv_obj.output_ids[i][
1076
- state.last_output_offset :
1077
- ]
1078
- state.last_output_offset = len(recv_obj.output_ids[i])
1171
+ state.output_ids.extend(recv_obj.output_ids[i])
1172
+ output_token_ids = state.output_ids[state.last_output_offset :]
1173
+ state.last_output_offset = len(state.output_ids)
1079
1174
  else:
1080
- output_token_ids = recv_obj.output_ids[i]
1175
+ state.output_ids.extend(recv_obj.output_ids[i])
1176
+ output_token_ids = state.output_ids
1081
1177
 
1082
1178
  out_dict = {
1083
1179
  "output_ids": output_token_ids,
1084
1180
  "meta_info": meta_info,
1085
1181
  }
1086
1182
  elif isinstance(recv_obj, BatchMultimodalOut):
1087
- raise NotImplementedError()
1183
+ if isinstance(recv_obj.outputs[i], str):
1184
+ out_dict = {
1185
+ "text": recv_obj.outputs[i],
1186
+ "meta_info": meta_info,
1187
+ }
1188
+ else:
1189
+ out_dict = {
1190
+ "outputs": json.dumps(recv_obj.outputs[i]),
1191
+ "meta_info": meta_info,
1192
+ }
1088
1193
  else:
1089
1194
  assert isinstance(recv_obj, BatchEmbeddingOut)
1090
1195
  out_dict = {
@@ -1098,6 +1203,7 @@ class TokenizerManager:
1098
1203
  meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1099
1204
  state.finished_time = time.time()
1100
1205
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1206
+ del self.rid_to_state[rid]
1101
1207
 
1102
1208
  state.out_list.append(out_dict)
1103
1209
  state.event.set()
@@ -1111,45 +1217,85 @@ class TokenizerManager:
1111
1217
  def convert_logprob_style(
1112
1218
  self,
1113
1219
  meta_info: dict,
1220
+ state: ReqState,
1114
1221
  top_logprobs_num: int,
1115
1222
  token_ids_logprob: List[int],
1116
1223
  return_text_in_logprobs: bool,
1117
1224
  recv_obj: BatchStrOut,
1118
1225
  recv_obj_index: int,
1119
1226
  ):
1227
+ if len(recv_obj.input_token_logprobs_val) > 0:
1228
+ state.input_token_logprobs_val.extend(
1229
+ recv_obj.input_token_logprobs_val[recv_obj_index]
1230
+ )
1231
+ state.input_token_logprobs_idx.extend(
1232
+ recv_obj.input_token_logprobs_idx[recv_obj_index]
1233
+ )
1234
+ state.output_token_logprobs_val.extend(
1235
+ recv_obj.output_token_logprobs_val[recv_obj_index]
1236
+ )
1237
+ state.output_token_logprobs_idx.extend(
1238
+ recv_obj.output_token_logprobs_idx[recv_obj_index]
1239
+ )
1120
1240
  meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1121
- recv_obj.input_token_logprobs_val[recv_obj_index],
1122
- recv_obj.input_token_logprobs_idx[recv_obj_index],
1241
+ state.input_token_logprobs_val,
1242
+ state.input_token_logprobs_idx,
1123
1243
  return_text_in_logprobs,
1124
1244
  )
1125
1245
  meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1126
- recv_obj.output_token_logprobs_val[recv_obj_index],
1127
- recv_obj.output_token_logprobs_idx[recv_obj_index],
1246
+ state.output_token_logprobs_val,
1247
+ state.output_token_logprobs_idx,
1128
1248
  return_text_in_logprobs,
1129
1249
  )
1130
1250
 
1131
1251
  if top_logprobs_num > 0:
1252
+ if len(recv_obj.input_top_logprobs_val) > 0:
1253
+ state.input_top_logprobs_val.extend(
1254
+ recv_obj.input_top_logprobs_val[recv_obj_index]
1255
+ )
1256
+ state.input_top_logprobs_idx.extend(
1257
+ recv_obj.input_top_logprobs_idx[recv_obj_index]
1258
+ )
1259
+ state.output_top_logprobs_val.extend(
1260
+ recv_obj.output_top_logprobs_val[recv_obj_index]
1261
+ )
1262
+ state.output_top_logprobs_idx.extend(
1263
+ recv_obj.output_top_logprobs_idx[recv_obj_index]
1264
+ )
1132
1265
  meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1133
- recv_obj.input_top_logprobs_val[recv_obj_index],
1134
- recv_obj.input_top_logprobs_idx[recv_obj_index],
1266
+ state.input_top_logprobs_val,
1267
+ state.input_top_logprobs_idx,
1135
1268
  return_text_in_logprobs,
1136
1269
  )
1137
1270
  meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1138
- recv_obj.output_top_logprobs_val[recv_obj_index],
1139
- recv_obj.output_top_logprobs_idx[recv_obj_index],
1271
+ state.output_top_logprobs_val,
1272
+ state.output_top_logprobs_idx,
1140
1273
  return_text_in_logprobs,
1141
1274
  )
1142
1275
 
1143
1276
  if token_ids_logprob is not None:
1277
+ if len(recv_obj.input_token_ids_logprobs_val) > 0:
1278
+ state.input_token_ids_logprobs_val.extend(
1279
+ recv_obj.input_token_ids_logprobs_val[recv_obj_index]
1280
+ )
1281
+ state.input_token_ids_logprobs_idx.extend(
1282
+ recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
1283
+ )
1284
+ state.output_token_ids_logprobs_val.extend(
1285
+ recv_obj.output_token_ids_logprobs_val[recv_obj_index]
1286
+ )
1287
+ state.output_token_ids_logprobs_idx.extend(
1288
+ recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
1289
+ )
1144
1290
  meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1145
- recv_obj.input_token_ids_logprobs_val[recv_obj_index],
1146
- recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
1291
+ state.input_token_ids_logprobs_val,
1292
+ state.input_token_ids_logprobs_idx,
1147
1293
  return_text_in_logprobs,
1148
1294
  )
1149
1295
  meta_info["output_token_ids_logprobs"] = (
1150
1296
  self.detokenize_top_logprobs_tokens(
1151
- recv_obj.output_token_ids_logprobs_val[recv_obj_index],
1152
- recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
1297
+ state.output_token_ids_logprobs_val,
1298
+ state.output_token_ids_logprobs_idx,
1153
1299
  return_text_in_logprobs,
1154
1300
  )
1155
1301
  )
@@ -1216,11 +1362,18 @@ class TokenizerManager:
1216
1362
  state.last_completion_tokens = completion_tokens
1217
1363
 
1218
1364
  if state.finished:
1365
+ has_grammar = (
1366
+ state.obj.sampling_params.get("json_schema", None)
1367
+ or state.obj.sampling_params.get("regex", None)
1368
+ or state.obj.sampling_params.get("ebnf", None)
1369
+ or state.obj.sampling_params.get("structural_tag", None)
1370
+ )
1219
1371
  self.metrics_collector.observe_one_finished_request(
1220
1372
  recv_obj.prompt_tokens[i],
1221
1373
  completion_tokens,
1222
1374
  recv_obj.cached_tokens[i],
1223
1375
  state.finished_time - state.created_time,
1376
+ has_grammar,
1224
1377
  )
1225
1378
 
1226
1379
  def dump_requests(self, state: ReqState, out_dict: dict):
@@ -1246,6 +1399,9 @@ class TokenizerManager:
1246
1399
  # Schedule the task to run in the background without awaiting it
1247
1400
  asyncio.create_task(asyncio.to_thread(background_task))
1248
1401
 
1402
+ def _handle_abort_req(self, recv_obj):
1403
+ self.rid_to_state.pop(recv_obj.rid)
1404
+
1249
1405
  def _handle_open_session_req_output(self, recv_obj):
1250
1406
  self.session_futures[recv_obj.session_id].set_result(
1251
1407
  recv_obj.session_id if recv_obj.success else None
@@ -1256,7 +1412,7 @@ class TokenizerManager:
1256
1412
  self.model_update_result.set_result(recv_obj)
1257
1413
  else: # self.server_args.dp_size > 1
1258
1414
  self.model_update_tmp.append(recv_obj)
1259
- # set future if the all results are recevied
1415
+ # set future if the all results are received
1260
1416
  if len(self.model_update_tmp) == self.server_args.dp_size:
1261
1417
  self.model_update_result.set_result(self.model_update_tmp)
1262
1418
 
@@ -1279,12 +1435,18 @@ class SignalHandler:
1279
1435
  def __init__(self, tokenizer_manager: TokenizerManager):
1280
1436
  self.tokenizer_manager = tokenizer_manager
1281
1437
 
1282
- def signal_handler(self, signum=None, frame=None):
1438
+ def sigterm_handler(self, signum=None, frame=None):
1283
1439
  logger.warning(
1284
1440
  f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
1285
1441
  )
1286
1442
  self.tokenizer_manager.gracefully_exit = True
1287
1443
 
1444
+ def running_phase_sigquit_handler(self, signum=None, frame=None):
1445
+ logger.error(
1446
+ "Received sigquit from a child process. It usually means the child failed."
1447
+ )
1448
+ kill_process_tree(os.getpid())
1449
+
1288
1450
 
1289
1451
  T = TypeVar("T")
1290
1452
 
@@ -1325,3 +1487,15 @@ class _Communicator(Generic[T]):
1325
1487
  self._result_values.append(recv_obj)
1326
1488
  if len(self._result_values) == self._fan_out:
1327
1489
  self._result_event.set()
1490
+
1491
+
1492
+ # Note: request abort handling logic
1493
+ # We should handle all of the following cases correctly.
1494
+ #
1495
+ # | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
1496
+ # | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
1497
+ # | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
1498
+ # | http | yes | running | background task | fast api | del in _handle_batch_output |
1499
+ # | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
1500
+ # | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
1501
+ #