sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,14 @@
2
2
  Batch the same prompt in random batch sizes, and test if the results are consistent across different trials.
3
3
 
4
4
  Usage:
5
- python3 -m sglang.test.test_deterministic --n-trials <numer_of_trials> --test-mode <single|mixed|prefix> --profile
5
+ # Single mode: test determinism with varying batch sizes
6
+ python3 -m sglang.test.test_deterministic --n-trials 50 --test-mode single
7
+
8
+ # Prefix mode: test with shared prefixes
9
+ python3 -m sglang.test.test_deterministic --n-start 1 --n-trials 50 --test-mode prefix
10
+
11
+ # Radix Cache Consistency mode: test radix cache determinism (cached vs uncached prefill)
12
+ python3 -m sglang.test.test_deterministic --test-mode radix_cache
6
13
  """
7
14
 
8
15
  import argparse
@@ -10,7 +17,7 @@ import dataclasses
10
17
  import json
11
18
  import os
12
19
  import random
13
- from typing import List
20
+ from typing import Any, Dict, List, Optional
14
21
 
15
22
  import requests
16
23
 
@@ -39,12 +46,15 @@ class BenchArgs:
39
46
  profile_steps: int = 3
40
47
  profile_by_stage: bool = False
41
48
  test_mode: str = "single"
49
+ n_trials: int = 50
50
+ n_start: int = 1
42
51
 
43
52
  @staticmethod
44
53
  def add_cli_args(parser: argparse.ArgumentParser):
45
54
  parser.add_argument("--host", type=str, default=BenchArgs.host)
46
55
  parser.add_argument("--port", type=int, default=BenchArgs.port)
47
- parser.add_argument("--n-trials", type=int, default=50)
56
+ parser.add_argument("--n-trials", type=int, default=BenchArgs.n_trials)
57
+ parser.add_argument("--n-start", type=int, default=BenchArgs.n_start)
48
58
  parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
49
59
  parser.add_argument(
50
60
  "--sampling-seed", type=int, default=BenchArgs.sampling_seed
@@ -64,7 +74,12 @@ class BenchArgs:
64
74
  "--test-mode",
65
75
  type=str,
66
76
  default=BenchArgs.test_mode,
67
- choices=["single", "mixed", "prefix"],
77
+ choices=[
78
+ "single",
79
+ "prefix",
80
+ "radix_cache",
81
+ "p_vs_d",
82
+ ],
68
83
  )
69
84
  parser.add_argument("--profile", action="store_true")
70
85
  parser.add_argument(
@@ -80,26 +95,55 @@ class BenchArgs:
80
95
 
81
96
  def send_single(
82
97
  args,
83
- batch_size: int,
84
98
  profile: bool = False,
85
99
  profile_steps: int = 3,
86
100
  profile_by_stage: bool = False,
101
+ return_full_response: bool = False,
102
+ input_ids: List[int] = None,
103
+ prompt: List[str] = None,
104
+ max_new_tokens: int = None,
105
+ extra_params: Optional[Dict[str, Any]] = None,
106
+ pick_first_result: bool = True,
87
107
  ):
88
-
89
108
  base_url = f"http://{args.host}:{args.port}"
90
- prompt = [PROMPT_1] * batch_size
91
109
 
92
- json_data = {
93
- "text": prompt,
94
- "sampling_params": {
95
- "temperature": args.temperature,
96
- "max_new_tokens": args.max_new_tokens,
97
- "frequency_penalty": args.frequency_penalty,
98
- "presence_penalty": args.presence_penalty,
99
- },
100
- "return_logprob": args.return_logprob,
101
- "stream": args.stream,
102
- }
110
+ # Use input_ids if provided, otherwise use text prompts
111
+ if input_ids is not None:
112
+ assert prompt is None
113
+ json_data = {
114
+ "input_ids": input_ids,
115
+ "sampling_params": {
116
+ "temperature": args.temperature,
117
+ "max_new_tokens": (
118
+ max_new_tokens
119
+ if max_new_tokens is not None
120
+ else args.max_new_tokens
121
+ ),
122
+ "frequency_penalty": args.frequency_penalty,
123
+ "presence_penalty": args.presence_penalty,
124
+ },
125
+ "return_logprob": args.return_logprob,
126
+ "stream": args.stream,
127
+ **(extra_params or {}),
128
+ }
129
+ else:
130
+ assert input_ids is None
131
+ json_data = {
132
+ "text": prompt,
133
+ "sampling_params": {
134
+ "temperature": args.temperature,
135
+ "max_new_tokens": (
136
+ max_new_tokens
137
+ if max_new_tokens is not None
138
+ else args.max_new_tokens
139
+ ),
140
+ "frequency_penalty": args.frequency_penalty,
141
+ "presence_penalty": args.presence_penalty,
142
+ },
143
+ "return_logprob": args.return_logprob,
144
+ "stream": args.stream,
145
+ **(extra_params or {}),
146
+ }
103
147
 
104
148
  if args.sampling_seed is not None:
105
149
  # sglang server cannot parse None value for sampling_seed
@@ -116,6 +160,11 @@ def send_single(
116
160
  stream=args.stream,
117
161
  )
118
162
 
163
+ if response.status_code != 200:
164
+ ret = response.json()
165
+ print(f"Error: {ret}")
166
+ return None
167
+
119
168
  if args.stream:
120
169
  for chunk in response.iter_lines(decode_unicode=False):
121
170
  chunk = chunk.decode("utf-8")
@@ -125,24 +174,30 @@ def send_single(
125
174
  ret = json.loads(chunk[5:].strip("\n"))
126
175
  else:
127
176
  ret = response.json()
128
- ret = ret[0]
129
177
 
130
- if response.status_code != 200:
131
- print(ret)
132
- return -1
178
+ if pick_first_result:
179
+ ret = ret[0] if isinstance(ret, list) else ret
133
180
 
134
- return ret["text"]
181
+ if return_full_response:
182
+ return ret
183
+ else:
184
+ return ret["text"]
135
185
 
136
186
 
137
- def send_mixed(args, batch_size: int):
138
- num_long_prompt = 0 if batch_size <= 10 else random.randint(1, 10)
139
- num_prompt_1 = random.randint(1, batch_size - num_long_prompt)
140
- num_prompt_2 = batch_size - num_prompt_1 - num_long_prompt
187
+ def send_prefix(
188
+ args, batch_size: int, prompts: List[str], return_full_response: bool = False
189
+ ):
190
+ requests.post(f"http://{args.host}:{args.port}/flush_cache")
191
+
192
+ batch_data = []
193
+ sampled_indices = []
194
+ for _ in range(batch_size):
195
+ sampled_index = random.randint(0, len(prompts) - 1)
196
+ sampled_indices.append(sampled_index)
197
+ batch_data.append(prompts[sampled_index])
141
198
 
142
199
  json_data = {
143
- "text": [PROMPT_1] * num_prompt_1
144
- + [PROMPT_2] * num_prompt_2
145
- + [LONG_PROMPT] * num_long_prompt,
200
+ "text": batch_data,
146
201
  "sampling_params": {
147
202
  "temperature": args.temperature,
148
203
  "max_new_tokens": args.max_new_tokens,
@@ -166,103 +221,171 @@ def send_mixed(args, batch_size: int):
166
221
  print(ret)
167
222
  return -1, -1, -1
168
223
 
169
- prompt_1_ret = [ret[i]["text"] for i in range(num_prompt_1)]
170
- prompt_2_ret = [
171
- ret[i]["text"] for i in range(num_prompt_1, num_prompt_1 + num_prompt_2)
172
- ]
173
- long_prompt_ret = [
174
- ret[i]["text"]
175
- for i in range(
176
- num_prompt_1 + num_prompt_2, num_prompt_1 + num_prompt_2 + num_long_prompt
177
- )
178
- ]
224
+ if return_full_response:
225
+ # Return full responses grouped by prompt index
226
+ ret_dict = {i: [] for i in range(len(prompts))}
227
+ for i in range(batch_size):
228
+ ret_dict[sampled_indices[i]].append(ret[i])
229
+ return ret_dict
230
+ else:
231
+ # Return only text grouped by prompt index
232
+ ret_dict = {i: [] for i in range(len(prompts))}
233
+ for i in range(batch_size):
234
+ ret_dict[sampled_indices[i]].append(ret[i]["text"])
235
+ return ret_dict
236
+
237
+
238
+ def compare_logprobs(logprobs1, logprobs2, tolerance=0):
239
+ """Compare two logprobs sequences with a tolerance."""
240
+ if len(logprobs1) != len(logprobs2):
241
+ return False, f"Length mismatch: {len(logprobs1)} vs {len(logprobs2)}"
242
+
243
+ for i, (lp1, lp2) in enumerate(zip(logprobs1, logprobs2)):
244
+ # Each element is [logprob, token_id]
245
+ if lp1[1] != lp2[1]:
246
+ return False, f"Token ID mismatch at position {i}: {lp1[1]} vs {lp2[1]}"
247
+ if abs(lp1[0] - lp2[0]) > tolerance:
248
+ return (
249
+ False,
250
+ f"Logprob mismatch at position {i}: {lp1[0]} vs {lp2[0]} (diff: {abs(lp1[0] - lp2[0])})",
251
+ )
252
+
253
+ return True, "Logprobs match"
179
254
 
180
- return prompt_1_ret, prompt_2_ret, long_prompt_ret
181
255
 
256
+ def _test_mode_p_vs_d(args, batch_size):
257
+ print()
258
+ print(f"Execute: test p_vs_d {batch_size=}")
182
259
 
183
- def send_prefix(args, batch_size: int, prompts: List[str]):
260
+ random.seed(42)
261
+ args.return_logprob = True
262
+ query_extra_params = {
263
+ "logprob_start_len": 0,
264
+ "return_text_in_logprobs": True,
265
+ }
266
+
267
+ def _create_prompts():
268
+ ans = [PROMPT_1, PROMPT_2]
269
+ for i in range(batch_size - len(ans)):
270
+ end = random.randrange(1, 4096)
271
+ if random.random() < 0.5:
272
+ begin = 0
273
+ else:
274
+ begin = random.randrange(0, end)
275
+ ans.append(LONG_PROMPT[begin:end])
276
+ return ans[:batch_size]
277
+
278
+ # warmup + flush
279
+ send_single(args, input_ids=[1] * 64, max_new_tokens=65, return_full_response=True)
184
280
  requests.post(f"http://{args.host}:{args.port}/flush_cache")
185
281
 
186
- batch_data = []
187
- sampled_indices = []
188
- for _ in range(batch_size):
189
- sampled_index = random.randint(0, len(prompts) - 1)
190
- sampled_indices.append(sampled_index)
191
- batch_data.append(prompts[sampled_index])
282
+ prompts = _create_prompts()
192
283
 
193
- json_data = {
194
- "text": batch_data,
195
- "sampling_params": {
196
- "temperature": args.temperature,
197
- "max_new_tokens": args.max_new_tokens,
198
- "frequency_penalty": args.frequency_penalty,
199
- "presence_penalty": args.presence_penalty,
200
- },
201
- "return_logprob": args.return_logprob,
202
- "stream": args.stream,
203
- }
284
+ resp_a = send_single(
285
+ args,
286
+ prompt=prompts,
287
+ max_new_tokens=args.max_new_tokens,
288
+ return_full_response=True,
289
+ pick_first_result=False,
290
+ extra_params=query_extra_params,
291
+ )
292
+ info_a = _extract_ids_and_logprobs(resp_a)
204
293
 
205
- if args.sampling_seed is not None:
206
- json_data["sampling_params"]["sampling_seed"] = args.sampling_seed
294
+ requests.post(f"http://{args.host}:{args.port}/flush_cache")
207
295
 
208
- response = requests.post(
209
- f"http://{args.host}:{args.port}/generate",
210
- json=json_data,
211
- stream=args.stream,
296
+ resp_b = send_single(
297
+ args,
298
+ input_ids=[x["io"].token_ids for x in info_a],
299
+ max_new_tokens=1,
300
+ return_full_response=True,
301
+ pick_first_result=False,
302
+ extra_params=query_extra_params,
212
303
  )
213
- ret = response.json()
214
- if response.status_code != 200:
215
- print(ret)
216
- return -1, -1, -1
304
+ info_b = _extract_ids_and_logprobs(resp_b)
217
305
 
218
- ret_dict = {i: [] for i in range(len(prompts))}
219
- for i in range(batch_size):
220
- ret_dict[sampled_indices[i]].append(ret[i]["text"])
306
+ ans = []
307
+ for i, (info_a_item, info_b_item) in enumerate(zip(info_a, info_b, strict=True)):
308
+ print(f"Compare sequence {i} in batch...")
309
+ correct = TokenIdsAndLogprobs.compare(info_a_item["io"], info_b_item["input"])
310
+ ans.append(int(correct))
221
311
 
222
- return ret_dict
312
+ return ans
223
313
 
224
314
 
225
- def test_deterministic(args):
226
- # First do some warmups
227
- for i in range(3):
228
- send_single(args, 16, args.profile)
315
+ @dataclasses.dataclass
316
+ class TokenIdsAndLogprobs:
317
+ token_ids: List[int]
318
+ logprobs: List[float]
319
+
320
+ def __add__(self, other):
321
+ return TokenIdsAndLogprobs(
322
+ token_ids=self.token_ids + other.token_ids,
323
+ logprobs=self.logprobs + other.logprobs,
324
+ )
325
+
326
+ @classmethod
327
+ def compare(cls, a: "TokenIdsAndLogprobs", b: "TokenIdsAndLogprobs"):
328
+ assert len(a.token_ids) == len(b.token_ids)
329
+ token_match = a.token_ids == b.token_ids
330
+ logprobs_match = a.logprobs == b.logprobs
331
+
332
+ if token_match:
333
+ print(f"Token match: {a.token_ids}")
334
+ else:
335
+ print(f"❗Token mismatch: {a.token_ids=} {b.token_ids=}")
336
+
337
+ if logprobs_match:
338
+ print(f"Logprobs match:", a.logprobs)
339
+ else:
340
+ print(f"❗Logprobs mismatch")
341
+ print(
342
+ " A: ",
343
+ [f"{x:.10f}" if x is not None else "None" for x in a.logprobs],
344
+ )
345
+ print(
346
+ " B: ",
347
+ [f"{x:.10f}" if x is not None else "None" for x in b.logprobs],
348
+ )
349
+ diff = [
350
+ abs(x - y) if x is not None else float("nan")
351
+ for x, y in zip(a.logprobs, b.logprobs)
352
+ ]
353
+ print(" Diff:", [f"{x:.10e}" for x in diff])
354
+
355
+ return token_match and logprobs_match
229
356
 
357
+
358
+ def _extract_ids_and_logprobs(responses):
359
+ def _extract_part(response, name):
360
+ token_ids, logprobs = [], []
361
+ for item in response["meta_info"][name]:
362
+ logprob, token_id, text = item
363
+ token_ids.append(token_id)
364
+ logprobs.append(logprob)
365
+ return TokenIdsAndLogprobs(token_ids=token_ids, logprobs=logprobs)
366
+
367
+ def _extract_one_response(response):
368
+ input = _extract_part(response, "input_token_logprobs")
369
+ output = _extract_part(response, "output_token_logprobs")
370
+ return dict(input=input, output=output, io=input + output)
371
+
372
+ if not isinstance(responses, list):
373
+ responses = [responses]
374
+ return [_extract_one_response(x) for x in responses]
375
+
376
+
377
+ def test_deterministic(args):
230
378
  if args.test_mode == "single":
231
379
  # In single mode, we test the deterministic behavior by sending the same prompt in batch sizes ranging from 1 to n_trials.
232
380
  texts = []
233
381
  for i in range(1, args.n_trials + 1):
234
382
  batch_size = i
235
- text = send_single(args, batch_size, args.profile)
383
+ text = send_single(args, args.profile, prompt=[PROMPT_1] * batch_size)
236
384
  text = text.replace("\n", " ")
237
385
  print(f"Trial {i} with batch size {batch_size}: {text}")
238
386
  texts.append(text)
239
-
240
387
  print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}")
241
- elif args.test_mode == "mixed":
242
- # In mixed mode, we send a mixture of two short prompts and one long prompt in the same batch with batch size ranging from 1 to n_trials.
243
- output_prompt_1 = []
244
- output_prompt_2 = []
245
- output_long_prompt = []
246
- for i in range(1, args.n_trials + 1):
247
- batch_size = i
248
- ret_prompt_1, ret_prompt_2, ret_long_prompt = send_mixed(args, batch_size)
249
- output_prompt_1.extend(ret_prompt_1)
250
- output_prompt_2.extend(ret_prompt_2)
251
- output_long_prompt.extend(ret_long_prompt)
252
-
253
- print(
254
- f"Testing Trial {i} with batch size {batch_size}, number of prompt 1: {len(ret_prompt_1)}, number of prompt 2: {len(ret_prompt_2)}, number of long prompt: {len(ret_long_prompt)}"
255
- )
256
-
257
- print(
258
- f"Prompt 1: total samples: {len(output_prompt_1)}, Unique samples: {len(set(output_prompt_1))}"
259
- )
260
- print(
261
- f"Prompt 2: total samples: {len(output_prompt_2)}, Unique samples: {len(set(output_prompt_2))}"
262
- )
263
- print(
264
- f"Long prompt: total samples: {len(output_long_prompt)}, Unique samples: {len(set(output_long_prompt))}"
265
- )
388
+ return [len(set(texts))]
266
389
 
267
390
  elif args.test_mode == "prefix":
268
391
  # In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
@@ -270,21 +393,251 @@ def test_deterministic(args):
270
393
  num_prompts = len(len_prefix)
271
394
  outputs = {i: [] for i in range(4)}
272
395
  prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
273
- for i in range(1, args.n_trials + 1):
396
+
397
+ # If return_logprob is enabled, store full responses for comparison
398
+ if args.return_logprob:
399
+ full_responses = {i: [] for i in range(4)}
400
+
401
+ for i in range(args.n_start, args.n_start + args.n_trials):
274
402
  batch_size = i
275
- ret_dict = send_prefix(args, batch_size, prompts)
403
+ ret_dict = send_prefix(
404
+ args, batch_size, prompts, return_full_response=args.return_logprob
405
+ )
276
406
  msg = f"Testing Trial {i} with batch size {batch_size},"
277
407
  for i in range(num_prompts):
278
408
  msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])},"
279
409
  print(msg)
280
410
  for i in range(num_prompts):
281
- outputs[i].extend(ret_dict[i])
411
+ if args.return_logprob:
412
+ # Store full response for logprob comparison
413
+ full_responses[i].extend(ret_dict[i])
414
+ # Extract text for determinism check
415
+ outputs[i].extend([resp["text"] for resp in ret_dict[i]])
416
+ else:
417
+ outputs[i].extend(ret_dict[i])
282
418
 
283
419
  for i in range(num_prompts):
284
420
  print(
285
421
  f"Prompt {i} with prefix length {len_prefix[i]}: total samples: {len(outputs[i])}, Unique samples: {len(set(outputs[i]))}"
286
422
  )
287
423
 
424
+ results = []
425
+ for i in range(num_prompts):
426
+ results.append(len(set(outputs[i])))
427
+
428
+ # If logprobs are enabled, compare them across different batch sizes
429
+ if args.return_logprob:
430
+ print(f"\n{'='*60}")
431
+ print("Logprobs Comparison Across Batch Sizes")
432
+ print("=" * 60)
433
+
434
+ logprob_results = []
435
+ for prompt_idx in range(num_prompts):
436
+ print(
437
+ f"\nPrompt {prompt_idx} (prefix length {len_prefix[prompt_idx]}):"
438
+ )
439
+ responses = full_responses[prompt_idx]
440
+
441
+ if len(responses) < 2:
442
+ continue
443
+
444
+ # Compare all responses against the first one
445
+ reference = responses[0]
446
+ all_match = True
447
+ mismatches = []
448
+
449
+ for j, resp in enumerate(responses[1:], start=1):
450
+ ref_logprobs = reference["meta_info"]["output_token_logprobs"]
451
+ resp_logprobs = resp["meta_info"]["output_token_logprobs"]
452
+
453
+ match, msg = compare_logprobs(ref_logprobs, resp_logprobs)
454
+
455
+ if not match:
456
+ print(f" ✗ Sample {j+1}: {msg}")
457
+ mismatches.append((j + 1, msg))
458
+ all_match = False
459
+
460
+ if all_match:
461
+ print(f" ✓ All {len(responses)} samples have identical logprobs")
462
+ logprob_results.append(1)
463
+ else:
464
+ print(
465
+ f" ✗ Found {len(mismatches)} mismatches out of {len(responses)} samples"
466
+ )
467
+ logprob_results.append(0)
468
+
469
+ print(f"\n{'='*60}")
470
+ if all(r == 1 for r in logprob_results):
471
+ print("✓✓✓ Logprobs are identical across all batch sizes! ✓✓✓")
472
+ else:
473
+ print("✗✗✗ Some logprobs differ across batch sizes! ✗✗✗")
474
+
475
+ return results
476
+
477
+ elif args.test_mode == "radix_cache":
478
+ # Radix mode requires logprobs to compare results
479
+ args.return_logprob = True
480
+
481
+ print("\n=== Prefill Cache Consistency Test ===")
482
+ print(
483
+ "This test verifies prefill request produces consistent logprobs w/ and w/o cache.\n"
484
+ )
485
+
486
+ # We noticed that we cannot call flush cache before any request, otherwise it will hang.
487
+ warmup_response = send_single(
488
+ args, input_ids=[1] * 64, max_new_tokens=65, return_full_response=True
489
+ )
490
+
491
+ # Flush cache first to make sure there is no cache hit from previous tests
492
+ flush_response = requests.post(f"http://{args.host}:{args.port}/flush_cache")
493
+
494
+ print(f"Step 1: Generating random 64 token IDs...")
495
+ # Use a reasonable token ID range (e.g., 1-50000 for most tokenizers)
496
+ # Avoid special tokens like 0 (padding), 1 (BOS), 2 (EOS)
497
+ # set seed for random.randint
498
+ random.seed(42)
499
+ initial_token_ids = [random.randint(100, 50000) for _ in range(64)]
500
+
501
+ print(f"✓ Using {len(initial_token_ids)} initial tokens")
502
+ print(f" Initial token IDs: {initial_token_ids}")
503
+
504
+ print(
505
+ f"\nStep 2: Generating 2 tokens from {len(initial_token_ids)} token prefix..."
506
+ )
507
+ first_response = send_single(
508
+ args,
509
+ input_ids=initial_token_ids,
510
+ max_new_tokens=100,
511
+ return_full_response=True,
512
+ )
513
+ first_output_text = first_response["text"]
514
+ first_output_token_ids = first_response["output_ids"]
515
+ first_output_logprobs = first_response["meta_info"]["output_token_logprobs"]
516
+
517
+ expected_token_id = first_output_token_ids[-1]
518
+ expected_logprob = first_output_logprobs[-1][0]
519
+
520
+ print(f"✓ Generated {len(first_output_token_ids)} tokens")
521
+ print(f' Output text: "{first_output_text}"')
522
+
523
+ print(
524
+ f"\nStep 3: Generating with radix cache (164 tokens prefill, should hit > 128 tokens cache, based on page size)..."
525
+ )
526
+ prefix_token_ids = initial_token_ids + first_output_token_ids[:-1]
527
+ print(
528
+ f" Prefix: {len(initial_token_ids)} initial + 64 generated = {len(prefix_token_ids)} tokens"
529
+ )
530
+ print(f"Using Prompt: {prefix_token_ids}")
531
+ cached_response = send_single(
532
+ args,
533
+ input_ids=prefix_token_ids,
534
+ max_new_tokens=1,
535
+ return_full_response=True,
536
+ )
537
+ cached_logprobs = cached_response["meta_info"]["output_token_logprobs"]
538
+ cached_token_data = cached_logprobs[0]
539
+ cached_logprob = cached_token_data[0]
540
+ cached_token_id = cached_token_data[1]
541
+
542
+ print(f"✓ Generated with cache:")
543
+ print(f" Token ID: {cached_token_id}")
544
+ print(f" Logprob: {cached_logprob:.10f}")
545
+
546
+ print(f"\nStep 4: Flushing cache...")
547
+ flush_response = requests.post(f"http://{args.host}:{args.port}/flush_cache")
548
+
549
+ print(
550
+ f"\nStep 5: Generating without cache (same 164 tokens prefill, no cache)..."
551
+ )
552
+ print(f"Using Prompt: {prefix_token_ids}")
553
+
554
+ uncached_response = send_single(
555
+ args,
556
+ input_ids=prefix_token_ids,
557
+ max_new_tokens=1,
558
+ return_full_response=True,
559
+ )
560
+
561
+ uncached_logprobs = uncached_response["meta_info"]["output_token_logprobs"]
562
+ uncached_token_data = uncached_logprobs[0]
563
+ uncached_logprob = uncached_token_data[0]
564
+ uncached_token_id = uncached_token_data[1]
565
+
566
+ print(f"✓ Generated without cache:")
567
+ print(f" Token ID: {uncached_token_id}")
568
+ print(f" Logprob: {uncached_logprob:.10f}")
569
+
570
+ # Step 6: Compare results
571
+ print(f"\n{'='*60}")
572
+ print("Comparison 1: Decode (Request 1) vs Prefill with Cache (Request 2)")
573
+ print("=" * 60)
574
+
575
+ # Compare first request (decode) vs second request (prefill with cache)
576
+ # We expect them to be different (different kernels)
577
+ decode_vs_prefill_token_match = expected_token_id == cached_token_id
578
+ decode_vs_prefill_logprob_match = expected_logprob == cached_logprob
579
+
580
+ print(
581
+ f" Decode token (Request 1): ID={expected_token_id}, logprob={expected_logprob:.10f}"
582
+ )
583
+ print(
584
+ f" Prefill w/ cache token (Request 2): ID={cached_token_id}, logprob={cached_logprob:.10f}"
585
+ )
586
+ print(
587
+ f" Token ID match: {'✓ YES' if decode_vs_prefill_token_match else '✗ NO'}"
588
+ )
589
+ print(
590
+ f" Logprob match: {'✓ YES' if decode_vs_prefill_logprob_match else '✗ NO'}"
591
+ )
592
+ if not decode_vs_prefill_logprob_match:
593
+ diff = abs(expected_logprob - cached_logprob)
594
+ print(f" Logprob difference: {diff:.10e}")
595
+ print(f" Note: We expect these to be DIFFERENT (decode vs prefill kernels)")
596
+
597
+ print(f"\n{'='*60}")
598
+ print(
599
+ "Comparison 2: Cached Prefill (Request 2) vs Uncached Prefill (Request 3)"
600
+ )
601
+ print("=" * 60)
602
+
603
+ # Main test: compare cached vs uncached prefill (should be identical)
604
+ token_match = cached_token_id == uncached_token_id
605
+ logprob_match = cached_logprob == uncached_logprob
606
+
607
+ print(
608
+ f" Cached prefill token (Request 2): ID={cached_token_id}, logprob={cached_logprob:.10f}"
609
+ )
610
+ print(
611
+ f" Uncached prefill token (Request 3): ID={uncached_token_id}, logprob={uncached_logprob:.10f}"
612
+ )
613
+ print(f" Token ID match: {'✓ YES' if token_match else '✗ NO'}")
614
+ if not token_match:
615
+ print(f" Cached: {cached_token_id}")
616
+ print(f" Uncached: {uncached_token_id}")
617
+
618
+ print(f" Logprob match: {'✓ YES' if logprob_match else '✗ NO'}")
619
+ if not logprob_match:
620
+ print(f" Cached: {cached_logprob:.10f}")
621
+ print(f" Uncached: {uncached_logprob:.10f}")
622
+ diff = abs(cached_logprob - uncached_logprob)
623
+ print(f" Difference: {diff:.10e}")
624
+ print(f" Note: We expect these to be IDENTICAL (both prefill kernels)")
625
+
626
+ print(f"\n{'='*60}")
627
+ if token_match and logprob_match:
628
+ print("✓✓✓ TEST PASSED - Radix cache is consistent! ✓✓✓")
629
+ return [1]
630
+ else:
631
+ print("✗✗✗ TEST FAILED - Radix cache produces different results! ✗✗✗")
632
+ return [0]
633
+
634
+ elif args.test_mode == "p_vs_d":
635
+ # TODO also extract other modes to functions
636
+ ans = []
637
+ for i in range(1, args.n_trials + 1):
638
+ ans += _test_mode_p_vs_d(args, batch_size=i)
639
+ return ans
640
+
288
641
  else:
289
642
  raise ValueError(f"Invalid test mode: {args.test_mode}")
290
643
 
@@ -294,4 +647,7 @@ if __name__ == "__main__":
294
647
  BenchArgs.add_cli_args(parser)
295
648
  args = parser.parse_args()
296
649
 
650
+ if args.sampling_seed is None:
651
+ args.sampling_seed = 42
652
+
297
653
  test_deterministic(args)