sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import threading
5
+ import time
4
6
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
5
7
 
8
+ from sglang.srt.disaggregation.utils import DisaggregationMode
6
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
7
10
  from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
8
11
  from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
@@ -15,6 +18,10 @@ if TYPE_CHECKING:
15
18
  Scheduler,
16
19
  )
17
20
 
21
+ logger = logging.getLogger(__name__)
22
+
23
+ DEFAULT_FORCE_STREAM_INTERVAL = 50
24
+
18
25
 
19
26
  class SchedulerOutputProcessorMixin:
20
27
  """
@@ -36,20 +43,16 @@ class SchedulerOutputProcessorMixin:
36
43
  next_token_ids,
37
44
  extend_input_len_per_req,
38
45
  extend_logprob_start_len_per_req,
39
- bid,
40
46
  ) = (
41
47
  result.logits_output,
42
48
  result.next_token_ids,
43
49
  result.extend_input_len_per_req,
44
50
  result.extend_logprob_start_len_per_req,
45
- result.bid,
46
51
  )
47
52
 
48
53
  if self.enable_overlap:
49
- logits_output, next_token_ids = (
50
- self.tp_worker.resolve_last_batch_result(
51
- launch_done,
52
- )
54
+ logits_output, next_token_ids, _ = (
55
+ self.tp_worker.resolve_last_batch_result(launch_done)
53
56
  )
54
57
  else:
55
58
  # Move next_token_ids and logprobs to cpu
@@ -85,6 +88,7 @@ class SchedulerOutputProcessorMixin:
85
88
 
86
89
  if req.finished():
87
90
  self.tree_cache.cache_finished_req(req)
91
+ req.time_stats.completion_time = time.time()
88
92
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
89
93
  # This updates radix so others can match
90
94
  self.tree_cache.cache_unfinished_req(req)
@@ -151,10 +155,7 @@ class SchedulerOutputProcessorMixin:
151
155
  )
152
156
  logprob_pt += num_input_logprobs
153
157
 
154
- if batch.next_batch_sampling_info:
155
- batch.next_batch_sampling_info.update_regex_vocab_mask()
156
- self.current_stream.synchronize()
157
- batch.next_batch_sampling_info.sampling_info_done.set()
158
+ self.set_next_batch_sampling_info_done(batch)
158
159
 
159
160
  else: # embedding or reward model
160
161
  embeddings, bid = result.embeddings, result.bid
@@ -187,16 +188,16 @@ class SchedulerOutputProcessorMixin:
187
188
  result: GenerationBatchResult,
188
189
  launch_done: Optional[threading.Event] = None,
189
190
  ):
190
- logits_output, next_token_ids, bid = (
191
+ logits_output, next_token_ids, can_run_cuda_graph = (
191
192
  result.logits_output,
192
193
  result.next_token_ids,
193
- result.bid,
194
+ result.can_run_cuda_graph,
194
195
  )
195
196
  self.num_generated_tokens += len(batch.reqs)
196
197
 
197
198
  if self.enable_overlap:
198
- logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
199
- launch_done
199
+ logits_output, next_token_ids, can_run_cuda_graph = (
200
+ self.tp_worker.resolve_last_batch_result(launch_done)
200
201
  )
201
202
  next_token_logprobs = logits_output.next_token_logprobs
202
203
  elif batch.spec_algorithm.is_none():
@@ -235,6 +236,7 @@ class SchedulerOutputProcessorMixin:
235
236
  req.check_finished()
236
237
  if req.finished():
237
238
  self.tree_cache.cache_finished_req(req)
239
+ req.time_stats.completion_time = time.time()
238
240
 
239
241
  if req.return_logprob and batch.spec_algorithm.is_none():
240
242
  # speculative worker handles logprob in speculative decoding
@@ -264,13 +266,8 @@ class SchedulerOutputProcessorMixin:
264
266
  req.grammar.accept_token(next_token_id)
265
267
  req.grammar.finished = req.finished()
266
268
 
267
- if batch.next_batch_sampling_info:
268
- batch.next_batch_sampling_info.update_regex_vocab_mask()
269
- self.current_stream.synchronize()
270
- batch.next_batch_sampling_info.sampling_info_done.set()
271
-
269
+ self.set_next_batch_sampling_info_done(batch)
272
270
  self.stream_output(batch.reqs, batch.return_logprob)
273
-
274
271
  self.token_to_kv_pool_allocator.free_group_end()
275
272
 
276
273
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
@@ -278,7 +275,7 @@ class SchedulerOutputProcessorMixin:
278
275
  self.attn_tp_rank == 0
279
276
  and self.forward_ct_decode % self.server_args.decode_log_interval == 0
280
277
  ):
281
- self.log_decode_stats(running_batch=batch)
278
+ self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
282
279
 
283
280
  def add_input_logprob_return_values(
284
281
  self: Scheduler,
@@ -512,29 +509,47 @@ class SchedulerOutputProcessorMixin:
512
509
  if self.model_config.is_multimodal_gen and req.to_abort:
513
510
  continue
514
511
 
515
- if (
516
- req.finished()
517
- # If stream, follow the given stream_interval
518
- or (req.stream and len(req.output_ids) % self.stream_interval == 0)
519
- # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
520
- # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
521
- # always increase one-by-one.
522
- or (
523
- not req.stream
524
- and len(req.output_ids) % 50 == 0
525
- and not self.model_config.is_multimodal_gen
512
+ if req.finished():
513
+ if req.finished_output:
514
+ # With the overlap schedule, a request will try to output twice and hit this line twice
515
+ # because of the one additional delayed token. This "continue" prevented the dummy output.
516
+ continue
517
+ req.finished_output = True
518
+ should_output = True
519
+ else:
520
+ if req.stream:
521
+ stream_interval = (
522
+ req.sampling_params.stream_interval or self.stream_interval
523
+ )
524
+ should_output = len(req.output_ids) % stream_interval == 0
525
+ else:
526
+ should_output = (
527
+ len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
528
+ and not self.model_config.is_multimodal_gen
529
+ )
530
+
531
+ if should_output:
532
+ send_token_offset = req.send_token_offset
533
+ send_output_token_logprobs_offset = (
534
+ req.send_output_token_logprobs_offset
526
535
  )
527
- ):
528
536
  rids.append(req.rid)
529
537
  finished_reasons.append(
530
538
  req.finished_reason.to_json() if req.finished_reason else None
531
539
  )
532
540
  decoded_texts.append(req.decoded_text)
533
541
  decode_ids, read_offset = req.init_incremental_detokenize()
534
- decode_ids_list.append(decode_ids)
542
+
543
+ if self.model_config.is_multimodal_gen:
544
+ decode_ids_list.append(decode_ids)
545
+ else:
546
+ decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
547
+
548
+ req.send_decode_id_offset = len(decode_ids)
535
549
  read_offsets.append(read_offset)
536
550
  if self.skip_tokenizer_init:
537
- output_ids.append(req.output_ids)
551
+ output_ids.append(req.output_ids[send_token_offset:])
552
+ req.send_token_offset = len(req.output_ids)
538
553
  skip_special_tokens.append(req.sampling_params.skip_special_tokens)
539
554
  spaces_between_special_tokens.append(
540
555
  req.sampling_params.spaces_between_special_tokens
@@ -548,36 +563,90 @@ class SchedulerOutputProcessorMixin:
548
563
  spec_verify_ct.append(req.spec_verify_ct)
549
564
 
550
565
  if return_logprob:
551
- input_token_logprobs_val.append(req.input_token_logprobs_val)
552
- input_token_logprobs_idx.append(req.input_token_logprobs_idx)
553
- output_token_logprobs_val.append(req.output_token_logprobs_val)
554
- output_token_logprobs_idx.append(req.output_token_logprobs_idx)
555
- input_top_logprobs_val.append(req.input_top_logprobs_val)
556
- input_top_logprobs_idx.append(req.input_top_logprobs_idx)
557
- output_top_logprobs_val.append(req.output_top_logprobs_val)
558
- output_top_logprobs_idx.append(req.output_top_logprobs_idx)
559
- input_token_ids_logprobs_val.append(
560
- req.input_token_ids_logprobs_val
561
- )
562
- input_token_ids_logprobs_idx.append(
563
- req.input_token_ids_logprobs_idx
564
- )
565
- output_token_ids_logprobs_val.append(
566
- req.output_token_ids_logprobs_val
567
- )
568
- output_token_ids_logprobs_idx.append(
569
- req.output_token_ids_logprobs_idx
570
- )
566
+ if (
567
+ req.return_logprob
568
+ and not req.input_logprob_sent
569
+ # Decode server does not send input logprobs
570
+ and self.disaggregation_mode != DisaggregationMode.DECODE
571
+ ):
572
+ input_token_logprobs_val.append(req.input_token_logprobs_val)
573
+ input_token_logprobs_idx.append(req.input_token_logprobs_idx)
574
+ input_top_logprobs_val.append(req.input_top_logprobs_val)
575
+ input_top_logprobs_idx.append(req.input_top_logprobs_idx)
576
+ input_token_ids_logprobs_val.append(
577
+ req.input_token_ids_logprobs_val
578
+ )
579
+ input_token_ids_logprobs_idx.append(
580
+ req.input_token_ids_logprobs_idx
581
+ )
582
+ req.input_logprob_sent = True
583
+ else:
584
+ input_token_logprobs_val.append([])
585
+ input_token_logprobs_idx.append([])
586
+ input_top_logprobs_val.append([])
587
+ input_top_logprobs_idx.append([])
588
+ input_token_ids_logprobs_val.append([])
589
+ input_token_ids_logprobs_idx.append([])
590
+
591
+ if req.return_logprob:
592
+ output_token_logprobs_val.append(
593
+ req.output_token_logprobs_val[
594
+ send_output_token_logprobs_offset:
595
+ ]
596
+ )
597
+ output_token_logprobs_idx.append(
598
+ req.output_token_logprobs_idx[
599
+ send_output_token_logprobs_offset:
600
+ ]
601
+ )
602
+ output_top_logprobs_val.append(
603
+ req.output_top_logprobs_val[
604
+ send_output_token_logprobs_offset:
605
+ ]
606
+ )
607
+ output_top_logprobs_idx.append(
608
+ req.output_top_logprobs_idx[
609
+ send_output_token_logprobs_offset:
610
+ ]
611
+ )
612
+ output_token_ids_logprobs_val.append(
613
+ req.output_token_ids_logprobs_val[
614
+ send_output_token_logprobs_offset:
615
+ ]
616
+ )
617
+ output_token_ids_logprobs_idx.append(
618
+ req.output_token_ids_logprobs_idx[
619
+ send_output_token_logprobs_offset:
620
+ ]
621
+ )
622
+ req.send_output_token_logprobs_offset = len(
623
+ req.output_token_logprobs_val
624
+ )
625
+ else:
626
+ output_token_logprobs_val.append([])
627
+ output_token_logprobs_idx.append([])
628
+ output_top_logprobs_val.append([])
629
+ output_top_logprobs_idx.append([])
630
+ output_token_ids_logprobs_val.append([])
631
+ output_token_ids_logprobs_idx.append([])
571
632
 
572
633
  if req.return_hidden_states:
573
634
  if output_hidden_states is None:
574
635
  output_hidden_states = []
575
636
  output_hidden_states.append(req.hidden_states)
576
637
 
638
+ if (
639
+ req.finished()
640
+ and self.tp_rank == 0
641
+ and self.server_args.enable_request_time_stats_logging
642
+ ):
643
+ req.log_time_stats()
644
+
577
645
  # Send to detokenizer
578
646
  if rids:
579
647
  if self.model_config.is_multimodal_gen:
580
648
  return
649
+
581
650
  self.send_to_detokenizer.send_pyobj(
582
651
  BatchTokenIDOut(
583
652
  rids,
@@ -125,10 +125,10 @@ logger = logging.getLogger(__name__)
125
125
  class ReqState:
126
126
  """Store the state a request."""
127
127
 
128
- out_list: List
128
+ out_list: List[Dict[Any, Any]]
129
129
  finished: bool
130
130
  event: asyncio.Event
131
- obj: Any
131
+ obj: Union[GenerateReqInput, EmbeddingReqInput]
132
132
 
133
133
  # For metrics
134
134
  created_time: float
@@ -139,6 +139,21 @@ class ReqState:
139
139
 
140
140
  # For streaming output
141
141
  last_output_offset: int = 0
142
+ # For incremental state update.
143
+ text: str = ""
144
+ output_ids: List[int] = dataclasses.field(default_factory=list)
145
+ input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
146
+ input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
147
+ output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
148
+ output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
149
+ input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
150
+ input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
151
+ output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
152
+ output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
153
+ input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
154
+ input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
155
+ output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
156
+ output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
142
157
 
143
158
 
144
159
  class TokenizerManager:
@@ -288,6 +303,7 @@ class TokenizerManager:
288
303
  ),
289
304
  self._handle_batch_output,
290
305
  ),
306
+ (AbortReq, self._handle_abort_req),
291
307
  (OpenSessionReqOutput, self._handle_open_session_req_output),
292
308
  (
293
309
  UpdateWeightFromDiskReqOutput,
@@ -341,13 +357,14 @@ class TokenizerManager:
341
357
  ]
342
358
  )
343
359
 
360
+ # For pd disaggregtion
344
361
  self.disaggregation_mode = DisaggregationMode(
345
362
  self.server_args.disaggregation_mode
346
363
  )
347
364
  self.transfer_backend = TransferBackend(
348
365
  self.server_args.disaggregation_transfer_backend
349
366
  )
350
- # for disaggregtion, start kv boostrap server on prefill
367
+ # Start kv boostrap server on prefill
351
368
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
352
369
  # only start bootstrap server on prefill tm
353
370
  kv_bootstrap_server_class = get_kv_class(
@@ -482,6 +499,14 @@ class TokenizerManager:
482
499
  session_params = (
483
500
  SessionParams(**obj.session_params) if obj.session_params else None
484
501
  )
502
+ if (
503
+ obj.custom_logit_processor
504
+ and not self.server_args.enable_custom_logit_processor
505
+ ):
506
+ raise ValueError(
507
+ "The server is not configured to enable custom logit processor. "
508
+ "Please set `--enable-custom-logits-processor` to enable this feature."
509
+ )
485
510
 
486
511
  sampling_params = SamplingParams(**obj.sampling_params)
487
512
  sampling_params.normalize(self.tokenizer)
@@ -570,9 +595,9 @@ class TokenizerManager:
570
595
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
571
596
  created_time: Optional[float] = None,
572
597
  ):
598
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
573
599
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
574
600
  self.rid_to_state[obj.rid] = state
575
- self.send_to_scheduler.send_pyobj(tokenized_obj)
576
601
 
577
602
  async def _wait_one_response(
578
603
  self,
@@ -587,10 +612,11 @@ class TokenizerManager:
587
612
  await asyncio.wait_for(state.event.wait(), timeout=4)
588
613
  except asyncio.TimeoutError:
589
614
  if request is not None and await request.is_disconnected():
615
+ # Abort the request for disconnected requests (non-streaming, waiting queue)
590
616
  self.abort_request(obj.rid)
617
+ # Use exception to kill the whole call stack and asyncio task
591
618
  raise ValueError(
592
- "Request is disconnected from the client side. "
593
- f"Abort request {obj.rid}"
619
+ f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
594
620
  )
595
621
  continue
596
622
 
@@ -605,7 +631,6 @@ class TokenizerManager:
605
631
  else:
606
632
  msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
607
633
  logger.info(msg)
608
- del self.rid_to_state[obj.rid]
609
634
 
610
635
  # Check if this was an abort/error created by scheduler
611
636
  if isinstance(out["meta_info"].get("finish_reason"), dict):
@@ -625,10 +650,11 @@ class TokenizerManager:
625
650
  yield out
626
651
  else:
627
652
  if request is not None and await request.is_disconnected():
653
+ # Abort the request for disconnected requests (non-streaming, running)
628
654
  self.abort_request(obj.rid)
655
+ # Use exception to kill the whole call stack and asyncio task
629
656
  raise ValueError(
630
- "Request is disconnected from the client side. "
631
- f"Abort request {obj.rid}"
657
+ f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
632
658
  )
633
659
 
634
660
  async def _handle_batch_request(
@@ -728,7 +754,6 @@ class TokenizerManager:
728
754
  def abort_request(self, rid: str):
729
755
  if rid not in self.rid_to_state:
730
756
  return
731
- del self.rid_to_state[rid]
732
757
  req = AbortReq(rid)
733
758
  self.send_to_scheduler.send_pyobj(req)
734
759
 
@@ -737,12 +762,16 @@ class TokenizerManager:
737
762
  output_dir: Optional[str] = None,
738
763
  num_steps: Optional[int] = None,
739
764
  activities: Optional[List[str]] = None,
765
+ with_stack: Optional[bool] = None,
766
+ record_shapes: Optional[bool] = None,
740
767
  ):
741
768
  req = ProfileReq(
742
769
  type=ProfileReqType.START_PROFILE,
743
770
  output_dir=output_dir,
744
771
  num_steps=num_steps,
745
772
  activities=activities,
773
+ with_stack=with_stack,
774
+ record_shapes=record_shapes,
746
775
  profile_id=str(time.time()),
747
776
  )
748
777
  result = (await self.start_profile_communicator(req))[0]
@@ -909,12 +938,13 @@ class TokenizerManager:
909
938
  ):
910
939
  await self.send_to_scheduler.send_pyobj(obj)
911
940
 
912
- async def get_internal_state(self) -> Dict[Any, Any]:
941
+ async def get_internal_state(self) -> List[Dict[Any, Any]]:
913
942
  req = GetInternalStateReq()
914
- res: List[GetInternalStateReqOutput] = (
943
+ responses: List[GetInternalStateReqOutput] = (
915
944
  await self.get_internal_state_communicator(req)
916
945
  )
917
- return res[0].internal_state
946
+ # Many DP ranks
947
+ return [res.internal_state for res in responses]
918
948
 
919
949
  def get_log_request_metadata(self):
920
950
  max_length = None
@@ -964,7 +994,7 @@ class TokenizerManager:
964
994
  def create_abort_task(self, obj: GenerateReqInput):
965
995
  # Abort the request if the client is disconnected.
966
996
  async def abort_request():
967
- await asyncio.sleep(1)
997
+ await asyncio.sleep(2)
968
998
  if obj.is_single:
969
999
  self.abort_request(obj.rid)
970
1000
  else:
@@ -1035,6 +1065,9 @@ class TokenizerManager:
1035
1065
  for i, rid in enumerate(recv_obj.rids):
1036
1066
  state = self.rid_to_state.get(rid, None)
1037
1067
  if state is None:
1068
+ logger.error(
1069
+ f"Received output for {rid=} but the state was deleted in TokenizerManager."
1070
+ )
1038
1071
  continue
1039
1072
 
1040
1073
  # Build meta_info and return value
@@ -1047,9 +1080,11 @@ class TokenizerManager:
1047
1080
  if getattr(state.obj, "return_logprob", False):
1048
1081
  self.convert_logprob_style(
1049
1082
  meta_info,
1083
+ state,
1050
1084
  state.obj.top_logprobs_num,
1051
1085
  state.obj.token_ids_logprob,
1052
- state.obj.return_text_in_logprobs,
1086
+ state.obj.return_text_in_logprobs
1087
+ and not self.server_args.skip_tokenizer_init,
1053
1088
  recv_obj,
1054
1089
  i,
1055
1090
  )
@@ -1066,18 +1101,19 @@ class TokenizerManager:
1066
1101
  meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
1067
1102
 
1068
1103
  if isinstance(recv_obj, BatchStrOut):
1104
+ state.text += recv_obj.output_strs[i]
1069
1105
  out_dict = {
1070
- "text": recv_obj.output_strs[i],
1106
+ "text": state.text,
1071
1107
  "meta_info": meta_info,
1072
1108
  }
1073
1109
  elif isinstance(recv_obj, BatchTokenIDOut):
1074
1110
  if self.server_args.stream_output and state.obj.stream:
1075
- output_token_ids = recv_obj.output_ids[i][
1076
- state.last_output_offset :
1077
- ]
1078
- state.last_output_offset = len(recv_obj.output_ids[i])
1111
+ state.output_ids.extend(recv_obj.output_ids[i])
1112
+ output_token_ids = state.output_ids[state.last_output_offset :]
1113
+ state.last_output_offset = len(state.output_ids)
1079
1114
  else:
1080
- output_token_ids = recv_obj.output_ids[i]
1115
+ state.output_ids.extend(recv_obj.output_ids[i])
1116
+ output_token_ids = state.output_ids
1081
1117
 
1082
1118
  out_dict = {
1083
1119
  "output_ids": output_token_ids,
@@ -1098,6 +1134,7 @@ class TokenizerManager:
1098
1134
  meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1099
1135
  state.finished_time = time.time()
1100
1136
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1137
+ del self.rid_to_state[rid]
1101
1138
 
1102
1139
  state.out_list.append(out_dict)
1103
1140
  state.event.set()
@@ -1111,45 +1148,85 @@ class TokenizerManager:
1111
1148
  def convert_logprob_style(
1112
1149
  self,
1113
1150
  meta_info: dict,
1151
+ state: ReqState,
1114
1152
  top_logprobs_num: int,
1115
1153
  token_ids_logprob: List[int],
1116
1154
  return_text_in_logprobs: bool,
1117
1155
  recv_obj: BatchStrOut,
1118
1156
  recv_obj_index: int,
1119
1157
  ):
1158
+ if len(recv_obj.input_token_logprobs_val) > 0:
1159
+ state.input_token_logprobs_val.extend(
1160
+ recv_obj.input_token_logprobs_val[recv_obj_index]
1161
+ )
1162
+ state.input_token_logprobs_idx.extend(
1163
+ recv_obj.input_token_logprobs_idx[recv_obj_index]
1164
+ )
1165
+ state.output_token_logprobs_val.extend(
1166
+ recv_obj.output_token_logprobs_val[recv_obj_index]
1167
+ )
1168
+ state.output_token_logprobs_idx.extend(
1169
+ recv_obj.output_token_logprobs_idx[recv_obj_index]
1170
+ )
1120
1171
  meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1121
- recv_obj.input_token_logprobs_val[recv_obj_index],
1122
- recv_obj.input_token_logprobs_idx[recv_obj_index],
1172
+ state.input_token_logprobs_val,
1173
+ state.input_token_logprobs_idx,
1123
1174
  return_text_in_logprobs,
1124
1175
  )
1125
1176
  meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1126
- recv_obj.output_token_logprobs_val[recv_obj_index],
1127
- recv_obj.output_token_logprobs_idx[recv_obj_index],
1177
+ state.output_token_logprobs_val,
1178
+ state.output_token_logprobs_idx,
1128
1179
  return_text_in_logprobs,
1129
1180
  )
1130
1181
 
1131
1182
  if top_logprobs_num > 0:
1183
+ if len(recv_obj.input_top_logprobs_val) > 0:
1184
+ state.input_top_logprobs_val.extend(
1185
+ recv_obj.input_top_logprobs_val[recv_obj_index]
1186
+ )
1187
+ state.input_top_logprobs_idx.extend(
1188
+ recv_obj.input_top_logprobs_idx[recv_obj_index]
1189
+ )
1190
+ state.output_top_logprobs_val.extend(
1191
+ recv_obj.output_top_logprobs_val[recv_obj_index]
1192
+ )
1193
+ state.output_top_logprobs_idx.extend(
1194
+ recv_obj.output_top_logprobs_idx[recv_obj_index]
1195
+ )
1132
1196
  meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1133
- recv_obj.input_top_logprobs_val[recv_obj_index],
1134
- recv_obj.input_top_logprobs_idx[recv_obj_index],
1197
+ state.input_top_logprobs_val,
1198
+ state.input_top_logprobs_idx,
1135
1199
  return_text_in_logprobs,
1136
1200
  )
1137
1201
  meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1138
- recv_obj.output_top_logprobs_val[recv_obj_index],
1139
- recv_obj.output_top_logprobs_idx[recv_obj_index],
1202
+ state.output_top_logprobs_val,
1203
+ state.output_top_logprobs_idx,
1140
1204
  return_text_in_logprobs,
1141
1205
  )
1142
1206
 
1143
1207
  if token_ids_logprob is not None:
1208
+ if len(recv_obj.input_token_ids_logprobs_val) > 0:
1209
+ state.input_token_ids_logprobs_val.extend(
1210
+ recv_obj.input_token_ids_logprobs_val[recv_obj_index]
1211
+ )
1212
+ state.input_token_ids_logprobs_idx.extend(
1213
+ recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
1214
+ )
1215
+ state.output_token_ids_logprobs_val.extend(
1216
+ recv_obj.output_token_ids_logprobs_val[recv_obj_index]
1217
+ )
1218
+ state.output_token_ids_logprobs_idx.extend(
1219
+ recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
1220
+ )
1144
1221
  meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1145
- recv_obj.input_token_ids_logprobs_val[recv_obj_index],
1146
- recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
1222
+ state.input_token_ids_logprobs_val,
1223
+ state.input_token_ids_logprobs_idx,
1147
1224
  return_text_in_logprobs,
1148
1225
  )
1149
1226
  meta_info["output_token_ids_logprobs"] = (
1150
1227
  self.detokenize_top_logprobs_tokens(
1151
- recv_obj.output_token_ids_logprobs_val[recv_obj_index],
1152
- recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
1228
+ state.output_token_ids_logprobs_val,
1229
+ state.output_token_ids_logprobs_idx,
1153
1230
  return_text_in_logprobs,
1154
1231
  )
1155
1232
  )
@@ -1216,11 +1293,18 @@ class TokenizerManager:
1216
1293
  state.last_completion_tokens = completion_tokens
1217
1294
 
1218
1295
  if state.finished:
1296
+ has_grammar = (
1297
+ state.obj.sampling_params.get("json_schema", None)
1298
+ or state.obj.sampling_params.get("regex", None)
1299
+ or state.obj.sampling_params.get("ebnf", None)
1300
+ or state.obj.sampling_params.get("structural_tag", None)
1301
+ )
1219
1302
  self.metrics_collector.observe_one_finished_request(
1220
1303
  recv_obj.prompt_tokens[i],
1221
1304
  completion_tokens,
1222
1305
  recv_obj.cached_tokens[i],
1223
1306
  state.finished_time - state.created_time,
1307
+ has_grammar,
1224
1308
  )
1225
1309
 
1226
1310
  def dump_requests(self, state: ReqState, out_dict: dict):
@@ -1246,6 +1330,9 @@ class TokenizerManager:
1246
1330
  # Schedule the task to run in the background without awaiting it
1247
1331
  asyncio.create_task(asyncio.to_thread(background_task))
1248
1332
 
1333
+ def _handle_abort_req(self, recv_obj):
1334
+ self.rid_to_state.pop(recv_obj.rid)
1335
+
1249
1336
  def _handle_open_session_req_output(self, recv_obj):
1250
1337
  self.session_futures[recv_obj.session_id].set_result(
1251
1338
  recv_obj.session_id if recv_obj.success else None
@@ -1256,7 +1343,7 @@ class TokenizerManager:
1256
1343
  self.model_update_result.set_result(recv_obj)
1257
1344
  else: # self.server_args.dp_size > 1
1258
1345
  self.model_update_tmp.append(recv_obj)
1259
- # set future if the all results are recevied
1346
+ # set future if the all results are received
1260
1347
  if len(self.model_update_tmp) == self.server_args.dp_size:
1261
1348
  self.model_update_result.set_result(self.model_update_tmp)
1262
1349
 
@@ -1325,3 +1412,15 @@ class _Communicator(Generic[T]):
1325
1412
  self._result_values.append(recv_obj)
1326
1413
  if len(self._result_values) == self._fan_out:
1327
1414
  self._result_event.set()
1415
+
1416
+
1417
+ # Note: request abort handling logic
1418
+ # We should handle all of the following cases correctly.
1419
+ #
1420
+ # | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
1421
+ # | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
1422
+ # | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
1423
+ # | http | yes | running | background task | fast api | del in _handle_batch_output |
1424
+ # | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
1425
+ # | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
1426
+ #