sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,7 @@ from typing import Any, List, Optional, Tuple
|
|
11
11
|
|
12
12
|
import torch
|
13
13
|
|
14
|
-
from sglang.srt.
|
15
|
-
from sglang.srt.layers.dp_attention import (
|
16
|
-
get_attention_tp_rank,
|
17
|
-
is_dp_attention_enabled,
|
18
|
-
)
|
19
|
-
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
14
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
20
15
|
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
21
16
|
|
22
17
|
logger = logging.getLogger(__name__)
|
@@ -130,6 +125,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
130
125
|
entries: int,
|
131
126
|
dtype: torch.dtype,
|
132
127
|
metadata_client: Hf3fsMetadataInterface,
|
128
|
+
is_mla_model: bool = False,
|
133
129
|
):
|
134
130
|
self.rank = rank
|
135
131
|
self.file_path = file_path
|
@@ -139,9 +135,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
139
135
|
self.entries = entries
|
140
136
|
self.dtype = dtype
|
141
137
|
self.metadata_client = metadata_client
|
142
|
-
|
138
|
+
self.is_mla_model = is_mla_model
|
143
139
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
144
140
|
self.num_pages = self.file_size // self.bytes_per_page
|
141
|
+
self.skip_backup = False
|
142
|
+
if self.is_mla_model and self.rank != 0:
|
143
|
+
self.skip_backup = True
|
144
|
+
self.rank = 0
|
145
145
|
|
146
146
|
logger.info(
|
147
147
|
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
@@ -172,19 +172,16 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
172
172
|
|
173
173
|
@staticmethod
|
174
174
|
def from_env_config(
|
175
|
-
bytes_per_page: int,
|
175
|
+
bytes_per_page: int,
|
176
|
+
dtype: torch.dtype,
|
177
|
+
storage_config: HiCacheStorageConfig = None,
|
176
178
|
) -> "HiCacheHF3FS":
|
177
179
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
178
180
|
Hf3fsGlobalMetadataClient,
|
179
181
|
Hf3fsLocalMetadataClient,
|
180
182
|
)
|
181
183
|
|
182
|
-
if
|
183
|
-
rank = (
|
184
|
-
get_attention_tp_rank()
|
185
|
-
if is_dp_attention_enabled()
|
186
|
-
else get_tensor_model_parallel_rank()
|
187
|
-
)
|
184
|
+
rank = storage_config.tp_rank if storage_config is not None else 0
|
188
185
|
|
189
186
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
190
187
|
if not config_path:
|
@@ -217,10 +214,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
217
214
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
218
215
|
|
219
216
|
# Choose metadata client based on configuration
|
217
|
+
is_mla_model = False
|
220
218
|
if "metadata_server_url" in config and config["metadata_server_url"]:
|
221
219
|
# Use global metadata client to connect to metadata server
|
222
220
|
metadata_server_url = config["metadata_server_url"]
|
223
221
|
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
222
|
+
|
223
|
+
# Enable MLA optimization only when using the global metadata client
|
224
|
+
is_mla_model = storage_config.is_mla_model if storage_config else False
|
224
225
|
logger.info(
|
225
226
|
f"Using global metadata client with server url: {metadata_server_url}"
|
226
227
|
)
|
@@ -230,13 +231,15 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
230
231
|
|
231
232
|
return HiCacheHF3FS(
|
232
233
|
rank=rank,
|
233
|
-
|
234
|
+
# Let all ranks use the same file path for MLA model
|
235
|
+
file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
|
234
236
|
file_size=int(config["file_size"]),
|
235
237
|
numjobs=int(config["numjobs"]),
|
236
238
|
bytes_per_page=bytes_per_page,
|
237
239
|
entries=int(config["entries"]),
|
238
240
|
dtype=dtype,
|
239
241
|
metadata_client=metadata_client,
|
242
|
+
is_mla_model=is_mla_model,
|
240
243
|
)
|
241
244
|
|
242
245
|
def get(
|
@@ -320,6 +323,10 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
320
323
|
target_locations: Optional[Any] = None,
|
321
324
|
target_sizes: Optional[Any] = None,
|
322
325
|
) -> bool:
|
326
|
+
# In MLA backend, only one rank needs to backup the KV cache
|
327
|
+
if self.skip_backup:
|
328
|
+
return True
|
329
|
+
|
323
330
|
# Todo: Add prefix block's hash key
|
324
331
|
key_with_prefix = [(key, "") for key in keys]
|
325
332
|
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
@@ -371,18 +378,29 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
371
378
|
|
372
379
|
return all(results)
|
373
380
|
|
374
|
-
@synchronized()
|
375
381
|
def delete(self, key: str) -> None:
|
376
382
|
self.metadata_client.delete_keys(self.rank, [key])
|
377
383
|
|
378
|
-
@synchronized()
|
379
384
|
def exists(self, key: str) -> bool:
|
380
385
|
result = self.metadata_client.exists(self.rank, [key])
|
381
386
|
return result[0] if result else False
|
382
387
|
|
383
|
-
|
384
|
-
|
385
|
-
|
388
|
+
def batch_exists(self, keys: List[str]) -> int:
|
389
|
+
results = self.metadata_client.exists(self.rank, keys)
|
390
|
+
for i in range(len(keys)):
|
391
|
+
if not results[i]:
|
392
|
+
return i
|
393
|
+
|
394
|
+
return len(keys)
|
395
|
+
|
396
|
+
def clear(self) -> bool:
|
397
|
+
try:
|
398
|
+
self.metadata_client.clear(self.rank)
|
399
|
+
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
|
400
|
+
return True
|
401
|
+
except Exception as e:
|
402
|
+
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
|
403
|
+
return False
|
386
404
|
|
387
405
|
def close(self) -> None:
|
388
406
|
try:
|
@@ -10,24 +10,14 @@ import numpy as np
|
|
10
10
|
import torch
|
11
11
|
|
12
12
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
13
|
-
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
13
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
14
14
|
|
15
15
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
16
|
-
DEFAULT_LOCAL_BUFFER_SIZE =
|
16
|
+
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
|
17
17
|
|
18
18
|
logger = logging.getLogger(__name__)
|
19
19
|
|
20
20
|
|
21
|
-
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
22
|
-
prefix_str = ""
|
23
|
-
if prior_hash:
|
24
|
-
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
25
|
-
current_token_ids_bytes = np.array(token_ids).tobytes()
|
26
|
-
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
27
|
-
current_hash_hex = current_hash_object.hexdigest()
|
28
|
-
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
|
29
|
-
|
30
|
-
|
31
21
|
@dataclass
|
32
22
|
class MooncakeStoreConfig:
|
33
23
|
local_hostname: str
|
@@ -54,9 +44,8 @@ class MooncakeStoreConfig:
|
|
54
44
|
global_segment_size=config.get(
|
55
45
|
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
56
46
|
),
|
57
|
-
|
58
|
-
|
59
|
-
),
|
47
|
+
# Zero copy interface does not need local buffer
|
48
|
+
local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
|
60
49
|
protocol=config.get("protocol", "tcp"),
|
61
50
|
device_name=config.get("device_name", "auto"),
|
62
51
|
master_server_address=config.get("master_server_address"),
|
@@ -79,9 +68,8 @@ class MooncakeStoreConfig:
|
|
79
68
|
global_segment_size=int(
|
80
69
|
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
81
70
|
),
|
82
|
-
|
83
|
-
|
84
|
-
),
|
71
|
+
# Zero copy interface does not need local buffer
|
72
|
+
local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
|
85
73
|
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
|
86
74
|
device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
|
87
75
|
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
@@ -96,7 +84,7 @@ class MooncakeStoreConfig:
|
|
96
84
|
|
97
85
|
|
98
86
|
class MooncakeStore(HiCacheStorage):
|
99
|
-
def __init__(self,
|
87
|
+
def __init__(self, storage_config: HiCacheStorageConfig = None):
|
100
88
|
try:
|
101
89
|
from mooncake.store import MooncakeDistributedStore
|
102
90
|
except ImportError as e:
|
@@ -126,7 +114,13 @@ class MooncakeStore(HiCacheStorage):
|
|
126
114
|
logger.info("Connect to Mooncake store successfully.")
|
127
115
|
self.warmup()
|
128
116
|
logger.info("Mooncake store warmup successfully.")
|
129
|
-
|
117
|
+
|
118
|
+
if storage_config is not None:
|
119
|
+
self.is_mla_backend = storage_config.is_mla_model
|
120
|
+
self.local_rank = storage_config.tp_rank
|
121
|
+
else:
|
122
|
+
self.is_mla_backend = False
|
123
|
+
self.local_rank = 0
|
130
124
|
|
131
125
|
except ValueError as e:
|
132
126
|
logger.error("Configuration loading failed: %s", e)
|
@@ -137,12 +131,10 @@ class MooncakeStore(HiCacheStorage):
|
|
137
131
|
|
138
132
|
def warmup(self):
|
139
133
|
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
|
140
|
-
#
|
141
|
-
|
142
|
-
self.store.put(warmup_key, warmup_value)
|
134
|
+
warmup_value = bytes(4 * 1024) # 4 KB
|
135
|
+
assert self.store.put(warmup_key, warmup_value) == 0
|
143
136
|
assert self.store.is_exist(warmup_key) == 1
|
144
|
-
self.store.get(warmup_key)
|
145
|
-
self.store.remove(warmup_key)
|
137
|
+
assert self.store.get(warmup_key) == warmup_value
|
146
138
|
|
147
139
|
def register_buffer(self, buffer: torch.Tensor) -> None:
|
148
140
|
try:
|
@@ -162,78 +154,96 @@ class MooncakeStore(HiCacheStorage):
|
|
162
154
|
target_location: Optional[List[int]] = None,
|
163
155
|
target_sizes: Optional[List[int]] = None,
|
164
156
|
) -> bool:
|
165
|
-
|
166
|
-
if len(key) == 0:
|
167
|
-
return
|
168
|
-
|
169
|
-
for i in range(len(key)):
|
170
|
-
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
171
|
-
return
|
172
|
-
|
173
|
-
self._put_batch_zero_copy_impl(key, target_location, target_sizes)
|
157
|
+
return self.batch_set([key], [value], [target_location], [target_sizes])
|
174
158
|
|
175
159
|
def batch_set(
|
176
160
|
self,
|
177
161
|
keys: List[str],
|
178
|
-
|
162
|
+
values: Optional[List[torch.Tensor]] = None,
|
179
163
|
target_location: Optional[List[int]] = None,
|
180
164
|
target_sizes: Optional[List[int]] = None,
|
181
165
|
) -> bool:
|
182
166
|
assert len(keys) == len(target_location) == len(target_sizes)
|
183
167
|
if len(keys) == 0:
|
184
|
-
return
|
168
|
+
return False
|
185
169
|
|
186
170
|
for i in range(len(keys)):
|
187
171
|
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
188
|
-
return
|
172
|
+
return False
|
189
173
|
|
190
|
-
self.
|
174
|
+
exist_result = self._batch_exist(keys)
|
175
|
+
set_keys = []
|
176
|
+
set_target_locations = []
|
177
|
+
set_target_sizes = []
|
178
|
+
set_indices = []
|
179
|
+
for i in range(len(keys)):
|
180
|
+
if exist_result[i] != 1:
|
181
|
+
set_keys.append(keys[i])
|
182
|
+
set_target_locations.append(target_location[i])
|
183
|
+
set_target_sizes.append(target_sizes[i])
|
184
|
+
set_indices.append(i)
|
185
|
+
# Only set non-existing keys to storage
|
186
|
+
put_result = self._put_batch_zero_copy_impl(
|
187
|
+
set_keys, set_target_locations, set_target_sizes
|
188
|
+
)
|
189
|
+
for i in range(len(set_indices)):
|
190
|
+
if put_result[i] == 0:
|
191
|
+
exist_result[set_indices[i]] = 1
|
192
|
+
|
193
|
+
success_count = 0
|
194
|
+
for i in range(len(keys)):
|
195
|
+
if exist_result[i] == 0:
|
196
|
+
break
|
197
|
+
success_count += 1
|
198
|
+
# TODO: return the number of consecutive successful operations from the start.
|
199
|
+
return success_count == len(keys)
|
191
200
|
|
192
201
|
def get(
|
193
202
|
self,
|
194
203
|
key,
|
195
204
|
target_location: Optional[Any] = None,
|
196
205
|
target_sizes: Optional[Any] = None,
|
197
|
-
) ->
|
198
|
-
|
199
|
-
if len(key) == 0:
|
200
|
-
return
|
201
|
-
|
202
|
-
for i in range(len(key)):
|
203
|
-
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
204
|
-
return
|
205
|
-
|
206
|
-
return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
|
206
|
+
) -> bool:
|
207
|
+
return self.batch_get([key], [target_location], [target_sizes]) == 1
|
207
208
|
|
208
209
|
def batch_get(
|
209
210
|
self,
|
210
211
|
keys: List[str],
|
211
212
|
target_location: Optional[Any] = None,
|
212
213
|
target_sizes: Optional[Any] = None,
|
213
|
-
) ->
|
214
|
+
) -> int:
|
214
215
|
assert len(keys) == len(target_location) == len(target_sizes)
|
215
216
|
if len(keys) == 0:
|
216
|
-
return
|
217
|
-
|
217
|
+
return 0
|
218
|
+
get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
|
219
|
+
if self.is_mla_backend:
|
220
|
+
key_multiplier = 1
|
221
|
+
else:
|
222
|
+
key_multiplier = 2
|
218
223
|
for i in range(len(keys)):
|
219
|
-
if
|
220
|
-
return
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
224
|
+
if get_result[i] < 0:
|
225
|
+
return i // key_multiplier
|
226
|
+
return len(keys) // key_multiplier
|
227
|
+
|
228
|
+
def exists(self, key) -> bool:
|
229
|
+
return self.batch_exists([key]) > 0
|
230
|
+
|
231
|
+
def batch_exists(self, keys) -> int:
|
232
|
+
if self.is_mla_backend:
|
233
|
+
query_keys = [f"{key}_k" for key in keys]
|
234
|
+
key_multiplier = 1
|
235
|
+
else:
|
236
|
+
query_keys = []
|
237
|
+
for key in keys:
|
238
|
+
query_keys.append(f"{key}_{self.local_rank}_k")
|
239
|
+
query_keys.append(f"{key}_{self.local_rank}_v")
|
240
|
+
key_multiplier = 2
|
241
|
+
|
242
|
+
exist_result = self._batch_exist(query_keys)
|
243
|
+
for i in range(len(query_keys)):
|
244
|
+
if exist_result[i] != 1:
|
245
|
+
return i // key_multiplier
|
246
|
+
return len(query_keys) // key_multiplier
|
237
247
|
|
238
248
|
def delete(self, key) -> None:
|
239
249
|
raise (NotImplementedError)
|
@@ -244,22 +254,17 @@ class MooncakeStore(HiCacheStorage):
|
|
244
254
|
pass
|
245
255
|
|
246
256
|
def clear(self) -> None:
|
247
|
-
|
257
|
+
self.store.remove_all()
|
248
258
|
|
249
259
|
def _put_batch_zero_copy_impl(
|
250
260
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
251
|
-
) ->
|
252
|
-
|
253
|
-
self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
254
|
-
except TypeError as err:
|
255
|
-
logger.error("Failed to put value to Mooncake Store: %s", err)
|
256
|
-
raise TypeError("Mooncake Store Put Type Error.") from err
|
261
|
+
) -> List[int]:
|
262
|
+
return self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
257
263
|
|
258
264
|
def _get_batch_zero_copy_impl(
|
259
265
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
260
|
-
) ->
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
raise TypeError("Mooncake Store Get Type Error.") from err
|
266
|
+
) -> List[int]:
|
267
|
+
return self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
|
268
|
+
|
269
|
+
def _batch_exist(self, key_strs: List[str]) -> List[int]:
|
270
|
+
return self.store.batch_is_exist(key_strs)
|
@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
|
|
464
464
|
self.req_to_token_pool.free(req.req_pool_idx)
|
465
465
|
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
466
466
|
|
467
|
-
def cache_unfinished_req(self, req: Req) -> None:
|
467
|
+
def cache_unfinished_req(self, req: Req, chunked=False) -> None:
|
468
468
|
"""Cache request when it is unfinished."""
|
469
469
|
if self.disable:
|
470
470
|
kv_indices = self.req_to_token_pool.req_to_token[
|
@@ -66,7 +66,6 @@ from sglang.srt.layers.quantization import (
|
|
66
66
|
)
|
67
67
|
from sglang.srt.layers.sampler import Sampler
|
68
68
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
69
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
70
69
|
from sglang.srt.lora.lora_manager import LoRAManager
|
71
70
|
from sglang.srt.lora.lora_registry import LoRARef
|
72
71
|
from sglang.srt.managers.schedule_batch import (
|
@@ -121,6 +120,7 @@ from sglang.srt.utils import (
|
|
121
120
|
is_hopper_with_cuda_12_3,
|
122
121
|
is_no_spec_infer_or_topk_one,
|
123
122
|
is_npu,
|
123
|
+
is_sm100_supported,
|
124
124
|
monkey_patch_p2p_access_check,
|
125
125
|
monkey_patch_vllm_gguf_config,
|
126
126
|
set_cuda_arch,
|
@@ -307,7 +307,10 @@ class ModelRunner:
|
|
307
307
|
model_num_layers = (
|
308
308
|
self.model_config.num_nextn_predict_layers
|
309
309
|
if self.is_draft_worker and model_has_mtp_layers
|
310
|
-
else
|
310
|
+
else max(
|
311
|
+
self.model_config.num_hidden_layers,
|
312
|
+
self.model_config.num_attention_layers,
|
313
|
+
)
|
311
314
|
)
|
312
315
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
313
316
|
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
@@ -1440,14 +1443,12 @@ class ModelRunner:
|
|
1440
1443
|
else self.server_args.attention_backend
|
1441
1444
|
)
|
1442
1445
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
1443
|
-
assert (
|
1444
|
-
self.server_args.speculative_algorithm is None
|
1445
|
-
), "Currently HybridAttentionBackend does not support speculative decoding."
|
1446
1446
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
1447
1447
|
HybridAttnBackend,
|
1448
1448
|
)
|
1449
1449
|
|
1450
1450
|
attn_backend = HybridAttnBackend(
|
1451
|
+
self,
|
1451
1452
|
decode_backend=self._get_attention_backend_from_str(
|
1452
1453
|
self.decode_attention_backend_str
|
1453
1454
|
),
|
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
|
|
42
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
43
|
from sglang.srt.model_loader.utils import (
|
44
44
|
get_model_architecture,
|
45
|
+
post_load_weights,
|
45
46
|
set_default_torch_dtype,
|
46
47
|
)
|
47
48
|
from sglang.srt.model_loader.weight_utils import (
|
@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
|
|
600
601
|
# random values to the weights.
|
601
602
|
initialize_dummy_weights(model)
|
602
603
|
|
603
|
-
|
604
|
-
# 1. Initial weight loading.
|
605
|
-
# 2. Post-processing of weights, including assigning specific member variables.
|
606
|
-
# For `dummy_init`, only the second stage is required.
|
607
|
-
if hasattr(model, "post_load_weights"):
|
608
|
-
if (
|
609
|
-
model_config.hf_config.architectures[0]
|
610
|
-
== "DeepseekV3ForCausalLMNextN"
|
611
|
-
):
|
612
|
-
model.post_load_weights(is_nextn=True)
|
613
|
-
else:
|
614
|
-
model.post_load_weights()
|
604
|
+
post_load_weights(model, model_config)
|
615
605
|
|
616
606
|
return model.eval()
|
617
607
|
|
@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
|
|
751
741
|
state_dict.pop(key)
|
752
742
|
if state_dict:
|
753
743
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
744
|
+
|
745
|
+
post_load_weights(model, model_config)
|
746
|
+
|
754
747
|
return model.eval()
|
755
748
|
|
756
749
|
@staticmethod
|
@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1421
1414
|
# ignore hidden files
|
1422
1415
|
if file_name.startswith("."):
|
1423
1416
|
continue
|
1424
|
-
if os.path.splitext(file_name)[1]
|
1425
|
-
".bin",
|
1426
|
-
".pt",
|
1427
|
-
".safetensors",
|
1428
|
-
):
|
1417
|
+
if os.path.splitext(file_name)[1] in (".json", ".py"):
|
1429
1418
|
file_path = os.path.join(root, file_name)
|
1430
1419
|
with open(file_path, encoding="utf-8") as file:
|
1431
1420
|
file_content = file.read()
|
1432
1421
|
f_key = f"{model_name}/files/{file_name}"
|
1433
1422
|
client.setstr(f_key, file_content)
|
1434
1423
|
|
1435
|
-
def _load_model_from_remote_kv(
|
1424
|
+
def _load_model_from_remote_kv(
|
1425
|
+
self, model: nn.Module, model_config: ModelConfig, client
|
1426
|
+
):
|
1436
1427
|
for _, module in model.named_modules():
|
1437
1428
|
quant_method = getattr(module, "quant_method", None)
|
1438
1429
|
if quant_method is not None:
|
@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1460
1451
|
if state_dict:
|
1461
1452
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
1462
1453
|
|
1454
|
+
post_load_weights(model, model_config)
|
1455
|
+
|
1463
1456
|
def _load_model_from_remote_fs(
|
1464
1457
|
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1465
1458
|
) -> nn.Module:
|
@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1501
1494
|
with set_default_torch_dtype(model_config.dtype):
|
1502
1495
|
with torch.device(device_config.device):
|
1503
1496
|
model = _initialize_model(model_config, self.load_config)
|
1504
|
-
for _, module in model.named_modules():
|
1505
|
-
quant_method = getattr(module, "quant_method", None)
|
1506
|
-
if quant_method is not None:
|
1507
|
-
quant_method.process_weights_after_loading(module)
|
1508
1497
|
|
1509
|
-
with create_remote_connector(
|
1498
|
+
with create_remote_connector(
|
1499
|
+
model_weights, device=device_config.device
|
1500
|
+
) as client:
|
1510
1501
|
connector_type = get_connector_type(client)
|
1511
1502
|
if connector_type == ConnectorType.KV:
|
1512
|
-
self._load_model_from_remote_kv(model, client)
|
1503
|
+
self._load_model_from_remote_kv(model, model_config, client)
|
1513
1504
|
elif connector_type == ConnectorType.FS:
|
1514
1505
|
self._load_model_from_remote_fs(
|
1515
1506
|
model, client, model_config, device_config
|
sglang/srt/model_loader/utils.py
CHANGED
@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
|
|
105
105
|
|
106
106
|
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
107
107
|
return get_model_architecture(model_config)[1]
|
108
|
+
|
109
|
+
|
110
|
+
def post_load_weights(model: nn.Module, model_config: ModelConfig):
|
111
|
+
# Model weight loading consists of two stages:
|
112
|
+
# 1. Initial weight loading.
|
113
|
+
# 2. Post-processing of weights, including assigning specific member variables.
|
114
|
+
# For `dummy_init`, only the second stage is required.
|
115
|
+
if hasattr(model, "post_load_weights"):
|
116
|
+
if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
|
117
|
+
model.post_load_weights(is_nextn=True)
|
118
|
+
else:
|
119
|
+
model.post_load_weights()
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -87,8 +87,8 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
87
87
|
block_dequant as int8_block_dequant,
|
88
88
|
)
|
89
89
|
from sglang.srt.layers.radix_attention import RadixAttention
|
90
|
-
from sglang.srt.layers.rotary_embedding import
|
91
|
-
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
90
|
+
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
91
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
92
92
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
93
93
|
ParallelLMHead,
|
94
94
|
VocabParallelEmbedding,
|
@@ -114,6 +114,8 @@ from sglang.srt.utils import (
|
|
114
114
|
is_flashinfer_available,
|
115
115
|
is_hip,
|
116
116
|
is_non_idle_and_non_empty,
|
117
|
+
is_npu,
|
118
|
+
is_sm100_supported,
|
117
119
|
log_info_on_rank0,
|
118
120
|
make_layers,
|
119
121
|
use_intel_amx_backend,
|
@@ -121,6 +123,7 @@ from sglang.srt.utils import (
|
|
121
123
|
|
122
124
|
_is_hip = is_hip()
|
123
125
|
_is_cuda = is_cuda()
|
126
|
+
_is_npu = is_npu()
|
124
127
|
_is_fp8_fnuz = is_fp8_fnuz()
|
125
128
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
126
129
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -994,7 +997,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
994
997
|
self.current_attention_backend = attention_backend
|
995
998
|
|
996
999
|
if attention_backend == "ascend":
|
997
|
-
|
1000
|
+
if (
|
1001
|
+
forward_batch.forward_mode.is_extend()
|
1002
|
+
and not forward_batch.forward_mode.is_target_verify()
|
1003
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
1004
|
+
):
|
1005
|
+
return AttnForwardMethod.MHA
|
1006
|
+
else:
|
1007
|
+
return AttnForwardMethod.MLA
|
998
1008
|
elif (
|
999
1009
|
attention_backend == "flashinfer"
|
1000
1010
|
or attention_backend == "fa3"
|
@@ -1173,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1173
1183
|
k[..., : self.qk_nope_head_dim] = k_nope
|
1174
1184
|
k[..., self.qk_nope_head_dim :] = k_pe
|
1175
1185
|
|
1176
|
-
|
1177
|
-
|
1186
|
+
if not _is_npu:
|
1187
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
1188
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
1178
1189
|
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1190
|
+
# Save latent cache
|
1191
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1192
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1193
|
+
)
|
1194
|
+
else:
|
1195
|
+
# To reduce a time-costing split operation
|
1196
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1197
|
+
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
1198
|
+
)
|
1183
1199
|
|
1184
1200
|
return q, k, v, forward_batch
|
1185
1201
|
|
@@ -1292,6 +1308,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1292
1308
|
or self.current_attention_backend == "flashinfer"
|
1293
1309
|
or self.current_attention_backend == "cutlass_mla"
|
1294
1310
|
or self.current_attention_backend == "trtllm_mla"
|
1311
|
+
or self.current_attention_backend == "ascend"
|
1295
1312
|
):
|
1296
1313
|
extra_args = {}
|
1297
1314
|
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
@@ -2397,18 +2414,26 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2397
2414
|
)
|
2398
2415
|
|
2399
2416
|
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
2417
|
+
|
2400
2418
|
for layer_id in range(num_hidden_layers):
|
2401
2419
|
if is_nextn:
|
2402
2420
|
layer = self.model.decoder
|
2403
2421
|
else:
|
2404
2422
|
layer = self.model.layers[layer_id]
|
2405
2423
|
|
2406
|
-
|
2407
|
-
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
2408
|
-
layer.self_attn.q_b_proj,
|
2424
|
+
module_list = [
|
2409
2425
|
layer.self_attn.kv_b_proj,
|
2410
2426
|
layer.self_attn.o_proj,
|
2411
|
-
]
|
2427
|
+
]
|
2428
|
+
|
2429
|
+
if self.config.q_lora_rank is not None:
|
2430
|
+
module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
|
2431
|
+
module_list.append(layer.self_attn.q_b_proj)
|
2432
|
+
else:
|
2433
|
+
module_list.append(layer.self_attn.kv_a_proj_with_mqa)
|
2434
|
+
module_list.append(layer.self_attn.q_proj)
|
2435
|
+
|
2436
|
+
for module in module_list:
|
2412
2437
|
requant_weight_ue8m0_inplace(
|
2413
2438
|
module.weight, module.weight_scale_inv, weight_block_size
|
2414
2439
|
)
|