sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ from torch.distributed.device_mesh import DeviceMesh
6
6
  from torch.distributed.tensor import DTensor
7
7
 
8
8
  from sglang.srt.entrypoints.engine import Engine
9
- from sglang.srt.managers.tokenizer_manager import UpdateWeightsFromTensorReqInput
9
+ from sglang.srt.managers.io_struct import UpdateWeightsFromTensorReqInput
10
10
  from sglang.srt.model_executor.model_runner import LocalSerializedTensor
11
11
  from sglang.srt.utils import MultiprocessingSerializer
12
12
 
@@ -33,7 +33,7 @@ async def update_weights(
33
33
  """
34
34
  infer_tp_size = device_mesh[device_mesh_key].mesh.size()[0]
35
35
  infer_tp_rank = device_mesh[device_mesh_key].get_local_rank()
36
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
36
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
37
37
 
38
38
  monkey_patch_torch_reductions()
39
39
 
@@ -41,6 +41,10 @@ DEFAULT_CONFIG = {
41
41
  "v_head_dim": 512,
42
42
  "num_kv_heads": 1,
43
43
  "layer_id": 0,
44
+ "tp_q_head_num": 128,
45
+ "tp_k_head_num": 128,
46
+ "prefill_head_dim": 192,
47
+ "prefill_v_head_dim": 128,
44
48
  }
45
49
 
46
50
  ROPE_BASE = 10000
@@ -92,7 +96,7 @@ TEST_CASES = {
92
96
  "description": "Medium-scale batch",
93
97
  },
94
98
  ],
95
- "decode_output_match": [
99
+ "output_match": [
96
100
  {
97
101
  "name": "single_fp16",
98
102
  "batch_size": 1,
@@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase):
322
326
  config.update(test_case)
323
327
  return config
324
328
 
325
- def _create_model_components(self, config):
329
+ def _create_model_components(self, config, is_prefill=False):
326
330
  """Create model runners, backends, and layer for testing."""
327
331
  # Create model runners
328
332
  model_runner_trtllm = MockModelRunner(config)
@@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase):
332
336
  trtllm_backend = TRTLLMMLABackend(model_runner_trtllm)
333
337
  reference_backend = FlashInferMLAAttnBackend(model_runner_reference)
334
338
 
339
+ head_dim = (
340
+ config["kv_lora_rank"] + config["qk_rope_head_dim"]
341
+ if not is_prefill
342
+ else config["prefill_head_dim"]
343
+ )
344
+ v_head_dim = (
345
+ config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"]
346
+ )
347
+
335
348
  # Create RadixAttention layer
336
349
  layer = RadixAttention(
337
350
  num_heads=config["num_attention_heads"],
338
- head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"],
351
+ head_dim=head_dim,
339
352
  scaling=model_runner_trtllm.model_config.scaling,
340
353
  num_kv_heads=config["num_kv_heads"],
341
354
  layer_id=config["layer_id"],
342
- v_head_dim=config["v_head_dim"],
355
+ v_head_dim=v_head_dim,
343
356
  prefix="attn_mqa",
344
357
  )
345
358
 
@@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase):
524
537
  """Test that TRTLLM and FlashInfer MLA backends produce matching outputs."""
525
538
  print(f"\nRunning decode output matching tests...")
526
539
 
527
- for test_case in TEST_CASES["decode_output_match"]:
540
+ for test_case in TEST_CASES["output_match"]:
528
541
  with self.subTest(test_case=test_case["name"]):
529
542
  print(f" Testing {test_case['name']}: {test_case['description']}")
530
543
 
@@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase):
1099
1112
  self.assertIsNotNone(metadata_3.block_kv_indices)
1100
1113
  self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
1101
1114
 
1115
+ def test_prefill_output_match_self_attention(self):
1116
+ """Test prefill (forward) behavior of TRTLLM MLA backend vs reference."""
1117
+ print(f"\nRunning prefill output tests...")
1118
+
1119
+ for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed
1120
+ with self.subTest(test_case=test_case["name"]):
1121
+ print(
1122
+ f"Prefill Testing {test_case['name']}: {test_case['description']}"
1123
+ )
1124
+
1125
+ config = self._merge_config(test_case)
1126
+ batch_size = config["batch_size"]
1127
+ max_seq_len = config["max_seq_len"]
1128
+
1129
+ # Create components
1130
+ (
1131
+ model_runner_trtllm,
1132
+ model_runner_reference,
1133
+ trtllm_backend,
1134
+ reference_backend,
1135
+ layer,
1136
+ ) = self._create_model_components(config, is_prefill=True)
1137
+
1138
+ # Prefill uses full sequences
1139
+ seq_lens = torch.full(
1140
+ (batch_size,), max_seq_len, device=config["device"]
1141
+ )
1142
+
1143
+ def _create_forward_batch_prefill(
1144
+ batch_size,
1145
+ seq_lens,
1146
+ extend_prefix_lens,
1147
+ backend,
1148
+ model_runner,
1149
+ config,
1150
+ ):
1151
+ """Create a forward batch for the given backend."""
1152
+
1153
+ fb = ForwardBatch(
1154
+ batch_size=batch_size,
1155
+ input_ids=torch.randint(
1156
+ 0, 100, (batch_size, 1), device=config["device"]
1157
+ ),
1158
+ out_cache_loc=torch.arange(batch_size, device=config["device"]),
1159
+ seq_lens_sum=int(seq_lens.sum().item()),
1160
+ extend_prefix_lens=extend_prefix_lens,
1161
+ extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(),
1162
+ extend_seq_lens_cpu=(seq_lens - extend_prefix_lens)
1163
+ .cpu()
1164
+ .int()
1165
+ .tolist(),
1166
+ forward_mode=ForwardMode.EXTEND,
1167
+ req_pool_indices=torch.arange(
1168
+ batch_size, device=config["device"]
1169
+ ),
1170
+ seq_lens=seq_lens,
1171
+ seq_lens_cpu=seq_lens.cpu(),
1172
+ attn_attend_prefix_cache=False,
1173
+ mha_return_lse=False,
1174
+ attn_backend=backend,
1175
+ )
1176
+ fb.req_to_token_pool = model_runner.req_to_token_pool
1177
+ fb.token_to_kv_pool = model_runner.token_to_kv_pool
1178
+
1179
+ # Add position information for RoPE
1180
+ fb.positions = torch.arange(batch_size, device=config["device"])
1181
+
1182
+ return fb
1183
+
1184
+ # Create forward batches
1185
+ fb_trtllm = _create_forward_batch_prefill(
1186
+ batch_size,
1187
+ seq_lens.clone(),
1188
+ torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
1189
+ trtllm_backend,
1190
+ model_runner_trtllm,
1191
+ config,
1192
+ )
1193
+ fb_reference = _create_forward_batch_prefill(
1194
+ batch_size,
1195
+ seq_lens.clone(),
1196
+ torch.zeros(batch_size, device=config["device"], dtype=torch.int32),
1197
+ reference_backend,
1198
+ model_runner_reference,
1199
+ config,
1200
+ )
1201
+
1202
+ # Initialize metadata for both backends
1203
+ trtllm_backend.init_forward_metadata(fb_trtllm)
1204
+ reference_backend.init_forward_metadata(fb_reference)
1205
+
1206
+ # Create Q, K, V tensors for prefill
1207
+ torch.manual_seed(config["seed_qkv"])
1208
+
1209
+ def _create_qkv_tensors_prefill(
1210
+ batch_size, seq_len, config, dtype_override=None
1211
+ ):
1212
+ """Create Q, K, V tensors for prefill, using config for head_num and head_dim."""
1213
+ device = config["device"]
1214
+ dtype = dtype_override or config["dtype"]
1215
+
1216
+ total_tokens = batch_size * seq_len
1217
+
1218
+ tp_q_head_num = config["tp_q_head_num"]
1219
+ tp_k_head_num = config["tp_k_head_num"]
1220
+ head_dim = config["prefill_head_dim"]
1221
+ v_head_dim = config["prefill_v_head_dim"]
1222
+
1223
+ q = torch.randn(
1224
+ (total_tokens, tp_q_head_num * head_dim),
1225
+ dtype=dtype,
1226
+ device=device,
1227
+ )
1228
+ k = torch.randn(
1229
+ (total_tokens, tp_k_head_num * head_dim),
1230
+ dtype=dtype,
1231
+ device=device,
1232
+ )
1233
+ v = torch.randn(
1234
+ (total_tokens, tp_k_head_num * v_head_dim),
1235
+ dtype=dtype,
1236
+ device=device,
1237
+ )
1238
+
1239
+ # Reshape as requested
1240
+ q = q.view(-1, tp_q_head_num, head_dim)
1241
+ k = k.view(-1, tp_k_head_num, head_dim)
1242
+ v = v.view(-1, tp_k_head_num, v_head_dim)
1243
+
1244
+ return q, k, v
1245
+
1246
+ q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config)
1247
+ # Run prefill on both backends
1248
+ out_trtllm = trtllm_backend.forward_extend(
1249
+ q, k, v, layer, fb_trtllm, False
1250
+ ).view(-1, layer.tp_q_head_num * layer.v_head_dim)
1251
+ out_reference = reference_backend.forward_extend(
1252
+ q, k, v, layer, fb_reference, False
1253
+ )
1254
+
1255
+ tolerance = config.get("tolerance", 1e-2)
1256
+ comparison_passed = compare_outputs(
1257
+ out_trtllm, out_reference, tolerance=tolerance
1258
+ )
1259
+ self.assertTrue(
1260
+ comparison_passed,
1261
+ f"TRTLLM and Reference prefill outputs differ beyond tolerance. "
1262
+ f"Config: {test_case['name']}, "
1263
+ f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
1264
+ )
1265
+
1102
1266
 
1103
1267
  if __name__ == "__main__":
1104
1268
  unittest.main()
@@ -0,0 +1,57 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class DummyModel(nn.Module):
6
+ def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
7
+ super().__init__()
8
+ self.weights_proj = nn.Linear(d_in, 1024)
9
+ self.n_heads = n_heads
10
+ self.softmax_scale = softmax_scale
11
+
12
+ def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
13
+ weights = self.weights_proj(x)
14
+ weights = weights * self.n_heads**-0.5
15
+ q_scale = q_scale.unsqueeze(1) # (B,1,1)
16
+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
17
+ return weights
18
+
19
+ def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
20
+ weights = self.weights_proj(x)
21
+ q_scale = q_scale.unsqueeze(1) # (B,1,1)
22
+ scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
23
+ weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
24
+ return weights
25
+
26
+
27
+ def main():
28
+ torch.manual_seed(0)
29
+ model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
30
+ x = torch.randn(128, 2048) # batch=128, d_in=2048
31
+ q_scale = torch.randn(128, 1)
32
+
33
+ import time
34
+
35
+ start = time.time()
36
+ for _ in range(1000):
37
+ out_orig = model._get_logits_head_gate_orig(x, q_scale)
38
+ print("Original version time:", time.time() - start)
39
+
40
+ start = time.time()
41
+ for _ in range(1000):
42
+ out_opt = model._get_logits_head_gate_opt(x, q_scale)
43
+ print("Optimized version time:", time.time() - start)
44
+
45
+ print("Difference:", (out_orig - out_opt).abs().max().item())
46
+ assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
51
+
52
+
53
+ """
54
+ Original version time: 0.49235057830810547
55
+ Optimized version time: 0.4087331295013428
56
+ Difference: 1.4901161193847656e-08
57
+ """
sglang/test/run_eval.py CHANGED
@@ -10,11 +10,46 @@ import time
10
10
 
11
11
  from sglang.test.simple_eval_common import (
12
12
  ChatCompletionSampler,
13
+ Eval,
13
14
  make_report,
14
15
  set_ulimit,
15
16
  )
16
17
 
17
18
 
19
+ def get_thinking_kwargs(args):
20
+ thinking_mode = getattr(args, "thinking_mode", None)
21
+ if thinking_mode in THINKING_MODE_CHOICES:
22
+ if thinking_mode == "deepseek-v3":
23
+ thinking_param = "thinking"
24
+ else:
25
+ thinking_param = "enable_thinking"
26
+ return {
27
+ "chat_template_kwargs": {thinking_param: True},
28
+ }
29
+ return {}
30
+
31
+
32
+ def run_eval_once(args, base_url: str, eval_obj: Eval) -> dict:
33
+ # Get thinking kwargs based on user's choice
34
+ thinking_kwargs = get_thinking_kwargs(args)
35
+
36
+ sampler = ChatCompletionSampler(
37
+ model=args.model,
38
+ max_tokens=getattr(args, "max_tokens", 2048),
39
+ base_url=base_url,
40
+ temperature=getattr(args, "temperature", 0.0),
41
+ reasoning_effort=getattr(args, "reasoning_effort", None),
42
+ extra_body=thinking_kwargs,
43
+ )
44
+
45
+ # Run eval
46
+ tic = time.perf_counter()
47
+ result = eval_obj(sampler)
48
+ latency = time.perf_counter() - tic
49
+
50
+ return result, latency, sampler
51
+
52
+
18
53
  def run_eval(args):
19
54
  set_ulimit()
20
55
 
@@ -60,21 +95,40 @@ def run_eval(args):
60
95
  from sglang.test.simple_eval_humaneval import HumanEval
61
96
 
62
97
  eval_obj = HumanEval(args.num_examples, args.num_threads)
98
+ elif args.eval_name == "mmmu":
99
+ # VLM MMMU evaluation with fixed 100 examples by default
100
+ from sglang.test.simple_eval_mmmu_vlm import MMMUVLMEval
101
+
102
+ eval_obj = MMMUVLMEval(args.num_examples, args.num_threads)
63
103
  else:
64
104
  raise ValueError(f"Invalid eval name: {args.eval_name}")
65
105
 
66
- sampler = ChatCompletionSampler(
67
- model=args.model,
68
- max_tokens=getattr(args, "max_tokens", 2048),
69
- base_url=base_url,
70
- temperature=getattr(args, "temperature", 0.0),
71
- reasoning_effort=getattr(args, "reasoning_effort", None),
72
- )
106
+ if getattr(args, "repeat", 1) == 1:
107
+ result, latency, sampler = run_eval_once(args, base_url, eval_obj)
108
+ else:
109
+ from concurrent.futures import ThreadPoolExecutor
73
110
 
74
- # Run eval
75
- tic = time.perf_counter()
76
- result = eval_obj(sampler)
77
- latency = time.perf_counter() - tic
111
+ executor = ThreadPoolExecutor(max_workers=args.repeat)
112
+
113
+ futures = [
114
+ executor.submit(run_eval_once, args, base_url, eval_obj)
115
+ for _ in range(args.repeat)
116
+ ]
117
+
118
+ scores_repeat = []
119
+
120
+ for f in futures:
121
+ result, latency, sampler = f.result()
122
+ scores_repeat.append(result.score)
123
+
124
+ mean_score = sum(scores_repeat) / len(scores_repeat)
125
+ scores_repeat = [f"{s:.3f}" for s in scores_repeat]
126
+ print("=" * 20)
127
+ print(f"Repeat: {args.repeat}, mean: {mean_score:.3f}")
128
+ print(f"Scores: {scores_repeat}")
129
+ print("=" * 20)
130
+
131
+ executor.shutdown()
78
132
 
79
133
  # Dump reports
80
134
  metrics = result.metrics | {"score": result.score}
@@ -94,9 +148,13 @@ def run_eval(args):
94
148
  print(f"Total latency: {latency:.3f} s")
95
149
  print(f"Score: {metrics['score']:.3f}")
96
150
 
151
+ if getattr(args, "return_latency", False):
152
+ return metrics, latency
97
153
  return metrics
98
154
 
99
155
 
156
+ THINKING_MODE_CHOICES = ["deepseek-r1", "deepseek-v3", "qwen3"]
157
+
100
158
  if __name__ == "__main__":
101
159
  parser = argparse.ArgumentParser()
102
160
  parser.add_argument(
@@ -118,12 +176,22 @@ if __name__ == "__main__":
118
176
  type=str,
119
177
  help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
120
178
  )
179
+ parser.add_argument(
180
+ "--repeat", type=int, default=1, help="repeat the evaluation n times"
181
+ )
121
182
  parser.add_argument("--eval-name", type=str, default="mmlu")
122
183
  parser.add_argument("--num-examples", type=int)
123
184
  parser.add_argument("--num-threads", type=int, default=512)
124
185
  parser.add_argument("--max-tokens", type=int, default=2048)
125
186
  parser.add_argument("--temperature", type=float, default=0.0)
126
187
  parser.add_argument("--reasoning-effort", type=str)
188
+ parser.add_argument(
189
+ "--thinking-mode",
190
+ default=None,
191
+ type=str,
192
+ choices=THINKING_MODE_CHOICES,
193
+ help="Enable thinking mode in Deepseek R1, V3.1/3.2, or Qwen3",
194
+ )
127
195
  args = parser.parse_args()
128
196
 
129
197
  run_eval(args)
sglang/test/runners.py CHANGED
@@ -30,8 +30,8 @@ from transformers import (
30
30
  )
31
31
 
32
32
  from sglang.srt.entrypoints.engine import Engine
33
- from sglang.srt.hf_transformers_utils import get_tokenizer
34
33
  from sglang.srt.utils import load_image
34
+ from sglang.srt.utils.hf_transformers_utils import get_tokenizer
35
35
  from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
36
36
 
37
37
  DEFAULT_PROMPTS = [
@@ -505,6 +505,7 @@ class SRTRunner:
505
505
  mem_fraction_static: float = 0.65,
506
506
  trust_remote_code: bool = False,
507
507
  speculative_draft_model_path: Optional[str] = None,
508
+ speculative_draft_model_revision: Optional[str] = None,
508
509
  speculative_algorithm: Optional[str] = None,
509
510
  speculative_num_steps: Optional[int] = None,
510
511
  speculative_eagle_topk: Optional[int] = None,
@@ -526,6 +527,9 @@ class SRTRunner:
526
527
  spec_kwargs = {}
527
528
  if speculative_draft_model_path:
528
529
  spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
530
+ spec_kwargs["speculative_draft_model_revision"] = (
531
+ speculative_draft_model_revision
532
+ )
529
533
  spec_kwargs["speculative_algorithm"] = speculative_algorithm
530
534
  spec_kwargs["speculative_num_steps"] = speculative_num_steps
531
535
  spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
@@ -93,6 +93,7 @@ class ChatCompletionSampler(SamplerBase):
93
93
  temperature: float = 0.0,
94
94
  reasoning_effort: Optional[str] = None,
95
95
  max_tokens: int = 2048,
96
+ extra_body: Optional[Dict[str, Any]] = None,
96
97
  ):
97
98
  self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
98
99
 
@@ -104,9 +105,10 @@ class ChatCompletionSampler(SamplerBase):
104
105
  self.temperature = temperature
105
106
  self.max_tokens = max_tokens
106
107
  self.reasoning_effort = reasoning_effort
108
+ self.extra_body = extra_body
107
109
  self.image_format = "url"
108
110
  print(
109
- f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
111
+ f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=} {self.extra_body=}"
110
112
  )
111
113
 
112
114
  def _handle_image(
@@ -136,7 +138,7 @@ class ChatCompletionSampler(SamplerBase):
136
138
  self._pack_message("system", self.system_message)
137
139
  ] + message_list
138
140
  trial = 0
139
- while True:
141
+ while trial < 6: # 126 seconds in total
140
142
  try:
141
143
  response = self.client.chat.completions.create(
142
144
  model=self.model,
@@ -144,6 +146,7 @@ class ChatCompletionSampler(SamplerBase):
144
146
  temperature=self.temperature,
145
147
  max_tokens=self.max_tokens,
146
148
  reasoning_effort=self.reasoning_effort,
149
+ extra_body=self.extra_body,
147
150
  )
148
151
  return response.choices[0].message.content
149
152
  # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU