sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. sglang/compile_deep_gemm.py +8 -1
  2. sglang/global_config.py +5 -1
  3. sglang/srt/conversation.py +0 -112
  4. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  5. sglang/srt/disaggregation/prefill.py +1 -0
  6. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  7. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  8. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  9. sglang/srt/distributed/parallel_state.py +11 -0
  10. sglang/srt/entrypoints/engine.py +4 -2
  11. sglang/srt/entrypoints/http_server.py +35 -15
  12. sglang/srt/eplb/expert_distribution.py +4 -2
  13. sglang/srt/hf_transformers_utils.py +25 -10
  14. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  15. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  16. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  17. sglang/srt/layers/attention/vision.py +27 -10
  18. sglang/srt/layers/communicator.py +14 -4
  19. sglang/srt/layers/linear.py +7 -1
  20. sglang/srt/layers/logits_processor.py +9 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +11 -35
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
  24. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  25. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  26. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  27. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  28. sglang/srt/layers/moe/utils.py +43 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  31. sglang/srt/layers/quantization/fp8.py +5 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  33. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  34. sglang/srt/lora/lora_registry.py +7 -0
  35. sglang/srt/managers/cache_controller.py +8 -4
  36. sglang/srt/managers/data_parallel_controller.py +52 -2
  37. sglang/srt/managers/io_struct.py +6 -1
  38. sglang/srt/managers/schedule_batch.py +3 -2
  39. sglang/srt/managers/schedule_policy.py +3 -1
  40. sglang/srt/managers/scheduler.py +144 -6
  41. sglang/srt/managers/template_manager.py +25 -22
  42. sglang/srt/managers/tokenizer_manager.py +114 -62
  43. sglang/srt/managers/utils.py +45 -1
  44. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  45. sglang/srt/mem_cache/hicache_storage.py +13 -21
  46. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  47. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  48. sglang/srt/model_executor/cuda_graph_runner.py +17 -3
  49. sglang/srt/model_executor/forward_batch_info.py +13 -3
  50. sglang/srt/model_executor/model_runner.py +5 -0
  51. sglang/srt/models/deepseek_v2.py +23 -17
  52. sglang/srt/models/glm4_moe.py +82 -19
  53. sglang/srt/models/grok.py +3 -3
  54. sglang/srt/models/llama4.py +13 -2
  55. sglang/srt/models/mixtral.py +3 -3
  56. sglang/srt/models/mllama4.py +428 -19
  57. sglang/srt/models/qwen2_moe.py +1 -4
  58. sglang/srt/models/qwen3_moe.py +7 -8
  59. sglang/srt/models/step3_vl.py +1 -1
  60. sglang/srt/multimodal/processors/base_processor.py +4 -3
  61. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  62. sglang/srt/operations_strategy.py +1 -1
  63. sglang/srt/server_args.py +80 -20
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  65. sglang/srt/two_batch_overlap.py +6 -4
  66. sglang/srt/utils.py +3 -24
  67. sglang/srt/weight_sync/utils.py +1 -1
  68. sglang/test/runners.py +2 -2
  69. sglang/test/test_utils.py +3 -3
  70. sglang/version.py +1 -1
  71. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  72. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
  73. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  74. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  75. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  76. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  77. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  78. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  79. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  80. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -146,34 +146,3 @@ def triton_kernel_fused_experts(
146
146
  )
147
147
 
148
148
  return intermediate_cache3
