sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import time
20
20
  from abc import ABC
21
21
  from collections import deque
22
22
  from contextlib import contextmanager
23
+ from pathlib import Path
23
24
  from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
24
25
 
25
26
  import einops
@@ -27,6 +28,7 @@ import torch
27
28
  import torch.distributed
28
29
 
29
30
  from sglang.srt.environ import envs
31
+ from sglang.srt.metrics.collector import ExpertDispatchCollector
30
32
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
31
33
  from sglang.srt.server_args import ServerArgs
32
34
  from sglang.srt.utils import Withable, is_npu
@@ -415,10 +417,19 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
415
417
 
416
418
  def collect(self) -> Dict:
417
419
  num_tokens = len(self._metadata["input_ids"])
420
+
421
+ global_physical_count = _convert_per_token_to_global_physical_count(
422
+ num_tokens,
423
+ num_layers=self._expert_location_metadata.num_layers,
424
+ num_physical_experts=self._expert_location_metadata.num_physical_experts,
425
+ _topk_ids_of_layer=self._topk_ids_of_layer,
426
+ )
427
+
418
428
  return dict(
419
429
  **self._metadata,
420
430
  topk_ids_of_layer=self._topk_ids_of_layer[:, :num_tokens, :].clone().cpu(),
421
431
  misc_objects=self._misc_objects,
432
+ global_physical_count=global_physical_count,
422
433
  )
423
434
 
424
435
 
@@ -547,6 +558,27 @@ class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
547
558
  self._data[layer_idx, :] += local_physical_count_of_layer
548
559
 
549
560
 
561
+ def _convert_per_token_to_global_physical_count(
562
+ num_tokens: int,
563
+ num_layers: int,
564
+ num_physical_experts: int,
565
+ _topk_ids_of_layer: torch.Tensor,
566
+ ) -> torch.Tensor:
567
+ topk_ids_layer_major = _topk_ids_of_layer[:, :num_tokens, :].reshape(num_layers, -1)
568
+ mask = topk_ids_layer_major != -1
569
+
570
+ index = topk_ids_layer_major.masked_fill(~mask, 0).long()
571
+ src = mask.int()
572
+
573
+ ans = torch.zeros(
574
+ (num_layers, num_physical_experts),
575
+ dtype=_topk_ids_of_layer.dtype,
576
+ device=_topk_ids_of_layer.device,
577
+ )
578
+ ans.scatter_add_(dim=1, index=index, src=src)
579
+ return ans
580
+
581
+
550
582
  def _convert_local_to_global_physical_count(
551
583
  local_physical_count: torch.Tensor,
552
584
  rank: int,
@@ -630,6 +662,10 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
630
662
  self.window_sizes = [10, 100, 1000]
631
663
  self._history = _DequeCollection(maxlens=self.window_sizes)
632
664
  self._rank = torch.distributed.get_rank()
665
+ self._expert_dispatch_collector = ExpertDispatchCollector(
666
+ self._expert_location_metadata.ep_size
667
+ )
668
+ self._collection_counter = 0
633
669
 
634
670
  def append(
635
671
  self,
@@ -661,6 +697,8 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
661
697
  )
662
698
 
663
699
  if self._rank == 0:
700
+ self._collect_metrics_if_needed(gpu_physical_count)
701
+
664
702
  utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
665
703
  utilization_rate = torch.mean(utilization_rate_tensor).item()
666
704
  self._history.append(utilization_rate)
@@ -676,6 +714,31 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
676
714
  # f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
677
715
  )
678
716
 
717
+ def _collect_metrics_if_needed(self, gpu_physical_count: torch.Tensor):
718
+ # sglang:eplb_gpu_physical_count metric is disabled if SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL <= 0
719
+ if (
720
+ envs.SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL > 0
721
+ and self._collection_counter % envs.SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL
722
+ == 0
723
+ ):
724
+ for layer_idx in range(self._expert_location_metadata.num_layers):
725
+ count_of_layer = (
726
+ self._expert_dispatch_collector.eplb_gpu_physical_count.labels(
727
+ layer=str(layer_idx)
728
+ )
729
+ )
730
+ # Exclude the +Inf bucket.
731
+ assert (
732
+ self._expert_location_metadata.ep_size
733
+ == len(count_of_layer._buckets) - 1
734
+ ), f"{self._expert_location_metadata.ep_size=}, {len(count_of_layer._buckets)=}"
735
+ for gpu_rank in range(self._expert_location_metadata.ep_size):
736
+ count = gpu_physical_count[layer_idx, gpu_rank]
737
+ if count > 0:
738
+ count_of_layer._sum.inc(count * gpu_rank)
739
+ count_of_layer._buckets[gpu_rank].inc(count)
740
+ self._collection_counter += 1
741
+
679
742
 
