sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,371 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # Adapted from
4
+ # https://github.com/vllm-project/vllm/blob/v0.8.2/vllm/platforms/interface.py
5
+
6
+ import enum
7
+ import platform
8
+ import random
9
+ from platform import uname
10
+ from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple, Union
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.server_args import ServerArgs
14
+ from sglang.srt.configs.model_config import ModelConfig
15
+
16
+ import logging
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def in_wsl() -> bool:
25
+ # Reference: https://github.com/microsoft/WSL/issues/4071
26
+ return "microsoft" in " ".join(uname()).lower()
27
+
28
+
29
+ class PlatformEnum(enum.Enum):
30
+ CUDA = enum.auto()
31
+ ROCM = enum.auto()
32
+ HPU = enum.auto()
33
+ XPU = enum.auto()
34
+ CPU = enum.auto()
35
+ OOT = enum.auto()
36
+ UNSPECIFIED = enum.auto()
37
+
38
+
39
+ class CpuArchEnum(enum.Enum):
40
+ X86 = enum.auto()
41
+ ARM = enum.auto()
42
+ POWERPC = enum.auto()
43
+ OTHER = enum.auto()
44
+ UNKNOWN = enum.auto()
45
+
46
+
47
+ class DeviceCapability(NamedTuple):
48
+ major: int
49
+ minor: int
50
+
51
+ def as_version_str(self) -> str:
52
+ return f"{self.major}.{self.minor}"
53
+
54
+ def to_int(self) -> int:
55
+ """
56
+ Express device capability as an integer ``<major><minor>``.
57
+
58
+ It is assumed that the minor version is always a single digit.
59
+ """
60
+ assert 0 <= self.minor < 10
61
+ return self.major * 10 + self.minor
62
+
63
+
64
+ class Platform:
65
+ _enum: PlatformEnum
66
+
67
+ # Real device name of current platform.
68
+ device_name: str
69
+
70
+ # For specifying torch device for cuda alike platform's capability.
71
+ device_type: str
72
+
73
+ # The torch.distributed backend on current platform
74
+ torch_distributed_backend: str
75
+
76
+ # The torch.compile backend for compiling simple and
77
+ # standalone functions. The default value is "inductor" to keep
78
+ # the same behavior as PyTorch.
79
+ torch_compile_backend: str = "inductor"
80
+
81
+ supported_quantization: list[str] = []
82
+
83
+ supported_speculative_algorithm: list[str] = []
84
+
85
+ # Use first element as default dtype
86
+ supported_dtype: list[str] = []
87
+
88
+ # Use first element as default backend
89
+ supported_attntion_backend: list[str] = []
90
+
91
+ # Use first element as default backend
92
+ supported_sampling_backend: list[str] = []
93
+
94
+ # Use first element as default backend
95
+ supported_lora_backend: list[str] = []
96
+
97
+ def is_cuda(self) -> bool:
98
+ return self._enum == PlatformEnum.CUDA
99
+
100
+ def is_rocm(self) -> bool:
101
+ return self._enum == PlatformEnum.ROCM
102
+
103
+ def is_hpu(self) -> bool:
104
+ return self._enum == PlatformEnum.HPU
105
+
106
+ def is_xpu(self) -> bool:
107
+ return self._enum == PlatformEnum.XPU
108
+
109
+ def is_cpu(self) -> bool:
110
+ return self._enum == PlatformEnum.CPU
111
+
112
+ def is_out_of_tree(self) -> bool:
113
+ return self._enum == PlatformEnum.OOT
114
+
115
+ def is_cuda_alike(self) -> bool:
116
+ """Stateless version of :func:`torch.cuda.is_available`."""
117
+ return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
118
+
119
+ @classmethod
120
+ def get_device_capability(
121
+ cls,
122
+ device_id: int = 0,
123
+ ) -> Optional[DeviceCapability]:
124
+ """Stateless version of :func:`torch.cuda.get_device_capability`."""
125
+ return None
126
+
127
+ @classmethod
128
+ def has_device_capability(
129
+ cls,
130
+ capability: Union[Tuple[int, int], int],
131
+ device_id: int = 0,
132
+ ) -> bool:
133
+ """
134
+ Test whether this platform is compatible with a device capability.
135
+
136
+ The ``capability`` argument can either be:
137
+
138
+ - A tuple ``(major, minor)``.
139
+ - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
140
+ """
141
+ current_capability = cls.get_device_capability(device_id=device_id)
142
+ if current_capability is None:
143
+ return False
144
+
145
+ if isinstance(capability, tuple):
146
+ return current_capability >= capability
147
+
148
+ return current_capability.to_int() >= capability
149
+
150
+ @classmethod
151
+ def get_device_module(cls) -> Any:
152
+ """Get `torch.device_module` like `torch.cuda` of current platform."""
153
+ raise NotImplementedError
154
+
155
+ @classmethod
156
+ def get_device_sku(cls, device_id: int = 0) -> str:
157
+ """Get the SKU name of a device."""
158
+ raise NotImplementedError
159
+
160
+ @classmethod
161
+ def get_device_uuid(cls, device_id: int = 0) -> str:
162
+ """Get the uuid of a device, e.g. the PCI bus ID."""
163
+ raise NotImplementedError
164
+
165
+ @classmethod
166
+ def get_device_core_count(cls, device_id: int = 0) -> str:
167
+ """Get the core count of a device, e.g. SMs of CUDA, CUs of ROCM."""
168
+ raise NotImplementedError
169
+
170
+ @classmethod
171
+ def get_device_count(cls) -> int:
172
+ """Get device count on current platform"""
173
+ raise NotImplementedError
174
+
175
+ @classmethod
176
+ def get_device_total_memory(cls, device_id: int = 0, distributed=False) -> float:
177
+ """
178
+ Get total memory for device_type:device_id device in gigabytes.
179
+ """
180
+ raise NotImplementedError
181
+
182
+ @classmethod
183
+ def get_device_available_memory(
184
+ cls, device_id: int = 0, distributed=False, empty_cache=True
185
+ ) -> float:
186
+ """
187
+ Get available memory for device_type:device_id device in gigabytes.
188
+ When distributed is True, the available memory is the minimum available memory of all GPUs.
189
+ """
190
+ raise NotImplementedError
191
+
192
+ @classmethod
193
+ def supports_overlap_scheduler(cls) -> bool:
194
+ """
195
+ Check if the current platform supports overlap scheduler
196
+ """
197
+ raise NotImplementedError
198
+
199
+ @classmethod
200
+ def seed_everything(cls, seed: Optional[int] = None) -> None:
201
+ """
202
+ Set the seed of each random module.
203
+ `torch.manual_seed` will set seed on all devices.
204
+
205
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
206
+ """
207
+ if seed is not None:
208
+ random.seed(seed)
209
+ np.random.seed(seed)
210
+ torch.manual_seed(seed)
211
+
212
+ @classmethod
213
+ def check_and_update_server_args(cls, server_args: ServerArgs) -> None:
214
+ """
215
+ Check and update the server arguments for the current platform.
216
+
217
+ It can raise an exception if the configuration is not compatible with
218
+ the current platform, or it can update the configuration to make it
219
+ compatible with the current platform.
220
+
221
+ The config is passed by reference, so it can be modified in place.
222
+ """
223
+ pass
224
+
225
+ @classmethod
226
+ def check_and_update_model_dtype(cls, model_config: ModelConfig, dtype: str) -> str:
227
+ """
228
+ Check and update the model's dtype for the current platform.
229
+ """
230
+ if cls.supported_dtype and dtype not in cls.supported_dtype:
231
+ logger.warning(
232
+ f"dtype {dtype} is currently not supported in "
233
+ f"{cls.device_name}. use {cls.supported_dtype[0]} instead"
234
+ )
235
+ return cls.supported_dtype[0]
236
+ return dtype
237
+
238
+ @classmethod
239
+ def check_and_update_attntion_backend(
240
+ cls, model_config: ModelConfig, backend: str
241
+ ) -> str:
242
+ """
243
+ Check and update the attntion backend for the current platform.
244
+ """
245
+ raise NotImplementedError
246
+
247
+ @classmethod
248
+ def check_and_update_sampling_backend(cls, backend: str) -> str:
249
+ """
250
+ Check and update the sampling backend for the current platform.
251
+ """
252
+ raise NotImplementedError
253
+
254
+ @classmethod
255
+ def check_and_update_lora_backend(cls, backend: str) -> str:
256
+ """
257
+ Check and update the lora backend for the current platform.
258
+ """
259
+ raise NotImplementedError
260
+
261
+ @classmethod
262
+ def verify_model_arch(cls, model_arch: str) -> None:
263
+ """
264
+ Verify whether the current platform supports the specified model
265
+ architecture.
266
+
267
+ - This will raise an Error or Warning based on the model support on
268
+ the current platform.
269
+ - By default all models are considered supported.
270
+ """
271
+ pass
272
+
273
+ @classmethod
274
+ def verify_quantization(cls, quant: str) -> None:
275
+ """
276
+ Verify whether the quantization is supported by the current platform.
277
+ """
278
+ if cls.supported_quantization and quant not in cls.supported_quantization:
279
+ raise ValueError(
280
+ f"{quant} quantization is currently not supported in "
281
+ f"{cls.device_name}."
282
+ )
283
+
284
+ @classmethod
285
+ def verify_speculative_algorithm(cls, algo: str) -> None:
286
+ """
287
+ Verify whether the speculative algorithm is supported by the current platform.
288
+ """
289
+ if (
290
+ cls.supported_speculative_algorithm
291
+ and algo not in cls.supported_speculative_algorithm
292
+ ):
293
+ raise ValueError(
294
+ f"speculative algorithm {algo} is currently not supported in "
295
+ f"{cls.device_name}."
296
+ )
297
+
298
+ @classmethod
299
+ def get_cpu_architecture(cls) -> CpuArchEnum:
300
+ """
301
+ Determine the CPU architecture of the current system.
302
+ Returns CpuArchEnum indicating the architecture type.
303
+ """
304
+ machine = platform.machine().lower()
305
+
306
+ if machine in ("x86_64", "amd64", "i386", "i686"):
307
+ return CpuArchEnum.X86
308
+ elif machine.startswith("arm") or machine.startswith("aarch"):
309
+ return CpuArchEnum.ARM
310
+ elif machine.startswith("ppc"):
311
+ return CpuArchEnum.POWERPC
312
+
313
+ return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN
314
+
315
+ @classmethod
316
+ def is_pin_memory_available(cls) -> bool:
317
+ """Checks whether pin memory is available on the current platform."""
318
+ if in_wsl():
319
+ # Pinning memory in WSL is not supported.
320
+ # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
321
+ logger.warning(
322
+ "Using 'pin_memory=False' as WSL is detected. "
323
+ "This may slow down the performance."
324
+ )
325
+ return False
326
+ return True
327
+
328
+ @classmethod
329
+ def get_device_communicator_cls(cls) -> str:
330
+ """
331
+ Get device specific communicator class for distributed communication.
332
+ """
333
+ raise NotImplementedError
334
+
335
+ @classmethod
336
+ def supports_fp8(cls) -> bool:
337
+ return False
338
+
339
+ @classmethod
340
+ def fp8_dtype(cls) -> torch.dtype:
341
+ """
342
+ Returns the preferred FP8 type on the current platform.
343
+ """
344
+ return torch.float8_e4m3fn
345
+
346
+ @classmethod
347
+ def fp8_min_max(cls) -> Tuple[float, float]:
348
+ """
349
+ Returns the preferred FP8 max value on the current platform.
350
+ """
351
+ fp8_max = torch.finfo(cls.fp8_dtype()).max
352
+ return (-fp8_max, fp8_max)
353
+
354
+ @classmethod
355
+ def is_triton_avaliable(cls) -> bool:
356
+ raise NotImplementedError
357
+
358
+ @classmethod
359
+ def init_environments(cls) -> None:
360
+ """
361
+ Init environments on current platform.
362
+
363
+ - Init platform specific env vars.
364
+ - Init platform specific patches.
365
+ """
366
+ pass
367
+
368
+
369
+ class UnspecifiedPlatform(Platform):
370
+ _enum = PlatformEnum.UNSPECIFIED
371
+ device_type = ""
sglang/srt/server_args.py CHANGED
@@ -15,15 +15,17 @@
15
15
 
