sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,9 @@ import logging
|
|
5
5
|
import os
|
6
6
|
import signal
|
7
7
|
import threading
|
8
|
-
from
|
8
|
+
from abc import ABC, abstractmethod
|
9
9
|
from functools import wraps
|
10
|
-
from typing import List, Optional
|
10
|
+
from typing import List, Optional, Tuple
|
11
11
|
|
12
12
|
import torch
|
13
13
|
|
@@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
|
17
17
|
logger = logging.getLogger(__name__)
|
18
18
|
|
19
19
|
|
20
|
+
class Hf3fsMetadataInterface(ABC):
|
21
|
+
"""Interface for HF3FS metadata operations."""
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
def initialize(self, rank: int, num_pages: int) -> None:
|
25
|
+
"""Initialize the metadata service with specified number of pages."""
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def reserve_and_allocate_page_indices(
|
30
|
+
self,
|
31
|
+
rank: int,
|
32
|
+
keys: List[Tuple[str, str]],
|
33
|
+
) -> List[Tuple[bool, int]]:
|
34
|
+
"""
|
35
|
+
Reserve and allocate page indices for the specified keys.
|
36
|
+
Args:
|
37
|
+
rank: The rank of the process.
|
38
|
+
keys: The keys to reserve and allocate page indices for. Each tuple contains a key and the key of its prefix block.
|
39
|
+
Returns:
|
40
|
+
List[Tuple[bool, int]]: A list of tuples, where each tuple contains a boolean indicating whether the key has existed and an integer indicating the allocated page index.
|
41
|
+
"""
|
42
|
+
pass
|
43
|
+
|
44
|
+
@abstractmethod
|
45
|
+
def confirm_write(
|
46
|
+
self,
|
47
|
+
rank: int,
|
48
|
+
written_keys_to_confirm: List[Tuple[str, int]],
|
49
|
+
pages_to_release: List[int],
|
50
|
+
) -> None:
|
51
|
+
"""
|
52
|
+
Confirm that key-value pairs have been successfully written to storage.
|
53
|
+
Args:
|
54
|
+
rank: The rank of the process.
|
55
|
+
written_keys_to_confirm: A list of tuples, where each tuple contains a key and its corresponding page index.
|
56
|
+
pages_to_release: A list of page indices to be released.
|
57
|
+
"""
|
58
|
+
pass
|
59
|
+
|
60
|
+
@abstractmethod
|
61
|
+
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
|
62
|
+
"""
|
63
|
+
Get page indices for the specified keys.
|
64
|
+
Args:
|
65
|
+
rank: The rank of the process.
|
66
|
+
keys: A list of keys.
|
67
|
+
Returns:
|
68
|
+
List[Optional[int]]: A list of integers representing the page indices for the specified keys.
|
69
|
+
If a key is not found, the corresponding index will be None.
|
70
|
+
"""
|
71
|
+
pass
|
72
|
+
|
73
|
+
@abstractmethod
|
74
|
+
def delete_keys(self, rank: int, keys: List[str]) -> None:
|
75
|
+
"""Delete specified keys and their associated pages."""
|
76
|
+
pass
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
def exists(self, rank: int, keys: List[str]) -> List[bool]:
|
80
|
+
"""Check if the specified keys exist."""
|
81
|
+
pass
|
82
|
+
|
83
|
+
@abstractmethod
|
84
|
+
def clear(self, rank: int) -> None:
|
85
|
+
"""Clear all key-value pairs and page allocations for the specified rank."""
|
86
|
+
pass
|
87
|
+
|
88
|
+
|
20
89
|
class AtomicCounter:
|
21
90
|
def __init__(self, n: int):
|
22
91
|
assert n > 0
|
@@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
48
117
|
|
49
118
|
def __init__(
|
50
119
|
self,
|
120
|
+
rank: int,
|
51
121
|
file_path: str,
|
52
122
|
file_size: int,
|
53
123
|
numjobs: int,
|
54
124
|
bytes_per_page: int,
|
55
125
|
entries: int,
|
56
126
|
dtype: torch.dtype,
|
127
|
+
metadata_client: Hf3fsMetadataInterface,
|
57
128
|
):
|
129
|
+
self.rank = rank
|
58
130
|
self.file_path = file_path
|
59
131
|
self.file_size = file_size
|
60
132
|
self.numjobs = numjobs
|
61
133
|
self.bytes_per_page = bytes_per_page
|
62
134
|
self.entries = entries
|
63
135
|
self.dtype = dtype
|
136
|
+
self.metadata_client = metadata_client
|
64
137
|
|
65
138
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
66
|
-
|
67
139
|
self.num_pages = self.file_size // self.bytes_per_page
|
68
140
|
|
69
141
|
logger.info(
|
70
|
-
"HiCacheHF3FS "
|
71
|
-
f"file_path
|
72
|
-
f"file_size
|
73
|
-
f"
|
74
|
-
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
|
75
|
-
f"entries = {self.entries}, "
|
76
|
-
f"num_pages = {self.num_pages}"
|
142
|
+
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
143
|
+
f"file_path={self.file_path}, "
|
144
|
+
f"file_size={self.file_size / (2 ** 30):.2f} GB, "
|
145
|
+
f"num_pages={self.num_pages}"
|
77
146
|
)
|
78
147
|
|
79
148
|
self.ac = AtomicCounter(self.numjobs)
|
@@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
84
153
|
for _ in range(numjobs)
|
85
154
|
]
|
86
155
|
self.executor = concurrent.futures.ThreadPoolExecutor(
|
87
|
-
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
|
156
|
+
max_workers=self.numjobs, thread_name_prefix=f"HiCacheHF3FS-Rank{self.rank}"
|
88
157
|
)
|
89
158
|
|
90
|
-
|
91
|
-
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
|
92
|
-
# through centralized metadata orchestration.
|
159
|
+
self.metadata_client.initialize(self.rank, self.num_pages)
|
93
160
|
self.lock = threading.RLock()
|
94
|
-
self.free_pages = list(range(self.num_pages))
|
95
|
-
self.key_to_index = OrderedDict()
|
96
161
|
|
97
162
|
atexit.register(self.close)
|
98
163
|
|
@@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
104
169
|
def from_env_config(
|
105
170
|
rank: int, bytes_per_page: int, dtype: torch.dtype
|
106
171
|
) -> "HiCacheHF3FS":
|
172
|
+
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
173
|
+
Hf3fsGlobalMetadataClient,
|
174
|
+
Hf3fsLocalMetadataClient,
|
175
|
+
)
|
176
|
+
|
107
177
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
108
178
|
if not config_path:
|
109
179
|
return HiCacheHF3FS(
|
180
|
+
rank=rank,
|
110
181
|
file_path=f"/data/hicache.{rank}.bin",
|
111
182
|
file_size=1 << 40,
|
112
183
|
numjobs=16,
|
113
184
|
bytes_per_page=bytes_per_page,
|
114
185
|
entries=8,
|
115
186
|
dtype=dtype,
|
187
|
+
metadata_client=Hf3fsLocalMetadataClient(),
|
116
188
|
)
|
117
189
|
|
118
190
|
try:
|
@@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
121
193
|
except Exception as e:
|
122
194
|
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
123
195
|
|
196
|
+
# Check required keys (metadata_server_url is now optional)
|
124
197
|
required_keys = {
|
125
198
|
"file_path_prefix",
|
126
199
|
"file_size",
|
@@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
131
204
|
if missing_keys:
|
132
205
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
133
206
|
|
207
|
+
# Choose metadata client based on configuration
|
208
|
+
if "metadata_server_url" in config and config["metadata_server_url"]:
|
209
|
+
# Use global metadata client to connect to metadata server
|
210
|
+
metadata_server_url = config["metadata_server_url"]
|
211
|
+
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
212
|
+
logger.info(
|
213
|
+
f"Using global metadata client with server url: {metadata_server_url}"
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
# Use local metadata client for single-machine deployment
|
217
|
+
metadata_client = Hf3fsLocalMetadataClient()
|
218
|
+
|
134
219
|
return HiCacheHF3FS(
|
220
|
+
rank=rank,
|
135
221
|
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
136
222
|
file_size=int(config["file_size"]),
|
137
223
|
numjobs=int(config["numjobs"]),
|
138
224
|
bytes_per_page=bytes_per_page,
|
139
225
|
entries=int(config["entries"]),
|
140
226
|
dtype=dtype,
|
227
|
+
metadata_client=metadata_client,
|
141
228
|
)
|
142
229
|
|
143
230
|
def get(
|
144
231
|
self, key: str, target_location: Optional[torch.Tensor] = None
|
145
232
|
) -> torch.Tensor | None:
|
146
|
-
return self.batch_get([key], target_location)[0]
|
233
|
+
return self.batch_get([key], [target_location] if target_location else None)[0]
|
147
234
|
|
148
235
|
@synchronized()
|
149
236
|
def batch_get(
|
@@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
151
238
|
keys: List[str],
|
152
239
|
target_locations: Optional[List[torch.Tensor]] = None,
|
153
240
|
) -> List[torch.Tensor | None]:
|
241
|
+
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
242
|
+
|
154
243
|
batch_indices, file_offsets = [], []
|
155
|
-
for i,
|
156
|
-
if
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
self.key_to_index.move_to_end(key)
|
161
|
-
# TODO: target_locations
|
244
|
+
for i, page_index in enumerate(page_indices):
|
245
|
+
if page_index is not None:
|
246
|
+
batch_indices.append(i)
|
247
|
+
file_offsets.append(page_index * self.bytes_per_page)
|
248
|
+
|
162
249
|
file_results = [
|
163
250
|
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
164
251
|
]
|
@@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
180
267
|
if read_result == self.bytes_per_page:
|
181
268
|
results[batch_index] = file_result
|
182
269
|
else:
|
183
|
-
logger.error(
|
270
|
+
logger.error(
|
271
|
+
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
|
272
|
+
)
|
184
273
|
|
185
274
|
return results
|
186
275
|
|
@@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
188
277
|
return self.batch_set([key], [value])
|
189
278
|
|
190
279
|
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
191
|
-
|
280
|
+
# Todo: Add prefix block's hash key
|
281
|
+
key_with_prefix = [(key, "") for key in keys]
|
282
|
+
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
283
|
+
self.rank, key_with_prefix
|
284
|
+
)
|
285
|
+
|
192
286
|
batch_indices, file_offsets, file_values = [], [], []
|
193
|
-
|
194
|
-
|
287
|
+
pages_to_release = []
|
288
|
+
|
289
|
+
for i, (value, (is_written, page_index)) in enumerate(zip(values, indices)):
|
290
|
+
if is_written or page_index == -1:
|
195
291
|
continue
|
292
|
+
|
196
293
|
batch_indices.append(i)
|
197
|
-
file_offsets.append(
|
294
|
+
file_offsets.append(page_index * self.bytes_per_page)
|
198
295
|
file_values.append(value.contiguous())
|
199
296
|
|
200
297
|
futures = [
|
@@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
211
308
|
for result in future.result()
|
212
309
|
]
|
213
310
|
|
311
|
+
written_keys_to_confirm = []
|
214
312
|
results = [index[0] for index in indices]
|
215
313
|
for batch_index, write_result in zip(batch_indices, write_results):
|
216
314
|
key = keys[batch_index]
|
217
|
-
|
315
|
+
page_index = indices[batch_index][1]
|
218
316
|
if write_result:
|
219
|
-
|
220
|
-
self.key_to_index.move_to_end(key)
|
317
|
+
written_keys_to_confirm.append((key, page_index))
|
221
318
|
else:
|
222
|
-
logger.error(f"HiCacheHF3FS set {key} failed")
|
223
|
-
|
319
|
+
logger.error(f"[Rank {self.rank}] HiCacheHF3FS set {key} failed")
|
320
|
+
pages_to_release.append(page_index)
|
224
321
|
results[batch_index] = write_result
|
225
|
-
return all(results)
|
226
|
-
|
227
|
-
@synchronized()
|
228
|
-
def get_batch_set_indices(self, keys: List[str]) -> list:
|
229
|
-
ionum = len(keys)
|
230
|
-
# results: tuples of (is_written: bool, page_idx: int)
|
231
|
-
# - is_written: True = hit (no I/O), False = write (miss)
|
232
|
-
# - page_idx: page storing data
|
233
|
-
results = [None] * min(ionum, self.num_pages)
|
234
|
-
if ionum > self.num_pages:
|
235
|
-
results.extend([(False, -1)] * (ionum - self.num_pages))
|
236
|
-
|
237
|
-
new_keys = []
|
238
|
-
for batch_index, key in enumerate(keys[: self.num_pages]):
|
239
|
-
if key in self.key_to_index:
|
240
|
-
results[batch_index] = (True, self.key_to_index[key])
|
241
|
-
self.key_to_index.move_to_end(key)
|
242
|
-
else:
|
243
|
-
new_keys.append((batch_index, key))
|
244
322
|
|
245
|
-
|
246
|
-
|
247
|
-
self.
|
248
|
-
if len(self.free_pages) > 0
|
249
|
-
else self.key_to_index.popitem(last=False)[1]
|
323
|
+
if len(written_keys_to_confirm) > 0 or len(pages_to_release) > 0:
|
324
|
+
self.metadata_client.confirm_write(
|
325
|
+
self.rank, written_keys_to_confirm, pages_to_release
|
250
326
|
)
|
251
|
-
results[batch_index] = (False, index)
|
252
327
|
|
253
|
-
return results
|
328
|
+
return all(results)
|
254
329
|
|
255
330
|
@synchronized()
|
256
331
|
def delete(self, key: str) -> None:
|
257
|
-
|
258
|
-
return
|
259
|
-
index = self.key_to_index.pop(key)
|
260
|
-
self.free_pages.append(index)
|
332
|
+
self.metadata_client.delete_keys(self.rank, [key])
|
261
333
|
|
262
334
|
@synchronized()
|
263
335
|
def exists(self, key: str) -> bool:
|
264
|
-
|
336
|
+
result = self.metadata_client.exists(self.rank, [key])
|
337
|
+
return result[0] if result else False
|
265
338
|
|
266
339
|
@synchronized()
|
267
340
|
def clear(self) -> None:
|
268
|
-
self.
|
269
|
-
self.key_to_index.clear()
|
341
|
+
self.metadata_client.clear(self.rank)
|
270
342
|
|
271
343
|
def close(self) -> None:
|
272
344
|
try:
|
@@ -18,13 +18,12 @@ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
|
|
18
18
|
logger = logging.getLogger(__name__)
|
19
19
|
|
20
20
|
|
21
|
-
def get_hash_str_mooncake(
|
21
|
+
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
22
22
|
local_rank = get_tensor_model_parallel_rank()
|
23
23
|
prefix_str = ""
|
24
|
-
if
|
25
|
-
|
26
|
-
|
27
|
-
current_token_ids_bytes = np.array(current_page_ids).tobytes()
|
24
|
+
if prior_hash:
|
25
|
+
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
26
|
+
current_token_ids_bytes = np.array(token_ids).tobytes()
|
28
27
|
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
29
28
|
current_hash_hex = current_hash_object.hexdigest()
|
30
29
|
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
|
@@ -224,13 +223,11 @@ class MooncakeStore(HiCacheStorage):
|
|
224
223
|
|
225
224
|
def exists(self, keys) -> bool | dict:
|
226
225
|
_keys = []
|
227
|
-
local_rank = torch.cuda.current_device()
|
228
226
|
for key in keys:
|
229
227
|
if key is None:
|
230
228
|
return None
|
231
|
-
|
232
|
-
|
233
|
-
_keys.append(f"{key}_{local_rank}_k")
|
229
|
+
|
230
|
+
_keys.append(f"{key}_k")
|
234
231
|
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
235
232
|
return result
|
236
233
|
|
@@ -33,7 +33,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
33
33
|
set_graph_pool_id,
|
34
34
|
)
|
35
35
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
36
|
-
from sglang.srt.layers.dp_attention import
|
36
|
+
from sglang.srt.layers.dp_attention import (
|
37
|
+
DpPaddingMode,
|
38
|
+
get_attention_tp_rank,
|
39
|
+
get_attention_tp_size,
|
40
|
+
set_dp_buffer_len,
|
41
|
+
)
|
37
42
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
38
43
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
39
44
|
from sglang.srt.model_executor.forward_batch_info import (
|
@@ -255,6 +260,9 @@ class CudaGraphRunner:
|
|
255
260
|
self.dp_size = model_runner.server_args.dp_size
|
256
261
|
self.pp_size = model_runner.server_args.pp_size
|
257
262
|
|
263
|
+
self.attn_tp_size = get_attention_tp_size()
|
264
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
265
|
+
|
258
266
|
# Batch sizes to capture
|
259
267
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
260
268
|
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
@@ -342,30 +350,15 @@ class CudaGraphRunner:
|
|
342
350
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
343
351
|
(self.dp_size,), dtype=torch.int32
|
344
352
|
)
|
345
|
-
self.gathered_buffer = torch.zeros(
|
346
|
-
(
|
347
|
-
self.max_num_token * self.dp_size,
|
348
|
-
self.model_runner.model_config.hidden_size,
|
349
|
-
),
|
350
|
-
dtype=self.model_runner.dtype,
|
351
|
-
)
|
352
353
|
else:
|
353
354
|
assert self.require_attn_tp_gather
|
354
355
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
355
356
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
356
357
|
(1,), dtype=torch.int32
|
357
358
|
)
|
358
|
-
self.gathered_buffer = torch.zeros(
|
359
|
-
(
|
360
|
-
self.max_num_token,
|
361
|
-
self.model_runner.model_config.hidden_size,
|
362
|
-
),
|
363
|
-
dtype=self.model_runner.dtype,
|
364
|
-
)
|
365
359
|
else:
|
366
360
|
self.global_num_tokens_gpu = None
|
367
361
|
self.global_num_tokens_for_logprob_gpu = None
|
368
|
-
self.gathered_buffer = None
|
369
362
|
|
370
363
|
self.custom_mask = torch.ones(
|
371
364
|
(
|
@@ -549,7 +542,7 @@ class CudaGraphRunner:
|
|
549
542
|
device=input_ids.device,
|
550
543
|
)
|
551
544
|
)
|
552
|
-
|
545
|
+
global_dp_buffer_len = num_tokens * self.dp_size
|
553
546
|
elif self.require_attn_tp_gather:
|
554
547
|
self.global_num_tokens_gpu.copy_(
|
555
548
|
torch.tensor(
|
@@ -565,9 +558,9 @@ class CudaGraphRunner:
|
|
565
558
|
device=input_ids.device,
|
566
559
|
)
|
567
560
|
)
|
568
|
-
|
561
|
+
global_dp_buffer_len = num_tokens
|
569
562
|
else:
|
570
|
-
|
563
|
+
global_dp_buffer_len = None
|
571
564
|
|
572
565
|
spec_info = self.get_spec_info(num_tokens)
|
573
566
|
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
@@ -600,8 +593,8 @@ class CudaGraphRunner:
|
|
600
593
|
positions=positions,
|
601
594
|
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
602
595
|
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
603
|
-
dp_padding_mode=
|
604
|
-
|
596
|
+
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
597
|
+
global_dp_buffer_len=global_dp_buffer_len,
|
605
598
|
mrope_positions=mrope_positions,
|
606
599
|
spec_algorithm=self.model_runner.spec_algorithm,
|
607
600
|
spec_info=spec_info,
|
@@ -630,6 +623,7 @@ class CudaGraphRunner:
|
|
630
623
|
def run_once():
|
631
624
|
# Clean intermediate result cache for DP attention
|
632
625
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
626
|
+
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
633
627
|
|
634
628
|
kwargs = {}
|
635
629
|
if (
|
@@ -729,10 +723,12 @@ class CudaGraphRunner:
|
|
729
723
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
730
724
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
731
725
|
|
726
|
+
seq_lens_cpu = None
|
732
727
|
if forward_batch.seq_lens_cpu is not None:
|
733
728
|
if bs != raw_bs:
|
734
729
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
735
730
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
731
|
+
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
736
732
|
|
737
733
|
if pp_proxy_tensors:
|
738
734
|
for key in self.pp_proxy_tensors.keys():
|
@@ -747,7 +743,17 @@ class CudaGraphRunner:
|
|
747
743
|
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
748
744
|
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
749
745
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
750
|
-
|
746
|
+
num_token_non_padded = forward_batch.num_token_non_padded
|
747
|
+
if self.require_gathered_buffer:
|
748
|
+
tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
|
749
|
+
num_local_token_non_padded = torch.clamp(
|
750
|
+
num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
|
751
|
+
min=0,
|
752
|
+
max=tokens_per_rank,
|
753
|
+
)
|
754
|
+
self.num_token_non_padded.copy_(num_local_token_non_padded)
|
755
|
+
else:
|
756
|
+
self.num_token_non_padded.copy_(num_token_non_padded)
|
751
757
|
if self.enable_two_batch_overlap:
|
752
758
|
self.tbo_plugin.replay_prepare(
|
753
759
|
forward_mode=self.capture_forward_mode,
|
@@ -766,7 +772,7 @@ class CudaGraphRunner:
|
|
766
772
|
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
767
773
|
self.capture_forward_mode,
|
768
774
|
forward_batch.spec_info,
|
769
|
-
seq_lens_cpu=
|
775
|
+
seq_lens_cpu=seq_lens_cpu,
|
770
776
|
)
|
771
777
|
|
772
778
|
# Store fields
|
@@ -40,9 +40,10 @@ import triton.language as tl
|
|
40
40
|
|
41
41
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
42
42
|
from sglang.srt.layers.dp_attention import (
|
43
|
-
|
43
|
+
DpPaddingMode,
|
44
44
|
get_attention_dp_rank,
|
45
45
|
get_attention_tp_size,
|
46
|
+
set_dp_buffer_len,
|
46
47
|
)
|
47
48
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
48
49
|
from sglang.srt.utils import (
|
@@ -274,13 +275,13 @@ class ForwardBatch:
|
|
274
275
|
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
275
276
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
276
277
|
# The padding mode for DP attention
|
277
|
-
dp_padding_mode: Optional[
|
278
|
+
dp_padding_mode: Optional[DpPaddingMode] = None
|
278
279
|
# for extend, local start pos and num tokens is different in logits processor
|
279
280
|
# this will be computed in get_dp_local_info
|
280
281
|
# this will be recomputed in LogitsMetadata.from_forward_batch
|
281
282
|
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
282
283
|
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
283
|
-
|
284
|
+
global_dp_buffer_len: Optional[int] = None
|
284
285
|
is_extend_in_batch: bool = False
|
285
286
|
can_run_dp_cuda_graph: bool = False
|
286
287
|
global_forward_mode: Optional[ForwardMode] = None
|
@@ -628,7 +629,7 @@ class ForwardBatch:
|
|
628
629
|
(global_num_tokens[i] - 1) // attn_tp_size + 1
|
629
630
|
) * attn_tp_size
|
630
631
|
|
631
|
-
dp_padding_mode =
|
632
|
+
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
|
632
633
|
self.dp_padding_mode = dp_padding_mode
|
633
634
|
|
634
635
|
if dp_padding_mode.is_max_len():
|
@@ -642,23 +643,38 @@ class ForwardBatch:
|
|
642
643
|
else:
|
643
644
|
buffer_len = sum(global_num_tokens)
|
644
645
|
|
645
|
-
self.gathered_buffer = torch.zeros(
|
646
|
-
(buffer_len, model_runner.model_config.hidden_size),
|
647
|
-
dtype=model_runner.dtype,
|
648
|
-
device=model_runner.device,
|
649
|
-
)
|
650
|
-
|
651
646
|
if len(global_num_tokens) > 1:
|
652
647
|
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
653
648
|
else:
|
654
649
|
num_tokens = global_num_tokens[0]
|
655
650
|
|
656
|
-
|
657
|
-
|
658
|
-
self.batch_size = num_tokens
|
651
|
+
self.global_dp_buffer_len = buffer_len
|
652
|
+
set_dp_buffer_len(buffer_len, num_tokens)
|
659
653
|
|
660
654
|
bs = self.batch_size
|
661
655
|
|
656
|
+
if self.forward_mode.is_decode():
|
657
|
+
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
|
658
|
+
setattr(self, "_original_forward_mode", self.forward_mode)
|
659
|
+
self.forward_mode = ForwardMode.EXTEND
|
660
|
+
self.extend_num_tokens = bs
|
661
|
+
self.extend_seq_lens = torch.full_like(self.seq_lens, 1)
|
662
|
+
self.extend_prefix_lens = self.seq_lens - 1
|
663
|
+
self.extend_start_loc = torch.arange(
|
664
|
+
bs, dtype=torch.int32, device=self.seq_lens.device
|
665
|
+
)
|
666
|
+
self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu()
|
667
|
+
self.extend_seq_lens_cpu = self.extend_seq_lens.cpu()
|
668
|
+
self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu
|
669
|
+
else:
|
670
|
+
setattr(self, "_original_batch_size", self.batch_size)
|
671
|
+
if self.spec_info is not None:
|
672
|
+
bs = self.batch_size = (
|
673
|
+
num_tokens // self.spec_info.num_tokens_per_batch
|
674
|
+
)
|
675
|
+
else:
|
676
|
+
bs = self.batch_size = num_tokens
|
677
|
+
|
662
678
|
# padding
|
663
679
|
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
664
680
|
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
@@ -689,6 +705,7 @@ class ForwardBatch:
|
|
689
705
|
if self.mrope_positions is not None:
|
690
706
|
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
|
691
707
|
|
708
|
+
# TODO: check if we need to pad other tensors
|
692
709
|
if self.extend_seq_lens is not None:
|
693
710
|
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
694
711
|
|
@@ -712,7 +729,9 @@ class ForwardBatch:
|
|
712
729
|
|
713
730
|
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
714
731
|
|
715
|
-
|
732
|
+
self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
|
733
|
+
self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
|
734
|
+
bs = self.batch_size
|
716
735
|
|
717
736
|
if self.spec_info is not None:
|
718
737
|
if self.forward_mode.is_decode(): # draft
|