sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. sglang/compile_deep_gemm.py +8 -1
  2. sglang/global_config.py +5 -1
  3. sglang/srt/conversation.py +0 -112
  4. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  5. sglang/srt/disaggregation/prefill.py +1 -0
  6. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  7. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  8. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  9. sglang/srt/distributed/parallel_state.py +11 -0
  10. sglang/srt/entrypoints/engine.py +4 -2
  11. sglang/srt/entrypoints/http_server.py +35 -15
  12. sglang/srt/eplb/expert_distribution.py +4 -2
  13. sglang/srt/hf_transformers_utils.py +25 -10
  14. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  15. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  16. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  17. sglang/srt/layers/attention/vision.py +27 -10
  18. sglang/srt/layers/communicator.py +14 -4
  19. sglang/srt/layers/linear.py +7 -1
  20. sglang/srt/layers/logits_processor.py +9 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +11 -35
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
  24. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  25. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  26. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  27. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  28. sglang/srt/layers/moe/utils.py +43 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  31. sglang/srt/layers/quantization/fp8.py +5 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  33. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  34. sglang/srt/lora/lora_registry.py +7 -0
  35. sglang/srt/managers/cache_controller.py +8 -4
  36. sglang/srt/managers/data_parallel_controller.py +52 -2
  37. sglang/srt/managers/io_struct.py +6 -1
  38. sglang/srt/managers/schedule_batch.py +3 -2
  39. sglang/srt/managers/schedule_policy.py +3 -1
  40. sglang/srt/managers/scheduler.py +144 -6
  41. sglang/srt/managers/template_manager.py +25 -22
  42. sglang/srt/managers/tokenizer_manager.py +114 -62
  43. sglang/srt/managers/utils.py +45 -1
  44. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  45. sglang/srt/mem_cache/hicache_storage.py +13 -21
  46. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  47. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  48. sglang/srt/model_executor/cuda_graph_runner.py +17 -3
  49. sglang/srt/model_executor/forward_batch_info.py +13 -3
  50. sglang/srt/model_executor/model_runner.py +5 -0
  51. sglang/srt/models/deepseek_v2.py +23 -17
  52. sglang/srt/models/glm4_moe.py +82 -19
  53. sglang/srt/models/grok.py +3 -3
  54. sglang/srt/models/llama4.py +13 -2
  55. sglang/srt/models/mixtral.py +3 -3
  56. sglang/srt/models/mllama4.py +428 -19
  57. sglang/srt/models/qwen2_moe.py +1 -4
  58. sglang/srt/models/qwen3_moe.py +7 -8
  59. sglang/srt/models/step3_vl.py +1 -1
  60. sglang/srt/multimodal/processors/base_processor.py +4 -3
  61. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  62. sglang/srt/operations_strategy.py +1 -1
  63. sglang/srt/server_args.py +80 -20
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  65. sglang/srt/two_batch_overlap.py +6 -4
  66. sglang/srt/utils.py +3 -24
  67. sglang/srt/weight_sync/utils.py +1 -1
  68. sglang/test/runners.py +2 -2
  69. sglang/test/test_utils.py +3 -3
  70. sglang/version.py +1 -1
  71. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  72. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
  73. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  74. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  75. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  76. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  77. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  78. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  79. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  80. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,14 @@ from tqdm import tqdm
29
29
  from transformers import PretrainedConfig
30
30
 
31
31
  from sglang.srt.distributed import (
32
+ get_moe_expert_parallel_world_size,
32
33
  get_tensor_model_parallel_world_size,
33
34
  parallel_state,
34
35
  tensor_model_parallel_all_reduce,
35
36
  )
37
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
38
+ use_symmetric_memory,
39
+ )
36
40
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
37
41
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
38
42
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
@@ -61,7 +65,6 @@ from sglang.srt.layers.moe.ep_moe.layer import (
61
65
  get_moe_impl_class,
62
66
  should_use_flashinfer_trtllm_moe,
63
67
  )
64
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
65
68
  from sglang.srt.layers.moe.topk import TopK
