sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. sglang/compile_deep_gemm.py +8 -1
  2. sglang/global_config.py +5 -1
  3. sglang/srt/conversation.py +0 -112
  4. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  5. sglang/srt/disaggregation/prefill.py +1 -0
  6. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  7. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  8. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  9. sglang/srt/distributed/parallel_state.py +11 -0
  10. sglang/srt/entrypoints/engine.py +4 -2
  11. sglang/srt/entrypoints/http_server.py +35 -15
  12. sglang/srt/eplb/expert_distribution.py +4 -2
  13. sglang/srt/hf_transformers_utils.py +25 -10
  14. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  15. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  16. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  17. sglang/srt/layers/attention/vision.py +27 -10
  18. sglang/srt/layers/communicator.py +14 -4
  19. sglang/srt/layers/linear.py +7 -1
  20. sglang/srt/layers/logits_processor.py +9 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +11 -35
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
  24. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  25. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  26. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  27. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  28. sglang/srt/layers/moe/utils.py +43 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  31. sglang/srt/layers/quantization/fp8.py +5 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  33. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  34. sglang/srt/lora/lora_registry.py +7 -0
  35. sglang/srt/managers/cache_controller.py +8 -4
  36. sglang/srt/managers/data_parallel_controller.py +52 -2
  37. sglang/srt/managers/io_struct.py +6 -1
  38. sglang/srt/managers/schedule_batch.py +3 -2
  39. sglang/srt/managers/schedule_policy.py +3 -1
  40. sglang/srt/managers/scheduler.py +144 -6
  41. sglang/srt/managers/template_manager.py +25 -22
  42. sglang/srt/managers/tokenizer_manager.py +114 -62
  43. sglang/srt/managers/utils.py +45 -1
  44. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  45. sglang/srt/mem_cache/hicache_storage.py +13 -21
  46. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  47. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  48. sglang/srt/model_executor/cuda_graph_runner.py +17 -3
  49. sglang/srt/model_executor/forward_batch_info.py +13 -3
  50. sglang/srt/model_executor/model_runner.py +5 -0
  51. sglang/srt/models/deepseek_v2.py +23 -17
  52. sglang/srt/models/glm4_moe.py +82 -19
  53. sglang/srt/models/grok.py +3 -3
  54. sglang/srt/models/llama4.py +13 -2
  55. sglang/srt/models/mixtral.py +3 -3
  56. sglang/srt/models/mllama4.py +428 -19
  57. sglang/srt/models/qwen2_moe.py +1 -4
  58. sglang/srt/models/qwen3_moe.py +7 -8
  59. sglang/srt/models/step3_vl.py +1 -1
  60. sglang/srt/multimodal/processors/base_processor.py +4 -3
  61. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  62. sglang/srt/operations_strategy.py +1 -1
  63. sglang/srt/server_args.py +80 -20
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  65. sglang/srt/two_batch_overlap.py +6 -4
  66. sglang/srt/utils.py +3 -24
  67. sglang/srt/weight_sync/utils.py +1 -1
  68. sglang/test/runners.py +2 -2
  69. sglang/test/test_utils.py +3 -3
  70. sglang/version.py +1 -1
  71. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  72. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
  73. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  74. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  75. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  76. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  77. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  78. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  79. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  80. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ import torch
24
24
  from torch import nn
25
25
 
