sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
16
16
  import time
17
17
  from typing import Dict, List, Optional, Union
18
18
 
19
- from pydantic import BaseModel, Field
19
+ from pydantic import BaseModel, Field, root_validator
20
20
  from typing_extensions import Literal
21
21
 
22
22
 
@@ -28,6 +28,7 @@ class ModelCard(BaseModel):
28
28
  created: int = Field(default_factory=lambda: int(time.time()))
29
29
  owned_by: str = "sglang"
30
30
  root: Optional[str] = None
31
+ max_model_len: Optional[int] = None
31
32
 
32
33
 
33
34
  class ModelList(BaseModel):
@@ -187,7 +188,7 @@ class CompletionResponseChoice(BaseModel):
187
188
  index: int
188
189
  text: str
189
190
  logprobs: Optional[LogProbs] = None
190
- finish_reason: Optional[str] = None
191
+ finish_reason: Literal["stop", "length", "content_filter"]
191
192
  matched_stop: Union[None, int, str] = None
192
193
 
193
194
 
@@ -204,7 +205,7 @@ class CompletionResponseStreamChoice(BaseModel):
204
205
  index: int
205
206
  text: str
206
207
  logprobs: Optional[LogProbs] = None
207
- finish_reason: Optional[str] = None
208
+ finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
208
209
  matched_stop: Union[None, int, str] = None
209
210
 
210
211
 
@@ -227,14 +228,25 @@ class ChatCompletionMessageContentImageURL(BaseModel):
227
228
  detail: Optional[Literal["auto", "low", "high"]] = "auto"
228
229
 
229
230
 
231
+ class ChatCompletionMessageContentAudioURL(BaseModel):
232
+ url: str
233
+
234
+
230
235
  class ChatCompletionMessageContentImagePart(BaseModel):
231
236
  type: Literal["image_url"]
232
237
  image_url: ChatCompletionMessageContentImageURL
233
238
  modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
234
239
 
235
240
 
241
+ class ChatCompletionMessageContentAudioPart(BaseModel):
242
+ type: Literal["audio_url"]
243
+ audio_url: ChatCompletionMessageContentAudioURL
244
+
245
+
236
246
  ChatCompletionMessageContentPart = Union[
237
- ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
247
+ ChatCompletionMessageContentTextPart,
248
+ ChatCompletionMessageContentImagePart,
249
+ ChatCompletionMessageContentAudioPart,
238
250
  ]
239
251
 
240
252
 
@@ -276,6 +288,7 @@ class Function(BaseModel):
276
288
  description: Optional[str] = Field(default=None, examples=[None])
277
289
  name: Optional[str] = None
278
290
  parameters: Optional[object] = None
291
+ strict: bool = False
279
292
 
280
293
 
281
294
  class Tool(BaseModel):
@@ -310,7 +323,7 @@ class ChatCompletionRequest(BaseModel):
310
323
  max_tokens: Optional[int] = None
311
324
  n: int = 1
312
325
  presence_penalty: float = 0.0
313
- response_format: Union[ResponseFormat, StructuralTagResponseFormat] = None
326
+ response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
314
327
  seed: Optional[int] = None
315
328
  stop: Optional[Union[str, List[str]]] = None
316
329
  stream: bool = False
@@ -323,6 +336,15 @@ class ChatCompletionRequest(BaseModel):
323
336
  default="auto", examples=["none"]
324
337
  ) # noqa
325
338
 
339
+ @root_validator(pre=True)
340
+ def set_tool_choice_default(cls, values):
341
+ if values.get("tool_choice") is None:
342
+ if values.get("tools") is None:
343
+ values["tool_choice"] = "none"
344
+ else:
345
+ values["tool_choice"] = "auto"
346
+ return values
347
+
326
348
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
327
349
  top_k: int = -1
328
350
  min_p: float = 0.0