66
69
  from sglang.srt.layers.quantization import deep_gemm_wrapper
67
70
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -96,7 +99,6 @@ from sglang.srt.two_batch_overlap import (
96
99
  )
97
100
  from sglang.srt.utils import (
98
101
  BumpAllocator,
99
- DeepEPMode,
100
102
  LazyValue,
101
103
  add_prefix,
102
104
  bind_or_assign,
@@ -333,15 +335,14 @@ class DeepseekV2MoE(nn.Module):
333
335
  routed_scaling_factor=self.routed_scaling_factor,
334
336
  prefix=add_prefix("experts", prefix),
335
337
  **(
336
- dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
337
- if global_server_args_dict["enable_deepep_moe"]
338
+ dict(deepep_mode=global_server_args_dict["deepep_mode"])
339
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
338
340
  else {}
339
341
  ),
340
342
  # Additional args for FusedMoE
341
343
  **(
342
344
  dict(
343
345
  enable_flashinfer_cutlass_moe=True,
344
- enable_ep_moe=global_server_args_dict["enable_ep_moe"],
345
346
  )
346
347
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]
347
348
  else {}
@@ -374,7 +375,7 @@ class DeepseekV2MoE(nn.Module):
374
375
  prefix=add_prefix("shared_experts", prefix),
375
376
  **(
376
377
  dict(tp_rank=0, tp_size=1)
377
- if global_server_args_dict["enable_deepep_moe"]
378
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
378
379
  else {}
379
380
  ),
380
381
  )
@@ -404,9 +405,9 @@ class DeepseekV2MoE(nn.Module):
404
405
 
405
406
  self.top_k = config.num_experts_per_tok
406
407
 
407
- if global_server_args_dict["enable_deepep_moe"]:
408
+ if global_server_args_dict["moe_a2a_backend"].is_deepep():
408
409
  # TODO: we will support tp < ep in the future
409
- self.ep_size = get_tensor_model_parallel_world_size()
410
+ self.ep_size = get_moe_expert_parallel_world_size()
410
411
  self.num_experts = (
411
412
  config.n_routed_experts
412
413
  + global_server_args_dict["ep_num_redundant_experts"]
@@ -428,12 +429,12 @@ class DeepseekV2MoE(nn.Module):
428
429
  num_local_experts=config.n_routed_experts // self.tp_size,
429
430
  hidden_size=config.hidden_size,
430
431
  params_dtype=config.torch_dtype,
431
- deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
432
+ deepep_mode=global_server_args_dict["deepep_mode"],
432
433
  async_finish=True,
433
434
  return_recv_hook=True,
434
435
  )
435
436
 
436
- self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
437
+ self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
437
438
 
438
439
  def get_moe_weights(self):
439
440
  return [
@@ -483,7 +484,11 @@ class DeepseekV2MoE(nn.Module):
483
484
  if not _is_cuda:
484
485
  final_hidden_states *= self.routed_scaling_factor
485
486
  current_stream.wait_stream(self.alt_stream)
486
- final_hidden_states += shared_output
487
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
488
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
489
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
490
+ final_hidden_states = final_hidden_states_out
491
+ sm.tag(final_hidden_states)
487
492
  if self.tp_size > 1 and not can_fuse_mlp_allreduce:
488
493
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
489
494
  return final_hidden_states
@@ -509,7 +514,11 @@ class DeepseekV2MoE(nn.Module):
509
514
  # fused in biased_grouped_topk so we can skip here
510
515
  final_hidden_states *= self.routed_scaling_factor
511
516
  if shared_output is not None:
512
- final_hidden_states = final_hidden_states + shared_output
517
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
518
+ final_hidden_states_out = torch.empty_like(final_hidden_states)
519
+ torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
520
+ final_hidden_states = final_hidden_states_out
521
+ sm.tag(final_hidden_states)
513
522
  if self.tp_size > 1 and not can_fuse_mlp_allreduce:
514
523
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
515
524
  return final_hidden_states
@@ -2104,11 +2113,8 @@ class DeepseekV2ForCausalLM(nn.Module):
2104
2113
  or self.config.n_shared_experts != 1
2105
2114
  ):