16
16
  import argparse
17
17
  import dataclasses
18
+ import json
18
19
  import logging
19
20
  import os
20
21
  import random
21
22
  import tempfile
22
- from typing import List, Optional
23
+ from typing import List, Literal, Optional
23
24
 
24
25
  from sglang.srt.hf_transformers_utils import check_gguf_file
25
26
  from sglang.srt.reasoning_parser import ReasoningParser
26
27
  from sglang.srt.utils import (
28
+ configure_ipv6,
27
29
  get_amdgpu_memory_capacity,
28
30
  get_device,
29
31
  get_hpu_memory_capacity,
@@ -52,7 +54,7 @@ class ServerArgs:
52
54
  dtype: str = "auto"
53
55
  kv_cache_dtype: str = "auto"
54
56
  quantization: Optional[str] = None
55
- quantization_param_path: nullable_str = None
57
+ quantization_param_path: Optional[str] = None
56
58
  context_length: Optional[int] = None
57
59
  device: Optional[str] = None
58
60
  served_model_name: Optional[str] = None
@@ -126,21 +128,21 @@ class ServerArgs:
126
128
  # Kernel backend
127
129
  attention_backend: Optional[str] = None
128
130
  sampling_backend: Optional[str] = None
129
- grammar_backend: Optional[str] = "xgrammar"
131
+ grammar_backend: Optional[str] = None
130
132
 
131
133
  # Speculative decoding
132
134
  speculative_algorithm: Optional[str] = None
133
135
  speculative_draft_model_path: Optional[str] = None
134
- speculative_num_steps: int = 5
135
- speculative_eagle_topk: int = 4
136
- speculative_num_draft_tokens: int = 8
136
+ speculative_num_steps: Optional[int] = None
137
+ speculative_eagle_topk: Optional[int] = None
138
+ speculative_num_draft_tokens: Optional[int] = None
137
139
  speculative_accept_threshold_single: float = 1.0
138
140
  speculative_accept_threshold_acc: float = 1.0
139
141
  speculative_token_map: Optional[str] = None
140
142
 
141
143
  # Double Sparsity
142
144
  enable_double_sparsity: bool = False
143
- ds_channel_config_path: str = None
145
+ ds_channel_config_path: Optional[str] = None
144
146
  ds_heavy_channel_num: int = 32
145
147
  ds_heavy_token_num: int = 256
146
148
  ds_heavy_channel_type: str = "qk"
@@ -159,6 +161,7 @@ class ServerArgs:
159
161
  enable_dp_attention: bool = False
160
162
  enable_ep_moe: bool = False
161
163
  enable_deepep_moe: bool = False
164
+ deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
162
165
  enable_torch_compile: bool = False
163
166
  torch_compile_max_bs: int = 32
164
167
  cuda_graph_max_bs: Optional[int] = None
@@ -173,13 +176,15 @@ class ServerArgs:
173
176
  enable_memory_saver: bool = False
174
177
  allow_auto_truncate: bool = False
175
178
  enable_custom_logit_processor: bool = False
176
- tool_call_parser: str = None
179
+ tool_call_parser: Optional[str] = None
177
180
  enable_hierarchical_cache: bool = False
178
181
  hicache_ratio: float = 2.0
179
- enable_flashinfer_mla: bool = False
182
+ enable_flashinfer_mla: bool = False # TODO: remove this argument
180
183
  enable_flashmla: bool = False
181
184
  flashinfer_mla_disable_ragged: bool = False
182
185
  warmups: Optional[str] = None
186
+ n_share_experts_fusion: int = 0
187
+ disable_shared_experts_fusion: bool = False
183
188
 
184
189
  # Debug tensor dumps
185
190
  debug_tensor_dump_output_folder: Optional[str] = None
@@ -191,6 +196,13 @@ class ServerArgs:
191
196
  disaggregation_bootstrap_port: int = 8998
192
197
 
193
198
  def __post_init__(self):
199
+ # Expert parallelism
200
+ if self.enable_ep_moe:
201
+ self.ep_size = self.tp_size
202
+ logger.info(
203
+ f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
204
+ )
205
+
194
206
  # Set missing default values
195
207
  if self.tokenizer_path is None:
196
208
  self.tokenizer_path = self.model_path
@@ -214,6 +226,9 @@ class ServerArgs:
214
226
  # GPU memory is not known yet or no GPU is available.
215
227
  gpu_mem = None
216
228
 
229
+ if is_hip():
230
+ self.disable_shared_experts_fusion = True
231
+
217
232
  # Set mem fraction static, which depends on the tensor parallelism size
218
233
  if self.mem_fraction_static is None:
219
234
  if self.tp_size >= 16:
@@ -252,15 +267,11 @@ class ServerArgs:
252
267
  else:
253
268
  self.cuda_graph_max_bs = 160
254
269
 
255
- # Choose kernel backends
270
+ # Set kernel backends for hpu device
256
271
  if self.device == "hpu":
257
272
  self.attention_backend = "torch_native"
258
273
  self.sampling_backend = "pytorch"
259
274
 
260
- if self.attention_backend is None:
261
- self.attention_backend = (
262
- "flashinfer" if is_flashinfer_available() else "triton"
263
- )
264
275
  if self.sampling_backend is None:
265
276
  self.sampling_backend = (
266
277
  "flashinfer" if is_flashinfer_available() else "pytorch"
@@ -272,6 +283,10 @@ class ServerArgs:
272
283
  )
273
284
  self.disable_cuda_graph = True
274
285
 
286
+ # Choose grammar backend
287
+ if self.grammar_backend is None:
288
+ self.grammar_backend = "xgrammar"
289
+
275
290
  # Expert parallelism
276
291
  if self.enable_ep_moe:
277
292
  self.ep_size = self.tp_size
@@ -290,12 +305,21 @@ class ServerArgs:
290
305
  logger.warning(
291
306
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
292
307
  )
293
- # DeepEP MoE
294
- if self.enable_deepep_moe:
295
- self.ep_size = self.dp_size
296
- logger.info(
297
- f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the data parallel size[{self.dp_size}]."
298
- )
308
+
309
+ self.enable_sp_layernorm = False
310
+ # DeepEP MoE
311
+ if self.enable_deepep_moe:
312
+ if self.deepep_mode == "auto":
313
+ assert (
314
+ not self.enable_dp_attention
315
+ ), "DeepEP MoE `auto` mode is not supported with DP Attention."
316
+ self.ep_size = self.tp_size
317
+ self.enable_sp_layernorm = (
318
+ self.dp_size < self.tp_size if self.enable_dp_attention else True
319
+ )
320
+ logger.info(
321
+ f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
322
+ )
299
323
 
300
324
  # Speculative Decoding
301
325
  if self.speculative_algorithm == "NEXTN":
@@ -307,12 +331,29 @@ class ServerArgs:
307
331
  or self.speculative_algorithm == "EAGLE3"
308
332
  ):
