sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
1
+ import logging
2
+ from typing import Union
3
+
4
+ from fastapi import Request
5
+
6
+ from sglang.srt.entrypoints.openai.protocol import (
7
+ ErrorResponse,
8
+ ScoringRequest,
9
+ ScoringResponse,
10
+ )
11
+ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class OpenAIServingScore(OpenAIServingBase):
17
+ """Handler for /v1/score requests"""
18
+
19
+ # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
20
+ # to another module in the future.
21
+
22
+ def _request_id_prefix(self) -> str:
23
+ return "score-"
24
+
25
+ def _convert_to_internal_request(
26
+ self,
27
+ request: ScoringRequest,
28
+ ) -> tuple[ScoringRequest, ScoringRequest]:
29
+ """Convert OpenAI scoring request to internal format"""
30
+ # For scoring, we pass the request directly as the tokenizer_manager
31
+ # has a specialized score_request method that doesn't use GenerateReqInput
32
+
33
+ return request, request
34
+
35
+ async def _handle_non_streaming_request(
36
+ self,
37
+ adapted_request: ScoringRequest,
38
+ request: ScoringRequest,
39
+ raw_request: Request,
40
+ ) -> Union[ScoringResponse, ErrorResponse]:
41
+ """Handle the scoring request"""
42
+ try:
43
+ # Use tokenizer_manager's score_request method directly
44
+ scores = await self.tokenizer_manager.score_request(
45
+ query=request.query,
46
+ items=request.items,
47
+ label_token_ids=request.label_token_ids,
48
+ apply_softmax=request.apply_softmax,
49
+ item_first=request.item_first,
50
+ request=raw_request,
51
+ )
52
+
53
+ # Create response with just the scores, without usage info
54
+ response = ScoringResponse(
55
+ scores=scores,
56
+ model=request.model,
57
+ )
58
+ return response
59
+
60
+ except ValueError as e:
61
+ return self.create_error_response(str(e))
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Mapping, Optional, final
4
+
5
+ from sglang.srt.entrypoints.openai.protocol import UsageInfo
6
+
7
+
8
+ @final
9
+ class UsageProcessor:
10
+ """Stateless helpers that turn raw token counts into a UsageInfo."""
11
+
12
+ @staticmethod
13
+ def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
14
+ """Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
15
+ return {"cached_tokens": count} if count > 0 else None
16
+
17
+ @staticmethod
18
+ def calculate_response_usage(
19
+ responses: List[Dict[str, Any]],
20
+ n_choices: int = 1,
21
+ enable_cache_report: bool = False,
22
+ ) -> UsageInfo:
23
+ completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)
24
+
25
+ prompt_tokens = sum(
26
+ responses[i]["meta_info"]["prompt_tokens"]
27
+ for i in range(0, len(responses), n_choices)
28
+ )
29
+
30
+ cached_details = None
31
+ if enable_cache_report:
32
+ cached_total = sum(
33
+ r["meta_info"].get("cached_tokens", 0) for r in responses
34
+ )
35
+ cached_details = UsageProcessor._details_if_cached(cached_total)
36
+
37
+ return UsageProcessor.calculate_token_usage(
38
+ prompt_tokens=prompt_tokens,
39
+ completion_tokens=completion_tokens,
40
+ cached_tokens=cached_details,
41
+ )
42
+
43
+ @staticmethod
44
+ def calculate_streaming_usage(
45
+ prompt_tokens: Mapping[int, int],
46
+ completion_tokens: Mapping[int, int],
47
+ cached_tokens: Mapping[int, int],
48
+ n_choices: int,
49
+ enable_cache_report: bool = False,
50
+ ) -> UsageInfo:
51
+ # index % n_choices == 0 marks the first choice of a prompt
52
+ total_prompt_tokens = sum(
53
+ tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
54
+ )
55
+ total_completion_tokens = sum(completion_tokens.values())
56
+
57
+ cached_details = (
58
+ UsageProcessor._details_if_cached(sum(cached_tokens.values()))
59
+ if enable_cache_report
60
+ else None
61
+ )
62
+
63
+ return UsageProcessor.calculate_token_usage(
64
+ prompt_tokens=total_prompt_tokens,
65
+ completion_tokens=total_completion_tokens,
66
+ cached_tokens=cached_details,
67
+ )
68
+
69
+ @staticmethod
70
+ def calculate_token_usage(
71
+ prompt_tokens: int,
72
+ completion_tokens: int,
73
+ cached_tokens: Optional[Dict[str, int]] = None,
74
+ ) -> UsageInfo:
75
+ """Calculate token usage information"""
76
+ return UsageInfo(
77
+ prompt_tokens=prompt_tokens,
78
+ completion_tokens=completion_tokens,
79
+ total_tokens=prompt_tokens + completion_tokens,
80
+ prompt_tokens_details=cached_tokens,
81
+ )
@@ -0,0 +1,72 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ from sglang.srt.entrypoints.openai.protocol import (
5
+ ChatCompletionRequest,
6
+ CompletionRequest,
7
+ LogProbs,
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def to_openai_style_logprobs(
14
+ input_token_logprobs=None,
15
+ output_token_logprobs=None,
16
+ input_top_logprobs=None,
17
+ output_top_logprobs=None,
18
+ ):
19
+ ret_logprobs = LogProbs()
20
+
21
+ def append_token_logprobs(token_logprobs):
22
+ for logprob, _, token_text in token_logprobs:
23
+ ret_logprobs.tokens.append(token_text)
24
+ ret_logprobs.token_logprobs.append(logprob)
25
+
26
+ # Not supported yet
27
+ ret_logprobs.text_offset.append(-1)
28
+
29
+ def append_top_logprobs(top_logprobs):
30
+ for tokens in top_logprobs:
31
+ if tokens is not None:
32
+ ret_logprobs.top_logprobs.append(
33
+ {token[2]: token[0] for token in tokens}
34
+ )
35
+ else:
36
+ ret_logprobs.top_logprobs.append(None)
37
+
38
+ if input_token_logprobs is not None:
39
+ append_token_logprobs(input_token_logprobs)
40
+ if output_token_logprobs is not None:
41
+ append_token_logprobs(output_token_logprobs)
42
+ if input_top_logprobs is not None:
43
+ append_top_logprobs(input_top_logprobs)
44
+ if output_top_logprobs is not None:
45
+ append_top_logprobs(output_top_logprobs)
46
+
47
+ return ret_logprobs
48
+
49
+
50
+ def process_hidden_states_from_ret(
51
+ ret_item: Dict[str, Any],
52
+ request: Union[
53
+ ChatCompletionRequest,
54
+ CompletionRequest,
55
+ ],
56
+ ) -> Optional[List]:
57
+ """Process hidden states from a ret item in non-streaming response.
58
+
59
+ Args:
60
+ ret_item: Response item containing meta_info
61
+ request: The original request object
62
+
63
+ Returns:
64
+ Processed hidden states for the last token, or None
65
+ """
66
+ if not request.return_hidden_states:
67
+ return None
68
+
69
+ hidden_states = ret_item["meta_info"].get("hidden_states", None)
70
+ if hidden_states is not None:
71
+ hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []
72
+ return hidden_states
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
6
6
  from partial_json_parser.core.exceptions import MalformedJSON
7
7
  from partial_json_parser.core.options import Allow
8
8
 
9
+ from sglang.srt.entrypoints.openai.protocol import Tool
9
10
  from sglang.srt.function_call.core_types import (
10
11
  StreamingParseResult,
11
12
  ToolCallItem,
@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
16
17
  _is_complete_json,
17
18
  _partial_json_loads,
18
19
  )
19
- from sglang.srt.openai_api.protocol import Tool
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -111,11 +111,10 @@ class BaseFormatDetector(ABC):
111
111
  # The current_text has tool_call if it is the start of a new tool call sequence
112
112
  # or it is the start of a new tool call after a tool call separator, when there is a previous tool call
113
113
  if not (
114
- self.bot_token in current_text
115
- or current_text.startswith("{")
114
+ self.has_tool_call(current_text)
116
115
  or (
117
116
  self.current_tool_id > 0
118
- and current_text.startswith(self.tool_call_separator + "{")
117
+ and current_text.startswith(self.tool_call_separator)
119
118
  )
120
119
  ):
121
120
  # Only clear buffer if we're sure no tool call is starting
@@ -143,6 +142,10 @@ class BaseFormatDetector(ABC):
143
142
  try:
144
143
  if current_text.startswith(self.bot_token):
145
144
  start_idx = len(self.bot_token)
145
+ elif self.current_tool_id > 0 and current_text.startswith(
146
+ self.tool_call_separator + self.bot_token
147
+ ):
148
+ start_idx = len(self.tool_call_separator + self.bot_token)
146
149
  elif self.current_tool_id > 0 and current_text.startswith(
147
150
  self.tool_call_separator
148
151
  ):
@@ -3,6 +3,7 @@ import logging
3
3
  import re
4
4
  from typing import List
5
5
 
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
6
7
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
7
8
  from sglang.srt.function_call.core_types import (
8
9
  StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
12
13
  )
13
14
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
14
15
  from sglang.srt.function_call.utils import _is_complete_json
15
- from sglang.srt.openai_api.protocol import Tool
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
@@ -211,20 +211,74 @@ class EBNFComposer:
211
211
  properties = params.get("properties", {})
212
212
  required_props = set(params.get("required", []))
213
213
 
214
- # Build argument rules for this tool
215
- arg_rules = []
214
+ # The generated pattern ensures:
215
+ # 1. Required properties appear first, joined by commas
216
+ # 2. Optional properties are wrapped with comma included: ( "," ( "prop" : value )? )?
217
+ # 3. For multiple optional properties, we allow flexible ordering:
218
+ # - Each optional can be skipped entirely
219
+ # - They can appear in any combination
220
+ #
221
+ # Example patterns generated:
222
+ # - One required, one optional:
223
+ # "{" "location" ":" string ( "," ( "unit" ":" enum ) )? "}"
224
+ # Allows: {"location": "Paris"} or {"location": "Paris", "unit": "celsius"}
225
+ #
226
+ # - Multiple optional properties with flexible ordering:
227
+ # "{" "req" ":" string ( "," ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value ) )? "}"
228
+ # Allows: {"req": "x"}, {"req": "x", "opt1": "y"}, {"req": "x", "opt2": "z"},
229
+ # {"req": "x", "opt1": "y", "opt2": "z"}
230
+ #
231
+ # - All optional properties with flexible ordering:
232
+ # "{" ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value )? "}"
233
+ # Allows: {}, {"opt1": "x"}, {"opt2": "y"}, {"opt1": "x", "opt2": "y"}
234
+
235
+ prop_kv_pairs = {}
236
+ ordered_props = list(properties.keys())
237
+
216
238
  for prop_name, prop_schema in properties.items():
217
239
  value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
218
240
  # Create key=value pair
219
241
  pair = key_value_template.format(key=prop_name, valrule=value_rule)
220
-
221
- if prop_name not in required_props:
222
- pair = f"[ {pair} ]"
223
-
224
- arg_rules.append(pair)
225
-
226
- # Combine all argument rules
227
- combined_args = ' "," '.join(arg_rules) if arg_rules else ""
242
+ prop_kv_pairs[prop_name] = pair
243
+
244
+ # Separate into required and optional while preserving order
245
+ required = [p for p in ordered_props if p in required_props]
246
+ optional = [p for p in ordered_props if p not in required_props]
247
+
248
+ # Build the combined rule
249
+ rule_parts = []
250
+
251
+ # Add required properties joined by commas
252
+ if required:
253
+ rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required))
254
+
255
+ # Add optional properties with flexible ordering
256
+ if optional:
257
+ # Build alternatives where any optional property can appear first
258
+ opt_alternatives = []
259
+ for i in range(len(optional)):
260
+ # Build pattern for optional[i] appearing first
261
+ opt_parts = []
262
+ for j in range(i, len(optional)):
263
+ if j == i:
264
+ opt_parts.append(prop_kv_pairs[optional[j]])
265
+ else:
266
+ opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?')
267
+ opt_alternatives.append("".join(opt_parts))
268
+
269
+ # Wrap with appropriate comma handling based on whether we have required properties
270
+ if required:
271
+ # Required properties exist, so optional group needs outer comma
272
+ rule_parts.append(' ( "," ( ')
273
+ rule_parts.append(" | ".join(opt_alternatives))
274
+ rule_parts.append(" ) )?")
275
+ else:
276
+ # All properties are optional
277
+ rule_parts.append("( ")
278
+ rule_parts.append(" | ".join(opt_alternatives))
279
+ rule_parts.append(" )?")
280
+
281
+ combined_args = "".join(rule_parts)
228
282
  arguments_rule = args_template.format(arg_rules=combined_args)
229
283
 
230
284
  # Add the function call rule and its arguments rule
@@ -1,6 +1,12 @@
1
1
  import logging
2
2
  from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
3
3
 
4
+ from sglang.srt.entrypoints.openai.protocol import (
5
+ StructuralTagResponseFormat,
6
+ StructuresResponseFormat,
7
+ Tool,
8
+ ToolChoice,
9
+ )
4
10
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
5
11
  from sglang.srt.function_call.core_types import ToolCallItem
6
12
  from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
8
14
  from sglang.srt.function_call.mistral_detector import MistralDetector
9
15
  from sglang.srt.function_call.pythonic_detector import PythonicDetector
10
16
  from sglang.srt.function_call.qwen25_detector import Qwen25Detector
11
- from sglang.srt.openai_api.protocol import (
12
- StructuralTagResponseFormat,
13
- StructuresResponseFormat,
14
- Tool,
15
- ToolChoice,
16
- )
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
@@ -2,6 +2,7 @@ import json
2
2
  import logging
3
3
  from typing import List
4
4
 
5
+ from sglang.srt.entrypoints.openai.protocol import Tool
5
6
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
6
7
  from sglang.srt.function_call.core_types import (
7
8
  StreamingParseResult,
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
9
10
  _GetInfoFunc,
10
11
  )
11
12
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
12
- from sglang.srt.openai_api.protocol import Tool
13
13
 
14
14
  logger = logging.getLogger(__name__)
15
15
 
@@ -3,6 +3,7 @@ import logging
3
3
  import re
4
4
  from typing import List
5
5
 
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
6
7
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
7
8
  from sglang.srt.function_call.core_types import (
8
9
  StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
10
11
  _GetInfoFunc,
11
12
  )
12
13
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
13
- from sglang.srt.openai_api.protocol import Tool
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -4,6 +4,7 @@ import logging
4
4
  import re
5
5
  from typing import List, Optional
6
6
 
7
+ from sglang.srt.entrypoints.openai.protocol import Tool
7
8
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
8
9
  from sglang.srt.function_call.core_types import (
9
10
  StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
12
13
  _GetInfoFunc,
13
14
  )
14
15
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
15
- from sglang.srt.openai_api.protocol import Tool
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
@@ -3,6 +3,7 @@ import logging
3
3
  import re
4
4
  from typing import List
5
5
 
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
6
7
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
7
8
  from sglang.srt.function_call.core_types import (
8
9
  StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
10
11
  _GetInfoFunc,
11
12
  )
12
13
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
13
- from sglang.srt.openai_api.protocol import Tool
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -1,11 +1,12 @@
1
- """
2
- Utility functions for OpenAI API adapter.
1
+ """Template utilities for Jinja template processing.
2
+
3
+ This module provides utilities for analyzing and processing Jinja chat templates,
4
+ including content format detection and message processing.
3
5
  """
