sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -535,6 +535,17 @@ class OpenAIServingChat(OpenAIServingBase):
535
535
  choices=[choice_data],
536
536
  model=request.model,
537
537
  )
538
+
539
+ # Add usage stats if continuous_usage_stats is enabled
540
+ if (
541
+ request.stream_options
542
+ and request.stream_options.continuous_usage_stats
543
+ ):
544
+ chunk.usage = UsageProcessor.calculate_token_usage(
545
+ prompt_tokens=prompt_tokens.get(index, 0),
546
+ completion_tokens=completion_tokens.get(index, 0),
547
+ )
548
+
538
549
  yield f"data: {chunk.model_dump_json()}\n\n"
539
550
 
540
551
  # Handle tool calls
@@ -579,6 +590,17 @@ class OpenAIServingChat(OpenAIServingBase):
579
590
  choices=[choice_data],
580
591
  model=request.model,
581
592
  )
593
+
594
+ # Add usage stats if continuous_usage_stats is enabled
595
+ if (
596
+ request.stream_options
597
+ and request.stream_options.continuous_usage_stats
598
+ ):
599
+ chunk.usage = UsageProcessor.calculate_token_usage(
600
+ prompt_tokens=prompt_tokens.get(index, 0),
601
+ completion_tokens=completion_tokens.get(index, 0),
602
+ )
603
+
582
604
  yield f"data: {chunk.model_dump_json()}\n\n"
583
605
 
584
606
  # Send finish_reason chunks for each index that completed
@@ -1056,6 +1078,16 @@ class OpenAIServingChat(OpenAIServingBase):
1056
1078
  choices=[choice_data],
1057
1079
  model=request.model,
1058
1080
  )
1081
+
1082
+ # Add usage stats if continuous_usage_stats is enabled
1083
+ if request.stream_options and request.stream_options.continuous_usage_stats:
1084
+ prompt_tokens = content["meta_info"].get("prompt_tokens", 0)
1085
+ completion_tokens = content["meta_info"].get("completion_tokens", 0)
1086
+ chunk.usage = UsageProcessor.calculate_token_usage(
1087
+ prompt_tokens=prompt_tokens,
1088
+ completion_tokens=completion_tokens,
1089
+ )
1090
+
1059
1091
  yield f"data: {chunk.model_dump_json()}\n\n"
1060
1092
 
1061
1093
  # Yield tool calls
@@ -1096,6 +1128,16 @@ class OpenAIServingChat(OpenAIServingBase):
1096
1128
  choices=[choice_data],
1097
1129
  model=request.model,
1098
1130
  )
1131
+
1132
+ # Add usage stats if continuous_usage_stats is enabled
1133
+ if request.stream_options and request.stream_options.continuous_usage_stats:
1134
+ prompt_tokens = content["meta_info"].get("prompt_tokens", 0)
1135
+ completion_tokens = content["meta_info"].get("completion_tokens", 0)
1136
+ chunk.usage = UsageProcessor.calculate_token_usage(
1137
+ prompt_tokens=prompt_tokens,
1138
+ completion_tokens=completion_tokens,
1139
+ )
1140
+
1099
1141
  yield f"data: {chunk.model_dump_json()}\n\n"
1100
1142
 
1101
1143
  def _check_for_unstreamed_tool_args(
@@ -272,6 +272,16 @@ class OpenAIServingCompletion(OpenAIServingBase):
272
272
  model=request.model,
273
273
  )
274
274
 
275
+ # Add usage stats if continuous_usage_stats is enabled
276
+ if (
277
+ request.stream_options
278
+ and request.stream_options.continuous_usage_stats
279
+ ):
280
+ chunk.usage = UsageProcessor.calculate_token_usage(
281
+ prompt_tokens=prompt_tokens.get(index, 0),
282
+ completion_tokens=completion_tokens.get(index, 0),
283
+ )
284
+
275
285
  yield f"data: {chunk.model_dump_json()}\n\n"
276
286
 
277
287
  if request.return_hidden_states and hidden_states:
@@ -126,6 +126,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
126
126
  **prompt_kwargs,
127
127
  rid=request.rid,
128
128
  priority=request.priority,
129
+ dimensions=request.dimensions,
129
130
  )
130
131
 
131
132
  return adapted_request, request
