sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -285,6 +285,7 @@ class MultimodalInputs:
|
|
285
285
|
num_image_tokens: Optional[int] = None
|
286
286
|
|
287
287
|
# QWen2-VL related
|
288
|
+
mrope_positions: Optional[torch.Tensor] = None
|
288
289
|
mrope_position_delta: Optional[torch.Tensor] = None
|
289
290
|
|
290
291
|
# image
|
@@ -310,16 +311,12 @@ class MultimodalInputs:
|
|
310
311
|
assert isinstance(ret.mm_items, list)
|
311
312
|
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
312
313
|
|
313
|
-
assert len(ret.mm_items) != 0
|
314
|
-
|
315
|
-
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
316
|
-
# Please note that if the `input_ids` is later used in the model forward,
|
317
|
-
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
318
|
-
# errors in cuda kernels. See also llava.py for example.
|
319
314
|
for item in ret.mm_items:
|
320
315
|
item.set_pad_value()
|
321
316
|
|
322
317
|
optional_args = [
|
318
|
+
"mrope_positions",
|
319
|
+
"mrope_position_delta",
|
323
320
|
"im_token_id",
|
324
321
|
"im_start_id",
|
325
322
|
"im_end_id",
|
@@ -350,11 +347,6 @@ class MultimodalInputs:
|
|
350
347
|
merge image inputs when requests are being merged
|
351
348
|
"""
|
352
349
|
|
353
|
-
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
354
|
-
# Please note that if the `input_ids` is later used in the model forward,
|
355
|
-
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
356
|
-
# errors in cuda kernels. See also llava.py for example.
|
357
|
-
|
358
350
|
# args needed to be merged
|
359
351
|
optional_args = [
|
360
352
|
"mm_items",
|
@@ -364,6 +356,30 @@ class MultimodalInputs:
|
|
364
356
|
self_arg = getattr(self, arg, None)
|
365
357
|
if self_arg is not None:
|
366
358
|
setattr(self, arg, self_arg + getattr(other, arg))
|
359
|
+
|
360
|
+
mrope_positions = self.mrope_positions
|
361
|
+
if mrope_positions is not None:
|
362
|
+
if other.mrope_positions is None:
|
363
|
+
self.mrope_positions = mrope_positions
|
364
|
+
else:
|
365
|
+
self.mrope_positions = torch.cat(
|
366
|
+
[self.mrope_positions, other.mrope_positions], dim=1
|
367
|
+
)
|
368
|
+
|
369
|
+
mrope_position_delta = self.mrope_position_delta
|
370
|
+
if mrope_position_delta is not None:
|
371
|
+
if other.mrope_position_delta is None:
|
372
|
+
self.mrope_position_delta = mrope_position_delta
|
373
|
+
else:
|
374
|
+
self.mrope_position_delta = torch.cat(
|
375
|
+
[self.mrope_position_delta, other.mrope_position_delta], dim=0
|
376
|
+
)
|
377
|
+
|
378
|
+
for key, val in other.__dict__.items():
|
379
|
+
if "_id" in key:
|
380
|
+
# set token_ids
|
381
|
+
if getattr(self, key, None) is None:
|
382
|
+
setattr(self, key, getattr(other, key, None))
|
367
383
|
# other args would be kept intact
|
368
384
|
|
369
385
|
|
@@ -388,6 +404,7 @@ class Req:
|
|
388
404
|
return_hidden_states: bool = False,
|
389
405
|
eos_token_ids: Optional[Set[int]] = None,
|
390
406
|
bootstrap_host: Optional[str] = None,
|
407
|
+
bootstrap_port: Optional[int] = None,
|
391
408
|
bootstrap_room: Optional[int] = None,
|
392
409
|
):
|
393
410
|
# Input and output info
|
@@ -523,6 +540,7 @@ class Req:
|
|
523
540
|
|
524
541
|
# For disaggregation
|
525
542
|
self.bootstrap_host: str = bootstrap_host
|
543
|
+
self.bootstrap_port: Optional[int] = bootstrap_port
|
526
544
|
self.bootstrap_room: Optional[int] = bootstrap_room
|
527
545
|
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
528
546
|
|
@@ -539,6 +557,11 @@ class Req:
|
|
539
557
|
# The first output_id transferred from prefill instance.
|
540
558
|
self.transferred_output_id: Optional[int] = None
|
541
559
|
|
560
|
+
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
561
|
+
# This is because kv is not ready in `process_prefill_chunk`.
|
562
|
+
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
563
|
+
self.tmp_end_idx: int = -1
|
564
|
+
|
542
565
|
@property
|
543
566
|
def seqlen(self):
|
544
567
|
return len(self.origin_input_ids) + len(self.output_ids)
|
@@ -571,6 +594,14 @@ class Req:
|
|
571
594
|
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
572
595
|
rid=self.rid, key=self.adjust_max_prefix_ids()
|
573
596
|
)
|
597
|
+
elif enable_hierarchical_cache:
|
598
|
+
# in case last_node is evicted during scheduling, we need to update the prefix_indices
|
599
|
+
while self.last_node.evicted:
|
600
|
+
self.prefix_indices = self.prefix_indices[
|
601
|
+
: -len(self.last_node.host_value)
|
602
|
+
]
|
603
|
+
self.last_node = self.last_node.parent
|
604
|
+
|
574
605
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
575
606
|
|
576
607
|
def adjust_max_prefix_ids(self):
|
@@ -1437,7 +1468,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1437
1468
|
if self.model_config.is_encoder_decoder:
|
1438
1469
|
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
1439
1470
|
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
1440
|
-
|
1441
1471
|
self.req_pool_indices = torch.cat(
|
1442
1472
|
[self.req_pool_indices, other.req_pool_indices]
|
1443
1473
|
)
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import (
|
|
60
60
|
CloseSessionReqInput,
|
61
61
|
ExpertDistributionReq,
|
62
62
|
ExpertDistributionReqOutput,
|
63
|
-
|
63
|
+
FlushCacheReqInput,
|
64
|
+
FlushCacheReqOutput,
|
64
65
|
GetInternalStateReq,
|
65
66
|
GetInternalStateReqOutput,
|
66
67
|
GetWeightsByNameReqInput,
|
@@ -402,7 +403,7 @@ class Scheduler(
|
|
402
403
|
[
|
403
404
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
404
405
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
405
|
-
(
|
406
|
+
(FlushCacheReqInput, self.flush_cache_wrapped),
|
406
407
|
(AbortReq, self.abort_request),
|
407
408
|
(OpenSessionReqInput, self.open_session),
|
408
409
|
(CloseSessionReqInput, self.close_session),
|
@@ -488,6 +489,8 @@ class Scheduler(
|
|
488
489
|
tp_cache_group=self.tp_cpu_group,
|
489
490
|
page_size=self.page_size,
|
490
491
|
hicache_ratio=server_args.hicache_ratio,
|
492
|
+
hicache_size=server_args.hicache_size,
|
493
|
+
hicache_write_policy=server_args.hicache_write_policy,
|
491
494
|
)
|
492
495
|
else:
|
493
496
|
self.tree_cache = RadixCache(
|
@@ -575,6 +578,10 @@ class Scheduler(
|
|
575
578
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
576
579
|
transfer_backend=self.transfer_backend,
|
577
580
|
)
|
581
|
+
|
582
|
+
# Metric for pre-allocation
|
583
|
+
self.num_tokens_pre_allocated = 0
|
584
|
+
|
578
585
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
579
586
|
# *2 for the headroom.
|
580
587
|
buffer_size = self.max_running_requests * 2
|
@@ -590,7 +597,7 @@ class Scheduler(
|
|
590
597
|
)
|
591
598
|
metadata_buffers = [output_id_buffer]
|
592
599
|
|
593
|
-
self.
|
600
|
+
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
594
601
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
595
602
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
596
603
|
metadata_buffers=metadata_buffers,
|
@@ -784,6 +791,7 @@ class Scheduler(
|
|
784
791
|
return_hidden_states=recv_req.return_hidden_states,
|
785
792
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
786
793
|
bootstrap_host=recv_req.bootstrap_host,
|
794
|
+
bootstrap_port=recv_req.bootstrap_port,
|
787
795
|
bootstrap_room=recv_req.bootstrap_room,
|
788
796
|
)
|
789
797
|
req.tokenizer = self.tokenizer
|
@@ -898,7 +906,7 @@ class Scheduler(
|
|
898
906
|
def _add_request_to_queue(self, req: Req):
|
899
907
|
req.queue_time_start = time.time()
|
900
908
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
901
|
-
self.
|
909
|
+
self.disagg_prefill_bootstrap_queue.add(req)
|
902
910
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
903
911
|
self.disagg_decode_prealloc_queue.add(req)
|
904
912
|
else:
|
@@ -988,8 +996,15 @@ class Scheduler(
|
|
988
996
|
f"#cached-token: {adder.log_hit_tokens}, "
|
989
997
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
990
998
|
f"#running-req: {running_bs}, "
|
991
|
-
f"#queue-req: {len(self.waiting_queue)}, "
|
992
999
|
)
|
1000
|
+
|
1001
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1002
|
+
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
1003
|
+
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1004
|
+
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
|
1005
|
+
else:
|
1006
|
+
f += f"#queue-req: {len(self.waiting_queue)}"
|
1007
|
+
|
993
1008
|
logger.info(f)
|
994
1009
|
|
995
1010
|
if self.enable_metrics:
|
@@ -1025,15 +1040,14 @@ class Scheduler(
|
|
1025
1040
|
gap_latency / self.server_args.decode_log_interval
|
1026
1041
|
)
|
1027
1042
|
|
1043
|
+
msg = (
|
1044
|
+
f"Decode batch. "
|
1045
|
+
f"#running-req: {num_running_reqs}, "
|
1046
|
+
f"#token: {num_used}, "
|
1047
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
1048
|
+
)
|
1049
|
+
|
1028
1050
|
if self.spec_algorithm.is_none():
|
1029
|
-
msg = (
|
1030
|
-
f"Decode batch. "
|
1031
|
-
f"#running-req: {num_running_reqs}, "
|
1032
|
-
f"#token: {num_used}, "
|
1033
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
1034
|
-
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1035
|
-
f"#queue-req: {len(self.waiting_queue)}, "
|
1036
|
-
)
|
1037
1051
|
spec_accept_length = 0
|
1038
1052
|
else:
|
1039
1053
|
spec_accept_length = (
|
@@ -1042,15 +1056,15 @@ class Scheduler(
|
|
1042
1056
|
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
1043
1057
|
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
1044
1058
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
1045
|
-
msg
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1059
|
+
msg += f"accept len: {spec_accept_length:.2f}, "
|
1060
|
+
|
1061
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1062
|
+
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
1063
|
+
|
1064
|
+
msg += (
|
1065
|
+
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1066
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
1067
|
+
)
|
1054
1068
|
|
1055
1069
|
logger.info(msg)
|
1056
1070
|
if self.enable_metrics:
|
@@ -1596,8 +1610,9 @@ class Scheduler(
|
|
1596
1610
|
time.sleep(5)
|
1597
1611
|
self.parent_process.send_signal(signal.SIGQUIT)
|
1598
1612
|
|
1599
|
-
def flush_cache_wrapped(self, recv_req:
|
1600
|
-
self.flush_cache()
|
1613
|
+
def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
|
1614
|
+
success = self.flush_cache()
|
1615
|
+
return FlushCacheReqOutput(success=success)
|
1601
1616
|
|
1602
1617
|
def flush_cache(self):
|
1603
1618
|
"""Flush the memory pool and cache."""
|
@@ -2010,9 +2025,15 @@ def run_scheduler_process(
|
|
2010
2025
|
else:
|
2011
2026
|
scheduler.event_loop_normal()
|
2012
2027
|
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
2013
|
-
scheduler.
|
2028
|
+
if scheduler.enable_overlap:
|
2029
|
+
scheduler.event_loop_overlap_disagg_prefill()
|
2030
|
+
else:
|
2031
|
+
scheduler.event_loop_normal_disagg_prefill()
|
2014
2032
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2015
|
-
scheduler.
|
2033
|
+
if scheduler.enable_overlap:
|
2034
|
+
scheduler.event_loop_overlap_disagg_decode()
|
2035
|
+
else:
|
2036
|
+
scheduler.event_loop_normal_disagg_decode()
|
2016
2037
|
|
2017
2038
|
except Exception:
|
2018
2039
|
traceback = get_exception_traceback()
|
@@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import (
|
|
66
66
|
EmbeddingReqInput,
|
67
67
|
ExpertDistributionReq,
|
68
68
|
ExpertDistributionReqOutput,
|
69
|
-
|
69
|
+
FlushCacheReqInput,
|
70
|
+
FlushCacheReqOutput,
|
70
71
|
GenerateReqInput,
|
71
72
|
GetInternalStateReq,
|
72
73
|
GetInternalStateReqOutput,
|
@@ -264,6 +265,9 @@ class TokenizerManager:
|
|
264
265
|
self.resume_memory_occupation_communicator = _Communicator(
|
265
266
|
self.send_to_scheduler, server_args.dp_size
|
266
267
|
)
|
268
|
+
self.flush_cache_communicator = _Communicator(
|
269
|
+
self.send_to_scheduler, server_args.dp_size
|
270
|
+
)
|
267
271
|
self.start_profile_communicator = _Communicator(
|
268
272
|
self.send_to_scheduler, server_args.dp_size
|
269
273
|
)
|
@@ -314,6 +318,10 @@ class TokenizerManager:
|
|
314
318
|
ResumeMemoryOccupationReqOutput,
|
315
319
|
self.resume_memory_occupation_communicator.handle_recv,
|
316
320
|
),
|
321
|
+
(
|
322
|
+
FlushCacheReqOutput,
|
323
|
+
self.flush_cache_communicator.handle_recv,
|
324
|
+
),
|
317
325
|
(
|
318
326
|
ProfileReqOutput,
|
319
327
|
self.start_profile_communicator.handle_recv,
|
@@ -411,42 +419,67 @@ class TokenizerManager:
|
|
411
419
|
input_ids = self.tokenizer.encode(input_text)
|
412
420
|
|
413
421
|
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
414
|
-
obj.image_data,
|
422
|
+
image_data=obj.image_data,
|
423
|
+
input_text=input_text or input_ids,
|
424
|
+
request_obj=obj,
|
425
|
+
max_req_input_len=self.max_req_input_len,
|
415
426
|
)
|
416
427
|
if image_inputs and "input_ids" in image_inputs:
|
417
428
|
input_ids = image_inputs["input_ids"]
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
429
|
+
|
430
|
+
self._validate_token_len(obj, input_ids)
|
431
|
+
return self._create_tokenized_object(
|
432
|
+
obj, input_text, input_ids, input_embeds, image_inputs
|
433
|
+
)
|
434
|
+
|
435
|
+
def _validate_token_len(
|
436
|
+
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
437
|
+
) -> None:
|
438
|
+
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
426
439
|
|
427
440
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
441
|
+
# Check if input alone exceeds context length
|
428
442
|
if input_token_num >= self.context_len:
|
429
443
|
raise ValueError(
|
430
444
|
f"The input ({input_token_num} tokens) is longer than the "
|
431
445
|
f"model's context length ({self.context_len} tokens)."
|
432
446
|
)
|
433
447
|
|
448
|
+
# Check total tokens (input + max_new_tokens)
|
449
|
+
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
434
450
|
if (
|
435
|
-
|
436
|
-
and
|
437
|
-
>= self.context_len
|
451
|
+
max_new_tokens is not None
|
452
|
+
and (max_new_tokens + input_token_num) >= self.context_len
|
438
453
|
):
|
439
|
-
|
454
|
+
total_tokens = max_new_tokens + input_token_num
|
455
|
+
error_msg = (
|
440
456
|
f"Requested token count exceeds the model's maximum context length "
|
441
|
-
f"of {self.context_len} tokens. You requested a total of "
|
442
|
-
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
|
457
|
+
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
443
458
|
f"tokens: {input_token_num} tokens from the input messages and "
|
444
|
-
f"{
|
445
|
-
f"
|
446
|
-
|
459
|
+
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
460
|
+
f"of tokens in the input messages or the completion to fit within the limit."
|
461
|
+
)
|
462
|
+
raise ValueError(error_msg)
|
463
|
+
|
464
|
+
def _create_tokenized_object(
|
465
|
+
self,
|
466
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
467
|
+
input_text: str,
|
468
|
+
input_ids: List[int],
|
469
|
+
input_embeds: Optional[Union[List[float], None]] = None,
|
470
|
+
image_inputs: Optional[Dict] = None,
|
471
|
+
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
472
|
+
"""Create a tokenized request object from common parameters."""
|
473
|
+
|
474
|
+
if self.is_generation:
|
475
|
+
return_logprob = obj.return_logprob
|
476
|
+
logprob_start_len = obj.logprob_start_len
|
477
|
+
top_logprobs_num = obj.top_logprobs_num
|
478
|
+
token_ids_logprob = obj.token_ids_logprob
|
479
|
+
session_params = (
|
480
|
+
SessionParams(**obj.session_params) if obj.session_params else None
|
447
481
|
)
|
448
482
|
|
449
|
-
# Parse sampling parameters
|
450
483
|
sampling_params = SamplingParams(**obj.sampling_params)
|
451
484
|
sampling_params.normalize(self.tokenizer)
|
452
485
|
sampling_params.verify()
|
@@ -465,6 +498,7 @@ class TokenizerManager:
|
|
465
498
|
token_ids_logprob,
|
466
499
|
obj.stream,
|
467
500
|
bootstrap_host=obj.bootstrap_host,
|
501
|
+
bootstrap_port=obj.bootstrap_port,
|
468
502
|
bootstrap_room=obj.bootstrap_room,
|
469
503
|
lora_path=obj.lora_path,
|
470
504
|
input_embeds=input_embeds,
|
@@ -483,6 +517,50 @@ class TokenizerManager:
|
|
483
517
|
|
484
518
|
return tokenized_obj
|
485
519
|
|
520
|
+
async def _batch_tokenize_and_process(
|
521
|
+
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
522
|
+
) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
|
523
|
+
"""Handle batch tokenization for text inputs only."""
|
524
|
+
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
|
525
|
+
|
526
|
+
# Collect requests and texts
|
527
|
+
requests = [obj[i] for i in range(batch_size)]
|
528
|
+
texts = [req.text for req in requests]
|
529
|
+
|
530
|
+
# Batch tokenize all texts
|
531
|
+
encoded = self.tokenizer(texts)
|
532
|
+
input_ids_list = encoded["input_ids"]
|
533
|
+
|
534
|
+
# Process all requests
|
535
|
+
tokenized_objs = []
|
536
|
+
for i, req in enumerate(requests):
|
537
|
+
self._validate_token_len(obj[i], input_ids_list[i])
|
538
|
+
tokenized_objs.append(
|
539
|
+
self._create_tokenized_object(
|
540
|
+
req, req.text, input_ids_list[i], None, None
|
541
|
+
)
|
542
|
+
)
|
543
|
+
logger.debug(f"Completed batch processing for {batch_size} requests")
|
544
|
+
return tokenized_objs
|
545
|
+
|
546
|
+
def _validate_batch_tokenization_constraints(
|
547
|
+
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
548
|
+
) -> None:
|
549
|
+
"""Validate constraints for batch tokenization processing."""
|
550
|
+
for i in range(batch_size):
|
551
|
+
if self.is_generation and obj[i].image_data:
|
552
|
+
raise ValueError(
|
553
|
+
"For image input processing do not set `enable_tokenizer_batch_encode`."
|
554
|
+
)
|
555
|
+
if obj[i].input_ids is not None:
|
556
|
+
raise ValueError(
|
557
|
+
"Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
|
558
|
+
)
|
559
|
+
if obj[i].input_embeds is not None:
|
560
|
+
raise ValueError(
|
561
|
+
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
562
|
+
)
|
563
|
+
|
486
564
|
def _send_one_request(
|
487
565
|
self,
|
488
566
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -560,14 +638,27 @@ class TokenizerManager:
|
|
560
638
|
|
561
639
|
generators = []
|
562
640
|
rids = []
|
641
|
+
|
563
642
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
self.
|
569
|
-
|
570
|
-
|
643
|
+
if self.server_args.enable_tokenizer_batch_encode:
|
644
|
+
# Validate batch tokenization constraints
|
645
|
+
self._validate_batch_tokenization_constraints(batch_size, obj)
|
646
|
+
|
647
|
+
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
648
|
+
|
649
|
+
for i, tokenized_obj in enumerate(tokenized_objs):
|
650
|
+
tmp_obj = obj[i]
|
651
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
652
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
653
|
+
rids.append(tmp_obj.rid)
|
654
|
+
else:
|
655
|
+
# Sequential tokenization and processing
|
656
|
+
for i in range(batch_size):
|
657
|
+
tmp_obj = obj[i]
|
658
|
+
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
659
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
660
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
661
|
+
rids.append(tmp_obj.rid)
|
571
662
|
else:
|
572
663
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
573
664
|
if batch_size > 128:
|
@@ -628,9 +719,8 @@ class TokenizerManager:
|
|
628
719
|
except StopAsyncIteration:
|
629
720
|
pass
|
630
721
|
|
631
|
-
def flush_cache(self):
|
632
|
-
|
633
|
-
self.send_to_scheduler.send_pyobj(req)
|
722
|
+
async def flush_cache(self) -> FlushCacheReqOutput:
|
723
|
+
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
634
724
|
|
635
725
|
def abort_request(self, rid: str):
|
636
726
|
if rid not in self.rid_to_state:
|
sglang/srt/managers/tp_worker.py
CHANGED