sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
import asyncio
|
17
17
|
import copy
|
18
18
|
import dataclasses
|
19
|
+
import json
|
19
20
|
import logging
|
20
21
|
import os
|
21
22
|
import pickle
|
@@ -90,6 +91,8 @@ from sglang.srt.managers.io_struct import (
|
|
90
91
|
ResumeMemoryOccupationReqInput,
|
91
92
|
ResumeMemoryOccupationReqOutput,
|
92
93
|
SessionParams,
|
94
|
+
SetInternalStateReq,
|
95
|
+
SetInternalStateReqOutput,
|
93
96
|
SlowDownReqInput,
|
94
97
|
SlowDownReqOutput,
|
95
98
|
TokenizedEmbeddingReqInput,
|
@@ -125,10 +128,10 @@ logger = logging.getLogger(__name__)
|
|
125
128
|
class ReqState:
|
126
129
|
"""Store the state a request."""
|
127
130
|
|
128
|
-
out_list: List
|
131
|
+
out_list: List[Dict[Any, Any]]
|
129
132
|
finished: bool
|
130
133
|
event: asyncio.Event
|
131
|
-
obj:
|
134
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput]
|
132
135
|
|
133
136
|
# For metrics
|
134
137
|
created_time: float
|
@@ -139,6 +142,21 @@ class ReqState:
|
|
139
142
|
|
140
143
|
# For streaming output
|
141
144
|
last_output_offset: int = 0
|
145
|
+
# For incremental state update.
|
146
|
+
text: str = ""
|
147
|
+
output_ids: List[int] = dataclasses.field(default_factory=list)
|
148
|
+
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
149
|
+
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
150
|
+
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
151
|
+
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
152
|
+
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
153
|
+
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
154
|
+
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
155
|
+
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
156
|
+
input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
157
|
+
input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
158
|
+
output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
159
|
+
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
142
160
|
|
143
161
|
|
144
162
|
class TokenizerManager:
|
@@ -154,6 +172,11 @@ class TokenizerManager:
|
|
154
172
|
self.enable_metrics = server_args.enable_metrics
|
155
173
|
self.log_requests = server_args.log_requests
|
156
174
|
self.log_requests_level = server_args.log_requests_level
|
175
|
+
self.preferred_sampling_params = (
|
176
|
+
json.loads(server_args.preferred_sampling_params)
|
177
|
+
if server_args.preferred_sampling_params
|
178
|
+
else None
|
179
|
+
)
|
157
180
|
|
158
181
|
# Init inter-process communication
|
159
182
|
context = zmq.asyncio.Context(2)
|
@@ -213,6 +236,7 @@ class TokenizerManager:
|
|
213
236
|
# Store states
|
214
237
|
self.no_create_loop = False
|
215
238
|
self.rid_to_state: Dict[str, ReqState] = {}
|
239
|
+
self.health_check_failed = False
|
216
240
|
self.gracefully_exit = False
|
217
241
|
self.last_receive_tstamp = 0
|
218
242
|
self.dump_requests_folder = "" # By default do not dump
|
@@ -240,6 +264,10 @@ class TokenizerManager:
|
|
240
264
|
"model_name": self.server_args.served_model_name,
|
241
265
|
# TODO: Add lora name/path in the future,
|
242
266
|
},
|
267
|
+
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
268
|
+
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
269
|
+
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
270
|
+
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
|
243
271
|
)
|
244
272
|
|
245
273
|
# Communicators
|
@@ -267,12 +295,16 @@ class TokenizerManager:
|
|
267
295
|
self.flush_cache_communicator = _Communicator(
|
268
296
|
self.send_to_scheduler, server_args.dp_size
|
269
297
|
)
|
270
|
-
self.
|
298
|
+
self.profile_communicator = _Communicator(
|
271
299
|
self.send_to_scheduler, server_args.dp_size
|
272
300
|
)
|
301
|
+
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
273
302
|
self.get_internal_state_communicator = _Communicator(
|
274
303
|
self.send_to_scheduler, server_args.dp_size
|
275
304
|
)
|
305
|
+
self.set_internal_state_communicator = _Communicator(
|
306
|
+
self.send_to_scheduler, server_args.dp_size
|
307
|
+
)
|
276
308
|
self.expert_distribution_communicator = _Communicator(
|
277
309
|
self.send_to_scheduler, server_args.dp_size
|
278
310
|
)
|
@@ -288,6 +320,7 @@ class TokenizerManager:
|
|
288
320
|
),
|
289
321
|
self._handle_batch_output,
|
290
322
|
),
|
323
|
+
(AbortReq, self._handle_abort_req),
|
291
324
|
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
292
325
|
(
|
293
326
|
UpdateWeightFromDiskReqOutput,
|
@@ -327,12 +360,16 @@ class TokenizerManager:
|
|
327
360
|
),
|
328
361
|
(
|
329
362
|
ProfileReqOutput,
|
330
|
-
self.
|
363
|
+
self.profile_communicator.handle_recv,
|
331
364
|
),
|
332
365
|
(
|
333
366
|
GetInternalStateReqOutput,
|
334
367
|
self.get_internal_state_communicator.handle_recv,
|
335
368
|
),
|
369
|
+
(
|
370
|
+
SetInternalStateReqOutput,
|
371
|
+
self.set_internal_state_communicator.handle_recv,
|
372
|
+
),
|
336
373
|
(
|
337
374
|
ExpertDistributionReqOutput,
|
338
375
|
self.expert_distribution_communicator.handle_recv,
|
@@ -341,13 +378,14 @@ class TokenizerManager:
|
|
341
378
|
]
|
342
379
|
)
|
343
380
|
|
381
|
+
# For pd disaggregtion
|
344
382
|
self.disaggregation_mode = DisaggregationMode(
|
345
383
|
self.server_args.disaggregation_mode
|
346
384
|
)
|
347
385
|
self.transfer_backend = TransferBackend(
|
348
386
|
self.server_args.disaggregation_transfer_backend
|
349
387
|
)
|
350
|
-
#
|
388
|
+
# Start kv boostrap server on prefill
|
351
389
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
352
390
|
# only start bootstrap server on prefill tm
|
353
391
|
kv_bootstrap_server_class = get_kv_class(
|
@@ -421,14 +459,16 @@ class TokenizerManager:
|
|
421
459
|
)
|
422
460
|
input_ids = self.tokenizer.encode(input_text)
|
423
461
|
|
424
|
-
image_inputs: Dict =
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
462
|
+
image_inputs: Optional[Dict] = None
|
463
|
+
if obj.contains_mm_input():
|
464
|
+
image_inputs = await self.mm_processor.process_mm_data_async(
|
465
|
+
image_data=obj.image_data,
|
466
|
+
input_text=input_text or input_ids,
|
467
|
+
request_obj=obj,
|
468
|
+
max_req_input_len=self.max_req_input_len,
|
469
|
+
)
|
470
|
+
if image_inputs and "input_ids" in image_inputs:
|
471
|
+
input_ids = image_inputs["input_ids"]
|
432
472
|
|
433
473
|
self._validate_token_len(obj, input_ids)
|
434
474
|
return self._create_tokenized_object(
|
@@ -482,8 +522,23 @@ class TokenizerManager:
|
|
482
522
|
session_params = (
|
483
523
|
SessionParams(**obj.session_params) if obj.session_params else None
|
484
524
|
)
|
525
|
+
if (
|
526
|
+
obj.custom_logit_processor
|
527
|
+
and not self.server_args.enable_custom_logit_processor
|
528
|
+
):
|
529
|
+
raise ValueError(
|
530
|
+
"The server is not configured to enable custom logit processor. "
|
531
|
+
"Please set `--enable-custom-logits-processor` to enable this feature."
|
532
|
+
)
|
485
533
|
|
486
|
-
|
534
|
+
# Parse sampling parameters
|
535
|
+
# Note: if there are preferred sampling params, we use them if they are not
|
536
|
+
# explicitly passed in sampling_params
|
537
|
+
if self.preferred_sampling_params:
|
538
|
+
sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
|
539
|
+
else:
|
540
|
+
sampling_kwargs = obj.sampling_params
|
541
|
+
sampling_params = SamplingParams(**sampling_kwargs)
|
487
542
|
sampling_params.normalize(self.tokenizer)
|
488
543
|
sampling_params.verify()
|
489
544
|
|
@@ -570,9 +625,9 @@ class TokenizerManager:
|
|
570
625
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
571
626
|
created_time: Optional[float] = None,
|
572
627
|
):
|
628
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
573
629
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
574
630
|
self.rid_to_state[obj.rid] = state
|
575
|
-
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
576
631
|
|
577
632
|
async def _wait_one_response(
|
578
633
|
self,
|
@@ -587,10 +642,11 @@ class TokenizerManager:
|
|
587
642
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
588
643
|
except asyncio.TimeoutError:
|
589
644
|
if request is not None and await request.is_disconnected():
|
645
|
+
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
590
646
|
self.abort_request(obj.rid)
|
647
|
+
# Use exception to kill the whole call stack and asyncio task
|
591
648
|
raise ValueError(
|
592
|
-
"Request is disconnected from the client side. "
|
593
|
-
f"Abort request {obj.rid}"
|
649
|
+
f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
|
594
650
|
)
|
595
651
|
continue
|
596
652
|
|
@@ -605,7 +661,6 @@ class TokenizerManager:
|
|
605
661
|
else:
|
606
662
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
607
663
|
logger.info(msg)
|
608
|
-
del self.rid_to_state[obj.rid]
|
609
664
|
|
610
665
|
# Check if this was an abort/error created by scheduler
|
611
666
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
@@ -625,10 +680,11 @@ class TokenizerManager:
|
|
625
680
|
yield out
|
626
681
|
else:
|
627
682
|
if request is not None and await request.is_disconnected():
|
683
|
+
# Abort the request for disconnected requests (non-streaming, running)
|
628
684
|
self.abort_request(obj.rid)
|
685
|
+
# Use exception to kill the whole call stack and asyncio task
|
629
686
|
raise ValueError(
|
630
|
-
"Request is disconnected from the client side. "
|
631
|
-
f"Abort request {obj.rid}"
|
687
|
+
f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
|
632
688
|
)
|
633
689
|
|
634
690
|
async def _handle_batch_request(
|
@@ -641,7 +697,6 @@ class TokenizerManager:
|
|
641
697
|
|
642
698
|
generators = []
|
643
699
|
rids = []
|
644
|
-
|
645
700
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
646
701
|
if self.server_args.enable_tokenizer_batch_encode:
|
647
702
|
# Validate batch tokenization constraints
|
@@ -728,7 +783,6 @@ class TokenizerManager:
|
|
728
783
|
def abort_request(self, rid: str):
|
729
784
|
if rid not in self.rid_to_state:
|
730
785
|
return
|
731
|
-
del self.rid_to_state[rid]
|
732
786
|
req = AbortReq(rid)
|
733
787
|
self.send_to_scheduler.send_pyobj(req)
|
734
788
|
|
@@ -737,30 +791,42 @@ class TokenizerManager:
|
|
737
791
|
output_dir: Optional[str] = None,
|
738
792
|
num_steps: Optional[int] = None,
|
739
793
|
activities: Optional[List[str]] = None,
|
794
|
+
with_stack: Optional[bool] = None,
|
795
|
+
record_shapes: Optional[bool] = None,
|
740
796
|
):
|
797
|
+
self.auto_create_handle_loop()
|
741
798
|
req = ProfileReq(
|
742
799
|
type=ProfileReqType.START_PROFILE,
|
743
800
|
output_dir=output_dir,
|
744
801
|
num_steps=num_steps,
|
745
802
|
activities=activities,
|
803
|
+
with_stack=with_stack,
|
804
|
+
record_shapes=record_shapes,
|
746
805
|
profile_id=str(time.time()),
|
747
806
|
)
|
748
|
-
|
807
|
+
return await self._execute_profile(req)
|
808
|
+
|
809
|
+
async def stop_profile(self):
|
810
|
+
self.auto_create_handle_loop()
|
811
|
+
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
812
|
+
return await self._execute_profile(req)
|
813
|
+
|
814
|
+
async def _execute_profile(self, req: ProfileReq):
|
815
|
+
result = (await self.profile_communicator(req))[0]
|
749
816
|
if not result.success:
|
750
817
|
raise RuntimeError(result.message)
|
751
818
|
return result
|
752
819
|
|
753
|
-
def stop_profile(self):
|
754
|
-
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
755
|
-
self.send_to_scheduler.send_pyobj(req)
|
756
|
-
|
757
820
|
async def start_expert_distribution_record(self):
|
821
|
+
self.auto_create_handle_loop()
|
758
822
|
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
759
823
|
|
760
824
|
async def stop_expert_distribution_record(self):
|
825
|
+
self.auto_create_handle_loop()
|
761
826
|
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
762
827
|
|
763
828
|
async def dump_expert_distribution_record(self):
|
829
|
+
self.auto_create_handle_loop()
|
764
830
|
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
765
831
|
|
766
832
|
async def update_weights_from_disk(
|
@@ -827,8 +893,8 @@ class TokenizerManager:
|
|
827
893
|
) -> Tuple[bool, str]:
|
828
894
|
self.auto_create_handle_loop()
|
829
895
|
assert (
|
830
|
-
self.server_args.dp_size == 1
|
831
|
-
), "dp_size must be for update weights from distributed"
|
896
|
+
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
897
|
+
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
832
898
|
|
833
899
|
# This means that weight sync
|
834
900
|
# cannot run while requests are in progress.
|
@@ -843,8 +909,8 @@ class TokenizerManager:
|
|
843
909
|
) -> Tuple[bool, str]:
|
844
910
|
self.auto_create_handle_loop()
|
845
911
|
assert (
|
846
|
-
self.server_args.dp_size == 1
|
847
|
-
), "dp_size must be 1 for update weights from
|
912
|
+
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
913
|
+
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
848
914
|
|
849
915
|
# This means that weight sync
|
850
916
|
# cannot run while requests are in progress.
|
@@ -909,12 +975,21 @@ class TokenizerManager:
|
|
909
975
|
):
|
910
976
|
await self.send_to_scheduler.send_pyobj(obj)
|
911
977
|
|
912
|
-
async def get_internal_state(self) -> Dict[Any, Any]:
|
978
|
+
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
913
979
|
req = GetInternalStateReq()
|
914
|
-
|
980
|
+
responses: List[GetInternalStateReqOutput] = (
|
915
981
|
await self.get_internal_state_communicator(req)
|
916
982
|
)
|
917
|
-
|
983
|
+
# Many DP ranks
|
984
|
+
return [res.internal_state for res in responses]
|
985
|
+
|
986
|
+
async def set_internal_state(
|
987
|
+
self, obj: SetInternalStateReq
|
988
|
+
) -> SetInternalStateReqOutput:
|
989
|
+
responses: List[SetInternalStateReqOutput] = (
|
990
|
+
await self.set_internal_state_communicator(obj)
|
991
|
+
)
|
992
|
+
return [res.internal_state for res in responses]
|
918
993
|
|
919
994
|
def get_log_request_metadata(self):
|
920
995
|
max_length = None
|
@@ -964,7 +1039,7 @@ class TokenizerManager:
|
|
964
1039
|
def create_abort_task(self, obj: GenerateReqInput):
|
965
1040
|
# Abort the request if the client is disconnected.
|
966
1041
|
async def abort_request():
|
967
|
-
await asyncio.sleep(
|
1042
|
+
await asyncio.sleep(2)
|
968
1043
|
if obj.is_single:
|
969
1044
|
self.abort_request(obj.rid)
|
970
1045
|
else:
|
@@ -985,11 +1060,17 @@ class TokenizerManager:
|
|
985
1060
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
986
1061
|
)
|
987
1062
|
|
1063
|
+
self.event_loop = loop
|
1064
|
+
|
988
1065
|
# We cannot add signal handler when the tokenizer manager is not in
|
989
1066
|
# the main thread due to the CPython limitation.
|
990
1067
|
if threading.current_thread() is threading.main_thread():
|
991
1068
|
signal_handler = SignalHandler(self)
|
992
|
-
loop.add_signal_handler(signal.SIGTERM, signal_handler.
|
1069
|
+
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
1070
|
+
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
1071
|
+
loop.add_signal_handler(
|
1072
|
+
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
1073
|
+
)
|
993
1074
|
else:
|
994
1075
|
logger.warning(
|
995
1076
|
"Signal handler is not added because the tokenizer manager is "
|
@@ -1007,6 +1088,15 @@ class TokenizerManager:
|
|
1007
1088
|
# Drain requests
|
1008
1089
|
while True:
|
1009
1090
|
remain_num_req = len(self.rid_to_state)
|
1091
|
+
|
1092
|
+
if self.health_check_failed:
|
1093
|
+
# if health check failed, we should exit immediately
|
1094
|
+
logger.error(
|
1095
|
+
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
1096
|
+
remain_num_req,
|
1097
|
+
)
|
1098
|
+
break
|
1099
|
+
|
1010
1100
|
logger.info(
|
1011
1101
|
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
1012
1102
|
)
|
@@ -1035,6 +1125,9 @@ class TokenizerManager:
|
|
1035
1125
|
for i, rid in enumerate(recv_obj.rids):
|
1036
1126
|
state = self.rid_to_state.get(rid, None)
|
1037
1127
|
if state is None:
|
1128
|
+
logger.error(
|
1129
|
+
f"Received output for {rid=} but the state was deleted in TokenizerManager."
|
1130
|
+
)
|
1038
1131
|
continue
|
1039
1132
|
|
1040
1133
|
# Build meta_info and return value
|
@@ -1047,9 +1140,11 @@ class TokenizerManager:
|
|
1047
1140
|
if getattr(state.obj, "return_logprob", False):
|
1048
1141
|
self.convert_logprob_style(
|
1049
1142
|
meta_info,
|
1143
|
+
state,
|
1050
1144
|
state.obj.top_logprobs_num,
|
1051
1145
|
state.obj.token_ids_logprob,
|
1052
|
-
state.obj.return_text_in_logprobs
|
1146
|
+
state.obj.return_text_in_logprobs
|
1147
|
+
and not self.server_args.skip_tokenizer_init,
|
1053
1148
|
recv_obj,
|
1054
1149
|
i,
|
1055
1150
|
)
|
@@ -1066,25 +1161,35 @@ class TokenizerManager:
|
|
1066
1161
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
1067
1162
|
|
1068
1163
|
if isinstance(recv_obj, BatchStrOut):
|
1164
|
+
state.text += recv_obj.output_strs[i]
|
1069
1165
|
out_dict = {
|
1070
|
-
"text":
|
1166
|
+
"text": state.text,
|
1071
1167
|
"meta_info": meta_info,
|
1072
1168
|
}
|
1073
1169
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
1074
1170
|
if self.server_args.stream_output and state.obj.stream:
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
state.last_output_offset = len(recv_obj.output_ids[i])
|
1171
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1172
|
+
output_token_ids = state.output_ids[state.last_output_offset :]
|
1173
|
+
state.last_output_offset = len(state.output_ids)
|
1079
1174
|
else:
|
1080
|
-
|
1175
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1176
|
+
output_token_ids = state.output_ids
|
1081
1177
|
|
1082
1178
|
out_dict = {
|
1083
1179
|
"output_ids": output_token_ids,
|
1084
1180
|
"meta_info": meta_info,
|
1085
1181
|
}
|
1086
1182
|
elif isinstance(recv_obj, BatchMultimodalOut):
|
1087
|
-
|
1183
|
+
if isinstance(recv_obj.outputs[i], str):
|
1184
|
+
out_dict = {
|
1185
|
+
"text": recv_obj.outputs[i],
|
1186
|
+
"meta_info": meta_info,
|
1187
|
+
}
|
1188
|
+
else:
|
1189
|
+
out_dict = {
|
1190
|
+
"outputs": json.dumps(recv_obj.outputs[i]),
|
1191
|
+
"meta_info": meta_info,
|
1192
|
+
}
|
1088
1193
|
else:
|
1089
1194
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
1090
1195
|
out_dict = {
|
@@ -1098,6 +1203,7 @@ class TokenizerManager:
|
|
1098
1203
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
1099
1204
|
state.finished_time = time.time()
|
1100
1205
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1206
|
+
del self.rid_to_state[rid]
|
1101
1207
|
|
1102
1208
|
state.out_list.append(out_dict)
|
1103
1209
|
state.event.set()
|
@@ -1111,45 +1217,85 @@ class TokenizerManager:
|
|
1111
1217
|
def convert_logprob_style(
|
1112
1218
|
self,
|
1113
1219
|
meta_info: dict,
|
1220
|
+
state: ReqState,
|
1114
1221
|
top_logprobs_num: int,
|
1115
1222
|
token_ids_logprob: List[int],
|
1116
1223
|
return_text_in_logprobs: bool,
|
1117
1224
|
recv_obj: BatchStrOut,
|
1118
1225
|
recv_obj_index: int,
|
1119
1226
|
):
|
1227
|
+
if len(recv_obj.input_token_logprobs_val) > 0:
|
1228
|
+
state.input_token_logprobs_val.extend(
|
1229
|
+
recv_obj.input_token_logprobs_val[recv_obj_index]
|
1230
|
+
)
|
1231
|
+
state.input_token_logprobs_idx.extend(
|
1232
|
+
recv_obj.input_token_logprobs_idx[recv_obj_index]
|
1233
|
+
)
|
1234
|
+
state.output_token_logprobs_val.extend(
|
1235
|
+
recv_obj.output_token_logprobs_val[recv_obj_index]
|
1236
|
+
)
|
1237
|
+
state.output_token_logprobs_idx.extend(
|
1238
|
+
recv_obj.output_token_logprobs_idx[recv_obj_index]
|
1239
|
+
)
|
1120
1240
|
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
1121
|
-
|
1122
|
-
|
1241
|
+
state.input_token_logprobs_val,
|
1242
|
+
state.input_token_logprobs_idx,
|
1123
1243
|
return_text_in_logprobs,
|
1124
1244
|
)
|
1125
1245
|
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
1126
|
-
|
1127
|
-
|
1246
|
+
state.output_token_logprobs_val,
|
1247
|
+
state.output_token_logprobs_idx,
|
1128
1248
|
return_text_in_logprobs,
|
1129
1249
|
)
|
1130
1250
|
|
1131
1251
|
if top_logprobs_num > 0:
|
1252
|
+
if len(recv_obj.input_top_logprobs_val) > 0:
|
1253
|
+
state.input_top_logprobs_val.extend(
|
1254
|
+
recv_obj.input_top_logprobs_val[recv_obj_index]
|
1255
|
+
)
|
1256
|
+
state.input_top_logprobs_idx.extend(
|
1257
|
+
recv_obj.input_top_logprobs_idx[recv_obj_index]
|
1258
|
+
)
|
1259
|
+
state.output_top_logprobs_val.extend(
|
1260
|
+
recv_obj.output_top_logprobs_val[recv_obj_index]
|
1261
|
+
)
|
1262
|
+
state.output_top_logprobs_idx.extend(
|
1263
|
+
recv_obj.output_top_logprobs_idx[recv_obj_index]
|
1264
|
+
)
|
1132
1265
|
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
1133
|
-
|
1134
|
-
|
1266
|
+
state.input_top_logprobs_val,
|
1267
|
+
state.input_top_logprobs_idx,
|
1135
1268
|
return_text_in_logprobs,
|
1136
1269
|
)
|
1137
1270
|
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
1138
|
-
|
1139
|
-
|
1271
|
+
state.output_top_logprobs_val,
|
1272
|
+
state.output_top_logprobs_idx,
|
1140
1273
|
return_text_in_logprobs,
|
1141
1274
|
)
|
1142
1275
|
|
1143
1276
|
if token_ids_logprob is not None:
|
1277
|
+
if len(recv_obj.input_token_ids_logprobs_val) > 0:
|
1278
|
+
state.input_token_ids_logprobs_val.extend(
|
1279
|
+
recv_obj.input_token_ids_logprobs_val[recv_obj_index]
|
1280
|
+
)
|
1281
|
+
state.input_token_ids_logprobs_idx.extend(
|
1282
|
+
recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
|
1283
|
+
)
|
1284
|
+
state.output_token_ids_logprobs_val.extend(
|
1285
|
+
recv_obj.output_token_ids_logprobs_val[recv_obj_index]
|
1286
|
+
)
|
1287
|
+
state.output_token_ids_logprobs_idx.extend(
|
1288
|
+
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
|
1289
|
+
)
|
1144
1290
|
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
1145
|
-
|
1146
|
-
|
1291
|
+
state.input_token_ids_logprobs_val,
|
1292
|
+
state.input_token_ids_logprobs_idx,
|
1147
1293
|
return_text_in_logprobs,
|
1148
1294
|
)
|
1149
1295
|
meta_info["output_token_ids_logprobs"] = (
|
1150
1296
|
self.detokenize_top_logprobs_tokens(
|
1151
|
-
|
1152
|
-
|
1297
|
+
state.output_token_ids_logprobs_val,
|
1298
|
+
state.output_token_ids_logprobs_idx,
|
1153
1299
|
return_text_in_logprobs,
|
1154
1300
|
)
|
1155
1301
|
)
|
@@ -1216,11 +1362,18 @@ class TokenizerManager:
|
|
1216
1362
|
state.last_completion_tokens = completion_tokens
|
1217
1363
|
|
1218
1364
|
if state.finished:
|
1365
|
+
has_grammar = (
|
1366
|
+
state.obj.sampling_params.get("json_schema", None)
|
1367
|
+
or state.obj.sampling_params.get("regex", None)
|
1368
|
+
or state.obj.sampling_params.get("ebnf", None)
|
1369
|
+
or state.obj.sampling_params.get("structural_tag", None)
|
1370
|
+
)
|
1219
1371
|
self.metrics_collector.observe_one_finished_request(
|
1220
1372
|
recv_obj.prompt_tokens[i],
|
1221
1373
|
completion_tokens,
|
1222
1374
|
recv_obj.cached_tokens[i],
|
1223
1375
|
state.finished_time - state.created_time,
|
1376
|
+
has_grammar,
|
1224
1377
|
)
|
1225
1378
|
|
1226
1379
|
def dump_requests(self, state: ReqState, out_dict: dict):
|
@@ -1246,6 +1399,9 @@ class TokenizerManager:
|
|
1246
1399
|
# Schedule the task to run in the background without awaiting it
|
1247
1400
|
asyncio.create_task(asyncio.to_thread(background_task))
|
1248
1401
|
|
1402
|
+
def _handle_abort_req(self, recv_obj):
|
1403
|
+
self.rid_to_state.pop(recv_obj.rid)
|
1404
|
+
|
1249
1405
|
def _handle_open_session_req_output(self, recv_obj):
|
1250
1406
|
self.session_futures[recv_obj.session_id].set_result(
|
1251
1407
|
recv_obj.session_id if recv_obj.success else None
|
@@ -1256,7 +1412,7 @@ class TokenizerManager:
|
|
1256
1412
|
self.model_update_result.set_result(recv_obj)
|
1257
1413
|
else: # self.server_args.dp_size > 1
|
1258
1414
|
self.model_update_tmp.append(recv_obj)
|
1259
|
-
# set future if the all results are
|
1415
|
+
# set future if the all results are received
|
1260
1416
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
1261
1417
|
self.model_update_result.set_result(self.model_update_tmp)
|
1262
1418
|
|
@@ -1279,12 +1435,18 @@ class SignalHandler:
|
|
1279
1435
|
def __init__(self, tokenizer_manager: TokenizerManager):
|
1280
1436
|
self.tokenizer_manager = tokenizer_manager
|
1281
1437
|
|
1282
|
-
def
|
1438
|
+
def sigterm_handler(self, signum=None, frame=None):
|
1283
1439
|
logger.warning(
|
1284
1440
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
1285
1441
|
)
|
1286
1442
|
self.tokenizer_manager.gracefully_exit = True
|
1287
1443
|
|
1444
|
+
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
1445
|
+
logger.error(
|
1446
|
+
"Received sigquit from a child process. It usually means the child failed."
|
1447
|
+
)
|
1448
|
+
kill_process_tree(os.getpid())
|
1449
|
+
|
1288
1450
|
|
1289
1451
|
T = TypeVar("T")
|
1290
1452
|
|
@@ -1325,3 +1487,15 @@ class _Communicator(Generic[T]):
|
|
1325
1487
|
self._result_values.append(recv_obj)
|
1326
1488
|
if len(self._result_values) == self._fan_out:
|
1327
1489
|
self._result_event.set()
|
1490
|
+
|
1491
|
+
|
1492
|
+
# Note: request abort handling logic
|
1493
|
+
# We should handle all of the following cases correctly.
|
1494
|
+
#
|
1495
|
+
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
1496
|
+
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
1497
|
+
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
1498
|
+
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
1499
|
+
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
1500
|
+
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
1501
|
+
#
|