4
6
 
5
7
  import logging
6
- from typing import Dict, List
7
8
 
8
- import jinja2.nodes
9
+ import jinja2
9
10
  import transformers.utils.chat_template_utils as hf_chat_utils
10
11
 
11
12
  logger = logging.getLogger(__name__)
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
75
76
  return None
76
77
 
77
78
 
78
- def detect_template_content_format(chat_template: str) -> str:
79
+ def detect_jinja_template_content_format(chat_template: str) -> str:
79
80
  """
80
81
  Detect whether a chat template expects 'string' or 'openai' content format.
81
82
 
@@ -29,10 +29,19 @@ from sglang.srt.distributed import (
29
29
  get_tensor_model_parallel_world_size,
30
30
  )
31
31
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
32
- from sglang.srt.utils import is_cuda, set_weight_attrs
32
+ from sglang.srt.utils import (
33
+ cpu_has_amx_support,
34
+ is_cpu,
35
+ is_cuda,
36
+ is_npu,
37
+ set_weight_attrs,
38
+ )
33
39
  from sglang.utils import resolve_obj_by_qualname
34
40
 
35
41
  _is_cuda = is_cuda()
42
+ _is_npu = is_npu()
43
+ _is_cpu_amx_available = cpu_has_amx_support()
44
+ _is_cpu = is_cpu()
36
45
 
37
46
  if _is_cuda:
38
47
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
@@ -52,6 +61,15 @@ class SiluAndMul(CustomOp):
52
61
  silu_and_mul(x, out)
53
62
  return out
54
63
 
64
+ def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
65
+ if _is_cpu_amx_available:
66
+ d = x.shape[-1] // 2
67
+ output_shape = x.shape[:-1] + (d,)
68
+ out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
69
+ return out
70
+ else:
71
+ return self.forward_native(x)
72
+
55
73
 
56
74
  class GeluAndMul(CustomOp):
57
75
  def __init__(self, approximate="tanh"):
@@ -184,8 +202,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
184
202
  return nn.Identity()
185
203
 
186
204
 
187
- if not _is_cuda:
205
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
188
206
  logger.info(
189
- "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
207
+ "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
190
208
  )
191
209
  from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
324
324
  )
325
325
 
326
326
  def init_cuda_graph_state(
327
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
327
+ self,
328
+ max_bs: int,
329
+ max_num_tokens: int,
330
+ kv_indices_buf: Optional[torch.Tensor] = None,
328
331
  ):
329
332
  self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
330
333
  if kv_indices_buf is None:
@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
338
341
 
339
342
  if not self.skip_prefill:
340
343
  self.cuda_graph_custom_mask = torch.zeros(
341
- (max_bs * self.max_context_len),
344
+ (max_num_tokens * self.max_context_len),
342
345
  dtype=torch.uint8,
343
346
  device=self.device,
344
347
  )
@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
19
19
  """Init the metadata for a forward pass."""
20
20
  raise NotImplementedError()
21
21
 
22
- def init_cuda_graph_state(self, max_bs: int):
22
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
23
23
  """Init the global shared states for cuda graph."""
24
24
  raise NotImplementedError()
25
25
 
@@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
122
122
  def init_cuda_graph_state(
123
123
  self,
124
124
  max_bs: int,
125
+ max_num_tokens: int,
125
126
  block_kv_indices: Optional[torch.Tensor] = None,
126
127
  ):
127
128
  if block_kv_indices is None:
@@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
1120
1120
 
1121
1121
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
1122
1122
 
1123
- def init_cuda_graph_state(self, max_bs: int):
1123
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
1124
1124
  """Initialize CUDA graph state for the attention backend.