sglang/srt/environ.py CHANGED
@@ -129,10 +129,13 @@ class Envs:
129
129
  SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
130
130
  SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
131
131
  SGLANG_TORCH_PROFILER_DIR = EnvStr("/tmp")
132
+ SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS = EnvInt(500)
133
+ SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE = EnvInt(64)
132
134
 
133
135
  # Scheduler: memory leak test
134
136
  SGLANG_TEST_RETRACT = EnvBool(False)
135
137
  SGLANG_TEST_RETRACT_INTERVAL = EnvInt(3)
138
+ SGLANG_TEST_RETRACT_NO_PREFILL_BS = EnvInt(2 ** 31)
136
139
  SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK = EnvBool(False)
137
140
 
138
141
  # Scheduler: new token ratio hyperparameters
@@ -180,6 +183,7 @@ class Envs:
180
183
 
181
184
  # Triton
182
185
  SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS = EnvBool(False)
186
+ SGLANG_USE_CUSTOM_TRITON_KERNEL_CACHE = EnvBool(False)
183
187
 
184
188
  # Torch Compile
185
189
  SGLANG_ENABLE_TORCH_COMPILE = EnvBool(False)
@@ -238,6 +242,9 @@ class Envs:
238
242
  SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
239
243
  SGLANG_RESIZE_RESAMPLE = EnvStr("")
240
244
 
245
+ # Release & Resume Memory
246
+ SGLANG_MEMORY_SAVER_CUDA_GRAPH = EnvBool(False)
247
+
241
248
  # Ktransformers
242
249
  SGLANG_KT_MOE_NUM_GPU_EXPERTS = EnvInt(None)
243
250
  SGLANG_KT_MOE_CPUINFER = EnvInt(None)
@@ -20,6 +20,7 @@ import time
20
20
  from abc import ABC
21
21
  from collections import deque
22
22
  from contextlib import contextmanager
23
+ from pathlib import Path
23
24
  from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
24
25
 
25
26
  import einops
@@ -27,6 +28,7 @@ import torch
27
28
  import torch.distributed
28
29
 
29
30
  from sglang.srt.environ import envs
31
+ from sglang.srt.metrics.collector import ExpertDispatchCollector
30
32
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
31
33
  from sglang.srt.server_args import ServerArgs
32
34
  from sglang.srt.utils import Withable, is_npu
@@ -660,6 +662,10 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
660
662
  self.window_sizes = [10, 100, 1000]
661
663
  self._history = _DequeCollection(maxlens=self.window_sizes)
662
664
  self._rank = torch.distributed.get_rank()
665
+ self._expert_dispatch_collector = ExpertDispatchCollector(
666
+ self._expert_location_metadata.ep_size
667
+ )
668
+ self._collection_counter = 0
663
669
 
664
670
  def append(
665
671
  self,
@@ -691,6 +697,8 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
691
697
  )
692
698
 
693
699
  if self._rank == 0:
700
+ self._collect_metrics_if_needed(gpu_physical_count)
701
+
694
702
  utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
695
703
  utilization_rate = torch.mean(utilization_rate_tensor).item()
696
704
  self._history.append(utilization_rate)
@@ -706,6 +714,31 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
706
714
  # f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
707
715
  )
708
716
 
717
+ def _collect_metrics_if_needed(self, gpu_physical_count: torch.Tensor):
718
+ # sglang:eplb_gpu_physical_count metric is disabled if SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL <= 0
719
+ if (
720
+ envs.SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL > 0
721
+ and self._collection_counter % envs.SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL
722
+ == 0
723
+ ):
724
+ for layer_idx in range(self._expert_location_metadata.num_layers):
725
+ count_of_layer = (
726
+ self._expert_dispatch_collector.eplb_gpu_physical_count.labels(
727
+ layer=str(layer_idx)
728
+ )
729
+ )
730
+ # Exclude the +Inf bucket.
731
+ assert (
732
+ self._expert_location_metadata.ep_size
733
+ == len(count_of_layer._buckets) - 1
734
+ ), f"{self._expert_location_metadata.ep_size=}, {len(count_of_layer._buckets)=}"
735
+ for gpu_rank in range(self._expert_location_metadata.ep_size):
736
+ count = gpu_physical_count[layer_idx, gpu_rank]
737
+ if count > 0:
738
+ count_of_layer._sum.inc(count * gpu_rank)
739
+ count_of_layer._buckets[gpu_rank].inc(count)
740
+ self._collection_counter += 1
741
+
709
742
 
