sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ import ctypes
22
22
  import dataclasses
23
23
  import functools
24
24
  import importlib
25
+ import inspect
25
26
  import io
26
27
  import ipaddress
27
28
  import itertools
@@ -82,11 +83,9 @@ from packaging import version as pkg_version
82
83
  from PIL import Image
83
84
  from starlette.routing import Mount
84
85
  from torch import nn
85
- from torch.func import functional_call
86
86
  from torch.library import Library
87
87
  from torch.profiler import ProfilerActivity, profile, record_function
88
88
  from torch.utils._contextlib import _DecoratorContextManager
89
- from triton.runtime.cache import FileCacheManager
90
89
  from typing_extensions import Literal
91
90
 
92
91
  from sglang.srt.metrics.func_timer import enable_func_timer
@@ -167,6 +166,7 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
167
166
  is_hopper_with_cuda_12_3 = lambda: _check(9)
168
167
 
169
168
 
169
+ @lru_cache(maxsize=1)
170
170
  def is_blackwell():
171
171
  if not is_cuda():
172
172
  return False
@@ -175,6 +175,8 @@ def is_blackwell():
175
175
 
176
176
  @lru_cache(maxsize=1)
177
177
  def is_sm100_supported(device=None) -> bool:
178
+ if not is_cuda_alike():
179
+ return False
178
180
  return (torch.cuda.get_device_capability(device)[0] == 10) and (
179
181
  torch.version.cuda >= "12.8"
180
182
  )
@@ -182,6 +184,8 @@ def is_sm100_supported(device=None) -> bool:
182
184
 
183
185
  @lru_cache(maxsize=1)
184
186
  def is_sm90_supported(device=None) -> bool:
187
+ if not is_cuda_alike():
188
+ return False
185
189
  return (torch.cuda.get_device_capability(device)[0] == 9) and (
186
190
  torch.version.cuda >= "12.3"
187
191
  )
@@ -191,6 +195,7 @@ _warned_bool_env_var_keys = set()
191
195
 
192
196
 
193
197
  def get_bool_env_var(name: str, default: str = "false") -> bool:
198
+ # FIXME: move your environment variable to sglang.srt.environ
194
199
  value = os.getenv(name, default)
195
200
  value = value.lower()
196
201
 
@@ -208,6 +213,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
208
213
 
209
214
 
210
215
  def get_int_env_var(name: str, default: int = 0) -> int:
216
+ # FIXME: move your environment variable to sglang.srt.environ
211
217
  value = os.getenv(name)
212
218
  if value is None or not value.strip():
213
219
  return default
@@ -465,7 +471,7 @@ def is_pin_memory_available() -> bool:
465
471
 
466
472
  class LayerFn(Protocol):
467
473
 
468
- def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
474
+ def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
469
475
 
470
476
 
471
477
  def make_layers(
@@ -476,7 +482,7 @@ def make_layers(
476
482
  prefix: str = "",
477
483
  return_tuple: bool = False,
478
484
  offloader_kwargs: Dict[str, Any] = {},
479
- ) -> Tuple[int, int, torch.nn.ModuleList]:
485
+ ) -> Tuple[torch.nn.Module, int, int]:
480
486
  """Make a list of layers with the given layer function"""
481
487
  # circula imports
482
488
  from sglang.srt.distributed import get_pp_indices
