sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +12 -6
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +43 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +36 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +75 -68
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +2 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +33 -7
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +26 -10
- sglang/srt/models/gpt_oss.py +0 -14
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +14 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +83 -78
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@ from functools import wraps
|
|
7
7
|
import psutil
|
8
8
|
import torch
|
9
9
|
|
10
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
10
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
12
11
|
from sglang.srt.utils import is_npu
|
13
12
|
|
@@ -464,8 +463,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
464
463
|
else:
|
465
464
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
466
465
|
|
467
|
-
def get_buffer_meta(self, keys, indices):
|
468
|
-
local_rank = get_tensor_model_parallel_rank()
|
466
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
469
467
|
ptr_list = []
|
470
468
|
key_list = []
|
471
469
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
@@ -704,7 +702,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
704
702
|
else:
|
705
703
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
706
704
|
|
707
|
-
def get_buffer_meta(self, keys, indices):
|
705
|
+
def get_buffer_meta(self, keys, indices, local_rank):
|
708
706
|
ptr_list = []
|
709
707
|
key_list = []
|
710
708
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
@@ -62,7 +62,6 @@ class TreeNode:
|
|
62
62
|
self.host_value: Optional[torch.Tensor] = None
|
63
63
|
# store hash values of each pages
|
64
64
|
self.hash_value: Optional[List[str]] = None
|
65
|
-
self.backuped_storage = False
|
66
65
|
|
67
66
|
self.id = TreeNode.counter if id is None else id
|
68
67
|
TreeNode.counter += 1
|
@@ -195,7 +194,7 @@ class RadixCache(BasePrefixCache):
|
|
195
194
|
last_host_node=last_node,
|
196
195
|
)
|
197
196
|
|
198
|
-
def insert(self, key: List, value=None):
|
197
|
+
def insert(self, key: List, value=None, chunked=False):
|
199
198
|
if self.disable:
|
200
199
|
return 0
|
201
200
|
|
@@ -240,7 +239,7 @@ class RadixCache(BasePrefixCache):
|
|
240
239
|
self.req_to_token_pool.free(req.req_pool_idx)
|
241
240
|
self.dec_lock_ref(req.last_node)
|
242
241
|
|
243
|
-
def cache_unfinished_req(self, req: Req):
|
242
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
244
243
|
"""Cache request when it is unfinished."""
|
245
244
|
if self.disable:
|
246
245
|
return
|
@@ -261,7 +260,9 @@ class RadixCache(BasePrefixCache):
|
|
261
260
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
262
261
|
|
263
262
|
# Radix Cache takes one ref in memory pool
|
264
|
-
new_prefix_len = self.insert(
|
263
|
+
new_prefix_len = self.insert(
|
264
|
+
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
|
265
|
+
)
|
265
266
|
self.token_to_kv_pool_allocator.free(
|
266
267
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
267
268
|
)
|
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
|
|
181
181
|
self.dec_lock_ref(req.last_node)
|
182
182
|
self.req_to_token_pool.free(req.req_pool_idx)
|
183
183
|
|
184
|
-
def cache_unfinished_req(self, req: Req):
|
184
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
185
185
|
"""Cache request when it is unfinished."""
|
186
186
|
assert req.req_pool_idx is not None
|
187
187
|
token_ids = req.fill_ids
|
@@ -125,6 +125,7 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
125
125
|
entries: int,
|
126
126
|
dtype: torch.dtype,
|
127
127
|
metadata_client: Hf3fsMetadataInterface,
|
128
|
+
is_mla_model: bool = False,
|
128
129
|
):
|
129
130
|
self.rank = rank
|
130
131
|
self.file_path = file_path
|
@@ -134,9 +135,13 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
134
135
|
self.entries = entries
|
135
136
|
self.dtype = dtype
|
136
137
|
self.metadata_client = metadata_client
|
137
|
-
|
138
|
+
self.is_mla_model = is_mla_model
|
138
139
|
self.numel = self.bytes_per_page // self.dtype.itemsize
|
139
140
|
self.num_pages = self.file_size // self.bytes_per_page
|
141
|
+
self.skip_backup = False
|
142
|
+
if self.is_mla_model and self.rank != 0:
|
143
|
+
self.skip_backup = True
|
144
|
+
self.rank = 0
|
140
145
|
|
141
146
|
logger.info(
|
142
147
|
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
|
@@ -209,10 +214,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
209
214
|
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
210
215
|
|
211
216
|
# Choose metadata client based on configuration
|
217
|
+
is_mla_model = False
|
212
218
|
if "metadata_server_url" in config and config["metadata_server_url"]:
|
213
219
|
# Use global metadata client to connect to metadata server
|
214
220
|
metadata_server_url = config["metadata_server_url"]
|
215
221
|
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
|
222
|
+
|
223
|
+
# Enable MLA optimization only when using the global metadata client
|
224
|
+
is_mla_model = storage_config.is_mla_model if storage_config else False
|
216
225
|
logger.info(
|
217
226
|
f"Using global metadata client with server url: {metadata_server_url}"
|
218
227
|
)
|
@@ -222,13 +231,15 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
222
231
|
|
223
232
|
return HiCacheHF3FS(
|
224
233
|
rank=rank,
|
225
|
-
|
234
|
+
# Let all ranks use the same file path for MLA model
|
235
|
+
file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
|
226
236
|
file_size=int(config["file_size"]),
|
227
237
|
numjobs=int(config["numjobs"]),
|
228
238
|
bytes_per_page=bytes_per_page,
|
229
239
|
entries=int(config["entries"]),
|
230
240
|
dtype=dtype,
|
231
241
|
metadata_client=metadata_client,
|
242
|
+
is_mla_model=is_mla_model,
|
232
243
|
)
|
233
244
|
|
234
245
|
def get(
|
@@ -312,6 +323,10 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
312
323
|
target_locations: Optional[Any] = None,
|
313
324
|
target_sizes: Optional[Any] = None,
|
314
325
|
) -> bool:
|
326
|
+
# In MLA backend, only one rank needs to backup the KV cache
|
327
|
+
if self.skip_backup:
|
328
|
+
return True
|
329
|
+
|
315
330
|
# Todo: Add prefix block's hash key
|
316
331
|
key_with_prefix = [(key, "") for key in keys]
|
317
332
|
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
@@ -363,18 +378,29 @@ class HiCacheHF3FS(HiCacheStorage):
|
|
363
378
|
|
364
379
|
return all(results)
|
365
380
|
|
366
|
-
@synchronized()
|
367
381
|
def delete(self, key: str) -> None:
|
368
382
|
self.metadata_client.delete_keys(self.rank, [key])
|
369
383
|
|
370
|
-
@synchronized()
|
371
384
|
def exists(self, key: str) -> bool:
|
372
385
|
result = self.metadata_client.exists(self.rank, [key])
|
373
386
|
return result[0] if result else False
|
374
387
|
|
375
|
-
|
376
|
-
|
377
|
-
|
388
|
+
def batch_exists(self, keys: List[str]) -> int:
|
389
|
+
results = self.metadata_client.exists(self.rank, keys)
|
390
|
+
for i in range(len(keys)):
|
391
|
+
if not results[i]:
|
392
|
+
return i
|
393
|
+
|
394
|
+
return len(keys)
|
395
|
+
|
396
|
+
def clear(self) -> bool:
|
397
|
+
try:
|
398
|
+
self.metadata_client.clear(self.rank)
|
399
|
+
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
|
400
|
+
return True
|
401
|
+
except Exception as e:
|
402
|
+
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
|
403
|
+
return False
|
378
404
|
|
379
405
|
def close(self) -> None:
|
380
406
|
try:
|
@@ -159,6 +159,7 @@ class MooncakeStore(HiCacheStorage):
|
|
159
159
|
def batch_set(
|
160
160
|
self,
|
161
161
|
keys: List[str],
|
162
|
+
values: Optional[List[torch.Tensor]] = None,
|
162
163
|
target_location: Optional[List[int]] = None,
|
163
164
|
target_sizes: Optional[List[int]] = None,
|
164
165
|
) -> bool:
|
@@ -253,7 +254,7 @@ class MooncakeStore(HiCacheStorage):
|
|
253
254
|
pass
|
254
255
|
|
255
256
|
def clear(self) -> None:
|
256
|
-
|
257
|
+
self.store.remove_all()
|
257
258
|
|
258
259
|
def _put_batch_zero_copy_impl(
|
259
260
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
|
|
464
464
|
self.req_to_token_pool.free(req.req_pool_idx)
|
465
465
|
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
466
466
|
|
467
|
-
def cache_unfinished_req(self, req: Req) -> None:
|
467
|
+
def cache_unfinished_req(self, req: Req, chunked=False) -> None:
|
468
468
|
"""Cache request when it is unfinished."""
|
469
469
|
if self.disable:
|
470
470
|
kv_indices = self.req_to_token_pool.req_to_token[
|
@@ -307,7 +307,10 @@ class ModelRunner:
|
|
307
307
|
model_num_layers = (
|
308
308
|
self.model_config.num_nextn_predict_layers
|
309
309
|
if self.is_draft_worker and model_has_mtp_layers
|
310
|
-
else
|
310
|
+
else max(
|
311
|
+
self.model_config.num_hidden_layers,
|
312
|
+
self.model_config.num_attention_layers,
|
313
|
+
)
|
311
314
|
)
|
312
315
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
313
316
|
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
@@ -1440,14 +1443,12 @@ class ModelRunner:
|
|
1440
1443
|
else self.server_args.attention_backend
|
1441
1444
|
)
|
1442
1445
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
1443
|
-
assert (
|
1444
|
-
self.server_args.speculative_algorithm is None
|
1445
|
-
), "Currently HybridAttentionBackend does not support speculative decoding."
|
1446
1446
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
1447
1447
|
HybridAttnBackend,
|
1448
1448
|
)
|
1449
1449
|
|
1450
1450
|
attn_backend = HybridAttnBackend(
|
1451
|
+
self,
|
1451
1452
|
decode_backend=self._get_attention_backend_from_str(
|
1452
1453
|
self.decode_attention_backend_str
|
1453
1454
|
),
|
@@ -42,6 +42,7 @@ from sglang.srt.distributed import (
|
|
42
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
43
|
from sglang.srt.model_loader.utils import (
|
44
44
|
get_model_architecture,
|
45
|
+
post_load_weights,
|
45
46
|
set_default_torch_dtype,
|
46
47
|
)
|
47
48
|
from sglang.srt.model_loader.weight_utils import (
|
@@ -600,18 +601,7 @@ class DummyModelLoader(BaseModelLoader):
|
|
600
601
|
# random values to the weights.
|
601
602
|
initialize_dummy_weights(model)
|
602
603
|
|
603
|
-
|
604
|
-
# 1. Initial weight loading.
|
605
|
-
# 2. Post-processing of weights, including assigning specific member variables.
|
606
|
-
# For `dummy_init`, only the second stage is required.
|
607
|
-
if hasattr(model, "post_load_weights"):
|
608
|
-
if (
|
609
|
-
model_config.hf_config.architectures[0]
|
610
|
-
== "DeepseekV3ForCausalLMNextN"
|
611
|
-
):
|
612
|
-
model.post_load_weights(is_nextn=True)
|
613
|
-
else:
|
614
|
-
model.post_load_weights()
|
604
|
+
post_load_weights(model, model_config)
|
615
605
|
|
616
606
|
return model.eval()
|
617
607
|
|
@@ -751,6 +741,9 @@ class ShardedStateLoader(BaseModelLoader):
|
|
751
741
|
state_dict.pop(key)
|
752
742
|
if state_dict:
|
753
743
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
744
|
+
|
745
|
+
post_load_weights(model, model_config)
|
746
|
+
|
754
747
|
return model.eval()
|
755
748
|
|
756
749
|
@staticmethod
|
@@ -1421,18 +1414,16 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1421
1414
|
# ignore hidden files
|
1422
1415
|
if file_name.startswith("."):
|
1423
1416
|
continue
|
1424
|
-
if os.path.splitext(file_name)[1]
|
1425
|
-
".bin",
|
1426
|
-
".pt",
|
1427
|
-
".safetensors",
|
1428
|
-
):
|
1417
|
+
if os.path.splitext(file_name)[1] in (".json", ".py"):
|
1429
1418
|
file_path = os.path.join(root, file_name)
|
1430
1419
|
with open(file_path, encoding="utf-8") as file:
|
1431
1420
|
file_content = file.read()
|
1432
1421
|
f_key = f"{model_name}/files/{file_name}"
|
1433
1422
|
client.setstr(f_key, file_content)
|
1434
1423
|
|
1435
|
-
def _load_model_from_remote_kv(
|
1424
|
+
def _load_model_from_remote_kv(
|
1425
|
+
self, model: nn.Module, model_config: ModelConfig, client
|
1426
|
+
):
|
1436
1427
|
for _, module in model.named_modules():
|
1437
1428
|
quant_method = getattr(module, "quant_method", None)
|
1438
1429
|
if quant_method is not None:
|
@@ -1460,6 +1451,8 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1460
1451
|
if state_dict:
|
1461
1452
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
1462
1453
|
|
1454
|
+
post_load_weights(model, model_config)
|
1455
|
+
|
1463
1456
|
def _load_model_from_remote_fs(
|
1464
1457
|
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1465
1458
|
) -> nn.Module:
|
@@ -1501,15 +1494,13 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1501
1494
|
with set_default_torch_dtype(model_config.dtype):
|
1502
1495
|
with torch.device(device_config.device):
|
1503
1496
|
model = _initialize_model(model_config, self.load_config)
|
1504
|
-
for _, module in model.named_modules():
|
1505
|
-
quant_method = getattr(module, "quant_method", None)
|
1506
|
-
if quant_method is not None:
|
1507
|
-
quant_method.process_weights_after_loading(module)
|
1508
1497
|
|
1509
|
-
with create_remote_connector(
|
1498
|
+
with create_remote_connector(
|
1499
|
+
model_weights, device=device_config.device
|
1500
|
+
) as client:
|
1510
1501
|
connector_type = get_connector_type(client)
|
1511
1502
|
if connector_type == ConnectorType.KV:
|
1512
|
-
self._load_model_from_remote_kv(model, client)
|
1503
|
+
self._load_model_from_remote_kv(model, model_config, client)
|
1513
1504
|
elif connector_type == ConnectorType.FS:
|
1514
1505
|
self._load_model_from_remote_fs(
|
1515
1506
|
model, client, model_config, device_config
|
sglang/srt/model_loader/utils.py
CHANGED
@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
|
|
105
105
|
|
106
106
|
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
107
107
|
return get_model_architecture(model_config)[1]
|
108
|
+
|
109
|
+
|
110
|
+
def post_load_weights(model: nn.Module, model_config: ModelConfig):
|
111
|
+
# Model weight loading consists of two stages:
|
112
|
+
# 1. Initial weight loading.
|
113
|
+
# 2. Post-processing of weights, including assigning specific member variables.
|
114
|
+
# For `dummy_init`, only the second stage is required.
|
115
|
+
if hasattr(model, "post_load_weights"):
|
116
|
+
if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
|
117
|
+
model.post_load_weights(is_nextn=True)
|
118
|
+
else:
|
119
|
+
model.post_load_weights()
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
|
|
114
114
|
is_flashinfer_available,
|
115
115
|
is_hip,
|
116
116
|
is_non_idle_and_non_empty,
|
117
|
+
is_npu,
|
117
118
|
is_sm100_supported,
|
118
119
|
log_info_on_rank0,
|
119
120
|
make_layers,
|
@@ -122,6 +123,7 @@ from sglang.srt.utils import (
|
|
122
123
|
|
123
124
|
_is_hip = is_hip()
|
124
125
|
_is_cuda = is_cuda()
|
126
|
+
_is_npu = is_npu()
|
125
127
|
_is_fp8_fnuz = is_fp8_fnuz()
|
126
128
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
127
129
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -1181,13 +1183,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1181
1183
|
k[..., : self.qk_nope_head_dim] = k_nope
|
1182
1184
|
k[..., self.qk_nope_head_dim :] = k_pe
|
1183
1185
|
|
1184
|
-
|
1185
|
-
|
1186
|
+
if not _is_npu:
|
1187
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
1188
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
1186
1189
|
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1190
|
+
# Save latent cache
|
1191
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1192
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1193
|
+
)
|
1194
|
+
else:
|
1195
|
+
# To reduce a time-costing split operation
|
1196
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1197
|
+
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
1198
|
+
)
|
1191
1199
|
|
1192
1200
|
return q, k, v, forward_batch
|
1193
1201
|
|
@@ -2406,18 +2414,26 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2406
2414
|
)
|
2407
2415
|
|
2408
2416
|
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
2417
|
+
|
2409
2418
|
for layer_id in range(num_hidden_layers):
|
2410
2419
|
if is_nextn:
|
2411
2420
|
layer = self.model.decoder
|
2412
2421
|
else:
|
2413
2422
|
layer = self.model.layers[layer_id]
|
2414
2423
|
|
2415
|
-
|
2416
|
-
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
2417
|
-
layer.self_attn.q_b_proj,
|
2424
|
+
module_list = [
|
2418
2425
|
layer.self_attn.kv_b_proj,
|
2419
2426
|
layer.self_attn.o_proj,
|
2420
|
-
]
|
2427
|
+
]
|
2428
|
+
|
2429
|
+
if self.config.q_lora_rank is not None:
|
2430
|
+
module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
|
2431
|
+
module_list.append(layer.self_attn.q_b_proj)
|
2432
|
+
else:
|
2433
|
+
module_list.append(layer.self_attn.kv_a_proj_with_mqa)
|
2434
|
+
module_list.append(layer.self_attn.q_proj)
|
2435
|
+
|
2436
|
+
for module in module_list:
|
2421
2437
|
requant_weight_ue8m0_inplace(
|
2422
2438
|
module.weight, module.weight_scale_inv, weight_block_size
|
2423
2439
|
)
|
sglang/srt/models/gpt_oss.py
CHANGED
@@ -1029,10 +1029,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1029
1029
|
)
|
1030
1030
|
|
1031
1031
|
params_dict = dict(self.named_parameters())
|
1032
|
-
params_checker = {k: False for k, v in params_dict.items()}
|
1033
|
-
|
1034
|
-
for other_loaded_param_name in other_loaded_param_names:
|
1035
|
-
params_checker[other_loaded_param_name] = True
|
1036
1032
|
|
1037
1033
|
for name, loaded_weight in weights:
|
1038
1034
|
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
|
@@ -1069,7 +1065,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1069
1065
|
param = params_dict[name]
|
1070
1066
|
weight_loader = param.weight_loader
|
1071
1067
|
weight_loader(param, loaded_weight, shard_id)
|
1072
|
-
params_checker[name] = True
|
1073
1068
|
break
|
1074
1069
|
else:
|
1075
1070
|
for mapping in expert_params_mapping:
|
@@ -1092,7 +1087,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1092
1087
|
name,
|
1093
1088
|
shard_id=shard_id,
|
1094
1089
|
)
|
1095
|
-
params_checker[name] = True
|
1096
1090
|
break
|
1097
1091
|
else:
|
1098
1092
|
if name.endswith(".bias") and name not in params_dict:
|
@@ -1111,17 +1105,9 @@ class GptOssForCausalLM(nn.Module):
|
|
1111
1105
|
param, "weight_loader", default_weight_loader
|
1112
1106
|
)
|
1113
1107
|
weight_loader(param, loaded_weight)
|
1114
|
-
params_checker[name] = True
|
1115
1108
|
else:
|
1116
1109
|
logger.warning(f"Parameter {name} not found in params_dict")
|
1117
1110
|
|
1118
|
-
not_loaded_params = [k for k, v in params_checker.items() if not v]
|
1119
|
-
if tp_rank == 0:
|
1120
|
-
if len(not_loaded_params) > 0:
|
1121
|
-
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
|
1122
|
-
else:
|
1123
|
-
logging.info("All parameters loaded successfully.")
|
1124
|
-
|
1125
1111
|
def get_embed_and_head(self):
|
1126
1112
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
1127
1113
|
|
@@ -185,9 +185,13 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
|
185
185
|
)
|
186
186
|
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
187
187
|
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
188
|
+
self.load_lm_head_from_target = False
|
188
189
|
if self.config.tie_word_embeddings:
|
189
190
|
self.lm_head = self.model.embed_tokens
|
190
191
|
else:
|
192
|
+
if config.draft_vocab_size is None:
|
193
|
+
self.load_lm_head_from_target = True
|
194
|
+
config.draft_vocab_size = config.vocab_size
|
191
195
|
self.lm_head = ParallelLMHead(
|
192
196
|
config.draft_vocab_size,
|
193
197
|
config.hidden_size,
|