1125
1125
 
1126
1126
  Args:
@@ -1704,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend):
1704
1704
 
1705
1705
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1706
1706
  metadata_expand = self.target_verify_metadata_topk_expand[bs]
1707
+
1707
1708
  # metadata_expand.max_seq_len_q = 1, already set in capture
1708
1709
  # metadata_expand.cu_seqlens_q already set in capture
1709
-
1710
1710
  offsets = torch.arange(
1711
1711
  self.speculative_num_draft_tokens, device=device
1712
1712
  ).unsqueeze(
1713
1713
  0
1714
1714
  ) # shape: (1, self.speculative_num_draft_tokens)
1715
+
1715
1716
  cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
1716
1717
  cum_len = torch.nn.functional.pad(
1717
1718
  torch.cumsum(
@@ -1728,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend):
1728
1729
  ).view(1, -1)
1729
1730
  # avoid extracting padded seq indices which will be out of boundary
1730
1731
  mask_extraction_indices[
1731
- :, spec_info.positions.numel() * self.speculative_num_draft_tokens :
1732
+ :,
1733
+ spec_info.positions.numel() * self.speculative_num_draft_tokens :,
1732
1734
  ].fill_(0)
1733
-
1734
1735
  mask = spec_info.custom_mask[mask_extraction_indices].view(
1735
1736
  -1, self.speculative_num_draft_tokens
1736
1737
  ) # (bsz * draft_num, draft_num)
1738
+
1737
1739
  col_indices = offsets.expand(
1738
1740
  mask.shape[0], self.speculative_num_draft_tokens
1739
1741
  )
1740
1742
  keys = torch.where(
1741
- mask, col_indices, col_indices + self.speculative_num_draft_tokens
1743
+ mask,
1744
+ col_indices,
1745
+ col_indices + self.speculative_num_draft_tokens,
1742
1746
  )
1743
1747
  _, sort_order = torch.sort(keys, dim=1)
1744
1748
 
@@ -1747,6 +1751,7 @@ class FlashAttentionBackend(AttentionBackend):
1747
1751
  .gather(1, cols)
1748
1752
  .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
1749
1753
  ) # (bsz, draft_num)
1754
+
1750
1755
  metadata_expand.page_table.copy_(
1751
1756
  non_masked_page_table.gather(1, sort_order)
1752
1757
  )
@@ -1758,6 +1763,7 @@ class FlashAttentionBackend(AttentionBackend):
1758
1763
  dtype=torch.int32,
1759
1764
  )
1760
1765
  )
1766
+
1761
1767
  elif forward_mode.is_draft_extend():
1762
1768
  metadata = self.draft_extend_metadata[bs]
1763
1769
  metadata.cache_seqlens_int32.copy_(seq_lens)
