sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
|
|
78
78
|
ExpertDistributionReqOutput,
|
79
79
|
FlushCacheReqInput,
|
80
80
|
FlushCacheReqOutput,
|
81
|
+
FreezeGCReq,
|
81
82
|
GenerateReqInput,
|
82
83
|
GetInternalStateReq,
|
83
84
|
GetInternalStateReqOutput,
|
@@ -122,7 +123,9 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
|
122
123
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
123
124
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
124
125
|
from sglang.srt.utils import (
|
126
|
+
configure_gc_warning,
|
125
127
|
dataclass_to_string_truncated,
|
128
|
+
freeze_gc,
|
126
129
|
get_bool_env_var,
|
127
130
|
get_zmq_socket,
|
128
131
|
kill_process_tree,
|
@@ -298,7 +301,7 @@ class TokenizerManager:
|
|
298
301
|
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
299
302
|
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
300
303
|
# to internally used unique LoRA IDs.
|
301
|
-
self.lora_registry = LoRARegistry(self.server_args.lora_paths
|
304
|
+
self.lora_registry = LoRARegistry(self.server_args.lora_paths)
|
302
305
|
# Lock to serialize LoRA update operations.
|
303
306
|
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
304
307
|
# LoRA updates and inference to overlap.
|
@@ -352,6 +355,10 @@ class TokenizerManager:
|
|
352
355
|
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
|
353
356
|
)
|
354
357
|
|
358
|
+
# Configure GC warning
|
359
|
+
if self.server_args.gc_warning_threshold_secs > 0.0:
|
360
|
+
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
361
|
+
|
355
362
|
# Communicators
|
356
363
|
self.init_weights_update_group_communicator = _Communicator(
|
357
364
|
self.send_to_scheduler, server_args.dp_size
|
@@ -446,6 +453,10 @@ class TokenizerManager:
|
|
446
453
|
ProfileReqOutput,
|
447
454
|
self.profile_communicator.handle_recv,
|
448
455
|
),
|
456
|
+
(
|
457
|
+
FreezeGCReq,
|
458
|
+
lambda x: None,
|
459
|
+
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
449
460
|
(
|
450
461
|
GetInternalStateReqOutput,
|
451
462
|
self.get_internal_state_communicator.handle_recv,
|
@@ -565,14 +576,24 @@ class TokenizerManager:
|
|
565
576
|
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
566
577
|
) -> None:
|
567
578
|
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
579
|
+
# FIXME: unify the length validation logic with the one in the scheduler.
|
580
|
+
_max_req_len = self.context_len
|
568
581
|
|
569
582
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
570
|
-
# Check if input alone exceeds context length
|
571
583
|
if input_token_num >= self.context_len:
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
584
|
+
if self.server_args.allow_auto_truncate:
|
585
|
+
logger.warning(
|
586
|
+
f"The input ({input_token_num} tokens) is longer than the "
|
587
|
+
f"model's context length ({self.context_len} tokens). "
|
588
|
+
"Truncating the input."
|
589
|
+
)
|
590
|
+
del input_ids[_max_req_len:]
|
591
|
+
input_token_num = len(input_ids)
|
592
|
+
else:
|
593
|
+
raise ValueError(
|
594
|
+
f"The input ({input_token_num} tokens) is longer than the "
|
595
|
+
f"model's context length ({self.context_len} tokens)."
|
596
|
+
)
|
576
597
|
|
577
598
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
578
599
|
raise ValueError(
|
@@ -584,17 +605,27 @@ class TokenizerManager:
|
|
584
605
|
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
585
606
|
if (
|
586
607
|
max_new_tokens is not None
|
587
|
-
and (max_new_tokens + input_token_num) >=
|
608
|
+
and (max_new_tokens + input_token_num) >= _max_req_len
|
588
609
|
):
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
610
|
+
if self.server_args.allow_auto_truncate:
|
611
|
+
logger.warning(
|
612
|
+
f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
|
613
|
+
f"exceeds the model's context length ({self.context_len} tokens). "
|
614
|
+
"Truncating max_new_tokens."
|
615
|
+
)
|
616
|
+
obj.sampling_params["max_new_tokens"] = max(
|
617
|
+
0, _max_req_len - input_token_num
|
618
|
+
)
|
619
|
+
else:
|
620
|
+
total_tokens = max_new_tokens + input_token_num
|
621
|
+
error_msg = (
|
622
|
+
f"Requested token count exceeds the model's maximum context length "
|
623
|
+
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
624
|
+
f"tokens: {input_token_num} tokens from the input messages and "
|
625
|
+
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
626
|
+
f"of tokens in the input messages or the completion to fit within the limit."
|
627
|
+
)
|
628
|
+
raise ValueError(error_msg)
|
598
629
|
|
599
630
|
if isinstance(obj, GenerateReqInput):
|
600
631
|
if (
|
@@ -782,15 +813,17 @@ class TokenizerManager:
|
|
782
813
|
):
|
783
814
|
raise ValueError(finish_reason["message"])
|
784
815
|
|
785
|
-
if (
|
786
|
-
|
787
|
-
|
788
|
-
|
816
|
+
if finish_reason.get("type") == "abort" and finish_reason.get(
|
817
|
+
"status_code"
|
818
|
+
) in (
|
819
|
+
HTTPStatus.SERVICE_UNAVAILABLE,
|
820
|
+
HTTPStatus.INTERNAL_SERVER_ERROR,
|
789
821
|
):
|
790
822
|
# This is an abort request initiated by scheduler.
|
791
823
|
# Delete the key to prevent resending abort request to the scheduler and
|
792
824
|
# to ensure aborted request state is cleaned up.
|
793
|
-
|
825
|
+
if state.obj.rid in self.rid_to_state:
|
826
|
+
del self.rid_to_state[state.obj.rid]
|
794
827
|
|
795
828
|
# Mark ongoing LoRA request as finished.
|
796
829
|
if self.server_args.enable_lora and state.obj.lora_path:
|
@@ -1337,6 +1370,12 @@ class TokenizerManager:
|
|
1337
1370
|
logging.info(f"Config logging: {obj=}")
|
1338
1371
|
self.log_request_metadata = self.get_log_request_metadata()
|
1339
1372
|
|
1373
|
+
async def freeze_gc(self):
|
1374
|
+
"""Send a freeze_gc message to the scheduler first, then freeze locally."""
|
1375
|
+
self.send_to_scheduler.send_pyobj(FreezeGCReq())
|
1376
|
+
freeze_gc("Tokenizer Manager")
|
1377
|
+
return None
|
1378
|
+
|
1340
1379
|
def create_abort_task(self, obj: GenerateReqInput):
|
1341
1380
|
# Abort the request if the client is disconnected.
|
1342
1381
|
async def abort_request():
|
sglang/srt/managers/tp_worker.py
CHANGED
sglang/srt/managers/utils.py
CHANGED
@@ -1,9 +1,16 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
4
|
import multiprocessing as mp
|
3
5
|
from http import HTTPStatus
|
4
|
-
from typing import Dict, List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
5
7
|
|
8
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
6
9
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
10
|
+
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult
|
7
14
|
|
8
15
|
logger = logging.getLogger(__name__)
|
9
16
|
|
@@ -41,6 +48,57 @@ def validate_input_length(
|
|
41
48
|
return None
|
42
49
|
|
43
50
|
|
51
|
+
def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict:
|
52
|
+
|
53
|
+
logits_output = result.logits_output
|
54
|
+
assert logits_output is not None
|
55
|
+
|
56
|
+
return {
|
57
|
+
"extend_input_len_per_req": result.extend_input_len_per_req,
|
58
|
+
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
59
|
+
"next_token_logprobs": result.logits_output.next_token_logprobs,
|
60
|
+
"next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val,
|
61
|
+
"next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx,
|
62
|
+
"next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val,
|
63
|
+
"next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx,
|
64
|
+
"input_token_logprobs": result.logits_output.input_token_logprobs,
|
65
|
+
"input_top_logprobs_val": result.logits_output.input_top_logprobs_val,
|
66
|
+
"input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx,
|
67
|
+
"input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val,
|
68
|
+
"input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx,
|
69
|
+
}
|
70
|
+
|
71
|
+
|
72
|
+
def get_logprob_from_pp_outputs(
|
73
|
+
next_pp_outputs: PPProxyTensors,
|
74
|
+
) -> tuple[LogitsProcessorOutput, list[int], list[int]]:
|
75
|
+
logits_output = LogitsProcessorOutput(
|
76
|
+
# Do not send logits and hidden states because they are large
|
77
|
+
next_token_logits=None,
|
78
|
+
hidden_states=None,
|
79
|
+
next_token_logprobs=next_pp_outputs["next_token_logprobs"],
|
80
|
+
next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"],
|
81
|
+
next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"],
|
82
|
+
next_token_token_ids_logprobs_val=next_pp_outputs[
|
83
|
+
"next_token_token_ids_logprobs_val"
|
84
|
+
],
|
85
|
+
next_token_token_ids_logprobs_idx=next_pp_outputs[
|
86
|
+
"next_token_token_ids_logprobs_idx"
|
87
|
+
],
|
88
|
+
input_token_logprobs=next_pp_outputs["input_token_logprobs"],
|
89
|
+
input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"],
|
90
|
+
input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"],
|
91
|
+
input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"],
|
92
|
+
input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"],
|
93
|
+
)
|
94
|
+
extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"]
|
95
|
+
extend_logprob_start_len_per_req = next_pp_outputs[
|
96
|
+
"extend_logprob_start_len_per_req"
|
97
|
+
]
|
98
|
+
|
99
|
+
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
|
100
|
+
|
101
|
+
|
44
102
|
class DPBalanceMeta:
|
45
103
|
"""
|
46
104
|
This class will be use in scheduler and dp controller
|
@@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
434
434
|
device: str,
|
435
435
|
kvcache: KVCache,
|
436
436
|
need_sort: bool,
|
437
|
-
max_num_extend_tokens: int,
|
438
437
|
):
|
439
438
|
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
440
439
|
self.num_pages = size // page_size
|
441
|
-
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
|
442
|
-
max_num_extend_tokens
|
443
|
-
)
|
444
440
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
445
441
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
442
|
+
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
446
443
|
self.clear()
|
447
444
|
|
448
445
|
def alloc(self, need_size: int):
|
@@ -480,6 +477,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
480
477
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
481
478
|
)
|
482
479
|
|
480
|
+
self.seen_max_num_extend_tokens_next_power_of_2 = max(
|
481
|
+
self.seen_max_num_extend_tokens_next_power_of_2,
|
482
|
+
next_power_of_2(extend_num_tokens),
|
483
|
+
)
|
484
|
+
|
483
485
|
bs = len(prefix_lens)
|
484
486
|
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
|
485
487
|
self.free_pages
|
@@ -498,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
498
500
|
self.ret_values,
|
499
501
|
next_power_of_2(bs),
|
500
502
|
self.page_size,
|
501
|
-
self.
|
503
|
+
self.seen_max_num_extend_tokens_next_power_of_2,
|
502
504
|
)
|
503
505
|
|
504
506
|
if self.debug_mode:
|
@@ -66,17 +66,6 @@ def alloc_extend_kernel_ascend(
|
|
66
66
|
|
67
67
|
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
68
68
|
|
69
|
-
def __init__(
|
70
|
-
self,
|
71
|
-
size: int,
|
72
|
-
page_size: int,
|
73
|
-
dtype: torch.dtype,
|
74
|
-
device: str,
|
75
|
-
kvcache: KVCache,
|
76
|
-
need_sort: bool,
|
77
|
-
):
|
78
|
-
super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
|
79
|
-
|
80
69
|
def alloc_extend(
|
81
70
|
self,
|
82
71
|
prefix_lens: torch.Tensor,
|
@@ -13,6 +13,11 @@ from sglang.srt.distributed import (
|
|
13
13
|
get_tensor_model_parallel_rank,
|
14
14
|
get_tensor_model_parallel_world_size,
|
15
15
|
)
|
16
|
+
from sglang.srt.layers.dp_attention import (
|
17
|
+
get_attention_tp_rank,
|
18
|
+
get_attention_tp_size,
|
19
|
+
is_dp_attention_enabled,
|
20
|
+
)
|
16
21
|
|
17
22
|
|
18
23
|
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
@@ -101,11 +106,16 @@ class HiCacheStorage(ABC):
|
|
101
106
|
|
102
107
|
class HiCacheFile(HiCacheStorage):
|
103
108
|
|
104
|
-
def __init__(self, file_path: str = "/tmp/hicache"):
|
109
|
+
def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
|
105
110
|
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
|
106
|
-
|
107
|
-
|
108
|
-
|
111
|
+
if is_dp_attention_enabled():
|
112
|
+
tp_rank = get_attention_tp_rank()
|
113
|
+
tp_size = get_attention_tp_size()
|
114
|
+
else:
|
115
|
+
tp_rank = get_tensor_model_parallel_rank()
|
116
|
+
tp_size = get_tensor_model_parallel_world_size()
|
117
|
+
|
118
|
+
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
|
109
119
|
if not os.path.exists(self.file_path) and tp_rank == 0:
|
110
120
|
os.makedirs(self.file_path)
|
111
121
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
@@ -849,7 +849,7 @@ class MLATokenToKVPool(KVCache):
|
|
849
849
|
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
850
850
|
|
851
851
|
set_mla_kv_buffer_triton(
|
852
|
-
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
852
|
+
self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
|
853
853
|
)
|
854
854
|
|
855
855
|
def get_cpu_copy(self, indices):
|
@@ -951,7 +951,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
951
951
|
cache_k = cache_k.to(self.dtype)
|
952
952
|
|
953
953
|
if self.store_dtype != self.dtype:
|
954
|
-
cache_k = cache_k.view(store_dtype)
|
954
|
+
cache_k = cache_k.view(self.store_dtype)
|
955
955
|
|
956
956
|
import torch_npu
|
957
957
|
|
@@ -1070,7 +1070,7 @@ def copy_all_layer_kv_cache(
|
|
1070
1070
|
num_loop = tl.cdiv(stride, BLOCK_SIZE)
|
1071
1071
|
for i in range(num_loop):
|
1072
1072
|
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
1073
|
-
mask = (num_locs_offset < num_locs)[:, None]
|
1073
|
+
mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
|
1074
1074
|
value = tl.load(
|
1075
1075
|
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
|
1076
1076
|
)
|
@@ -7,6 +7,7 @@ from functools import wraps
|
|
7
7
|
import psutil
|
8
8
|
import torch
|
9
9
|
|
10
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
10
11
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
12
|
from sglang.srt.utils import is_npu
|
12
13
|
|
@@ -307,6 +308,9 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
307
308
|
|
308
309
|
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
309
310
|
|
311
|
+
def get_ksize_per_token(self):
|
312
|
+
return self.get_size_per_token() // 2
|
313
|
+
|
310
314
|
def init_kv_buffer(self):
|
311
315
|
if self.layout == "layer_first":
|
312
316
|
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
@@ -484,8 +488,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
484
488
|
ptr_list.append(k_ptr)
|
485
489
|
ptr_list.append(v_ptr)
|
486
490
|
key_ = keys[index // self.page_size]
|
487
|
-
key_list.append(f"{key_}_k")
|
488
|
-
key_list.append(f"{key_}_v")
|
491
|
+
key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k")
|
492
|
+
key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v")
|
489
493
|
element_size = (
|
490
494
|
self.layer_num
|
491
495
|
* self.dtype.itemsize
|
@@ -496,6 +500,21 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
496
500
|
element_size_list = [element_size] * len(key_list)
|
497
501
|
return key_list, ptr_list, element_size_list
|
498
502
|
|
503
|
+
def get_buffer_with_hash(self, keys, indices):
|
504
|
+
assert self.layout == "page_first"
|
505
|
+
assert len(keys) == (len(indices) // self.page_size)
|
506
|
+
|
507
|
+
key_list = []
|
508
|
+
buf_list = []
|
509
|
+
|
510
|
+
for key, i in zip(keys, range(0, len(indices), self.page_size)):
|
511
|
+
key_list.append(f"{key}-k")
|
512
|
+
buf_list.append(self.k_buffer[i : i + self.page_size])
|
513
|
+
key_list.append(f"{key}-v")
|
514
|
+
buf_list.append(self.v_buffer[i : i + self.page_size])
|
515
|
+
|
516
|
+
return key_list, buf_list
|
517
|
+
|
499
518
|
|
500
519
|
class MLATokenToKVPoolHost(HostKVCache):
|
501
520
|
device_pool: MLATokenToKVPool
|
@@ -538,6 +557,9 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
538
557
|
* self.layer_num
|
539
558
|
)
|
540
559
|
|
560
|
+
def get_ksize_per_token(self):
|
561
|
+
return self.get_size_per_token()
|
562
|
+
|
541
563
|
def init_kv_buffer(self):
|
542
564
|
if self.layout == "layer_first":
|
543
565
|
dims = (
|
@@ -704,3 +726,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
704
726
|
)
|
705
727
|
element_size_list = [element_size] * len(key_list)
|
706
728
|
return key_list, ptr_list, element_size_list
|
729
|
+
|
730
|
+
def get_buffer_with_hash(self, keys, indices):
|
731
|
+
assert self.layout == "page_first"
|
732
|
+
assert len(keys) == (len(indices) // self.page_size)
|
733
|
+
|
734
|
+
buf_list = []
|
735
|
+
|
736
|
+
for i in range(0, len(indices), self.page_size):
|
737
|
+
buf_list.append(self.kv_buffer[i : i + self.page_size])
|
738
|
+
|
739
|
+
return keys, buf_list
|
@@ -7,10 +7,15 @@ import signal
|
|
7
7
|
import threading
|
8
8
|
from abc import ABC, abstractmethod
|
9
9
|
from functools import wraps
|
10
|
-
from typing import List, Optional, Tuple
|
10
|
+
from typing import Any, List, Optional, Tuple
|
11
11
|
|
12
12
|
import torch
|
13
13
|
|
14
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
15
|
+
from sglang.srt.layers.dp_attention import (
|
16
|
+
get_attention_tp_rank,
|
17
|
+
is_dp_attention_enabled,
|
18
|
+
)
|
14
19
|
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
15
20
|
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
16
21
|
|
@@ -167,13 +172,20 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
167
172
|
|
168
173
|
@staticmethod
|
169
174
|
def from_env_config(
|
170
|
-
|
175
|
+
bytes_per_page: int, dtype: torch.dtype, rank: int = None
|
171
176
|
) -> "HiCacheHF3FS":
|
172
177
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
173
178
|
Hf3fsGlobalMetadataClient,
|
174
179
|
Hf3fsLocalMetadataClient,
|
175
180
|
)
|
176
181
|
|
182
|
+
if rank is None:
|
183
|
+
rank = (
|
184
|
+
get_attention_tp_rank()
|
185
|
+
if is_dp_attention_enabled()
|
186
|
+
else get_tensor_model_parallel_rank()
|
187
|
+
)
|
188
|
+
|
177
189
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
178
190
|
if not config_path:
|
179
191
|
return HiCacheHF3FS(
|
@@ -228,15 +240,23 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
228
240
|
)
|
229
241
|
|
230
242
|
def get(
|
231
|
-
self,
|
243
|
+
self,
|
244
|
+
key: str,
|
245
|
+
target_location: Optional[Any] = None,
|
246
|
+
target_sizes: Optional[Any] = None,
|
232
247
|
) -> torch.Tensor | None:
|
233
|
-
return self.batch_get(
|
248
|
+
return self.batch_get(
|
249
|
+
[key],
|
250
|
+
[target_location] if target_location is not None else None,
|
251
|
+
[target_sizes] if target_sizes is not None else None,
|
252
|
+
)[0]
|
234
253
|
|
235
254
|
@synchronized()
|
236
255
|
def batch_get(
|
237
256
|
self,
|
238
257
|
keys: List[str],
|
239
|
-
target_locations: Optional[
|
258
|
+
target_locations: Optional[Any] = None,
|
259
|
+
target_sizes: Optional[Any] = None,
|
240
260
|
) -> List[torch.Tensor | None]:
|
241
261
|
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
242
262
|
|
@@ -246,9 +266,15 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
246
266
|
batch_indices.append(i)
|
247
267
|
file_offsets.append(page_index * self.bytes_per_page)
|
248
268
|
|
249
|
-
|
250
|
-
|
251
|
-
|
269
|
+
if target_locations is not None:
|
270
|
+
for target_location in target_locations:
|
271
|
+
assert target_location.is_contiguous()
|
272
|
+
file_results = target_locations
|
273
|
+
else:
|
274
|
+
file_results = [
|
275
|
+
torch.empty(self.numel, dtype=self.dtype)
|
276
|
+
for _ in range(len(batch_indices))
|
277
|
+
]
|
252
278
|
|
253
279
|
futures = [
|
254
280
|
self.executor.submit(
|
@@ -273,10 +299,27 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
273
299
|
|
274
300
|
return results
|
275
301
|
|
276
|
-
def set(
|
277
|
-
|
302
|
+
def set(
|
303
|
+
self,
|
304
|
+
key: str,
|
305
|
+
value: Optional[Any] = None,
|
306
|
+
target_location: Optional[Any] = None,
|
307
|
+
target_sizes: Optional[Any] = None,
|
308
|
+
) -> bool:
|
309
|
+
return self.batch_set(
|
310
|
+
[key],
|
311
|
+
[value] if value is not None else None,
|
312
|
+
[target_location] if target_location is not None else None,
|
313
|
+
[target_sizes] if target_sizes is not None else None,
|
314
|
+
)
|
278
315
|
|
279
|
-
def batch_set(
|
316
|
+
def batch_set(
|
317
|
+
self,
|
318
|
+
keys: List[str],
|
319
|
+
values: Optional[Any] = None,
|
320
|
+
target_locations: Optional[Any] = None,
|
321
|
+
target_sizes: Optional[Any] = None,
|
322
|
+
) -> bool:
|
280
323
|
# Todo: Add prefix block's hash key
|
281
324
|
key_with_prefix = [(key, "") for key in keys]
|
282
325
|
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
@@ -292,7 +335,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
292
335
|
|
293
336
|
batch_indices.append(i)
|
294
337
|
file_offsets.append(page_index * self.bytes_per_page)
|
295
|
-
|
338
|
+
assert value.is_contiguous()
|
339
|
+
file_values.append(value)
|
296
340
|
|
297
341
|
futures = [
|
298
342
|
self.executor.submit(
|
@@ -19,14 +19,13 @@ logger = logging.getLogger(__name__)
|
|
19
19
|
|
20
20
|
|
21
21
|
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
22
|
-
local_rank = get_tensor_model_parallel_rank()
|
23
22
|
prefix_str = ""
|
24
23
|
if prior_hash:
|
25
24
|
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
26
25
|
current_token_ids_bytes = np.array(token_ids).tobytes()
|
27
26
|
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
28
27
|
current_hash_hex = current_hash_object.hexdigest()
|
29
|
-
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}
|
28
|
+
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
|
30
29
|
|
31
30
|
|
32
31
|
@dataclass
|
@@ -97,7 +96,7 @@ class MooncakeStoreConfig:
|
|
97
96
|
|
98
97
|
|
99
98
|
class MooncakeStore(HiCacheStorage):
|
100
|
-
def __init__(self):
|
99
|
+
def __init__(self, is_mla: bool = False):
|
101
100
|
try:
|
102
101
|
from mooncake.store import MooncakeDistributedStore
|
103
102
|
except ImportError as e:
|
@@ -127,6 +126,7 @@ class MooncakeStore(HiCacheStorage):
|
|
127
126
|
logger.info("Connect to Mooncake store successfully.")
|
128
127
|
self.warmup()
|
129
128
|
logger.info("Mooncake store warmup successfully.")
|
129
|
+
self.is_mla = is_mla
|
130
130
|
|
131
131
|
except ValueError as e:
|
132
132
|
logger.error("Configuration loading failed: %s", e)
|
@@ -223,11 +223,15 @@ class MooncakeStore(HiCacheStorage):
|
|
223
223
|
|
224
224
|
def exists(self, keys) -> bool | dict:
|
225
225
|
_keys = []
|
226
|
+
local_rank = get_tensor_model_parallel_rank()
|
226
227
|
for key in keys:
|
227
228
|
if key is None:
|
228
229
|
return None
|
229
230
|
|
230
|
-
|
231
|
+
if self.is_mla:
|
232
|
+
_keys.append(f"{key}_k")
|
233
|
+
else:
|
234
|
+
_keys.append(f"{key}_{local_rank}_k")
|
231
235
|
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
232
236
|
return result
|
233
237
|
|