sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,367 @@
1
+ import ast
2
+ import html
3
+ import json
4
+ import logging
5
+ import re
6
+ from typing import Any, Dict, List, Tuple
7
+
8
+ from sglang.srt.entrypoints.openai.protocol import Tool
9
+ from sglang.srt.function_call.base_format_detector import BaseFormatDetector
10
+ from sglang.srt.function_call.core_types import (
11
+ StreamingParseResult,
12
+ ToolCallItem,
13
+ _GetInfoFunc,
14
+ )
15
+ from sglang.srt.function_call.ebnf_composer import EBNFComposer
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _safe_val(raw: str) -> Any:
21
+ raw = html.unescape(raw.strip())
22
+ try:
23
+ return json.loads(raw)
24
+ except Exception:
25
+ try:
26
+ return ast.literal_eval(raw)
27
+ except Exception:
28
+ return raw
29
+
30
+
31
+ class MinimaxM2Detector(BaseFormatDetector):
32
+ """
33
+ Detector for MiniMax M2 models.
34
+ Assumes function call format:
35
+ <minimax:tool_call>
36
+ <invoke name="func1">
37
+ <parameter name="param1">value1</parameter>
38
+ <parameter name="param2">value2</parameter>
39
+ </invoke>
40
+ </minimax:tool_call>
41
+ """
42
+
43
+ def __init__(self):
44
+ super().__init__()
45
+ self.tool_call_start_token: str = "<minimax:tool_call>"
46
+ self.tool_call_end_token: str = "</minimax:tool_call>"
47
+ self.tool_call_prefix: str = '<invoke name="'
48
+ self.tool_call_function_end_token: str = "</invoke>"
49
+ self.tool_call_regex = re.compile(
50
+ r"<minimax:tool_call>(.*?)</minimax:tool_call>|<minimax:tool_call>(.*?)$",
51
+ re.DOTALL,
52
+ )
53
+ self.tool_call_function_regex = re.compile(
54
+ r"<invoke name=\"(.*?)</invoke>|<invoke name=\"(.*)$", re.DOTALL
55
+ )
56
+ self.tool_call_parameter_regex = re.compile(
57
+ r"<parameter name=\"(.*?)</parameter>|<parameter name=\"(.*?)$", re.DOTALL
58
+ )
59
+ self._buf: str = ""
60
+
61
+ # Streaming state variables
62
+ self._current_function_name: str = ""
63
+ self._current_parameters: Dict[str, Any] = {}
64
+ self._streamed_parameters: Dict[str, str] = (
65
+ {}
66
+ ) # Track what parameter content we've streamed
67
+ self._in_tool_call: bool = False
68
+ self._function_name_sent: bool = False
69
+
70
+ def has_tool_call(self, text: str) -> bool:
71
+ return self.tool_call_start_token in text
72
+
73
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
74
+ normal, calls = self._extract(text, tools)
75
+ return StreamingParseResult(normal_text=normal, calls=calls)
76
+
77
+ def parse_streaming_increment(
78
+ self, new_text: str, tools: List[Tool]
79
+ ) -> StreamingParseResult:
80
+ self._buf += new_text
81
+ normal = ""
82
+ calls: List[ToolCallItem] = []
83
+
84
+ # Build tool indices for validation
85
+ if not hasattr(self, "_tool_indices"):
86
+ self._tool_indices = self._get_tool_indices(tools)
87
+
88
+ while True:
89
+ # If we're not in a tool call and don't see a start token, return normal text
90
+ if not self._in_tool_call and self.tool_call_start_token not in self._buf:
91
+ normal += self._buf
92
+ self._buf = ""
93
+ break
94
+
95
+ # Look for tool call start
96
+ if not self._in_tool_call:
97
+ s = self._buf.find(self.tool_call_start_token)
98
+ if s == -1:
99
+ normal += self._buf
100
+ self._buf = ""
101
+ break
102
+
103
+ normal += self._buf[:s]
104
+ self._buf = self._buf[s:]
105
+
106
+ self._in_tool_call = True
107
+ self._function_name_sent = False
108
+ self._current_function_name = ""
109
+ self._current_parameters = {}
110
+ self._streamed_parameters = {}
111
+
112
+ # Remove the start token
113
+ self._buf = self._buf[len(self.tool_call_start_token) :]
114
+ continue
115
+
116
+ # We're in a tool call, try to parse function name if not sent yet
117
+ if not self._function_name_sent:
118
+ # Look for function name pattern: <invoke name=name>
119
+ function_match = re.search(r"<invoke name=\"([^>]+)\">", self._buf)
120
+ if function_match:
121
+ function_name = function_match.group(1).strip()
122
+
123
+ # Validate function name
124
+ if function_name in self._tool_indices:
125
+ self._current_function_name = function_name
126
+ self._function_name_sent = True
127
+
128
+ # Initialize tool call tracking
129
+ if self.current_tool_id == -1:
130
+ self.current_tool_id = 0
131
+
132
+ # Ensure tracking arrays are large enough
133
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
134
+ self.prev_tool_call_arr.append({})
135
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
136
+ self.streamed_args_for_tool.append("")
137
+
138
+ # Store tool call info
139
+ self.prev_tool_call_arr[self.current_tool_id] = {
140
+ "name": function_name,
141
+ "arguments": {},
142
+ }
143
+
144
+ # Send tool name with empty parameters
145
+ calls.append(
146
+ ToolCallItem(
147
+ tool_index=self.current_tool_id,
148
+ name=function_name,
149
+ parameters="",
150
+ )
151
+ )
152
+
153
+ # Remove the processed function declaration
154
+ self._buf = self._buf[function_match.end() :]
155
+ continue
156
+ else:
157
+ # Invalid function name, reset state
158
+ logger.warning(f"Invalid function name: {function_name}")
159
+ self._reset_streaming_state()
160
+ normal += self._buf
161
+ self._buf = ""
162
+ break
163
+ else:
164
+ # Function name not complete yet, wait for more text
165
+ break
166
+
167
+ # Parse parameters incrementally
168
+ if self._function_name_sent:
169
+ # Process parameters and get any calls to emit
170
+ parameter_calls = self._parse_and_stream_parameters(self._buf)
171
+ calls.extend(parameter_calls)
172
+
173
+ # Check if tool call is complete
174
+ if self.tool_call_function_end_token in self._buf:
175
+ end_pos = self._buf.find(self.tool_call_function_end_token)
176
+
177
+ # Add closing brace to complete the JSON object
178
+ current_streamed = self.streamed_args_for_tool[self.current_tool_id]
179
+ if current_streamed:
180
+ # Count opening and closing braces to check if JSON is complete
181
+ open_braces = current_streamed.count("{")
182
+ close_braces = current_streamed.count("}")
183
+ if open_braces > close_braces:
184
+ calls.append(
185
+ ToolCallItem(
186
+ tool_index=self.current_tool_id,
187
+ name=None,
188
+ parameters="}",
189
+ )
190
+ )
191
+ self.streamed_args_for_tool[self.current_tool_id] = (
192
+ current_streamed + "}"
193
+ )
194
+
195
+ # Complete the tool call
196
+ self._buf = self._buf[
197
+ end_pos + len(self.tool_call_function_end_token) :
198
+ ]
199
+ self._reset_streaming_state(True)
200
+ self.current_tool_id += 1
201
+ continue
202
+ else:
203
+ # Tool call not complete yet, wait for more text
204
+ break
205
+
206
+ return StreamingParseResult(normal_text=normal, calls=calls)
207
+
208
+ def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]:
209
+ """
210
+ Parse complete parameter blocks from text and return any tool call items to emit.
211
+
212
+ This method:
213
+ 1. Finds all complete <parameter> blocks
214
+ 2. Parses them into a dictionary
215
+ 3. Compares with current parameters and generates diff if needed
216
+ 4. Updates internal state
217
+
218
+ Args:
219
+ text_to_parse: The text to search for parameter blocks
220
+
221
+ Returns:
222
+ List of ToolCallItem objects to emit (may be empty)
223
+ """
224
+ calls: List[ToolCallItem] = []
225
+
226
+ # Find all complete parameter patterns
227
+ param_matches = list(
228
+ re.finditer(
229
+ r"<parameter name=\"([^>]+)\">(.*?)</parameter>",
230
+ text_to_parse,
231
+ re.DOTALL,
232
+ )
233
+ )
234
+
235
+ # Build new parameters dictionary
236
+ new_params = {}
237
+ for match in param_matches:
238
+ param_name = match.group(1).strip()
239
+ param_value = match.group(2)
240
+ new_params[param_name] = _safe_val(param_value)
241
+
242
+ # Calculate parameter diff to stream with proper incremental JSON building
243
+ if new_params != self._current_parameters:
244
+ previous_args_json = self.streamed_args_for_tool[self.current_tool_id]
245
+
246
+ # Build incremental JSON properly
247
+ if not self._current_parameters:
248
+ # First parameter(s) - start JSON object but don't close it yet
249
+ items = []
250
+ for key, value in new_params.items():
251
+ items.append(
252
+ f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
253
+ )
254
+ json_fragment = "{" + ", ".join(items)
255
+
256
+ calls.append(
257
+ ToolCallItem(
258
+ tool_index=self.current_tool_id,
259
+ name=None,
260
+ parameters=json_fragment,
261
+ )
262
+ )
263
+ self.streamed_args_for_tool[self.current_tool_id] = json_fragment
264
+
265
+ else:
266
+ # Additional parameters - add them incrementally
267
+ new_keys = set(new_params.keys()) - set(self._current_parameters.keys())
268
+ if new_keys:
269
+ # Build the continuation part (no closing brace yet)
270
+ continuation_parts = []
271
+ for key in new_keys:
272
+ value = new_params[key]
273
+ continuation_parts.append(
274
+ f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
275
+ )
276
+
277
+ json_fragment = ", " + ", ".join(continuation_parts)
278
+
279
+ calls.append(
280
+ ToolCallItem(
281
+ tool_index=self.current_tool_id,
282
+ name=None,
283
+ parameters=json_fragment,
284
+ )
285
+ )
286
+ self.streamed_args_for_tool[self.current_tool_id] = (
287
+ previous_args_json + json_fragment
288
+ )
289
+
290
+ # Update current state
291
+ self._current_parameters = new_params
292
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
293
+
294
+ return calls
295
+
296
+ def _reset_streaming_state(self, still_in_tool_call: bool = False):
297
+ """Reset streaming state for the next tool call"""
298
+ self._in_tool_call = still_in_tool_call
299
+ self._function_name_sent = False
300
+ self._current_function_name = ""
301
+ self._current_parameters = {}
302
+ self._streamed_parameters = {}
303
+ self.current_tool_name_sent = False
304
+
305
+ def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
306
+ normal_parts: List[str] = []
307
+ calls: List[ToolCallItem] = []
308
+ cursor = 0
309
+ while True:
310
+ s = text.find(self.tool_call_start_token, cursor)
311
+ if s == -1:
312
+ normal_parts.append(text[cursor:])
313
+ break
314
+ normal_parts.append(text[cursor:s])
315
+ e = text.find(self.tool_call_end_token, s)
316
+ if e == -1:
317
+ normal_parts.append(text[s:])
318
+ break
319
+ block = text[s : e + len(self.tool_call_end_token)]
320
+ cursor = e + len(self.tool_call_end_token)
321
+ calls.extend(self._parse_block(block, tools))
322
+ return "".join(normal_parts), calls
323
+
324
+ def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
325
+ res: List[ToolCallItem] = []
326
+ for m in self.tool_call_function_regex.findall(block):
327
+ txt = m[0] if m[0] else m[1]
328
+ if '">' not in txt:
329
+ continue
330
+ idx = txt.index('">')
331
+ fname = txt[:idx].strip()
332
+ body = txt[idx + 2 :]
333
+ params: Dict[str, Any] = {}
334
+ for pm in self.tool_call_parameter_regex.findall(body):
335
+ ptxt = pm[0] if pm[0] else pm[1]
336
+ if '">' not in ptxt:
337
+ continue
338
+ pidx = ptxt.index('">')
339
+ pname = ptxt[:pidx].strip()
340
+ pval = ptxt[pidx + 2 :].lstrip("\n").rstrip("\n")
341
+ params[pname] = _safe_val(pval)
342
+ raw = {"name": fname, "arguments": params}
343
+ try:
344
+ # TODO: fix idx in function call, the index for a function
345
+ # call will always be -1 in parse_base_json
346
+ res.extend(self.parse_base_json(raw, tools))
347
+ except Exception:
348
+ logger.warning("invalid tool call for %s dropped", fname)
349
+ return res
350
+
351
+ def supports_structural_tag(self) -> bool:
352
+ return False
353
+
354
+ def structure_info(self) -> _GetInfoFunc:
355
+ raise NotImplementedError
356
+
357
+ def build_ebnf(self, tools: List[Tool]):
358
+ return EBNFComposer.build_ebnf(
359
+ tools,
360
+ individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
361
+ individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
362
+ tool_call_separator="\\n",
363
+ function_format="xml",
364
+ call_rule_fmt='"<invoke name=\\"{name}\\">\\n" {arguments_rule} "\\n</invoke>"',
365
+ key_value_rule_fmt='"<parameter name=\\"{key}\\">\\n" {valrule} "\\n</parameter>"',
366
+ key_value_separator='"\\n"',
367
+ )
@@ -18,6 +18,9 @@ Options:
18
18
  ### Install Dependencies
