sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,13 @@
1
1
  # Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
2
2
 
3
+ from functools import lru_cache
4
+
3
5
  import numpy as np
4
6
  import torch
5
- from decord import VideoReader, cpu
7
+ import torchvision.transforms as T
8
+ from decord import VideoReader, cpu, gpu
6
9
  from PIL import Image
10
+ from torchvision.transforms import InterpolationMode
7
11
 
8
12
  from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
9
13
  from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
@@ -17,6 +21,20 @@ from sglang.srt.multimodal.processors.base_processor import (
17
21
  class InternVLImageProcessor(BaseMultimodalProcessor):
18
22
  models = [InternVLChatModel, InternS1ForConditionalGeneration]
19
23
 
24
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
25
+ IMAGENET_STD = [0.229, 0.224, 0.225]
26
+
27
+ @staticmethod
28
+ @lru_cache(maxsize=1)
29
+ def _get_normalize_tensors(device="cuda", dtype=torch.float32):
30
+ mean = torch.tensor(
31
+ InternVLImageProcessor.IMAGENET_MEAN, device=device, dtype=dtype
32
+ ).view(-1, 1, 1)
33
+ std = torch.tensor(
34
+ InternVLImageProcessor.IMAGENET_STD, device=device, dtype=dtype
35
+ ).view(-1, 1, 1)
36
+ return mean, std
37
+
20
38
  def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
21
39
  super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
22
40
  image_size = (
@@ -48,99 +66,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
48
66
  image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
49
67
  ).build(_image_processor)
50
68
 
51
- @staticmethod
52
- def build_transform(input_size):
53
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
54
- IMAGENET_STD = (0.229, 0.224, 0.225)
55
-
56
- def resize_image(img, size):
57
- return img.resize((size, size), Image.Resampling.BICUBIC)
58
-
59
- def to_tensor(img):
60
- # Convert PIL Image to numpy array
61
- img_array = np.array(img).astype(np.float32) / 255.0
62
- # Convert HWC to CHW format
63
- img_array = img_array.transpose(2, 0, 1)
64
- return torch.from_numpy(img_array)
65
-
66
- def normalize(tensor, mean, std):
67
- mean = torch.tensor(mean).view(-1, 1, 1)
68
- std = torch.tensor(std).view(-1, 1, 1)
69
- return (tensor - mean) / std
70
-
71
- def transform(img):
72
- img = img.convert("RGB") if img.mode != "RGB" else img
73
- img = resize_image(img, input_size)
74
- tensor = to_tensor(img)
75
- tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
76
- return tensor
77
-
78
- return transform
79
-
80
- @staticmethod
81
- def dynamic_preprocess(
82
- image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
83
- ):
84
-
85
- def find_closest_aspect_ratio(
86
- aspect_ratio, target_ratios, width, height, image_size
87
- ):
88
- best_ratio_diff = float("inf")
89
- best_ratio = (1, 1)
90
- area = width * height
91
- for ratio in target_ratios:
92
- target_aspect_ratio = ratio[0] / ratio[1]
93
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
94
- if ratio_diff < best_ratio_diff:
95
- best_ratio_diff = ratio_diff
96
- best_ratio = ratio
97
- elif ratio_diff == best_ratio_diff:
98
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
99
- best_ratio = ratio
100
- return best_ratio
101
-
102
- orig_width, orig_height = image.size
103
- aspect_ratio = orig_width / orig_height
104
-
105
- # calculate the existing image aspect ratio
106
- target_ratios = set(
107
- (i, j)
108
- for n in range(min_num, max_num + 1)
109
- for i in range(1, n + 1)
110
- for j in range(1, n + 1)
111
- if i * j <= max_num and i * j >= min_num
112
- )
113
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
114
-
115
- # find the closest aspect ratio to the target
116
- target_aspect_ratio = find_closest_aspect_ratio(
117
- aspect_ratio, target_ratios, orig_width, orig_height, image_size
118
- )
119
-
120
- # calculate the target width and height
121
- target_width = image_size * target_aspect_ratio[0]
122
- target_height = image_size * target_aspect_ratio[1]
123
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
124
-
125
- # resize the image
126
- resized_img = image.resize((target_width, target_height))
127
- processed_images = []
128
- for i in range(blocks):
129
- box = (
130
- (i % (target_width // image_size)) * image_size,
131
- (i // (target_width // image_size)) * image_size,
132
- ((i % (target_width // image_size)) + 1) * image_size,
133
- ((i // (target_width // image_size)) + 1) * image_size,
134
- )
135
- # split the image
136
- split_img = resized_img.crop(box)
137
- processed_images.append(split_img)
138
- assert len(processed_images) == blocks
139
- if use_thumbnail and len(processed_images) != 1:
140
- thumbnail_img = image.resize((image_size, image_size))
141
- processed_images.append(thumbnail_img)
142
- return processed_images
143
-
144
69
  @staticmethod
145
70
  def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
146
71
  if bound:
@@ -160,27 +85,110 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
160
85
 
161
86
  @staticmethod
162
87
  def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
163
- vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
88
+ try:
89
+ vr = VideoReader(video_path, ctx=gpu(0), num_threads=1)
90
+ use_gpu = True
91
+ except (RuntimeError, OSError) as e:
92
+ print(
93
+ f"[WARNING] Load video on gpu decoding failed: {e}. Falling back to CPU."
94
+ )
95
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
96
+ use_gpu = False
97
+
164
98
  max_frame = len(vr) - 1
165
99
  fps = float(vr.get_avg_fps())
166
100
 
167
- pixel_values_list, num_patches_list = [], []
168
- transform = InternVLImageProcessor.build_transform(input_size=input_size)
101
+ pixel_values_list = []
102
+ num_patches_list = []
169
103
  frame_indices = InternVLImageProcessor.get_index(
170
104
  bound, fps, max_frame, first_idx=0, num_segments=num_segments
171
105
  )
106
+
107
+ mean, std = InternVLImageProcessor._get_normalize_tensors(device="cuda")
108
+
172
109
  for frame_index in frame_indices:
173
- img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
174
- img = InternVLImageProcessor.dynamic_preprocess(
175
- img, image_size=input_size, use_thumbnail=True, max_num=max_num
110
+ # Load frame
111
+ frame = vr[frame_index]
112
+ if use_gpu:
113
+ img = frame.cuda().permute(2, 0, 1).float() / 255.0
114
+ else:
115
+ img_np = frame.asnumpy()
116
+ img = torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
117
+
118
+ img = (img - mean) / std
119
+
120
+ tiles = InternVLImageProcessor.dynamic_preprocess(
121
+ img, image_size=input_size, max_num=max_num, use_thumbnail=True
176
122
  )
177
- pixel_values = [transform(tile) for tile in img]
178
- pixel_values = torch.stack(pixel_values)
179
- num_patches_list.append(pixel_values.shape[0])
180
- pixel_values_list.append(pixel_values)
181
- pixel_values = torch.cat(pixel_values_list)
123
+
124
+ pixel_values_list.append(tiles)
125
+ num_patches_list.append(tiles.shape[0])
126
+
127
+ pixel_values = torch.cat(pixel_values_list, dim=0)
182
128
  return pixel_values, num_patches_list
183
129
 
130
+ @staticmethod
131
+ def dynamic_preprocess(tensor, image_size=448, max_num=12, use_thumbnail=False):
132
+ C, H, W = tensor.shape
133
+ aspect_ratio = W / H
134
+
135
+ # Generate all possible aspect ratios
136
+ target_ratios = set(
137
+ (i, j)
138
+ for n in range(1, max_num + 1)
139
+ for i in range(1, n + 1)
140
+ for j in range(1, n + 1)
141
+ if i * j <= max_num
142
+ )
143
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
144
+
145
+ # Find closest ratio
146
+ best_ratio_diff = float("inf")
147
+ best_ratio = (1, 1)
148
+
149
+ for x, y in target_ratios:
150
+ target_ar = x / y
151
+ diff = abs(aspect_ratio - target_ar)
152
+ blocks = x * y
153
+ best_blocks = best_ratio[0] * best_ratio[1]
154
+
155
+ if diff < best_ratio_diff:
156
+ best_ratio_diff = diff
157
+ best_ratio = (x, y)
158
+ elif diff == best_ratio_diff and blocks > best_blocks:
159
+ best_ratio = (x, y)
160
+
161
+ target_w, target_h = image_size * best_ratio[0], image_size * best_ratio[1]
162
+ blocks = best_ratio[0] * best_ratio[1]
163
+
164
+ # Resize on GPU
165
+ resized = torch.nn.functional.interpolate(
166
+ tensor.unsqueeze(0),
167
+ size=(target_h, target_w),
168
+ mode="bicubic",
169
+ align_corners=False,
170
+ ).squeeze(0)
171
+
172
+ # Split into tiles
173
+ tiles = []
174
+ for i in range(blocks):
175
+ x = (i % best_ratio[0]) * image_size
176
+ y = (i // best_ratio[0]) * image_size
177
+ tile = resized[:, y : y + image_size, x : x + image_size]
178
+ tiles.append(tile)
179
+
180
+ # Add thumbnail if needed
181
+ if use_thumbnail and len(tiles) > 1:
182
+ thumb = torch.nn.functional.interpolate(
183
+ tensor.unsqueeze(0),
184
+ size=(image_size, image_size),
185
+ mode="bicubic",
186
+ align_corners=False,
187
+ ).squeeze(0)
188
+ tiles.append(thumb)
189
+
190
+ return torch.stack(tiles).to(torch.bfloat16)
191
+
184
192
  async def process_mm_data_async(
185
193
  self, image_data, input_text, request_obj, **kwargs
186
194
  ):
@@ -191,53 +199,69 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
191
199
  discard_alpha_channel=True,
192
200
  )
193
201
 
194
- def process_image_internvl(image, input_size=448, max_num=12):
195
- transform = InternVLImageProcessor.build_transform(input_size=input_size)
196
- images = InternVLImageProcessor.dynamic_preprocess(
197
- image, image_size=input_size, use_thumbnail=True, max_num=max_num
198
- )
199
- pixel_values = [transform(image) for image in images]
200
- pixel_values = torch.stack(pixel_values)
201
- return pixel_values
202
-
203
202
  num_patches_list = []
204
203
  pixel_values = []
204
+
205
+ mean, std = InternVLImageProcessor._get_normalize_tensors(device="cuda")
206
+
205
207
  # Process each input with allocated frames
206
- for image_index, (image) in enumerate(base_output.images):
208
+ for image_index, image in enumerate(base_output.images):
207
209
  try:
208
210
  # TODO: video input
209
- raw_image = process_image_internvl(image)
210
- pixel_value = [raw_image.to(torch.bfloat16)]
211
- pixel_values += pixel_value
212
- num_patches = raw_image.shape[0]
213
- num_patches_list += [num_patches]
214
-
215
- except FileNotFoundError as e:
216
- print(e)
211
+ # Convert PIL to GPU tensor
212
+ if isinstance(image, Image.Image):
213
+ img_np = np.array(image.convert("RGB"))
214
+ tensor = (
215
+ torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
216
+ )
217
+ else:
218
+ tensor = image.cuda() # assume already tensor
219
+
220
+ tensor = (tensor - mean) / std
221
+ tiles = self.dynamic_preprocess(
222
+ tensor, image_size=448, max_num=12, use_thumbnail=True
223
+ )
224
+
225
+ pixel_values.append(tiles)
226
+ num_patches_list.append(tiles.shape[0])
227
+
228
+ except Exception as e:
229
+ print(f"[Error] Failed to process image {image_index}: {e}")
217
230
  return None
218
231
 
232
+ # Concatenate all
219
233
  pixel_values = torch.cat(pixel_values, dim=0)
220
234
 
221
235
  original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
222
236
  input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
223
237
 
224
- for idx, num_patches in enumerate(num_patches_list):
238
+ input_text_updated = input_text
239
+ for num_patches in num_patches_list:
225
240
  image_tokens = (
226
241
  self.IMG_START_TOKEN
227
242
  + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
228
243
  + self.IMG_END_TOKEN
229
244
  )
230
- input_text = input_text.replace(original_placeholder, image_tokens, 1)
245
+ input_text_updated = input_text_updated.replace(
246
+ original_placeholder, image_tokens, 1
247
+ )
231
248
 
232
- input_text = input_text.replace(original_placeholder, self.IMG_CONTEXT_TOKEN)
249
+ input_text_updated = input_text_updated.replace(
250
+ original_placeholder, self.IMG_CONTEXT_TOKEN
251
+ )
233
252
 
234
- input_ids = self.tokenizer(input_text, return_tensors="pt")[
253
+ # Tokenize
254
+ input_ids_tensor = self.tokenizer(input_text_updated, return_tensors="pt")[
235
255
  "input_ids"
236
256
  ].flatten()
257
+ input_ids = input_ids_tensor.tolist()
258
+
259
+ # Get image token offsets
237
260
  image_offsets = self.get_mm_items_offset(
238
- input_ids=input_ids,
261
+ input_ids=input_ids_tensor.to("cuda"),
239
262
  mm_token_id=self.mm_tokens.image_token_id,
240
263
  )
264
+
241
265
  items = [
242
266
  MultimodalDataItem(
243
267
  feature=pixel_values,
@@ -247,7 +271,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
247
271
  ]
248
272
 
249
273
  return {
250
- "input_ids": input_ids.tolist(),
274
+ "input_ids": input_ids,
251
275
  "mm_items": items,
252
276
  "im_start_id": self.img_start_token_id,
253
277
  "im_end_id": self.img_end_token_id,
@@ -12,6 +12,8 @@ from torchvision.transforms import InterpolationMode
12
12
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
13
13
  from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
14
14
  from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
15
+ from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
16
+ from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
15
17
  from sglang.srt.multimodal.processors.base_processor import (
16
18
  BaseMultimodalProcessor as SGLangBaseProcessor,
17
19
  )
@@ -67,10 +69,15 @@ def smart_resize(
67
69
  return h_bar, w_bar
68
70
 
69
71
 
70
- def resize_image(image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
72
+ def resize_image(
73
+ image,
74
+ min_pixels: int = MIN_PIXELS,
75
+ max_pixels: int = MAX_PIXELS,
76
+ size_factor: int = IMAGE_FACTOR,
77
+ ) -> Image.Image:
71
78
  width, height = image.size
72
- min_pixels = MIN_PIXELS
73
- max_pixels = MAX_PIXELS
79
+ min_pixels = min_pixels
80
+ max_pixels = max_pixels
74
81
  resized_height, resized_width = smart_resize(
75
82
  height,
76
83
  width,
@@ -97,8 +104,13 @@ def floor_by_factor(number: int, factor: int) -> int:
97
104
  return math.floor(number / factor) * factor
98
105
 
99
106
 
100
- async def resize_image_async(image):
101
- return resize_image(image)
107
+ async def resize_image_async(
108
+ image,
109
+ min_pixels: int = MIN_PIXELS,
110
+ max_pixels: int = MAX_PIXELS,
111
+ size_factor: int = IMAGE_FACTOR,
112
+ ):
113
+ return resize_image(image, min_pixels, max_pixels, size_factor)
102
114
 
103
115
 
104
116
  def smart_nframes(
@@ -199,7 +211,12 @@ async def preprocess_video(
199
211
 
200
212
  # Compatible with Qwen2VL and Qwen2_5VL
201
213
  class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
202
- models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
214
+ models = [
215
+ Qwen2VLForConditionalGeneration,
216
+ Qwen2_5_VLForConditionalGeneration,
217
+ Qwen3VLForConditionalGeneration,
218
+ Qwen3VLMoeForConditionalGeneration,
219
+ ]
203
220
 
204
221
  def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
205
222
  super().__init__(hf_config, server_args, _processor, *args, **kwargs)
@@ -0,0 +1,81 @@
1
+ from typing import List, Union
2
+
3
+ from sglang.srt.models.sarashina2_vision import Sarashina2VisionForCausalLM
4
+ from sglang.srt.multimodal.processors.base_processor import (
5
+ BaseMultimodalProcessor,
6
+ MultimodalSpecialTokens,
7
+ )
8
+
9
+
10
+ class Sarashina2VisionProcessor(BaseMultimodalProcessor):
11
+ models = [Sarashina2VisionForCausalLM]
12
+
13
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
14
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
15
+
16
+ # Sarashina2Vision specific tokens (default is <|file|>)
17
+ self.IMAGE_TOKEN = "<|file|>"
18
+ self.IM_TOKEN_ID = getattr(hf_config, "image_token_index", 14)
19
+ self.IM_START_ID = getattr(hf_config, "start_image_token_index", 102397)
20
+ self.IM_END_ID = getattr(hf_config, "end_image_token_index", 102398)
21
+
22
+ self.mm_tokens = MultimodalSpecialTokens(
23
+ image_token=self.IMAGE_TOKEN,
24
+ image_token_id=self.IM_TOKEN_ID,
25
+ ).build(_processor)
26
+
27
+ # Patch the processor's image processor to handle parameter compatibility
28
+ if hasattr(_processor, "image_processor") and hasattr(
29
+ _processor.image_processor, "_preprocess"
30
+ ):
31
+ original_preprocess = _processor.image_processor._preprocess
32
+
33
+ def patched_preprocess(*args, **kwargs):
34
+ # Filter kwargs to only include parameters that the custom _preprocess method accepts
35
+ # Based on Sarashina2VisionImageProcessor._preprocess signature
36
+ allowed_params = {
37
+ "do_resize",
38
+ "resample",
39
+ "do_rescale",
40
+ "rescale_factor",
41
+ "do_normalize",
42
+ "image_mean",
43
+ "image_std",
44
+ "do_convert_rgb",
45
+ "data_format",
46
+ "input_data_format",
47
+ }
48
+ filtered_kwargs = {
49
+ k: v for k, v in kwargs.items() if k in allowed_params
50
+ }
51
+ return original_preprocess(*args, **filtered_kwargs)
52
+
53
+ _processor.image_processor._preprocess = patched_preprocess
54
+
55
+ async def process_mm_data_async(
56
+ self,
57
+ image_data: List[Union[str, bytes]],
58
+ input_text,
59
+ request_obj,
60
+ *args,
61
+ **kwargs,
62
+ ):
63
+ """Process image data for Sarashina2Vision model using standard SGLang pattern."""
64
+ base_output = self.load_mm_data(
65
+ prompt=input_text,
66
+ image_data=image_data,
67
+ multimodal_tokens=self.mm_tokens,
68
+ )
69
+
70
+ mm_items, input_ids, ret = self.process_and_combine_mm_data(
71
+ base_output=base_output,
72
+ mm_tokens=self.mm_tokens,
73
+ )
74
+
75
+ return {
76
+ "mm_items": mm_items,
77
+ "input_ids": input_ids.tolist(),
78
+ "im_token_id": self.mm_tokens.image_token_id,
79
+ "im_start_id": self.IM_START_ID,
80
+ "im_end_id": self.IM_END_ID,
81
+ }
sglang/srt/offloader.py CHANGED
@@ -38,6 +38,10 @@ class BaseOffloader(ABC):
38
38
  def post_init(self):
39
39
  pass
40
40
 
41
+ @property
42
+ def forbid_copy_engine_usage(self):
43
+ return False
44
+
41
45
 
42
46
  class NoopOffloader(BaseOffloader):
43
47
  pass
@@ -233,6 +237,10 @@ class OffloaderV2(BaseOffloader):
233
237
  for i in range(self.prefetch_step):
234
238
  self.offloaders[i].start_onload()
235
239
 
240
+ @property
241
+ def forbid_copy_engine_usage(self):
242
+ return self.mode == "cpu"
243
+
236
244
 
237
245
  def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
238
246
  def _on_forward_end():
@@ -398,14 +406,30 @@ class _ShmCpuParamOffloader(_BaseParamOffloader):
398
406
  return self.shm_cpu_data.to("cuda", non_blocking=True)
399
407
 
400
408
 
409
+ def update_param(param, new_tensor):
410
+ """Update parameter while keeping properties needed by Offloader (e.g. pinned host memory)."""
411
+
412
+ if param.device == new_tensor.device:
413
+ param.data = new_tensor
414
+ else:
415
+ assert param.device == torch.device(
416
+ "cpu"
417
+ ), f"{param.device=} {new_tensor.device=}"
418
+ param.data = _create_cpu_data(new_tensor, pin_memory=True)
419
+
420
+
401
421
  def _move_param_to_cpu(param, pin_memory: bool):
422
+ param.data = _create_cpu_data(param.data, pin_memory=pin_memory)
423
+
424
+
425
+ def _create_cpu_data(data, pin_memory: bool):
402
426
  cpu_data = _empty_strided_like(
403
- param.data,
427
+ data,
404
428
  device="cpu",
405
429
  pin_memory=pin_memory,
406
430
  )
407
- cpu_data.copy_(param.data)
408
- param.data = cpu_data
431
+ cpu_data.copy_(data)
432
+ return cpu_data
409
433
 
410
434
 
411
435
  def _move_param_to_meta(module, param_name):
@@ -89,6 +89,12 @@ def detect_jinja_template_content_format(chat_template: str) -> str:
89
89
  - If template has loops like {%- for content in message['content'] -%} → 'openai'
90
90
  - Otherwise → 'string'
91
91
  """
92
+ # Shortcut for multimodal templates
93
+ if any(
94
+ keyword in chat_template for keyword in ["image", "audio", "video", "vision"]
95
+ ):
96
+ return "openai"
97
+
92
98
  jinja_ast = _try_extract_ast(chat_template)
93
99
  if jinja_ast is None:
94
100
  return "string"
@@ -60,6 +60,9 @@ class SamplingBatchInfo:
60
60
  Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
61
61
  ] = None
62
62
 
63
+ # Used for deterministic sampling
64
+ sampling_seed: Optional[torch.Tensor] = None
65
+
63
66
  # Device
64
67
  device: str = "cuda"
65
68
 
@@ -67,28 +70,41 @@ class SamplingBatchInfo:
67
70
  logit_bias: Optional[torch.Tensor] = None
68
71
 
69
72
  @classmethod
70
- def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
73
+ def _get_global_server_args_dict(cls):
71
74
  from sglang.srt.managers.schedule_batch import global_server_args_dict
72
75
 
76
+ return global_server_args_dict
77
+
78
+ @classmethod
79
+ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
80
+ global_server_args_dict = cls._get_global_server_args_dict()
81
+ enable_deterministic = global_server_args_dict["enable_deterministic_inference"]
82
+
73
83
  reqs = batch.reqs
74
84
  device = batch.device
75
- temperatures = (
76
- torch.tensor(
77
- [r.sampling_params.temperature for r in reqs],
78
- dtype=torch.float,
79
- )
80
- .view(-1, 1)
81
- .to(device, non_blocking=True)
82
- )
85
+ temperatures = torch.tensor(
86
+ [r.sampling_params.temperature for r in reqs],
87
+ dtype=torch.float,
88
+ device=device,
89
+ ).view(-1, 1)
83
90
  top_ps = torch.tensor(
84
- [r.sampling_params.top_p for r in reqs], dtype=torch.float
85
- ).to(device, non_blocking=True)
91
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
92
+ )
86
93
  top_ks = torch.tensor(
87
- [r.sampling_params.top_k for r in reqs], dtype=torch.int32
88
- ).to(device, non_blocking=True)
94
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int32, device=device
95
+ )
89
96
  min_ps = torch.tensor(
90
- [r.sampling_params.min_p for r in reqs], dtype=torch.float
91
- ).to(device, non_blocking=True)
97
+ [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
98
+ )
99
+ sampling_seed = (
100
+ torch.tensor(
101
+ [r.sampling_params.sampling_seed for r in reqs],
102
+ dtype=torch.int32,
103
+ device=device,
104
+ )
105
+ if enable_deterministic
106
+ else None
107
+ )
92
108
 
93
109
  logit_bias = None
94
110
  if any(r.sampling_params.logit_bias is not None for r in reqs):
@@ -154,6 +170,7 @@ class SamplingBatchInfo:
154
170
  top_ps=top_ps,
155
171
  top_ks=top_ks,
156
172
  min_ps=min_ps,
173
+ sampling_seed=sampling_seed,
157
174
  is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
158
175
  need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
159
176
  need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
@@ -235,9 +252,11 @@ class SamplingBatchInfo:
235
252
  "top_ps",
236
253
  "top_ks",
237
254
  "min_ps",
255
+ "sampling_seed",
238
256
  ]:
239
257
  value = getattr(self, item, None)
240
- setattr(self, item, value[keep_indices_device])
258
+ if value is not None:
259
+ setattr(self, item, value[keep_indices_device])
241
260
 
242
261
  if self.logit_bias is not None:
243
262
  self.logit_bias = self.logit_bias[keep_indices_device]
@@ -339,10 +358,12 @@ class SamplingBatchInfo:
339
358
  "top_ps",
340
359
  "top_ks",
341
360
  "min_ps",
361
+ "sampling_seed",
342
362
  ]:
343
363
  self_val = getattr(self, item, None)
344
364
  other_val = getattr(other, item, None)
345
- setattr(self, item, torch.cat([self_val, other_val]))
365
+ if self_val is not None and other_val is not None:
366
+ setattr(self, item, torch.cat([self_val, other_val]))
346
367
 
347
368
  self.is_all_greedy &= other.is_all_greedy
348
369
  self.need_top_p_sampling |= other.need_top_p_sampling