@@ -512,6 +518,50 @@ def make_layers(
512
518
  return modules, start_layer, end_layer
513
519
 
514
520
 
521
+ cmo_stream = None
522
+
523
+
524
+ def get_cmo_stream():
525
+ """
526
+ Cache Management Operation(CMO).
527
+ Launch a new stream to prefetch the weight of matmul when running other
528
+ AIV or communication kernels, aiming to overlap the memory access time.
529
+ """
530
+ global cmo_stream
531
+ if cmo_stream is None:
532
+ cmo_stream = torch.get_device_module().Stream()
533
+ return cmo_stream
534
+
535
+
536
+ def prepare_weight_cache(handle, cache):
537
+ import torch_npu
538
+
539
+ NPU_PREFETCH_MAX_SIZE_BYTES = (
540
+ 1000000000 # 1GB, a large value to prefetch entire weight
541
+ )
542
+ stream = get_cmo_stream()
543
+ stream.wait_stream(torch.npu.current_stream())
544
+ with torch.npu.stream(stream):
545
+ if isinstance(cache, list):
546
+ for weight in cache:
547
+ torch_npu.npu_prefetch(
548
+ weight,
549
+ handle,
550
+ NPU_PREFETCH_MAX_SIZE_BYTES,
551
+ )
552
+ else:
553
+ torch_npu.npu_prefetch(
554
+ cache,
555
+ handle,
556
+ NPU_PREFETCH_MAX_SIZE_BYTES,
557
+ )
558
+
559
+
560
+ def wait_cmo_stream():
561
+ cur_stream = torch.get_device_module().current_stream()
562
+ cur_stream.wait_stream(get_cmo_stream())
563
+
564
+
515
565
  def set_random_seed(seed: int) -> None:
516
566
  """Set the random seed for all libraries."""
517
567
  random.seed(seed)
@@ -749,6 +799,25 @@ def load_image(
749
799
  return image, image_size
750
800
 
751
801
 
802
+ def get_image_bytes(image_file: Union[str, bytes]):
803
+ if isinstance(image_file, bytes):
804
+ return image_file
805
+ elif image_file.startswith("http://") or image_file.startswith("https://"):
806
+ timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
807
+ response = requests.get(image_file, timeout=timeout)
808
+ return response.content
809
+ elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
810
+ with open(image_file, "rb") as f:
811
+ return f.read()
812
+ elif image_file.startswith("data:"):
813
+ image_file = image_file.split(",")[1]
814
+ return pybase64.b64decode(image_file)
815
+ elif isinstance(image_file, str):
816
+ return pybase64.b64decode(image_file)
817
+ else:
818
+ raise NotImplementedError(f"Invalid image: {image_file}")
819
+
820
+
752
821
  def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
753
822
  # We import decord here to avoid a strange Segmentation fault (core dumped) issue.
754
823
  from decord import VideoReader, cpu, gpu
@@ -804,6 +873,33 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
804
873
  os.unlink(tmp_file.name)
805
874
 
806
875
 
876
+ def encode_video(video_path, frame_count_limit=None):
877
+ # Lazy import because decord is not available on some arm platforms.
878
+ from decord import VideoReader, cpu
879
+
880
+ if not os.path.exists(video_path):
881
+ logger.error(f"Video {video_path} does not exist")
882
+ return []
883
+
884
+ if frame_count_limit == 0:
885
+ return []
886
+
887
+ def uniform_sample(l, n):
888
+ gap = len(l) / n
889
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
890
+ return [l[i] for i in idxs]
891
+
892
+ vr = VideoReader(video_path, ctx=cpu(0))
893
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
894
+ frame_indices = [i for i in range(0, len(vr), sample_fps)]
895
+ if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
896
+ frame_indices = uniform_sample(frame_indices, frame_count_limit)
897
+
898
+ frames = vr.get_batch(frame_indices).asnumpy()
899
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
900
+ return frames
901
+
902
+
807
903
  def suppress_other_loggers():
808
904
  warnings.filterwarnings(
809
905
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -946,6 +1042,13 @@ def set_ulimit(target_soft_limit=65535):
946
1042
  logger.warning(f"Fail to set RLIMIT_STACK: {e}")
947
1043
 
948
1044
 
1045
+ def rank0_log(msg: str):
1046
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
1047
+
1048
+ if get_tensor_model_parallel_rank() == 0:
1049
+ logger.info(msg)
1050
+
1051
+
949
1052
  def add_api_key_middleware(app, api_key: str):
950
1053
  @app.middleware("http")
951
1054
  async def authentication(request, call_next):
@@ -1404,6 +1507,32 @@ def get_npu_memory_capacity():
1404
1507
  raise ImportError("torch_npu is required when run on npu device.")
1405
1508
 
1406
1509
 
1510
+ def get_cpu_memory_capacity():
1511
+ # Per-rank memory capacity cannot be determined for customized core settings
1512
+ if os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", ""):
1513
+ return None
1514
+ n_numa_node: int = len(get_cpu_ids_by_node())
1515
+ if n_numa_node == 0:
1516
+ # Cannot determine NUMA config, fallback to total memory and avoid ZeroDivisionError.
1517
+ return float(psutil.virtual_memory().total // (1 << 20))
1518
+ try:
1519
+ numa_mem_list = list()
1520
+ file_prefix = "/sys/devices/system/node/"
1521
+ for numa_id in range(n_numa_node):
1522
+ file_meminfo = f"node{numa_id}/meminfo"
1523
+ with open(os.path.join(file_prefix, file_meminfo), "r") as f:
1524
+ # 1st line contains 'MemTotal'
1525
+ line = f.read().split("\n")[0]
1526
+ numa_mem_list.append(int(line.split()[3]))
1527
+ # Retrieved value in KB, need MB
1528
+ numa_mem = float(min(numa_mem_list) // 1024)
1529
+ return numa_mem
1530
+ except FileNotFoundError:
1531
+ numa_mem = psutil.virtual_memory().total / n_numa_node
1532
+ # Retrieved value in Byte, need MB
1533
+ return float(numa_mem // (1 << 20))
1534
+
1535
+
1407
1536
  def get_device_memory_capacity(device: str = None):
1408
1537
  if is_cuda():
1409
1538
  gpu_mem = get_nvgpu_memory_capacity()
@@ -1413,6 +1542,8 @@ def get_device_memory_capacity(device: str = None):
1413
1542
  gpu_mem = get_hpu_memory_capacity()
1414
1543
  elif device == "npu":
1415
1544
  gpu_mem = get_npu_memory_capacity()
1545
+ elif device == "cpu":
1546
+ gpu_mem = get_cpu_memory_capacity()
1416
1547
  else:
1417
1548
  # GPU memory is not known yet or no GPU is available.
1418
1549
  gpu_mem = None
@@ -1951,50 +2082,6 @@ def set_uvicorn_logging_configs():
1951
2082
  LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1952
2083
 
1953
2084
 
1954
- def get_ip() -> str:
1955
- # SGLANG_HOST_IP env can be ignore
1956
- host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
1957
- if host_ip:
1958
- return host_ip
1959
-
1960
- # IP is not set, try to get it from the network interface
1961
-
1962
- # try ipv4
1963
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1964
- try:
1965
- s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
1966
- return s.getsockname()[0]
1967
- except Exception:
1968
- pass
1969
-
1970
- # try ipv6
1971
- try:
1972
- s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
1973
- # Google's public DNS server, see
1974
- # https://developers.google.com/speed/public-dns/docs/using#addresses
1975
- s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
1976
- return s.getsockname()[0]
1977
- except Exception:
1978
- pass
1979
-
1980
- # try using hostname
1981
- hostname = socket.gethostname()
1982
- try:
1983
- ip_addr = socket.gethostbyname(hostname)
1984
- warnings.warn("using local ip address: {}".format(ip_addr))
1985
- return ip_addr
1986
- except Exception:
1987
- pass
1988
-
1989
- warnings.warn(
1990
- "Failed to get the IP address, using 0.0.0.0 by default."
1991
- "The value can be set by the environment variable"
1992
- " SGLANG_HOST_IP or HOST_IP.",
1993
- stacklevel=2,
1994
- )
1995
- return "0.0.0.0"
1996
-
1997
-
1998
2085
  def get_open_port() -> int:
1999
2086
  port = os.getenv("SGLANG_PORT")
2000
2087
  if port is not None:
@@ -2251,16 +2338,9 @@ def bind_or_assign(target, source):
2251
2338
  return source
2252
2339
 
2253
2340
 
2254
- def get_local_ip_auto() -> str:
2255
- interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
2256
- return (
2257
- get_local_ip_by_nic(interface)
2258
- if interface is not None
2259
- else get_local_ip_by_remote()
2260
- )
2261
-
2262
-
2263
- def get_local_ip_by_nic(interface: str) -> str:
2341
+ def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
2342
+ if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
2343
+ return None
2264
2344
  try:
2265
2345
  import netifaces
2266
2346
  except ImportError as e:
@@ -2281,15 +2361,13 @@ def get_local_ip_by_nic(interface: str) -> str:
2281
2361
  if ip and not ip.startswith("fe80::") and ip != "::1":
2282
2362
  return ip.split("%")[0]
2283
2363
  except (ValueError, OSError) as e:
2284
- raise ValueError(
2285
- "Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2364
+ logger.warning(
2365
+ f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
2286
2366
  )
2287
-
2288
- # Fallback
2289
- return get_local_ip_by_remote()
2367
+ return None
2290
2368
 
2291
2369
 
2292
- def get_local_ip_by_remote() -> str:
2370
+ def get_local_ip_by_remote() -> Optional[str]:
2293
2371
  # try ipv4
2294
2372
  s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
2295
2373
  try:
@@ -2314,7 +2392,51 @@ def get_local_ip_by_remote() -> str:
2314
2392
  s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
2315
2393
  return s.getsockname()[0]
2316
2394
  except Exception:
2317
- raise ValueError("Can not get local ip")
2395
+ logger.warning("Can not get local ip by remote")
2396
+ return None
2397
+
2398
+
2399
+ def get_local_ip_auto(fallback: str = None) -> str:
2400
+ """
2401
+ Automatically detect the local IP address using multiple fallback strategies.
2402
+
2403
+ This function attempts to obtain the local IP address through several methods.
2404
+ If all methods fail, it returns the specified fallback value or raises an exception.
2405
+
2406
+ Args:
2407
+ fallback (str, optional): Fallback IP address to return if all detection
2408
+ methods fail. For server applications, explicitly set this to
2409
+ "0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
2410
+ Defaults to None.
2411
+
2412
+ Returns:
2413
+ str: The detected local IP address, or the fallback value if detection fails.
2414
+
2415
+ Raises:
2416
+ ValueError: If IP detection fails and no fallback value is provided.
2417
+
2418
+ Note:
2419
+ The function tries detection methods in the following order:
2420
+ 1. Direct IP detection via get_ip()
2421
+ 2. Network interface enumeration via get_local_ip_by_nic()
2422
+ 3. Remote connection method via get_local_ip_by_remote()
2423
+ """
2424
+ # Try environment variable
2425
+ host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
2426
+ if host_ip:
2427
+ return host_ip
2428
+ logger.debug("get_ip failed")
2429
+ # Fallback
2430
+ if ip := get_local_ip_by_nic():
2431
+ return ip
2432
+ logger.debug("get_local_ip_by_nic failed")
2433
+ # Fallback
2434
+ if ip := get_local_ip_by_remote():
2435
+ return ip
2436
+ logger.debug("get_local_ip_by_remote failed")
2437
+ if fallback:
2438
+ return fallback
2439
+ raise ValueError("Can not get local ip")
2318
2440
 
2319
2441
 
2320
2442
  def is_page_size_one(server_args):
@@ -2366,7 +2488,7 @@ class BumpAllocator:
2366
2488
  def log_info_on_rank0(logger, msg):
2367
2489
  from sglang.srt.distributed import get_tensor_model_parallel_rank
2368
2490
 
2369
- if get_tensor_model_parallel_rank() == 0:
2491
+ if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
2370
2492
  logger.info(msg)
2371
2493
 
2372
2494
 
@@ -2496,14 +2618,6 @@ def read_system_prompt_from_file(model_name: str) -> str:
2496
2618
  return ""
2497
2619
 
2498
2620
 
2499
- def bind_or_assign(target, source):
2500
- if target is not None:
2501
- target.copy_(source)
2502
- return target
2503
- else:
2504
- return source
2505
-
2506
-
2507
2621
  def prepack_weight_if_needed(weight):
2508
2622
  if weight.device != torch.device("cpu"):
2509
2623
  return weight
@@ -3042,6 +3156,44 @@ def check_cuda_result(raw_output):
3042
3156
  return results
3043
3157
 
3044
3158
 
3159
+ def get_physical_device_id(pytorch_device_id: int) -> int:
3160
+ """
3161
+ Convert PyTorch logical device ID to physical device ID.
3162
+ """
3163
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
3164
+ assert (
3165
+ cuda_visible_devices is not None
3166
+ ), "CUDA_VISIBLE_DEVICES should be set in a scheduler"
3167
+ device_list = cuda_visible_devices.split(",")
3168
+ assert (
3169
+ len(device_list) == 1
3170
+ ), "CUDA_VISIBLE_DEVICES should be set to a single device in a scheduler"
3171
+ return int(device_list[0])
3172
+
3173
+
3174
+ def get_device_sm_nvidia_smi():
3175
+ try:
3176
+ # Run nvidia-smi command and capture output
3177
+ result = subprocess.run(
3178
+ ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"],
3179
+ capture_output=True,
3180
+ text=True,
3181
+ check=True,
3182
+ )
3183
+
3184
+ # Get the first line of output (assuming at least one GPU exists)
3185
+ compute_cap_str = result.stdout.strip().split("\n")[0]
3186
+
3187
+ # Convert string (e.g., "9.0") to tuple of integers (9, 0)
3188
+ major, minor = map(int, compute_cap_str.split("."))
3189
+ return (major, minor)
3190
+
3191
+ except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
3192
+ # Handle cases where nvidia-smi isn't available or output is unexpected
3193
+ print(f"Error getting compute capability: {e}")
3194
+ return (0, 0) # Default/fallback value
3195
+
3196
+
3045
3197
  def numa_bind_to_node(node: int):
3046
3198
  libnuma = ctypes.CDLL("libnuma.so")
3047
3199
  if libnuma.numa_available() < 0:
@@ -3058,3 +3210,176 @@ def json_list_type(value):
3058
3210
  raise argparse.ArgumentTypeError(
3059
3211
  f"Invalid JSON list: {value}. Please provide a valid JSON list."
3060
3212
  )
3213
+
3214
+
3215
+ @contextmanager
3216
+ def temp_set_cuda_visible_devices(gpu_id: int):
3217
+ original_cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
3218
+ if original_cuda_visible_devices:
3219
+ cuda_visible_devices = original_cuda_visible_devices.split(",")
3220
+ else:
3221
+ cuda_visible_devices = []
3222
+
3223
+ str_gpu_id = cuda_visible_devices[gpu_id] if cuda_visible_devices else str(gpu_id)
3224
+ os.environ["CUDA_VISIBLE_DEVICES"] = str_gpu_id
3225
+ yield
3226
+ if original_cuda_visible_devices:
3227
+ os.environ["CUDA_VISIBLE_DEVICES"] = original_cuda_visible_devices
3228
+ else:
3229
+ del os.environ["CUDA_VISIBLE_DEVICES"]
3230
+
3231
+
3232
+ def get_extend_input_len_swa_limit(
3233
+ sliding_window_size: int, chunked_prefill_size: int, page_size: int
3234
+ ) -> int:
3235
+ # 1. a factor of 2x is because each prefill contains chunked_prefill_size tokens,
3236
+ # and between prefills, we run swa_radix_cache.cache_unfinished_req(),
3237
+ # so we unlock the previously locked nodes.
3238
+ # 2. max is to handle the case that chunked_prefill_size is larger than sliding_window_size.
3239
+ # in that case, each prefill contains chunked_prefill_size tokens,
3240
+ # and we can only free out-of-sliding-window kv indices after each prefill.
3241
+ # 3. page_size is because we want to have 1 token extra for generated tokens.
3242
+ return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
3243
+
3244
+
3245
+ def get_num_new_pages(
3246
+ seq_lens: torch.Tensor,
3247
+ page_size: int,
3248
+ prefix_lens: Optional[torch.Tensor] = None,
3249
+ decode: bool = False,
3250
+ ) -> torch.Tensor:
3251
+ """
3252
+ Get the number of new pages for the given prefix and sequence lengths.
3253
+ We use cpu tensors to avoid blocking kernel launch.
3254
+ """
3255
+ cpu_device = torch.device("cpu")
3256
+ assert seq_lens.device == cpu_device
3257
+
3258
+ if prefix_lens is None or decode:
3259
+ # NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
3260
+ assert decode
3261
+ return (seq_lens % page_size == 1).int().sum().item()
3262
+
3263
+ assert prefix_lens.device == cpu_device
3264
+ num_pages_after = (seq_lens + page_size - 1) // page_size
3265
+ num_pages_before = (prefix_lens + page_size - 1) // page_size
3266
+ num_new_pages = num_pages_after - num_pages_before
3267
+ sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
3268
+ return sum_num_new_pages.item()
3269
+
3270
+
3271
+ class CachedKernel:
3272
+ """
3273
+ Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
3274
+
3275
+ This wrapper caches compiled Triton kernels based on keys extracted by a
3276
+ user-provided key function to avoid redundant compilations.
3277
+ """
3278
+
3279
+ def __init__(self, fn, key_fn=None):
3280
+ self.fn = fn
3281
+ assert isinstance(fn, triton.runtime.jit.JITFunction)
3282
+
3283
+ original_fn = fn.fn
3284
+ self.signature = inspect.signature(original_fn)
3285
+ self.param_names = tuple(self.signature.parameters.keys())
3286
+ self.num_args = len(self.param_names)
3287
+
3288
+ # Check that no parameters have default values
3289
+ for name, param in self.signature.parameters.items():
3290
+ assert (
3291
+ param.default is inspect.Parameter.empty
3292
+ ), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
3293
+
3294
+ functools.update_wrapper(self, original_fn)
3295
+ self.kernel_cache = {}
3296
+
3297
+ # Store the key function
3298
+ self.key_fn = key_fn
3299
+
3300
+ def __getitem__(self, grid):
3301
+ """
3302
+ Index with grid to get a launcher function.
3303
+ Returns a launcher that will handle caching based on the key function.
3304
+ """
3305
+ assert (
3306
+ isinstance(grid, tuple) and len(grid) <= 3
3307
+ ), "Grid must be a tuple with at most 3 dimensions."
3308
+
3309
+ # Normalize grid once
3310
+ if len(grid) < 3:
3311
+ grid = grid + (1,) * (3 - len(grid))
3312
+
3313
+ def launcher(*args, **kwargs):
3314
+ cache_key = self.key_fn(args, kwargs)
3315
+
3316
+ cached_kernel = self.kernel_cache.get(cache_key)
3317
+
3318
+ if cached_kernel is None:
3319
+ # First time: compile and cache the kernel
3320
+ cached_kernel = self.fn[grid](*args, **kwargs)
3321
+ self.kernel_cache[cache_key] = cached_kernel
3322
+ return cached_kernel
3323
+ else:
3324
+ # Use cached kernel
3325
+ all_args = self._build_args(args, kwargs)
3326
+ cached_kernel[grid](*all_args)
3327
+ return cached_kernel
3328
+
3329
+ return launcher
3330
+
3331
+ def _build_args(self, args, kwargs):
3332
+ """
3333
+ Build the complete argument list for kernel invocation.
3334
+ """
3335
+ complete_args = list(args)
3336
+
3337
+ for i in range(len(args), self.num_args):
3338
+ name = self.param_names[i]
3339
+ value = kwargs.get(name, inspect.Parameter.empty)
3340
+ if value is not inspect.Parameter.empty:
3341
+ complete_args.append(value)
3342
+ else:
3343
+ raise ValueError(f"Missing argument: {name}")
3344
+
3345
+ return complete_args
3346
+
3347
+ def _clear_cache(self):
3348
+ """
3349
+ Clear the kernel cache for testing purposes.
3350
+ """
3351
+ self.kernel_cache.clear()
3352
+
3353
+
3354
+ def cached_triton_kernel(key_fn=None):
3355
+ """
3356
+ Decorator that enables key-based caching for Triton kernels using a key function.
3357
+
3358
+ It essentially bypasses Triton's built-in caching mechanism, allowing users to
3359
+ define their own caching strategy based on kernel parameters. This helps reduce
3360
+ the heavy overheads of Triton kernel launch when the kernel specialization dispatch
3361
+ is simple.
3362
+
3363
+ Usage:
3364
+ @cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
3365
+ @triton.jit
3366
+ def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
3367
+ ...
3368
+
3369
+ # Invoke normally
3370
+ my_kernel[grid](x, y, BLOCK_SIZE=1024)
3371
+
3372
+ Args:
3373
+ key_fn: A function that takes (args, kwargs) and returns the cache key(s).
3374
+ The key can be a single value or a tuple of values.
3375
+
3376
+ Returns:
3377
+ A decorator that wraps the kernel with caching functionality.
3378
+
3379
+ Note: Kernels with default parameter values are not supported and will raise an assertion error.
3380
+ """
3381
+
3382
+ def decorator(fn):
3383
+ return CachedKernel(fn, key_fn)
3384
+
3385
+ return decorator