sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import (
|
|
60
60
|
CloseSessionReqInput,
|
61
61
|
ExpertDistributionReq,
|
62
62
|
ExpertDistributionReqOutput,
|
63
|
-
|
63
|
+
FlushCacheReqInput,
|
64
|
+
FlushCacheReqOutput,
|
64
65
|
GetInternalStateReq,
|
65
66
|
GetInternalStateReqOutput,
|
66
67
|
GetWeightsByNameReqInput,
|
@@ -391,6 +392,7 @@ class Scheduler(
|
|
391
392
|
self.torch_profiler = None
|
392
393
|
self.torch_profiler_output_dir: Optional[str] = None
|
393
394
|
self.profiler_activities: Optional[List[str]] = None
|
395
|
+
self.profiler_id: Optional[str] = None
|
394
396
|
self.profiler_target_forward_ct: Optional[int] = None
|
395
397
|
|
396
398
|
# Init metrics stats
|
@@ -401,7 +403,7 @@ class Scheduler(
|
|
401
403
|
[
|
402
404
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
403
405
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
404
|
-
(
|
406
|
+
(FlushCacheReqInput, self.flush_cache_wrapped),
|
405
407
|
(AbortReq, self.abort_request),
|
406
408
|
(OpenSessionReqInput, self.open_session),
|
407
409
|
(CloseSessionReqInput, self.close_session),
|
@@ -484,9 +486,11 @@ class Scheduler(
|
|
484
486
|
self.tree_cache = HiRadixCache(
|
485
487
|
req_to_token_pool=self.req_to_token_pool,
|
486
488
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
487
|
-
tp_cache_group=self.
|
489
|
+
tp_cache_group=self.tp_cpu_group,
|
488
490
|
page_size=self.page_size,
|
489
491
|
hicache_ratio=server_args.hicache_ratio,
|
492
|
+
hicache_size=server_args.hicache_size,
|
493
|
+
hicache_write_policy=server_args.hicache_write_policy,
|
490
494
|
)
|
491
495
|
else:
|
492
496
|
self.tree_cache = RadixCache(
|
@@ -553,7 +557,7 @@ class Scheduler(
|
|
553
557
|
|
554
558
|
# The decode requests polling kv cache
|
555
559
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
556
|
-
gloo_group=self.
|
560
|
+
gloo_group=self.attn_tp_cpu_group,
|
557
561
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
558
562
|
metadata_buffers=metadata_buffers,
|
559
563
|
)
|
@@ -568,7 +572,7 @@ class Scheduler(
|
|
568
572
|
scheduler=self,
|
569
573
|
transfer_queue=self.disagg_decode_transfer_queue,
|
570
574
|
tree_cache=self.tree_cache,
|
571
|
-
gloo_group=self.
|
575
|
+
gloo_group=self.attn_tp_cpu_group,
|
572
576
|
tp_rank=self.tp_rank,
|
573
577
|
tp_size=self.tp_size,
|
574
578
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
@@ -597,7 +601,7 @@ class Scheduler(
|
|
597
601
|
tp_rank=self.tp_rank,
|
598
602
|
tp_size=self.tp_size,
|
599
603
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
600
|
-
gloo_group=self.
|
604
|
+
gloo_group=self.attn_tp_cpu_group,
|
601
605
|
transfer_backend=self.transfer_backend,
|
602
606
|
scheduler=self,
|
603
607
|
)
|
@@ -664,70 +668,6 @@ class Scheduler(
|
|
664
668
|
|
665
669
|
self.last_batch = batch
|
666
670
|
|
667
|
-
@torch.no_grad()
|
668
|
-
def event_loop_normal_disagg_prefill(self):
|
669
|
-
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
670
|
-
|
671
|
-
while True:
|
672
|
-
recv_reqs = self.recv_requests()
|
673
|
-
self.process_input_requests(recv_reqs)
|
674
|
-
self.waiting_queue.extend(
|
675
|
-
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
676
|
-
)
|
677
|
-
self.process_prefill_chunk()
|
678
|
-
batch = self.get_new_batch_prefill()
|
679
|
-
self.cur_batch = batch
|
680
|
-
|
681
|
-
if batch:
|
682
|
-
result = self.run_batch(batch)
|
683
|
-
self.process_batch_result_disagg_prefill(batch, result)
|
684
|
-
|
685
|
-
if len(self.disagg_prefill_inflight_queue) > 0:
|
686
|
-
self.process_disagg_prefill_inflight_queue()
|
687
|
-
|
688
|
-
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
689
|
-
self.check_memory()
|
690
|
-
self.new_token_ratio = self.init_new_token_ratio
|
691
|
-
|
692
|
-
self.last_batch = batch
|
693
|
-
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
694
|
-
# Otherwise, it hangs under high concurrency
|
695
|
-
self.running_batch.batch_is_full = False
|
696
|
-
|
697
|
-
@torch.no_grad()
|
698
|
-
def event_loop_normal_disagg_decode(self):
|
699
|
-
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
700
|
-
|
701
|
-
while True:
|
702
|
-
recv_reqs = self.recv_requests()
|
703
|
-
self.process_input_requests(recv_reqs)
|
704
|
-
# polling and allocating kv cache
|
705
|
-
self.process_decode_queue()
|
706
|
-
batch = self.get_next_disagg_decode_batch_to_run()
|
707
|
-
self.cur_batch = batch
|
708
|
-
|
709
|
-
if batch:
|
710
|
-
# Generate fake extend output.
|
711
|
-
if batch.forward_mode.is_extend():
|
712
|
-
# Note: Logprobs should be handled on the prefill engine.
|
713
|
-
self.stream_output(
|
714
|
-
batch.reqs, [False for _ in range(len(batch.reqs))]
|
715
|
-
)
|
716
|
-
else:
|
717
|
-
result = self.run_batch(batch)
|
718
|
-
self.process_batch_result(batch, result)
|
719
|
-
|
720
|
-
if batch is None and (
|
721
|
-
len(self.disagg_decode_transfer_queue.queue)
|
722
|
-
+ len(self.disagg_decode_prealloc_queue.queue)
|
723
|
-
== 0
|
724
|
-
):
|
725
|
-
# When the server is idle, do self-check and re-init some states
|
726
|
-
self.check_memory()
|
727
|
-
self.new_token_ratio = self.init_new_token_ratio
|
728
|
-
|
729
|
-
self.last_batch = batch
|
730
|
-
|
731
671
|
def recv_requests(self) -> List[Req]:
|
732
672
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
733
673
|
if self.attn_tp_rank == 0:
|
@@ -1659,8 +1599,9 @@ class Scheduler(
|
|
1659
1599
|
time.sleep(5)
|
1660
1600
|
self.parent_process.send_signal(signal.SIGQUIT)
|
1661
1601
|
|
1662
|
-
def flush_cache_wrapped(self, recv_req:
|
1663
|
-
self.flush_cache()
|
1602
|
+
def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
|
1603
|
+
success = self.flush_cache()
|
1604
|
+
return FlushCacheReqOutput(success=success)
|
1664
1605
|
|
1665
1606
|
def flush_cache(self):
|
1666
1607
|
"""Flush the memory pool and cache."""
|
@@ -1869,6 +1810,7 @@ class Scheduler(
|
|
1869
1810
|
recv_req.activities,
|
1870
1811
|
recv_req.with_stack,
|
1871
1812
|
recv_req.record_shapes,
|
1813
|
+
recv_req.profile_id,
|
1872
1814
|
)
|
1873
1815
|
else:
|
1874
1816
|
return self.stop_profile()
|
@@ -1880,6 +1822,7 @@ class Scheduler(
|
|
1880
1822
|
activities: Optional[List[str]],
|
1881
1823
|
with_stack: Optional[bool],
|
1882
1824
|
record_shapes: Optional[bool],
|
1825
|
+
profile_id: Optional[str],
|
1883
1826
|
) -> None:
|
1884
1827
|
if self.profiler_activities:
|
1885
1828
|
return ProfileReqOutput(
|
@@ -1894,9 +1837,11 @@ class Scheduler(
|
|
1894
1837
|
|
1895
1838
|
self.torch_profiler_output_dir = output_dir
|
1896
1839
|
self.profiler_activities = activities
|
1840
|
+
self.profiler_id = profile_id
|
1897
1841
|
logger.info(
|
1898
|
-
"Profiling starts. Traces will be saved to: %s",
|
1842
|
+
"Profiling starts. Traces will be saved to: %s (with id %s)",
|
1899
1843
|
self.torch_profiler_output_dir,
|
1844
|
+
self.profiler_id,
|
1900
1845
|
)
|
1901
1846
|
|
1902
1847
|
activity_map = {
|
@@ -1938,14 +1883,14 @@ class Scheduler(
|
|
1938
1883
|
self.torch_profiler.export_chrome_trace(
|
1939
1884
|
os.path.join(
|
1940
1885
|
self.torch_profiler_output_dir,
|
1941
|
-
|
1886
|
+
self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
1942
1887
|
)
|
1943
1888
|
)
|
1944
1889
|
|
1945
1890
|
if "MEM" in self.profiler_activities:
|
1946
1891
|
memory_profile_path = os.path.join(
|
1947
1892
|
self.torch_profiler_output_dir,
|
1948
|
-
|
1893
|
+
self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
|
1949
1894
|
)
|
1950
1895
|
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
1951
1896
|
torch.cuda.memory._record_memory_history(enabled=None)
|
@@ -2069,9 +2014,15 @@ def run_scheduler_process(
|
|
2069
2014
|
else:
|
2070
2015
|
scheduler.event_loop_normal()
|
2071
2016
|
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
2072
|
-
scheduler.
|
2017
|
+
if scheduler.enable_overlap:
|
2018
|
+
scheduler.event_loop_overlap_disagg_prefill()
|
2019
|
+
else:
|
2020
|
+
scheduler.event_loop_normal_disagg_prefill()
|
2073
2021
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2074
|
-
scheduler.
|
2022
|
+
if scheduler.enable_overlap:
|
2023
|
+
scheduler.event_loop_overlap_disagg_decode()
|
2024
|
+
else:
|
2025
|
+
scheduler.event_loop_normal_disagg_decode()
|
2075
2026
|
|
2076
2027
|
except Exception:
|
2077
2028
|
traceback = get_exception_traceback()
|
@@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import (
|
|
66
66
|
EmbeddingReqInput,
|
67
67
|
ExpertDistributionReq,
|
68
68
|
ExpertDistributionReqOutput,
|
69
|
-
|
69
|
+
FlushCacheReqInput,
|
70
|
+
FlushCacheReqOutput,
|
70
71
|
GenerateReqInput,
|
71
72
|
GetInternalStateReq,
|
72
73
|
GetInternalStateReqOutput,
|
@@ -264,6 +265,9 @@ class TokenizerManager:
|
|
264
265
|
self.resume_memory_occupation_communicator = _Communicator(
|
265
266
|
self.send_to_scheduler, server_args.dp_size
|
266
267
|
)
|
268
|
+
self.flush_cache_communicator = _Communicator(
|
269
|
+
self.send_to_scheduler, server_args.dp_size
|
270
|
+
)
|
267
271
|
self.start_profile_communicator = _Communicator(
|
268
272
|
self.send_to_scheduler, server_args.dp_size
|
269
273
|
)
|
@@ -314,6 +318,10 @@ class TokenizerManager:
|
|
314
318
|
ResumeMemoryOccupationReqOutput,
|
315
319
|
self.resume_memory_occupation_communicator.handle_recv,
|
316
320
|
),
|
321
|
+
(
|
322
|
+
FlushCacheReqOutput,
|
323
|
+
self.flush_cache_communicator.handle_recv,
|
324
|
+
),
|
317
325
|
(
|
318
326
|
ProfileReqOutput,
|
319
327
|
self.start_profile_communicator.handle_recv,
|
@@ -415,38 +423,60 @@ class TokenizerManager:
|
|
415
423
|
)
|
416
424
|
if image_inputs and "input_ids" in image_inputs:
|
417
425
|
input_ids = image_inputs["input_ids"]
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
+
|
427
|
+
self._validate_token_len(obj, input_ids)
|
428
|
+
return self._create_tokenized_object(
|
429
|
+
obj, input_text, input_ids, input_embeds, image_inputs
|
430
|
+
)
|
431
|
+
|
432
|
+
def _validate_token_len(
|
433
|
+
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
434
|
+
) -> None:
|
435
|
+
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
426
436
|
|
427
437
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
438
|
+
# Check if input alone exceeds context length
|
428
439
|
if input_token_num >= self.context_len:
|
429
440
|
raise ValueError(
|
430
441
|
f"The input ({input_token_num} tokens) is longer than the "
|
431
442
|
f"model's context length ({self.context_len} tokens)."
|
432
443
|
)
|
433
444
|
|
445
|
+
# Check total tokens (input + max_new_tokens)
|
446
|
+
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
434
447
|
if (
|
435
|
-
|
436
|
-
and
|
437
|
-
>= self.context_len
|
448
|
+
max_new_tokens is not None
|
449
|
+
and (max_new_tokens + input_token_num) >= self.context_len
|
438
450
|
):
|
439
|
-
|
451
|
+
total_tokens = max_new_tokens + input_token_num
|
452
|
+
error_msg = (
|
440
453
|
f"Requested token count exceeds the model's maximum context length "
|
441
|
-
f"of {self.context_len} tokens. You requested a total of "
|
442
|
-
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
|
454
|
+
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
443
455
|
f"tokens: {input_token_num} tokens from the input messages and "
|
444
|
-
f"{
|
445
|
-
f"
|
446
|
-
|
456
|
+
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
457
|
+
f"of tokens in the input messages or the completion to fit within the limit."
|
458
|
+
)
|
459
|
+
raise ValueError(error_msg)
|
460
|
+
|
461
|
+
def _create_tokenized_object(
|
462
|
+
self,
|
463
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
464
|
+
input_text: str,
|
465
|
+
input_ids: List[int],
|
466
|
+
input_embeds: Optional[Union[List[float], None]] = None,
|
467
|
+
image_inputs: Optional[Dict] = None,
|
468
|
+
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
469
|
+
"""Create a tokenized request object from common parameters."""
|
470
|
+
|
471
|
+
if self.is_generation:
|
472
|
+
return_logprob = obj.return_logprob
|
473
|
+
logprob_start_len = obj.logprob_start_len
|
474
|
+
top_logprobs_num = obj.top_logprobs_num
|
475
|
+
token_ids_logprob = obj.token_ids_logprob
|
476
|
+
session_params = (
|
477
|
+
SessionParams(**obj.session_params) if obj.session_params else None
|
447
478
|
)
|
448
479
|
|
449
|
-
# Parse sampling parameters
|
450
480
|
sampling_params = SamplingParams(**obj.sampling_params)
|
451
481
|
sampling_params.normalize(self.tokenizer)
|
452
482
|
sampling_params.verify()
|
@@ -483,6 +513,50 @@ class TokenizerManager:
|
|
483
513
|
|
484
514
|
return tokenized_obj
|
485
515
|
|
516
|
+
async def _batch_tokenize_and_process(
|
517
|
+
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
518
|
+
) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
|
519
|
+
"""Handle batch tokenization for text inputs only."""
|
520
|
+
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
|
521
|
+
|
522
|
+
# Collect requests and texts
|
523
|
+
requests = [obj[i] for i in range(batch_size)]
|
524
|
+
texts = [req.text for req in requests]
|
525
|
+
|
526
|
+
# Batch tokenize all texts
|
527
|
+
encoded = self.tokenizer(texts)
|
528
|
+
input_ids_list = encoded["input_ids"]
|
529
|
+
|
530
|
+
# Process all requests
|
531
|
+
tokenized_objs = []
|
532
|
+
for i, req in enumerate(requests):
|
533
|
+
self._validate_token_len(obj[i], input_ids_list[i])
|
534
|
+
tokenized_objs.append(
|
535
|
+
self._create_tokenized_object(
|
536
|
+
req, req.text, input_ids_list[i], None, None
|
537
|
+
)
|
538
|
+
)
|
539
|
+
logger.debug(f"Completed batch processing for {batch_size} requests")
|
540
|
+
return tokenized_objs
|
541
|
+
|
542
|
+
def _validate_batch_tokenization_constraints(
|
543
|
+
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
544
|
+
) -> None:
|
545
|
+
"""Validate constraints for batch tokenization processing."""
|
546
|
+
for i in range(batch_size):
|
547
|
+
if self.is_generation and obj[i].image_data:
|
548
|
+
raise ValueError(
|
549
|
+
"For image input processing do not set `enable_tokenizer_batch_encode`."
|
550
|
+
)
|
551
|
+
if obj[i].input_ids is not None:
|
552
|
+
raise ValueError(
|
553
|
+
"Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
|
554
|
+
)
|
555
|
+
if obj[i].input_embeds is not None:
|
556
|
+
raise ValueError(
|
557
|
+
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
558
|
+
)
|
559
|
+
|
486
560
|
def _send_one_request(
|
487
561
|
self,
|
488
562
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -560,14 +634,27 @@ class TokenizerManager:
|
|
560
634
|
|
561
635
|
generators = []
|
562
636
|
rids = []
|
637
|
+
|
563
638
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
self.
|
569
|
-
|
570
|
-
|
639
|
+
if self.server_args.enable_tokenizer_batch_encode:
|
640
|
+
# Validate batch tokenization constraints
|
641
|
+
self._validate_batch_tokenization_constraints(batch_size, obj)
|
642
|
+
|
643
|
+
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
644
|
+
|
645
|
+
for i, tokenized_obj in enumerate(tokenized_objs):
|
646
|
+
tmp_obj = obj[i]
|
647
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
648
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
649
|
+
rids.append(tmp_obj.rid)
|
650
|
+
else:
|
651
|
+
# Sequential tokenization and processing
|
652
|
+
for i in range(batch_size):
|
653
|
+
tmp_obj = obj[i]
|
654
|
+
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
655
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
656
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
657
|
+
rids.append(tmp_obj.rid)
|
571
658
|
else:
|
572
659
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
573
660
|
if batch_size > 128:
|
@@ -628,9 +715,8 @@ class TokenizerManager:
|
|
628
715
|
except StopAsyncIteration:
|
629
716
|
pass
|
630
717
|
|
631
|
-
def flush_cache(self):
|
632
|
-
|
633
|
-
self.send_to_scheduler.send_pyobj(req)
|
718
|
+
async def flush_cache(self) -> FlushCacheReqOutput:
|
719
|
+
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
634
720
|
|
635
721
|
def abort_request(self, rid: str):
|
636
722
|
if rid not in self.rid_to_state:
|
@@ -650,6 +736,7 @@ class TokenizerManager:
|
|
650
736
|
output_dir=output_dir,
|
651
737
|
num_steps=num_steps,
|
652
738
|
activities=activities,
|
739
|
+
profile_id=str(time.time()),
|
653
740
|
)
|
654
741
|
result = (await self.start_profile_communicator(req))[0]
|
655
742
|
if not result.success:
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
|
|
29
29
|
tp_cache_group: torch.distributed.ProcessGroup,
|
30
30
|
page_size: int,
|
31
31
|
hicache_ratio: float,
|
32
|
+
hicache_size: int,
|
33
|
+
hicache_write_policy: str,
|
32
34
|
):
|
33
35
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
34
36
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
35
37
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
36
|
-
self.kv_cache, hicache_ratio, page_size
|
38
|
+
self.kv_cache, hicache_ratio, hicache_size, page_size
|
37
39
|
)
|
38
40
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
39
41
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
40
|
-
self.kv_cache, hicache_ratio, page_size
|
42
|
+
self.kv_cache, hicache_ratio, hicache_size, page_size
|
41
43
|
)
|
42
44
|
else:
|
43
45
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
|
|
50
52
|
self.token_to_kv_pool_host,
|
51
53
|
page_size,
|
52
54
|
load_cache_event=self.load_cache_event,
|
55
|
+
write_policy=hicache_write_policy,
|
53
56
|
)
|
54
57
|
|
55
58
|
# record the nodes with ongoing write through
|
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
|
|
57
60
|
# record the node segments with ongoing load back
|
58
61
|
self.ongoing_load_back = {}
|
59
62
|
# todo: dynamically adjust the threshold
|
60
|
-
self.write_through_threshold =
|
63
|
+
self.write_through_threshold = (
|
64
|
+
1 if hicache_write_policy == "write_through" else 3
|
65
|
+
)
|
61
66
|
self.load_back_threshold = 10
|
62
67
|
super().__init__(
|
63
68
|
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
|
|
76
81
|
height += 1
|
77
82
|
return height
|
78
83
|
|
79
|
-
def write_backup(self, node: TreeNode):
|
84
|
+
def write_backup(self, node: TreeNode, write_back=False):
|
80
85
|
host_indices = self.cache_controller.write(
|
81
86
|
device_indices=node.value,
|
82
87
|
node_id=node.id,
|
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
|
|
90
95
|
if host_indices is not None:
|
91
96
|
node.host_value = host_indices
|
92
97
|
self.ongoing_write_through[node.id] = node
|
93
|
-
|
98
|
+
if not write_back:
|
99
|
+
# no need to lock nodes if write back
|
100
|
+
self.inc_lock_ref(node)
|
94
101
|
else:
|
95
|
-
return
|
102
|
+
return 0
|
96
103
|
|
97
104
|
return len(host_indices)
|
98
105
|
|
99
106
|
def inc_hit_count(self, node: TreeNode):
|
100
|
-
if self.cache_controller.write_policy
|
107
|
+
if node.backuped or self.cache_controller.write_policy == "write_back":
|
101
108
|
return
|
102
109
|
node.hit_count += 1
|
103
|
-
if node.
|
110
|
+
if node.hit_count >= self.write_through_threshold:
|
104
111
|
self.write_backup(node)
|
105
112
|
node.hit_count = 0
|
106
113
|
|
107
|
-
def writing_check(self):
|
114
|
+
def writing_check(self, write_back=False):
|
115
|
+
if write_back:
|
116
|
+
# blocking till all write back complete
|
117
|
+
while len(self.ongoing_write_through) > 0:
|
118
|
+
ack_id = self.cache_controller.ack_write_queue.get()
|
119
|
+
del self.ongoing_write_through[ack_id]
|
120
|
+
return
|
108
121
|
queue_size = torch.tensor(
|
109
122
|
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
110
123
|
)
|
@@ -143,28 +156,25 @@ class HiRadixCache(RadixCache):
|
|
143
156
|
heapq.heapify(leaves)
|
144
157
|
|
145
158
|
num_evicted = 0
|
146
|
-
|
159
|
+
write_back_nodes = []
|
147
160
|
while num_evicted < num_tokens and len(leaves):
|
148
161
|
x = heapq.heappop(leaves)
|
149
162
|
|
150
163
|
if x.lock_ref > 0:
|
151
164
|
continue
|
152
165
|
|
153
|
-
if x.
|
166
|
+
if not x.backuped:
|
154
167
|
if self.cache_controller.write_policy == "write_back":
|
155
|
-
|
156
|
-
|
157
|
-
|
168
|
+
# write to host if the node is not backuped
|
169
|
+
num_evicted += self.write_backup(x, write_back=True)
|
170
|
+
write_back_nodes.append(x)
|
158
171
|
else:
|
159
|
-
|
160
|
-
self.cache_controller.write_policy != "write_through"
|
161
|
-
), "write_through should be inclusive"
|
162
|
-
raise NotImplementedError
|
172
|
+
num_evicted += self._evict_regular(x)
|
163
173
|
else:
|
164
|
-
num_evicted += self.
|
174
|
+
num_evicted += self._evict_backuped(x)
|
165
175
|
|
166
176
|
for child in x.parent.children.values():
|
167
|
-
if child in
|
177
|
+
if child in write_back_nodes:
|
168
178
|
continue
|
169
179
|
if not child.evicted:
|
170
180
|
break
|
@@ -173,12 +183,12 @@ class HiRadixCache(RadixCache):
|
|
173
183
|
heapq.heappush(leaves, x.parent)
|
174
184
|
|
175
185
|
if self.cache_controller.write_policy == "write_back":
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
186
|
+
self.writing_check(write_back=True)
|
187
|
+
for node in write_back_nodes:
|
188
|
+
assert node.backuped
|
189
|
+
self._evict_backuped(node)
|
180
190
|
|
181
|
-
def
|
191
|
+
def _evict_backuped(self, node: TreeNode):
|
182
192
|
# evict a node already written to host
|
183
193
|
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
|
184
194
|
assert num_evicted > 0
|
@@ -186,7 +196,7 @@ class HiRadixCache(RadixCache):
|
|
186
196
|
node.value = None
|
187
197
|
return num_evicted
|
188
198
|
|
189
|
-
def
|
199
|
+
def _evict_regular(self, node: TreeNode):
|
190
200
|
# evict a node not initiated write to host
|
191
201
|
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
192
202
|
num_evicted = len(node.value)
|
@@ -335,11 +345,13 @@ class HiRadixCache(RadixCache):
|
|
335
345
|
prefix_len = self.key_match_fn(child.key, key)
|
336
346
|
if prefix_len < len(child.key):
|
337
347
|
new_node = self._split_node(child.key, child, prefix_len)
|
348
|
+
self.inc_hit_count(new_node)
|
338
349
|
if not new_node.evicted:
|
339
350
|
value.append(new_node.value)
|
340
351
|
node = new_node
|
341
352
|
break
|
342
353
|
else:
|
354
|
+
self.inc_hit_count(child)
|
343
355
|
if not child.evicted:
|
344
356
|
value.append(child.value)
|
345
357
|
node = child
|
@@ -365,7 +377,7 @@ class HiRadixCache(RadixCache):
|
|
365
377
|
else:
|
366
378
|
new_node.value = child.value[:split_len]
|
367
379
|
child.value = child.value[split_len:]
|
368
|
-
if child.
|
380
|
+
if child.backuped:
|
369
381
|
new_node.host_value = child.host_value[:split_len]
|
370
382
|
child.host_value = child.host_value[split_len:]
|
371
383
|
child.parent = new_node
|
@@ -422,8 +434,8 @@ class HiRadixCache(RadixCache):
|
|
422
434
|
node.children[child_key] = new_node
|
423
435
|
self.evictable_size_ += len(value)
|
424
436
|
|
425
|
-
if self.cache_controller.write_policy
|
426
|
-
self.
|
437
|
+
if self.cache_controller.write_policy != "write_back":
|
438
|
+
self.inc_hit_count(new_node)
|
427
439
|
return total_prefix_length
|
428
440
|
|
429
441
|
def _collect_leaves_device(self):
|