sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,10 @@
1
- from typing import Any, Dict, List, Optional, Union
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2
4
 
3
5
  from fastapi import Request
4
6
  from fastapi.responses import ORJSONResponse
5
7
 
6
- from sglang.srt.conversation import generate_embedding_convs
7
8
  from sglang.srt.entrypoints.openai.protocol import (
8
9
  EmbeddingObject,
9
10
  EmbeddingRequest,
@@ -14,8 +15,11 @@ from sglang.srt.entrypoints.openai.protocol import (
14
15
  )
15
16
  from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
16
17
  from sglang.srt.managers.io_struct import EmbeddingReqInput
17
- from sglang.srt.managers.template_manager import TemplateManager
18
- from sglang.srt.managers.tokenizer_manager import TokenizerManager
18
+ from sglang.srt.parser.conversation import generate_embedding_convs
19
+
20
+ if TYPE_CHECKING:
21
+ from sglang.srt.managers.template_manager import TemplateManager
22
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
19
23
 
20
24
 
21
25
  class OpenAIServingEmbedding(OpenAIServingBase):
@@ -1,6 +1,7 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  # Adapted from vLLM's OpenAIServingResponses
3
3
  """Handler for /v1/responses requests"""
4
+ from __future__ import annotations
4
5
 
5
6
  import asyncio
6
7
  import copy
@@ -9,7 +10,7 @@ import logging
9
10
  import time
10
11
  from contextlib import AsyncExitStack
11
12
  from http import HTTPStatus
12
- from typing import Any, AsyncGenerator, AsyncIterator, Optional, Union
13
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional, Union
13
14
 
14
15
  import jinja2
15
16
  import openai.types.responses as openai_responses_types
@@ -54,11 +55,13 @@ from sglang.srt.entrypoints.openai.protocol import (
54
55
  from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
55
56
  from sglang.srt.entrypoints.openai.tool_server import MCPToolServer, ToolServer
56
57
  from sglang.srt.managers.io_struct import GenerateReqInput
57
- from sglang.srt.managers.template_manager import TemplateManager
58
- from sglang.srt.managers.tokenizer_manager import TokenizerManager
59
- from sglang.srt.reasoning_parser import ReasoningParser
58
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
60
59
  from sglang.srt.utils import random_uuid
61
60
 
61
+ if TYPE_CHECKING:
62
+ from sglang.srt.managers.template_manager import TemplateManager
63
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
64
+
62
65
  logger = logging.getLogger(__name__)
63
66
 
64
67
 
@@ -55,12 +55,21 @@ class EPLBManager:
55
55
  enable_timing = self._rebalance_layers_per_chunk is None
56
56
 
57
57
  if enable_timing:
58
- torch.cuda.synchronize()
58
+ torch.get_device_module().synchronize()
59
59
  time_start = time.time()
60
60
 
61
- logical_count = get_global_expert_distribution_recorder().dump_record(
61
+ dump_record_output = get_global_expert_distribution_recorder().dump_record(
62
62
  output_mode="object"
63
- )["logical_count"]
63
+ )
64
+ logical_count = dump_record_output["logical_count"]
65
+ average_utilization_rate_over_window = dump_record_output[
66
+ "average_utilization_rate_over_window"
67
+ ]
68
+
69
+ # Check whether rebalancing is needed
70
+ if not self._check_rebalance_needed(average_utilization_rate_over_window):
71
+ return
72
+
64
73
  expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
65
74
  self._server_args, self._model_runner.model_config, logical_count
66
75
  )
@@ -76,11 +85,26 @@ class EPLBManager:
76
85
 
77
86
  msg = f"[EPLBManager] rebalance end"
78
87
  if enable_timing:
79
- torch.cuda.synchronize()
88
+ torch.get_device_module().synchronize()
80
89
  time_end = time.time()
81
90
  msg += f" time={time_end - time_start:.3f}s"
82
91
  logger.info(msg)
83
92
 
93
+ def _check_rebalance_needed(self, average_utilization_rate_over_window):
94
+ if average_utilization_rate_over_window is None:
95
+ return True
96
+
97
+ if (
98
+ average_utilization_rate_over_window
99
+ > self._server_args.eplb_min_rebalancing_utilization_threshold
100
+ ):
101
+ logger.info(
102
+ f"[EPLBManager] Skipped ep rebalancing: current GPU utilization {average_utilization_rate_over_window:.2f} > minimum rebalance threshold {self._server_args.eplb_min_rebalancing_utilization_threshold:.2f}"
103
+ )
104
+ return False
105
+
106
+ return True
107
+
84
108
  def _compute_update_layer_ids_chunks(self) -> List[List[int]]:
85
109
  all_layer_ids = sorted(
86
110
  list(self._model_runner.model.routed_experts_weights_of_layer.keys())
@@ -11,23 +11,31 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
+
15
+ from __future__ import annotations
16
+
14
17
  import logging
18
+ import math
15
19
  import os
16
20
  import time
17
21
  from abc import ABC
18
22
  from collections import deque
19
23
  from contextlib import contextmanager
20
24
  from pathlib import Path
21
- from typing import Any, Dict, List, Literal, Optional, Tuple, Type
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
22
26
 
23
27
  import einops
24
28
  import torch
25
29
  import torch.distributed
26
30
 
27
- from sglang.srt.eplb.expert_location import ExpertLocationMetadata
28
31
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
29
32
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.utils import Withable, get_bool_env_var
33
+ from sglang.srt.utils import Withable, get_bool_env_var, is_npu
34
+
35
+ _is_npu = is_npu()
36
+
37
+ if TYPE_CHECKING:
38
+ from sglang.srt.eplb.expert_location import ExpertLocationMetadata
31
39
 
32
40
  logger = logging.getLogger(__name__)
33
41
 
@@ -42,7 +50,7 @@ class ExpertDistributionRecorder(ABC):
42
50
  @staticmethod
43
51
  def init_new(
44
52
  server_args: ServerArgs,
45
- expert_location_metadata: "ExpertLocationMetadata",
53
+ expert_location_metadata: ExpertLocationMetadata,
46
54
  rank: int,
47
55
  ):
48
56
  if server_args.expert_distribution_recorder_mode is not None:
@@ -117,7 +125,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
117
125
  def __init__(
118
126
  self,
119
127
  server_args: ServerArgs,
120
- expert_location_metadata: "ExpertLocationMetadata",
128
+ expert_location_metadata: ExpertLocationMetadata,
121
129
  rank: int,
122
130
  ):
123
131
  self._server_args = server_args
@@ -210,7 +218,9 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
210
218
  def _on_hook(self, hook_name: str, **kwargs):
211
219
  if self._disable_all:
212
220
  return
213
- if not (self._recording or torch.cuda.is_current_stream_capturing()):
221
+ if not (
222
+ self._recording or torch.get_device_module().is_current_stream_capturing()
223
+ ):
214
224
  return
215
225
  gatherer = self._single_pass_gatherers[
216
226
  self._accumulator.get_single_pass_gatherer_key(
@@ -278,7 +288,7 @@ class _SinglePassGatherer(ABC):
278
288
  @staticmethod
279
289
  def init_new(
280
290
  server_args: ServerArgs,
281
- expert_location_metadata: "ExpertLocationMetadata",
291
+ expert_location_metadata: ExpertLocationMetadata,
282
292
  rank: int,
283
293
  ) -> "_SinglePassGatherer":
284
294
  if server_args.expert_distribution_recorder_mode == "per_token":
@@ -306,7 +316,7 @@ class _SinglePassGatherer(ABC):
306
316
 
307
317
  return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
308
318
 
309
- def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
319
+ def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int):
310
320
  self._expert_location_metadata = expert_location_metadata
311
321
  self._rank = rank
312
322
 
@@ -345,7 +355,7 @@ class _DetailSinglePassGatherer(_SinglePassGatherer):
345
355
  def __init__(
346
356
  self,
347
357
  server_args: ServerArgs,
348
- expert_location_metadata: "ExpertLocationMetadata",
358
+ expert_location_metadata: ExpertLocationMetadata,
349
359
  rank: int,
350
360
  ):
351
361
  super().__init__(expert_location_metadata, rank)
@@ -445,6 +455,10 @@ def _list_sum(a: List, b: List) -> List:
445
455
  class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
446
456
  def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
447
457
  super().__init__(*args, **kwargs)
458
+ if not _is_npu:
459
+ device = "cuda"
460
+ else:
461
+ device = "npu"
448
462
  self._enable_global_physical_experts = enable_global_physical_experts
449
463
  self._data = torch.zeros(
450
464
  (
@@ -456,7 +470,7 @@ class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
456
470
  ),
457
471
  ),
458
472
  dtype=torch.int,
459
- device="cuda",
473
+ device=device,
460
474
  )
461
475
 
462
476
  def reset(self):
@@ -560,7 +574,7 @@ class _Accumulator(ABC):
560
574
  @staticmethod
561
575
  def init_new(
562
576
  server_args: ServerArgs,
563
- expert_location_metadata: "ExpertLocationMetadata",
577
+ expert_location_metadata: ExpertLocationMetadata,
564
578
  rank: int,
565
579
  ) -> "_Accumulator":
566
580
  return _Accumulator.get_class(server_args)(
@@ -579,7 +593,7 @@ class _Accumulator(ABC):
579
593
  def __init__(
580
594
  self,
581
595
  server_args: ServerArgs,
582
- expert_location_metadata: "ExpertLocationMetadata",
596
+ expert_location_metadata: ExpertLocationMetadata,
583
597
  rank: int,
584
598
  ):
585
599
  self._server_args = server_args
@@ -614,8 +628,8 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
614
628
  self._enable = self._server_args.enable_expert_distribution_metrics
615
629
 
616
630
  if self._enable:
617
- window_sizes = [10, 100, 1000]
618
- self._history = _DequeCollection(maxlens=window_sizes)
631
+ self.window_sizes = [10, 100, 1000]
632
+ self._history = _DequeCollection(maxlens=self.window_sizes)
619
633
  self._rank = torch.distributed.get_rank()
620
634
 
621
635
  def append(
@@ -778,7 +792,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
778
792
 
779
793
  if self._first_dump:
780
794
  self._first_dump = False
781
- torch.cuda.empty_cache()
795
+ torch.get_device_module().empty_cache()
782
796
 
783
797
  torch.distributed.all_reduce(
784
798
  logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
@@ -787,6 +801,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
787
801
  output = dict(
788
802
  rank=self._rank,
789
803
  logical_count=logical_count_of_buffered_step,
804
+ average_utilization_rate_over_window=self._get_global_average_utilization_rate(),
790
805
  )
791
806
 
792
807
  if output_mode == "file":
@@ -797,6 +812,31 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
797
812
  else:
798
813
  raise NotImplementedError
799
814
 
815
+ def _get_global_average_utilization_rate(self):
816
+ if not self._enable or math.isclose(
817
+ self._server_args.eplb_min_rebalancing_utilization_threshold, 1.0
818
+ ):
819
+ return None
820
+
821
+ if self._rank == 0:
822
+ utilization_mean_rates = self._history.mean()
823
+ window_index = self.window_sizes[-1]
824
+ average_utilization_rate_over_window = (
825
+ utilization_mean_rates[window_index]
826
+ if window_index in utilization_mean_rates
827
+ else 0
828
+ )
829
+
830
+ avg_rate_tensor = torch.tensor(
831
+ [average_utilization_rate_over_window],
832
+ dtype=torch.float32,
833
+ device="cuda",
834
+ )
835
+ else:
836
+ avg_rate_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
837
+ torch.distributed.broadcast(avg_rate_tensor, src=0)
838
+ return avg_rate_tensor.item()
839
+
800
840
 
801
841
  def _dump_to_file(name, data):
802
842
  save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
@@ -11,21 +11,26 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
+
15
+ from __future__ import annotations
16
+
14
17
  import json
15
18
  import logging
16
19
  import random
17
20
  from dataclasses import dataclass
18
21
  from pathlib import Path
19
- from typing import List, Optional
22
+ from typing import TYPE_CHECKING, List, Optional
20
23
 
21
24
  import torch
22
25
  import torch.distributed
23
26
  import torch.nn.functional as F
24
27
 
25
- from sglang.srt.configs.model_config import ModelConfig
26
28
  from sglang.srt.eplb import eplb_algorithms
27
29
  from sglang.srt.model_loader import get_model_architecture
28
- from sglang.srt.server_args import ServerArgs
30
+
31
+ if TYPE_CHECKING:
32
+ from sglang.srt.configs.model_config import ModelConfig
33
+ from sglang.srt.server_args import ServerArgs
29
34
 
30
35
  logger = logging.getLogger(__name__)
31
36
 
@@ -47,7 +47,7 @@ class ExpertLocationUpdater:
47
47
  ):
48
48
  if self._first_execution:
49
49
  self._first_execution = False
50
- torch.cuda.empty_cache()
50
+ torch.get_device_module().empty_cache()
51
51
 
52
52
  old_expert_location_metadata = get_global_expert_location_metadata()
53
53
  assert old_expert_location_metadata is not None
@@ -50,19 +50,19 @@ class EBNFComposer:
50
50
 
51
51
  CALL_RULE_MAP = {
52
52
  "pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
53
- "json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
53
+ "json": 'call_{name} ::= "{{" ws "\\"name\\"" ws ":" ws "\\"{name}\\"" ws "," ws "\\"arguments\\"" ws ":" ws {arguments_rule} ws "}}"',
54
54
  "xml": 'call_{name} ::= "<function={name}>\\n" {arguments_rule} "\\n</function>"',
55
55
  }
56
56
 
57
57
  ARGUMENTS_RULE_MAP = {
58
58
  "pythonic": "{arg_rules}",
59
- "json": '"{{" {arg_rules} "}}"',
59
+ "json": '"{{" ws {arg_rules} ws "}}"',
60
60
  "xml": "{arg_rules}",
61
61
  }
62
62
 
63
63
  KEY_VALUE_RULE_MAP = {
64
64
  "pythonic": '"{key}" "=" {valrule}',
65
- "json": '"\\"{key}\\"" ":" {valrule}',
65
+ "json": '"\\"{key}\\"" ws ":" ws {valrule}',
66
66
  "xml": '"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
67
67
  }
68
68
 
@@ -165,7 +165,7 @@ class EBNFComposer:
165
165
  tool_call_separator: Optional[str] = None,
166
166
  call_rule_fmt: Optional[str] = None,
167
167
  key_value_rule_fmt: Optional[str] = None,
168
- key_value_separator: str = ",",
168
+ key_value_separator: str = 'ws "," ws',
169
169
  ):
170
170
  """
171
171
  Generalized EBNF builder for all detectors.
@@ -183,6 +183,10 @@ class EBNFComposer:
183
183
  key_value_rule_fmt: Optional custom format string for key-value pairs. It should define how each parameter is formatted,
184
184
  with placeholders {key} for the parameter name and {valrule} for the value rule. If None, a default format
185
185
  based on function_format will be used.
186
+ key_value_separator: Raw EBNF fragment inserted between key-value pairs.
187
+ This string is used verbatim (not auto-quoted). Pass:
188
+ - Quoted terminals when you need a literal token (e.g. '","' or '"\\n"').
189
+ - Raw/non-terminals when you need grammar tokens (e.g. 'ws "," ws').
186
190
  """
187
191
  # =================================================================
188
192
  # Step 1: Determine the root tool calls rule
@@ -281,9 +285,7 @@ class EBNFComposer:
281
285
  # Add required properties joined by commas
282
286
  if required:
283
287
  rule_parts.append(
284
- f' "{key_value_separator}" '.join(
285
- prop_kv_pairs[k] for k in required
286
- )
288
+ f" {key_value_separator} ".join(prop_kv_pairs[k] for k in required)
287
289
  )
288
290
 
289
291
  # Add optional properties with flexible ordering
@@ -298,14 +300,14 @@ class EBNFComposer:
298
300
  opt_parts.append(prop_kv_pairs[optional[j]])
299
301
  else:
300
302
  opt_parts.append(
301
- f' ( "{key_value_separator}" {prop_kv_pairs[optional[j]]} )?'
303
+ f" ( {key_value_separator} {prop_kv_pairs[optional[j]]} )?"
302
304
  )
303
305
  opt_alternatives.append("".join(opt_parts))
304
306
 
305
307
  # Wrap with appropriate comma handling based on whether we have required properties
306
308
  if required:
307
309
  # Required properties exist, so optional group needs outer comma
308
- rule_parts.append(f' ( "{key_value_separator}" ( ')
310
+ rule_parts.append(f" ( {key_value_separator} ( ")
309
311
  rule_parts.append(" | ".join(opt_alternatives))
310
312
  rule_parts.append(" ) )?")
311
313
  else:
@@ -160,5 +160,5 @@ class Glm4MoeDetector(BaseFormatDetector):
160
160
  function_format="xml",
161
161
  call_rule_fmt='"{name}" "\\n" ( {arguments_rule} "\\n" )?',
162
162
  key_value_rule_fmt='"<arg_key>{key}</arg_key>" "\\n" "<arg_value>" {valrule} "</arg_value>"',
163
- key_value_separator="\\n",
163
+ key_value_separator='"\\n"',
164
164
  )
@@ -10,7 +10,7 @@ from sglang.srt.function_call.core_types import (
10
10
  ToolCallItem,
11
11
  _GetInfoFunc,
12
12
  )
13
- from sglang.srt.harmony_parser import HarmonyParser
13
+ from sglang.srt.parser.harmony_parser import HarmonyParser
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -358,5 +358,5 @@ class Qwen3CoderDetector(BaseFormatDetector):
358
358
  function_format="xml",
359
359
  call_rule_fmt='"<function={name}>\\n" {arguments_rule} "\\n</function>"',
360
360
  key_value_rule_fmt='"<parameter={key}>\\n" {valrule} "\\n</parameter>"',
361
- key_value_separator="\\n",
361
+ key_value_separator='"\\n"',
362
362
  )
@@ -40,7 +40,9 @@ from sglang.srt.configs import (
40
40
  DeepseekVL2Config,
41
41
  ExaoneConfig,
42
42
  KimiVLConfig,
43
+ LongcatFlashConfig,
43
44
  MultiModalityConfig,
45
+ Qwen3NextConfig,
44
46
  Step3VLConfig,
45
47
  )
46
48
  from sglang.srt.configs.internvl import InternVLChatConfig
@@ -56,6 +58,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
56
58
  KimiVLConfig.model_type: KimiVLConfig,
57
59
  InternVLChatConfig.model_type: InternVLChatConfig,
58
60
  Step3VLConfig.model_type: Step3VLConfig,
61
+ LongcatFlashConfig.model_type: LongcatFlashConfig,
62
+ Qwen3NextConfig.model_type: Qwen3NextConfig,
59
63
  }
