sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,15 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  import json
3
5
  import logging
4
6
  import time
5
7
  import uuid
6
- from typing import Any, AsyncGenerator, Dict, List, Optional, Union
8
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union
7
9
 
8
10
  from fastapi import Request
9
11
  from fastapi.responses import ORJSONResponse, StreamingResponse
10
12
 
11
- from sglang.srt.conversation import generate_chat_conv
12
13
  from sglang.srt.entrypoints.openai.protocol import (
13
14
  ChatCompletionRequest,
14
15
  ChatCompletionResponse,
@@ -33,13 +34,16 @@ from sglang.srt.entrypoints.openai.utils import (
33
34
  to_openai_style_logprobs,
34
35
  )
35
36
  from sglang.srt.function_call.function_call_parser import FunctionCallParser
36
- from sglang.srt.jinja_template_utils import process_content_for_template_format
37
37
  from sglang.srt.managers.io_struct import GenerateReqInput
38
- from sglang.srt.managers.template_manager import TemplateManager
39
- from sglang.srt.managers.tokenizer_manager import TokenizerManager
40
- from sglang.srt.reasoning_parser import ReasoningParser
38
+ from sglang.srt.parser.conversation import generate_chat_conv
39
+ from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
40
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
41
41
  from sglang.utils import convert_json_schema_to_str
42
42
 
43
+ if TYPE_CHECKING:
44
+ from sglang.srt.managers.template_manager import TemplateManager
45
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
46
+
43
47
  logger = logging.getLogger(__name__)
44
48
 
45
49
 
@@ -53,6 +57,7 @@ class OpenAIServingChat(OpenAIServingBase):
53
57
  ):
54
58
  super().__init__(tokenizer_manager)
55
59
  self.template_manager = template_manager
60
+ self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
56
61
 
57
62
  def _request_id_prefix(self) -> str:
58
63
  return "chatcmpl-"
@@ -148,6 +153,16 @@ class OpenAIServingChat(OpenAIServingBase):
148
153
  self, request: ChatCompletionRequest, is_multimodal: bool
149
154
  ) -> MessageProcessingResult:
150
155
  """Process chat messages and apply chat template"""
156
+ is_gpt_oss = (
157
+ hasattr(self.tokenizer_manager.model_config, "hf_config")
158
+ and hasattr(self.tokenizer_manager.model_config.hf_config, "model_type")
159
+ and self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss"
160
+ )
161
+
162
+ # GptOss model needs to keep special tokens for harmony parsing
163
+ if is_gpt_oss:
164
+ request.skip_special_tokens = False
165
+
151
166
  tool_call_constraint = None
152
167
 
153
168
  # Apply chat template and its stop strings
@@ -162,10 +177,11 @@ class OpenAIServingChat(OpenAIServingBase):
162
177
  ]
163
178
  else:
164
179
  tools = [item.function.model_dump() for item in request.tools]
165
-
166
- tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
167
- parser = FunctionCallParser(request.tools, tool_call_parser)
168
- tool_call_constraint = parser.get_structure_constraint(request.tool_choice)
180
+ if self.tool_call_parser:
181
+ parser = FunctionCallParser(request.tools, self.tool_call_parser)
182
+ tool_call_constraint = parser.get_structure_constraint(
183
+ request.tool_choice
184
+ )
169
185
 
170
186
  # Use chat template
171
187
  if self.template_manager.chat_template_name is None:
@@ -207,6 +223,25 @@ class OpenAIServingChat(OpenAIServingBase):
207
223
  audio_data,
208
224
  modalities,
209
225
  )