710
743
  class _DequeCollection:
711
744
  def __init__(self, maxlens: List[int]):
@@ -868,7 +901,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
868
901
 
869
902
 
870
903
  def _dump_to_file(name, data):
871
- save_dir = envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get()
904
+ save_dir = Path(envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get())
872
905
  path_output = save_dir / name
873
906
  logger.info(f"Write expert distribution to {path_output}")
874
907
  if not save_dir.exists():
@@ -85,7 +85,9 @@ class ExpertLocationMetadata:
85
85
  # -------------------------------- construction ------------------------------------
86
86
 
87
87
  @staticmethod
88
- def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
88
+ def init_trivial(
89
+ server_args: ServerArgs, model_config: ModelConfig, moe_ep_rank: int
90
+ ):
89
91
  """Trivial location - logical expert i corresponds to physical expert i"""
90
92
  common = ExpertLocationMetadata._init_common(server_args, model_config)
91
93
 
@@ -106,6 +108,7 @@ class ExpertLocationMetadata:
106
108
  server_args,
107
109
  model_config,
108
110
  physical_to_logical_map=physical_to_logical_map,
111
+ moe_ep_rank=moe_ep_rank,
109
112
  )
110
113
 
111
114
  @staticmethod
@@ -113,6 +116,7 @@ class ExpertLocationMetadata:
113
116
  server_args: ServerArgs,
114
117
  model_config: ModelConfig,
115
118
  physical_to_logical_map,
119
+ moe_ep_rank: int = None,
116
120
  ):
117
121
  if not isinstance(physical_to_logical_map, torch.Tensor):
118
122
  physical_to_logical_map = torch.tensor(physical_to_logical_map)
@@ -125,8 +129,11 @@ class ExpertLocationMetadata:
125
129
 
126
130
  model_config_for_expert_location = common["model_config_for_expert_location"]
127
131
  logical_to_all_physical_map = _compute_logical_to_all_physical_map(
128
- physical_to_logical_map,
132
+ server_args=server_args,
133
+ physical_to_logical_map=physical_to_logical_map,
129
134
  num_logical_experts=model_config_for_expert_location.num_logical_experts,
135
+ ep_size=common["ep_size"],
136
+ moe_ep_rank=moe_ep_rank,
130
137
  )
131
138
 
