sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
1
+ # Adapted from qwen2.py
2
+
3
+ from functools import partial
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from sglang.srt.distributed import (
10
+ get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ split_tensor_along_last_dim,
13
+ tensor_model_parallel_all_gather,
14
+ )
15
+ from sglang.srt.layers.layernorm import RMSNorm
16
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
17
+ from sglang.srt.layers.logits_processor import LogitsProcessor
18
+ from sglang.srt.layers.pooler import Pooler, PoolingType
19
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.layers.rotary_embedding import get_rope
22
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
+ from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2Model
26
+ from sglang.srt.utils import add_prefix
27
+
28
+ MiMoConfig = None
29
+
30
+
31
+ class MiMoModel(Qwen2Model):
32
+ def __init__(
33
+ self,
34
+ config: MiMoConfig,
35
+ quant_config: Optional[QuantizationConfig] = None,
36
+ prefix: str = "",
37
+ ) -> None:
38
+ super().__init__(
39
+ config=config,
40
+ quant_config=quant_config,
41
+ prefix=prefix,
42
+ decoder_layer_type=Qwen2DecoderLayer,
43
+ )
44
+
45
+
46
+ class MiMoForCausalLM(nn.Module):
47
+ # BitandBytes specific attributes
48
+ default_bitsandbytes_target_modules = [
49
+ ".gate_proj.",
50
+ ".down_proj.",
51
+ ".up_proj.",
52
+ ".q_proj.",
53
+ ".k_proj.",
54
+ ".v_proj.",
55
+ ".o_proj.",
56
+ ]
57
+ bitsandbytes_stacked_params_mapping = {
58
+ # shard_name, weight_name, index
59
+ "q_proj": ("qkv_proj", 0),
60
+ "k_proj": ("qkv_proj", 1),
61
+ "v_proj": ("qkv_proj", 2),
62
+ "gate_proj": ("gate_up_proj", 0),
63
+ "up_proj": ("gate_up_proj", 1),
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ config: MiMoConfig,
69
+ quant_config: Optional[QuantizationConfig] = None,
70
+ prefix: str = "",
71
+ ) -> None:
72
+ super().__init__()
73
+ self.config = config
74
+ self.quant_config = quant_config
75
+ self.model = MiMoModel(
76
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
77
+ )
78
+ if config.tie_word_embeddings:
79
+ self.lm_head = self.model.embed_tokens
80
+ else:
81
+ self.lm_head = ParallelLMHead(
82
+ config.vocab_size,
83
+ config.hidden_size,
84
+ quant_config=quant_config,
85
+ prefix=add_prefix("lm_head", prefix),
86
+ )
87
+ self.logits_processor = LogitsProcessor(config)
88
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
89
+
90
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
91
+ return self.model.get_input_embeddings(input_ids)
92
+
93
+ @torch.no_grad()
94
+ def forward(
95
+ self,
96
+ input_ids: torch.Tensor,
97
+ positions: torch.Tensor,
98
+ forward_batch: ForwardBatch,
99
+ input_embeds: torch.Tensor = None,
100
+ get_embedding: bool = False,
101
+ ) -> torch.Tensor:
102
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
103
+ if not get_embedding:
104
+ return self.logits_processor(
105
+ input_ids, hidden_states, self.lm_head, forward_batch
106
+ )
107
+ else:
108
+ return self.pooler(hidden_states, forward_batch)
109
+
110
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
111
+ stacked_params_mapping = [
112
+ # (param_name, shard_name, shard_id)
113
+ ("qkv_proj", "q_proj", "q"),
114
+ ("qkv_proj", "k_proj", "k"),
115
+ ("qkv_proj", "v_proj", "v"),
116
+ ("gate_up_proj", "gate_proj", 0),
117
+ ("gate_up_proj", "up_proj", 1),
118
+ ]
119
+
120
+ params_dict = dict(self.named_parameters())
121
+ for name, loaded_weight in weights:
122
+ if (
123
+ "rotary_emb.inv_freq" in name
124
+ or "projector" in name
125
+ or "mtp_layers" in name
126
+ ):
127
+ continue
128
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
129
+ # Models trained using ColossalAI may include these tensors in
130
+ # the checkpoint. Skip them.
131
+ continue
132
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
133
+ continue
134
+ if name.startswith("model.vision_tower") and name not in params_dict:
135
+ continue
136
+
137
+ for param_name, weight_name, shard_id in stacked_params_mapping:
138
+ if weight_name not in name:
139
+ continue
140
+ name = name.replace(weight_name, param_name)
141
+ # Skip loading extra bias for GPTQ models.
142
+ if name.endswith(".bias") and name not in params_dict:
143
+ continue
144
+ param = params_dict[name]
145
+ weight_loader = param.weight_loader
146
+ weight_loader(param, loaded_weight, shard_id)
147
+ break
148
+ else:
149
+ # Skip loading extra bias for GPTQ models.
150
+ if name.endswith(".bias") and name not in params_dict:
151
+ continue
152
+ param = params_dict[name]
153
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
154
+ weight_loader(param, loaded_weight)
155
+
156
+ def get_embed_and_head(self):
157
+ return self.model.embed_tokens.weight, self.lm_head.weight
158
+
159
+ def set_embed_and_head(self, embed, head):
160
+ del self.model.embed_tokens.weight
161
+ del self.lm_head.weight
162
+ self.model.embed_tokens.weight = embed
163
+ self.lm_head.weight = head
164
+ torch.cuda.empty_cache()
165
+ torch.cuda.synchronize()
166
+
167
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
168
+ self.model.load_kv_cache_scales(quantization_param_path)
169
+
170
+
171
+ EntryClass = MiMoForCausalLM
@@ -14,6 +14,7 @@
14
14
  """Conversion between OpenAI APIs and native SRT APIs"""