@@ -1767,7 +1773,11 @@ class FlashAttentionBackend(AttentionBackend):
1767
1773
  torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
1768
1774
  )
1769
1775
  accept_length = spec_info.accept_length[:bs]
1770
- metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
1776
+ if spec_info.accept_length_cpu:
1777
+ metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
1778
+ else:
1779
+ metadata.max_seq_len_q = 1
1780
+
1771
1781
  metadata.cu_seqlens_q[1:].copy_(
1772
1782
  torch.cumsum(accept_length, dim=0, dtype=torch.int32)
1773
1783
  )
@@ -1807,7 +1817,7 @@ class FlashAttentionBackend(AttentionBackend):
1807
1817
 
1808
1818
  def get_cuda_graph_seq_len_fill_value(self):
1809
1819
  """Get the fill value for sequence length in CUDA graph."""
1810
- return 0
1820
+ return 1
1811
1821
 
1812
1822
  def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
1813
1823
  """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
@@ -1999,9 +2009,9 @@ class FlashAttentionMultiStepBackend:
1999
2009
  for i in range(self.speculative_num_steps - 1):
2000
2010
  self.attn_backends[i].init_forward_metadata(forward_batch)
2001
2011
 
2002
- def init_cuda_graph_state(self, max_bs: int):
2012
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
2003
2013
  for i in range(self.speculative_num_steps):
2004
- self.attn_backends[i].init_cuda_graph_state(max_bs)
2014
+ self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
2005
2015
 
2006
2016
  def init_forward_metadata_capture_cuda_graph(
2007
2017
  self,