sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import logging
|
3
4
|
import threading
|
5
|
+
import time
|
4
6
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
5
7
|
|
8
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
6
9
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
7
10
|
from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
|
8
11
|
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
@@ -15,6 +18,10 @@ if TYPE_CHECKING:
|
|
15
18
|
Scheduler,
|
16
19
|
)
|
17
20
|
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
DEFAULT_FORCE_STREAM_INTERVAL = 50
|
24
|
+
|
18
25
|
|
19
26
|
class SchedulerOutputProcessorMixin:
|
20
27
|
"""
|
@@ -36,20 +43,16 @@ class SchedulerOutputProcessorMixin:
|
|
36
43
|
next_token_ids,
|
37
44
|
extend_input_len_per_req,
|
38
45
|
extend_logprob_start_len_per_req,
|
39
|
-
bid,
|
40
46
|
) = (
|
41
47
|
result.logits_output,
|
42
48
|
result.next_token_ids,
|
43
49
|
result.extend_input_len_per_req,
|
44
50
|
result.extend_logprob_start_len_per_req,
|
45
|
-
result.bid,
|
46
51
|
)
|
47
52
|
|
48
53
|
if self.enable_overlap:
|
49
|
-
logits_output, next_token_ids = (
|
50
|
-
self.tp_worker.resolve_last_batch_result(
|
51
|
-
launch_done,
|
52
|
-
)
|
54
|
+
logits_output, next_token_ids, _ = (
|
55
|
+
self.tp_worker.resolve_last_batch_result(launch_done)
|
53
56
|
)
|
54
57
|
else:
|
55
58
|
# Move next_token_ids and logprobs to cpu
|
@@ -85,6 +88,7 @@ class SchedulerOutputProcessorMixin:
|
|
85
88
|
|
86
89
|
if req.finished():
|
87
90
|
self.tree_cache.cache_finished_req(req)
|
91
|
+
req.time_stats.completion_time = time.time()
|
88
92
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
89
93
|
# This updates radix so others can match
|
90
94
|
self.tree_cache.cache_unfinished_req(req)
|
@@ -151,10 +155,7 @@ class SchedulerOutputProcessorMixin:
|
|
151
155
|
)
|
152
156
|
logprob_pt += num_input_logprobs
|
153
157
|
|
154
|
-
|
155
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
156
|
-
self.current_stream.synchronize()
|
157
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
158
|
+
self.set_next_batch_sampling_info_done(batch)
|
158
159
|
|
159
160
|
else: # embedding or reward model
|
160
161
|
embeddings, bid = result.embeddings, result.bid
|
@@ -187,16 +188,16 @@ class SchedulerOutputProcessorMixin:
|
|
187
188
|
result: GenerationBatchResult,
|
188
189
|
launch_done: Optional[threading.Event] = None,
|
189
190
|
):
|
190
|
-
logits_output, next_token_ids,
|
191
|
+
logits_output, next_token_ids, can_run_cuda_graph = (
|
191
192
|
result.logits_output,
|
192
193
|
result.next_token_ids,
|
193
|
-
result.
|
194
|
+
result.can_run_cuda_graph,
|
194
195
|
)
|
195
196
|
self.num_generated_tokens += len(batch.reqs)
|
196
197
|
|
197
198
|
if self.enable_overlap:
|
198
|
-
logits_output, next_token_ids =
|
199
|
-
launch_done
|
199
|
+
logits_output, next_token_ids, can_run_cuda_graph = (
|
200
|
+
self.tp_worker.resolve_last_batch_result(launch_done)
|
200
201
|
)
|
201
202
|
next_token_logprobs = logits_output.next_token_logprobs
|
202
203
|
elif batch.spec_algorithm.is_none():
|
@@ -235,6 +236,7 @@ class SchedulerOutputProcessorMixin:
|
|
235
236
|
req.check_finished()
|
236
237
|
if req.finished():
|
237
238
|
self.tree_cache.cache_finished_req(req)
|
239
|
+
req.time_stats.completion_time = time.time()
|
238
240
|
|
239
241
|
if req.return_logprob and batch.spec_algorithm.is_none():
|
240
242
|
# speculative worker handles logprob in speculative decoding
|
@@ -264,13 +266,8 @@ class SchedulerOutputProcessorMixin:
|
|
264
266
|
req.grammar.accept_token(next_token_id)
|
265
267
|
req.grammar.finished = req.finished()
|
266
268
|
|
267
|
-
|
268
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
269
|
-
self.current_stream.synchronize()
|
270
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
271
|
-
|
269
|
+
self.set_next_batch_sampling_info_done(batch)
|
272
270
|
self.stream_output(batch.reqs, batch.return_logprob)
|
273
|
-
|
274
271
|
self.token_to_kv_pool_allocator.free_group_end()
|
275
272
|
|
276
273
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
@@ -278,7 +275,7 @@ class SchedulerOutputProcessorMixin:
|
|
278
275
|
self.attn_tp_rank == 0
|
279
276
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
280
277
|
):
|
281
|
-
self.log_decode_stats(running_batch=batch)
|
278
|
+
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
|
282
279
|
|
283
280
|
def add_input_logprob_return_values(
|
284
281
|
self: Scheduler,
|
@@ -512,29 +509,47 @@ class SchedulerOutputProcessorMixin:
|
|
512
509
|
if self.model_config.is_multimodal_gen and req.to_abort:
|
513
510
|
continue
|
514
511
|
|
515
|
-
if (
|
516
|
-
req.
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
512
|
+
if req.finished():
|
513
|
+
if req.finished_output:
|
514
|
+
# With the overlap schedule, a request will try to output twice and hit this line twice
|
515
|
+
# because of the one additional delayed token. This "continue" prevented the dummy output.
|
516
|
+
continue
|
517
|
+
req.finished_output = True
|
518
|
+
should_output = True
|
519
|
+
else:
|
520
|
+
if req.stream:
|
521
|
+
stream_interval = (
|
522
|
+
req.sampling_params.stream_interval or self.stream_interval
|
523
|
+
)
|
524
|
+
should_output = len(req.output_ids) % stream_interval == 0
|
525
|
+
else:
|
526
|
+
should_output = (
|
527
|
+
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
|
528
|
+
and not self.model_config.is_multimodal_gen
|
529
|
+
)
|
530
|
+
|
531
|
+
if should_output:
|
532
|
+
send_token_offset = req.send_token_offset
|
533
|
+
send_output_token_logprobs_offset = (
|
534
|
+
req.send_output_token_logprobs_offset
|
526
535
|
)
|
527
|
-
):
|
528
536
|
rids.append(req.rid)
|
529
537
|
finished_reasons.append(
|
530
538
|
req.finished_reason.to_json() if req.finished_reason else None
|
531
539
|
)
|
532
540
|
decoded_texts.append(req.decoded_text)
|
533
541
|
decode_ids, read_offset = req.init_incremental_detokenize()
|
534
|
-
|
542
|
+
|
543
|
+
if self.model_config.is_multimodal_gen:
|
544
|
+
decode_ids_list.append(decode_ids)
|
545
|
+
else:
|
546
|
+
decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
|
547
|
+
|
548
|
+
req.send_decode_id_offset = len(decode_ids)
|
535
549
|
read_offsets.append(read_offset)
|
536
550
|
if self.skip_tokenizer_init:
|
537
|
-
output_ids.append(req.output_ids)
|
551
|
+
output_ids.append(req.output_ids[send_token_offset:])
|
552
|
+
req.send_token_offset = len(req.output_ids)
|
538
553
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
539
554
|
spaces_between_special_tokens.append(
|
540
555
|
req.sampling_params.spaces_between_special_tokens
|
@@ -548,36 +563,90 @@ class SchedulerOutputProcessorMixin:
|
|
548
563
|
spec_verify_ct.append(req.spec_verify_ct)
|
549
564
|
|
550
565
|
if return_logprob:
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
req.
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
566
|
+
if (
|
567
|
+
req.return_logprob
|
568
|
+
and not req.input_logprob_sent
|
569
|
+
# Decode server does not send input logprobs
|
570
|
+
and self.disaggregation_mode != DisaggregationMode.DECODE
|
571
|
+
):
|
572
|
+
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
573
|
+
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
574
|
+
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
575
|
+
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
576
|
+
input_token_ids_logprobs_val.append(
|
577
|
+
req.input_token_ids_logprobs_val
|
578
|
+
)
|
579
|
+
input_token_ids_logprobs_idx.append(
|
580
|
+
req.input_token_ids_logprobs_idx
|
581
|
+
)
|
582
|
+
req.input_logprob_sent = True
|
583
|
+
else:
|
584
|
+
input_token_logprobs_val.append([])
|
585
|
+
input_token_logprobs_idx.append([])
|
586
|
+
input_top_logprobs_val.append([])
|
587
|
+
input_top_logprobs_idx.append([])
|
588
|
+
input_token_ids_logprobs_val.append([])
|
589
|
+
input_token_ids_logprobs_idx.append([])
|
590
|
+
|
591
|
+
if req.return_logprob:
|
592
|
+
output_token_logprobs_val.append(
|
593
|
+
req.output_token_logprobs_val[
|
594
|
+
send_output_token_logprobs_offset:
|
595
|
+
]
|
596
|
+
)
|
597
|
+
output_token_logprobs_idx.append(
|
598
|
+
req.output_token_logprobs_idx[
|
599
|
+
send_output_token_logprobs_offset:
|
600
|
+
]
|
601
|
+
)
|
602
|
+
output_top_logprobs_val.append(
|
603
|
+
req.output_top_logprobs_val[
|
604
|
+
send_output_token_logprobs_offset:
|
605
|
+
]
|
606
|
+
)
|
607
|
+
output_top_logprobs_idx.append(
|
608
|
+
req.output_top_logprobs_idx[
|
609
|
+
send_output_token_logprobs_offset:
|
610
|
+
]
|
611
|
+
)
|
612
|
+
output_token_ids_logprobs_val.append(
|
613
|
+
req.output_token_ids_logprobs_val[
|
614
|
+
send_output_token_logprobs_offset:
|
615
|
+
]
|
616
|
+
)
|
617
|
+
output_token_ids_logprobs_idx.append(
|
618
|
+
req.output_token_ids_logprobs_idx[
|
619
|
+
send_output_token_logprobs_offset:
|
620
|
+
]
|
621
|
+
)
|
622
|
+
req.send_output_token_logprobs_offset = len(
|
623
|
+
req.output_token_logprobs_val
|
624
|
+
)
|
625
|
+
else:
|
626
|
+
output_token_logprobs_val.append([])
|
627
|
+
output_token_logprobs_idx.append([])
|
628
|
+
output_top_logprobs_val.append([])
|
629
|
+
output_top_logprobs_idx.append([])
|
630
|
+
output_token_ids_logprobs_val.append([])
|
631
|
+
output_token_ids_logprobs_idx.append([])
|
571
632
|
|
572
633
|
if req.return_hidden_states:
|
573
634
|
if output_hidden_states is None:
|
574
635
|
output_hidden_states = []
|
575
636
|
output_hidden_states.append(req.hidden_states)
|
576
637
|
|
638
|
+
if (
|
639
|
+
req.finished()
|
640
|
+
and self.tp_rank == 0
|
641
|
+
and self.server_args.enable_request_time_stats_logging
|
642
|
+
):
|
643
|
+
req.log_time_stats()
|
644
|
+
|
577
645
|
# Send to detokenizer
|
578
646
|
if rids:
|
579
647
|
if self.model_config.is_multimodal_gen:
|
580
648
|
return
|
649
|
+
|
581
650
|
self.send_to_detokenizer.send_pyobj(
|
582
651
|
BatchTokenIDOut(
|
583
652
|
rids,
|
@@ -125,10 +125,10 @@ logger = logging.getLogger(__name__)
|
|
125
125
|
class ReqState:
|
126
126
|
"""Store the state a request."""
|
127
127
|
|
128
|
-
out_list: List
|
128
|
+
out_list: List[Dict[Any, Any]]
|
129
129
|
finished: bool
|
130
130
|
event: asyncio.Event
|
131
|
-
obj:
|
131
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput]
|
132
132
|
|
133
133
|
# For metrics
|
134
134
|
created_time: float
|
@@ -139,6 +139,21 @@ class ReqState:
|
|
139
139
|
|
140
140
|
# For streaming output
|
141
141
|
last_output_offset: int = 0
|
142
|
+
# For incremental state update.
|
143
|
+
text: str = ""
|
144
|
+
output_ids: List[int] = dataclasses.field(default_factory=list)
|
145
|
+
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
146
|
+
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
147
|
+
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
148
|
+
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
149
|
+
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
150
|
+
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
151
|
+
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
152
|
+
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
153
|
+
input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
154
|
+
input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
155
|
+
output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
156
|
+
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
142
157
|
|
143
158
|
|
144
159
|
class TokenizerManager:
|
@@ -288,6 +303,7 @@ class TokenizerManager:
|
|
288
303
|
),
|
289
304
|
self._handle_batch_output,
|
290
305
|
),
|
306
|
+
(AbortReq, self._handle_abort_req),
|
291
307
|
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
292
308
|
(
|
293
309
|
UpdateWeightFromDiskReqOutput,
|
@@ -341,13 +357,14 @@ class TokenizerManager:
|
|
341
357
|
]
|
342
358
|
)
|
343
359
|
|
360
|
+
# For pd disaggregtion
|
344
361
|
self.disaggregation_mode = DisaggregationMode(
|
345
362
|
self.server_args.disaggregation_mode
|
346
363
|
)
|
347
364
|
self.transfer_backend = TransferBackend(
|
348
365
|
self.server_args.disaggregation_transfer_backend
|
349
366
|
)
|
350
|
-
#
|
367
|
+
# Start kv boostrap server on prefill
|
351
368
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
352
369
|
# only start bootstrap server on prefill tm
|
353
370
|
kv_bootstrap_server_class = get_kv_class(
|
@@ -482,6 +499,14 @@ class TokenizerManager:
|
|
482
499
|
session_params = (
|
483
500
|
SessionParams(**obj.session_params) if obj.session_params else None
|
484
501
|
)
|
502
|
+
if (
|
503
|
+
obj.custom_logit_processor
|
504
|
+
and not self.server_args.enable_custom_logit_processor
|
505
|
+
):
|
506
|
+
raise ValueError(
|
507
|
+
"The server is not configured to enable custom logit processor. "
|
508
|
+
"Please set `--enable-custom-logits-processor` to enable this feature."
|
509
|
+
)
|
485
510
|
|
486
511
|
sampling_params = SamplingParams(**obj.sampling_params)
|
487
512
|
sampling_params.normalize(self.tokenizer)
|
@@ -570,9 +595,9 @@ class TokenizerManager:
|
|
570
595
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
571
596
|
created_time: Optional[float] = None,
|
572
597
|
):
|
598
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
573
599
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
574
600
|
self.rid_to_state[obj.rid] = state
|
575
|
-
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
576
601
|
|
577
602
|
async def _wait_one_response(
|
578
603
|
self,
|
@@ -587,10 +612,11 @@ class TokenizerManager:
|
|
587
612
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
588
613
|
except asyncio.TimeoutError:
|
589
614
|
if request is not None and await request.is_disconnected():
|
615
|
+
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
590
616
|
self.abort_request(obj.rid)
|
617
|
+
# Use exception to kill the whole call stack and asyncio task
|
591
618
|
raise ValueError(
|
592
|
-
"Request is disconnected from the client side. "
|
593
|
-
f"Abort request {obj.rid}"
|
619
|
+
f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
|
594
620
|
)
|
595
621
|
continue
|
596
622
|
|
@@ -605,7 +631,6 @@ class TokenizerManager:
|
|
605
631
|
else:
|
606
632
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
607
633
|
logger.info(msg)
|
608
|
-
del self.rid_to_state[obj.rid]
|
609
634
|
|
610
635
|
# Check if this was an abort/error created by scheduler
|
611
636
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
@@ -625,10 +650,11 @@ class TokenizerManager:
|
|
625
650
|
yield out
|
626
651
|
else:
|
627
652
|
if request is not None and await request.is_disconnected():
|
653
|
+
# Abort the request for disconnected requests (non-streaming, running)
|
628
654
|
self.abort_request(obj.rid)
|
655
|
+
# Use exception to kill the whole call stack and asyncio task
|
629
656
|
raise ValueError(
|
630
|
-
"Request is disconnected from the client side. "
|
631
|
-
f"Abort request {obj.rid}"
|
657
|
+
f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
|
632
658
|
)
|
633
659
|
|
634
660
|
async def _handle_batch_request(
|
@@ -728,7 +754,6 @@ class TokenizerManager:
|
|
728
754
|
def abort_request(self, rid: str):
|
729
755
|
if rid not in self.rid_to_state:
|
730
756
|
return
|
731
|
-
del self.rid_to_state[rid]
|
732
757
|
req = AbortReq(rid)
|
733
758
|
self.send_to_scheduler.send_pyobj(req)
|
734
759
|
|
@@ -737,12 +762,16 @@ class TokenizerManager:
|
|
737
762
|
output_dir: Optional[str] = None,
|
738
763
|
num_steps: Optional[int] = None,
|
739
764
|
activities: Optional[List[str]] = None,
|
765
|
+
with_stack: Optional[bool] = None,
|
766
|
+
record_shapes: Optional[bool] = None,
|
740
767
|
):
|
741
768
|
req = ProfileReq(
|
742
769
|
type=ProfileReqType.START_PROFILE,
|
743
770
|
output_dir=output_dir,
|
744
771
|
num_steps=num_steps,
|
745
772
|
activities=activities,
|
773
|
+
with_stack=with_stack,
|
774
|
+
record_shapes=record_shapes,
|
746
775
|
profile_id=str(time.time()),
|
747
776
|
)
|
748
777
|
result = (await self.start_profile_communicator(req))[0]
|
@@ -909,12 +938,13 @@ class TokenizerManager:
|
|
909
938
|
):
|
910
939
|
await self.send_to_scheduler.send_pyobj(obj)
|
911
940
|
|
912
|
-
async def get_internal_state(self) -> Dict[Any, Any]:
|
941
|
+
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
913
942
|
req = GetInternalStateReq()
|
914
|
-
|
943
|
+
responses: List[GetInternalStateReqOutput] = (
|
915
944
|
await self.get_internal_state_communicator(req)
|
916
945
|
)
|
917
|
-
|
946
|
+
# Many DP ranks
|
947
|
+
return [res.internal_state for res in responses]
|
918
948
|
|
919
949
|
def get_log_request_metadata(self):
|
920
950
|
max_length = None
|
@@ -964,7 +994,7 @@ class TokenizerManager:
|
|
964
994
|
def create_abort_task(self, obj: GenerateReqInput):
|
965
995
|
# Abort the request if the client is disconnected.
|
966
996
|
async def abort_request():
|
967
|
-
await asyncio.sleep(
|
997
|
+
await asyncio.sleep(2)
|
968
998
|
if obj.is_single:
|
969
999
|
self.abort_request(obj.rid)
|
970
1000
|
else:
|
@@ -1035,6 +1065,9 @@ class TokenizerManager:
|
|
1035
1065
|
for i, rid in enumerate(recv_obj.rids):
|
1036
1066
|
state = self.rid_to_state.get(rid, None)
|
1037
1067
|
if state is None:
|
1068
|
+
logger.error(
|
1069
|
+
f"Received output for {rid=} but the state was deleted in TokenizerManager."
|
1070
|
+
)
|
1038
1071
|
continue
|
1039
1072
|
|
1040
1073
|
# Build meta_info and return value
|
@@ -1047,9 +1080,11 @@ class TokenizerManager:
|
|
1047
1080
|
if getattr(state.obj, "return_logprob", False):
|
1048
1081
|
self.convert_logprob_style(
|
1049
1082
|
meta_info,
|
1083
|
+
state,
|
1050
1084
|
state.obj.top_logprobs_num,
|
1051
1085
|
state.obj.token_ids_logprob,
|
1052
|
-
state.obj.return_text_in_logprobs
|
1086
|
+
state.obj.return_text_in_logprobs
|
1087
|
+
and not self.server_args.skip_tokenizer_init,
|
1053
1088
|
recv_obj,
|
1054
1089
|
i,
|
1055
1090
|
)
|
@@ -1066,18 +1101,19 @@ class TokenizerManager:
|
|
1066
1101
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
1067
1102
|
|
1068
1103
|
if isinstance(recv_obj, BatchStrOut):
|
1104
|
+
state.text += recv_obj.output_strs[i]
|
1069
1105
|
out_dict = {
|
1070
|
-
"text":
|
1106
|
+
"text": state.text,
|
1071
1107
|
"meta_info": meta_info,
|
1072
1108
|
}
|
1073
1109
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
1074
1110
|
if self.server_args.stream_output and state.obj.stream:
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
state.last_output_offset = len(recv_obj.output_ids[i])
|
1111
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1112
|
+
output_token_ids = state.output_ids[state.last_output_offset :]
|
1113
|
+
state.last_output_offset = len(state.output_ids)
|
1079
1114
|
else:
|
1080
|
-
|
1115
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1116
|
+
output_token_ids = state.output_ids
|
1081
1117
|
|
1082
1118
|
out_dict = {
|
1083
1119
|
"output_ids": output_token_ids,
|
@@ -1098,6 +1134,7 @@ class TokenizerManager:
|
|
1098
1134
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
1099
1135
|
state.finished_time = time.time()
|
1100
1136
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1137
|
+
del self.rid_to_state[rid]
|
1101
1138
|
|
1102
1139
|
state.out_list.append(out_dict)
|
1103
1140
|
state.event.set()
|
@@ -1111,45 +1148,85 @@ class TokenizerManager:
|
|
1111
1148
|
def convert_logprob_style(
|
1112
1149
|
self,
|
1113
1150
|
meta_info: dict,
|
1151
|
+
state: ReqState,
|
1114
1152
|
top_logprobs_num: int,
|
1115
1153
|
token_ids_logprob: List[int],
|
1116
1154
|
return_text_in_logprobs: bool,
|
1117
1155
|
recv_obj: BatchStrOut,
|
1118
1156
|
recv_obj_index: int,
|
1119
1157
|
):
|
1158
|
+
if len(recv_obj.input_token_logprobs_val) > 0:
|
1159
|
+
state.input_token_logprobs_val.extend(
|
1160
|
+
recv_obj.input_token_logprobs_val[recv_obj_index]
|
1161
|
+
)
|
1162
|
+
state.input_token_logprobs_idx.extend(
|
1163
|
+
recv_obj.input_token_logprobs_idx[recv_obj_index]
|
1164
|
+
)
|
1165
|
+
state.output_token_logprobs_val.extend(
|
1166
|
+
recv_obj.output_token_logprobs_val[recv_obj_index]
|
1167
|
+
)
|
1168
|
+
state.output_token_logprobs_idx.extend(
|
1169
|
+
recv_obj.output_token_logprobs_idx[recv_obj_index]
|
1170
|
+
)
|
1120
1171
|
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
1121
|
-
|
1122
|
-
|
1172
|
+
state.input_token_logprobs_val,
|
1173
|
+
state.input_token_logprobs_idx,
|
1123
1174
|
return_text_in_logprobs,
|
1124
1175
|
)
|
1125
1176
|
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
1126
|
-
|
1127
|
-
|
1177
|
+
state.output_token_logprobs_val,
|
1178
|
+
state.output_token_logprobs_idx,
|
1128
1179
|
return_text_in_logprobs,
|
1129
1180
|
)
|
1130
1181
|
|
1131
1182
|
if top_logprobs_num > 0:
|
1183
|
+
if len(recv_obj.input_top_logprobs_val) > 0:
|
1184
|
+
state.input_top_logprobs_val.extend(
|
1185
|
+
recv_obj.input_top_logprobs_val[recv_obj_index]
|
1186
|
+
)
|
1187
|
+
state.input_top_logprobs_idx.extend(
|
1188
|
+
recv_obj.input_top_logprobs_idx[recv_obj_index]
|
1189
|
+
)
|
1190
|
+
state.output_top_logprobs_val.extend(
|
1191
|
+
recv_obj.output_top_logprobs_val[recv_obj_index]
|
1192
|
+
)
|
1193
|
+
state.output_top_logprobs_idx.extend(
|
1194
|
+
recv_obj.output_top_logprobs_idx[recv_obj_index]
|
1195
|
+
)
|
1132
1196
|
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
1133
|
-
|
1134
|
-
|
1197
|
+
state.input_top_logprobs_val,
|
1198
|
+
state.input_top_logprobs_idx,
|
1135
1199
|
return_text_in_logprobs,
|
1136
1200
|
)
|
1137
1201
|
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
1138
|
-
|
1139
|
-
|
1202
|
+
state.output_top_logprobs_val,
|
1203
|
+
state.output_top_logprobs_idx,
|
1140
1204
|
return_text_in_logprobs,
|
1141
1205
|
)
|
1142
1206
|
|
1143
1207
|
if token_ids_logprob is not None:
|
1208
|
+
if len(recv_obj.input_token_ids_logprobs_val) > 0:
|
1209
|
+
state.input_token_ids_logprobs_val.extend(
|
1210
|
+
recv_obj.input_token_ids_logprobs_val[recv_obj_index]
|
1211
|
+
)
|
1212
|
+
state.input_token_ids_logprobs_idx.extend(
|
1213
|
+
recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
|
1214
|
+
)
|
1215
|
+
state.output_token_ids_logprobs_val.extend(
|
1216
|
+
recv_obj.output_token_ids_logprobs_val[recv_obj_index]
|
1217
|
+
)
|
1218
|
+
state.output_token_ids_logprobs_idx.extend(
|
1219
|
+
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
|
1220
|
+
)
|
1144
1221
|
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
1145
|
-
|
1146
|
-
|
1222
|
+
state.input_token_ids_logprobs_val,
|
1223
|
+
state.input_token_ids_logprobs_idx,
|
1147
1224
|
return_text_in_logprobs,
|
1148
1225
|
)
|
1149
1226
|
meta_info["output_token_ids_logprobs"] = (
|
1150
1227
|
self.detokenize_top_logprobs_tokens(
|
1151
|
-
|
1152
|
-
|
1228
|
+
state.output_token_ids_logprobs_val,
|
1229
|
+
state.output_token_ids_logprobs_idx,
|
1153
1230
|
return_text_in_logprobs,
|
1154
1231
|
)
|
1155
1232
|
)
|
@@ -1216,11 +1293,18 @@ class TokenizerManager:
|
|
1216
1293
|
state.last_completion_tokens = completion_tokens
|
1217
1294
|
|
1218
1295
|
if state.finished:
|
1296
|
+
has_grammar = (
|
1297
|
+
state.obj.sampling_params.get("json_schema", None)
|
1298
|
+
or state.obj.sampling_params.get("regex", None)
|
1299
|
+
or state.obj.sampling_params.get("ebnf", None)
|
1300
|
+
or state.obj.sampling_params.get("structural_tag", None)
|
1301
|
+
)
|
1219
1302
|
self.metrics_collector.observe_one_finished_request(
|
1220
1303
|
recv_obj.prompt_tokens[i],
|
1221
1304
|
completion_tokens,
|
1222
1305
|
recv_obj.cached_tokens[i],
|
1223
1306
|
state.finished_time - state.created_time,
|
1307
|
+
has_grammar,
|
1224
1308
|
)
|
1225
1309
|
|
1226
1310
|
def dump_requests(self, state: ReqState, out_dict: dict):
|
@@ -1246,6 +1330,9 @@ class TokenizerManager:
|
|
1246
1330
|
# Schedule the task to run in the background without awaiting it
|
1247
1331
|
asyncio.create_task(asyncio.to_thread(background_task))
|
1248
1332
|
|
1333
|
+
def _handle_abort_req(self, recv_obj):
|
1334
|
+
self.rid_to_state.pop(recv_obj.rid)
|
1335
|
+
|
1249
1336
|
def _handle_open_session_req_output(self, recv_obj):
|
1250
1337
|
self.session_futures[recv_obj.session_id].set_result(
|
1251
1338
|
recv_obj.session_id if recv_obj.success else None
|
@@ -1256,7 +1343,7 @@ class TokenizerManager:
|
|
1256
1343
|
self.model_update_result.set_result(recv_obj)
|
1257
1344
|
else: # self.server_args.dp_size > 1
|
1258
1345
|
self.model_update_tmp.append(recv_obj)
|
1259
|
-
# set future if the all results are
|
1346
|
+
# set future if the all results are received
|
1260
1347
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
1261
1348
|
self.model_update_result.set_result(self.model_update_tmp)
|
1262
1349
|
|
@@ -1325,3 +1412,15 @@ class _Communicator(Generic[T]):
|
|
1325
1412
|
self._result_values.append(recv_obj)
|
1326
1413
|
if len(self._result_values) == self._fan_out:
|
1327
1414
|
self._result_event.set()
|
1415
|
+
|
1416
|
+
|
1417
|
+
# Note: request abort handling logic
|
1418
|
+
# We should handle all of the following cases correctly.
|
1419
|
+
#
|
1420
|
+
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
1421
|
+
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
1422
|
+
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
1423
|
+
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
1424
|
+
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
1425
|
+
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
1426
|
+
#
|