226
+
227
+ # per the Transformers docs & maintainers, tool call arguments in
228
+ # assistant-role messages with tool_calls need to be dicts not JSON str -
229
+ # this is how tool-use chat templates will expect them moving forwards
230
+ # so, for messages that have tool_calls, parse the string (which we get
231
+ # from openAI format) to dict
232
+ if (
233
+ processed_msg["role"] == "assistant"
234
+ and "tool_calls" in processed_msg
235
+ and isinstance(processed_msg["tool_calls"], list)
236
+ ):
237
+ for item in processed_msg["tool_calls"]:
238
+ if "arguments" in item["function"] and isinstance(
239
+ item["function"]["arguments"], str
240
+ ):
241
+ item["function"]["arguments"] = json.loads(
242
+ item["function"]["arguments"]
243
+ )
244
+
210
245
  openai_compatible_messages.append(processed_msg)
211
246
 
212
247
  # Handle assistant prefix for continue_final_message
@@ -508,7 +543,11 @@ class OpenAIServingChat(OpenAIServingBase):
508
543
  yield f"data: {chunk.model_dump_json()}\n\n"
509
544
 
510
545
  # Handle tool calls
511
- if request.tool_choice != "none" and request.tools:
546
+ if (
547
+ request.tool_choice != "none"
548
+ and request.tools
549
+ and self.tool_call_parser
550
+ ):
512
551
  async for chunk in self._process_tool_call_stream(
513
552
  index,
514
553
  delta,
@@ -698,10 +737,13 @@ class OpenAIServingChat(OpenAIServingBase):
698
737
 
699
738
  # Handle tool calls
700
739
  tool_calls = None
701
- if request.tool_choice != "none" and request.tools:
702
- tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
740
+ if (
741
+ request.tool_choice != "none"
742
+ and request.tools
743
+ and self.tool_call_parser
744
+ ):
703
745
  tool_calls, text, finish_reason = self._process_tool_calls(
704
- text, request.tools, tool_call_parser, finish_reason
746
+ text, request.tools, finish_reason
705
747
  )
706
748
 
707
749
  choice_data = ChatCompletionResponseChoice(
@@ -795,26 +837,36 @@ class OpenAIServingChat(OpenAIServingBase):
795
837
  self,
796
838
  text: str,
797
839
  tools: List[Any],
798
- tool_call_parser: Optional[str],
799
840
  finish_reason: Dict[str, Any],
800
841
  ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
801
842
  """Process tool calls in the response"""
802
- parser = FunctionCallParser(tools, tool_call_parser)
843
+ parser = FunctionCallParser(tools, self.tool_call_parser)
803
844
  if parser.has_tool_call(text):
804
845
  if finish_reason["type"] == "stop":
805
846
  finish_reason["type"] = "tool_calls"
806
847
  finish_reason["matched"] = None
807
848
  try:
808
849
  text, call_info_list = parser.parse_non_stream(text)
809
- tool_calls = [
810
- ToolCall(
811
- id=f"call_{uuid.uuid4().hex[:24]}",
812
- function=FunctionResponse(
813
- name=call_info.name, arguments=call_info.parameters
814
- ),
850
+ tool_calls = []
851
+ for call_info in call_info_list:
852
+ # For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
853
+ if (
854
+ self.tool_call_parser == "kimi_k2"
855
+ and call_info.name is not None
856
+ ):
857
+ tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
858
+ else:
859
+ tool_id = f"call_{uuid.uuid4().hex[:24]}"
860
+
861
+ tool_calls.append(
862
+ ToolCall(
863
+ id=tool_id,
864
+ index=getattr(call_info, "tool_index", None),
865
+ function=FunctionResponse(
866
+ name=call_info.name, arguments=call_info.parameters
867
+ ),
868
+ )
815
869
  )
816
- for call_info in call_info_list
817
- ]
818
870
  return tool_calls, text, finish_reason
819
871
  except Exception as e:
820
872
  logger.error(f"Tool call parsing error: {e}")
@@ -896,7 +948,7 @@ class OpenAIServingChat(OpenAIServingBase):
896
948
  if index not in parser_dict:
897
949
  parser_dict[index] = FunctionCallParser(
898
950
  tools=request.tools,
899
- tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
951
+ tool_call_parser=self.tool_call_parser,
900
952
  )
901
953
  parser = parser_dict[index]
902
954
 
@@ -925,7 +977,11 @@ class OpenAIServingChat(OpenAIServingBase):
925
977
  # Tool call ID should be generated only once per tool call
926
978
  if call_item.name:
927
979
  # First chunk: include ID and function name
928
- tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
980
+ if self.tool_call_parser == "kimi_k2":
981
+ # Align with Kimi-K2 format: functions.{name}:{index}
982
+ tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
983
+ else:
984
+ tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
929
985
  function_name = call_item.name
930
986
  else:
931
987
  # Subsequent chunks: null ID and name for argument deltas
@@ -1,11 +1,12 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import time
3
- from typing import Any, AsyncGenerator, Dict, List, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union
4
6
 
5
7
  from fastapi import Request
6
8
  from fastapi.responses import ORJSONResponse, StreamingResponse
7
9
 
8
- from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
9
10
  from sglang.srt.entrypoints.openai.protocol import (
10
11
  CompletionRequest,
11
12
  CompletionResponse,
@@ -21,8 +22,14 @@ from sglang.srt.entrypoints.openai.utils import (
21
22
  to_openai_style_logprobs,
22
23
  )
23
24
  from sglang.srt.managers.io_struct import GenerateReqInput
24
- from sglang.srt.managers.template_manager import TemplateManager
25
- from sglang.srt.managers.tokenizer_manager import TokenizerManager
25
+ from sglang.srt.parser.code_completion_parser import (
26
+ generate_completion_prompt_from_request,
27
+ )
28
+ from sglang.utils import convert_json_schema_to_str
29
+
30
+ if TYPE_CHECKING:
31
+ from sglang.srt.managers.template_manager import TemplateManager
32
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
26
33
 
27
34
  logger = logging.getLogger(__name__)
28
35
 
@@ -125,6 +132,20 @@ class OpenAIServingCompletion(OpenAIServingBase):
125
132
  "logit_bias": request.logit_bias,
126
133
  }
127
134
 
135
+ # Handle response_format constraints
136
+ if request.response_format and request.response_format.type == "json_schema":
137
+ sampling_params["json_schema"] = convert_json_schema_to_str(
138
+ request.response_format.json_schema.schema_
139
+ )
140
+ elif request.response_format and request.response_format.type == "json_object":
141
+ sampling_params["json_schema"] = '{"type": "object"}'
142
+ elif (
143
+ request.response_format and request.response_format.type == "structural_tag"
144
+ ):
145
+ sampling_params["structural_tag"] = convert_json_schema_to_str(
146
+ request.response_format.model_dump(by_alias=True)
147
+ )
148
+
128
149
  return sampling_params
129
150
 
130
151
  async def _handle_streaming_request(
@@ -1,9 +1,10 @@
1
- from typing import Any, Dict, List, Optional, Union
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2
4
 
3
5
  from fastapi import Request
4
6
  from fastapi.responses import ORJSONResponse
5
7
 
6
- from sglang.srt.conversation import generate_embedding_convs
7
8
  from sglang.srt.entrypoints.openai.protocol import (
8
9
  EmbeddingObject,
9
10
  EmbeddingRequest,
@@ -14,8 +15,11 @@ from sglang.srt.entrypoints.openai.protocol import (
14
15
  )
15
16
  from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
16
17
  from sglang.srt.managers.io_struct import EmbeddingReqInput
17
- from sglang.srt.managers.template_manager import TemplateManager
18
- from sglang.srt.managers.tokenizer_manager import TokenizerManager
18
+ from sglang.srt.parser.conversation import generate_embedding_convs
19
+
20
+ if TYPE_CHECKING:
21
+ from sglang.srt.managers.template_manager import TemplateManager
22
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
19
23
 
20
24
 
21
25
  class OpenAIServingEmbedding(OpenAIServingBase):
@@ -1,6 +1,7 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  # Adapted from vLLM's OpenAIServingResponses
3
3
  """Handler for /v1/responses requests"""
4
+ from __future__ import annotations
4
5
 
5
6
  import asyncio
6
7
  import copy
@@ -9,7 +10,7 @@ import logging
9
10
  import time
10
11
  from contextlib import AsyncExitStack
11
12
  from http import HTTPStatus
12
- from typing import Any, AsyncGenerator, AsyncIterator, Optional, Union
13
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional, Union
13
14
 
14
15
  import jinja2
15
16
  import openai.types.responses as openai_responses_types
@@ -54,11 +55,13 @@ from sglang.srt.entrypoints.openai.protocol import (
54
55
  from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
55
56
  from sglang.srt.entrypoints.openai.tool_server import MCPToolServer, ToolServer
56
57
  from sglang.srt.managers.io_struct import GenerateReqInput
57
- from sglang.srt.managers.template_manager import TemplateManager
58
- from sglang.srt.managers.tokenizer_manager import TokenizerManager
59
- from sglang.srt.reasoning_parser import ReasoningParser
58
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
60
59
  from sglang.srt.utils import random_uuid
61
60
 
61
+ if TYPE_CHECKING:
62
+ from sglang.srt.managers.template_manager import TemplateManager
63
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
64
+
62
65
  logger = logging.getLogger(__name__)
63
66
 
64
67
 
@@ -55,12 +55,21 @@ class EPLBManager:
55
55
  enable_timing = self._rebalance_layers_per_chunk is None
56
56
 
57
57
  if enable_timing:
58
- torch.cuda.synchronize()
58
+ torch.get_device_module().synchronize()
59
59
  time_start = time.time()
60
60
 
61
- logical_count = get_global_expert_distribution_recorder().dump_record(
61
+ dump_record_output = get_global_expert_distribution_recorder().dump_record(
62
62
  output_mode="object"
63
- )["logical_count"]
63
+ )
64
+ logical_count = dump_record_output["logical_count"]
65
+ average_utilization_rate_over_window = dump_record_output[
66
+ "average_utilization_rate_over_window"
67
+ ]
68
+
69
+ # Check whether rebalancing is needed
70
+ if not self._check_rebalance_needed(average_utilization_rate_over_window):
71
+ return
72
+
64
73
  expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
65
74
  self._server_args, self._model_runner.model_config, logical_count
66
75
  )
@@ -76,11 +85,26 @@ class EPLBManager:
76
85
 
77
86
  msg = f"[EPLBManager] rebalance end"
78
87
  if enable_timing:
79
- torch.cuda.synchronize()
88
+ torch.get_device_module().synchronize()
80
89
  time_end = time.time()
81
90
  msg += f" time={time_end - time_start:.3f}s"
82
91
  logger.info(msg)
83
92
 
93
+ def _check_rebalance_needed(self, average_utilization_rate_over_window):
94
+ if average_utilization_rate_over_window is None:
95
+ return True
96
+
97
+ if (
98
+ average_utilization_rate_over_window
99
+ > self._server_args.eplb_min_rebalancing_utilization_threshold
100
+ ):
101
+ logger.info(
102
+ f"[EPLBManager] Skipped ep rebalancing: current GPU utilization {average_utilization_rate_over_window:.2f} > minimum rebalance threshold {self._server_args.eplb_min_rebalancing_utilization_threshold:.2f}"
103
+ )
104
+ return False
105
+
106
+ return True
107
+
84
108
  def _compute_update_layer_ids_chunks(self) -> List[List[int]]:
85
109
  all_layer_ids = sorted(
86
110
  list(self._model_runner.model.routed_experts_weights_of_layer.keys())
@@ -11,23 +11,31 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
+
15
+ from __future__ import annotations
16
+
14
17
  import logging
18
+ import math
15
19
  import os
16
20
  import time
17
21
  from abc import ABC
18
22
  from collections import deque
19
23
  from contextlib import contextmanager
20
24
  from pathlib import Path
21
- from typing import Any, Dict, List, Literal, Optional, Tuple, Type
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
22
26
 
23
27
  import einops
24
28
  import torch
25
29
  import torch.distributed
26
30
 
27
- from sglang.srt.eplb.expert_location import ExpertLocationMetadata
28
31
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
29
32
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.utils import Withable, get_bool_env_var
33
+ from sglang.srt.utils import Withable, get_bool_env_var, is_npu
34
+
35
+ _is_npu = is_npu()
36
+
37
+ if TYPE_CHECKING:
38
+ from sglang.srt.eplb.expert_location import ExpertLocationMetadata
31
39
 
32
40
  logger = logging.getLogger(__name__)
33
41
 
@@ -42,7 +50,7 @@ class ExpertDistributionRecorder(ABC):
42
50
  @staticmethod
43
51
  def init_new(
44
52
  server_args: ServerArgs,
45
- expert_location_metadata: "ExpertLocationMetadata",
53
+ expert_location_metadata: ExpertLocationMetadata,
46
54
  rank: int,
47
55
  ):
48
56
  if server_args.expert_distribution_recorder_mode is not None:
@@ -117,7 +125,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
117
125
  def __init__(
118
126
  self,
119
127
  server_args: ServerArgs,
120
- expert_location_metadata: "ExpertLocationMetadata",
128
+ expert_location_metadata: ExpertLocationMetadata,
121
129
  rank: int,
122
130
  ):
123
131
  self._server_args = server_args
@@ -210,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
210
218
  def _on_hook(self, hook_name: str, **kwargs):
211
219
  if self._disable_all:
212
220
  return
213
- if not (self._recording or torch.cuda.is_current_stream_capturing()):
221
+ if not (
222
+ self._recording or torch.get_device_module().is_current_stream_capturing()
223
+ ):
214
224
  return
215
225
  gatherer = self._single_pass_gatherers[
216
226
  self._accumulator.get_single_pass_gatherer_key(
@@ -278,7 +288,7 @@ class _SinglePassGatherer(ABC):
278
288
  @staticmethod
279
289
  def init_new(
280
290
  server_args: ServerArgs,
281
- expert_location_metadata: "ExpertLocationMetadata",
291
+ expert_location_metadata: ExpertLocationMetadata,
282
292
  rank: int,
283
293
  ) -> "_SinglePassGatherer":
284
294
  if server_args.expert_distribution_recorder_mode == "per_token":
@@ -306,7 +316,7 @@ class _SinglePassGatherer(ABC):
306
316
 
307
317
  return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
308
318
 
309
- def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
319
+ def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int):
310
320
  self._expert_location_metadata = expert_location_metadata
311
321
  self._rank = rank
312
322
 
@@ -345,7 +355,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
345
355
  def __init__(
346
356
  self,
347
357
  server_args: ServerArgs,
348
- expert_location_metadata: "ExpertLocationMetadata",
358
+ expert_location_metadata: ExpertLocationMetadata,
349
359
  rank: int,
350
360
  ):
351
361
  super().__init__(expert_location_metadata, rank)
@@ -445,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
445
455
  class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
446
456
  def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
447
457
  super().__init__(*args, **kwargs)
458
+ if not _is_npu:
459
+ device = "cuda"
460
+ else:
461
+ device = "npu"
448
462
  self._enable_global_physical_experts = enable_global_physical_experts
449
463
  self._data = torch.zeros(
450
464
  (
@@ -456,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
456
470
  ),
457
471
  ),
458
472
  dtype=torch.int,
459
- device="cuda",
473
+ device=device,
460
474
  )
461
475
 
462
476
  def reset(self):
@@ -560,7 +574,7 @@ class _Accumulator(ABC):
560
574
  @staticmethod
561
575
  def init_new(
562
576
  server_args: ServerArgs,
563
- expert_location_metadata: "ExpertLocationMetadata",
577
+ expert_location_metadata: ExpertLocationMetadata,
564
578
  rank: int,
565
579
  ) -> "_Accumulator":
566
580
  return _Accumulator.get_class(server_args)(
@@ -579,7 +593,7 @@ class _Accumulator(ABC):
579
593
  def __init__(
580
594
  self,
581
595
  server_args: ServerArgs,
582
- expert_location_metadata: "ExpertLocationMetadata",
596
+ expert_location_metadata: ExpertLocationMetadata,
583
597
  rank: int,
584
598
  ):
585
599
  self._server_args = server_args
@@ -614,8 +628,8 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
614
628
  self._enable = self._server_args.enable_expert_distribution_metrics
615
629
 
616
630
  if self._enable:
617
- window_sizes = [10, 100, 1000]
618
- self._history = _DequeCollection(maxlens=window_sizes)
631
+ self.window_sizes = [10, 100, 1000]
632
+ self._history = _DequeCollection(maxlens=self.window_sizes)
619
633
  self._rank = torch.distributed.get_rank()
620
634
 
621
635
  def append(
@@ -778,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
778
792
 
779
793
  if self._first_dump:
780
794
  self._first_dump = False
781
- torch.cuda.empty_cache()
795
+ torch.get_device_module().empty_cache()
782
796
 
783
797
  torch.distributed.all_reduce(
784
798
  logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
@@ -787,6 +801,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
787
801
  output = dict(
788
802
  rank=self._rank,
789
803
  logical_count=logical_count_of_buffered_step,
804
+ average_utilization_rate_over_window=self._get_global_average_utilization_rate(),
790
805
  )
791
806
 
792
807
  if output_mode == "file":
@@ -797,6 +812,31 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
797
812
  else:
798
813
  raise NotImplementedError
799
814
 
815
+ def _get_global_average_utilization_rate(self):
816
+ if not self._enable or math.isclose(
817
+ self._server_args.eplb_min_rebalancing_utilization_threshold, 1.0
818
+ ):
819
+ return None
820
+
821
+ if self._rank == 0:
822
+ utilization_mean_rates = self._history.mean()
823
+ window_index = self.window_sizes[-1]
824
+ average_utilization_rate_over_window = (
825
+ utilization_mean_rates[window_index]
826
+ if window_index in utilization_mean_rates
827
+ else 0
828
+ )
829
+
830
+ avg_rate_tensor = torch.tensor(
831
+ [average_utilization_rate_over_window],
832
+ dtype=torch.float32,
833
+ device="cuda",
834
+ )
835
+ else:
836
+ avg_rate_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
837
+ torch.distributed.broadcast(avg_rate_tensor, src=0)
838
+ return avg_rate_tensor.item()
839
+
800
840
 
801
841
  def _dump_to_file(name, data):
802
842
  save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
@@ -11,21 +11,26 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
+
15
+ from __future__ import annotations
16
+
14
17
  import json
15
18
  import logging
16
19
  import random
17
20
  from dataclasses import dataclass
18
21
  from pathlib import Path
19
- from typing import List, Optional
22
+ from typing import TYPE_CHECKING, List, Optional
20
23
 
21
24
  import torch
22
25
  import torch.distributed
23
26
  import torch.nn.functional as F
24
27
 
25
- from sglang.srt.configs.model_config import ModelConfig
26
28
  from sglang.srt.eplb import eplb_algorithms
27
29
  from sglang.srt.model_loader import get_model_architecture
28
- from sglang.srt.server_args import ServerArgs
30
+
31
+ if TYPE_CHECKING:
32
+ from sglang.srt.configs.model_config import ModelConfig
33
+ from sglang.srt.server_args import ServerArgs
29
34
 
30
35
  logger = logging.getLogger(__name__)
31
36
 
@@ -47,7 +47,7 @@ class ExpertLocationUpdater:
47
47
  ):
48
48
  if self._first_execution:
49
49
  self._first_execution = False
50
- torch.cuda.empty_cache()
50
+ torch.get_device_module().empty_cache()
51
51
 
52
52
  old_expert_location_metadata = get_global_expert_location_metadata()
53
53
  assert old_expert_location_metadata is not None