19
19
  pip install "grpcio==1.75.1" "grpcio-tools==1.75.1"
20
20
 
21
+ Please make sure to use the same version of grpcio and grpcio-tools specified in pyproject.toml
22
+ otherwise update the versions specified in pyproject.toml
23
+
21
24
  ### Run Script
22
25
  cd python/sglang/srt/grpc
23
26
  python compile_proto.py
@@ -29,6 +29,7 @@ from sglang.srt.distributed import (
29
29
  get_tensor_model_parallel_world_size,
30
30
  )
31
31
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
32
+ from sglang.srt.server_args import get_global_server_args
32
33
  from sglang.srt.utils import (
33
34
  cpu_has_amx_support,
34
35
  is_cpu,
@@ -59,6 +60,11 @@ logger = logging.getLogger(__name__)
59
60
 
60
61
 
61
62
  class SiluAndMul(CustomOp):
63
+ def __init__(self, *args, **kwargs):
64
+ super().__init__(*args, **kwargs)
65
+ if get_global_server_args().rl_on_policy_target == "fsdp":
66
+ self._forward_method = self.forward_native
67
+
62
68
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
63
69
  d = x.shape[-1] // 2
64
70
  return F.silu(x[..., :d]) * x[..., d:]
@@ -59,6 +59,19 @@ class AscendAttnBackend(AttentionBackend):
59
59
  )
60
60
  self.mask_len = max_seq_len
61
61
 
62
+ def get_verify_buffers_to_fill_after_draft(self):
63
+ """
64
+ Return buffers for verify attention kernels that needs to be filled after draft.
65
+
66
+ Typically, these are tree mask and position buffers.
67
+ """
68
+ return [None, None]
69
+
70
+ def update_verify_buffers_to_fill_after_draft(
71
+ self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
72
+ ):
73
+ pass
74
+
62
75
  def __init__(self, model_runner: ModelRunner):
63
76
  super().__init__()
64
77
  self.forward_metadata = None
@@ -87,15 +100,22 @@ class AscendAttnBackend(AttentionBackend):
87
100
  device=model_runner.device,
88
101
  )
89
102
  )
103
+ self.speculative_num_draft_tokens = (
104
+ model_runner.server_args.speculative_num_draft_tokens
105
+ )
106
+ self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu()
107
+ self.mtp_mask = ~self.mtp_mask
90
108
 
91
109
  def init_forward_metadata(self, forward_batch: ForwardBatch):
92
110
  """Init the metadata for a forward pass."""
93
111
  tp_size = get_attention_tp_size()
94
112
  self.forward_metadata = ForwardMetadata()
95
-
113
+ seq_lens_max = forward_batch.seq_lens.max()
114
+ if forward_batch.forward_mode.is_target_verify():
115
+ seq_lens_max += self.speculative_num_draft_tokens
96
116
  self.forward_metadata.block_tables = (
97
117
  forward_batch.req_to_token_pool.req_to_token[
98
- forward_batch.req_pool_indices, : forward_batch.seq_lens.max()
118
+ forward_batch.req_pool_indices, :seq_lens_max
99
119
  ][:, :: self.page_size]
100
120
  // self.page_size
101
121
  )