132
139
  return ExpertLocationMetadata._init_raw(
@@ -233,7 +240,7 @@ class ExpertLocationMetadata:
233
240
  compute_logical_to_rank_dispatch_physical_map(
234
241
  server_args=server_args,
235
242
  logical_to_all_physical_map=logical_to_all_physical_map,
236
- num_gpus=ep_size,
243
+ ep_size=ep_size,
237
244
  num_physical_experts=num_physical_experts,
238
245
  # TODO improve when we have real EP rank
239
246
  ep_rank=torch.distributed.get_rank() % ep_size,
@@ -303,7 +310,11 @@ def set_global_expert_location_metadata(value):
303
310
 
304
311
 
305
312
  def _compute_logical_to_all_physical_map(
306
- physical_to_logical_map: torch.Tensor, num_logical_experts: int
313
+ server_args: ServerArgs,
314
+ physical_to_logical_map: torch.Tensor,
315
+ num_logical_experts: int,
316
+ ep_size: int,
317
+ moe_ep_rank: int,
307
318
  ):
308
319
  # This is rarely called, so we use for loops for maximum clarity
309
320
 
@@ -312,6 +323,8 @@ def _compute_logical_to_all_physical_map(
312
323
  logical_to_all_physical_map = [
313
324
  [[] for _ in range(num_logical_experts)] for _ in range(num_layers)
314
325
  ]
326
+
327
+ # Find out the candidate physical experts for each logical expert on each layer
315
328
  for layer_id in range(num_layers):
316
329
  for physical_expert_id in range(num_physical_experts):
317
330
  logical_expert_id = physical_to_logical_map[
@@ -321,6 +334,32 @@ def _compute_logical_to_all_physical_map(
321
334
  physical_expert_id
322
335
  )
323
336
 
337
+ # Replace by the physical expert on local GPU or node if possible
338
+ if moe_ep_rank is not None:
339
+ num_gpus_per_node = server_args.ep_size // server_args.nnodes
340
+ num_local_gpu_physical_experts = num_physical_experts // ep_size
341
+ num_local_node_physical_experts = (
342
+ num_local_gpu_physical_experts * num_gpus_per_node
343
+ )
344
+ for layer_id in range(num_layers):
345
+ for logical_expert_id in range(num_logical_experts):
346
+ # Try to find the nearest physical expert
347
+ nearest_expert = _find_nearest_expert(
348
+ candidate_physical_expert_ids=logical_to_all_physical_map[layer_id][
349
+ logical_expert_id
350
+ ],
351
+ num_local_gpu_physical_experts=num_local_gpu_physical_experts,
352
+ moe_ep_rank=moe_ep_rank,
353
+ num_gpus_per_node=num_gpus_per_node,
354
+ num_local_node_physical_experts=num_local_node_physical_experts,
355
+ )
356
+
357
+ # Replace by the nearest physical expert
358
+ if nearest_expert != -1:
359
+ logical_to_all_physical_map[layer_id][logical_expert_id] = [
360
+ nearest_expert
361
+ ]
362
+
324
363
  logical_to_all_physical_map = _pad_nested_array(
325
364
  logical_to_all_physical_map, pad_value=-1
326
365
  )
@@ -343,21 +382,21 @@ def _pad_nested_array(arr, pad_value):
343
382
  def compute_logical_to_rank_dispatch_physical_map(
344
383
  server_args: ServerArgs,
345
384
  logical_to_all_physical_map: torch.Tensor,
346
- num_gpus: int,
385
+ ep_size: int,
347
386
  num_physical_experts: int,
348
387
  ep_rank: int,
349
388
  seed: int = 42,
350
389
  ):
351
390
  r = random.Random(seed)
352
391
 
353
- num_local_gpu_physical_experts = num_physical_experts // num_gpus
392
+ num_local_gpu_physical_experts = num_physical_experts // ep_size
354
393
  num_gpus_per_node = server_args.ep_size // server_args.nnodes
355
394
  num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
356
395
  num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
357
396
  dtype = logical_to_all_physical_map.dtype
358
397
 
359
398
  logical_to_rank_dispatch_physical_map = torch.full(
360
- size=(num_gpus, num_layers, num_logical_experts),
399
+ size=(ep_size, num_layers, num_logical_experts),
361
400
  fill_value=-1,
362
401
  dtype=dtype,
363
402
  )
@@ -371,33 +410,17 @@ def compute_logical_to_rank_dispatch_physical_map(
371
410
  :, layer_id, logical_expert_id
372
411
  ]
373
412
 
374
- for gpu_id in range(num_gpus):
375
- same_gpu_physical_expert_ids = [
376
- physical_expert_id
377
- for physical_expert_id in candidate_physical_expert_ids
378
- if _compute_gpu_id_of_physical_expert(
379
- physical_expert_id, num_local_gpu_physical_experts
380
- )
381
- == gpu_id
382
- ]
383
- if len(same_gpu_physical_expert_ids) > 0:
384
- # 1. Prefer same-GPU experts
385
- output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
386
- else:
387
- # 2. Otherwise, prefer same-node experts
388
- node_id = gpu_id // num_gpus_per_node
389
- same_node_physical_expert_ids = [
390
- physical_expert_id
391
- for physical_expert_id in candidate_physical_expert_ids
392
- if _compute_node_id_of_physical_expert(
393
- physical_expert_id, num_local_node_physical_experts
394
- )
395
- == node_id
396
- ]
397
- if len(same_node_physical_expert_ids) > 0:
398
- output_partial[gpu_id] = same_node_physical_expert_ids[0]
413
+ for moe_ep_rank in range(ep_size):
414
+ # Fill with the nearest physical expert
415
+ output_partial[moe_ep_rank] = _find_nearest_expert(
416
+ candidate_physical_expert_ids=candidate_physical_expert_ids,
417
+ num_local_gpu_physical_experts=num_local_gpu_physical_experts,
418
+ moe_ep_rank=moe_ep_rank,
419
+ num_gpus_per_node=num_gpus_per_node,
420
+ num_local_node_physical_experts=num_local_node_physical_experts,
421
+ )
399
422
 
400
- # 3. Fill remaining slots with fair random choices
423
+ # Fill remaining slots with fair random choices
401
424
  num_remain = torch.sum(output_partial == -1).item()
402
425
  output_partial[output_partial == -1] = torch.tensor(
403
426
  _fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
@@ -434,6 +457,46 @@ def _compute_node_id_of_physical_expert(
434
457
  return physical_expert_id // num_local_host_physical_experts
435
458
 
436
459
 
460
+ def _find_nearest_expert(
461
+ candidate_physical_expert_ids: List[int],
462
+ num_local_gpu_physical_experts: int,
463
+ moe_ep_rank: int,
464
+ num_gpus_per_node: int,
465
+ num_local_node_physical_experts: int,
466
+ ) -> int:
467
+ # 1. If only one candidate, return it directly
468
+ if len(candidate_physical_expert_ids) == 1:
469
+ return candidate_physical_expert_ids[0]
470
+
471
+ # 2. Prefer same-GPU experts
472
+ same_gpu_physical_expert_ids = [
473
+ physical_expert_id
474
+ for physical_expert_id in candidate_physical_expert_ids
475
+ if _compute_gpu_id_of_physical_expert(
476
+ physical_expert_id, num_local_gpu_physical_experts
477
+ )
478
+ == moe_ep_rank
479
+ ]
480
+ if len(same_gpu_physical_expert_ids) > 0:
481
+ return same_gpu_physical_expert_ids[0]
482
+
483
+ # 3. Otherwise, prefer same-node experts
484
+ node_rank = moe_ep_rank // num_gpus_per_node
485
+ same_node_physical_expert_ids = [
486
+ physical_expert_id
487
+ for physical_expert_id in candidate_physical_expert_ids
488
+ if _compute_node_id_of_physical_expert(
489
+ physical_expert_id, num_local_node_physical_experts
490
+ )
491
+ == node_rank
492
+ ]
493
+ if len(same_node_physical_expert_ids) > 0:
494
+ return same_node_physical_expert_ids[0]
495
+
496
+ # 4. At last, leave it as -1 to indicate not found.
497
+ return -1
498
+
499
+
437
500
  def _fair_choices(arr: List, k: int, r: random.Random) -> List:
438
501
  quotient, remainder = divmod(k, len(arr))
439
502
  ans = arr * quotient + r.sample(arr, k=remainder)
@@ -459,11 +522,15 @@ class ModelConfigForExpertLocation:
459
522
 
460
523
 
461
524
  def compute_initial_expert_location_metadata(
462
- server_args: ServerArgs, model_config: ModelConfig
525
+ server_args: ServerArgs,
526
+ model_config: ModelConfig,
527
+ moe_ep_rank: int,
463
528
  ) -> Optional[ExpertLocationMetadata]:
464
529
  data = server_args.init_expert_location
465
530
  if data == "trivial":
466
- return ExpertLocationMetadata.init_trivial(server_args, model_config)
531
+ return ExpertLocationMetadata.init_trivial(
532
+ server_args, model_config, moe_ep_rank
533
+ )
467
534
 
468
535
  # TODO unify with the utils function
469
536
  if data.endswith(".pt"):
@@ -478,7 +545,10 @@ def compute_initial_expert_location_metadata(
478
545
  "init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
479
546
  )
480
547
  return ExpertLocationMetadata.init_by_mapping(
481
- server_args, model_config, **data_dict
548
+ server_args,
549
+ model_config,
550
+ **data_dict,
551
+ moe_ep_rank=moe_ep_rank,
482
552
  )
483
553
  elif "logical_count" in data_dict:
484
554
  logger.info(
@@ -18,6 +18,9 @@ Options:
18
18
  ### Install Dependencies
19
19
  pip install "grpcio==1.75.1" "grpcio-tools==1.75.1"
20
20
 
21
+ Please make sure to use the same version of grpcio and grpcio-tools specified in pyproject.toml
22
+ otherwise update the versions specified in pyproject.toml
23
+
21
24
  ### Run Script
22
25
  cd python/sglang/srt/grpc
23
26
  python compile_proto.py