60
64
 
61
65
  for name, cls in _CONFIG_REGISTRY.items():
@@ -126,6 +130,14 @@ def get_config(
126
130
  kwargs["gguf_file"] = model
127
131
  model = Path(model).parent
128
132
 
133
+ if is_remote_url(model):
134
+ # BaseConnector implements __del__() to clean up the local dir.
135
+ # Since config files need to exist all the time, so we DO NOT use
136
+ # with statement to avoid closing the client.
137
+ client = create_remote_connector(model)
138
+ client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
139
+ model = client.get_local_dir()
140
+
129
141
  config = AutoConfig.from_pretrained(
130
142
  model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
131
143
  )
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
35
35
  is_cuda,
36
36
  is_hip,
37
37
  is_npu,
38
+ is_xpu,
38
39
  set_weight_attrs,
39
40
  )
40
41
  from sglang.utils import resolve_obj_by_qualname
@@ -44,8 +45,9 @@ _is_npu = is_npu()
44
45
  _is_cpu_amx_available = cpu_has_amx_support()
45
46
  _is_cpu = is_cpu()
46
47
  _is_hip = is_hip()
48
+ _is_xpu = is_xpu()
47
49
 
48
- if _is_cuda:
50
+ if _is_cuda or _is_xpu:
49
51
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
50
52
  elif _is_hip:
