sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,9 @@ import logging
|
|
5
5
|
import os
|
6
6
|
import signal
|
7
7
|
import threading
|
8
|
-
from
|
8
|
+
from abc import ABC, abstractmethod
|
9
9
|
from functools import wraps
|
10
|
-
from typing import List, Optional
|
10
|
+
from typing import List, Optional, Tuple
|
11
11
|
|
12
12
|
import torch
|
13
13
|
|
@@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
|
17
17
|
logger = logging.getLogger(__name__)
|
18
18
|
|
19
19
|
|
20
|
+
class Hf3fsMetadataInterface(ABC):
|
21
|
+
"""Interface for HF3FS metadata operations."""
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
def initialize(self, rank: int, num_pages: int) -> None:
|
25
|
+
"""Initialize the metadata service with specified number of pages."""
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def reserve_and_allocate_page_indices(
|
30
|
+
self,
|
31
|
+
rank: int,
|
32
|
+
keys: List[Tuple[str, str]],
|
33
|
+
) -> List[Tuple[bool, int]]:
|
34
|
+
"""
|
35
|
+
Reserve and allocate page indices for the specified keys.
|
36
|
+
Args:
|
37
|
+
rank: The rank of the process.
|
38
|
+
keys: The keys to reserve and allocate page indices for. Each tuple contains a key and the key of its prefix block.
|
39
|
+
Returns:
|
40
|
+
List[Tuple[bool, int]]: A list of tuples, where each tuple contains a boolean indicating whether the key has existed and an integer indicating the allocated page index.
|
41
|
+
"""
|
42
|
+
pass
|
43
|
+
|
44
|
+
@abstractmethod
|
45
|
+
def confirm_write(
|
46
|
+
self,
|
47
|
+
rank: int,
|
48
|
+
written_keys_to_confirm: List[Tuple[str, int]],
|
49
|
+
pages_to_release: List[int],
|
50
|
+
) -> None:
|
51
|
+
"""
|
52
|
+
Confirm that key-value pairs have been successfully written to storage.
|
53
|
+
Args:
|
54
|
+
rank: The rank of the process.
|
55
|
+
written_keys_to_confirm: A list of tuples, where each tuple contains a key and its corresponding page index.
|
56
|
+
pages_to_release: A list of page indices to be released.
|
57
|
+
"""
|
58
|
+
pass
|
59
|
+
|
60
|
+
@abstractmethod
|
61
|
+
def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
|
62
|
+
"""
|
63
|
+
Get page indices for the specified keys.
|
64
|
+
Args:
|
65
|
+
rank: The rank of the process.
|
66
|
+
keys: A list of keys.
|
67
|
+
Returns:
|
68
|
+
List[Optional[int]]: A list of integers representing the page indices for the specified keys.
|
69
|
+
If a key is not found, the corresponding index will be None.
|
70
|
+
"""
|
71
|
+
pass
|
72
|
+
|
73
|
+
@abstractmethod
|
74
|
+
def delete_keys(self, rank: int, keys: List[str]) -> None:
|
75
|
+
"""Delete specified keys and their associated pages."""
|
76
|
+
pass
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
def exists(self, rank: int, keys: List[str]) -> List[bool]:
|
80
|
+
"""Check if the specified keys exist."""
|
81
|
+
pass
|
82
|
+
|
83
|
+
@abstractmethod
|
84
|
+
def clear(self, rank: int) -> None:
|
85
|
+
"""Clear all key-value pairs and page allocations for the specified rank."""
|
86
|
+
pass
|
87
|
+
|
88
|
+
|
20
89
|
class AtomicCounter:
|
21
90
|
def __init__(self, n: int):
|
22
91
|
assert n > 0
|
@@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
48
117
|
|
49
118
|
def __init__(
|
50
119
|
self,
|
120
|
+
rank: int,
|
51
121
|
file_path: str,
|
52
122
|
file_size: int,
|
53
123
|
numjobs: int,
|
54
124
|
bytes_per_page: int,
|
55
125
|
entries: int,
|
56
126
|
dtype: torch.dtype,
|
127
|
+
metadata_client: Hf3fsMetadataInterface,
|
57
128
|
):
|
129
|
+
self.rank = rank
|
58
130
|
self.file_path = file_path
|
59
131
|
self.file_size = file_size
|
60
132
|
self.numjobs = numjobs
|
61
133
|
self.bytes_per_page = bytes_per_page
|
62
134
|
self.entries = entries
|
63
135
|
self.dtype = dtype
|
136
|
+
self.metadata_client = metadata_client
|
64
137
|
|
65
138
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
66
|
-
|
67
139
|
self.num_pages = self.file_size // self.bytes_per_page
|
68
140
|
|
69
141
|
logger.info(
|
70
|
-
"HiCacheHF3FS "
|
71
|
-
f"file_path
|
72
|
-
f"file_size
|
73
|
-
f"
|
74
|
-
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
|
75
|
-
f"entries = {self.entries}, "
|
76
|
-
f"num_pages = {self.num_pages}"
|
142
|
+
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
143
|
+
f"file_path={self.file_path}, "
|
144
|
+
f"file_size={self.file_size / (2 ** 30):.2f} GB, "
|
145
|
+
f"num_pages={self.num_pages}"
|
77
146
|
)
|
78
147
|
|
79
148
|
self.ac = AtomicCounter(self.numjobs)
|
@@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
84
153
|
for _ in range(numjobs)
|
85
154
|
]
|
86
155
|
self.executor = concurrent.futures.ThreadPoolExecutor(
|
87
|
-
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
|
156
|
+
max_workers=self.numjobs, thread_name_prefix=f"HiCacheHF3FS-Rank{self.rank}"
|
88
157
|
)
|
89
158
|
|
90
|
-
|
91
|
-
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
|
92
|
-
# through centralized metadata orchestration.
|
159
|
+
self.metadata_client.initialize(self.rank, self.num_pages)
|
93
160
|
self.lock = threading.RLock()
|
94
|
-
self.free_pages = list(range(self.num_pages))
|
95
|
-
self.key_to_index = OrderedDict()
|
96
161
|
|
97
162
|
atexit.register(self.close)
|
98
163
|
|
@@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
104
169
|
def from_env_config(
|
105
170
|
rank: int, bytes_per_page: int, dtype: torch.dtype
|
106
171
|
) -> "HiCacheHF3FS":
|
172
|
+
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
173
|
+
Hf3fsGlobalMetadataClient,
|
174
|
+
Hf3fsLocalMetadataClient,
|
175
|
+
)
|
176
|
+
|
107
177
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
108
178
|
if not config_path:
|
109
179
|
return HiCacheHF3FS(
|
180
|
+
rank=rank,
|
110
181
|
file_path=f"/data/hicache.{rank}.bin",
|
111
182
|
file_size=1 << 40,
|
112
183
|
numjobs=16,
|
113
184
|
bytes_per_page=bytes_per_page,
|
114
185
|
entries=8,
|
115
186
|
dtype=dtype,
|
187
|
+
metadata_client=Hf3fsLocalMetadataClient(),
|
116
188
|
)
|
117
189
|
|
118
190
|
try:
|
@@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
121
193
|
except Exception as e:
|
122
194
|
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
123
195
|
|
196
|
+
# Check required keys (metadata_server_url is now optional)
|
124
197
|
required_keys = {
|
125
198
|
"file_path_prefix",
|
126
199
|
"file_size",
|
@@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
131
204
|
if missing_keys:
|
132
205
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
133
206
|
|
207
|
+
# Choose metadata client based on configuration
|
208
|
+
if "metadata_server_url" in config and config["metadata_server_url"]:
|
209
|
+
# Use global metadata client to connect to metadata server
|
210
|
+
metadata_server_url = config["metadata_server_url"]
|
211
|
+
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
212
|
+
logger.info(
|
213
|
+
f"Using global metadata client with server url: {metadata_server_url}"
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
# Use local metadata client for single-machine deployment
|
217
|
+
metadata_client = Hf3fsLocalMetadataClient()
|
218
|
+
|
134
219
|
return HiCacheHF3FS(
|
220
|
+
rank=rank,
|
135
221
|
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
136
222
|
file_size=int(config["file_size"]),
|
137
223
|
numjobs=int(config["numjobs"]),
|
138
224
|
bytes_per_page=bytes_per_page,
|
139
225
|
entries=int(config["entries"]),
|
140
226
|
dtype=dtype,
|
227
|
+
metadata_client=metadata_client,
|
141
228
|
)
|
142
229
|
|
143
230
|
def get(
|
144
231
|
self, key: str, target_location: Optional[torch.Tensor] = None
|
145
232
|
) -> torch.Tensor | None:
|
146
|
-
return self.batch_get([key], target_location)[0]
|
233
|
+
return self.batch_get([key], [target_location] if target_location else None)[0]
|
147
234
|
|
148
235
|
@synchronized()
|
149
236
|
def batch_get(
|
@@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
151
238
|
keys: List[str],
|
152
239
|
target_locations: Optional[List[torch.Tensor]] = None,
|
153
240
|
) -> List[torch.Tensor | None]:
|
241
|
+
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
242
|
+
|
154
243
|
batch_indices, file_offsets = [], []
|
155
|
-
for i,
|
156
|
-
if
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
self.key_to_index.move_to_end(key)
|
161
|
-
# TODO: target_locations
|
244
|
+
for i, page_index in enumerate(page_indices):
|
245
|
+
if page_index is not None:
|
246
|
+
batch_indices.append(i)
|
247
|
+
file_offsets.append(page_index * self.bytes_per_page)
|
248
|
+
|
162
249
|
file_results = [
|
163
250
|
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
164
251
|
]
|
@@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
180
267
|
if read_result == self.bytes_per_page:
|
181
268
|
results[batch_index] = file_result
|
182
269
|
else:
|
183
|
-
logger.error(
|
270
|
+
logger.error(
|
271
|
+
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
|
272
|
+
)
|
184
273
|
|
185
274
|
return results
|
186
275
|
|
@@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
188
277
|
return self.batch_set([key], [value])
|
189
278
|
|
190
279
|
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
191
|
-
|
280
|
+
# Todo: Add prefix block's hash key
|
281
|
+
key_with_prefix = [(key, "") for key in keys]
|
282
|
+
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
283
|
+
self.rank, key_with_prefix
|
284
|
+
)
|
285
|
+
|
192
286
|
batch_indices, file_offsets, file_values = [], [], []
|
193
|
-
|
194
|
-
|
287
|
+
pages_to_release = []
|
288
|
+
|
289
|
+
for i, (value, (is_written, page_index)) in enumerate(zip(values, indices)):
|
290
|
+
if is_written or page_index == -1:
|
195
291
|
continue
|
292
|
+
|
196
293
|
batch_indices.append(i)
|
197
|
-
file_offsets.append(
|
294
|
+
file_offsets.append(page_index * self.bytes_per_page)
|
198
295
|
file_values.append(value.contiguous())
|
199
296
|
|
200
297
|
futures = [
|
@@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
211
308
|
for result in future.result()
|
212
309
|
]
|
213
310
|
|
311
|
+
written_keys_to_confirm = []
|
214
312
|
results = [index[0] for index in indices]
|
215
313
|
for batch_index, write_result in zip(batch_indices, write_results):
|
216
314
|
key = keys[batch_index]
|
217
|
-
|
315
|
+
page_index = indices[batch_index][1]
|
218
316
|
if write_result:
|
219
|
-
|
220
|
-
self.key_to_index.move_to_end(key)
|
317
|
+
written_keys_to_confirm.append((key, page_index))
|
221
318
|
else:
|
222
|
-
logger.error(f"HiCacheHF3FS set {key} failed")
|
223
|
-
|
319
|
+
logger.error(f"[Rank {self.rank}] HiCacheHF3FS set {key} failed")
|
320
|
+
pages_to_release.append(page_index)
|
224
321
|
results[batch_index] = write_result
|
225
|
-
return all(results)
|
226
|
-
|
227
|
-
@synchronized()
|
228
|
-
def get_batch_set_indices(self, keys: List[str]) -> list:
|
229
|
-
ionum = len(keys)
|
230
|
-
# results: tuples of (is_written: bool, page_idx: int)
|
231
|
-
# - is_written: True = hit (no I/O), False = write (miss)
|
232
|
-
# - page_idx: page storing data
|
233
|
-
results = [None] * min(ionum, self.num_pages)
|
234
|
-
if ionum > self.num_pages:
|
235
|
-
results.extend([(False, -1)] * (ionum - self.num_pages))
|
236
|
-
|
237
|
-
new_keys = []
|
238
|
-
for batch_index, key in enumerate(keys[: self.num_pages]):
|
239
|
-
if key in self.key_to_index:
|
240
|
-
results[batch_index] = (True, self.key_to_index[key])
|
241
|
-
self.key_to_index.move_to_end(key)
|
242
|
-
else:
|
243
|
-
new_keys.append((batch_index, key))
|
244
322
|
|
245
|
-
|
246
|
-
|
247
|
-
self.
|
248
|
-
if len(self.free_pages) > 0
|
249
|
-
else self.key_to_index.popitem(last=False)[1]
|
323
|
+
if len(written_keys_to_confirm) > 0 or len(pages_to_release) > 0:
|
324
|
+
self.metadata_client.confirm_write(
|
325
|
+
self.rank, written_keys_to_confirm, pages_to_release
|
250
326
|
)
|
251
|
-
results[batch_index] = (False, index)
|
252
327
|
|
253
|
-
return results
|
328
|
+
return all(results)
|
254
329
|
|
255
330
|
@synchronized()
|
256
331
|
def delete(self, key: str) -> None:
|
257
|
-
|
258
|
-
return
|
259
|
-
index = self.key_to_index.pop(key)
|
260
|
-
self.free_pages.append(index)
|
332
|
+
self.metadata_client.delete_keys(self.rank, [key])
|
261
333
|
|
262
334
|
@synchronized()
|
263
335
|
def exists(self, key: str) -> bool:
|
264
|
-
|
336
|
+
result = self.metadata_client.exists(self.rank, [key])
|
337
|
+
return result[0] if result else False
|
265
338
|
|
266
339
|
@synchronized()
|
267
340
|
def clear(self) -> None:
|
268
|
-
self.
|
269
|
-
self.key_to_index.clear()
|
341
|
+
self.metadata_client.clear(self.rank)
|
270
342
|
|
271
343
|
def close(self) -> None:
|
272
344
|
try:
|
@@ -18,13 +18,12 @@ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
|
|
18
18
|
logger = logging.getLogger(__name__)
|
19
19
|
|
20
20
|
|
21
|
-
def get_hash_str_mooncake(
|
21
|
+
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
22
22
|
local_rank = get_tensor_model_parallel_rank()
|
23
23
|
prefix_str = ""
|
24
|
-
if
|
25
|
-
|
26
|
-
|
27
|
-
current_token_ids_bytes = np.array(current_page_ids).tobytes()
|
24
|
+
if prior_hash:
|
25
|
+
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
26
|
+
current_token_ids_bytes = np.array(token_ids).tobytes()
|
28
27
|
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
29
28
|
current_hash_hex = current_hash_object.hexdigest()
|
30
29
|
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
|
@@ -224,13 +223,11 @@ class MooncakeStore(HiCacheStorage):
|
|
224
223
|
|
225
224
|
def exists(self, keys) -> bool | dict:
|
226
225
|
_keys = []
|
227
|
-
local_rank = torch.cuda.current_device()
|
228
226
|
for key in keys:
|
229
227
|
if key is None:
|
230
228
|
return None
|
231
|
-
|
232
|
-
|
233
|
-
_keys.append(f"{key}_{local_rank}_k")
|
229
|
+
|
230
|
+
_keys.append(f"{key}_k")
|
234
231
|
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
|
235
232
|
return result
|
236
233
|
|
@@ -33,7 +33,11 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
33
33
|
set_graph_pool_id,
|
34
34
|
)
|
35
35
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
36
|
-
from sglang.srt.layers.dp_attention import
|
36
|
+
from sglang.srt.layers.dp_attention import (
|
37
|
+
DPPaddingMode,
|
38
|
+
get_attention_tp_rank,
|
39
|
+
get_attention_tp_size,
|
40
|
+
)
|
37
41
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
38
42
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
39
43
|
from sglang.srt.model_executor.forward_batch_info import (
|
@@ -255,6 +259,9 @@ class CudaGraphRunner:
|
|
255
259
|
self.dp_size = model_runner.server_args.dp_size
|
256
260
|
self.pp_size = model_runner.server_args.pp_size
|
257
261
|
|
262
|
+
self.attn_tp_size = get_attention_tp_size()
|
263
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
264
|
+
|
258
265
|
# Batch sizes to capture
|
259
266
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
260
267
|
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
@@ -576,11 +583,11 @@ class CudaGraphRunner:
|
|
576
583
|
)
|
577
584
|
|
578
585
|
if self.model_runner.server_args.enable_lora:
|
579
|
-
# It is safe to capture CUDA graph using empty LoRA
|
580
|
-
# `--enable-lora` is set to True (and return immediately if the LoRA
|
581
|
-
|
586
|
+
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
587
|
+
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
588
|
+
lora_ids = [None] * bs
|
582
589
|
else:
|
583
|
-
|
590
|
+
lora_ids = None
|
584
591
|
|
585
592
|
forward_batch = ForwardBatch(
|
586
593
|
forward_mode=self.capture_forward_mode,
|
@@ -589,6 +596,7 @@ class CudaGraphRunner:
|
|
589
596
|
req_pool_indices=req_pool_indices,
|
590
597
|
seq_lens=seq_lens,
|
591
598
|
next_token_logits_buffer=next_token_logits_buffer,
|
599
|
+
orig_seq_lens=seq_lens,
|
592
600
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
593
601
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
594
602
|
attn_backend=self.model_runner.attn_backend,
|
@@ -607,11 +615,11 @@ class CudaGraphRunner:
|
|
607
615
|
capture_hidden_mode=self.capture_hidden_mode,
|
608
616
|
num_token_non_padded=self.num_token_non_padded,
|
609
617
|
global_forward_mode=self.capture_forward_mode,
|
610
|
-
|
618
|
+
lora_ids=lora_ids,
|
611
619
|
)
|
612
620
|
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
613
621
|
|
614
|
-
if
|
622
|
+
if lora_ids is not None:
|
615
623
|
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
616
624
|
|
617
625
|
# Attention backend
|
@@ -728,10 +736,12 @@ class CudaGraphRunner:
|
|
728
736
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
729
737
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
730
738
|
|
739
|
+
seq_lens_cpu = None
|
731
740
|
if forward_batch.seq_lens_cpu is not None:
|
732
741
|
if bs != raw_bs:
|
733
742
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
734
743
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
744
|
+
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
735
745
|
|
736
746
|
if pp_proxy_tensors:
|
737
747
|
for key in self.pp_proxy_tensors.keys():
|
@@ -746,7 +756,17 @@ class CudaGraphRunner:
|
|
746
756
|
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
747
757
|
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
748
758
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
749
|
-
|
759
|
+
num_token_non_padded = forward_batch.num_token_non_padded
|
760
|
+
if self.require_gathered_buffer:
|
761
|
+
tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
|
762
|
+
num_local_token_non_padded = torch.clamp(
|
763
|
+
num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
|
764
|
+
min=0,
|
765
|
+
max=tokens_per_rank,
|
766
|
+
)
|
767
|
+
self.num_token_non_padded.copy_(num_local_token_non_padded)
|
768
|
+
else:
|
769
|
+
self.num_token_non_padded.copy_(num_token_non_padded)
|
750
770
|
if self.enable_two_batch_overlap:
|
751
771
|
self.tbo_plugin.replay_prepare(
|
752
772
|
forward_mode=self.capture_forward_mode,
|
@@ -765,7 +785,7 @@ class CudaGraphRunner:
|
|
765
785
|
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
766
786
|
self.capture_forward_mode,
|
767
787
|
forward_batch.spec_info,
|
768
|
-
seq_lens_cpu=
|
788
|
+
seq_lens_cpu=seq_lens_cpu,
|
769
789
|
)
|
770
790
|
|
771
791
|
# Store fields
|
@@ -180,6 +180,9 @@ class ForwardBatch:
|
|
180
180
|
# The sum of all sequence lengths
|
181
181
|
seq_lens_sum: int
|
182
182
|
|
183
|
+
# The original sequence length without being chunked. Qwen-1M related.
|
184
|
+
orig_seq_lens: Optional[torch.Tensor] = None
|
185
|
+
|
183
186
|
# Optional seq_lens on cpu
|
184
187
|
seq_lens_cpu: Optional[torch.Tensor] = None
|
185
188
|
|
@@ -248,7 +251,7 @@ class ForwardBatch:
|
|
248
251
|
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
249
252
|
|
250
253
|
# For LoRA
|
251
|
-
|
254
|
+
lora_ids: Optional[List[str]] = None
|
252
255
|
|
253
256
|
# For input embeddings
|
254
257
|
input_embeds: Optional[torch.Tensor] = None
|
@@ -321,13 +324,14 @@ class ForwardBatch:
|
|
321
324
|
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
322
325
|
seq_lens_sum=batch.seq_lens_sum,
|
323
326
|
seq_lens_cpu=batch.seq_lens_cpu,
|
327
|
+
orig_seq_lens=batch.orig_seq_lens,
|
324
328
|
return_logprob=batch.return_logprob,
|
325
329
|
top_logprobs_nums=batch.top_logprobs_nums,
|
326
330
|
token_ids_logprobs=batch.token_ids_logprobs,
|
327
331
|
is_extend_in_batch=batch.is_extend_in_batch,
|
328
332
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
329
333
|
global_forward_mode=batch.global_forward_mode,
|
330
|
-
|
334
|
+
lora_ids=batch.lora_ids,
|
331
335
|
sampling_info=batch.sampling_info,
|
332
336
|
req_to_token_pool=model_runner.req_to_token_pool,
|
333
337
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
@@ -420,16 +424,12 @@ class ForwardBatch:
|
|
420
424
|
batch.extend_prefix_lens, dtype=torch.int32
|
421
425
|
).to(device, non_blocking=True)
|
422
426
|
ret.extend_num_tokens = batch.extend_num_tokens
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
else:
|
430
|
-
positions, ret.extend_start_loc = compute_position_torch(
|
431
|
-
ret.extend_prefix_lens, ret.extend_seq_lens
|
432
|
-
)
|
427
|
+
positions, ret.extend_start_loc = compute_position(
|
428
|
+
model_runner.server_args.attention_backend,
|
429
|
+
ret.extend_prefix_lens,
|
430
|
+
ret.extend_seq_lens,
|
431
|
+
ret.extend_num_tokens,
|
432
|
+
)
|
433
433
|
if ret.positions is None:
|
434
434
|
ret.positions = positions
|
435
435
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
@@ -632,8 +632,10 @@ class ForwardBatch:
|
|
632
632
|
self.dp_padding_mode = dp_padding_mode
|
633
633
|
|
634
634
|
if dp_padding_mode.is_max_len():
|
635
|
-
# when DP gather mode is all gather, we will use
|
636
|
-
#
|
635
|
+
# when DP gather mode is all gather, we will use
|
636
|
+
# all_gather_into_tensor to gather hidden states, where transferred
|
637
|
+
# tokens should be padded to the same length. We will also use
|
638
|
+
# reduce-scatter instead of all-reduce after MLP.
|
637
639
|
max_num_tokens = max(global_num_tokens)
|
638
640
|
global_num_tokens = [max_num_tokens] * sync_group_size
|
639
641
|
buffer_len = max_num_tokens * sync_group_size
|
@@ -651,12 +653,30 @@ class ForwardBatch:
|
|
651
653
|
else:
|
652
654
|
num_tokens = global_num_tokens[0]
|
653
655
|
|
654
|
-
if self.forward_mode.is_decode():
|
655
|
-
setattr(self, "raw_bs", self.batch_size)
|
656
|
-
self.batch_size = num_tokens
|
657
|
-
|
658
656
|
bs = self.batch_size
|
659
657
|
|
658
|
+
if self.forward_mode.is_decode():
|
659
|
+
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
|
660
|
+
setattr(self, "_original_forward_mode", self.forward_mode)
|
661
|
+
self.forward_mode = ForwardMode.EXTEND
|
662
|
+
self.extend_num_tokens = bs
|
663
|
+
self.extend_seq_lens = torch.full_like(self.seq_lens, 1)
|
664
|
+
self.extend_prefix_lens = self.seq_lens - 1
|
665
|
+
self.extend_start_loc = torch.arange(
|
666
|
+
bs, dtype=torch.int32, device=self.seq_lens.device
|
667
|
+
)
|
668
|
+
self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu()
|
669
|
+
self.extend_seq_lens_cpu = self.extend_seq_lens.cpu()
|
670
|
+
self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu
|
671
|
+
else:
|
672
|
+
setattr(self, "_original_batch_size", self.batch_size)
|
673
|
+
if self.spec_info is not None:
|
674
|
+
bs = self.batch_size = (
|
675
|
+
num_tokens // self.spec_info.num_tokens_per_batch
|
676
|
+
)
|
677
|
+
else:
|
678
|
+
bs = self.batch_size = num_tokens
|
679
|
+
|
660
680
|
# padding
|
661
681
|
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
662
682
|
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
@@ -687,6 +707,7 @@ class ForwardBatch:
|
|
687
707
|
if self.mrope_positions is not None:
|
688
708
|
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
|
689
709
|
|
710
|
+
# TODO: check if we need to pad other tensors
|
690
711
|
if self.extend_seq_lens is not None:
|
691
712
|
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
692
713
|
|
@@ -710,7 +731,9 @@ class ForwardBatch:
|
|
710
731
|
|
711
732
|
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
712
733
|
|
713
|
-
|
734
|
+
self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
|
735
|
+
self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
|
736
|
+
bs = self.batch_size
|
714
737
|
|
715
738
|
if self.spec_info is not None:
|
716
739
|
if self.forward_mode.is_decode(): # draft
|
@@ -882,6 +905,25 @@ class PPProxyTensors:
|
|
882
905
|
return f"PPProxyTensors(tensors={self.tensors})"
|
883
906
|
|
884
907
|
|
908
|
+
def compute_position(
|
909
|
+
attn_backend: str,
|
910
|
+
extend_prefix_lens: torch.Tensor,
|
911
|
+
extend_seq_lens: torch.Tensor,
|
912
|
+
extend_seq_lens_sum: int,
|
913
|
+
):
|
914
|
+
if support_triton(attn_backend):
|
915
|
+
positions, extend_start_loc = compute_position_triton(
|
916
|
+
extend_prefix_lens,
|
917
|
+
extend_seq_lens,
|
918
|
+
extend_seq_lens_sum,
|
919
|
+
)
|
920
|
+
else:
|
921
|
+
positions, extend_start_loc = compute_position_torch(
|
922
|
+
extend_prefix_lens, extend_seq_lens
|
923
|
+
)
|
924
|
+
return positions, extend_start_loc
|
925
|
+
|
926
|
+
|
885
927
|
def compute_position_triton(
|
886
928
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
887
929
|
):
|