2106
2115
  disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2107
- elif (
2108
- global_server_args_dict["enable_deepep_moe"]
2109
- or global_server_args_dict["enable_ep_moe"]
2110
- ):
2111
- disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
2116
+ elif get_moe_expert_parallel_world_size() > 1:
2117
+ disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2112
2118
 
2113
2119
  if disable_reason is not None:
2114
2120
  global_server_args_dict["disable_shared_experts_fusion"] = True
@@ -23,6 +23,7 @@ from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
 
25
25
  from sglang.srt.distributed import (
26
+ get_moe_expert_parallel_world_size,
26
27
  get_tensor_model_parallel_rank,
27
28
  get_tensor_model_parallel_world_size,
28
29
  parallel_state,
@@ -50,7 +51,6 @@ from sglang.srt.layers.linear import (
50
51
  )
51
52
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
53
  from sglang.srt.layers.moe.ep_moe.layer import (
53
- DeepEPMoE,
54
54
  get_moe_impl_class,
55
55
  should_use_flashinfer_trtllm_moe,
56
56
  )
@@ -83,7 +83,6 @@ from sglang.srt.two_batch_overlap import (
83
83
  )
84
84
  from sglang.srt.utils import (
85
85
  BumpAllocator,
86
- DeepEPMode,
87
86
  LazyValue,
88
87
  add_prefix,
89
88
  bind_or_assign,
@@ -388,6 +387,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
388
387
  ):
389
388
  nn.Module.__init__(self)
390
389
  self.tp_size = get_tensor_model_parallel_world_size()
390
+ self.ep_size = get_moe_expert_parallel_world_size()
391
391
  self.routed_scaling_factor = config.routed_scaling_factor
392
392
  self.n_shared_experts = config.n_shared_experts
393
393
  self.num_fused_shared_experts = (
@@ -443,15 +443,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
443
443
  routed_scaling_factor=self.routed_scaling_factor,
444
444
  prefix=add_prefix("experts", prefix),
445
445
  **(
446
- dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
447
- if global_server_args_dict["enable_deepep_moe"]
446
+ dict(deepep_mode=global_server_args_dict["deepep_mode"])
447
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
448
448
  else {}
449
449
  ),
450
450
  # Additional args for FusedMoE
451
451
  **(
452
452
  dict(
453
453
  enable_flashinfer_cutlass_moe=True,
454
- enable_ep_moe=global_server_args_dict["enable_ep_moe"],
455
454
  )
456
455
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]
457
456
  else {}
@@ -482,11 +481,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
482
481
  quant_config=quant_config,
483
482
  reduce_results=False,
484
483
  prefix=add_prefix("shared_experts", prefix),
485
- **(
486
- dict(tp_rank=0, tp_size=1)
487
- if global_server_args_dict["enable_deepep_moe"]
488
- else {}
489
- ),
484
+ **(dict(tp_rank=0, tp_size=1) if self.ep_size > 1 else {}),
490
485
  )
491
486
  is_packed_weight = hasattr(
492
487
  self.shared_experts.gate_up_proj.quant_method, "quant_config"
@@ -502,9 +497,9 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
502
497
 
503
498
  self.top_k = config.num_experts_per_tok
504
499
 
505
- if global_server_args_dict["enable_deepep_moe"]:
500
+ if global_server_args_dict["moe_a2a_backend"].is_deepep():
506
501
  # TODO: we will support tp < ep in the future
507
- self.ep_size = get_tensor_model_parallel_world_size()
502
+ self.ep_size = get_moe_expert_parallel_world_size()
508
503
  self.num_experts = (
509
504
  config.n_routed_experts
510
505
  + global_server_args_dict["ep_num_redundant_experts"]
@@ -526,12 +521,83 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
526
521
  num_local_experts=config.n_routed_experts // self.tp_size,
527
522
  hidden_size=config.hidden_size,
528
523
  params_dtype=config.torch_dtype,
529
- deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
524
+ deepep_mode=global_server_args_dict["deepep_mode"],
530
525
  async_finish=True,
531
526
  return_recv_hook=True,
532
527
  )
533
528
 
534
- self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
529
+ self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
530
+
531
+ def forward_normal_dual_stream(
532
+ self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
533
+ ) -> torch.Tensor:
534
+
535
+ current_stream = torch.cuda.current_stream()
536
+ self.alt_stream.wait_stream(current_stream)
537
+ shared_output = self._forward_shared_experts(hidden_states)
538
+
539
+ with torch.cuda.stream(self.alt_stream):
540
+ # router_logits: (num_tokens, n_experts)
541
+ router_logits = self.gate(hidden_states)
542
+ kwargs = {"hidden_states": hidden_states}
543
+ if self.topk is not None:
544
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
545
+ else:
546
+ kwargs["router_logits"] = router_logits
547
+ final_hidden_states = self.experts(**kwargs)
548
+ if not _is_cuda:
549
+ final_hidden_states *= self.routed_scaling_factor
550
+ current_stream.wait_stream(self.alt_stream)
551
+
552
+ if self.ep_size > 1:
553
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
554
+ final_hidden_states = tensor_model_parallel_all_reduce(
555
+ final_hidden_states
556
+ )
557
+ final_hidden_states += shared_output
558
+ else:
559
+ final_hidden_states += shared_output
560
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
561
+ final_hidden_states = tensor_model_parallel_all_reduce(
562
+ final_hidden_states
563
+ )
564
+ return final_hidden_states
565
+
566
+ def forward_normal(
567
+ self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
568
+ ) -> torch.Tensor:
569
+ if hasattr(self, "shared_experts") and use_intel_amx_backend(
570
+ self.shared_experts.gate_up_proj
571
+ ):
572
+ return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
573
+
574
+ shared_output = self._forward_shared_experts(hidden_states)
575
+ # router_logits: (num_tokens, n_experts)
576
+ router_logits = self.gate(hidden_states)
577
+ kwargs = {"hidden_states": hidden_states}
578
+ if self.topk is not None:
579
+ kwargs["topk_output"] = self.topk(hidden_states, router_logits)
580
+ else:
581
+ kwargs["router_logits"] = router_logits
582
+ final_hidden_states = self.experts(**kwargs)
583
+ if not _is_cuda and not _use_aiter:
584
+ # fused in biased_grouped_topk so we can skip here
585
+ final_hidden_states *= self.routed_scaling_factor
586
+ if self.ep_size > 1:
587
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
588
+ final_hidden_states = tensor_model_parallel_all_reduce(
589
+ final_hidden_states
590
+ )
591
+ if shared_output is not None:
592
+ final_hidden_states += shared_output
593
+ else:
594
+ if shared_output is not None:
595
+ final_hidden_states += shared_output
596
+ if self.tp_size > 1 and not can_fuse_mlp_allreduce:
597
+ final_hidden_states = tensor_model_parallel_all_reduce(
598
+ final_hidden_states
599
+ )
600
+ return final_hidden_states
535
601
 
536
602
 
537
603
  class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
@@ -737,11 +803,8 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
737
803
  or self.config.n_shared_experts != 1
738
804
  ):
739
805
  disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
740
- elif (
741
- global_server_args_dict["enable_deepep_moe"]
742
- or global_server_args_dict["enable_ep_moe"]
743
- ):
744
- disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
806
+ elif get_moe_expert_parallel_world_size() > 1:
807
+ disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
745
808
 
746
809
  if disable_reason is not None:
747
810
  global_server_args_dict["disable_shared_experts_fusion"] = True
sglang/srt/models/grok.py CHANGED
@@ -29,6 +29,7 @@ from torch import nn
29
29
  from transformers import PretrainedConfig
30
30
 
31
31
  from sglang.srt.distributed import (
32
+ get_moe_expert_parallel_world_size,
32
33
  get_tensor_model_parallel_rank,
33
34
  get_tensor_model_parallel_world_size,
34
35
  tensor_model_parallel_all_gather,
@@ -117,7 +118,7 @@ class Grok1MoE(nn.Module):
117
118
  )
118
119
 
119
120
  kwargs = {}
120
- if global_server_args_dict["enable_ep_moe"]:
121
+ if get_moe_expert_parallel_world_size() > 1:
121
122
  MoEImpl = EPMoE
122
123
  else:
123
124
  MoEImpl = FusedMoE
@@ -616,8 +617,7 @@ class Grok1ForCausalLM(nn.Module):
616
617
 
617
618
  # Params for weights, fp8 weight scales, fp8 activation scales
618
619
  # (param_name, weight_name, expert_id, shard_id)
619
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
620
- expert_params_mapping = MoEImpl.make_expert_params_mapping(
620
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
621
621
  ckpt_gate_proj_name="w1",
622
622
  ckpt_down_proj_name="w2",
623
623
  ckpt_up_proj_name="w3",
@@ -241,13 +241,22 @@ class Llama4Attention(nn.Module):
241
241
  if self.use_qk_norm
242
242
  else None
243
243
  )
244
+
245
+ qkv_quant_config = quant_config
246
+ o_quant_config = quant_config
247
+ if quant_config and hasattr(quant_config, "ignore") and quant_config.ignore:
248
+ if add_prefix("q_proj", prefix) in quant_config.ignore:
249
+ qkv_quant_config = None
250
+ if add_prefix("o_proj", prefix) in quant_config.ignore:
251
+ o_quant_config = None
252
+
244
253
  self.qkv_proj = QKVParallelLinear(
245
254
  hidden_size=hidden_size,
246
255
  head_size=self.head_dim,
247
256
  total_num_heads=self.total_num_heads,
248
257
  total_num_kv_heads=self.total_num_kv_heads,
249
258
  bias=bias,
250
- quant_config=quant_config,
259
+ quant_config=qkv_quant_config,
251
260
  prefix=add_prefix("qkv_proj", prefix),
252
261
  tp_rank=attn_tp_rank,
253
262
  tp_size=attn_tp_size,
@@ -257,7 +266,7 @@ class Llama4Attention(nn.Module):
257
266
  input_size=self.total_num_heads * self.head_dim,
258
267
  output_size=hidden_size,
259
268
  bias=bias_o_proj,
260
- quant_config=quant_config,
269
+ quant_config=o_quant_config,
261
270
  prefix=add_prefix("o_proj", prefix),
262
271
  tp_rank=attn_tp_rank,
263
272
  tp_size=attn_tp_size,
@@ -406,6 +415,8 @@ class Llama4DecoderLayer(nn.Module):
406
415
  )
407
416
 
408
417
  def _is_moe_layer(self, layer_id: int) -> bool:
418
+ if self.config.interleave_moe_layer_step == 0:
419
+ return self.config.num_local_experts > 0
409
420
  return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
410
421
 
411
422
  def forward(
@@ -24,6 +24,7 @@ from torch import nn
24
24
  from transformers import MixtralConfig
25
25
 
26
26
  from sglang.srt.distributed import (
27
+ get_moe_expert_parallel_world_size,
27
28
  get_pp_group,
28
29
  get_tensor_model_parallel_world_size,
29
30
  tensor_model_parallel_all_reduce,
@@ -94,7 +95,7 @@ class MixtralMoE(nn.Module):
94
95
  renormalize=True,
95
96
  )
96
97
 
97
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
98
+ MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
98
99
  self.experts = MoEImpl(
99
100
  num_experts=num_experts,
100
101
  top_k=top_k,
@@ -398,8 +399,7 @@ class MixtralForCausalLM(nn.Module):
398
399
 
399
400
  # Params for weights, fp8 weight scales, fp8 activation scales
400
401
  # (param_name, weight_name, expert_id, shard_id)
401
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
402
- expert_params_mapping = MoEImpl.make_expert_params_mapping(
402
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
403
403
  ckpt_gate_proj_name="w1",
404
404
  ckpt_down_proj_name="w2",
405
405
  ckpt_up_proj_name="w3",