sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +59 -11
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +24 -9
- sglang/srt/entrypoints/http_server.py +8 -2
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +29 -12
- sglang/srt/managers/scheduler.py +31 -20
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +11 -24
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +18 -8
- sglang/srt/server_args.py +15 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +2 -1
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +36 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
-
from sglang.srt.utils import get_compiler_backend
|
41
|
+
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
44
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -364,23 +364,23 @@ class ForwardBatch:
|
|
364
364
|
|
365
365
|
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
366
366
|
"""
|
367
|
-
Merge all
|
367
|
+
Merge all multimodal inputs in the batch into a single MultiModalInputs object.
|
368
368
|
|
369
369
|
Returns:
|
370
|
-
if none, current batch contains no
|
370
|
+
if none, current batch contains no multimodal input
|
371
371
|
|
372
372
|
"""
|
373
373
|
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
|
374
374
|
return None
|
375
|
-
|
376
375
|
# Filter out None values
|
377
376
|
valid_inputs = [x for x in self.mm_inputs if x is not None]
|
378
377
|
|
379
|
-
#
|
380
|
-
|
378
|
+
# TODO: is it expensive?
|
379
|
+
# a workaround to avoid importing `MultimodalInputs`
|
380
|
+
merged = valid_inputs[0].__class__(mm_items=[])
|
381
381
|
|
382
382
|
# Merge remaining inputs
|
383
|
-
for mm_input in valid_inputs
|
383
|
+
for mm_input in valid_inputs:
|
384
384
|
merged.merge(mm_input)
|
385
385
|
|
386
386
|
return merged
|
@@ -407,104 +407,60 @@ class ForwardBatch:
|
|
407
407
|
def _compute_mrope_positions(
|
408
408
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
409
409
|
):
|
410
|
-
|
411
|
-
|
412
|
-
mrope_positions_list = [
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
421
|
-
mrope_position_delta,
|
422
|
-
int(self.seq_lens[i]) - 1,
|
423
|
-
int(self.seq_lens[i]),
|
410
|
+
# batch_size * [3 * seq_len]
|
411
|
+
batch_size = self.seq_lens.shape[0]
|
412
|
+
mrope_positions_list = [[]] * batch_size
|
413
|
+
for batch_idx in range(batch_size):
|
414
|
+
mm_input = batch.multimodal_inputs[batch_idx]
|
415
|
+
if self.forward_mode.is_decode():
|
416
|
+
mrope_position_deltas = (
|
417
|
+
[0]
|
418
|
+
if mm_input is None
|
419
|
+
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
|
424
420
|
)
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
421
|
+
next_input_positions = []
|
422
|
+
for mrope_position_delta in mrope_position_deltas:
|
423
|
+
# batched deltas needs to be processed separately
|
424
|
+
# Convert list of lists to tensor with shape [3, seq_len]
|
425
|
+
next_input_positions += [
|
426
|
+
MRotaryEmbedding.get_next_input_positions(
|
427
|
+
mrope_position_delta,
|
428
|
+
int(self.seq_lens[batch_idx]) - 1,
|
429
|
+
int(self.seq_lens[batch_idx]),
|
430
|
+
)
|
431
|
+
]
|
432
|
+
# 3 * N
|
433
|
+
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
|
434
|
+
elif self.forward_mode.is_extend():
|
435
|
+
extend_seq_len, extend_prefix_len = (
|
436
|
+
batch.extend_seq_lens[batch_idx],
|
437
|
+
batch.extend_prefix_lens[batch_idx],
|
432
438
|
)
|
433
439
|
if mm_input is None:
|
434
440
|
# text only
|
435
|
-
mrope_positions =
|
441
|
+
mrope_positions = torch.tensor(
|
436
442
|
[
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
443
|
+
[
|
444
|
+
pos
|
445
|
+
for pos in range(
|
446
|
+
extend_prefix_len,
|
447
|
+
extend_prefix_len + extend_seq_len,
|
448
|
+
)
|
449
|
+
]
|
441
450
|
]
|
442
|
-
|
443
|
-
else:
|
444
|
-
image_grid_thws_list = [
|
445
|
-
item.image_grid_thws
|
446
|
-
for item in mm_input.mm_items
|
447
|
-
if item.image_grid_thws is not None
|
448
|
-
]
|
449
|
-
image_grid_thw = (
|
450
|
-
None
|
451
|
-
if len(image_grid_thws_list) == 0
|
452
|
-
else torch.cat(image_grid_thws_list, dim=0)
|
453
|
-
)
|
454
|
-
|
455
|
-
video_grid_thws_list = [
|
456
|
-
item.video_grid_thws
|
457
|
-
for item in mm_input.mm_items
|
458
|
-
if item.video_grid_thws is not None
|
459
|
-
]
|
460
|
-
video_grid_thw = (
|
461
|
-
None
|
462
|
-
if len(video_grid_thws_list) == 0
|
463
|
-
else torch.cat(video_grid_thws_list, dim=0)
|
451
|
+
* 3
|
464
452
|
)
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
if item.second_per_grid_ts is not None
|
453
|
+
else:
|
454
|
+
mrope_positions = mm_input.mrope_positions[
|
455
|
+
:,
|
456
|
+
extend_prefix_len : extend_prefix_len + extend_seq_len,
|
470
457
|
]
|
471
|
-
|
472
|
-
None
|
473
|
-
if len(second_per_grid_ts_list) == 0
|
474
|
-
else torch.cat(second_per_grid_ts_list, dim=0)
|
475
|
-
)
|
476
|
-
|
477
|
-
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
478
|
-
mrope_positions, mrope_position_delta = (
|
479
|
-
MRotaryEmbedding.get_input_positions(
|
480
|
-
input_tokens=self.input_ids[
|
481
|
-
extend_start_loc : extend_start_loc + extend_seq_len
|
482
|
-
].tolist(),
|
483
|
-
image_grid_thw=image_grid_thw,
|
484
|
-
video_grid_thw=video_grid_thw,
|
485
|
-
image_token_id=hf_config.image_token_id,
|
486
|
-
video_token_id=hf_config.video_token_id,
|
487
|
-
vision_start_token_id=hf_config.vision_start_token_id,
|
488
|
-
vision_end_token_id=hf_config.vision_end_token_id,
|
489
|
-
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
490
|
-
context_len=0,
|
491
|
-
seq_len=len(self.input_ids),
|
492
|
-
second_per_grid_ts=second_per_grid_ts,
|
493
|
-
tokens_per_second=getattr(
|
494
|
-
hf_config.vision_config, "tokens_per_second", None
|
495
|
-
),
|
496
|
-
)
|
497
|
-
)
|
498
|
-
batch.multimodal_inputs[i].mrope_position_delta = (
|
499
|
-
mrope_position_delta
|
500
|
-
)
|
501
|
-
mrope_positions_list[i] = mrope_positions
|
458
|
+
mrope_positions_list[batch_idx] = mrope_positions
|
502
459
|
|
503
460
|
self.mrope_positions = torch.cat(
|
504
|
-
[
|
505
|
-
|
506
|
-
)
|
507
|
-
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
461
|
+
[pos.to(device=model_runner.device) for pos in mrope_positions_list],
|
462
|
+
dim=1,
|
463
|
+
).to(dtype=torch.int64, device=model_runner.device)
|
508
464
|
|
509
465
|
def get_max_chunk_capacity(self):
|
510
466
|
# Maximum number of tokens in each chunk
|
@@ -91,11 +91,14 @@ from sglang.srt.utils import (
|
|
91
91
|
set_cuda_arch,
|
92
92
|
)
|
93
93
|
|
94
|
-
|
95
|
-
|
94
|
+
# Use a small KV cache pool size for tests in CI
|
96
95
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
96
|
+
|
97
|
+
# Detect stragger ranks in model loading
|
97
98
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
98
99
|
|
100
|
+
logger = logging.getLogger(__name__)
|
101
|
+
|
99
102
|
|
100
103
|
class ModelRunner:
|
101
104
|
"""ModelRunner runs the forward passes of the models."""
|
@@ -177,7 +180,7 @@ class ModelRunner:
|
|
177
180
|
if _ENABLE_JIT_DEEPGEMM:
|
178
181
|
update_deep_gemm_config(gpu_id, server_args)
|
179
182
|
|
180
|
-
# If it is a draft model tp_group can be different
|
183
|
+
# If it is a draft model, tp_group can be different
|
181
184
|
self.initialize(min_per_gpu_memory)
|
182
185
|
|
183
186
|
def initialize(self, min_per_gpu_memory: float):
|
@@ -230,7 +233,8 @@ class ModelRunner:
|
|
230
233
|
|
231
234
|
if server_args.attention_backend is None:
|
232
235
|
"""
|
233
|
-
|
236
|
+
Auto select the fastest attention backend.
|
237
|
+
|
234
238
|
1. Models with MHA Architecture (e.g: Llama, QWen)
|
235
239
|
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
|
236
240
|
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
|
@@ -240,6 +244,7 @@ class ModelRunner:
|
|
240
244
|
"""
|
241
245
|
|
242
246
|
if not self.use_mla_backend:
|
247
|
+
# MHA architecture
|
243
248
|
if (
|
244
249
|
is_hopper_with_cuda_12_3()
|
245
250
|
and is_no_spec_infer_or_topk_one(server_args)
|
@@ -251,6 +256,7 @@ class ModelRunner:
|
|
251
256
|
"flashinfer" if is_flashinfer_available() else "triton"
|
252
257
|
)
|
253
258
|
else:
|
259
|
+
# MLA architecture
|
254
260
|
if is_hopper_with_cuda_12_3():
|
255
261
|
server_args.attention_backend = "fa3"
|
256
262
|
else:
|
@@ -259,7 +265,6 @@ class ModelRunner:
|
|
259
265
|
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
260
266
|
)
|
261
267
|
elif self.use_mla_backend:
|
262
|
-
# TODO: add MLA optimization on CPU
|
263
268
|
if server_args.device != "cpu":
|
264
269
|
if server_args.attention_backend in [
|
265
270
|
"flashinfer",
|
@@ -275,7 +280,7 @@ class ModelRunner:
|
|
275
280
|
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
276
281
|
)
|
277
282
|
else:
|
278
|
-
raise ValueError(
|
283
|
+
raise ValueError("MLA optimization not supported on CPU.")
|
279
284
|
|
280
285
|
if (
|
281
286
|
server_args.attention_backend == "fa3"
|
@@ -310,18 +315,6 @@ class ModelRunner:
|
|
310
315
|
)
|
311
316
|
server_args.chunked_prefill_size = -1
|
312
317
|
|
313
|
-
if self.model_config.hf_config.architectures == [
|
314
|
-
"Qwen2VLForConditionalGeneration"
|
315
|
-
] or self.model_config.hf_config.architectures == [
|
316
|
-
"Qwen2_5_VLForConditionalGeneration"
|
317
|
-
]:
|
318
|
-
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
|
319
|
-
logger.info("Automatically disable radix cache for qwen-vl series.")
|
320
|
-
server_args.disable_radix_cache = True
|
321
|
-
|
322
|
-
if server_args.enable_deepep_moe:
|
323
|
-
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
324
|
-
|
325
318
|
if not self.use_mla_backend:
|
326
319
|
server_args.disable_chunked_prefix_cache = True
|
327
320
|
elif self.page_size > 1:
|
@@ -964,12 +957,6 @@ class ModelRunner:
|
|
964
957
|
return
|
965
958
|
|
966
959
|
if self.server_args.disable_cuda_graph:
|
967
|
-
logger.warning(
|
968
|
-
"\n\nCUDA Graph is DISABLED.\n"
|
969
|
-
"This will cause significant performance degradation.\n"
|
970
|
-
"CUDA Graph should almost never be disabled in most usage scenarios.\n"
|
971
|
-
"If you encounter OOM issues, please try setting --mem-fraction-static to a lower value (such as 0.8 or 0.7) instead of disabling CUDA Graph.\n"
|
972
|
-
)
|
973
960
|
return
|
974
961
|
|
975
962
|
tic = time.time()
|
sglang/srt/models/deepseek.py
CHANGED
@@ -382,8 +382,14 @@ class DeepseekModel(nn.Module):
|
|
382
382
|
input_ids: torch.Tensor,
|
383
383
|
positions: torch.Tensor,
|
384
384
|
forward_batch: ForwardBatch,
|
385
|
+
input_embeds: torch.Tensor = None,
|
385
386
|
) -> torch.Tensor:
|
386
|
-
|
387
|
+
|
388
|
+
if input_embeds is None:
|
389
|
+
hidden_states = self.embed_tokens(input_ids)
|
390
|
+
else:
|
391
|
+
hidden_states = input_embeds
|
392
|
+
|
387
393
|
residual = None
|
388
394
|
for i in range(len(self.layers)):
|
389
395
|
layer = self.layers[i]
|
@@ -416,14 +422,18 @@ class DeepseekForCausalLM(nn.Module):
|
|
416
422
|
)
|
417
423
|
self.logits_processor = LogitsProcessor(config)
|
418
424
|
|
425
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
426
|
+
return self.model.embed_tokens
|
427
|
+
|
419
428
|
@torch.no_grad()
|
420
429
|
def forward(
|
421
430
|
self,
|
422
431
|
input_ids: torch.Tensor,
|
423
432
|
positions: torch.Tensor,
|
424
433
|
forward_batch: ForwardBatch,
|
434
|
+
input_embeds: torch.Tensor = None,
|
425
435
|
) -> torch.Tensor:
|
426
|
-
hidden_states = self.model(input_ids, positions, forward_batch)
|
436
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
427
437
|
return self.logits_processor(
|
428
438
|
input_ids, hidden_states, self.lm_head, forward_batch
|
429
439
|
)
|
@@ -13,12 +13,14 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
15
|
"""Inference-only DeepSeek NextN Speculative Decoding."""
|
16
|
+
import logging
|
16
17
|
from typing import Iterable, Optional, Tuple
|
17
18
|
|
18
19
|
import torch
|
19
20
|
from torch import nn
|
20
21
|
from transformers import PretrainedConfig
|
21
22
|
|
23
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
22
24
|
from sglang.srt.layers.layernorm import RMSNorm
|
23
25
|
from sglang.srt.layers.linear import ReplicatedLinear
|
24
26
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -51,6 +53,9 @@ else:
|
|
51
53
|
from vllm._custom_ops import awq_dequantize
|
52
54
|
|
53
55
|
|
56
|
+
logger = logging.getLogger(__name__)
|
57
|
+
|
58
|
+
|
54
59
|
class DeepseekModelNextN(nn.Module):
|
55
60
|
def __init__(
|
56
61
|
self,
|
@@ -134,7 +139,9 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
134
139
|
) -> None:
|
135
140
|
nn.Module.__init__(self)
|
136
141
|
self.config = config
|
142
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
137
143
|
self.quant_config = quant_config
|
144
|
+
self.determine_n_share_experts_fusion("DeepseekV3ForCausalLMNextN")
|
138
145
|
|
139
146
|
self.model = DeepseekModelNextN(
|
140
147
|
config, quant_config, prefix=add_prefix("model", prefix)
|
@@ -182,6 +189,48 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
182
189
|
("gate_up_proj", "gate_proj", 0),
|
183
190
|
("gate_up_proj", "up_proj", 1),
|
184
191
|
]
|
192
|
+
if self.n_share_experts_fusion > 0:
|
193
|
+
logger.info(
|
194
|
+
f"Cloning {self.n_share_experts_fusion} "
|
195
|
+
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
|
196
|
+
)
|
197
|
+
weights_list = list(weights)
|
198
|
+
weights_dict = dict(weights_list)
|
199
|
+
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
|
200
|
+
suffix_list = [
|
201
|
+
"down_proj.weight",
|
202
|
+
"down_proj.weight_scale",
|
203
|
+
"gate_proj.weight",
|
204
|
+
"gate_proj.weight_scale",
|
205
|
+
"up_proj.weight",
|
206
|
+
"up_proj.weight_scale",
|
207
|
+
]
|
208
|
+
else:
|
209
|
+
suffix_list = [
|
210
|
+
"down_proj.weight",
|
211
|
+
"down_proj.weight_scale_inv",
|
212
|
+
"gate_proj.weight",
|
213
|
+
"gate_proj.weight_scale_inv",
|
214
|
+
"up_proj.weight",
|
215
|
+
"up_proj.weight_scale_inv",
|
216
|
+
]
|
217
|
+
names_to_remove = []
|
218
|
+
for suffix in suffix_list:
|
219
|
+
shared_expert_weight_name = (
|
220
|
+
f"model.layers.0.mlp.shared_experts.{suffix}"
|
221
|
+
)
|
222
|
+
for num_repeat in range(self.n_share_experts_fusion):
|
223
|
+
weights_list.append(
|
224
|
+
(
|
225
|
+
f"model.layers.0."
|
226
|
+
f"mlp.experts."
|
227
|
+
f"{self.config.n_routed_experts + num_repeat}"
|
228
|
+
f".{suffix}",
|
229
|
+
weights_dict[shared_expert_weight_name],
|
230
|
+
)
|
231
|
+
)
|
232
|
+
names_to_remove += [shared_expert_weight_name]
|
233
|
+
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
185
234
|
|
186
235
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
187
236
|
# (param_name, weight_name, expert_id, shard_id)
|
@@ -190,8 +239,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
190
239
|
ckpt_gate_proj_name="gate_proj",
|
191
240
|
ckpt_down_proj_name="down_proj",
|
192
241
|
ckpt_up_proj_name="up_proj",
|
193
|
-
num_experts=self.config.n_routed_experts,
|
242
|
+
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
243
|
+
)
|
244
|
+
|
245
|
+
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
246
|
+
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
|
247
|
+
self.config.q_lora_rank is not None
|
194
248
|
)
|
249
|
+
cached_a_proj = {} if fuse_qkv_a_proj else None
|
195
250
|
|
196
251
|
nextn_layer_prefix = "model.layers.0"
|
197
252
|
nextn_spec_weight_names = [
|
@@ -264,11 +319,51 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
264
319
|
if name.endswith(".bias") and name not in params_dict:
|
265
320
|
continue
|
266
321
|
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
)
|
271
|
-
|
322
|
+
# Handle fused_qkv_a_proj
|
323
|
+
if fuse_qkv_a_proj and (
|
324
|
+
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
325
|
+
):
|
326
|
+
cached_a_proj[name] = loaded_weight
|
327
|
+
q_a_proj_name = (
|
328
|
+
name
|
329
|
+
if "q_a_proj" in name
|
330
|
+
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
|
331
|
+
)
|
332
|
+
kv_a_proj_name = (
|
333
|
+
name
|
334
|
+
if "kv_a_proj_with_mqa" in name
|
335
|
+
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
|
336
|
+
)
|
337
|
+
|
338
|
+
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
|
339
|
+
if (
|
340
|
+
q_a_proj_name in cached_a_proj
|
341
|
+
and kv_a_proj_name in cached_a_proj
|
342
|
+
):
|
343
|
+
|
344
|
+
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
345
|
+
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
346
|
+
fused_weight = torch.cat(
|
347
|
+
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
348
|
+
)
|
349
|
+
|
350
|
+
param_name = name.replace(
|
351
|
+
"q_a_proj", "fused_qkv_a_proj_with_mqa"
|
352
|
+
)
|
353
|
+
param = params_dict[param_name]
|
354
|
+
|
355
|
+
weight_loader = getattr(
|
356
|
+
param, "weight_loader", default_weight_loader
|
357
|
+
)
|
358
|
+
weight_loader(param, fused_weight)
|
359
|
+
cached_a_proj.pop(q_a_proj_name)
|
360
|
+
cached_a_proj.pop(kv_a_proj_name)
|
361
|
+
else:
|
362
|
+
param = params_dict[name]
|
363
|
+
weight_loader = getattr(
|
364
|
+
param, "weight_loader", default_weight_loader
|
365
|
+
)
|
366
|
+
weight_loader(param, loaded_weight)
|
272
367
|
|
273
368
|
self_attn = self.model.decoder.self_attn
|
274
369
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|