149
-
150
-
151
- def triton_kernel_moe_forward_fake(
152
- hidden_states: torch.Tensor,
153
- w1: torch.Tensor,
154
- w2: torch.Tensor,
155
- gating_output: torch.Tensor,
156
- topk: int,
157
- renormalize: bool,
158
- inplace: bool = False,
159
- activation: str = "silu",
160
- apply_router_weight_on_input: bool = False,
161
- use_fp8_w8a8: bool = False,
162
- per_channel_quant: bool = False,
163
- global_num_experts: int = -1,
164
- expert_map: Optional[torch.Tensor] = None,
165
- w1_scale: Optional[torch.Tensor] = None,
166
- w2_scale: Optional[torch.Tensor] = None,
167
- a1_scale: Optional[torch.Tensor] = None,
168
- a2_scale: Optional[torch.Tensor] = None,
169
- block_shape: Optional[list[int]] = None,
170
- ) -> torch.Tensor:
171
- return torch.empty_like(hidden_states)
172
-
173
-
174
- direct_register_custom_op(
175
- op_name="forward_cuda_triton",
176
- op_func=triton_kernel_moe_forward,
177
- mutates_args=[],
178
- fake_impl=triton_kernel_moe_forward_fake,
179
- )
@@ -0,0 +1,23 @@
1
+ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
2
+ BaseDispatcher,
3
+ BaseDispatcherConfig,
4
+ DispatchOutput,
5
+ DispatchOutputFormat,
6
+ )
7
+ from sglang.srt.layers.moe.token_dispatcher.deepep import (
8
+ DeepEPConfig,
9
+ DeepEPDispatcher,
10
+ DeepEPLLOutput,
11
+ DeepEPNormalOutput,
12
+ )
13
+
14
+ __all__ = [
15
+ "BaseDispatcher",
16
+ "BaseDispatcherConfig",
17
+ "DispatchOutput",
18
+ "DispatchOutputFormat",
19
+ "DeepEPConfig",
20
+ "DeepEPDispatcher",
21
+ "DeepEPNormalOutput",
22
+ "DeepEPLLOutput",
23
+ ]
@@ -2,11 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from enum import Enum, auto
5
- from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
5
+ from typing import Protocol, runtime_checkable
6
6
 
7
7
  import torch
8
8
 
9
9
 
10
+ class MoEA2ABackend(Enum):
11
+ none = "none"
12
+ deepep = "deepep"
13
+
14
+ def is_none(self):
15
+ return self == MoEA2ABackend.none
16
+
17
+ def is_deepep(self):
18
+ return self == MoEA2ABackend.deepep
19
+
20
+
10
21
  class DispatchOutputFormat(Enum):
11
22
  standard = auto()
12
23
  deepep_normal = auto()
@@ -1,5 +1,3 @@
1
- # TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
2
-
3
1
  from __future__ import annotations
4
2
 
5
3
  import logging
@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
22
20
  DispatchOutput,
23
21
  DispatchOutputFormat,
24
22
  )
23
+ from sglang.srt.layers.moe.utils import DeepEPMode
25
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
26
25
  from sglang.srt.managers.schedule_batch import global_server_args_dict
27
- from sglang.srt.utils import (
28
- DeepEPMode,
29
- get_bool_env_var,
30
- get_int_env_var,
31
- is_hip,
32
- load_json_config,
33
- )
26
+ from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
34
27
 
35
28
  try:
36
29
  from deep_ep import Buffer, Config
@@ -150,9 +143,9 @@ class DeepEPBuffer:
150
143
  num_rdma_bytes,
151
144
  )
152
145
 
153
- if deepep_mode == DeepEPMode.normal:
146
+ if deepep_mode == DeepEPMode.NORMAL:
154
147
  num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
155
- elif deepep_mode in [DeepEPMode.low_latency, DeepEPMode.auto]:
148
+ elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
156
149
  num_qps_per_rank = num_experts // group.size()
157
150
  else:
158
151
  raise NotImplementedError
@@ -161,7 +154,7 @@ class DeepEPBuffer:
161
154
  device="cuda"
162
155
  ).multi_processor_count
163
156
  if (
164
- (deepep_mode != DeepEPMode.low_latency)
157
+ (deepep_mode != DeepEPMode.LOW_LATENCY)
165
158
  and not global_server_args_dict["enable_two_batch_overlap"]
166
159
  and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
167
160
  ):
@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
611
604
  num_local_experts: int = None,
612
605
  hidden_size: int = None,
613
606
  params_dtype: torch.dtype = None,
614
- deepep_mode: DeepEPMode = DeepEPMode.auto,
607
+ deepep_mode: DeepEPMode = DeepEPMode.AUTO,
615
608
  async_finish: bool = False,
616
609
  return_recv_hook: bool = False,
617
610
  ):
@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
697
690
  resolved_deepep_mode = self.deepep_mode.resolve(
698
691
  forward_batch.is_extend_in_batch
699
692
  )
700
- if resolved_deepep_mode == DeepEPMode.normal:
693
+ if resolved_deepep_mode == DeepEPMode.NORMAL:
701
694
  return self._normal_dispatcher
702
- elif resolved_deepep_mode == DeepEPMode.low_latency:
695
+ elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
703
696
  return self._low_latency_dispatcher
704
697
  else:
705
698
  raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
@@ -0,0 +1,43 @@
1
+ from enum import Enum
2
+
3
+
4
+ class MoeA2ABackend(Enum):
5
+
6
+ STANDARD = ("standard", "none")
7
+ DEEPEP = "deepep"
8
+
9
+ @classmethod
10
+ def _missing_(cls, value):
11
+ if value is None:
12
+ return cls.STANDARD
13
+ for member in cls:
14
+ if value in member.value:
15
+ return member
16
+ raise ValueError(f"No {cls.__name__} member for value {value}")
17
+
18
+ def is_deepep(self):
19
+ return self == MoeA2ABackend.DEEPEP
20
+
21
+ def is_standard(self):
22
+ return self == MoeA2ABackend.STANDARD
23
+
24
+
25
+ class DeepEPMode(Enum):
26
+ NORMAL = "normal"
27
+ LOW_LATENCY = "low_latency"
28
+ AUTO = "auto"
29
+
30
+ def enable_normal(self):
31
+ return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
32
+
33
+ def enable_low_latency(self):
34
+ return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
35
+
36
+ def resolve(self, is_extend_in_batch: bool):
37
+ if self != DeepEPMode.AUTO:
38
+ return self
39
+
40
+ if is_extend_in_batch:
41
+ return DeepEPMode.NORMAL
42
+ else:
43
+ return DeepEPMode.LOW_LATENCY
@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.utils import (
23
23
  from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
24
24
 
25
25
  if TYPE_CHECKING:
26
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
26
27
  from sglang.srt.layers.moe.topk import TopKOutput
27
28
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
28
29
  CompressedTensorsConfig,
@@ -189,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
189
190
  layer.w13_input_scale = None
190
191
  layer.w2_input_scale = None
191
192
 
192
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
193
+ def process_weights_after_loading(self, layer: FusedMoE) -> None:
193
194
  # Fp8 moe kernels require a single activation scale.
194
195
  # We take the max of all the scales in case they differ.
195
196
  if self.static_input_scales:
@@ -246,7 +247,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
246
247
  assert layer.w13_weight_scale is not None
247
248
  shard_size = layer.intermediate_size_per_partition
248
249
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
249
- for expert_id in range(layer.local_num_experts):
250
+ for expert_id in range(layer.num_local_experts):
250
251
  start = 0
251
252
  for shard_id in range(2):
252
253
  dq_weight = per_tensor_dequantize(
@@ -148,7 +148,7 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
148
148
  "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
149
149
  "N": n,
150
150
  "K": k,
151
- "NUM_GROUPS": 1,
151
+ "NUM_GROUPS": num_groups,
152
152
  "BLOCK_M": block_m,
153
153
  "BLOCK_N": block_n,
154
154
  "BLOCK_K": block_k,
@@ -1039,7 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1039
1039
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
1040
1040
 
1041
1041
  topk_weights, topk_ids, _ = topk_output
1042
- return cutlass_fused_experts_fp8(
1042
+ output = cutlass_fused_experts_fp8(
1043
1043
  x,
1044
1044
  layer.w13_weight.transpose(1, 2),
1045
1045
  layer.w2_weight.transpose(1, 2),
@@ -1062,6 +1062,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1062
1062
  self.problem_sizes2,
1063
1063
  use_fp8_blockscale=True,
1064
1064
  )
1065
+ # TODO: Fuse into select_experts
1066
+ if routed_scaling_factor is not None:
1067
+ output *= routed_scaling_factor
1068
+ return output
1065
1069
  # Expert fusion with FP8 quantization
1066
1070
  return fused_experts(
1067
1071
  x,
@@ -354,10 +354,6 @@ def sglang_per_token_group_quant_fp8(
354
354
  ), "the last dimension of `x` cannot be divisible by `group_size`"
355
355
  assert x.is_contiguous(), "`x` is not contiguous"
356
356
 
357
- if scale_ue8m0:
358
- # TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
359
- assert x.shape[-1] % (group_size * 4) == 0
360
-
361
357
  x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
362
358
  x_s = create_per_token_group_quant_fp8_output_scale(
363
359
  x_shape=x.shape,
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
11
11
  divide,
12
12
  get_tensor_model_parallel_rank,
13
13
  get_tensor_model_parallel_world_size,
14
+ parallel_state,
14
15
  tensor_model_parallel_all_reduce,
15
16
  )
17
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
18
+ use_symmetric_memory,
19
+ )
16
20
  from sglang.srt.layers.amx_utils import PackWeightMethod
17
21
  from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
18
22
  from sglang.srt.layers.parameter import BasevLLMParameter
@@ -464,7 +468,9 @@ class VocabParallelEmbedding(torch.nn.Module):
464
468
  else:
465
469
  masked_input = input_
466
470
  # Get the embeddings.
467
- output_parallel = self.quant_method.embedding(self, masked_input.long())
471
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
472
+ output_parallel = self.quant_method.embedding(self, masked_input.long())
473
+ sm.tag(output_parallel)
468
474
  # Mask the output embedding.
469
475
  if self.tp_size > 1:
470
476
  output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
@@ -186,3 +186,10 @@ class LoRARegistry:
186
186
  self._registry[lora_ref.lora_name] = lora_ref
187
187
  self._counters[lora_ref.lora_id] = ConcurrentCounter()
188
188
  return lora_ref
189
+
190
+ @property
191
+ def num_registered_loras(self) -> int:
192
+ """
193
+ Returns the total number of LoRA adapters currently registered.
194
+ """
195
+ return len(self._registry)
@@ -236,6 +236,7 @@ class HiCacheController:
236
236
  self.enable_storage = False
237
237
  # todo: move backend initialization to storage backend module
238
238
  if storage_backend is not None:
239
+ self.storage_backend_type = storage_backend
239
240
  from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
240
241
 
241
242
  if storage_backend == "file":
@@ -573,6 +574,9 @@ class HiCacheController:
573
574
  self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
574
575
  operation.increment(len(operation.hash_value) * self.page_size)
575
576
 
577
+ def is_mooncake_backend(self):
578
+ return self.storage_backend_type == "mooncake"
579
+
576
580
  def prefetch_io_aux_func(self):
577
581
  """
578
582
  Auxiliary function conducting IO operations for prefetching.
@@ -580,7 +584,7 @@ class HiCacheController:
580
584
  while not self.stop_event.is_set():
581
585
  try:
582
586
  operation = self.prefetch_buffer.get(block=True, timeout=1)
583
- if isinstance(self.storage_backend, MooncakeStore):
587
+ if self.is_mooncake_backend():
584
588
  self.mooncake_page_transfer(operation)
585
589
  else:
586
590
  self.generic_page_transfer(operation)
@@ -615,14 +619,14 @@ class HiCacheController:
615
619
  )
616
620
 
617
621
  # todo, more unified interface
618
- if not isinstance(self.storage_backend, MooncakeStore):
622
+ if not self.is_mooncake_backend():
619
623
  if not self.storage_backend.exists(last_hash):
620
624
  break
621
625
  hash_value.append(last_hash)
622
626
  storage_hit_count += self.page_size
623
627
  remaining_tokens -= self.page_size
624
628
 
625
- if isinstance(self.storage_backend, MooncakeStore):
629
+ if self.is_mooncake_backend():
626
630
  # deferring to batch exists for mooncake store
627
631
  exist_result = self.storage_backend.exists(hash_value)
628
632
  storage_hit_count = (
@@ -744,7 +748,7 @@ class HiCacheController:
744
748
  remaining_tokens -= self.page_size
745
749
  operation.hash_value = hash_value
746
750
 
747
- if isinstance(self.storage_backend, MooncakeStore):
751
+ if self.is_mooncake_backend():
748
752
  self.mooncake_page_backup(operation)
749
753
  else:
750
754
  self.generic_page_backup(operation)
@@ -16,9 +16,13 @@
16
16
  import logging
17
17
  import multiprocessing as mp
18
18
  import signal
19
+ import struct
20
+ import sys
19
21
  import threading
20
22
  import time
21
23
  from enum import Enum, auto
24
+ from multiprocessing import shared_memory
25
+ from typing import Dict, List
22
26
 
23
27
  import psutil
24
28
  import setproctitle
@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import (
32
36
  )
33
37
  from sglang.srt.managers.schedule_batch import Req
34
38
  from sglang.srt.managers.scheduler import run_scheduler_process
39
+ from sglang.srt.managers.utils import DPBalanceMeta
35
40
  from sglang.srt.server_args import PortArgs, ServerArgs
36
41
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
37
42
  from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):
45
50
 
46
51
  ROUND_ROBIN = auto()
47
52
  SHORTEST_QUEUE = auto()
53
+ MINIMUM_TOKENS = auto()
48
54
 
49
55
  @classmethod
50
56
  def from_str(cls, method: str):
@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum):
58
64
  class DataParallelController:
59
65
  """A controller that dispatches requests to multiple data parallel workers."""
60
66
 
61
- def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
67
+ def __init__(
68
+ self,
69
+ server_args: ServerArgs,
70
+ port_args: PortArgs,
71
+ dp_balance_meta: DPBalanceMeta,
72
+ ) -> None:
73
+ # for dp balance
74
+ self.global_balance_id = 0
75
+ self.balance_meta = dp_balance_meta
76
+
62
77
  # Parse args
63
78
  self.max_total_num_tokens = None
64
79
  self.server_args = server_args
@@ -79,6 +94,7 @@ class DataParallelController:
79
94
  dispatch_lookup = {
80
95
  LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
81
96
  LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
97
+ LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
82
98
  }
83
99
  self.dispatching = dispatch_lookup[self.load_balance_method]
84
100
 
@@ -234,6 +250,7 @@ class DataParallelController:
234
250
  pp_rank,
235
251
  dp_rank,
236
252
  writer,
253
+ self.balance_meta,
237
254
  ),
238
255
  )
239
256
  with memory_saver_adapter.configure_subprocess():
@@ -269,6 +286,33 @@ class DataParallelController:
269
286
  def shortest_queue_scheduler(self, input_requests):
270
287
  raise NotImplementedError()
271
288
 
289
+ def minimum_tokens_scheduler(self, req):
290
+ # This variable corresponds to the balance_id in TokenizedGenerateReqInput.
291
+ # We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
292
+ def get_next_global_balance_id() -> int:
293
+ INT32_MAX = 2147483647
294
+ current_id = self.global_balance_id
295
+ self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
296
+ return current_id
297
+
298
+ req.dp_balance_id = get_next_global_balance_id()
299
+ with self.balance_meta.mutex:
300
+ # 1. local_tokens represents the tokens currently inferring on the worker,
301
+ # while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
302
+ onfly_info = self.balance_meta.get_shared_onfly()
303
+ local_tokens = self.balance_meta.get_shared_local_tokens()
304
+ total_tokens = [
305
+ local_token + sum(onfly_dict.values())
306
+ for local_token, onfly_dict in zip(local_tokens, onfly_info)
307
+ ]
308
+ target_worker = total_tokens.index(min(total_tokens))
309
+ onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
310
+ # 2. write the new onfly info to the shm
311
+ self.balance_meta.set_shared_onfly_info(onfly_info)
312
+
313
+ # logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
314
+ self.workers[target_worker].send_pyobj(req)
315
+
272
316
  def event_loop(self):
273
317
  while True:
274
318
  while True:
@@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
302
346
  setproctitle.setproctitle("sglang::data_parallel_controller")
303
347
  configure_logger(server_args)
304
348
  parent_process = psutil.Process().parent()
349
+ balance_meta = DPBalanceMeta(server_args.dp_size)
305
350
 
306
351
  try:
307
- controller = DataParallelController(server_args, port_args)
352
+ controller = DataParallelController(
353
+ server_args, port_args, dp_balance_meta=balance_meta
354
+ )
308
355
  pipe_writer.send(
309
356
  {
310
357
  "status": "ready",
@@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
323
370
  traceback = get_exception_traceback()
324
371
  logger.error(f"DataParallelController hit an exception: {traceback}")
325
372
  parent_process.send_signal(signal.SIGQUIT)
373
+ finally:
374
+ # we need to destruct mp.Manager() in balance_meta
375
+ balance_meta.destructor()
@@ -523,6 +523,9 @@ class TokenizedGenerateReqInput:
523
523
  # For data parallel rank routing
524
524
  data_parallel_rank: Optional[int] = None
525
525
 
526
+ # For dp balance
527
+ dp_balance_id: int = -1
528
+
526
529
 
527
530
  @dataclass
528
531
  class EmbeddingReqInput:
@@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput:
648
651
  token_type_ids: List[int]
649
652
  # Dummy sampling params for compatibility
650
653
  sampling_params: SamplingParams
654
+ # For dp balance
655
+ dp_balance_id: int = -1
651
656
 
652
657
 
653
658
  @dataclass
@@ -1097,7 +1102,7 @@ class UnloadLoRAAdapterReqInput:
1097
1102
  class LoRAUpdateResult:
1098
1103
  success: bool
1099
1104
  error_message: Optional[str] = None
1100
- loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
1105
+ loaded_adapters: Optional[Dict[str, LoRARef]] = None
1101
1106
 
1102
1107
 
1103
1108
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
51
51
  ScheduleBatchDisaggregationDecodeMixin,
52
52
  )
53
53
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
54
+ from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
54
55
  from sglang.srt.mem_cache.allocator import (
55
56
  BaseTokenToKVPoolAllocator,
56
57
  SWATokenToKVPoolAllocator,
@@ -85,9 +86,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
85
86
  "enable_dp_attention",
86
87
  "enable_two_batch_overlap",
87
88
  "enable_dp_lm_head",
88
- "enable_deepep_moe",
89
+ "moe_a2a_backend",
89
90
  "deepep_mode",
90
- "enable_ep_moe",
91
91
  "enable_flashinfer_cutlass_moe",
92
92
  "enable_flashinfer_trtllm_moe",
93
93
  "enable_flashinfer_allreduce_fusion",
@@ -108,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
108
108
  "weight_loader_disable_mmap",
109
109
  "enable_triton_kernel_moe",
110
110
  "enable_multimodal",
111
+ "enable_symm_mem",
111
112
  ]
112
113
 
113
114
  # Put some global args for easy access
@@ -455,7 +455,9 @@ class PrefillAdder:
455
455
  if not self.is_hybrid:
456
456
  # Skip this logic for swa. The SWA has different memory management, and
457
457
  # this mechanism is underestimating the memory usage.
458
- cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
458
+ cur_rem_tokens = self.cur_rem_tokens - self.ceil_paged_tokens(
459
+ req.extend_input_len
460
+ )
459
461
  tokens_freed = 0
460
462
  for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
461
463
  # tokens_left gives a reservative calculation as the last token is not stored