sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
1
+ import logging
2
+ from typing import Union
3
+
4
+ from fastapi import Request
5
+
6
+ from sglang.srt.entrypoints.openai.protocol import (
7
+ ErrorResponse,
8
+ ScoringRequest,
9
+ ScoringResponse,
10
+ )
11
+ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class OpenAIServingScore(OpenAIServingBase):
17
+ """Handler for /v1/score requests"""
18
+
19
+ # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
20
+ # to another module in the future.
21
+
22
+ def _request_id_prefix(self) -> str:
23
+ return "score-"
24
+
25
+ def _convert_to_internal_request(
26
+ self,
27
+ request: ScoringRequest,
28
+ ) -> tuple[ScoringRequest, ScoringRequest]:
29
+ """Convert OpenAI scoring request to internal format"""
30
+ # For scoring, we pass the request directly as the tokenizer_manager
31
+ # has a specialized score_request method that doesn't use GenerateReqInput
32
+
33
+ return request, request
34
+
35
+ async def _handle_non_streaming_request(
36
+ self,
37
+ adapted_request: ScoringRequest,
38
+ request: ScoringRequest,
39
+ raw_request: Request,
40
+ ) -> Union[ScoringResponse, ErrorResponse]:
41
+ """Handle the scoring request"""
42
+ try:
43
+ # Use tokenizer_manager's score_request method directly
44
+ scores = await self.tokenizer_manager.score_request(
45
+ query=request.query,
46
+ items=request.items,
47
+ label_token_ids=request.label_token_ids,
48
+ apply_softmax=request.apply_softmax,
49
+ item_first=request.item_first,
50
+ request=raw_request,
51
+ )
52
+
53
+ # Create response with just the scores, without usage info
54
+ response = ScoringResponse(
55
+ scores=scores,
56
+ model=request.model,
57
+ )
58
+ return response
59
+
60
+ except ValueError as e:
61
+ return self.create_error_response(str(e))
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Mapping, Optional, final
4
+
5
+ from sglang.srt.entrypoints.openai.protocol import UsageInfo
6
+
7
+
8
+ @final
9
+ class UsageProcessor:
10
+ """Stateless helpers that turn raw token counts into a UsageInfo."""
11
+
12
+ @staticmethod
13
+ def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
14
+ """Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
15
+ return {"cached_tokens": count} if count > 0 else None
16
+
17
+ @staticmethod
18
+ def calculate_response_usage(
19
+ responses: List[Dict[str, Any]],
20
+ n_choices: int = 1,
21
+ enable_cache_report: bool = False,
22
+ ) -> UsageInfo:
23
+ completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)
24
+
25
+ prompt_tokens = sum(
26
+ responses[i]["meta_info"]["prompt_tokens"]
27
+ for i in range(0, len(responses), n_choices)
28
+ )
29
+
30
+ cached_details = None
31
+ if enable_cache_report:
32
+ cached_total = sum(
33
+ r["meta_info"].get("cached_tokens", 0) for r in responses
34
+ )
35
+ cached_details = UsageProcessor._details_if_cached(cached_total)
36
+
37
+ return UsageProcessor.calculate_token_usage(
38
+ prompt_tokens=prompt_tokens,
39
+ completion_tokens=completion_tokens,
40
+ cached_tokens=cached_details,
41
+ )
42
+
43
+ @staticmethod
44
+ def calculate_streaming_usage(
45
+ prompt_tokens: Mapping[int, int],
46
+ completion_tokens: Mapping[int, int],
47
+ cached_tokens: Mapping[int, int],
48
+ n_choices: int,
49
+ enable_cache_report: bool = False,
50
+ ) -> UsageInfo:
51
+ # index % n_choices == 0 marks the first choice of a prompt
52
+ total_prompt_tokens = sum(
53
+ tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
54
+ )
55
+ total_completion_tokens = sum(completion_tokens.values())
56
+
57
+ cached_details = (
58
+ UsageProcessor._details_if_cached(sum(cached_tokens.values()))
59
+ if enable_cache_report
60
+ else None
61
+ )
62
+
63
+ return UsageProcessor.calculate_token_usage(
64
+ prompt_tokens=total_prompt_tokens,
65
+ completion_tokens=total_completion_tokens,
66
+ cached_tokens=cached_details,
67
+ )
68
+
69
+ @staticmethod
70
+ def calculate_token_usage(
71
+ prompt_tokens: int,
72
+ completion_tokens: int,
73
+ cached_tokens: Optional[Dict[str, int]] = None,
74
+ ) -> UsageInfo:
75
+ """Calculate token usage information"""
76
+ return UsageInfo(
77
+ prompt_tokens=prompt_tokens,
78
+ completion_tokens=completion_tokens,
79
+ total_tokens=prompt_tokens + completion_tokens,
80
+ prompt_tokens_details=cached_tokens,
81
+ )
@@ -0,0 +1,72 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ from sglang.srt.entrypoints.openai.protocol import (
5
+ ChatCompletionRequest,
6
+ CompletionRequest,
7
+ LogProbs,
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def to_openai_style_logprobs(
14
+ input_token_logprobs=None,
15
+ output_token_logprobs=None,
16
+ input_top_logprobs=None,
17
+ output_top_logprobs=None,
18
+ ):
19
+ ret_logprobs = LogProbs()
20
+
21
+ def append_token_logprobs(token_logprobs):
22
+ for logprob, _, token_text in token_logprobs:
23
+ ret_logprobs.tokens.append(token_text)
24
+ ret_logprobs.token_logprobs.append(logprob)
25
+
26
+ # Not supported yet
27
+ ret_logprobs.text_offset.append(-1)
28
+
29
+ def append_top_logprobs(top_logprobs):
30
+ for tokens in top_logprobs:
31
+ if tokens is not None:
32
+ ret_logprobs.top_logprobs.append(
33
+ {token[2]: token[0] for token in tokens}
34
+ )
35
+ else:
36
+ ret_logprobs.top_logprobs.append(None)
37
+
38
+ if input_token_logprobs is not None:
39
+ append_token_logprobs(input_token_logprobs)
40
+ if output_token_logprobs is not None:
41
+ append_token_logprobs(output_token_logprobs)
42
+ if input_top_logprobs is not None:
43
+ append_top_logprobs(input_top_logprobs)
44
+ if output_top_logprobs is not None:
45
+ append_top_logprobs(output_top_logprobs)
46
+
47
+ return ret_logprobs
48
+
49
+
50
+ def process_hidden_states_from_ret(
51
+ ret_item: Dict[str, Any],
52
+ request: Union[
53
+ ChatCompletionRequest,
54
+ CompletionRequest,
55
+ ],
56
+ ) -> Optional[List]:
57
+ """Process hidden states from a ret item in non-streaming response.
58
+
59
+ Args:
60
+ ret_item: Response item containing meta_info
61
+ request: The original request object
62
+
63
+ Returns:
64
+ Processed hidden states for the last token, or None
65
+ """
66
+ if not request.return_hidden_states:
67
+ return None
68
+
69
+ hidden_states = ret_item["meta_info"].get("hidden_states", None)
70
+ if hidden_states is not None:
71
+ hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []
72
+ return hidden_states
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
6
6
  from partial_json_parser.core.exceptions import MalformedJSON
7
7
  from partial_json_parser.core.options import Allow
8
8
 
9
+ from sglang.srt.entrypoints.openai.protocol import Tool
9
10
  from sglang.srt.function_call.core_types import (
10
11
  StreamingParseResult,
11
12
  ToolCallItem,
@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
16
17
  _is_complete_json,
17
18
  _partial_json_loads,
18
19
  )
19
- from sglang.srt.openai_api.protocol import Tool
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -111,11 +111,10 @@ class BaseFormatDetector(ABC):
111
111
  # The current_text has tool_call if it is the start of a new tool call sequence
112
112
  # or it is the start of a new tool call after a tool call separator, when there is a previous tool call
113
113
  if not (
114
- self.bot_token in current_text
115
- or current_text.startswith("{")
114
+ self.has_tool_call(current_text)
116
115
  or (
117
116
  self.current_tool_id > 0
118
- and current_text.startswith(self.tool_call_separator + "{")
117
+ and current_text.startswith(self.tool_call_separator)
119
118
  )
120
119
  ):
121
120
  # Only clear buffer if we're sure no tool call is starting
@@ -143,6 +142,10 @@ class BaseFormatDetector(ABC):
143
142
  try:
144
143
  if current_text.startswith(self.bot_token):
145
144
  start_idx = len(self.bot_token)
145
+ elif self.current_tool_id > 0 and current_text.startswith(
146
+ self.tool_call_separator + self.bot_token
147
+ ):
148
+ start_idx = len(self.tool_call_separator + self.bot_token)
146
149
  elif self.current_tool_id > 0 and current_text.startswith(
147
150
  self.tool_call_separator
148
151
  ):
@@ -3,6 +3,7 @@ import logging
3
3
  import re
4
4
  from typing import List
5
5
 
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
6
7
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
7
8
  from sglang.srt.function_call.core_types import (
8
9
  StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
12
13
  )
13
14
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
14
15
  from sglang.srt.function_call.utils import _is_complete_json
15
- from sglang.srt.openai_api.protocol import Tool
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
@@ -211,20 +211,74 @@ class EBNFComposer:
211
211
  properties = params.get("properties", {})
212
212
  required_props = set(params.get("required", []))
213
213
 
214
- # Build argument rules for this tool
215
- arg_rules = []
214
+ # The generated pattern ensures:
215
+ # 1. Required properties appear first, joined by commas
216
+ # 2. Optional properties are wrapped with comma included: ( "," ( "prop" : value )? )?
217
+ # 3. For multiple optional properties, we allow flexible ordering:
218
+ # - Each optional can be skipped entirely
219
+ # - They can appear in any combination
220
+ #
221
+ # Example patterns generated:
222
+ # - One required, one optional:
223
+ # "{" "location" ":" string ( "," ( "unit" ":" enum ) )? "}"
224
+ # Allows: {"location": "Paris"} or {"location": "Paris", "unit": "celsius"}
225
+ #
226
+ # - Multiple optional properties with flexible ordering:
227
+ # "{" "req" ":" string ( "," ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value ) )? "}"
228
+ # Allows: {"req": "x"}, {"req": "x", "opt1": "y"}, {"req": "x", "opt2": "z"},
229
+ # {"req": "x", "opt1": "y", "opt2": "z"}
230
+ #
231
+ # - All optional properties with flexible ordering:
232
+ # "{" ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value )? "}"
233
+ # Allows: {}, {"opt1": "x"}, {"opt2": "y"}, {"opt1": "x", "opt2": "y"}
234
+
235
+ prop_kv_pairs = {}
236
+ ordered_props = list(properties.keys())
237
+
216
238
  for prop_name, prop_schema in properties.items():
217
239
  value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
218
240
  # Create key=value pair
219
241
  pair = key_value_template.format(key=prop_name, valrule=value_rule)
220
-
221
- if prop_name not in required_props:
222
- pair = f"[ {pair} ]"
223
-
224
- arg_rules.append(pair)
225
-
226
- # Combine all argument rules
227
- combined_args = ' "," '.join(arg_rules) if arg_rules else ""
242
+ prop_kv_pairs[prop_name] = pair
243
+
244
+ # Separate into required and optional while preserving order
245
+ required = [p for p in ordered_props if p in required_props]
246
+ optional = [p for p in ordered_props if p not in required_props]
247
+
248
+ # Build the combined rule
249
+ rule_parts = []
250
+
251
+ # Add required properties joined by commas
252
+ if required:
253
+ rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required))
254
+
255
+ # Add optional properties with flexible ordering
256
+ if optional:
257
+ # Build alternatives where any optional property can appear first
258
+ opt_alternatives = []
259
+ for i in range(len(optional)):
260
+ # Build pattern for optional[i] appearing first
261
+ opt_parts = []
262
+ for j in range(i, len(optional)):
263
+ if j == i:
264
+ opt_parts.append(prop_kv_pairs[optional[j]])
265
+ else:
266
+ opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?')
267
+ opt_alternatives.append("".join(opt_parts))
268
+
269
+ # Wrap with appropriate comma handling based on whether we have required properties
270
+ if required:
271
+ # Required properties exist, so optional group needs outer comma
272
+ rule_parts.append(' ( "," ( ')
273
+ rule_parts.append(" | ".join(opt_alternatives))
274
+ rule_parts.append(" ) )?")
275
+ else:
276
+ # All properties are optional
277
+ rule_parts.append("( ")
278
+ rule_parts.append(" | ".join(opt_alternatives))
279
+ rule_parts.append(" )?")
280
+
281
+ combined_args = "".join(rule_parts)
228
282
  arguments_rule = args_template.format(arg_rules=combined_args)
229
283
 
230
284
  # Add the function call rule and its arguments rule
@@ -1,6 +1,12 @@
1
1
  import logging
2
2
  from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
3
3
 
4
+ from sglang.srt.entrypoints.openai.protocol import (
5
+ StructuralTagResponseFormat,
6
+ StructuresResponseFormat,
7
+ Tool,
8
+ ToolChoice,
9
+ )
4
10
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
5
11
  from sglang.srt.function_call.core_types import ToolCallItem
6
12
  from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
8
14
  from sglang.srt.function_call.mistral_detector import MistralDetector
9
15
  from sglang.srt.function_call.pythonic_detector import PythonicDetector
10
16
  from sglang.srt.function_call.qwen25_detector import Qwen25Detector
11
- from sglang.srt.openai_api.protocol import (
12
- StructuralTagResponseFormat,
13
- StructuresResponseFormat,
14
- Tool,
15
- ToolChoice,
16
- )
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
@@ -2,6 +2,7 @@ import json
2
2
  import logging
3
3
  from typing import List
4
4
 
5
+ from sglang.srt.entrypoints.openai.protocol import Tool
5
6
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
6
7
  from sglang.srt.function_call.core_types import (
7
8
  StreamingParseResult,
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
9
10
  _GetInfoFunc,
10
11
  )
11
12
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
12
- from sglang.srt.openai_api.protocol import Tool
13
13
 
14
14
  logger = logging.getLogger(__name__)
15
15
 
@@ -3,6 +3,7 @@ import logging
3
3
  import re
4
4
  from typing import List
5
5
 
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
6
7
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
7
8
  from sglang.srt.function_call.core_types import (
8
9
  StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
10
11
  _GetInfoFunc,
11
12
  )
12
13
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
13
- from sglang.srt.openai_api.protocol import Tool
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -4,6 +4,7 @@ import logging
4
4
  import re
5
5
  from typing import List, Optional
6
6
 
7
+ from sglang.srt.entrypoints.openai.protocol import Tool
7
8
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
8
9
  from sglang.srt.function_call.core_types import (
9
10
  StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
12
13
  _GetInfoFunc,
13
14
  )
14
15
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
15
- from sglang.srt.openai_api.protocol import Tool
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
@@ -3,6 +3,7 @@ import logging
3
3
  import re
4
4
  from typing import List
5
5
 
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
6
7
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
7
8
  from sglang.srt.function_call.core_types import (
8
9
  StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
10
11
  _GetInfoFunc,
11
12
  )
12
13
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
13
- from sglang.srt.openai_api.protocol import Tool
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -1,11 +1,12 @@
1
- """
2
- Utility functions for OpenAI API adapter.
1
+ """Template utilities for Jinja template processing.
2
+
3
+ This module provides utilities for analyzing and processing Jinja chat templates,
4
+ including content format detection and message processing.
3
5
  """