309
333
  if self.max_running_requests is None:
310
- self.max_running_requests = 32
334
+ self.max_running_requests = 48
311
335
  self.disable_overlap_schedule = True
312
336
  logger.info(
313
337
  "Overlap scheduler is disabled because of using "
314
338
  "eagle speculative decoding."
315
339
  )
340
+
341
+ # Auto choose parameters
342
+ if self.speculative_num_steps is None:
343
+ assert (
344
+ self.speculative_eagle_topk is None
345
+ and self.speculative_num_draft_tokens is None
346
+ )
347
+ (
348
+ self.speculative_num_steps,
349
+ self.speculative_eagle_topk,
350
+ self.speculative_num_draft_tokens,
351
+ ) = auto_choose_speculative_params(self)
352
+
353
+ if self.page_size > 1 and self.speculative_eagle_topk > 1:
354
+ self.speculative_eagle_topk = 1
355
+ logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
356
+
316
357
  # The token generated from the verify step is counted.
317
358
  # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
318
359
  # assert self.speculative_num_steps < self.speculative_num_draft_tokens
@@ -456,6 +497,7 @@ class ServerArgs:
456
497
  "modelopt",
457
498
  "w8a8_int8",
458
499
  "w8a8_fp8",
500
+ "moe_wna16",
459
501
  ],
