sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,11 @@ from sglang.srt.disaggregation.utils import (
|
|
54
54
|
TransferBackend,
|
55
55
|
get_kv_class,
|
56
56
|
)
|
57
|
-
from sglang.srt.hf_transformers_utils import
|
57
|
+
from sglang.srt.hf_transformers_utils import (
|
58
|
+
get_processor,
|
59
|
+
get_tokenizer,
|
60
|
+
get_tokenizer_from_processor,
|
61
|
+
)
|
58
62
|
from sglang.srt.managers.io_struct import (
|
59
63
|
AbortReq,
|
60
64
|
BatchEmbeddingOut,
|
@@ -86,6 +90,8 @@ from sglang.srt.managers.io_struct import (
|
|
86
90
|
ResumeMemoryOccupationReqInput,
|
87
91
|
ResumeMemoryOccupationReqOutput,
|
88
92
|
SessionParams,
|
93
|
+
SlowDownReqInput,
|
94
|
+
SlowDownReqOutput,
|
89
95
|
TokenizedEmbeddingReqInput,
|
90
96
|
TokenizedGenerateReqInput,
|
91
97
|
UpdateWeightFromDiskReqInput,
|
@@ -119,10 +125,10 @@ logger = logging.getLogger(__name__)
|
|
119
125
|
class ReqState:
|
120
126
|
"""Store the state a request."""
|
121
127
|
|
122
|
-
out_list: List
|
128
|
+
out_list: List[Dict[Any, Any]]
|
123
129
|
finished: bool
|
124
130
|
event: asyncio.Event
|
125
|
-
obj:
|
131
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput]
|
126
132
|
|
127
133
|
# For metrics
|
128
134
|
created_time: float
|
@@ -133,6 +139,21 @@ class ReqState:
|
|
133
139
|
|
134
140
|
# For streaming output
|
135
141
|
last_output_offset: int = 0
|
142
|
+
# For incremental state update.
|
143
|
+
text: str = ""
|
144
|
+
output_ids: List[int] = dataclasses.field(default_factory=list)
|
145
|
+
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
146
|
+
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
147
|
+
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
148
|
+
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
149
|
+
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
150
|
+
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
151
|
+
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
152
|
+
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
153
|
+
input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
154
|
+
input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
155
|
+
output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
156
|
+
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
136
157
|
|
137
158
|
|
138
159
|
class TokenizerManager:
|
@@ -161,17 +182,7 @@ class TokenizerManager:
|
|
161
182
|
# Read model args
|
162
183
|
self.model_path = server_args.model_path
|
163
184
|
self.served_model_name = server_args.served_model_name
|
164
|
-
self.model_config = ModelConfig(
|
165
|
-
server_args.model_path,
|
166
|
-
trust_remote_code=server_args.trust_remote_code,
|
167
|
-
revision=server_args.revision,
|
168
|
-
context_length=server_args.context_length,
|
169
|
-
model_override_args=server_args.json_model_override_args,
|
170
|
-
is_embedding=server_args.is_embedding,
|
171
|
-
enable_multimodal=server_args.enable_multimodal,
|
172
|
-
dtype=server_args.dtype,
|
173
|
-
quantization=server_args.quantization,
|
174
|
-
)
|
185
|
+
self.model_config = ModelConfig.from_server_args(server_args)
|
175
186
|
|
176
187
|
self.is_generation = self.model_config.is_generation
|
177
188
|
self.is_image_gen = self.model_config.is_image_gen
|
@@ -199,7 +210,7 @@ class TokenizerManager:
|
|
199
210
|
self.tokenizer = self.processor = None
|
200
211
|
else:
|
201
212
|
self.processor = _processor
|
202
|
-
self.tokenizer = self.processor
|
213
|
+
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
203
214
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
204
215
|
else:
|
205
216
|
self.mm_processor = get_dummy_processor()
|
@@ -265,6 +276,9 @@ class TokenizerManager:
|
|
265
276
|
self.resume_memory_occupation_communicator = _Communicator(
|
266
277
|
self.send_to_scheduler, server_args.dp_size
|
267
278
|
)
|
279
|
+
self.slow_down_communicator = _Communicator(
|
280
|
+
self.send_to_scheduler, server_args.dp_size
|
281
|
+
)
|
268
282
|
self.flush_cache_communicator = _Communicator(
|
269
283
|
self.send_to_scheduler, server_args.dp_size
|
270
284
|
)
|
@@ -289,6 +303,7 @@ class TokenizerManager:
|
|
289
303
|
),
|
290
304
|
self._handle_batch_output,
|
291
305
|
),
|
306
|
+
(AbortReq, self._handle_abort_req),
|
292
307
|
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
293
308
|
(
|
294
309
|
UpdateWeightFromDiskReqOutput,
|
@@ -318,6 +333,10 @@ class TokenizerManager:
|
|
318
333
|
ResumeMemoryOccupationReqOutput,
|
319
334
|
self.resume_memory_occupation_communicator.handle_recv,
|
320
335
|
),
|
336
|
+
(
|
337
|
+
SlowDownReqOutput,
|
338
|
+
self.slow_down_communicator.handle_recv,
|
339
|
+
),
|
321
340
|
(
|
322
341
|
FlushCacheReqOutput,
|
323
342
|
self.flush_cache_communicator.handle_recv,
|
@@ -338,13 +357,14 @@ class TokenizerManager:
|
|
338
357
|
]
|
339
358
|
)
|
340
359
|
|
360
|
+
# For pd disaggregtion
|
341
361
|
self.disaggregation_mode = DisaggregationMode(
|
342
362
|
self.server_args.disaggregation_mode
|
343
363
|
)
|
344
364
|
self.transfer_backend = TransferBackend(
|
345
365
|
self.server_args.disaggregation_transfer_backend
|
346
366
|
)
|
347
|
-
#
|
367
|
+
# Start kv boostrap server on prefill
|
348
368
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
349
369
|
# only start bootstrap server on prefill tm
|
350
370
|
kv_bootstrap_server_class = get_kv_class(
|
@@ -479,6 +499,14 @@ class TokenizerManager:
|
|
479
499
|
session_params = (
|
480
500
|
SessionParams(**obj.session_params) if obj.session_params else None
|
481
501
|
)
|
502
|
+
if (
|
503
|
+
obj.custom_logit_processor
|
504
|
+
and not self.server_args.enable_custom_logit_processor
|
505
|
+
):
|
506
|
+
raise ValueError(
|
507
|
+
"The server is not configured to enable custom logit processor. "
|
508
|
+
"Please set `--enable-custom-logits-processor` to enable this feature."
|
509
|
+
)
|
482
510
|
|
483
511
|
sampling_params = SamplingParams(**obj.sampling_params)
|
484
512
|
sampling_params.normalize(self.tokenizer)
|
@@ -567,9 +595,9 @@ class TokenizerManager:
|
|
567
595
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
568
596
|
created_time: Optional[float] = None,
|
569
597
|
):
|
598
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
570
599
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
571
600
|
self.rid_to_state[obj.rid] = state
|
572
|
-
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
573
601
|
|
574
602
|
async def _wait_one_response(
|
575
603
|
self,
|
@@ -584,10 +612,11 @@ class TokenizerManager:
|
|
584
612
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
585
613
|
except asyncio.TimeoutError:
|
586
614
|
if request is not None and await request.is_disconnected():
|
615
|
+
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
587
616
|
self.abort_request(obj.rid)
|
617
|
+
# Use exception to kill the whole call stack and asyncio task
|
588
618
|
raise ValueError(
|
589
|
-
"Request is disconnected from the client side. "
|
590
|
-
f"Abort request {obj.rid}"
|
619
|
+
f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
|
591
620
|
)
|
592
621
|
continue
|
593
622
|
|
@@ -602,7 +631,6 @@ class TokenizerManager:
|
|
602
631
|
else:
|
603
632
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
604
633
|
logger.info(msg)
|
605
|
-
del self.rid_to_state[obj.rid]
|
606
634
|
|
607
635
|
# Check if this was an abort/error created by scheduler
|
608
636
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
@@ -622,10 +650,11 @@ class TokenizerManager:
|
|
622
650
|
yield out
|
623
651
|
else:
|
624
652
|
if request is not None and await request.is_disconnected():
|
653
|
+
# Abort the request for disconnected requests (non-streaming, running)
|
625
654
|
self.abort_request(obj.rid)
|
655
|
+
# Use exception to kill the whole call stack and asyncio task
|
626
656
|
raise ValueError(
|
627
|
-
"Request is disconnected from the client side. "
|
628
|
-
f"Abort request {obj.rid}"
|
657
|
+
f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
|
629
658
|
)
|
630
659
|
|
631
660
|
async def _handle_batch_request(
|
@@ -725,7 +754,6 @@ class TokenizerManager:
|
|
725
754
|
def abort_request(self, rid: str):
|
726
755
|
if rid not in self.rid_to_state:
|
727
756
|
return
|
728
|
-
del self.rid_to_state[rid]
|
729
757
|
req = AbortReq(rid)
|
730
758
|
self.send_to_scheduler.send_pyobj(req)
|
731
759
|
|
@@ -734,12 +762,16 @@ class TokenizerManager:
|
|
734
762
|
output_dir: Optional[str] = None,
|
735
763
|
num_steps: Optional[int] = None,
|
736
764
|
activities: Optional[List[str]] = None,
|
765
|
+
with_stack: Optional[bool] = None,
|
766
|
+
record_shapes: Optional[bool] = None,
|
737
767
|
):
|
738
768
|
req = ProfileReq(
|
739
769
|
type=ProfileReqType.START_PROFILE,
|
740
770
|
output_dir=output_dir,
|
741
771
|
num_steps=num_steps,
|
742
772
|
activities=activities,
|
773
|
+
with_stack=with_stack,
|
774
|
+
record_shapes=record_shapes,
|
743
775
|
profile_id=str(time.time()),
|
744
776
|
)
|
745
777
|
result = (await self.start_profile_communicator(req))[0]
|
@@ -876,6 +908,14 @@ class TokenizerManager:
|
|
876
908
|
self.auto_create_handle_loop()
|
877
909
|
await self.resume_memory_occupation_communicator(obj)
|
878
910
|
|
911
|
+
async def slow_down(
|
912
|
+
self,
|
913
|
+
obj: SlowDownReqInput,
|
914
|
+
request: Optional[fastapi.Request] = None,
|
915
|
+
):
|
916
|
+
self.auto_create_handle_loop()
|
917
|
+
await self.slow_down_communicator(obj)
|
918
|
+
|
879
919
|
async def open_session(
|
880
920
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
881
921
|
):
|
@@ -898,12 +938,13 @@ class TokenizerManager:
|
|
898
938
|
):
|
899
939
|
await self.send_to_scheduler.send_pyobj(obj)
|
900
940
|
|
901
|
-
async def get_internal_state(self) -> Dict[Any, Any]:
|
941
|
+
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
902
942
|
req = GetInternalStateReq()
|
903
|
-
|
943
|
+
responses: List[GetInternalStateReqOutput] = (
|
904
944
|
await self.get_internal_state_communicator(req)
|
905
945
|
)
|
906
|
-
|
946
|
+
# Many DP ranks
|
947
|
+
return [res.internal_state for res in responses]
|
907
948
|
|
908
949
|
def get_log_request_metadata(self):
|
909
950
|
max_length = None
|
@@ -953,7 +994,7 @@ class TokenizerManager:
|
|
953
994
|
def create_abort_task(self, obj: GenerateReqInput):
|
954
995
|
# Abort the request if the client is disconnected.
|
955
996
|
async def abort_request():
|
956
|
-
await asyncio.sleep(
|
997
|
+
await asyncio.sleep(2)
|
957
998
|
if obj.is_single:
|
958
999
|
self.abort_request(obj.rid)
|
959
1000
|
else:
|
@@ -1024,6 +1065,9 @@ class TokenizerManager:
|
|
1024
1065
|
for i, rid in enumerate(recv_obj.rids):
|
1025
1066
|
state = self.rid_to_state.get(rid, None)
|
1026
1067
|
if state is None:
|
1068
|
+
logger.error(
|
1069
|
+
f"Received output for {rid=} but the state was deleted in TokenizerManager."
|
1070
|
+
)
|
1027
1071
|
continue
|
1028
1072
|
|
1029
1073
|
# Build meta_info and return value
|
@@ -1036,9 +1080,11 @@ class TokenizerManager:
|
|
1036
1080
|
if getattr(state.obj, "return_logprob", False):
|
1037
1081
|
self.convert_logprob_style(
|
1038
1082
|
meta_info,
|
1083
|
+
state,
|
1039
1084
|
state.obj.top_logprobs_num,
|
1040
1085
|
state.obj.token_ids_logprob,
|
1041
|
-
state.obj.return_text_in_logprobs
|
1086
|
+
state.obj.return_text_in_logprobs
|
1087
|
+
and not self.server_args.skip_tokenizer_init,
|
1042
1088
|
recv_obj,
|
1043
1089
|
i,
|
1044
1090
|
)
|
@@ -1055,18 +1101,19 @@ class TokenizerManager:
|
|
1055
1101
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
1056
1102
|
|
1057
1103
|
if isinstance(recv_obj, BatchStrOut):
|
1104
|
+
state.text += recv_obj.output_strs[i]
|
1058
1105
|
out_dict = {
|
1059
|
-
"text":
|
1106
|
+
"text": state.text,
|
1060
1107
|
"meta_info": meta_info,
|
1061
1108
|
}
|
1062
1109
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
1063
1110
|
if self.server_args.stream_output and state.obj.stream:
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
state.last_output_offset = len(recv_obj.output_ids[i])
|
1111
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1112
|
+
output_token_ids = state.output_ids[state.last_output_offset :]
|
1113
|
+
state.last_output_offset = len(state.output_ids)
|
1068
1114
|
else:
|
1069
|
-
|
1115
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1116
|
+
output_token_ids = state.output_ids
|
1070
1117
|
|
1071
1118
|
out_dict = {
|
1072
1119
|
"output_ids": output_token_ids,
|
@@ -1087,6 +1134,7 @@ class TokenizerManager:
|
|
1087
1134
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
1088
1135
|
state.finished_time = time.time()
|
1089
1136
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1137
|
+
del self.rid_to_state[rid]
|
1090
1138
|
|
1091
1139
|
state.out_list.append(out_dict)
|
1092
1140
|
state.event.set()
|
@@ -1100,45 +1148,85 @@ class TokenizerManager:
|
|
1100
1148
|
def convert_logprob_style(
|
1101
1149
|
self,
|
1102
1150
|
meta_info: dict,
|
1151
|
+
state: ReqState,
|
1103
1152
|
top_logprobs_num: int,
|
1104
1153
|
token_ids_logprob: List[int],
|
1105
1154
|
return_text_in_logprobs: bool,
|
1106
1155
|
recv_obj: BatchStrOut,
|
1107
1156
|
recv_obj_index: int,
|
1108
1157
|
):
|
1158
|
+
if len(recv_obj.input_token_logprobs_val) > 0:
|
1159
|
+
state.input_token_logprobs_val.extend(
|
1160
|
+
recv_obj.input_token_logprobs_val[recv_obj_index]
|
1161
|
+
)
|
1162
|
+
state.input_token_logprobs_idx.extend(
|
1163
|
+
recv_obj.input_token_logprobs_idx[recv_obj_index]
|
1164
|
+
)
|
1165
|
+
state.output_token_logprobs_val.extend(
|
1166
|
+
recv_obj.output_token_logprobs_val[recv_obj_index]
|
1167
|
+
)
|
1168
|
+
state.output_token_logprobs_idx.extend(
|
1169
|
+
recv_obj.output_token_logprobs_idx[recv_obj_index]
|
1170
|
+
)
|
1109
1171
|
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
1110
|
-
|
1111
|
-
|
1172
|
+
state.input_token_logprobs_val,
|
1173
|
+
state.input_token_logprobs_idx,
|
1112
1174
|
return_text_in_logprobs,
|
1113
1175
|
)
|
1114
1176
|
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
1115
|
-
|
1116
|
-
|
1177
|
+
state.output_token_logprobs_val,
|
1178
|
+
state.output_token_logprobs_idx,
|
1117
1179
|
return_text_in_logprobs,
|
1118
1180
|
)
|
1119
1181
|
|
1120
1182
|
if top_logprobs_num > 0:
|
1183
|
+
if len(recv_obj.input_top_logprobs_val) > 0:
|
1184
|
+
state.input_top_logprobs_val.extend(
|
1185
|
+
recv_obj.input_top_logprobs_val[recv_obj_index]
|
1186
|
+
)
|
1187
|
+
state.input_top_logprobs_idx.extend(
|
1188
|
+
recv_obj.input_top_logprobs_idx[recv_obj_index]
|
1189
|
+
)
|
1190
|
+
state.output_top_logprobs_val.extend(
|
1191
|
+
recv_obj.output_top_logprobs_val[recv_obj_index]
|
1192
|
+
)
|
1193
|
+
state.output_top_logprobs_idx.extend(
|
1194
|
+
recv_obj.output_top_logprobs_idx[recv_obj_index]
|
1195
|
+
)
|
1121
1196
|
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
1122
|
-
|
1123
|
-
|
1197
|
+
state.input_top_logprobs_val,
|
1198
|
+
state.input_top_logprobs_idx,
|
1124
1199
|
return_text_in_logprobs,
|
1125
1200
|
)
|
1126
1201
|
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
1127
|
-
|
1128
|
-
|
1202
|
+
state.output_top_logprobs_val,
|
1203
|
+
state.output_top_logprobs_idx,
|
1129
1204
|
return_text_in_logprobs,
|
1130
1205
|
)
|
1131
1206
|
|
1132
1207
|
if token_ids_logprob is not None:
|
1208
|
+
if len(recv_obj.input_token_ids_logprobs_val) > 0:
|
1209
|
+
state.input_token_ids_logprobs_val.extend(
|
1210
|
+
recv_obj.input_token_ids_logprobs_val[recv_obj_index]
|
1211
|
+
)
|
1212
|
+
state.input_token_ids_logprobs_idx.extend(
|
1213
|
+
recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
|
1214
|
+
)
|
1215
|
+
state.output_token_ids_logprobs_val.extend(
|
1216
|
+
recv_obj.output_token_ids_logprobs_val[recv_obj_index]
|
1217
|
+
)
|
1218
|
+
state.output_token_ids_logprobs_idx.extend(
|
1219
|
+
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
|
1220
|
+
)
|
1133
1221
|
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
1134
|
-
|
1135
|
-
|
1222
|
+
state.input_token_ids_logprobs_val,
|
1223
|
+
state.input_token_ids_logprobs_idx,
|
1136
1224
|
return_text_in_logprobs,
|
1137
1225
|
)
|
1138
1226
|
meta_info["output_token_ids_logprobs"] = (
|
1139
1227
|
self.detokenize_top_logprobs_tokens(
|
1140
|
-
|
1141
|
-
|
1228
|
+
state.output_token_ids_logprobs_val,
|
1229
|
+
state.output_token_ids_logprobs_idx,
|
1142
1230
|
return_text_in_logprobs,
|
1143
1231
|
)
|
1144
1232
|
)
|
@@ -1205,11 +1293,18 @@ class TokenizerManager:
|
|
1205
1293
|
state.last_completion_tokens = completion_tokens
|
1206
1294
|
|
1207
1295
|
if state.finished:
|
1296
|
+
has_grammar = (
|
1297
|
+
state.obj.sampling_params.get("json_schema", None)
|
1298
|
+
or state.obj.sampling_params.get("regex", None)
|
1299
|
+
or state.obj.sampling_params.get("ebnf", None)
|
1300
|
+
or state.obj.sampling_params.get("structural_tag", None)
|
1301
|
+
)
|
1208
1302
|
self.metrics_collector.observe_one_finished_request(
|
1209
1303
|
recv_obj.prompt_tokens[i],
|
1210
1304
|
completion_tokens,
|
1211
1305
|
recv_obj.cached_tokens[i],
|
1212
1306
|
state.finished_time - state.created_time,
|
1307
|
+
has_grammar,
|
1213
1308
|
)
|
1214
1309
|
|
1215
1310
|
def dump_requests(self, state: ReqState, out_dict: dict):
|
@@ -1235,6 +1330,9 @@ class TokenizerManager:
|
|
1235
1330
|
# Schedule the task to run in the background without awaiting it
|
1236
1331
|
asyncio.create_task(asyncio.to_thread(background_task))
|
1237
1332
|
|
1333
|
+
def _handle_abort_req(self, recv_obj):
|
1334
|
+
self.rid_to_state.pop(recv_obj.rid)
|
1335
|
+
|
1238
1336
|
def _handle_open_session_req_output(self, recv_obj):
|
1239
1337
|
self.session_futures[recv_obj.session_id].set_result(
|
1240
1338
|
recv_obj.session_id if recv_obj.success else None
|
@@ -1245,7 +1343,7 @@ class TokenizerManager:
|
|
1245
1343
|
self.model_update_result.set_result(recv_obj)
|
1246
1344
|
else: # self.server_args.dp_size > 1
|
1247
1345
|
self.model_update_tmp.append(recv_obj)
|
1248
|
-
# set future if the all results are
|
1346
|
+
# set future if the all results are received
|
1249
1347
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
1250
1348
|
self.model_update_result.set_result(self.model_update_tmp)
|
1251
1349
|
|
@@ -1314,3 +1412,15 @@ class _Communicator(Generic[T]):
|
|
1314
1412
|
self._result_values.append(recv_obj)
|
1315
1413
|
if len(self._result_values) == self._fan_out:
|
1316
1414
|
self._result_event.set()
|
1415
|
+
|
1416
|
+
|
1417
|
+
# Note: request abort handling logic
|
1418
|
+
# We should handle all of the following cases correctly.
|
1419
|
+
#
|
1420
|
+
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
1421
|
+
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
1422
|
+
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
1423
|
+
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
1424
|
+
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
1425
|
+
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
1426
|
+
#
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -20,8 +20,12 @@ from typing import Optional, Tuple, Union
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from sglang.srt.configs.model_config import ModelConfig
|
23
|
-
from sglang.srt.distributed import get_pp_group,
|
24
|
-
from sglang.srt.hf_transformers_utils import
|
23
|
+
from sglang.srt.distributed import get_pp_group, get_world_group
|
24
|
+
from sglang.srt.hf_transformers_utils import (
|
25
|
+
get_processor,
|
26
|
+
get_tokenizer,
|
27
|
+
get_tokenizer_from_processor,
|
28
|
+
)
|
25
29
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
26
30
|
from sglang.srt.managers.io_struct import (
|
27
31
|
GetWeightsByNameReqInput,
|
@@ -61,20 +65,13 @@ class TpModelWorker:
|
|
61
65
|
self.pp_rank = pp_rank
|
62
66
|
|
63
67
|
# Init model and tokenizer
|
64
|
-
self.model_config = ModelConfig(
|
65
|
-
|
68
|
+
self.model_config = ModelConfig.from_server_args(
|
69
|
+
server_args,
|
70
|
+
model_path=(
|
66
71
|
server_args.model_path
|
67
72
|
if not is_draft_worker
|
68
73
|
else server_args.speculative_draft_model_path
|
69
74
|
),
|
70
|
-
trust_remote_code=server_args.trust_remote_code,
|
71
|
-
revision=server_args.revision,
|
72
|
-
context_length=server_args.context_length,
|
73
|
-
model_override_args=server_args.json_model_override_args,
|
74
|
-
is_embedding=server_args.is_embedding,
|
75
|
-
enable_multimodal=server_args.enable_multimodal,
|
76
|
-
dtype=server_args.dtype,
|
77
|
-
quantization=server_args.quantization,
|
78
75
|
is_draft_model=is_draft_worker,
|
79
76
|
)
|
80
77
|
|
@@ -102,7 +99,7 @@ class TpModelWorker:
|
|
102
99
|
trust_remote_code=server_args.trust_remote_code,
|
103
100
|
revision=server_args.revision,
|
104
101
|
)
|
105
|
-
self.tokenizer = self.processor
|
102
|
+
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
106
103
|
else:
|
107
104
|
self.tokenizer = get_tokenizer(
|
108
105
|
server_args.tokenizer_path,
|
@@ -186,8 +183,11 @@ class TpModelWorker:
|
|
186
183
|
def forward_batch_generation(
|
187
184
|
self,
|
188
185
|
model_worker_batch: ModelWorkerBatch,
|
186
|
+
launch_done: Optional[threading.Event] = None,
|
189
187
|
skip_sample: bool = False,
|
190
|
-
) -> Tuple[
|
188
|
+
) -> Tuple[
|
189
|
+
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
190
|
+
]:
|
191
191
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
192
192
|
|
193
193
|
pp_proxy_tensors = None
|
@@ -199,11 +199,11 @@ class TpModelWorker:
|
|
199
199
|
)
|
200
200
|
|
201
201
|
if self.pp_group.is_last_rank:
|
202
|
-
logits_output = self.model_runner.forward(
|
202
|
+
logits_output, can_run_cuda_graph = self.model_runner.forward(
|
203
203
|
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
204
204
|
)
|
205
|
-
if
|
206
|
-
|
205
|
+
if launch_done is not None:
|
206
|
+
launch_done.set()
|
207
207
|
|
208
208
|
if skip_sample:
|
209
209
|
next_token_ids = None
|
@@ -212,17 +212,17 @@ class TpModelWorker:
|
|
212
212
|
logits_output, model_worker_batch
|
213
213
|
)
|
214
214
|
|
215
|
-
return logits_output, next_token_ids
|
215
|
+
return logits_output, next_token_ids, can_run_cuda_graph
|
216
216
|
else:
|
217
|
-
pp_proxy_tensors = self.model_runner.forward(
|
217
|
+
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
218
218
|
forward_batch,
|
219
219
|
pp_proxy_tensors=pp_proxy_tensors,
|
220
220
|
)
|
221
|
-
return pp_proxy_tensors.tensors, None
|
221
|
+
return pp_proxy_tensors.tensors, None, can_run_cuda_graph
|
222
222
|
|
223
223
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
224
224
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
225
|
-
logits_output = self.model_runner.forward(forward_batch)
|
225
|
+
logits_output, _ = self.model_runner.forward(forward_batch)
|
226
226
|
embeddings = logits_output.embeddings
|
227
227
|
return embeddings
|
228
228
|
|
@@ -18,7 +18,7 @@ import logging
|
|
18
18
|
import signal
|
19
19
|
import threading
|
20
20
|
from queue import Queue
|
21
|
-
from typing import Optional
|
21
|
+
from typing import Optional, Tuple
|
22
22
|
|
23
23
|
import psutil
|
24
24
|
import torch
|
@@ -127,10 +127,12 @@ class TpModelWorkerClient:
|
|
127
127
|
batch_lists = [None] * 2
|
128
128
|
|
129
129
|
while True:
|
130
|
-
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
130
|
+
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
131
131
|
if not model_worker_batch:
|
132
132
|
break
|
133
133
|
|
134
|
+
sync_event.wait()
|
135
|
+
|
134
136
|
# Keep a reference of model_worker_batch by storing it into a list.
|
135
137
|
# Otherwise, the tensor members of model_worker_batch will be released
|
136
138
|
# by pytorch and cause CUDA illegal memory access errors.
|
@@ -145,8 +147,10 @@ class TpModelWorkerClient:
|
|
145
147
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
146
148
|
|
147
149
|
# Run forward
|
148
|
-
logits_output, next_token_ids =
|
149
|
-
|
150
|
+
logits_output, next_token_ids, can_run_cuda_graph = (
|
151
|
+
self.worker.forward_batch_generation(
|
152
|
+
model_worker_batch, model_worker_batch.launch_done
|
153
|
+
)
|
150
154
|
)
|
151
155
|
|
152
156
|
# Update the future token ids map
|
@@ -171,14 +175,18 @@ class TpModelWorkerClient:
|
|
171
175
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
172
176
|
copy_done.record()
|
173
177
|
|
174
|
-
self.output_queue.put(
|
178
|
+
self.output_queue.put(
|
179
|
+
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
|
180
|
+
)
|
175
181
|
|
176
182
|
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
177
183
|
"""
|
178
184
|
This function is called to resolve the last batch result and
|
179
185
|
wait for the current batch to be launched. Used in overlap mode.
|
180
186
|
"""
|
181
|
-
copy_done, logits_output, next_token_ids =
|
187
|
+
copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
|
188
|
+
self.output_queue.get()
|
189
|
+
)
|
182
190
|
|
183
191
|
if launch_done is not None:
|
184
192
|
launch_done.wait()
|
@@ -193,9 +201,11 @@ class TpModelWorkerClient:
|
|
193
201
|
logits_output.input_token_logprobs.tolist()
|
194
202
|
)
|
195
203
|
next_token_ids = next_token_ids.tolist()
|
196
|
-
return logits_output, next_token_ids
|
204
|
+
return logits_output, next_token_ids, can_run_cuda_graph
|
197
205
|
|
198
|
-
def forward_batch_generation(
|
206
|
+
def forward_batch_generation(
|
207
|
+
self, model_worker_batch: ModelWorkerBatch
|
208
|
+
) -> Tuple[None, torch.Tensor, bool]:
|
199
209
|
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
200
210
|
sampling_info = model_worker_batch.sampling_info
|
201
211
|
sampling_info.update_penalties()
|
@@ -206,10 +216,11 @@ class TpModelWorkerClient:
|
|
206
216
|
)
|
207
217
|
|
208
218
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
209
|
-
self.
|
219
|
+
sync_event = torch.get_device_module(self.device).Event()
|
220
|
+
sync_event.record(self.scheduler_stream)
|
210
221
|
|
211
222
|
# Push a new batch to the queue
|
212
|
-
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
223
|
+
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
|
213
224
|
|
214
225
|
# Allocate output future objects
|
215
226
|
bs = len(model_worker_batch.seq_lens)
|
@@ -223,7 +234,7 @@ class TpModelWorkerClient:
|
|
223
234
|
self.future_token_ids_ct = (
|
224
235
|
self.future_token_ids_ct + bs
|
225
236
|
) % self.future_token_ids_limit
|
226
|
-
return None, future_next_token_ids
|
237
|
+
return None, future_next_token_ids, False
|
227
238
|
|
228
239
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
229
240
|
success, message = self.worker.update_weights_from_disk(recv_req)
|
@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache):
|
|
24
24
|
self,
|
25
25
|
req_to_token_pool: ReqToTokenPool,
|
26
26
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
27
|
+
page_size: int,
|
27
28
|
):
|
28
29
|
self.req_to_token_pool = req_to_token_pool
|
29
30
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
31
|
+
self.page_size = page_size
|
30
32
|
|
31
33
|
def reset(self):
|
32
34
|
pass
|
@@ -374,9 +374,9 @@ class MHATokenToKVPool(KVCache):
|
|
374
374
|
# Overlap the copy of K and V cache for small batch size
|
375
375
|
current_stream = self.device_module.current_stream()
|
376
376
|
self.alt_stream.wait_stream(current_stream)
|
377
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
377
378
|
with self.device_module.stream(self.alt_stream):
|
378
|
-
self.
|
379
|
-
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
379
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
380
380
|
current_stream.wait_stream(self.alt_stream)
|
381
381
|
else:
|
382
382
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
@@ -762,6 +762,8 @@ class HostKVCache(abc.ABC):
|
|
762
762
|
self.size = int(device_pool.size * host_to_device_ratio)
|
763
763
|
# Align the host memory pool size to the page size
|
764
764
|
self.size = self.size - (self.size % self.page_size)
|
765
|
+
self.start_layer = device_pool.start_layer
|
766
|
+
self.end_layer = device_pool.end_layer
|
765
767
|
|
766
768
|
assert (
|
767
769
|
self.size > device_pool.size
|