@@ -366,7 +388,9 @@ class ChatCompletionResponseChoice(BaseModel):
366
388
  index: int
367
389
  message: ChatMessage
368
390
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
369
- finish_reason: str
391
+ finish_reason: Literal[
392
+ "stop", "length", "tool_calls", "content_filter", "function_call"
393
+ ]
370
394
  matched_stop: Union[None, int, str] = None
371
395
 
372
396
 
@@ -390,7 +414,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
390
414
  index: int
391
415
  delta: DeltaMessage
392
416
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
393
- finish_reason: Optional[str] = None
417
+ finish_reason: Optional[
418
+ Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
419
+ ] = None
394
420
  matched_stop: Union[None, int, str] = None
395
421
 
396
422
 
@@ -0,0 +1,71 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ from typing import Callable, Union
15
+
16
+ import torch
17
+ from torch.multiprocessing import reductions
18
+
19
+
20
+ def monkey_patch_torch_reductions():
21
+ """Monkey patching before Torch https://github.com/pytorch/pytorch/pull/149248 is fixed"""
22
+
23
+ if hasattr(reductions, "_reduce_tensor_original"):
24
+ return
25
+
26
+ reductions._reduce_tensor_original = reductions.reduce_tensor
27
+ reductions._rebuild_cuda_tensor_original = reductions.rebuild_cuda_tensor
28
+
29
+ reductions.reduce_tensor = _reduce_tensor_modified
30
+ reductions.rebuild_cuda_tensor = _rebuild_cuda_tensor_modified
31
+
32
+ reductions.init_reductions()
33
+
34
+
35
+ # The signature has not been changed for years, and we will not need this when the next version is released,
36
+ # so it looks safe to use a constant.
37
+ _REDUCE_TENSOR_ARG_DEVICE_INDEX = 6
38
+
39
+
40
+ def _reduce_tensor_modified(*args, **kwargs):
41
+ output_fn, output_args = reductions._reduce_tensor_original(*args, **kwargs)
42
+ output_args = _modify_tuple(
43
+ output_args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_to_uuid
44
+ )
45
+ return output_fn, output_args
46
+
47
+
48
+ def _rebuild_cuda_tensor_modified(*args):
49
+ args = _modify_tuple(args, _REDUCE_TENSOR_ARG_DEVICE_INDEX, _device_from_maybe_uuid)
50
+ return reductions._rebuild_cuda_tensor_original(*args)
51
+
52
+
53
+ def _device_to_uuid(device: int) -> str:
54
+ return str(torch.cuda.get_device_properties(device).uuid)
55
+
56
+
57
+ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
58
+ if isinstance(device_maybe_uuid, int):
59
+ return device_maybe_uuid
60
+
61
+ if isinstance(device_maybe_uuid, str):
62
+ for device in range(torch.cuda.device_count()):
63
+ if str(torch.cuda.get_device_properties(device).uuid) == device_maybe_uuid:
64
+ return device
65
+ raise Exception("Invalid device_uuid=" + device_maybe_uuid)
66
+
67
+ raise Exception(f"Unknown type: {device_maybe_uuid=}")
68
+
69
+
70
+ def _modify_tuple(t, index: int, modifier: Callable):
71
+ return *t[:index], modifier(t[index]), *t[index + 1 :]
@@ -306,7 +306,7 @@ class SamplingBatchInfo:
306
306
  ]:
307
307
  self_val = getattr(self, item, None)
308
308
  other_val = getattr(other, item, None)
309
- setattr(self, item, torch.concat([self_val, other_val]))
309
+ setattr(self, item, torch.cat([self_val, other_val]))
310
310
 
311
311
  self.is_all_greedy |= other.is_all_greedy
312
312
  self.need_min_p_sampling |= other.need_min_p_sampling
@@ -77,7 +77,7 @@ class SamplingParams:
77
77
  self.custom_params = custom_params
78
78
 
79
79
  # Process some special cases