460
502
  help="The quantization method.",
461
503
  )
@@ -789,14 +831,14 @@ class ServerArgs:
789
831
  parser.add_argument(
790
832
  "--grammar-backend",
791
833
  type=str,
792
- choices=["xgrammar", "outlines", "llguidance"],
834
+ choices=["xgrammar", "outlines", "llguidance", "none"],
793
835
  default=ServerArgs.grammar_backend,
794
836
  help="Choose the backend for grammar-guided decoding.",
795
837
  )
796
838
  parser.add_argument(
797
839
  "--enable-flashinfer-mla",
798
840
  action="store_true",
799
- help="Enable FlashInfer MLA optimization",
841
+ help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
800
842
  )
801
843
  parser.add_argument(
802
844
  "--enable-flashmla",
@@ -1054,6 +1096,25 @@ class ServerArgs:
1054
1096
  action="store_true",
1055
1097
  help="Enabling DeepEP MoE implementation for EP MoE.",
1056
1098
  )
1099
+ parser.add_argument(
1100
+ "--deepep-mode",
1101
+ type=str,
1102
+ choices=["normal", "low_latency", "auto"],
1103
+ help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
1104
+ )
1105
+
1106
+ parser.add_argument(
1107
+ "--n-share-experts-fusion",
1108
+ type=int,
1109
+ default=0,
1110
+ help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
1111
+ "we use tp_size by default.",
1112
+ )
1113
+ parser.add_argument(
1114
+ "--disable-shared-experts-fusion",
1115
+ action="store_true",
1116
+ help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
1117
+ )
1057
1118
 
