sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -29,10 +29,14 @@ from tqdm import tqdm
|
|
29
29
|
from transformers import PretrainedConfig
|
30
30
|
|
31
31
|
from sglang.srt.distributed import (
|
32
|
+
get_moe_expert_parallel_world_size,
|
32
33
|
get_tensor_model_parallel_world_size,
|
33
34
|
parallel_state,
|
34
35
|
tensor_model_parallel_all_reduce,
|
35
36
|
)
|
37
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
38
|
+
use_symmetric_memory,
|
39
|
+
)
|
36
40
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
37
41
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
38
42
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
@@ -56,13 +60,9 @@ from sglang.srt.layers.linear import (
|
|
56
60
|
RowParallelLinear,
|
57
61
|
)
|
58
62
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
59
|
-
from sglang.srt.layers.moe.ep_moe.layer import
|
60
|
-
DeepEPMoE,
|
61
|
-
get_moe_impl_class,
|
62
|
-
should_use_flashinfer_trtllm_moe,
|
63
|
-
)
|
64
|
-
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
63
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
65
64
|
from sglang.srt.layers.moe.topk import TopK
|
65
|
+
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
66
66
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
67
67
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
68
68
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -96,7 +96,6 @@ from sglang.srt.two_batch_overlap import (
|
|
96
96
|
)
|
97
97
|
from sglang.srt.utils import (
|
98
98
|
BumpAllocator,
|
99
|
-
DeepEPMode,
|
100
99
|
LazyValue,
|
101
100
|
add_prefix,
|
102
101
|
bind_or_assign,
|
@@ -209,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
|
|
209
208
|
)
|
210
209
|
self.act_fn = SiluAndMul()
|
211
210
|
|
212
|
-
def forward(
|
211
|
+
def forward(
|
212
|
+
self,
|
213
|
+
x,
|
214
|
+
forward_batch=None,
|
215
|
+
can_fuse_mlp_allreduce: bool = False,
|
216
|
+
use_reduce_scatter: bool = False,
|
217
|
+
):
|
213
218
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
214
219
|
return x
|
215
220
|
|
216
221
|
gate_up, _ = self.gate_up_proj(x)
|
217
222
|
x = self.act_fn(gate_up)
|
218
|
-
x, _ = self.down_proj(
|
223
|
+
x, _ = self.down_proj(
|
224
|
+
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
|
225
|
+
)
|
219
226
|
return x
|
220
227
|
|
221
228
|
|
@@ -305,19 +312,15 @@ class DeepseekV2MoE(nn.Module):
|
|
305
312
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
306
313
|
)
|
307
314
|
|
308
|
-
self.topk = (
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
routed_scaling_factor=self.routed_scaling_factor,
|
318
|
-
)
|
319
|
-
if not should_use_flashinfer_trtllm_moe()
|
320
|
-
else None
|
315
|
+
self.topk = TopK(
|
316
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
317
|
+
renormalize=config.norm_topk_prob,
|
318
|
+
use_grouped_topk=True,
|
319
|
+
num_expert_group=config.n_group,
|
320
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
321
|
+
topk_group=config.topk_group,
|
322
|
+
correction_bias=self.gate.e_score_correction_bias,
|
323
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
321
324
|
)
|
322
325
|
|
323
326
|
self.experts = get_moe_impl_class()(
|
@@ -333,15 +336,14 @@ class DeepseekV2MoE(nn.Module):
|
|
333
336
|
routed_scaling_factor=self.routed_scaling_factor,
|
334
337
|
prefix=add_prefix("experts", prefix),
|
335
338
|
**(
|
336
|
-
dict(deepep_mode=
|
337
|
-
if global_server_args_dict["
|
339
|
+
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
340
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
338
341
|
else {}
|
339
342
|
),
|
340
343
|
# Additional args for FusedMoE
|
341
344
|
**(
|
342
345
|
dict(
|
343
346
|
enable_flashinfer_cutlass_moe=True,
|
344
|
-
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
345
347
|
)
|
346
348
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
347
349
|
else {}
|
@@ -374,7 +376,7 @@ class DeepseekV2MoE(nn.Module):
|
|
374
376
|
prefix=add_prefix("shared_experts", prefix),
|
375
377
|
**(
|
376
378
|
dict(tp_rank=0, tp_size=1)
|
377
|
-
if global_server_args_dict["
|
379
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
378
380
|
else {}
|
379
381
|
),
|
380
382
|
)
|
@@ -404,9 +406,9 @@ class DeepseekV2MoE(nn.Module):
|
|
404
406
|
|
405
407
|
self.top_k = config.num_experts_per_tok
|
406
408
|
|
407
|
-
if global_server_args_dict["
|
409
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
408
410
|
# TODO: we will support tp < ep in the future
|
409
|
-
self.ep_size =
|
411
|
+
self.ep_size = get_moe_expert_parallel_world_size()
|
410
412
|
self.num_experts = (
|
411
413
|
config.n_routed_experts
|
412
414
|
+ global_server_args_dict["ep_num_redundant_experts"]
|
@@ -428,12 +430,12 @@ class DeepseekV2MoE(nn.Module):
|
|
428
430
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
429
431
|
hidden_size=config.hidden_size,
|
430
432
|
params_dtype=config.torch_dtype,
|
431
|
-
deepep_mode=
|
433
|
+
deepep_mode=global_server_args_dict["deepep_mode"],
|
432
434
|
async_finish=True,
|
433
435
|
return_recv_hook=True,
|
434
436
|
)
|
435
437
|
|
436
|
-
self._enable_deepep_moe = global_server_args_dict["
|
438
|
+
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
437
439
|
|
438
440
|
def get_moe_weights(self):
|
439
441
|
return [
|
@@ -447,6 +449,7 @@ class DeepseekV2MoE(nn.Module):
|
|
447
449
|
hidden_states: torch.Tensor,
|
448
450
|
forward_batch: Optional[ForwardBatch] = None,
|
449
451
|
can_fuse_mlp_allreduce: bool = False,
|
452
|
+
use_reduce_scatter: bool = False,
|
450
453
|
) -> torch.Tensor:
|
451
454
|
if not self._enable_deepep_moe:
|
452
455
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
@@ -456,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
|
|
456
459
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
457
460
|
):
|
458
461
|
return self.forward_normal_dual_stream(
|
459
|
-
hidden_states, can_fuse_mlp_allreduce
|
462
|
+
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
460
463
|
)
|
461
464
|
else:
|
462
|
-
return self.forward_normal(
|
465
|
+
return self.forward_normal(
|
466
|
+
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
467
|
+
)
|
463
468
|
else:
|
464
469
|
return self.forward_deepep(hidden_states, forward_batch)
|
465
470
|
|
466
471
|
def forward_normal_dual_stream(
|
467
|
-
self,
|
472
|
+
self,
|
473
|
+
hidden_states: torch.Tensor,
|
474
|
+
can_fuse_mlp_allreduce: bool = False,
|
475
|
+
use_reduce_scatter: bool = False,
|
468
476
|
) -> torch.Tensor:
|
469
477
|
|
470
478
|
current_stream = torch.cuda.current_stream()
|
@@ -475,21 +483,32 @@ class DeepseekV2MoE(nn.Module):
|
|
475
483
|
# router_logits: (num_tokens, n_experts)
|
476
484
|
router_logits = self.gate(hidden_states)
|
477
485
|
kwargs = {"hidden_states": hidden_states}
|
478
|
-
|
479
|
-
|
486
|
+
|
487
|
+
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
488
|
+
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
489
|
+
if should_use_flashinfer_trtllm_moe():
|
490
|
+
kwargs["topk_output"] = (self.topk, router_logits)
|
480
491
|
else:
|
481
|
-
kwargs["
|
492
|
+
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
493
|
+
|
482
494
|
final_hidden_states = self.experts(**kwargs)
|
483
495
|
if not _is_cuda:
|
484
496
|
final_hidden_states *= self.routed_scaling_factor
|
485
497
|
current_stream.wait_stream(self.alt_stream)
|
486
|
-
|
487
|
-
|
498
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
499
|
+
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
500
|
+
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
501
|
+
final_hidden_states = final_hidden_states_out
|
502
|
+
sm.tag(final_hidden_states)
|
503
|
+
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
488
504
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
489
505
|
return final_hidden_states
|
490
506
|
|
491
507
|
def forward_normal(
|
492
|
-
self,
|
508
|
+
self,
|
509
|
+
hidden_states: torch.Tensor,
|
510
|
+
can_fuse_mlp_allreduce: bool = False,
|
511
|
+
use_reduce_scatter: bool = False,
|
493
512
|
) -> torch.Tensor:
|
494
513
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
495
514
|
self.shared_experts.gate_up_proj
|
@@ -500,17 +519,25 @@ class DeepseekV2MoE(nn.Module):
|
|
500
519
|
# router_logits: (num_tokens, n_experts)
|
501
520
|
router_logits = self.gate(hidden_states)
|
502
521
|
kwargs = {"hidden_states": hidden_states}
|
503
|
-
|
504
|
-
|
522
|
+
|
523
|
+
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
|
524
|
+
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
|
525
|
+
if should_use_flashinfer_trtllm_moe():
|
526
|
+
kwargs["topk_output"] = (self.topk, router_logits)
|
505
527
|
else:
|
506
|
-
kwargs["
|
528
|
+
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
529
|
+
|
507
530
|
final_hidden_states = self.experts(**kwargs)
|
508
531
|
if not _is_cuda and not _use_aiter:
|
509
532
|
# fused in biased_grouped_topk so we can skip here
|
510
533
|
final_hidden_states *= self.routed_scaling_factor
|
511
534
|
if shared_output is not None:
|
512
|
-
|
513
|
-
|
535
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
536
|
+
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
537
|
+
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
538
|
+
final_hidden_states = final_hidden_states_out
|
539
|
+
sm.tag(final_hidden_states)
|
540
|
+
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
514
541
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
515
542
|
return final_hidden_states
|
516
543
|
|
@@ -1812,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1812
1839
|
layer_scatter_modes=self.layer_scatter_modes,
|
1813
1840
|
input_layernorm=self.input_layernorm,
|
1814
1841
|
post_attention_layernorm=self.post_attention_layernorm,
|
1842
|
+
allow_reduce_scatter=True,
|
1815
1843
|
)
|
1816
1844
|
|
1817
1845
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
@@ -1874,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1874
1902
|
and not self.is_nextn
|
1875
1903
|
)
|
1876
1904
|
|
1877
|
-
|
1905
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
1906
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
1907
|
+
forward_batch
|
1908
|
+
)
|
1909
|
+
hidden_states = self.mlp(
|
1910
|
+
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
|
1911
|
+
)
|
1878
1912
|
|
1879
1913
|
if can_fuse_mlp_allreduce:
|
1880
1914
|
hidden_states._sglang_needs_allreduce_fusion = True
|
@@ -2051,6 +2085,8 @@ class DeepseekV2Model(nn.Module):
|
|
2051
2085
|
|
2052
2086
|
|
2053
2087
|
class DeepseekV2ForCausalLM(nn.Module):
|
2088
|
+
# for quark model load
|
2089
|
+
packed_modules_mapping = {}
|
2054
2090
|
|
2055
2091
|
def __init__(
|
2056
2092
|
self,
|
@@ -2059,6 +2095,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2059
2095
|
prefix: str = "",
|
2060
2096
|
) -> None:
|
2061
2097
|
super().__init__()
|
2098
|
+
|
2099
|
+
# for quark model load
|
2100
|
+
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
2101
|
+
self.fuse_qkv_a_proj = (
|
2102
|
+
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
|
2103
|
+
)
|
2104
|
+
if self.fuse_qkv_a_proj:
|
2105
|
+
self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
|
2106
|
+
"q_a_proj",
|
2107
|
+
"kv_a_proj_with_mqa",
|
2108
|
+
]
|
2109
|
+
|
2062
2110
|
self.config = config
|
2063
2111
|
self.tp_size = get_tensor_model_parallel_world_size()
|
2064
2112
|
self.quant_config = quant_config
|
@@ -2104,11 +2152,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2104
2152
|
or self.config.n_shared_experts != 1
|
2105
2153
|
):
|
2106
2154
|
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
2107
|
-
elif (
|
2108
|
-
|
2109
|
-
or global_server_args_dict["enable_ep_moe"]
|
2110
|
-
):
|
2111
|
-
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
2155
|
+
elif get_moe_expert_parallel_world_size() > 1:
|
2156
|
+
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
|
2112
2157
|
|
2113
2158
|
if disable_reason is not None:
|
2114
2159
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|