26
26
  from sglang.srt.distributed import (
27
+ get_moe_expert_parallel_world_size,
27
28
  get_pp_group,
28
29
  get_tensor_model_parallel_rank,
29
30
  get_tensor_model_parallel_world_size,
@@ -51,7 +52,6 @@ from sglang.srt.layers.linear import (
51
52
  )
52
53
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
53
54
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
54
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
55
55
  from sglang.srt.layers.moe.topk import TopK
56
56
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
57
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -72,7 +72,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
72
72
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
73
73
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
74
74
  from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
75
- from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty
75
+ from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
76
76
 
77
77
  Qwen3MoeConfig = None
78
78
 
@@ -113,15 +113,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
113
113
  quant_config=quant_config,
114
114
  prefix=add_prefix("experts", prefix),
115
115
  **(
116
- dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
117
- if global_server_args_dict["enable_deepep_moe"]
116
+ dict(deepep_mode=global_server_args_dict["deepep_mode"])
117
+ if global_server_args_dict["moe_a2a_backend"].is_deepep()
118
118
  else {}
119
119
  ),
120
120
  # Additional args for FusedMoE
121
121
  **(
122
122
  dict(
123
123
  enable_flashinfer_cutlass_moe=True,
124
- enable_ep_moe=global_server_args_dict["enable_ep_moe"],
125
124
  )
126
125
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]
127
126
  else {}
@@ -136,9 +135,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
136
135
  prefix=add_prefix("gate", prefix),
137
136
  )
138
137
 
139
- if global_server_args_dict["enable_deepep_moe"]:
138
+ if global_server_args_dict["moe_a2a_backend"].is_deepep():
140
139
  # TODO: we will support tp < ep in the future
141
- self.ep_size = get_tensor_model_parallel_world_size()
140
+ self.ep_size = get_moe_expert_parallel_world_size()
142
141
  self.num_experts = (
143
142
  config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
144
143
  )
@@ -148,7 +147,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
148
147
  self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
149
148
  ) -> torch.Tensor:
150
149
 
151
- if not global_server_args_dict["enable_deepep_moe"]:
150
+ if not global_server_args_dict["moe_a2a_backend"].is_deepep():
152
151
  return self.forward_normal(hidden_states)
153
152
  else:
154
153
  return self.forward_deepep(hidden_states, forward_batch)
@@ -146,7 +146,7 @@ class Step3TextMoEMLP(nn.Module):
146
146
  prefix=add_prefix("gate", prefix),
147
147
  )
148
148
 
149
- if global_server_args_dict["enable_deepep_moe"]:
149
+ if global_server_args_dict["moe_a2a_backend"].is_deepep():
150
150
  raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
151
151
 
152
152
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -12,7 +12,6 @@ import torch
12
12
  from PIL import Image
13
13
  from transformers import BaseImageProcessorFast
14
14
 
15
- from sglang.srt.managers.mm_utils import TransportProxyTensor
16
15
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
17
16
  from sglang.srt.utils import load_audio, load_image, load_video, logger
18
17
 
@@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC):
218
217
  kwargs["audio"] = audios
219
218
 
220
219
  processor = self._processor