15
15
 
16
16
  import asyncio
17
+ import base64
17
18
  import json
18
19
  import logging
19
20
  import os
@@ -174,6 +175,32 @@ def guess_chat_template_name_from_model_path(model_path):
174
175
  )
175
176
 
176
177
 
178
+ def _validate_prompt(prompt: str):
179
+ """Validate that the prompt is not empty or whitespace only."""
180
+ is_invalid = False
181
+
182
+ # Check for empty/whitespace string
183
+ if isinstance(prompt, str):
184
+ is_invalid = not prompt.strip()
185
+ # Check for various invalid list cases: [], [""], [" "], [[]]
186
+ elif isinstance(prompt, list):
187
+ is_invalid = not prompt or (
188
+ len(prompt) == 1
189
+ and (
190
+ (isinstance(prompt[0], str) and not prompt[0].strip())
191
+ or (isinstance(prompt[0], list) and not prompt[0])
192
+ )
193
+ )
194
+
195
+ if is_invalid:
196
+ raise HTTPException(
197
+ status_code=400,
198
+ detail="Input cannot be empty or contain only whitespace.",
199
+ )
200
+
201
+ return prompt
202
+
203
+
177
204
  async def v1_files_create(
178
205
  file: UploadFile, purpose: str, file_storage_path: str = None
179
206
  ):
@@ -589,7 +616,7 @@ def v1_generate_response(
589
616
  echo = False
590
617
 
591
618
  if (not isinstance(request, list)) and request.echo:
592
- # TODO: handle the case propmt is token ids
619
+ # TODO: handle the case prompt is token ids
593
620
  if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
594
621
  # for the case of multiple str prompts
595
622
  prompts = request.prompt
@@ -645,7 +672,7 @@ def v1_generate_response(
645
672
  finish_reason = ret_item["meta_info"]["finish_reason"]
646
673
 
647
674
  if to_file:
648
- # to make the choise data json serializable
675
+ # to make the choice data json serializable
649
676
  choice_data = {
650
677
  "index": 0,
651
678
  "text": text,
@@ -966,47 +993,23 @@ def v1_chat_generate_request(
966
993
 
967
994
  if chat_template_name is None:
968
995
  openai_compatible_messages = []
969
- if (
970
- tools
971
- and tokenizer_manager.server_args.tool_call_parser == "deepseekv3"
972
- ):
973
- # add function call prompt to deepseekv3
974
- openai_compatible_messages.append(
975
- {
976
- "role": "system",
977
- "content": """You are a helpful Assistant.
978
- ## Tools
979
- ### Function
980
- You have the following functions available:
981
- """
982
- + "".join(
983
- [
984
- f"""
985
- - `{tool['name']}`:
986
- ```json
987
- {json.dumps(tool)}
988
- ```
989
- """
990
- for tool in tools
991
- ]
992
- ),
993
- }
994
- )
995
996
 
996
997
  for message in request.messages:
997
998
  if message.content is None:
998
999
  message.content = ""
999
- if isinstance(message.content, str):
1000
- openai_compatible_messages.append(
1001
- {"role": message.role, "content": message.content}
1002
- )
1000
+ msg_dict = message.dict()
1001
+ if isinstance(msg_dict.get("content"), list):
1002
+ for chunk in msg_dict["content"]:
1003
+ if isinstance(chunk, dict) and chunk.get("type") == "text":
1004
+ new_msg = msg_dict.copy()
1005
+ new_msg["content"] = chunk["text"]
1006
+ new_msg = {
1007
+ k: v for k, v in new_msg.items() if v is not None
1008
+ }
1009
+ openai_compatible_messages.append(new_msg)
1003
1010
  else:
1004
- content_list = message.dict()["content"]
1005
- for content in content_list:
1006
- if content["type"] == "text":
1007
- openai_compatible_messages.append(
1008
- {"role": message.role, "content": content["text"]}
1009
- )
1011
+ msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
1012
+ openai_compatible_messages.append(msg_dict)
1010
1013
  if (
1011
1014
  openai_compatible_messages
1012
1015
  and openai_compatible_messages[-1]["role"] == "assistant"
@@ -1316,7 +1319,8 @@ def v1_chat_generate_response(
1316
1319
  text, call_info_list = parser.parse_non_stream(text)
1317
1320
  tool_calls = [
1318
1321
  ToolCall(
1319
- id=str(call_info.tool_index),
1322
+ id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
1323
+ index=call_info.tool_index,
1320
1324
  function=FunctionResponse(
1321
1325
  name=call_info.name, arguments=call_info.parameters
1322
1326
  ),
@@ -1432,6 +1436,7 @@ async def v1_chat_completions(
1432
1436
  reasoning_parser_dict = {}
1433
1437
 
1434
1438
  async def generate_stream_resp():
1439
+ tool_call_first = True
1435
1440
  is_firsts = {}
1436
1441
  stream_buffers = {}
1437
1442
  n_prev_tokens = {}
@@ -1598,7 +1603,6 @@ async def v1_chat_completions(
1598
1603
  # 2) if we found calls, we output them as separate chunk(s)
1599
1604
  for call_item in calls:
1600
1605
  # transform call_item -> FunctionResponse + ToolCall
1601
-
1602
1606
  if finish_reason_type == "stop":
1603
1607
  latest_delta_len = 0
1604
1608
  if isinstance(call_item.parameters, str):
@@ -1621,15 +1625,19 @@ async def v1_chat_completions(
1621
1625
  call_item.parameters = remaining_call
1622
1626
 
1623
1627
  finish_reason_type = "tool_calls"
1624
-
1625
1628
  tool_call = ToolCall(
1626
- id=str(call_item.tool_index),
1629
+ id=(
1630
+ f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
1631
+ if tool_call_first
1632
+ else None
1633
+ ),
1627
1634
  index=call_item.tool_index,
1628
1635
  function=FunctionResponse(
1629
1636
  name=call_item.name,
1630
1637
  arguments=call_item.parameters,
1631
1638
  ),
1632
1639
  )
1640
+ tool_call_first = False
1633
1641
  choice_data = ChatCompletionResponseStreamChoice(
1634
1642
  index=index,
1635
1643
  delta=DeltaMessage(tool_calls=[tool_call]),
@@ -1771,6 +1779,8 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1771
1779
 
1772
1780
  for request in all_requests:
1773
1781
  prompt = request.input
1782
+ # Check for empty/whitespace string
1783
+ prompt = _validate_prompt(request.input)
1774
1784
  assert (
1775
1785
  type(prompt) is first_prompt_type
1776
1786
  ), "All prompts must be of the same type in file input settings"
@@ -250,9 +250,29 @@ ChatCompletionMessageContentPart = Union[
250
250
  ]
251
251
 
252
252
 
253
+ class FunctionResponse(BaseModel):
254
+ """Function response."""
255
+
256
+ name: Optional[str] = None
257
+ arguments: Optional[str] = None
258
+
259
+
260
+ class ToolCall(BaseModel):
261
+ """Tool call response."""
262
+
263
+ id: Optional[str] = None
264
+ index: Optional[int] = None
265
+ type: Literal["function"] = "function"
266
+ function: FunctionResponse
267
+
268
+
253
269
  class ChatCompletionMessageGenericParam(BaseModel):
254
270
  role: Literal["system", "assistant", "tool"]
255
271
  content: Union[str, List[ChatCompletionMessageContentTextPart], None]
272
+ tool_call_id: Optional[str] = None
273
+ name: Optional[str] = None
274
+ reasoning_content: Optional[str] = None
275
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
256
276
 
257
277
 
258
278
  class ChatCompletionMessageUserParam(BaseModel):
@@ -378,22 +398,6 @@ class ChatCompletionRequest(BaseModel):
378
398
  bootstrap_room: Optional[int] = None
379
399
 
380
400
 
381
- class FunctionResponse(BaseModel):
382
- """Function response."""
383
-
384
- name: Optional[str] = None
385
- arguments: Optional[str] = None
386
-
387
-
388
- class ToolCall(BaseModel):
389
- """Tool call response."""
390
-
391
- id: str
392
- index: Optional[int] = None
393
- type: Literal["function"] = "function"
394
- function: FunctionResponse
395
-
396
-
397
401
  class ChatMessage(BaseModel):
398
402
  role: Optional[str] = None
399
403
  content: Optional[str] = None
@@ -147,7 +147,7 @@ class ReasoningParser:
147
147
 
148
148
  Args:
149
149
  model_type (str): Type of model to parse reasoning from
150
- stream_reasoning (bool): If Flase, accumulates reasoning content until complete.
150
+ stream_reasoning (bool): If False, accumulates reasoning content until complete.
151
151
  If True, streams reasoning content as it arrives.
152
152
  """
153
153
 
@@ -28,11 +28,26 @@ class CustomLogitProcessor(ABC):
28
28
  """Define the callable behavior."""
29
29
  raise NotImplementedError
30
30
 
31
- def to_str(self) -> str:
31
+ @classmethod
32
+ def to_str(cls) -> str:
32
33
  """Serialize the callable function to a JSON-compatible string."""
33
- return json.dumps({"callable": dill.dumps(self).hex()})
34
+ return json.dumps({"callable": dill.dumps(cls).hex()})
34
35
 
35
36
  @classmethod
36
37
  def from_str(cls, json_str: str):
37
38
  """Deserialize a callable function from a JSON string."""
38
- return _cache_from_str(json_str)
39
+ return _cache_from_str(json_str)()
40
+
41
+
42
+ class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
43
+ def __call__(
44
+ self,
45
+ logits: torch.Tensor,
46
+ custom_param_list: Optional[List[Dict[str, Any]]] = None,
47
+ ) -> torch.Tensor:
48
+ disallowed_token_ids = custom_param_list[0]["token_ids"]
49
+ assert all(
50
+ disallowed_token_ids == c["token_ids"] for c in custom_param_list
51
+ ), f"{custom_param_list=}"
52
+ logits[..., disallowed_token_ids] = -float("inf")
53
+ return logits
@@ -294,7 +294,7 @@ class SamplingBatchInfo:
294
294
  # Set the flag to True if any of the two has custom logit processor
295
295
  self.has_custom_logit_processor = True
296
296
 
297
- # Note: becasue the __len()__ operator is defined on the temperatures tensor,
297
+ # Note: because the __len()__ operator is defined on the temperatures tensor,
298
298
  # please make sure any merge operation with len(self) or len(other) is done before
299
299
  # the merge operation of the temperatures tensor below.
300
300
  for item in [
@@ -307,5 +307,5 @@ class SamplingBatchInfo:
307
307
  other_val = getattr(other, item, None)
308
308
  setattr(self, item, torch.cat([self_val, other_val]))
309
309
 
310
- self.is_all_greedy |= other.is_all_greedy
310
+ self.is_all_greedy &= other.is_all_greedy
311
311
  self.need_min_p_sampling |= other.need_min_p_sampling
@@ -50,6 +50,7 @@ class SamplingParams:
50
50
  spaces_between_special_tokens: bool = True,
51
51
  no_stop_trim: bool = False,
52
52
  custom_params: Optional[Dict[str, Any]] = None,
53
+ stream_interval: Optional[int] = None,
53
54
  ) -> None:
54
55
  self.max_new_tokens = max_new_tokens
55
56
  self.stop_strs = stop
@@ -75,6 +76,7 @@ class SamplingParams:
75
76
  self.spaces_between_special_tokens = spaces_between_special_tokens
76
77
  self.no_stop_trim = no_stop_trim
77
78
  self.custom_params = custom_params
79
+ self.stream_interval = stream_interval
78
80
 
79
81
  # Process some special cases
80
82
  if 0 <= self.temperature < _SAMPLING_EPS:
sglang/srt/server_args.py CHANGED
@@ -98,6 +98,7 @@ class ServerArgs:
98
98
  show_time_cost: bool = False
99
99
  enable_metrics: bool = False
100
100
  decode_log_interval: int = 40
101
+ enable_request_time_stats_logging: bool = False
101
102
 
102
103
  # API related
103
104
  api_key: Optional[str] = None
@@ -159,6 +160,7 @@ class ServerArgs:
159
160
  disable_overlap_schedule: bool = False
160
161
  enable_mixed_chunk: bool = False
161
162
  enable_dp_attention: bool = False
163
+ enable_dp_lm_head: bool = False
162
164
  enable_ep_moe: bool = False
163
165
  enable_deepep_moe: bool = False
164
166
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
@@ -187,6 +189,7 @@ class ServerArgs:
187
189
  n_share_experts_fusion: int = 0
188
190
  disable_chunked_prefix_cache: bool = False
189
191
  disable_fast_image_processor: bool = False
192
+ mm_attention_backend: Optional[str] = None
190
193
 
191
194
  # Debug tensor dumps
192
195
  debug_tensor_dump_output_folder: Optional[str] = None
@@ -198,6 +201,7 @@ class ServerArgs:
198
201
  disaggregation_bootstrap_port: int = 8998
199
202
  disaggregation_transfer_backend: str = "mooncake"
200
203
  disaggregation_ib_device: Optional[str] = None
204
+ pdlb_url: Optional[str] = None
201
205
 
202
206
  def __post_init__(self):
203
207
  # Expert parallelism
@@ -303,6 +307,12 @@ class ServerArgs:
303
307
  if self.grammar_backend is None:
304
308
  self.grammar_backend = "xgrammar"
305
309
 
310
+ if self.pp_size > 1:
311
+ self.disable_overlap_schedule = True
312
+ logger.warning(
313
+ "Overlap scheduler is disabled because of using pipeline parallelism."
314
+ )
315
+
306
316
  # Data parallelism attention
307
317
  if self.enable_dp_attention:
308
318
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
@@ -315,6 +325,11 @@ class ServerArgs:
315
325
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
316
326
  )
317
327
 
328
+ if self.enable_dp_lm_head:
329
+ assert (
330
+ self.enable_dp_attention
331
+ ), "Please enable dp attention when setting enable_dp_attention. "
332
+
318
333
  # DeepEP MoE
319
334
  self.enable_sp_layernorm = False
320
335
  if self.enable_deepep_moe:
@@ -322,6 +337,9 @@ class ServerArgs:
322
337
  assert (
323
338
  not self.enable_dp_attention
324
339
  ), "DeepEP MoE `auto` mode is not supported with DP Attention."
340
+ if self.deepep_mode == "normal":
341
+ logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
342
+ self.disable_cuda_graph = True
325
343
  self.ep_size = self.tp_size
326
344
  self.enable_sp_layernorm = (
327
345
  self.dp_size < self.tp_size if self.enable_dp_attention else True
@@ -330,6 +348,12 @@ class ServerArgs:
330
348
  f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
331
349
  )
332
350
 
351
+ if self.pp_size > 1:
352
+ self.disable_overlap_schedule = True
353
+ logger.warning(
354
+ "Pipeline parallelism is incompatible with overlap schedule."
355
+ )
356
+
333
357
  # Speculative Decoding
334
358
  if self.speculative_algorithm == "NEXTN":
335
359
  # NEXTN shares the same implementation of EAGLE
@@ -347,10 +371,13 @@ class ServerArgs:
347
371
  model_arch = get_model_arch(self)
348
372
 
349
373
  # Auto set draft_model_path DeepSeek-V3/R1
350
- if self.speculative_draft_model_path is None and model_arch in [
351
- "DeepseekV3ForCausalLM"
352
- ]:
353
- self.speculative_draft_model_path = self.model_path
374
+ if model_arch == "DeepseekV3ForCausalLM":
375
+ if self.speculative_draft_model_path is None:
376
+ self.speculative_draft_model_path = self.model_path
377
+ else:
378
+ logger.warning(
379
+ "DeepSeek MTP does not require setting speculative_draft_model_path."
380
+ )
354
381
 
355
382
  # Auto choose parameters
356
383
  if self.speculative_num_steps is None:
@@ -551,7 +578,7 @@ class ServerArgs:
551
578
  "--device",
552
579
  type=str,
553
580
  default=ServerArgs.device,
554
- help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.",
581
+ help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
555
582
  )
556
583
  parser.add_argument(
557
584
  "--served-model-name",
@@ -759,6 +786,12 @@ class ServerArgs:
759
786
  default=ServerArgs.decode_log_interval,
760
787
  help="The log interval of decode batch.",
761
788
  )
789
+ parser.add_argument(
790
+ "--enable-request-time-stats-logging",
791
+ action="store_true",
792
+ default=ServerArgs.enable_request_time_stats_logging,
793
+ help="Enable per request time stats logging",
794
+ )
762
795
 
763
796
  # API related
764
797
  parser.add_argument(
@@ -817,7 +850,7 @@ class ServerArgs:
817
850
  # Multi-node distributed serving
818
851
  parser.add_argument(
819
852
  "--dist-init-addr",
820
- "--nccl-init-addr", # For backward compatbility. This will be removed in the future.
853
+ "--nccl-init-addr", # For backward compatibility. This will be removed in the future.
821
854
  type=str,
822
855
  help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
823
856
  )
@@ -1041,6 +1074,11 @@ class ServerArgs:
1041
1074
  action="store_true",
1042
1075
  help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
1043
1076
  )
1077
+ parser.add_argument(
1078
+ "--enable-dp-lm-head",
1079
+ action="store_true",
1080
+ help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
1081
+ )
1044
1082
  parser.add_argument(
1045
1083
  "--enable-ep-moe",
1046
1084
  action="store_true",
@@ -1061,7 +1099,7 @@ class ServerArgs:
1061
1099
  "--cuda-graph-max-bs",
1062
1100
  type=int,
1063
1101
  default=ServerArgs.cuda_graph_max_bs,
1064
- help="Set the maximum batch size for cuda graph.",
1102
+ help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
1065
1103
  )
1066
1104
  parser.add_argument(
1067
1105
  "--cuda-graph-bs",
@@ -1088,7 +1126,7 @@ class ServerArgs:
1088
1126
  parser.add_argument(
1089
1127
  "--triton-attention-reduce-in-fp32",
1090
1128
  action="store_true",
1091
- help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
1129
+ help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
1092
1130
  "This only affects Triton attention kernels.",
1093
1131
  )
1094
1132
  parser.add_argument(
@@ -1180,7 +1218,7 @@ class ServerArgs:
1180
1218
  type=int,
1181
1219
  default=0,
1182
1220
  help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, "
1183
- "set it to tp_size can get best optimized performace.",
1221
+ "set it to tp_size can get best optimized performance. Note that for architectures with SM==90, we have enabled the shared experts fusion optimization by default for DeepSeek V3/R1, with n_share_experts_fusion automatically set to the TP size.",
1184
1222
  )
1185
1223
  parser.add_argument(
1186
1224
  "--disable-chunked-prefix-cache",
@@ -1247,7 +1285,23 @@ class ServerArgs:
1247
1285
  "--disaggregation-ib-device",
1248
1286
  type=str,
1249
1287
  default=ServerArgs.disaggregation_ib_device,
1250
- help="The ib device for disaggregation transfer. Default is None, it will be detected automatically if using the mooncake backend.",
1288
+ help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) "
1289
+ "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
1290
+ "Default is None, which triggers automatic device detection when mooncake backend is enabled.",
1291
+ )
1292
+ parser.add_argument(
1293
+ "--pdlb-url",
1294
+ type=str,
1295
+ default=None,
1296
+ help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
1297
+ )
1298
+
1299
+ parser.add_argument(
1300
+ "--mm-attention-backend",
1301
+ type=str,
1302
+ choices=["sdpa", "fa3", "triton_attn"],
1303
+ default=ServerArgs.mm_attention_backend,
1304
+ help="Set multimodal attention backend.",
1251
1305
  )
1252
1306
 
1253
1307
  @classmethod