sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +14 -1
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +27 -15
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +14 -13
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +9 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
- sglang/srt/managers/scheduler.py +39 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +94 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +4 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +31 -10
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +4 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -113,6 +113,8 @@ def synchronized():
|
|
113
113
|
|
114
114
|
|
115
115
|
class HiCacheHF3FS(HiCacheStorage):
|
116
|
+
"""HiCache backend that stores KV cache pages in HF3FS files."""
|
117
|
+
|
116
118
|
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
117
119
|
|
118
120
|
def __init__(
|
@@ -125,6 +127,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
125
127
|
entries: int,
|
126
128
|
dtype: torch.dtype,
|
127
129
|
metadata_client: Hf3fsMetadataInterface,
|
130
|
+
is_mla_model: bool = False,
|
128
131
|
):
|
129
132
|
self.rank = rank
|
130
133
|
self.file_path = file_path
|
@@ -134,9 +137,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
134
137
|
self.entries = entries
|
135
138
|
self.dtype = dtype
|
136
139
|
self.metadata_client = metadata_client
|
137
|
-
|
140
|
+
self.is_mla_model = is_mla_model
|
138
141
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
139
142
|
self.num_pages = self.file_size // self.bytes_per_page
|
143
|
+
self.skip_backup = False
|
144
|
+
if self.is_mla_model and self.rank != 0:
|
145
|
+
self.skip_backup = True
|
146
|
+
self.rank = 0
|
140
147
|
|
141
148
|
logger.info(
|
142
149
|
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
@@ -171,15 +178,32 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
171
178
|
dtype: torch.dtype,
|
172
179
|
storage_config: HiCacheStorageConfig = None,
|
173
180
|
) -> "HiCacheHF3FS":
|
181
|
+
"""Create a HiCacheHF3FS instance from environment configuration.
|
182
|
+
|
183
|
+
Environment:
|
184
|
+
- Uses env var stored in `HiCacheHF3FS.default_env_var` to locate a JSON config.
|
185
|
+
- Falls back to a local single-machine config when the env var is not set.
|
186
|
+
|
187
|
+
Raises:
|
188
|
+
ValueError: If MLA Model is requested without global metadata server or required keys are missing.
|
189
|
+
"""
|
174
190
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
175
191
|
Hf3fsGlobalMetadataClient,
|
176
192
|
Hf3fsLocalMetadataClient,
|
177
193
|
)
|
178
194
|
|
179
|
-
|
195
|
+
if storage_config is not None:
|
196
|
+
rank, is_mla_model = storage_config.tp_rank, storage_config.is_mla_model
|
197
|
+
else:
|
198
|
+
rank, is_mla_model = 0, False
|
199
|
+
|
200
|
+
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
|
180
201
|
|
181
202
|
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
182
203
|
if not config_path:
|
204
|
+
if is_mla_model:
|
205
|
+
raise ValueError(mla_unsupported_msg)
|
206
|
+
|
183
207
|
return HiCacheHF3FS(
|
184
208
|
rank=rank,
|
185
209
|
file_path=f"/data/hicache.{rank}.bin",
|
@@ -209,26 +233,34 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
209
233
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
210
234
|
|
211
235
|
# Choose metadata client based on configuration
|
212
|
-
if
|
236
|
+
if config.get("metadata_server_url"):
|
213
237
|
# Use global metadata client to connect to metadata server
|
214
238
|
metadata_server_url = config["metadata_server_url"]
|
215
239
|
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
240
|
+
|
216
241
|
logger.info(
|
217
242
|
f"Using global metadata client with server url: {metadata_server_url}"
|
218
243
|
)
|
219
244
|
else:
|
245
|
+
# Enable MLA optimization only when using the global metadata client
|
246
|
+
if is_mla_model:
|
247
|
+
raise ValueError(mla_unsupported_msg)
|
248
|
+
|
220
249
|
# Use local metadata client for single-machine deployment
|
221
250
|
metadata_client = Hf3fsLocalMetadataClient()
|
222
251
|
|
252
|
+
rank_for_path = 0 if is_mla_model else rank
|
223
253
|
return HiCacheHF3FS(
|
224
254
|
rank=rank,
|
225
|
-
|
255
|
+
# Let all ranks use the same file path for MLA model
|
256
|
+
file_path=f"{config['file_path_prefix']}.{rank_for_path}.bin",
|
226
257
|
file_size=int(config["file_size"]),
|
227
258
|
numjobs=int(config["numjobs"]),
|
228
259
|
bytes_per_page=bytes_per_page,
|
229
260
|
entries=int(config["entries"]),
|
230
261
|
dtype=dtype,
|
231
262
|
metadata_client=metadata_client,
|
263
|
+
is_mla_model=is_mla_model,
|
232
264
|
)
|
233
265
|
|
234
266
|
def get(
|
@@ -312,6 +344,10 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
312
344
|
target_locations: Optional[Any] = None,
|
313
345
|
target_sizes: Optional[Any] = None,
|
314
346
|
) -> bool:
|
347
|
+
# In MLA backend, only one rank needs to backup the KV cache
|
348
|
+
if self.skip_backup:
|
349
|
+
return True
|
350
|
+
|
315
351
|
# Todo: Add prefix block's hash key
|
316
352
|
key_with_prefix = [(key, "") for key in keys]
|
317
353
|
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
@@ -363,18 +399,29 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
363
399
|
|
364
400
|
return all(results)
|
365
401
|
|
366
|
-
@synchronized()
|
367
402
|
def delete(self, key: str) -> None:
|
368
403
|
self.metadata_client.delete_keys(self.rank, [key])
|
369
404
|
|
370
|
-
@synchronized()
|
371
405
|
def exists(self, key: str) -> bool:
|
372
406
|
result = self.metadata_client.exists(self.rank, [key])
|
373
407
|
return result[0] if result else False
|
374
408
|
|
375
|
-
|
376
|
-
|
377
|
-
|
409
|
+
def batch_exists(self, keys: List[str]) -> int:
|
410
|
+
results = self.metadata_client.exists(self.rank, keys)
|
411
|
+
for i in range(len(keys)):
|
412
|
+
if not results[i]:
|
413
|
+
return i
|
414
|
+
|
415
|
+
return len(keys)
|
416
|
+
|
417
|
+
def clear(self) -> bool:
|
418
|
+
try:
|
419
|
+
self.metadata_client.clear(self.rank)
|
420
|
+
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
|
421
|
+
return True
|
422
|
+
except Exception as e:
|
423
|
+
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
|
424
|
+
return False
|
378
425
|
|
379
426
|
def close(self) -> None:
|
380
427
|
try:
|
@@ -159,6 +159,7 @@ class MooncakeStore(HiCacheStorage):
|
|
159
159
|
def batch_set(
|
160
160
|
self,
|
161
161
|
keys: List[str],
|
162
|
+
values: Optional[List[torch.Tensor]] = None,
|
162
163
|
target_location: Optional[List[int]] = None,
|
163
164
|
target_sizes: Optional[List[int]] = None,
|
164
165
|
) -> bool:
|
@@ -253,7 +254,7 @@ class MooncakeStore(HiCacheStorage):
|
|
253
254
|
pass
|
254
255
|
|
255
256
|
def clear(self) -> None:
|
256
|
-
|
257
|
+
self.store.remove_all()
|
257
258
|
|
258
259
|
def _put_batch_zero_copy_impl(
|
259
260
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
|
|
464
464
|
self.req_to_token_pool.free(req.req_pool_idx)
|
465
465
|
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
466
466
|
|
467
|
-
def cache_unfinished_req(self, req: Req) -> None:
|
467
|
+
def cache_unfinished_req(self, req: Req, chunked=False) -> None:
|
468
468
|
"""Cache request when it is unfinished."""
|
469
469
|
if self.disable:
|
470
470
|
kv_indices = self.req_to_token_pool.req_to_token[
|
@@ -307,7 +307,10 @@ class ModelRunner:
|
|
307
307
|
model_num_layers = (
|
308
308
|
self.model_config.num_nextn_predict_layers
|
309
309
|
if self.is_draft_worker and model_has_mtp_layers
|
310
|
-
else
|
310
|
+
else max(
|
311
|
+
self.model_config.num_hidden_layers,
|
312
|
+
self.model_config.num_attention_layers,
|
313
|
+
)
|
311
314
|
)
|
312
315
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
313
316
|
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
@@ -1440,14 +1443,12 @@ class ModelRunner:
|
|
1440
1443
|
else self.server_args.attention_backend
|
1441
1444
|
)
|
1442
1445
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
1443
|
-
assert (
|
1444
|
-
self.server_args.speculative_algorithm is None
|
1445
|
-
), "Currently HybridAttentionBackend does not support speculative decoding."
|
1446
1446
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
1447
1447
|
HybridAttnBackend,
|
1448
1448
|
)
|
1449
1449
|
|
1450
1450
|
attn_backend = HybridAttnBackend(
|
1451
|
+
self,
|
1451
1452
|
decode_backend=self._get_attention_backend_from_str(
|
1452
1453
|
self.decode_attention_backend_str
|
1453
1454
|
),
|
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
|
|
42
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
43
|
from sglang.srt.model_loader.utils import (
|
44
44
|
get_model_architecture,
|
45
|
+
post_load_weights,
|
45
46
|
set_default_torch_dtype,
|
46
47
|
)
|
47
48
|
from sglang.srt.model_loader.weight_utils import (
|
@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
|
|
600
601
|
# random values to the weights.
|
601
602
|
initialize_dummy_weights(model)
|
602
603
|
|
603
|
-
|
604
|
-
# 1. Initial weight loading.
|
605
|
-
# 2. Post-processing of weights, including assigning specific member variables.
|
606
|
-
# For `dummy_init`, only the second stage is required.
|
607
|
-
if hasattr(model, "post_load_weights"):
|
608
|
-
if (
|
609
|
-
model_config.hf_config.architectures[0]
|
610
|
-
== "DeepseekV3ForCausalLMNextN"
|
611
|
-
):
|
612
|
-
model.post_load_weights(is_nextn=True)
|
613
|
-
else:
|
614
|
-
model.post_load_weights()
|
604
|
+
post_load_weights(model, model_config)
|
615
605
|
|
616
606
|
return model.eval()
|
617
607
|
|
@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
|
|
751
741
|
state_dict.pop(key)
|
752
742
|
if state_dict:
|
753
743
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
744
|
+
|
745
|
+
post_load_weights(model, model_config)
|
746
|
+
|
754
747
|
return model.eval()
|
755
748
|
|
756
749
|
@staticmethod
|
@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1421
1414
|
# ignore hidden files
|
1422
1415
|
if file_name.startswith("."):
|
1423
1416
|
continue
|
1424
|
-
if os.path.splitext(file_name)[1]
|
1425
|
-
".bin",
|
1426
|
-
".pt",
|
1427
|
-
".safetensors",
|
1428
|
-
):
|
1417
|
+
if os.path.splitext(file_name)[1] in (".json", ".py"):
|
1429
1418
|
file_path = os.path.join(root, file_name)
|
1430
1419
|
with open(file_path, encoding="utf-8") as file:
|
1431
1420
|
file_content = file.read()
|
1432
1421
|
f_key = f"{model_name}/files/{file_name}"
|
1433
1422
|
client.setstr(f_key, file_content)
|
1434
1423
|
|
1435
|
-
def _load_model_from_remote_kv(
|
1424
|
+
def _load_model_from_remote_kv(
|
1425
|
+
self, model: nn.Module, model_config: ModelConfig, client
|
1426
|
+
):
|
1436
1427
|
for _, module in model.named_modules():
|
1437
1428
|
quant_method = getattr(module, "quant_method", None)
|
1438
1429
|
if quant_method is not None:
|
@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1460
1451
|
if state_dict:
|
1461
1452
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
1462
1453
|
|
1454
|
+
post_load_weights(model, model_config)
|
1455
|
+
|
1463
1456
|
def _load_model_from_remote_fs(
|
1464
1457
|
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1465
1458
|
) -> nn.Module:
|
@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1501
1494
|
with set_default_torch_dtype(model_config.dtype):
|
1502
1495
|
with torch.device(device_config.device):
|
1503
1496
|
model = _initialize_model(model_config, self.load_config)
|
1504
|
-
for _, module in model.named_modules():
|
1505
|
-
quant_method = getattr(module, "quant_method", None)
|
1506
|
-
if quant_method is not None:
|
1507
|
-
quant_method.process_weights_after_loading(module)
|
1508
1497
|
|
1509
|
-
with create_remote_connector(
|
1498
|
+
with create_remote_connector(
|
1499
|
+
model_weights, device=device_config.device
|
1500
|
+
) as client:
|
1510
1501
|
connector_type = get_connector_type(client)
|
1511
1502
|
if connector_type == ConnectorType.KV:
|
1512
|
-
self._load_model_from_remote_kv(model, client)
|
1503
|
+
self._load_model_from_remote_kv(model, model_config, client)
|
1513
1504
|
elif connector_type == ConnectorType.FS:
|
1514
1505
|
self._load_model_from_remote_fs(
|
1515
1506
|
model, client, model_config, device_config
|
sglang/srt/model_loader/utils.py
CHANGED
@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
|
|
105
105
|
|
106
106
|
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
107
107
|
return get_model_architecture(model_config)[1]
|
108
|
+
|
109
|
+
|
110
|
+
def post_load_weights(model: nn.Module, model_config: ModelConfig):
|
111
|
+
# Model weight loading consists of two stages:
|
112
|
+
# 1. Initial weight loading.
|
113
|
+
# 2. Post-processing of weights, including assigning specific member variables.
|
114
|
+
# For `dummy_init`, only the second stage is required.
|
115
|
+
if hasattr(model, "post_load_weights"):
|
116
|
+
if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
|
117
|
+
model.post_load_weights(is_nextn=True)
|
118
|
+
else:
|
119
|
+
model.post_load_weights()
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
|
|
114
114
|
is_flashinfer_available,
|
115
115
|
is_hip,
|
116
116
|
is_non_idle_and_non_empty,
|
117
|
+
is_npu,
|
117
118
|
is_sm100_supported,
|
118
119
|
log_info_on_rank0,
|
119
120
|
make_layers,
|
@@ -122,6 +123,7 @@ from sglang.srt.utils import (
|
|
122
123
|
|
123
124
|
_is_hip = is_hip()
|
124
125
|
_is_cuda = is_cuda()
|
126
|
+
_is_npu = is_npu()
|
125
127
|
_is_fp8_fnuz = is_fp8_fnuz()
|
126
128
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
127
129
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -1181,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1181
1183
|
k[..., : self.qk_nope_head_dim] = k_nope
|
1182
1184
|
k[..., self.qk_nope_head_dim :] = k_pe
|
1183
1185
|
|
1184
|
-
|
1185
|
-
|
1186
|
+
if not _is_npu:
|
1187
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
1188
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
1186
1189
|
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1190
|
+
# Save latent cache
|
1191
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1192
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1193
|
+
)
|
1194
|
+
else:
|
1195
|
+
# To reduce a time-costing split operation
|
1196
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1197
|
+
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
1198
|
+
)
|
1191
1199
|
|
1192
1200
|
return q, k, v, forward_batch
|
1193
1201
|
|
@@ -2177,6 +2185,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2177
2185
|
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
2178
2186
|
elif get_moe_expert_parallel_world_size() > 1:
|
2179
2187
|
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
|
2188
|
+
elif self.quant_config.get_name() == "w4afp8":
|
2189
|
+
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
|
2180
2190
|
|
2181
2191
|
if disable_reason is not None:
|
2182
2192
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
@@ -2406,18 +2416,26 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2406
2416
|
)
|
2407
2417
|
|
2408
2418
|
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
2419
|
+
|
2409
2420
|
for layer_id in range(num_hidden_layers):
|
2410
2421
|
if is_nextn:
|
2411
2422
|
layer = self.model.decoder
|
2412
2423
|
else:
|
2413
2424
|
layer = self.model.layers[layer_id]
|
2414
2425
|
|
2415
|
-
|
2416
|
-
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
2417
|
-
layer.self_attn.q_b_proj,
|
2426
|
+
module_list = [
|
2418
2427
|
layer.self_attn.kv_b_proj,
|
2419
2428
|
layer.self_attn.o_proj,
|
2420
|
-
]
|
2429
|
+
]
|
2430
|
+
|
2431
|
+
if self.config.q_lora_rank is not None:
|
2432
|
+
module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
|
2433
|
+
module_list.append(layer.self_attn.q_b_proj)
|
2434
|
+
else:
|
2435
|
+
module_list.append(layer.self_attn.kv_a_proj_with_mqa)
|
2436
|
+
module_list.append(layer.self_attn.q_proj)
|
2437
|
+
|
2438
|
+
for module in module_list:
|
2421
2439
|
requant_weight_ue8m0_inplace(
|
2422
2440
|
module.weight, module.weight_scale_inv, weight_block_size
|
2423
2441
|
)
|
@@ -2480,6 +2498,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2480
2498
|
ckpt_up_proj_name="up_proj",
|
2481
2499
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
2482
2500
|
)
|
2501
|
+
# Params for special naming rules in mixed-precision models, for example:
|
2502
|
+
# model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
|
2503
|
+
# see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
|
2483
2504
|
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
2484
2505
|
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
|
2485
2506
|
num_experts=self.config.n_routed_experts
|
sglang/srt/models/gpt_oss.py
CHANGED
@@ -193,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
193
193
|
return ans
|
194
194
|
|
195
195
|
|
196
|
-
def _enable_fused_set_kv_buffer():
|
197
|
-
|
196
|
+
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
197
|
+
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
198
|
+
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
198
199
|
|
199
200
|
|
200
201
|
# TODO maybe move to a model-common utils
|
@@ -341,7 +342,7 @@ class GptOssAttention(nn.Module):
|
|
341
342
|
layer=self.attn,
|
342
343
|
forward_batch=forward_batch,
|
343
344
|
)
|
344
|
-
if _enable_fused_set_kv_buffer()
|
345
|
+
if _enable_fused_set_kv_buffer(forward_batch)
|
345
346
|
else None
|
346
347
|
),
|
347
348
|
)
|
@@ -355,7 +356,7 @@ class GptOssAttention(nn.Module):
|
|
355
356
|
attn_output = self.attn(
|
356
357
|
*inner_state,
|
357
358
|
sinks=self.sinks,
|
358
|
-
save_kv_cache=not _enable_fused_set_kv_buffer(),
|
359
|
+
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
|
359
360
|
)
|
360
361
|
output, _ = self.o_proj(attn_output)
|
361
362
|
return output
|
@@ -1029,10 +1030,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1029
1030
|
)
|
1030
1031
|
|
1031
1032
|
params_dict = dict(self.named_parameters())
|
1032
|
-
params_checker = {k: False for k, v in params_dict.items()}
|
1033
|
-
|
1034
|
-
for other_loaded_param_name in other_loaded_param_names:
|
1035
|
-
params_checker[other_loaded_param_name] = True
|
1036
1033
|
|
1037
1034
|
for name, loaded_weight in weights:
|
1038
1035
|
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
|
@@ -1069,7 +1066,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1069
1066
|
param = params_dict[name]
|
1070
1067
|
weight_loader = param.weight_loader
|
1071
1068
|
weight_loader(param, loaded_weight, shard_id)
|
1072
|
-
params_checker[name] = True
|
1073
1069
|
break
|
1074
1070
|
else:
|
1075
1071
|
for mapping in expert_params_mapping:
|
@@ -1092,7 +1088,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1092
1088
|
name,
|
1093
1089
|
shard_id=shard_id,
|
1094
1090
|
)
|
1095
|
-
params_checker[name] = True
|
1096
1091
|
break
|
1097
1092
|
else:
|
1098
1093
|
if name.endswith(".bias") and name not in params_dict:
|
@@ -1111,17 +1106,9 @@ class GptOssForCausalLM(nn.Module):
|
|
1111
1106
|
param, "weight_loader", default_weight_loader
|
1112
1107
|
)
|
1113
1108
|
weight_loader(param, loaded_weight)
|
1114
|
-
params_checker[name] = True
|
1115
1109
|
else:
|
1116
1110
|
logger.warning(f"Parameter {name} not found in params_dict")
|
1117
1111
|
|
1118
|
-
not_loaded_params = [k for k, v in params_checker.items() if not v]
|
1119
|
-
if tp_rank == 0:
|
1120
|
-
if len(not_loaded_params) > 0:
|
1121
|
-
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
|
1122
|
-
else:
|
1123
|
-
logging.info("All parameters loaded successfully.")
|
1124
|
-
|
1125
1112
|
def get_embed_and_head(self):
|
1126
1113
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
1127
1114
|
|
@@ -185,9 +185,13 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
|
185
185
|
)
|
186
186
|
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
187
187
|
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
188
|
+
self.load_lm_head_from_target = False
|
188
189
|
if self.config.tie_word_embeddings:
|
189
190
|
self.lm_head = self.model.embed_tokens
|
190
191
|
else:
|
192
|
+
if config.draft_vocab_size is None:
|
193
|
+
self.load_lm_head_from_target = True
|
194
|
+
config.draft_vocab_size = config.vocab_size
|
191
195
|
self.lm_head = ParallelLMHead(
|
192
196
|
config.draft_vocab_size,
|
193
197
|
config.hidden_size,
|