80
- if self.temperature < _SAMPLING_EPS:
80
+ if 0 <= self.temperature < _SAMPLING_EPS:
81
81
  # top_k = 1 means greedy sampling
82
82
  self.temperature = 1.0
83
83
  self.top_k = 1
@@ -93,9 +93,9 @@ class SamplingParams:
93
93
  raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
94
94
  if not 0.0 <= self.min_p <= 1.0:
95
95
  raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
96
- if self.top_k < -1 or self.top_k == 0:
96
+ if self.top_k < 1 or self.top_k == -1:
97
97
  raise ValueError(
98
- f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
98
+ f"top_k must be -1 (disable) or at least 1, got {self.top_k}."
99
99
  )
100
100
  if not -2.0 <= self.frequency_penalty <= 2.0:
101
101
  raise ValueError(
@@ -108,12 +108,12 @@ class SamplingParams:
108
108
  )
109
109
  if not 0.0 <= self.repetition_penalty <= 2.0:
110
110
  raise ValueError(
111
- "repetition_penalty must be in (0, 2], got "
111
+ "repetition_penalty must be in [0, 2], got "
112
112
  f"{self.repetition_penalty}."
113
113
  )
114
114
  if not 0 <= self.min_new_tokens:
115
115
  raise ValueError(
116
- f"min_new_tokens must be in (0, max_new_tokens], got "
116
+ f"min_new_tokens must be in [0, max_new_tokens], got "
117
117
  f"{self.min_new_tokens}."
118
118
  )
119
119
  if self.max_new_tokens is not None:
@@ -123,7 +123,7 @@ class SamplingParams:
123
123
  )
124
124
  if not self.min_new_tokens <= self.max_new_tokens:
125
125
  raise ValueError(
126
- f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
126
+ f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got "
127
127
  f"{self.min_new_tokens}."
128
128
  )
129
129
  grammars = [
sglang/srt/server_args.py CHANGED
@@ -16,6 +16,7 @@
16
16
  import argparse
17
17
  import dataclasses
18
18
  import logging
19
+ import os
19
20
  import random
20
21
  import tempfile
21
22
  from typing import List, Optional
@@ -23,13 +24,16 @@ from typing import List, Optional
23
24
  from sglang.srt.hf_transformers_utils import check_gguf_file
24
25
  from sglang.srt.reasoning_parser import ReasoningParser
25
26
  from sglang.srt.utils import (
27
+ configure_ipv6,
26
28
  get_amdgpu_memory_capacity,
29
+ get_device,
27
30
  get_hpu_memory_capacity,
28
31
  get_nvgpu_memory_capacity,
29
32
  is_cuda,
30
33
  is_flashinfer_available,
31
34
  is_hip,
32
35
  is_port_available,
36
+ is_remote_url,
33
37
  is_valid_ipv6_address,
34
38
  nullable_str,
35
39
  )
@@ -49,11 +53,12 @@ class ServerArgs:
49
53
  dtype: str = "auto"
50
54
  kv_cache_dtype: str = "auto"
51
55
  quantization: Optional[str] = None
52
- quantization_param_path: nullable_str = None
56
+ quantization_param_path: Optional[str] = None
53
57
  context_length: Optional[int] = None
54
- device: str = "cuda"
58
+ device: Optional[str] = None
55
59
  served_model_name: Optional[str] = None
56
60
  chat_template: Optional[str] = None
61
+ completion_template: Optional[str] = None
57
62
  is_embedding: bool = False
58
63
  revision: Optional[str] = None
59
64
 
@@ -122,7 +127,7 @@ class ServerArgs:
122
127
  # Kernel backend
123
128
  attention_backend: Optional[str] = None
124
129
  sampling_backend: Optional[str] = None
125
- grammar_backend: Optional[str] = "outlines"
130
+ grammar_backend: Optional[str] = "xgrammar"
126
131
 
127
132
  # Speculative decoding
128
133
  speculative_algorithm: Optional[str] = None
@@ -136,7 +141,7 @@ class ServerArgs:
136
141
 
137
142
  # Double Sparsity
138
143
  enable_double_sparsity: bool = False
139
- ds_channel_config_path: str = None
144
+ ds_channel_config_path: Optional[str] = None
140
145
  ds_heavy_channel_num: int = 32
141
146
  ds_heavy_token_num: int = 256
142
147
  ds_heavy_channel_type: str = "qk"
@@ -154,6 +159,7 @@ class ServerArgs:
154
159
  enable_mixed_chunk: bool = False
155
160
  enable_dp_attention: bool = False
156
161
  enable_ep_moe: bool = False
162
+ enable_deepep_moe: bool = False
157
163
  enable_torch_compile: bool = False
158
164
  torch_compile_max_bs: int = 32
159
165
  cuda_graph_max_bs: Optional[int] = None
@@ -168,9 +174,11 @@ class ServerArgs:
168
174
  enable_memory_saver: bool = False
169
175
  allow_auto_truncate: bool = False
170
176
  enable_custom_logit_processor: bool = False
171
- tool_call_parser: str = None
177
+ tool_call_parser: Optional[str] = None
172
178
  enable_hierarchical_cache: bool = False
179
+ hicache_ratio: float = 2.0
173
180
  enable_flashinfer_mla: bool = False
181
+ enable_flashmla: bool = False
174
182
  flashinfer_mla_disable_ragged: bool = False
175
183
  warmups: Optional[str] = None
176
184
 
@@ -179,11 +187,18 @@ class ServerArgs:
179
187
  debug_tensor_dump_input_file: Optional[str] = None
180
188
  debug_tensor_dump_inject: bool = False
181
189
 
190
+ # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
191
+ disaggregation_mode: str = "null"
192
+ disaggregation_bootstrap_port: int = 8998
193
+
182
194
  def __post_init__(self):
183
195
  # Set missing default values
184
196
  if self.tokenizer_path is None:
185
197
  self.tokenizer_path = self.model_path
186
198
 
199
+ if self.device is None:
200
+ self.device = get_device()
201
+
187
202
  if self.served_model_name is None:
188
203
  self.served_model_name = self.model_path
189
204
 
@@ -222,6 +237,11 @@ class ServerArgs:
222
237
 
223
238
  assert self.chunked_prefill_size % self.page_size == 0
224
239
 
240
+ if self.enable_flashmla is True:
241
+ logger.warning(
242
+ "FlashMLA only supports a page_size of 64, change page_size to 64."
243
+ )
244
+ self.page_size = 64
225
245
  # Set cuda graph max batch size
226
246
  if self.cuda_graph_max_bs is None:
227
247
  # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
@@ -272,15 +292,28 @@ class ServerArgs:
272
292
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
273
293
  )
274
294
 
295
+ self.enable_sp_layernorm = False
296
+ # DeepEP MoE
297
+ if self.enable_deepep_moe:
298
+ self.ep_size = self.tp_size
299
+ self.enable_sp_layernorm = (
300
+ self.dp_size < self.tp_size if self.enable_dp_attention else True
301
+ )
302
+ logger.info(
303
+ f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
304
+ )
305
+
275
306
  # Speculative Decoding
276
307
  if self.speculative_algorithm == "NEXTN":
277
308
  # NEXTN shares the same implementation of EAGLE
278
309
  self.speculative_algorithm = "EAGLE"
279
310
 
280
- if self.speculative_algorithm == "EAGLE":
311
+ if (
312
+ self.speculative_algorithm == "EAGLE"
313
+ or self.speculative_algorithm == "EAGLE3"
314
+ ):
281
315
  if self.max_running_requests is None:
282
316
  self.max_running_requests = 32
283
- self.disable_cuda_graph_padding = True
284
317
  self.disable_overlap_schedule = True
285
318
  logger.info(
286
319
  "Overlap scheduler is disabled because of using "
@@ -296,10 +329,29 @@ class ServerArgs:
296
329
  ) and check_gguf_file(self.model_path):
297
330
  self.quantization = self.load_format = "gguf"
298
331
 
332
+ if is_remote_url(self.model_path):
333
+ self.load_format = "remote"
334
+
299
335
  # AMD-specific Triton attention KV splits default number
300
336
  if is_hip():
301
337
  self.triton_attention_num_kv_splits = 16
302
338
 
339
+ # PD disaggregation
340
+ if self.disaggregation_mode == "prefill":
341
+ self.disable_cuda_graph = True
342
+ logger.warning("KV cache is forced as chunk cache for decode server")
343
+ self.disable_overlap_schedule = True
344
+ logger.warning("Overlap scheduler is disabled for prefill server")
345
+ elif self.disaggregation_mode == "decode":
346
+ self.disable_radix_cache = True
347
+ logger.warning("Cuda graph is disabled for prefill server")
348
+ self.disable_overlap_schedule = True
349
+ logger.warning("Overlap scheduler is disabled for decode server")
350
+
351
+ os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
352
+ "1" if self.enable_torch_compile else "0"
353
+ )
354
+
303
355
  @staticmethod
304
356
  def add_cli_args(parser: argparse.ArgumentParser):
305
357
  # Model and port args
@@ -345,9 +397,11 @@ class ServerArgs:
345
397
  "safetensors",
346
398
  "npcache",
347
399
  "dummy",
400
+ "sharded_state",
348
401
  "gguf",
349
402
  "bitsandbytes",
350
403
  "layered",
404
+ "remote",
351
405
  ],
352
406
  help="The format of the model weights to load. "
353
407
  '"auto" will try to load the weights in the safetensors format '
@@ -429,9 +483,8 @@ class ServerArgs:
429
483
  parser.add_argument(
430
484
  "--device",
431
485
  type=str,
432
- default="cuda",
433
- choices=["cuda", "xpu", "hpu", "cpu"],
434
- help="The device type.",
486
+ default=ServerArgs.device,
487
+ help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.",
435
488
  )
436
489
  parser.add_argument(
437
490
  "--served-model-name",
@@ -445,6 +498,12 @@ class ServerArgs:
445
498
  default=ServerArgs.chat_template,
446
499
  help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
447
500
  )
501
+ parser.add_argument(
502
+ "--completion-template",
503
+ type=str,
504
+ default=ServerArgs.completion_template,
505
+ help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
506
+ )
448
507
  parser.add_argument(
449
508
  "--is-embedding",
450
509
  action="store_true",
@@ -722,7 +781,7 @@ class ServerArgs:
722
781
  parser.add_argument(
723
782
  "--attention-backend",
724
783
  type=str,
725
- choices=["flashinfer", "triton", "torch_native"],
784
+ choices=["flashinfer", "triton", "torch_native", "fa3"],
726
785
  default=ServerArgs.attention_backend,
727
786
  help="Choose the kernels for attention layers.",
728
787
  )
@@ -745,6 +804,11 @@ class ServerArgs:
745
804
  action="store_true",
746
805
  help="Enable FlashInfer MLA optimization",
747
806
  )
807
+ parser.add_argument(
808
+ "--enable-flashmla",
809
+ action="store_true",
810
+ help="Enable FlashMLA decode optimization",
811
+ )
748
812
  parser.add_argument(
749
813
  "--flashinfer-mla-disable-ragged",
750
814
  action="store_true",
@@ -755,7 +819,7 @@ class ServerArgs:
755
819
  parser.add_argument(
756
820
  "--speculative-algorithm",
757
821
  type=str,
758
- choices=["EAGLE", "NEXTN"],
822
+ choices=["EAGLE", "EAGLE3", "NEXTN"],
759
823
  help="Speculative algorithm.",
760
824
  )
761
825
  parser.add_argument(
@@ -984,6 +1048,18 @@ class ServerArgs:
984
1048
  action="store_true",
985
1049
  help="Enable hierarchical cache",
986
1050
  )
1051
+ parser.add_argument(
1052
+ "--hicache-ratio",
1053
+ type=float,
1054
+ required=False,
1055
+ default=ServerArgs.hicache_ratio,
1056
+ help="The ratio of the size of host KV cache memory pool to the size of device pool.",
1057
+ )
1058
+ parser.add_argument(
1059
+ "--enable-deepep-moe",
1060
+ action="store_true",
1061
+ help="Enabling DeepEP MoE implementation for EP MoE.",
1062
+ )
987
1063
 
988
1064
  # Server warmups
989
1065
  parser.add_argument(
@@ -1014,6 +1090,21 @@ class ServerArgs:
1014
1090
  help="Inject the outputs from jax as the input of every layer.",
1015
1091
  )
1016
1092
 
1093
+ # Disaggregation
1094
+ parser.add_argument(
1095
+ "--disaggregation-mode",
1096
+ type=str,
1097
+ default="null",
1098
+ choices=["null", "prefill", "decode"],
1099
+ help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
1100
+ )
1101
+ parser.add_argument(
1102
+ "--disaggregation-bootstrap-port",
1103
+ type=int,
1104
+ default=ServerArgs.disaggregation_bootstrap_port,
1105
+ help="Bootstrap server port on the prefill server. Default is 8998.",
1106
+ )
1107
+
1017
1108
  @classmethod
1018
1109
  def from_cli_args(cls, args: argparse.Namespace):
1019
1110
  args.tp_size = args.tensor_parallel_size
@@ -1088,6 +1179,9 @@ class PortArgs:
1088
1179
  # The port for nccl initialization (torch.dist)
1089
1180
  nccl_port: int
1090
1181
 
1182
+ # The ipc filename for rpc call between Engine and Scheduler
1183
+ rpc_ipc_name: str
1184
+
1091
1185
  @staticmethod
1092
1186
  def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1093
1187
  port = server_args.port + random.randint(100, 1000)
@@ -1106,13 +1200,18 @@ class PortArgs:
1106
1200
  scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1107
1201
  detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1108
1202
  nccl_port=port,
1203
+ rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1109
1204
  )
1110
1205
  else:
1111
1206
  # DP attention. Use TCP + port to handle both single-node and multi-node.
1112
1207
  if server_args.nnodes == 1 and server_args.dist_init_addr is None:
1113
1208
  dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
1209
+ elif server_args.dist_init_addr.startswith("["): # ipv6 address
1210
+ port_num, host = configure_ipv6(server_args.dist_init_addr)
1211
+ dist_init_addr = (host, str(port_num))
1114
1212
  else:
1115
1213
  dist_init_addr = server_args.dist_init_addr.split(":")
1214
+
1116
1215
  assert (
1117
1216
  len(dist_init_addr) == 2
1118
1217
  ), "please provide --dist-init-addr as host:port of head node"
@@ -1121,16 +1220,17 @@ class PortArgs:
1121
1220
  port_base = int(dist_init_port) + 1
1122
1221
  if dp_rank is None:
1123
1222
  scheduler_input_port = (
1124
- port_base + 2
1223
+ port_base + 3
1125
1224
  ) # TokenizerManager to DataParallelController
1126
1225
  else:
1127
- scheduler_input_port = port_base + 2 + 1 + dp_rank
1226
+ scheduler_input_port = port_base + 3 + 1 + dp_rank
1128
1227
 
1129
1228
  return PortArgs(
1130
1229
  tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
1131
1230
  scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
1132
1231
  detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
1133
1232
  nccl_port=port,
1233
+ rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1134
1234
  )
1135
1235
 
1136
1236