221
- if hasattr(processor, "image_processor") and isinstance(
222
- processor.image_processor, BaseImageProcessorFast
220
+ if (
221
+ hasattr(processor, "image_processor")
222
+ and isinstance(processor.image_processor, BaseImageProcessorFast)
223
+ and not self.server_args.disable_fast_image_processor
223
224
  ):
224
225
  kwargs["device"] = "cuda"
225
226
  result = processor.__call__(
@@ -12,7 +12,6 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- import re
16
15
  from typing import Dict, List, Optional, Union
17
16
 
18
17
  from sglang.srt.managers.multimodal_processor import (
@@ -38,14 +37,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
38
37
  self.mm_tokens = MultimodalSpecialTokens(
39
38
  image_token="<image_soft_token>",
40
39
  image_token_id=hf_config.image_token_id,
41
- image_token_regex=re.compile(
42
- r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
43
- ),
44
40
  audio_token="<audio_soft_token>",
45
41
  audio_token_id=hf_config.audio_token_id,
46
- audio_token_regex=re.compile(
47
- r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
48
- ),
49
42
  ).build(_processor)
50
43
 
51
44
  async def process_mm_data_async(
@@ -4,7 +4,7 @@ from typing import List, Optional
4
4
  import torch
5
5
 
6
6
  from sglang.srt import operations
7
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPConfig
7
+ from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig
8
8
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
9
9
  from sglang.srt.operations import Operation
10
10
 
sglang/srt/server_args.py CHANGED
@@ -149,6 +149,7 @@ class ServerArgs:
149
149
  max_lora_rank: Optional[int] = None
150
150
  lora_target_modules: Optional[Union[set[str], List[str]]] = None
151
151
  lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
152
+ max_loaded_loras: Optional[int] = None
152
153
  max_loras_per_batch: int = 8
153
154
  lora_backend: str = "triton"
154
155
 
@@ -172,12 +173,11 @@ class ServerArgs:
172
173
 
173
174
  # Expert parallelism
174
175
  ep_size: int = 1
175
- enable_ep_moe: bool = False
176
- enable_deepep_moe: bool = False
176
+ moe_a2a_backend: Optional[Literal["deepep"]] = None
177
177
  enable_flashinfer_cutlass_moe: bool = False
178
178
  enable_flashinfer_trtllm_moe: bool = False
179
179
  enable_flashinfer_allreduce_fusion: bool = False
180
- deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
180
+ deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
181
181
  ep_num_redundant_experts: int = 0
182
182
  ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
183
183
  init_expert_location: str = "trivial"
@@ -219,6 +219,7 @@ class ServerArgs:
219
219
  enable_profile_cuda_graph: bool = False
220
220
  enable_cudagraph_gc: bool = False
221
221
  enable_nccl_nvls: bool = False
222
+ enable_symm_mem: bool = False
222
223
  enable_tokenizer_batch_encode: bool = False
223
224
  disable_outlines_disk_cache: bool = False
224
225
  disable_custom_all_reduce: bool = False
@@ -272,7 +273,27 @@ class ServerArgs:
272
273
  enable_pdmux: bool = False
273
274
  sm_group_num: int = 3
274
275
 
276
+ # Deprecated arguments
277
+ enable_ep_moe: bool = False
278
+ enable_deepep_moe: bool = False
279
+
275
280
  def __post_init__(self):
281
+
282
+ # Check deprecated arguments
283
+ def print_deprecated_warning(message: str):
284
+ logger.warning(f"\033[33m{message}\033[0m")
285
+
286
+ if self.enable_ep_moe:
287
+ self.ep_size = self.tp_size
288
+ print_deprecated_warning(
289
+ "NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
290
+ )
291
+ if self.enable_deepep_moe:
292
+ self.moe_a2a_backend = "deepep"
293
+ print_deprecated_warning(
294
+ "NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
295
+ )
296
+
276
297
  # Set missing default values
277
298
  if self.tokenizer_path is None:
278
299
  self.tokenizer_path = self.model_path
@@ -455,14 +476,13 @@ class ServerArgs:
455
476
  self.quantization == "modelopt_fp4"
456
477
  ), "modelopt_fp4 quantization is required for Flashinfer MOE"
457
478
  os.environ["TRTLLM_ENABLE_PDL"] = "1"
458
- if self.enable_ep_moe:
459
- self.ep_size = self.tp_size
460
- logger.warning(
461
- f"Flashinfer cutlass MoE and EP MoE are enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
462
- )
479
+ assert self.ep_size in [
480
+ 1,
481
+ self.tp_size,
482
+ ], "The expert parallel size must be 1 or the same as the tensor parallel size"
463
483
 
464
484
  # DeepEP MoE
465
- if self.enable_deepep_moe:
485
+ if self.moe_a2a_backend == "deepep":
466
486
  if self.deepep_mode == "normal":
467
487
  logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
468
488
  self.disable_cuda_graph = True
@@ -486,7 +506,7 @@ class ServerArgs:
486
506
  )
487
507
 
488
508
  if self.enable_eplb:
489
- assert self.enable_ep_moe or self.enable_deepep_moe
509
+ assert self.ep_size > 1 or self.moe_a2a_backend is not None
490
510
 
491
511
  if self.enable_expert_distribution_metrics and (
492
512
  self.expert_distribution_recorder_mode is None
@@ -1151,6 +1171,7 @@ class ServerArgs:
1151
1171
  choices=[
1152
1172
  "round_robin",
1153
1173
  "shortest_queue",
1174
+ "minimum_tokens",
1154
1175
  ],
1155
1176
  )
1156
1177
 
@@ -1218,6 +1239,12 @@ class ServerArgs:
1218
1239
  default=8,
1219
1240
  help="Maximum number of adapters for a running batch, include base-only request.",
1220
1241
  )
1242
+ parser.add_argument(
1243
+ "--max-loaded-loras",
1244
+ type=int,
1245
+ default=ServerArgs.max_loaded_loras,
1246
+ help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
1247
+ )
1221
1248
  parser.add_argument(
1222
1249
  "--lora-backend",
1223
1250
  type=str,
@@ -1354,30 +1381,27 @@ class ServerArgs:
1354
1381
  help="The expert parallelism size.",
1355
1382
  )
1356
1383
  parser.add_argument(
1357
- "--enable-ep-moe",
1358
- action="store_true",
1359
- help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
1384
+ "--moe-a2a-backend",
1385
+ type=str,
1386
+ choices=["deepep"],
1387
+ default=ServerArgs.moe_a2a_backend,
1388
+ help="Choose the backend for MoE A2A.",
1360
1389
  )
1361
1390
  parser.add_argument(
1362
1391
  "--enable-flashinfer-cutlass-moe",
1363
1392
  action="store_true",
1364
- help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
1393
+ help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
1365
1394
  )
1366
1395
  parser.add_argument(
1367
1396
  "--enable-flashinfer-trtllm-moe",
1368
1397
  action="store_true",
1369
- help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
1398
+ help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
1370
1399
  )
1371
1400
  parser.add_argument(
1372
1401
  "--enable-flashinfer-allreduce-fusion",
1373
1402
  action="store_true",
1374
1403
  help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
1375
1404
  )
1376
- parser.add_argument(
1377
- "--enable-deepep-moe",
1378
- action="store_true",
1379
- help="Enabling DeepEP MoE implementation for EP MoE.",
1380
- )
1381
1405
  parser.add_argument(
1382
1406
  "--deepep-mode",
1383
1407
  type=str,
@@ -1584,6 +1608,11 @@ class ServerArgs:
1584
1608
  action="store_true",
1585
1609
  help="Enable NCCL NVLS for prefill heavy requests when available.",
1586
1610
  )
1611
+ parser.add_argument(
1612
+ "--enable-symm-mem",
1613
+ action="store_true",
1614
+ help="Enable NCCL symmetric memory for fast collectives.",
1615
+ )
1587
1616
  parser.add_argument(
1588
1617
  "--enable-tokenizer-batch-encode",
1589
1618
  action="store_true",
@@ -1839,6 +1868,18 @@ class ServerArgs:
1839
1868
  help="Disable mmap while loading weight using safetensors.",
1840
1869
  )
1841
1870
 
1871
+ # Deprecated arguments
1872
+ parser.add_argument(
1873
+ "--enable-ep-moe",
1874
+ action="store_true",
1875
+ help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
1876
+ )
1877
+ parser.add_argument(
1878
+ "--enable-deepep-moe",
1879
+ action="store_true",
1880
+ help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
1881
+ )
1882
+
1842
1883
  @classmethod
1843
1884
  def from_cli_args(cls, args: argparse.Namespace):
1844
1885
  args.tp_size = args.tensor_parallel_size
@@ -1895,6 +1936,12 @@ class ServerArgs:
1895
1936
  if "Llama4" in model_arch:
1896
1937
  assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
1897
1938
 
1939
+ if "Gemma2ForCausalLM" in model_arch:
1940
+ # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
1941
+ # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
1942
+ logger.warning("Disable hybrid SWA memory for Gemma2ForCausalLM.")
1943
+ self.disable_hybrid_swa_memory = True
1944
+
1898
1945
  # Check LoRA
1899
1946
  self.check_lora_server_args()
1900
1947
 
@@ -1969,6 +2016,19 @@ class ServerArgs:
1969
2016
  self.max_lora_rank and self.lora_target_modules
1970
2017
  ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
1971
2018
 
2019
+ # Validate max_loaded_loras
2020
+ if self.max_loaded_loras is not None:
2021
+ assert self.max_loaded_loras >= self.max_loras_per_batch, (
2022
+ "max_loaded_loras should be greater than or equal to max_loras_per_batch. "
2023
+ f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
2024
+ )
2025
+ assert (
2026
+ not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
2027
+ ), (
2028
+ "The number of LoRA paths should not exceed max_loaded_loras. "
2029
+ f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
2030
+ )
2031
+
1972
2032
  def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
1973
2033
  larger_tp = max(decode_tp, prefill_tp)
1974
2034
  smaller_tp = min(decode_tp, prefill_tp)
@@ -142,6 +142,22 @@ class EAGLEDraftExtendCudaGraphRunner:
142
142
  self.global_num_tokens_for_logprob_gpu = None
143
143
  self.gathered_buffer = None
144
144
 
145
+ if hasattr(
146
+ self.model_runner.model_config.hf_config, "draft_vocab_size"
147
+ ): # llama_eagle
148
+ vocab_size = self.model_runner.model_config.hf_config.draft_vocab_size
149
+ elif hasattr(
150
+ self.model_runner.model_config.hf_config, "hot_vocab_size"
151
+ ): # llama_eagle3
152
+ vocab_size = self.model_runner.model_config.hf_config.hot_vocab_size
153
+ else:
154
+ vocab_size = self.model_runner.model_config.vocab_size
155
+
156
+ self.next_token_logits_buffer = torch.zeros(
157
+ (self.max_bs, vocab_size),
158
+ dtype=torch.float,
159
+ )
160
+
145
161
  # Capture
146
162
  try:
147
163
  with model_capture_mode():
@@ -189,6 +205,7 @@ class EAGLEDraftExtendCudaGraphRunner:
189
205
  out_cache_loc = self.out_cache_loc[:num_tokens]
190
206
  positions = self.positions[:num_tokens]
191
207
  hidden_states = self.hidden_states[:num_tokens]
208
+ next_token_logits_buffer = self.next_token_logits_buffer[:bs]
192
209
 
193
210
  if self.require_mlp_tp_gather:
194
211
  self.global_num_tokens_gpu.copy_(
@@ -238,6 +255,7 @@ class EAGLEDraftExtendCudaGraphRunner:
238
255
  input_ids=input_ids,
239
256
  req_pool_indices=req_pool_indices,
240
257
  seq_lens=seq_lens,
258
+ next_token_logits_buffer=next_token_logits_buffer,
241
259
  req_to_token_pool=self.model_runner.req_to_token_pool,
242
260
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
243
261
  out_cache_loc=out_cache_loc,
@@ -13,17 +13,18 @@ from sglang.srt.layers.communicator import (
13
13
  CommunicateSummableTensorPairFn,
14
14
  ScatterMode,
15
15
  )
16
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
16
+ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
17
+ from sglang.srt.layers.moe.utils import DeepEPMode
17
18
  from sglang.srt.layers.quantization import deep_gemm_wrapper
18
19
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
19
20
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
20
21
  from sglang.srt.operations import execute_operations, execute_overlapped_operations
21
22
  from sglang.srt.operations_strategy import OperationsStrategy
22
23
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
23
- from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
24
+ from sglang.srt.utils import BumpAllocator, get_bool_env_var
24
25
 
25
26
  if TYPE_CHECKING:
26
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import DispatchOutput
27
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
27
28
 
28
29
  _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
29
30
 
@@ -310,7 +311,7 @@ class TboDPAttentionPreparer:
310
311
  and not local_batch.forward_mode.is_target_verify()
311
312
  )
312
313
  and enable_deepep_moe
313
- and (resolved_deepep_mode == DeepEPMode.low_latency)
314
+ and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
314
315
  )