4
6
 
5
7
  import logging
6
- from typing import Dict, List
7
8
 
8
- import jinja2.nodes
9
+ import jinja2
9
10
  import transformers.utils.chat_template_utils as hf_chat_utils
10
11
 
11
12
  logger = logging.getLogger(__name__)
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
75
76
  return None
76
77
 
77
78
 
78
- def detect_template_content_format(chat_template: str) -> str:
79
+ def detect_jinja_template_content_format(chat_template: str) -> str:
79
80
  """
80
81
  Detect whether a chat template expects 'string' or 'openai' content format.
81
82
 
@@ -29,16 +29,28 @@ from sglang.srt.distributed import (
29
29
  get_tensor_model_parallel_world_size,
30
30
  )
31
31
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
32
- from sglang.srt.utils import is_cuda, set_weight_attrs
32
+ from sglang.srt.utils import (
33
+ cpu_has_amx_support,
34
+ is_cpu,
35
+ is_cuda,
36
+ is_npu,
37
+ set_weight_attrs,
38
+ )
33
39
  from sglang.utils import resolve_obj_by_qualname
34
40
 
35
41
  _is_cuda = is_cuda()
42
+ _is_npu = is_npu()
43
+ _is_cpu_amx_available = cpu_has_amx_support()
44
+ _is_cpu = is_cpu()
36
45
 
37
46
  if _is_cuda:
38
47
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
39
48
 
40
49
  logger = logging.getLogger(__name__)
41
50
 
51
+ if is_npu():
52
+ import torch_npu
53
+
42
54
 
43
55
  class SiluAndMul(CustomOp):
44
56
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -52,6 +64,19 @@ class SiluAndMul(CustomOp):
52
64
  silu_and_mul(x, out)
53
65
  return out
54
66
 
67
+ def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
68
+ if _is_cpu_amx_available:
69
+ d = x.shape[-1] // 2
70
+ output_shape = x.shape[:-1] + (d,)
71
+ out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
72
+ return out
73
+ else:
74
+ return self.forward_native(x)
75
+
76
+ def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
77
+ out = torch_npu.npu_swiglu(x)
78
+ return out
79
+
55
80
 
56
81
  class GeluAndMul(CustomOp):
57
82
  def __init__(self, approximate="tanh"):
@@ -184,8 +209,8 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
184
209
  return nn.Identity()
185
210
 
186
211
 
187
- if not _is_cuda:
212
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
188
213
  logger.info(
189
- "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
214
+ "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
190
215
  )
191
216
  from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
324
324
  )
325
325
 
326
326
  def init_cuda_graph_state(
327
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
327
+ self,
328
+ max_bs: int,
329
+ max_num_tokens: int,
330
+ kv_indices_buf: Optional[torch.Tensor] = None,
328
331
  ):
329
332
  self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
330
333
  if kv_indices_buf is None:
@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
338
341
 
339
342
  if not self.skip_prefill:
340
343
  self.cuda_graph_custom_mask = torch.zeros(
341
- (max_bs * self.max_context_len),
344
+ (max_num_tokens * self.max_context_len),
342
345
  dtype=torch.uint8,
343
346
  device=self.device,
344
347
  )
@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
19
19
  """Init the metadata for a forward pass."""
20
20
  raise NotImplementedError()
21
21
 
22
- def init_cuda_graph_state(self, max_bs: int):
22
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
23
23
  """Init the global shared states for cuda graph."""
24
24
  raise NotImplementedError()
25
25
 
@@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
122
122
  def init_cuda_graph_state(
123
123
  self,
124
124
  max_bs: int,
125
+ max_num_tokens: int,
125
126
  block_kv_indices: Optional[torch.Tensor] = None,
126
127
  ):
127
128
  if block_kv_indices is None: