sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
1
+ import logging
2
+ from typing import Union
3
+
4
+ from fastapi import Request
5
+
6
+ from sglang.srt.entrypoints.openai.protocol import (
7
+ ErrorResponse,
8
+ ScoringRequest,
9
+ ScoringResponse,
10
+ )
11
+ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class OpenAIServingScore(OpenAIServingBase):
17
+ """Handler for /v1/score requests"""
18
+
19
+ # NOTE: /v1/rerank is not an official OpenAI endpoint. This module may be moved
20
+ # to another module in the future.
21
+
22
+ def _request_id_prefix(self) -> str:
23
+ return "score-"
24
+
25
+ def _convert_to_internal_request(
26
+ self,
27
+ request: ScoringRequest,
28
+ ) -> tuple[ScoringRequest, ScoringRequest]:
29
+ """Convert OpenAI scoring request to internal format"""
30
+ # For scoring, we pass the request directly as the tokenizer_manager
31
+ # has a specialized score_request method that doesn't use GenerateReqInput
32
+
33
+ return request, request
34
+
35
+ async def _handle_non_streaming_request(
36
+ self,
37
+ adapted_request: ScoringRequest,
38
+ request: ScoringRequest,
39
+ raw_request: Request,
40
+ ) -> Union[ScoringResponse, ErrorResponse]:
41
+ """Handle the scoring request"""
42
+ try:
43
+ # Use tokenizer_manager's score_request method directly
44
+ scores = await self.tokenizer_manager.score_request(
45
+ query=request.query,
46
+ items=request.items,
47
+ label_token_ids=request.label_token_ids,
48
+ apply_softmax=request.apply_softmax,
49
+ item_first=request.item_first,
50
+ request=raw_request,
51
+ )
52
+
53
+ # Create response with just the scores, without usage info
54
+ response = ScoringResponse(
55
+ scores=scores,
56
+ model=request.model,
57
+ )
58
+ return response
59
+
60
+ except ValueError as e:
61
+ return self.create_error_response(str(e))
@@ -0,0 +1,81 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Mapping, Optional, final
4
+
5
+ from sglang.srt.entrypoints.openai.protocol import UsageInfo
6
+
7
+
8
+ @final
9
+ class UsageProcessor:
10
+ """Stateless helpers that turn raw token counts into a UsageInfo."""
11
+
12
+ @staticmethod
13
+ def _details_if_cached(count: int) -> Optional[Dict[str, int]]:
14
+ """Return {"cached_tokens": N} only when N > 0 (keeps JSON slim)."""
15
+ return {"cached_tokens": count} if count > 0 else None
16
+
17
+ @staticmethod
18
+ def calculate_response_usage(
19
+ responses: List[Dict[str, Any]],
20
+ n_choices: int = 1,
21
+ enable_cache_report: bool = False,
22
+ ) -> UsageInfo:
23
+ completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses)
24
+
25
+ prompt_tokens = sum(
26
+ responses[i]["meta_info"]["prompt_tokens"]
27
+ for i in range(0, len(responses), n_choices)
28
+ )
29
+
30
+ cached_details = None
31
+ if enable_cache_report:
32
+ cached_total = sum(
33
+ r["meta_info"].get("cached_tokens", 0) for r in responses
34
+ )
35
+ cached_details = UsageProcessor._details_if_cached(cached_total)
36
+
37
+ return UsageProcessor.calculate_token_usage(
38
+ prompt_tokens=prompt_tokens,
39
+ completion_tokens=completion_tokens,
40
+ cached_tokens=cached_details,
41
+ )
42
+
43
+ @staticmethod
44
+ def calculate_streaming_usage(
45
+ prompt_tokens: Mapping[int, int],
46
+ completion_tokens: Mapping[int, int],
47
+ cached_tokens: Mapping[int, int],
48
+ n_choices: int,
49
+ enable_cache_report: bool = False,
50
+ ) -> UsageInfo:
51
+ # index % n_choices == 0 marks the first choice of a prompt
52
+ total_prompt_tokens = sum(
53
+ tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0
54
+ )
55
+ total_completion_tokens = sum(completion_tokens.values())
56
+
57
+ cached_details = (
58
+ UsageProcessor._details_if_cached(sum(cached_tokens.values()))
59
+ if enable_cache_report
60
+ else None
61
+ )
62
+
63
+ return UsageProcessor.calculate_token_usage(
64
+ prompt_tokens=total_prompt_tokens,
65
+ completion_tokens=total_completion_tokens,
66
+ cached_tokens=cached_details,
67
+ )
68
+
69
+ @staticmethod
70
+ def calculate_token_usage(
71
+ prompt_tokens: int,
72
+ completion_tokens: int,
73
+ cached_tokens: Optional[Dict[str, int]] = None,
74
+ ) -> UsageInfo:
75
+ """Calculate token usage information"""
76
+ return UsageInfo(
77
+ prompt_tokens=prompt_tokens,
78
+ completion_tokens=completion_tokens,
79
+ total_tokens=prompt_tokens + completion_tokens,
80
+ prompt_tokens_details=cached_tokens,
81
+ )
@@ -0,0 +1,72 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ from sglang.srt.entrypoints.openai.protocol import (
5
+ ChatCompletionRequest,
6
+ CompletionRequest,
7
+ LogProbs,
8
+ )
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def to_openai_style_logprobs(
14
+ input_token_logprobs=None,
15
+ output_token_logprobs=None,
16
+ input_top_logprobs=None,
17
+ output_top_logprobs=None,
18
+ ):
19
+ ret_logprobs = LogProbs()
20
+
21
+ def append_token_logprobs(token_logprobs):
22
+ for logprob, _, token_text in token_logprobs:
23
+ ret_logprobs.tokens.append(token_text)
24
+ ret_logprobs.token_logprobs.append(logprob)
25
+
26
+ # Not supported yet
27
+ ret_logprobs.text_offset.append(-1)
28
+
29
+ def append_top_logprobs(top_logprobs):
30
+ for tokens in top_logprobs:
31
+ if tokens is not None:
32
+ ret_logprobs.top_logprobs.append(
33
+ {token[2]: token[0] for token in tokens}
34
+ )
35
+ else:
36
+ ret_logprobs.top_logprobs.append(None)
37
+
38
+ if input_token_logprobs is not None:
39
+ append_token_logprobs(input_token_logprobs)
40
+ if output_token_logprobs is not None:
41
+ append_token_logprobs(output_token_logprobs)
42
+ if input_top_logprobs is not None:
43
+ append_top_logprobs(input_top_logprobs)
44
+ if output_top_logprobs is not None:
45
+ append_top_logprobs(output_top_logprobs)
46
+
47
+ return ret_logprobs
48
+
49
+
50
+ def process_hidden_states_from_ret(
51
+ ret_item: Dict[str, Any],
52
+ request: Union[
53
+ ChatCompletionRequest,
54
+ CompletionRequest,
55
+ ],
56
+ ) -> Optional[List]:
57
+ """Process hidden states from a ret item in non-streaming response.
58
+
59
+ Args:
60
+ ret_item: Response item containing meta_info
61
+ request: The original request object
62
+
63
+ Returns:
64
+ Processed hidden states for the last token, or None
65
+ """
66
+ if not request.return_hidden_states:
67
+ return None
68
+
69
+ hidden_states = ret_item["meta_info"].get("hidden_states", None)
70
+ if hidden_states is not None:
71
+ hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []
72
+ return hidden_states
@@ -0,0 +1 @@
1
+ from . import reader
@@ -0,0 +1,51 @@
1
+ from collections import defaultdict
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from sglang.srt.managers.expert_distribution import (
8
+ _convert_global_physical_count_to_logical_count,
9
+ )
10
+
11
+ convert_global_physical_count_to_logical_count = (
12
+ _convert_global_physical_count_to_logical_count
13
+ )
14
+
15
+
16
+ def read_mode_per_pass(dir_data: Path):
17
+ """Read data from ExpertDistributionRecorder when recorded with mode `per_pass`"""
18
+
19
+ # gpc := global_physical_count
20
+ gpc_of_forward_pass_and_rank = defaultdict(lambda: defaultdict())
21
+ for path in tqdm(list(dir_data.glob("*.pt"))):
22
+ data_pack = torch.load(path, weights_only=True)
23
+ last_physical_to_logical_map = data_pack["last_physical_to_logical_map"]
24
+ for record in data_pack["records"]:
25
+ forward_pass_id = record["forward_pass_id"]
26
+ rank = record["rank"]
27
+ assert (
28
+ gpc_of_forward_pass_and_rank[forward_pass_id].get(rank) is None
29
+ ), f"Duplicated {forward_pass_id=} {rank=}"
30
+ gpc_of_forward_pass_and_rank[forward_pass_id][rank] = record[
31
+ "global_physical_count"
32
+ ]
33
+
34
+ forward_pass_ids = sorted(gpc_of_forward_pass_and_rank.keys())
35
+ print(f"Make {forward_pass_ids=} into array")
36
+
37
+ items = []
38
+ for forward_pass_id, gpc_of_rank in sorted(gpc_of_forward_pass_and_rank.items()):
39
+ gpc_of_rank_tensor = torch.stack(
40
+ [gpc for rank, gpc in sorted(gpc_of_rank.items())]
41
+ ).sum(dim=0)
42
+ items.append(gpc_of_rank_tensor)
43
+
44
+ gpc_of_forward_pass = torch.stack(items)
45
+ print(f"{gpc_of_forward_pass.shape=}")
46
+
47
+ return dict(
48
+ global_physical_count_of_forward_pass=gpc_of_forward_pass,
49
+ last_physical_to_logical_map=last_physical_to_logical_map,
50
+ forward_pass_ids=forward_pass_ids,
51
+ )
@@ -6,6 +6,7 @@ from typing import Any, Dict, List
6
6
  from partial_json_parser.core.exceptions import MalformedJSON