680
743
  class _DequeCollection:
681
744
  def __init__(self, maxlens: List[int]):
@@ -838,7 +901,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
838
901
 
839
902
 
840
903
  def _dump_to_file(name, data):
841
- save_dir = envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get()
904
+ save_dir = Path(envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get())
842
905
  path_output = save_dir / name
843
906
  logger.info(f"Write expert distribution to {path_output}")
844
907
  if not save_dir.exists():
@@ -85,7 +85,9 @@ class ExpertLocationMetadata:
85
85
  # -------------------------------- construction ------------------------------------
86
86
 
87
87
  @staticmethod
88
- def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
88
+ def init_trivial(
89
+ server_args: ServerArgs, model_config: ModelConfig, moe_ep_rank: int
90
+ ):
89
91
  """Trivial location - logical expert i corresponds to physical expert i"""
90
92
  common = ExpertLocationMetadata._init_common(server_args, model_config)
91
93
 
@@ -106,6 +108,7 @@ class ExpertLocationMetadata:
106
108
  server_args,
107
109
  model_config,
108
110
  physical_to_logical_map=physical_to_logical_map,
111
+ moe_ep_rank=moe_ep_rank,
109
112
  )
110
113
 
111
114
  @staticmethod
@@ -113,6 +116,7 @@ class ExpertLocationMetadata:
113
116
  server_args: ServerArgs,
114
117
  model_config: ModelConfig,
115
118
  physical_to_logical_map,
119
+ moe_ep_rank: int = None,
116
120
  ):
117
121
  if not isinstance(physical_to_logical_map, torch.Tensor):
118
122
  physical_to_logical_map = torch.tensor(physical_to_logical_map)
@@ -125,8 +129,11 @@ class ExpertLocationMetadata:
125
129
 
126
130
  model_config_for_expert_location = common["model_config_for_expert_location"]
127
131
  logical_to_all_physical_map = _compute_logical_to_all_physical_map(
128
- physical_to_logical_map,
132
+ server_args=server_args,
133
+ physical_to_logical_map=physical_to_logical_map,
129
134
  num_logical_experts=model_config_for_expert_location.num_logical_experts,
135
+ ep_size=common["ep_size"],
136
+ moe_ep_rank=moe_ep_rank,
130
137
  )
131
138
 
132
139
  return ExpertLocationMetadata._init_raw(
@@ -233,7 +240,7 @@ class ExpertLocationMetadata:
233
240
  compute_logical_to_rank_dispatch_physical_map(
234
241
  server_args=server_args,
235
242
  logical_to_all_physical_map=logical_to_all_physical_map,
236
- num_gpus=ep_size,
243
+ ep_size=ep_size,
237
244
  num_physical_experts=num_physical_experts,
238
245
  # TODO improve when we have real EP rank
239
246
  ep_rank=torch.distributed.get_rank() % ep_size,
@@ -303,7 +310,11 @@ def set_global_expert_location_metadata(value):
303
310
 
304
311
 
305
312
  def _compute_logical_to_all_physical_map(
306
- physical_to_logical_map: torch.Tensor, num_logical_experts: int
313
+ server_args: ServerArgs,
314
+ physical_to_logical_map: torch.Tensor,
315
+ num_logical_experts: int,
316
+ ep_size: int,
317
+ moe_ep_rank: int,
307
318
  ):
308
319
  # This is rarely called, so we use for loops for maximum clarity
309
320
 
@@ -312,6 +323,8 @@ def _compute_logical_to_all_physical_map(
312
323
  logical_to_all_physical_map = [
313
324
  [[] for _ in range(num_logical_experts)] for _ in range(num_layers)
314
325
  ]
326
+
327
+ # Find out the candidate physical experts for each logical expert on each layer
315
328
  for layer_id in range(num_layers):
316
329
  for physical_expert_id in range(num_physical_experts):
317
330
  logical_expert_id = physical_to_logical_map[
@@ -321,6 +334,32 @@ def _compute_logical_to_all_physical_map(
321
334
  physical_expert_id
322
335
  )
323
336
 
337
+ # Replace by the physical expert on local GPU or node if possible
338
+ if moe_ep_rank is not None:
339
+ num_gpus_per_node = server_args.ep_size // server_args.nnodes
340
+ num_local_gpu_physical_experts = num_physical_experts // ep_size
341
+ num_local_node_physical_experts = (
342
+ num_local_gpu_physical_experts * num_gpus_per_node
343
+ )
344
+ for layer_id in range(num_layers):
345
+ for logical_expert_id in range(num_logical_experts):
346
+ # Try to find the nearest physical expert
347
+ nearest_expert = _find_nearest_expert(
348
+ candidate_physical_expert_ids=logical_to_all_physical_map[layer_id][
349
+ logical_expert_id
350
+ ],
351
+ num_local_gpu_physical_experts=num_local_gpu_physical_experts,
352
+ moe_ep_rank=moe_ep_rank,
353
+ num_gpus_per_node=num_gpus_per_node,
354
+ num_local_node_physical_experts=num_local_node_physical_experts,
355
+ )
356
+
357
+ # Replace by the nearest physical expert
358
+ if nearest_expert != -1:
359
+ logical_to_all_physical_map[layer_id][logical_expert_id] = [
360
+ nearest_expert
361
+ ]
362
+
324
363
  logical_to_all_physical_map = _pad_nested_array(
325
364
  logical_to_all_physical_map, pad_value=-1
326
365
  )
@@ -343,21 +382,21 @@ def _pad_nested_array(arr, pad_value):
343
382
  def compute_logical_to_rank_dispatch_physical_map(
344
383
  server_args: ServerArgs,
345
384
  logical_to_all_physical_map: torch.Tensor,
346
- num_gpus: int,
385
+ ep_size: int,
347
386
  num_physical_experts: int,
348
387
  ep_rank: int,
349
388
  seed: int = 42,
350
389
  ):
351
390
  r = random.Random(seed)
352
391
 
353
- num_local_gpu_physical_experts = num_physical_experts // num_gpus
392
+ num_local_gpu_physical_experts = num_physical_experts // ep_size
354
393
  num_gpus_per_node = server_args.ep_size // server_args.nnodes
355
394
  num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
356
395
  num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
357
396
  dtype = logical_to_all_physical_map.dtype
358
397
 
359
398
  logical_to_rank_dispatch_physical_map = torch.full(
360
- size=(num_gpus, num_layers, num_logical_experts),
399
+ size=(ep_size, num_layers, num_logical_experts),
361
400
  fill_value=-1,
362
401
  dtype=dtype,
363
402
  )
@@ -371,33 +410,17 @@ def compute_logical_to_rank_dispatch_physical_map(
371
410
  :, layer_id, logical_expert_id
372
411
  ]
373
412
 
374
- for gpu_id in range(num_gpus):
375
- same_gpu_physical_expert_ids = [
376
- physical_expert_id
377
- for physical_expert_id in candidate_physical_expert_ids
378
- if _compute_gpu_id_of_physical_expert(
379
- physical_expert_id, num_local_gpu_physical_experts
380
- )
381
- == gpu_id
382
- ]
383
- if len(same_gpu_physical_expert_ids) > 0:
384
- # 1. Prefer same-GPU experts
385
- output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
386
- else:
387
- # 2. Otherwise, prefer same-node experts
388
- node_id = gpu_id // num_gpus_per_node
389
- same_node_physical_expert_ids = [
390
- physical_expert_id
391
- for physical_expert_id in candidate_physical_expert_ids
392
- if _compute_node_id_of_physical_expert(
393
- physical_expert_id, num_local_node_physical_experts
394
- )
395
- == node_id
396
- ]
397
- if len(same_node_physical_expert_ids) > 0:
398
- output_partial[gpu_id] = same_node_physical_expert_ids[0]
413
+ for moe_ep_rank in range(ep_size):
414
+ # Fill with the nearest physical expert
415
+ output_partial[moe_ep_rank] = _find_nearest_expert(
416
+ candidate_physical_expert_ids=candidate_physical_expert_ids,
417
+ num_local_gpu_physical_experts=num_local_gpu_physical_experts,
418
+ moe_ep_rank=moe_ep_rank,
419
+ num_gpus_per_node=num_gpus_per_node,
420
+ num_local_node_physical_experts=num_local_node_physical_experts,
421
+ )
399
422
 
400
- # 3. Fill remaining slots with fair random choices
423
+ # Fill remaining slots with fair random choices
401
424
  num_remain = torch.sum(output_partial == -1).item()
402
425
  output_partial[output_partial == -1] = torch.tensor(
403
426
  _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
@@ -434,6 +457,46 @@ def _compute_node_id_of_physical_expert(
434
457
  return physical_expert_id // num_local_host_physical_experts
435
458
 
436
459
 
460
+ def _find_nearest_expert(
461
+ candidate_physical_expert_ids: List[int],
462
+ num_local_gpu_physical_experts: int,
463
+ moe_ep_rank: int,
464
+ num_gpus_per_node: int,
465
+ num_local_node_physical_experts: int,
466
+ ) -> int:
467
+ # 1. If only one candidate, return it directly
468
+ if len(candidate_physical_expert_ids) == 1:
469
+ return candidate_physical_expert_ids[0]
470
+
471
+ # 2. Prefer same-GPU experts
472
+ same_gpu_physical_expert_ids = [
473
+ physical_expert_id
474
+ for physical_expert_id in candidate_physical_expert_ids
475
+ if _compute_gpu_id_of_physical_expert(
476
+ physical_expert_id, num_local_gpu_physical_experts
477
+ )
478
+ == moe_ep_rank
479
+ ]
480
+ if len(same_gpu_physical_expert_ids) > 0:
481
+ return same_gpu_physical_expert_ids[0]
482
+
483
+ # 3. Otherwise, prefer same-node experts
484
+ node_rank = moe_ep_rank // num_gpus_per_node
485
+ same_node_physical_expert_ids = [
486
+ physical_expert_id
487
+ for physical_expert_id in candidate_physical_expert_ids
488
+ if _compute_node_id_of_physical_expert(
489
+ physical_expert_id, num_local_node_physical_experts
490
+ )
491
+ == node_rank
492
+ ]
493
+ if len(same_node_physical_expert_ids) > 0:
494
+ return same_node_physical_expert_ids[0]
495
+
496
+ # 4. At last, leave it as -1 to indicate not found.
497
+ return -1
498
+
499
+
437
500
  def _fair_choices(arr: List, k: int, r: random.Random) -> List:
438
501
  quotient, remainder = divmod(k, len(arr))
439
502
  ans = arr * quotient + r.sample(arr, k=remainder)
@@ -459,11 +522,15 @@ class ModelConfigForExpertLocation:
459
522
 
460
523
 
461
524
  def compute_initial_expert_location_metadata(
462
- server_args: ServerArgs, model_config: ModelConfig
525
+ server_args: ServerArgs,
526
+ model_config: ModelConfig,
527
+ moe_ep_rank: int,
463
528
  ) -> Optional[ExpertLocationMetadata]:
464
529
  data = server_args.init_expert_location
465
530
  if data == "trivial":
466
- return ExpertLocationMetadata.init_trivial(server_args, model_config)
531
+ return ExpertLocationMetadata.init_trivial(
532
+ server_args, model_config, moe_ep_rank
533
+ )
467
534
 
468
535
  # TODO unify with the utils function
469
536
  if data.endswith(".pt"):
@@ -478,7 +545,10 @@ def compute_initial_expert_location_metadata(
478
545
  "init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
479
546
  )
480
547
  return ExpertLocationMetadata.init_by_mapping(
481
- server_args, model_config, **data_dict
548
+ server_args,
549
+ model_config,
550
+ **data_dict,
551
+ moe_ep_rank=moe_ep_rank,
482
552
  )
483
553
  elif "logical_count" in data_dict:
484
554
  logger.info(
@@ -16,6 +16,7 @@ from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector
16
16
  from sglang.srt.function_call.gpt_oss_detector import GptOssDetector
17
17
  from sglang.srt.function_call.kimik2_detector import KimiK2Detector
18
18
  from sglang.srt.function_call.llama32_detector import Llama32Detector
19
+ from sglang.srt.function_call.minimax_m2 import MinimaxM2Detector
19
20
  from sglang.srt.function_call.mistral_detector import MistralDetector
20
21
  from sglang.srt.function_call.pythonic_detector import PythonicDetector
21
22
  from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
@@ -49,6 +50,7 @@ class FunctionCallParser:
49
50
  "qwen25": Qwen25Detector,
50
51
  "qwen3_coder": Qwen3CoderDetector,
51
52
  "step3": Step3Detector,
53
+ "minimax-m2": MinimaxM2Detector,
52
54
  }
53
55
 
54
56
  def __init__(self, tools: List[Tool], tool_call_parser: str):