51
53
  from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
@@ -70,8 +72,6 @@ class SiluAndMul(CustomOp):
70
72
 
71
73
  def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
72
74
  if _is_cpu_amx_available:
73
- d = x.shape[-1] // 2
74
- output_shape = x.shape[:-1] + (d,)
75
75
  out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
76
76
  return out
77
77
  else:
@@ -81,17 +81,20 @@ class SiluAndMul(CustomOp):
81
81
  out = torch_npu.npu_swiglu(x)
82
82
  return out
83
83
 
84
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
85
+ d = x.shape[-1] // 2
86
+ output_shape = x.shape[:-1] + (d,)
87
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
88
+ silu_and_mul(x, out)
89
+ return out
90
+
84
91
 
85
92
  class GeluAndMul(CustomOp):
86
93
  def __init__(self, approximate="tanh"):
87
94
  super().__init__()
88
95
  self.approximate = approximate
89
96
 
90
- def forward_native(self, x: torch.Tensor) -> torch.Tensor:
91
- d = x.shape[-1] // 2
92
- return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
93
-
94
- def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
97
+ def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
95
98
  d = x.shape[-1] // 2
96
99
  output_shape = x.shape[:-1] + (d,)
97
100
  out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -103,6 +106,33 @@ class GeluAndMul(CustomOp):
103
106
  raise RuntimeError("GeluAndMul only support tanh or none")
104
107
  return out
105
108
 
109
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
110
+ d = x.shape[-1] // 2
111
+ return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
112
+
113
+ def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
114
+ if _is_cpu_amx_available and self.approximate == "tanh":
115
+ return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
116
+ elif _is_cpu_amx_available and self.approximate == "none":
117
+ return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
118
+ else:
119
+ return self.forward_native(x)
120
+
121
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
122
+ return self._forward_impl(x)
123
+
124
+ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
125
+ return self._forward_impl(x)
126
+
127
+ def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
128
+ y_npu, gelu_npu = torch_npu.npu_geglu(
129
+ x,
130
+ dim=-1,
131
+ approximate=1 if self.approximate == "tanh" else 0,
132
+ activate_left=True,
133
+ )
134
+ return y_npu
135
+
106
136
 
107
137
  class NewGELU(CustomOp):
108
138
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -137,6 +167,9 @@ class QuickGELU(CustomOp):
137
167
  gelu_quick(x, out)
138
168
  return out
139
169
 
170
+ def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
171
+ return torch_npu.npu_fast_gelu(x)
172
+
140
173
 
141
174
  class ScaledActivation(nn.Module):
142
175
  """An activation function with post-scale parameters.
@@ -230,7 +263,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
230
263
  return nn.Identity()
231
264
 
232
265
 
233
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
266
+ if not (
267
+ _is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu
268
+ ):
234
269
  logger.info(
235
270
  "sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
236
271
  )