7
7
  from partial_json_parser.core.options import Allow
8
8
 
9
+ from sglang.srt.entrypoints.openai.protocol import Tool
9
10
  from sglang.srt.function_call.core_types import (
10
11
  StreamingParseResult,
11
12
  ToolCallItem,
@@ -16,7 +17,6 @@ from sglang.srt.function_call.utils import (
16
17
  _is_complete_json,
17
18
  _partial_json_loads,
18
19
  )
19
- from sglang.srt.openai_api.protocol import Tool
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -111,11 +111,10 @@ class BaseFormatDetector(ABC):
111
111
  # The current_text has tool_call if it is the start of a new tool call sequence
112
112
  # or it is the start of a new tool call after a tool call separator, when there is a previous tool call
113
113
  if not (
114
- self.bot_token in current_text
115
- or current_text.startswith("{")
114
+ self.has_tool_call(current_text)
116
115
  or (
117
116
  self.current_tool_id > 0
118
- and current_text.startswith(self.tool_call_separator + "{")
117
+ and current_text.startswith(self.tool_call_separator)
119
118
  )
120
119
  ):
121
120
  # Only clear buffer if we're sure no tool call is starting
@@ -143,6 +142,10 @@ class BaseFormatDetector(ABC):
143
142
  try:
144
143
  if current_text.startswith(self.bot_token):
145
144
  start_idx = len(self.bot_token)
145
+ elif self.current_tool_id > 0 and current_text.startswith(
146
+ self.tool_call_separator + self.bot_token
147
+ ):
148
+ start_idx = len(self.tool_call_separator + self.bot_token)
146
149
  elif self.current_tool_id > 0 and current_text.startswith(
147
150
  self.tool_call_separator
148
151
  ):
@@ -3,6 +3,7 @@ import logging
3
3
  import re
4
4
  from typing import List
5
5
 
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
6
7
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
7
8
  from sglang.srt.function_call.core_types import (
8
9
  StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
12
13
  )
13
14
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
14
15
  from sglang.srt.function_call.utils import _is_complete_json
15
- from sglang.srt.openai_api.protocol import Tool
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
@@ -211,20 +211,74 @@ class EBNFComposer:
211
211
  properties = params.get("properties", {})
212
212
  required_props = set(params.get("required", []))
213
213
 
214
- # Build argument rules for this tool
215
- arg_rules = []
214
+ # The generated pattern ensures:
215
+ # 1. Required properties appear first, joined by commas
216
+ # 2. Optional properties are wrapped with comma included: ( "," ( "prop" : value )? )?
217
+ # 3. For multiple optional properties, we allow flexible ordering:
218
+ # - Each optional can be skipped entirely
219
+ # - They can appear in any combination
220
+ #
221
+ # Example patterns generated:
222
+ # - One required, one optional:
223
+ # "{" "location" ":" string ( "," ( "unit" ":" enum ) )? "}"
224
+ # Allows: {"location": "Paris"} or {"location": "Paris", "unit": "celsius"}
225
+ #
226
+ # - Multiple optional properties with flexible ordering:
227
+ # "{" "req" ":" string ( "," ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value ) )? "}"
228
+ # Allows: {"req": "x"}, {"req": "x", "opt1": "y"}, {"req": "x", "opt2": "z"},
229
+ # {"req": "x", "opt1": "y", "opt2": "z"}
230
+ #
231
+ # - All optional properties with flexible ordering:
232
+ # "{" ( "opt1" ":" value ( "," "opt2" ":" value )? | "opt2" ":" value )? "}"
233
+ # Allows: {}, {"opt1": "x"}, {"opt2": "y"}, {"opt1": "x", "opt2": "y"}
234
+
235
+ prop_kv_pairs = {}
236
+ ordered_props = list(properties.keys())
237
+
216
238
  for prop_name, prop_schema in properties.items():
217
239
  value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
218
240
  # Create key=value pair
219
241
  pair = key_value_template.format(key=prop_name, valrule=value_rule)
220
-
221
- if prop_name not in required_props:
222
- pair = f"[ {pair} ]"
223
-
224
- arg_rules.append(pair)
225
-
226
- # Combine all argument rules
227
- combined_args = ' "," '.join(arg_rules) if arg_rules else ""
242
+ prop_kv_pairs[prop_name] = pair
243
+
244
+ # Separate into required and optional while preserving order
245
+ required = [p for p in ordered_props if p in required_props]
246
+ optional = [p for p in ordered_props if p not in required_props]
247
+
248
+ # Build the combined rule
249
+ rule_parts = []
250
+
251
+ # Add required properties joined by commas
252
+ if required:
253
+ rule_parts.append(' "," '.join(prop_kv_pairs[k] for k in required))
254
+
255
+ # Add optional properties with flexible ordering
256
+ if optional:
257
+ # Build alternatives where any optional property can appear first
258
+ opt_alternatives = []
259
+ for i in range(len(optional)):
260
+ # Build pattern for optional[i] appearing first
261
+ opt_parts = []
262
+ for j in range(i, len(optional)):
263
+ if j == i:
264
+ opt_parts.append(prop_kv_pairs[optional[j]])
265
+ else:
266
+ opt_parts.append(f' ( "," {prop_kv_pairs[optional[j]]} )?')
267
+ opt_alternatives.append("".join(opt_parts))
268
+
269
+ # Wrap with appropriate comma handling based on whether we have required properties
270
+ if required:
271
+ # Required properties exist, so optional group needs outer comma
272
+ rule_parts.append(' ( "," ( ')
273
+ rule_parts.append(" | ".join(opt_alternatives))
274
+ rule_parts.append(" ) )?")
275
+ else:
276
+ # All properties are optional
277
+ rule_parts.append("( ")
278
+ rule_parts.append(" | ".join(opt_alternatives))
279
+ rule_parts.append(" )?")
280
+
281
+ combined_args = "".join(rule_parts)
228
282
  arguments_rule = args_template.format(arg_rules=combined_args)
229
283
 
230
284
  # Add the function call rule and its arguments rule
@@ -1,6 +1,12 @@
1
1
  import logging
2
2
  from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
3
3
 
4
+ from sglang.srt.entrypoints.openai.protocol import (
5
+ StructuralTagResponseFormat,
6
+ StructuresResponseFormat,
7
+ Tool,
8
+ ToolChoice,
9
+ )
4
10
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
5
11
  from sglang.srt.function_call.core_types import ToolCallItem
6
12
  from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
@@ -8,12 +14,6 @@ from sglang.srt.function_call.llama32_detector import Llama32Detector
8
14
  from sglang.srt.function_call.mistral_detector import MistralDetector
9
15
  from sglang.srt.function_call.pythonic_detector import PythonicDetector
10
16
  from sglang.srt.function_call.qwen25_detector import Qwen25Detector
11
- from sglang.srt.openai_api.protocol import (
12
- StructuralTagResponseFormat,
13
- StructuresResponseFormat,
14
- Tool,
15
- ToolChoice,
16
- )
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
@@ -2,6 +2,7 @@ import json
2
2
  import logging
3
3
  from typing import List
4
4
 
5
+ from sglang.srt.entrypoints.openai.protocol import Tool
5
6
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
6
7
  from sglang.srt.function_call.core_types import (
7
8
  StreamingParseResult,
@@ -9,7 +10,6 @@ from sglang.srt.function_call.core_types import (
9
10
  _GetInfoFunc,
10
11
  )
11
12
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
12
- from sglang.srt.openai_api.protocol import Tool
13
13
 
14
14
  logger = logging.getLogger(__name__)
15
15
 
@@ -3,6 +3,7 @@ import logging
3
3
  import re
4
4
  from typing import List
5
5
 
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
6
7
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
7
8
  from sglang.srt.function_call.core_types import (
8
9
  StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
10
11
  _GetInfoFunc,
11
12
  )
12
13
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
13
- from sglang.srt.openai_api.protocol import Tool
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -4,6 +4,7 @@ import logging
4
4
  import re
5
5
  from typing import List, Optional
6
6
 
7
+ from sglang.srt.entrypoints.openai.protocol import Tool
7
8
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
8
9
  from sglang.srt.function_call.core_types import (
9
10
  StreamingParseResult,
@@ -12,7 +13,6 @@ from sglang.srt.function_call.core_types import (
12
13
  _GetInfoFunc,
13
14
  )
14
15
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
15
- from sglang.srt.openai_api.protocol import Tool
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
@@ -3,6 +3,7 @@ import logging
3
3
  import re
4
4
  from typing import List
5
5
 
6
+ from sglang.srt.entrypoints.openai.protocol import Tool
6
7
  from sglang.srt.function_call.base_format_detector import BaseFormatDetector
7
8
  from sglang.srt.function_call.core_types import (
8
9
  StreamingParseResult,
@@ -10,7 +11,6 @@ from sglang.srt.function_call.core_types import (
10
11
  _GetInfoFunc,
11
12
  )
12
13
  from sglang.srt.function_call.ebnf_composer import EBNFComposer
13
- from sglang.srt.openai_api.protocol import Tool
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -1,11 +1,12 @@
1
- """
2
- Utility functions for OpenAI API adapter.
1
+ """Template utilities for Jinja template processing.
2
+
3
+ This module provides utilities for analyzing and processing Jinja chat templates,
4
+ including content format detection and message processing.
3
5
  """
4
6
 
5
7
  import logging
6
- from typing import Dict, List
7
8
 
8
- import jinja2.nodes
9
+ import jinja2
9
10
  import transformers.utils.chat_template_utils as hf_chat_utils
10
11
 
11
12
  logger = logging.getLogger(__name__)
@@ -75,7 +76,7 @@ def _try_extract_ast(chat_template: str):
75
76
  return None
76
77
 
77
78
 
78
- def detect_template_content_format(chat_template: str) -> str:
79
+ def detect_jinja_template_content_format(chat_template: str) -> str:
79
80
  """
80
81
  Detect whether a chat template expects 'string' or 'openai' content format.
81
82
 
@@ -20,6 +20,7 @@ from typing import Optional
20
20
  import torch
21
21
  import torch.nn as nn
22
22
  import torch.nn.functional as F
23
+ from transformers import PretrainedConfig
23
24
 
24
25
  from sglang.srt.custom_op import CustomOp
25
26
  from sglang.srt.distributed import (
@@ -28,9 +29,19 @@ from sglang.srt.distributed import (
28
29
  get_tensor_model_parallel_world_size,
29
30
  )
30
31
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
31
- from sglang.srt.utils import is_cuda, set_weight_attrs
32
+ from sglang.srt.utils import (
33
+ cpu_has_amx_support,
34
+ is_cpu,
35
+ is_cuda,
36
+ is_npu,
37
+ set_weight_attrs,
38
+ )
39
+ from sglang.utils import resolve_obj_by_qualname
32
40
 
33
41
  _is_cuda = is_cuda()
42
+ _is_npu = is_npu()
43
+ _is_cpu_amx_available = cpu_has_amx_support()
44
+ _is_cpu = is_cpu()
34
45
 
35
46
  if _is_cuda:
36
47
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
@@ -50,6 +61,15 @@ class SiluAndMul(CustomOp):
50
61
  silu_and_mul(x, out)
51
62
  return out
52
63
 
64
+ def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
65
+ if _is_cpu_amx_available:
66
+ d = x.shape[-1] // 2
67
+ output_shape = x.shape[:-1] + (d,)
68
+ out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
69
+ return out
70
+ else:
71
+ return self.forward_native(x)
72
+
53
73
 
54
74
  class GeluAndMul(CustomOp):
55
75
  def __init__(self, approximate="tanh"):
@@ -165,8 +185,25 @@ def get_act_fn(
165
185
  return act_fn
166
186
 
167
187
 
168
- if not _is_cuda:
188
+ def get_cross_encoder_activation_function(config: PretrainedConfig):
189
+ if (
190
+ hasattr(config, "sbert_ce_default_activation_function")
191
+ and config.sbert_ce_default_activation_function is not None
192
+ ):
193
+
194
+ function_name = config.sbert_ce_default_activation_function
195
+ assert function_name.startswith("torch.nn.modules."), (
196
+ "Loading of activation functions is restricted to "
197
+ "torch.nn.modules for security reasons"
198
+ )
199
+ return resolve_obj_by_qualname(function_name)()
200
+ else:
201
+ # adapt bge-reranker
202
+ return nn.Identity()
203
+
204
+
205
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
169
206
  logger.info(
170
- "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
207
+ "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
171
208
  )
172
209
  from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
324
324
  )
325
325
 
326
326
  def init_cuda_graph_state(
327
- self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
327
+ self,
328
+ max_bs: int,
329
+ max_num_tokens: int,
330
+ kv_indices_buf: Optional[torch.Tensor] = None,
328
331
  ):
329
332
  self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
330
333
  if kv_indices_buf is None:
@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
338
341
 
339
342
  if not self.skip_prefill:
340
343
  self.cuda_graph_custom_mask = torch.zeros(
341
- (max_bs * self.max_context_len),
344
+ (max_num_tokens * self.max_context_len),
342
345
  dtype=torch.uint8,
343
346
  device=self.device,
344
347
  )
@@ -717,6 +720,11 @@ class AiterIndicesUpdaterPrefill:
717
720
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
718
721
  self.update = self.update_single_wrapper
719
722
 
723
+ # get the last index of the pool
724
+ self.pool_size = (
725
+ model_runner.token_to_kv_pool.size + model_runner.token_to_kv_pool.page_size
726
+ ) - 1
727
+
720
728
  self.kv_indices = None
721
729
  self.max_q_len = 0
722
730
  self.max_kv_len = 0
@@ -754,8 +762,16 @@ class AiterIndicesUpdaterPrefill:
754
762
  # Normal extend
755
763
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
756
764
  kv_indptr = kv_indptr[: bs + 1]
757
- kv_indices = torch.empty(
758
- paged_kernel_lens_sum + 256,
765
+
766
+ # (TODO: Kk) WA - CI test_moe_eval_accuracy_large.py
767
+ # mha_batch_prefill reads 128 data to do computatoin
768
+ # if real data is not long enough then original padding value 0 is used
769
+ # but the 0 location will be made nan (noqa) in cuda graph capture mode
770
+ # this will cause the output tensor value becomes nan
771
+ # WA is to assure that last index of pool not changed
772
+ kv_indices = torch.full(
773
+ (paged_kernel_lens_sum + 128,),
774
+ self.pool_size,
759
775
  dtype=torch.int32,
760
776
  device=req_pool_indices.device,
761
777
  )
@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
19
19
  """Init the metadata for a forward pass."""
20
20
  raise NotImplementedError()
21
21
 
22
- def init_cuda_graph_state(self, max_bs: int):
22
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
23
23
  """Init the global shared states for cuda graph."""
24
24
  raise NotImplementedError()
25
25