@@ -104,16 +124,23 @@ class AscendAttnBackend(AttentionBackend):
104
124
  forward_batch.extend_seq_lens.cpu().int()
105
125
  )
106
126
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
127
+ if (
128
+ not forward_batch.forward_mode.is_draft_extend_v2()
129
+ and not forward_batch.forward_mode.is_draft_extend()
130
+ and not forward_batch.forward_mode.is_target_verify()
131
+ ):
132
+ seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
133
+ self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
107
134
 
108
- seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
109
- self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
135
+ if forward_batch.forward_mode.is_target_verify():
136
+ self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens
110
137
 
111
138
  self.graph_mode = False
112
139
 
113
140
  def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
114
141
  self.graph_metadata = {
115
142
  "block_tables": torch.empty(
116
- (max_bs, self.max_context_len // self.page_size),
143
+ (max_bs, (self.max_context_len + self.page_size - 1) // self.page_size),
117
144
  dtype=torch.int32,
118
145
  device=self.device,
119
146
  ),
@@ -156,6 +183,8 @@ class AscendAttnBackend(AttentionBackend):
156
183
  ):
157
184
  metadata = self.graph_metadata[bs]
158
185
  max_len = seq_lens_cpu[:bs].max().item()
186
+ if forward_mode.is_target_verify():
187
+ max_len += self.speculative_num_draft_tokens
159
188
  max_seq_pages = (max_len + self.page_size - 1) // self.page_size
160
189
 
161
190
  metadata.block_tables[:bs, :max_seq_pages].copy_(
@@ -257,6 +286,25 @@ class AscendAttnBackend(AttentionBackend):
257
286
  k_rope,
258
287
  topk_indices,
259
288
  )
289
+ if (
290
+ forward_batch.forward_mode.is_target_verify()
291
+ or forward_batch.forward_mode.is_draft_extend()
292
+ or forward_batch.forward_mode.is_draft_extend_v2()
293
+ ):
294
+
295
+ if is_mla_preprocess_enabled():
296
+ save_kv_cache = False
297
+ return self.forward_mtp(
298
+ q,
299
+ k,
300
+ v,
301
+ layer,
302
+ forward_batch,
303
+ save_kv_cache,
304
+ q_rope=q_rope,
305
+ k_rope=k_rope,
306
+ )
307
+
260
308
  if not self.use_mla:
261
309
  if save_kv_cache:
262
310
  forward_batch.token_to_kv_pool.set_kv_buffer(
@@ -393,6 +441,118 @@ class AscendAttnBackend(AttentionBackend):
393
441
  )
394
442
  return attn_output
395
443
 
444
+ def forward_mtp(
445
+ self,
446
+ q,
447
+ k,
448
+ v,
449
+ layer: RadixAttention,
450
+ forward_batch: ForwardBatch,
451
+ save_kv_cache: bool,
452
+ q_rope: Optional[torch.Tensor] = None,
453
+ k_rope: Optional[torch.Tensor] = None,
454
+ ):
455
+ if save_kv_cache:
456
+ if self.use_mla:
457
+ k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
458
+ k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
459
+ forward_batch.token_to_kv_pool.set_kv_buffer(
460
+ layer, forward_batch.out_cache_loc, k, k_rope
461
+ )
462
+ else:
463
+ forward_batch.token_to_kv_pool.set_kv_buffer(
464
+ layer, forward_batch.out_cache_loc, k, v
465
+ )
466
+
467
+ c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
468
+ k_rope_cache = k_rope.view(
469
+ -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim
470
+ )
471
+ c_kv_cache = c_kv.view(
472
+ -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
473
+ )
474
+
475
+ q_nope = q.view(-1, layer.tp_q_head_num, self.kv_lora_rank)
476
+ q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim)
477
+ if not self.graph_mode:
478
+ num_token_padding = q.shape[0]
479
+ q_nope = q_nope[: forward_batch.num_token_non_padded_cpu]
480
+ q_rope = q_rope[: forward_batch.num_token_non_padded_cpu]
481
+ if self.forward_metadata.seq_lens_cpu_int is None:
482
+ actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_list
483
+ else:
484
+ actual_seq_lengths_kv = (
485
+ self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist()
486
+ )
487
+ if forward_batch.forward_mode.is_draft_extend():
488
+ actual_seq_lengths = (
489
+ np.array(forward_batch.extend_seq_lens_cpu).cumsum().tolist()
490
+ )
491
+ else:
492
+ actual_seq_lengths = np.arange(
493
+ self.speculative_num_draft_tokens,
494
+ self.speculative_num_draft_tokens + q_nope.shape[0],
495
+ self.speculative_num_draft_tokens,
496
+ )
497
+
498
+ workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
499
+ q_nope,
500
+ c_kv_cache,
501
+ c_kv_cache,
502
+ query_rope=q_rope,
503
+ key_rope=k_rope_cache,
504
+ num_heads=layer.tp_q_head_num,
505
+ num_key_value_heads=layer.tp_k_head_num,
506
+ input_layout="TND",
507
+ scale=layer.scaling,
508
+ antiquant_mode=0,
509
+ antiquant_scale=None,
510
+ block_table=self.forward_metadata.block_tables,
511
+ block_size=self.page_size,
512
+ sparse_mode=3,
513
+ atten_mask=self.mtp_mask,
514
+ actual_seq_lengths=actual_seq_lengths,
515
+ actual_seq_lengths_kv=actual_seq_lengths_kv,
516
+ )
517
+ attn_output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
518
+ softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
519
+ torch_npu.npu_fused_infer_attention_score.out(
520
+ q_nope,
521
+ c_kv_cache,
522
+ c_kv_cache,
523
+ query_rope=q_rope,
524
+ key_rope=k_rope_cache,
525
+ num_heads=layer.tp_q_head_num,
526
+ num_key_value_heads=layer.tp_k_head_num,
527
+ input_layout="TND",
528
+ scale=layer.scaling,
529
+ antiquant_mode=0,
530
+ antiquant_scale=None,
531
+ block_table=self.forward_metadata.block_tables,
532
+ block_size=self.page_size,
533
+ sparse_mode=3,
534
+ atten_mask=self.mtp_mask,
535
+ actual_seq_lengths=actual_seq_lengths,
536
+ actual_seq_lengths_kv=actual_seq_lengths_kv,
537
+ workspace=workspace,
538
+ out=[attn_output, softmax_lse],
539
+ )
540
+ attn_output = attn_output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
541
+ if (
542
+ not self.graph_mode
543
+ and forward_batch.num_token_non_padded_cpu != num_token_padding
544
+ ):
545
+ attn_output = torch.cat(
546
+ [
547
+ attn_output,
548
+ attn_output.new_zeros(
549
+ num_token_padding - attn_output.shape[0], *attn_output.shape[1:]
550
+ ),
551
+ ],
552
+ dim=0,
553
+ )
554
+ return attn_output
555
+
396
556
  def forward_decode_graph(
397
557
  self,
398
558
  q: torch.Tensor,
@@ -690,3 +850,71 @@ class AscendAttnBackend(AttentionBackend):
690
850
  out=attn_output,
691
851
  )
692
852
  return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)
853
+
854
+
855
+ class AscendAttnMultiStepDraftBackend:
856
+ """
857
+ Wrap multiple Ascend attention backends as one for multiple consecutive
858
+ draft decoding steps
859
+ """
860
+
861
+ def __init__(
862
+ self,
863
+ model_runner: ModelRunner,
864
+ topk: int,
865
+ speculative_num_steps: int,
866
+ ):
867
+ self.topk = topk
868
+ self.speculative_num_steps = speculative_num_steps
869
+
870
+ self.attn_backends = []
871
+ for _ in range(self.speculative_num_steps):
872
+ self.attn_backends.append(AscendAttnBackend(model_runner))
873
+
874
+ def common_template(self, forward_batch: ForwardBatch, call_fn: int):
875
+ assert forward_batch.spec_info is not None
876
+
877
+ for i in range(self.speculative_num_steps - 1):
878
+ call_fn(i, forward_batch)
879
+
880
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
881
+ def call_fn(i, forward_batch):
882
+ assert forward_batch.spec_info is not None
883
+ self.attn_backends[i].init_forward_metadata(forward_batch)
884
+
885
+ self.common_template(forward_batch, call_fn)
886
+
887
+ def init_cuda_graph_state(self, max_bs, max_num_tokens):
888
+ for i in range(self.speculative_num_steps):
889
+ self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
890
+
891
+ def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
892
+ def call_fn(i, forward_batch):
893
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
894
+ forward_batch.batch_size,
895
+ forward_batch.batch_size * self.topk,
896
+ forward_batch.req_pool_indices,
897
+ forward_batch.seq_lens,
898
+ encoder_lens=None,
899
+ forward_mode=ForwardMode.DECODE,
900
+ spec_info=forward_batch.spec_info,
901
+ )
902
+
903
+ self.common_template(forward_batch, call_fn)
904
+
905
+ def init_forward_metadata_replay_cuda_graph(
906
+ self, forward_batch: ForwardBatch, bs: int
907
+ ):
908
+ def call_fn(i, forward_batch):
909
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
910
+ bs,
911
+ forward_batch.req_pool_indices,
912
+ forward_batch.seq_lens,
913
+ seq_lens_sum=-1,
914
+ encoder_lens=None,
915
+ forward_mode=ForwardMode.DECODE,
916
+ spec_info=forward_batch.spec_info,
917
+ seq_lens_cpu=None,
918
+ )
919
+
920
+ self.common_template(forward_batch, call_fn)