315
316
  else:
316
317
  self.local_tbo_split_seq_index = 0
@@ -563,6 +564,7 @@ class TboForwardBatchPreparer:
563
564
  mm_inputs=None,
564
565
  top_logprobs_nums=None,
565
566
  token_ids_logprobs=None,
567
+ next_token_logits_buffer=None,
566
568
  )
567
569
  )
568
570
 
sglang/srt/utils.py CHANGED
@@ -44,7 +44,6 @@ import traceback
44
44
  import warnings
45
45
  from collections import OrderedDict, defaultdict
46
46
  from contextlib import contextmanager
47
- from enum import Enum
48
47
  from functools import lru_cache
49
48
  from importlib.metadata import PackageNotFoundError, version
50
49
  from importlib.util import find_spec
@@ -93,6 +92,7 @@ logger = logging.getLogger(__name__)
93
92
  show_time_cost = False
94
93
  time_infos = {}
95
94
 
95
+
96
96
  HIP_FP8_E4M3_FNUZ_MAX = 224.0
97
97
 
98
98
 
@@ -2205,27 +2205,6 @@ def flatten_nested_list(nested_list):
2205
2205
  return [nested_list]
2206
2206
 
2207
2207
 
2208
- class DeepEPMode(Enum):
2209
- normal = "normal"
2210
- low_latency = "low_latency"
2211
- auto = "auto"
2212
-
2213
- def enable_normal(self):
2214
- return self in [DeepEPMode.normal, DeepEPMode.auto]
2215
-
2216
- def enable_low_latency(self):
2217
- return self in [DeepEPMode.low_latency, DeepEPMode.auto]
2218
-
2219
- def resolve(self, is_extend_in_batch: bool):
2220
- if self != DeepEPMode.auto:
2221
- return self
2222
-
2223
- if is_extend_in_batch:
2224
- return DeepEPMode.normal
2225
- else:
2226
- return DeepEPMode.low_latency
2227
-
2228
-
2229
2208
  def is_non_idle_and_non_empty(forward_mode, hidden_states):