1058
1119
  # Server warmups
1059
1120
  parser.add_argument(
@@ -1200,8 +1261,12 @@ class PortArgs:
1200
1261
  # DP attention. Use TCP + port to handle both single-node and multi-node.
1201
1262
  if server_args.nnodes == 1 and server_args.dist_init_addr is None:
1202
1263
  dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
1264
+ elif server_args.dist_init_addr.startswith("["): # ipv6 address
1265
+ port_num, host = configure_ipv6(server_args.dist_init_addr)
1266
+ dist_init_addr = (host, str(port_num))
1203
1267
  else:
1204
1268
  dist_init_addr = server_args.dist_init_addr.split(":")
1269
+
1205
1270
  assert (
1206
1271
  len(dist_init_addr) == 2
1207
1272
  ), "please provide --dist-init-addr as host:port of head node"
@@ -1210,10 +1275,10 @@ class PortArgs:
1210
1275
  port_base = int(dist_init_port) + 1
1211
1276
  if dp_rank is None:
1212
1277
  scheduler_input_port = (
1213
- port_base + 2
1278
+ port_base + 3
1214
1279
  ) # TokenizerManager to DataParallelController
1215
1280
  else:
1216
- scheduler_input_port = port_base + 2 + 1 + dp_rank
1281
+ scheduler_input_port = port_base + 3 + 1 + dp_rank
1217
1282
 
1218
1283
  return PortArgs(
1219
1284
  tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
@@ -1243,3 +1308,33 @@ class DeprecatedAction(argparse.Action):
1243
1308
 
1244
1309
  def __call__(self, parser, namespace, values, option_string=None):
1245
1310
  raise ValueError(self.help)
1311
+
1312
+
1313
+ def auto_choose_speculative_params(self: ServerArgs):
1314
+ """
1315
+ Automatically choose the parameters for speculative decoding.
1316
+
1317
+ You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
1318
+ """
1319
+ if self.decrypted_config_file:
1320
+ config_path = self.decrypted_config_file
1321
+ else:
1322
+ config_path = os.path.join(self.model_path, "config.json")
1323
+ if not os.path.exists(config_path):
1324
+ raise ValueError(f"{config_path} is not found.")
1325
+
1326
+ config = json.load(open(config_path))
1327
+
1328
+ arch = config.get("architectures", ["Unknown"])[0]
1329
+
1330
+ if arch in ["LlamaForCausalLM"]:
1331
+ # The default value for llama
1332
+ return (5, 4, 8)
1333
+ elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
1334
+ # The default value for deepseek
1335
+ return (5, 4, 8)
1336
+ elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
1337
+ return (5, 4, 8)
1338
+ else:
1339
+ # The default value for all other models
1340
+ return (5, 4, 8)
@@ -214,10 +214,10 @@ class EAGLEDraftCudaGraphRunner:
214
214
  forward_batch.positions = self.positions[:num_tokens]
215
215
 
216
216
  # Special handle for seq_len_cpu used when flashinfer mla is used
217
- if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
217
+ if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs):
218
218
  self.seq_lens_cpu.fill_(1)
219
- self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
220
- forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
219
+ self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
220
+ forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
221
221
 
222
222
  self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
223
223
  forward_batch, bs
@@ -233,7 +233,7 @@ class EAGLEDraftCudaGraphRunner:
233
233
  forward_batch.positions = self.positions[:raw_num_token]
234
234
  forward_batch.seq_lens = self.seq_lens[:raw_bs]
235
235
  forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
236
- if forward_batch.decode_seq_lens_cpu is not None:
237
- forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
236
+ if forward_batch.seq_lens_cpu is not None:
237
+ forward_batch.seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
238
238
 
239
239
  return out