2230
2209
  return (
2231
2210
  (forward_mode is not None)
@@ -2414,7 +2393,7 @@ def require_mlp_tp_gather(server_args):
2414
2393
  return True
2415
2394
  elif not server_args.enable_dp_lm_head:
2416
2395
  return True
2417
- elif not server_args.enable_deepep_moe:
2396
+ elif server_args.moe_a2a_backend is None:
2418
2397
  return True
2419
2398
  else:
2420
2399
  return (
@@ -2430,7 +2409,7 @@ def require_attn_tp_gather(server_args):
2430
2409
  Check if the input of attention is scattered.
2431
2410
  """
2432
2411
  assert server_args.moe_dense_tp_size in [1, None]
2433
- if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
2412
+ if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1:
2434
2413
  if server_args.enable_dp_attention:
2435
2414
  return server_args.dp_size < server_args.tp_size
2436
2415
  else:
@@ -45,7 +45,7 @@ async def update_weights(
45
45
  (
46
46
  name,
47
47
  MultiprocessingSerializer.serialize(
48
- _preprocess_tensor_for_update_weights(tensor)
48
+ _preprocess_tensor_for_update_weights(tensor.detach())
49
49
  ),
50
50
  )
51
51
  for name, tensor in params_batch
sglang/test/runners.py CHANGED
@@ -499,7 +499,6 @@ class SRTRunner:
499
499
  chunked_prefill_size: Optional[int] = None,
500
500
  dp_size: int = 1,
501
501
  tokenizer_path: Optional[str] = None,
502
- enable_ep_moe: bool = False,
503
502
  mem_fraction_static: float = 0.65,
504
503
  trust_remote_code: bool = False,
505
504
  speculative_draft_model_path: Optional[str] = None,
@@ -515,6 +514,7 @@ class SRTRunner:
515
514
  max_lora_rank: Optional[int] = None,
516
515
  lora_target_modules: Optional[List[str]] = None,
517
516
  enable_lora: Optional[bool] = None,
517
+ max_loaded_loras: Optional[int] = None,
518
518
  ):
519
519
  self.model_type = model_type
520
520
  self.is_generation = model_type == "generation"
@@ -550,7 +550,6 @@ class SRTRunner:
550
550
  enable_dp_attention=enable_dp_attention,
551
551
  dp_size=dp_size,
552
552
  tokenizer_path=tokenizer_path,
553
- enable_ep_moe=enable_ep_moe,
554
553
  disable_overlap_schedule=disable_overlap_schedule,
555
554
  cuda_graph_max_bs=cuda_graph_max_bs,
556
555
  disable_custom_all_reduce=disable_custom_all_reduce,
@@ -558,6 +557,7 @@ class SRTRunner:
558
557
  max_lora_rank=max_lora_rank,
559
558
  lora_target_modules=lora_target_modules,
560
559
  enable_lora=enable_lora,
560
+ max_loaded_loras=max_loaded_loras,
561
561
  **spec_kwargs,
562
562
  )
563
563
 
sglang/test/test_utils.py CHANGED
@@ -27,9 +27,6 @@ import torch.nn.functional as F
27
27
 
28
28
  from sglang.bench_serving import run_benchmark
29
29
  from sglang.global_config import global_config
30
- from sglang.lang.backend.openai import OpenAI
31
- from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
32
- from sglang.lang.interpreter import ProgramState
33
30
  from sglang.srt.utils import (
34
31
  get_bool_env_var,
35
32
  get_device,
@@ -358,6 +355,9 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
358
355
 
359
356
 
360
357
  def select_sglang_backend(args: argparse.Namespace):
358
+ from sglang.lang.backend.openai import OpenAI
359
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
360
+
361
361
  if args.backend.startswith("srt"):
362
362
  if args.backend == "srt-no-parallel":
363
363
  global_config.enable_parallel_encoding = False
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.10.post1"
1
+ __version__ = "0.4.10.post2"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sglang
3
- Version: 0.4.10.post1
3
+ Version: 0.4.10.post2
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -250,7 +250,7 @@ Requires-Dist: transformers==4.54.1; extra == "runtime-common"
250
250
  Requires-Dist: timm==1.0.16; extra == "runtime-common"
251
251
  Requires-Dist: uvicorn; extra == "runtime-common"
252
252
  Requires-Dist: uvloop; extra == "runtime-common"
253
- Requires-Dist: xgrammar==0.1.21; extra == "runtime-common"
253
+ Requires-Dist: xgrammar==0.1.22; extra == "runtime-common"
254
254
  Provides-Extra: srt
255
255
  Requires-Dist: sglang[runtime_common]; extra == "srt"
256
256
  Requires-Dist: sgl-kernel==0.2.8; extra == "srt"
@@ -301,6 +301,7 @@ Requires-Dist: matplotlib; extra == "test"
301
301
  Requires-Dist: pandas; extra == "test"
302
302
  Requires-Dist: peft; extra == "test"
303
303
  Requires-Dist: sentence_transformers; extra == "test"
304
+ Requires-Dist: pytest; extra == "test"
304
305
  Provides-Extra: all
305
306
  Requires-Dist: sglang[srt]; extra == "all"
306
307
  Requires-Dist: sglang[openai]; extra == "all"