sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,15 @@ from sglang.srt.layers.attention.trtllm_mla_backend import (
16
16
  TRTLLMMLABackend,
17
17
  TRTLLMMLADecodeMetadata,
18
18
  )
19
- from sglang.srt.layers.attention.utils import TRITON_PAD_NUM_PAGE_PER_BLOCK
19
+ from sglang.srt.layers.attention.utils import get_num_page_per_block_flashmla
20
20
  from sglang.srt.layers.radix_attention import RadixAttention
21
21
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
22
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+ from sglang.srt.server_args import (
24
+ ServerArgs,
25
+ get_global_server_args,
26
+ set_global_server_args_for_scheduler,
27
+ )
23
28
  from sglang.srt.utils import is_flashinfer_available
24
29
  from sglang.test.test_utils import CustomTestCase
25
30
 
@@ -104,15 +109,15 @@ TEST_CASES = {
104
109
  "page_size": 32,
105
110
  "description": "Single FP16 vs reference",
106
111
  },
107
- {
108
- "name": "single_fp8",
109
- "batch_size": 1,
110
- "max_seq_len": 64,
111
- "page_size": 64,
112
- "tolerance": 1e-1,
113
- "kv_cache_dtype": torch.float8_e4m3fn,
114
- "description": "Single FP8 vs reference",
115
- },
112
+ # {
113
+ # "name": "single_fp8",
114
+ # "batch_size": 1,
115
+ # "max_seq_len": 64,
116
+ # "page_size": 64,
117
+ # "tolerance": 1e-1,
118
+ # "kv_cache_dtype": torch.float8_e4m3fn,
119
+ # "description": "Single FP8 vs reference",
120
+ # },
116
121
  {
117
122
  "name": "batch_fp16",
118
123
  "batch_size": 32,
@@ -120,15 +125,15 @@ TEST_CASES = {
120
125
  "page_size": 32,
121
126
  "description": "Batch FP16 vs reference",
122
127
  },
123
- {
124
- "name": "batch_fp8",
125
- "batch_size": 32,
126
- "max_seq_len": 64,
127
- "page_size": 64,
128
- "tolerance": 1e-1,
129
- "kv_cache_dtype": torch.float8_e4m3fn,
130
- "description": "Batch FP8 vs reference",
131
- },
128
+ # {
129
+ # "name": "batch_fp8",
130
+ # "batch_size": 32,
131
+ # "max_seq_len": 64,
132
+ # "page_size": 64,
133
+ # "tolerance": 1e-1,
134
+ # "kv_cache_dtype": torch.float8_e4m3fn,
135
+ # "description": "Batch FP8 vs reference",
136
+ # },
132
137
  ],
133
138
  "page_size_consistency": [
134
139
  # Only 32 and 64 supported for now in flashinfer TRTLLM-GEN MLA kernel
@@ -213,13 +218,7 @@ class MockModelRunner:
213
218
  self.page_size = config["page_size"]
214
219
 
215
220
  # Server args stub - needed by attention backends
216
- self.server_args = type(
217
- "ServerArgs",
218
- (),
219
- {
220
- "enable_dp_attention": False, # Default value for testing
221
- },
222
- )
221
+ self.server_args = get_global_server_args()
223
222
 
224
223
  # Model-config stub with MLA attributes
225
224
  self.model_config = type(
@@ -320,6 +319,17 @@ def compare_outputs(trtllm_out, reference_out, tolerance=1e-2):
320
319
  class TestTRTLLMMLA(CustomTestCase):
321
320
  """Test suite for TRTLLM MLA backend with centralized configuration."""
322
321
 
322
+ @classmethod
323
+ def setUpClass(cls):
324
+ """Set up global server args for testing."""
325
+ server_args = ServerArgs(model_path="dummy")
326
+ server_args.enable_dp_attention = False
327
+ set_global_server_args_for_scheduler(server_args)
328
+
329
+ @classmethod
330
+ def tearDownClass(cls):
331
+ pass
332
+
323
333
  def _merge_config(self, test_case):
324
334
  """Merge test case with default configuration."""
325
335
  config = DEFAULT_CONFIG.copy()
@@ -841,25 +851,17 @@ class TestTRTLLMMLA(CustomTestCase):
841
851
  backend.init_forward_metadata(fb)
842
852
 
843
853
  # Verify metadata exists
844
- self.assertIsNotNone(backend.forward_metadata)
845
- self.assertIsInstance(backend.forward_metadata, TRTLLMMLADecodeMetadata)
854
+ self.assertIsNotNone(backend.forward_decode_metadata)
855
+ self.assertIsInstance(
856
+ backend.forward_decode_metadata, TRTLLMMLADecodeMetadata
857
+ )
846
858
 
847
859
  # Test metadata structure
848
- metadata = backend.forward_metadata
849
- self.assertIsNotNone(
850
- metadata.workspace, "Workspace should be allocated"
851
- )
860
+ metadata = backend.forward_decode_metadata
852
861
  self.assertIsNotNone(
853
862
  metadata.block_kv_indices, "Block KV indices should be created"
854
863
  )
855
864
 
856
- # Test workspace properties
857
- self.assertEqual(metadata.workspace.device.type, "cuda")
858
- self.assertEqual(metadata.workspace.dtype, torch.uint8)
859
- self.assertGreater(
860
- metadata.workspace.numel(), 0, "Workspace should have non-zero size"
861
- )
862
-
863
865
  # Test block KV indices properties
864
866
  self.assertEqual(metadata.block_kv_indices.device.type, "cuda")
865
867
  self.assertEqual(metadata.block_kv_indices.dtype, torch.int32)
@@ -915,9 +917,10 @@ class TestTRTLLMMLA(CustomTestCase):
915
917
 
916
918
  # Should satisfy TRT-LLM and Triton constraints
917
919
  trtllm_constraint = 128 // scenario["page_size"]
918
- constraint_lcm = math.lcm(
919
- trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK
920
+ triton_constraint = get_num_page_per_block_flashmla(
921
+ scenario["page_size"]
920
922
  )
923
+ constraint_lcm = math.lcm(trtllm_constraint, triton_constraint)
921
924
  self.assertEqual(
922
925
  calculated_blocks % constraint_lcm,
923
926
  0,
@@ -965,7 +968,7 @@ class TestTRTLLMMLA(CustomTestCase):
965
968
 
966
969
  # Initialize metadata
967
970
  backend.init_forward_metadata(fb)
968
- metadata = backend.forward_metadata
971
+ metadata = backend.forward_decode_metadata
969
972
 
970
973
  # Verify KV indices structure
971
974
  block_kv_indices = metadata.block_kv_indices
@@ -1016,7 +1019,6 @@ class TestTRTLLMMLA(CustomTestCase):
1016
1019
 
1017
1020
  # Verify CUDA graph buffers are allocated
1018
1021
  self.assertIsNotNone(backend.decode_cuda_graph_kv_indices)
1019
- self.assertIsNotNone(backend.decode_cuda_graph_workspace)
1020
1022
 
1021
1023
  # Test capture metadata
1022
1024
  seq_lens = torch.full(
@@ -1038,7 +1040,6 @@ class TestTRTLLMMLA(CustomTestCase):
1038
1040
  self.assertIn(batch_size, backend.decode_cuda_graph_metadata)
1039
1041
  capture_metadata = backend.decode_cuda_graph_metadata[batch_size]
1040
1042
 
1041
- self.assertIsNotNone(capture_metadata.workspace)
1042
1043
  self.assertIsNotNone(capture_metadata.block_kv_indices)
1043
1044
 
1044
1045
  # Test replay with different sequence lengths
@@ -1061,11 +1062,8 @@ class TestTRTLLMMLA(CustomTestCase):
1061
1062
  )
1062
1063
 
1063
1064
  # Verify replay updated the metadata
1064
- replay_metadata = backend.forward_metadata
1065
+ replay_metadata = backend.forward_decode_metadata
1065
1066
  self.assertIsNotNone(replay_metadata)
1066
- self.assertEqual(
1067
- replay_metadata.workspace.data_ptr(), capture_metadata.workspace.data_ptr()
1068
- )
1069
1067
 
1070
1068
  def test_metadata_consistency_across_calls(self):
1071
1069
  """Test metadata consistency across multiple forward calls."""
@@ -1083,7 +1081,7 @@ class TestTRTLLMMLA(CustomTestCase):
1083
1081
  config["batch_size"], seq_lens_1, backend, model_runner, config
1084
1082
  )
1085
1083
  backend.init_forward_metadata(fb_1)
1086
- metadata_1 = backend.forward_metadata
1084
+ metadata_1 = backend.forward_decode_metadata
1087
1085
 
1088
1086
  # Second call with same sequence lengths
1089
1087
  seq_lens_2 = torch.tensor([32, 48], device=config["device"])
@@ -1091,10 +1089,9 @@ class TestTRTLLMMLA(CustomTestCase):
1091
1089
  config["batch_size"], seq_lens_2, backend, model_runner, config
1092
1090
  )
1093
1091
  backend.init_forward_metadata(fb_2)
1094
- metadata_2 = backend.forward_metadata
1092
+ metadata_2 = backend.forward_decode_metadata
1095
1093
 
1096
1094
  # Metadata structure should be consistent
1097
- self.assertEqual(metadata_1.workspace.shape, metadata_2.workspace.shape)
1098
1095
  self.assertEqual(
1099
1096
  metadata_1.block_kv_indices.shape, metadata_2.block_kv_indices.shape
1100
1097
  )
@@ -1105,10 +1102,9 @@ class TestTRTLLMMLA(CustomTestCase):
1105
1102
  config["batch_size"], seq_lens_3, backend, model_runner, config
1106
1103
  )
1107
1104
  backend.init_forward_metadata(fb_3)
1108
- metadata_3 = backend.forward_metadata
1105
+ metadata_3 = backend.forward_decode_metadata
1109
1106
 
1110
1107
  # Should still have valid structure
1111
- self.assertIsNotNone(metadata_3.workspace)
1112
1108
  self.assertIsNotNone(metadata_3.block_kv_indices)
1113
1109
  self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"])
1114
1110
 
@@ -1263,6 +1259,178 @@ class TestTRTLLMMLA(CustomTestCase):
1263
1259
  f"Max diff: {(out_trtllm - out_reference).abs().max().item()}",
1264
1260
  )
1265
1261
 
1262
+ def test_draft_extend_padding_unpadding_kernels(self):
1263
+ """Test TRTLLM MLA Triton kernels: pad_draft_extend_query_kernel and unpad_draft_extend_output_kernel."""
1264
+
1265
+ # Import the kernels
1266
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
1267
+ pad_draft_extend_query_kernel,
1268
+ unpad_draft_extend_output_kernel,
1269
+ )
1270
+
1271
+ def _create_test_data(
1272
+ self, batch_size, max_seq_len, num_heads, head_dim, dtype=torch.float32
1273
+ ):
1274
+ """Create test data for kernel testing."""
1275
+ device = torch.device("cuda")
1276
+
1277
+ # Create sequence lengths (varying lengths for each batch)
1278
+ seq_lens = torch.randint(
1279
+ 1, max_seq_len + 1, (batch_size,), device=device, dtype=torch.int32
1280
+ )
1281
+
1282
+ # Create cumulative sequence lengths
1283
+ cum_seq_lens = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
1284
+ cum_seq_lens[1:] = torch.cumsum(seq_lens, dim=0)
1285
+
1286
+ # Create input query tensor (flattened format)
1287
+ total_tokens = cum_seq_lens[-1].item()
1288
+ q_input = torch.randn(
1289
+ total_tokens, num_heads, head_dim, device=device, dtype=dtype
1290
+ )
1291
+
1292
+ # Create padded query tensor (batch format)
1293
+ padded_q = torch.zeros(
1294
+ batch_size, max_seq_len, num_heads, head_dim, device=device, dtype=dtype
1295
+ )
1296
+
1297
+ return q_input, padded_q, seq_lens, cum_seq_lens
1298
+
1299
+ def _create_test_output_data(
1300
+ self,
1301
+ batch_size,
1302
+ token_per_batch,
1303
+ tp_q_head_num,
1304
+ v_head_dim,
1305
+ dtype=torch.float32,
1306
+ ):
1307
+ """Create test data for unpad kernel testing."""
1308
+ device = torch.device("cuda")
1309
+
1310
+ # Create accept lengths (varying lengths for each batch)
1311
+ accept_lengths = torch.randint(
1312
+ 1, token_per_batch + 1, (batch_size,), device=device, dtype=torch.int32
1313
+ )
1314
+
1315
+ # Create cumulative accept lengths
1316
+ cum_accept_lengths = torch.zeros(
1317
+ batch_size + 1, device=device, dtype=torch.int32
1318
+ )
1319
+ cum_accept_lengths[1:] = torch.cumsum(accept_lengths, dim=0)
1320
+
1321
+ # Create raw output tensor (batch format)
1322
+ raw_out = torch.randn(
1323
+ batch_size,
1324
+ token_per_batch,
1325
+ tp_q_head_num,
1326
+ v_head_dim,
1327
+ device=device,
1328
+ dtype=dtype,
1329
+ )
1330
+
1331
+ # Create output tensor (flattened format)
1332
+ total_tokens = cum_accept_lengths[-1].item()
1333
+ output = torch.empty(
1334
+ total_tokens, tp_q_head_num, v_head_dim, device=device, dtype=dtype
1335
+ )
1336
+
1337
+ return raw_out, output, accept_lengths, cum_accept_lengths
1338
+
1339
+ # Test 1: pad_draft_extend_query_kernel basic functionality
1340
+ with self.subTest(test="pad_kernel_basic"):
1341
+ batch_size = 4
1342
+ max_seq_len = 8
1343
+ num_heads = 16
1344
+ head_dim = 64
1345
+
1346
+ q_input, padded_q, seq_lens, cum_seq_lens = _create_test_data(
1347
+ self, batch_size, max_seq_len, num_heads, head_dim
1348
+ )
1349
+
1350
+ # Launch kernel
1351
+ BLOCK_SIZE = 64
1352
+ grid = (batch_size * max_seq_len,)
1353
+
1354
+ pad_draft_extend_query_kernel[grid](
1355
+ q_ptr=q_input,
1356
+ padded_q_ptr=padded_q,
1357
+ seq_lens_q_ptr=seq_lens,
1358
+ cumsum_ptr=cum_seq_lens,
1359
+ batch_size=batch_size,
1360
+ max_seq_len=max_seq_len,
1361
+ num_heads=num_heads,
1362
+ head_dim=head_dim,
1363
+ BLOCK_SIZE=BLOCK_SIZE,
1364
+ )
1365
+
1366
+ # Verify the padding worked correctly
1367
+ for i in range(batch_size):
1368
+ seq_len = seq_lens[i].item()
1369
+
1370
+ # Check that valid positions are copied correctly
1371
+ for pos in range(seq_len):
1372
+ input_start = cum_seq_lens[i].item()
1373
+ input_pos = input_start + pos
1374
+
1375
+ # Compare input and output for valid positions
1376
+ input_data = q_input[input_pos]
1377
+ output_data = padded_q[i, pos]
1378
+
1379
+ torch.testing.assert_close(
1380
+ input_data, output_data, rtol=1e-5, atol=1e-6
1381
+ )
1382
+
1383
+ # Check that invalid positions are zero
1384
+ for pos in range(seq_len, max_seq_len):
1385
+ output_data = padded_q[i, pos]
1386
+ self.assertTrue(
1387
+ torch.allclose(output_data, torch.zeros_like(output_data)),
1388
+ f"Position {pos} in batch {i} should be zero",
1389
+ )
1390
+
1391
+ # Test 2: unpad_draft_extend_output_kernel basic functionality
1392
+ with self.subTest(test="unpad_kernel_basic"):
1393
+ batch_size = 4
1394
+ token_per_batch = 8
1395
+ tp_q_head_num = 16
1396
+ v_head_dim = 64
1397
+
1398
+ raw_out, output, accept_lengths, cum_accept_lengths = (
1399
+ _create_test_output_data(
1400
+ self, batch_size, token_per_batch, tp_q_head_num, v_head_dim
1401
+ )
1402
+ )
1403
+
1404
+ # Launch kernel
1405
+ BLOCK_SIZE = 64
1406
+ grid = (batch_size * token_per_batch,)
1407
+
1408
+ unpad_draft_extend_output_kernel[grid](
1409
+ raw_out_ptr=raw_out,
1410
+ output_ptr=output,
1411
+ accept_length_ptr=accept_lengths,
1412
+ cumsum_ptr=cum_accept_lengths,
1413
+ batch_size=batch_size,
1414
+ token_per_batch=token_per_batch,
1415
+ tp_q_head_num=tp_q_head_num,
1416
+ v_head_dim=v_head_dim,
1417
+ BLOCK_SIZE=BLOCK_SIZE,
1418
+ )
1419
+
1420
+ # Verify the unpadding worked correctly
1421
+ for i in range(batch_size):
1422
+ accept_len = accept_lengths[i].item()
1423
+ output_start = cum_accept_lengths[i].item()
1424
+
1425
+ # Check that valid positions are copied correctly
1426
+ for pos in range(accept_len):
1427
+ input_data = raw_out[i, pos]
1428
+ output_data = output[output_start + pos]
1429
+
1430
+ torch.testing.assert_close(
1431
+ input_data, output_data, rtol=1e-5, atol=1e-6
1432
+ )
1433
+
1266
1434
 
1267
1435
  if __name__ == "__main__":
1268
1436
  unittest.main()
@@ -1,16 +1,14 @@
1
1
  import argparse
2
2
  import ast
3
3
  import asyncio
4
- import json
5
4
  import re
6
5
  import time
6
+ from typing import Optional
7
7
 
8
8
  import numpy as np
9
9
 
10
10
  import sglang as sgl
11
- from sglang.lang.api import set_default_backend
12
- from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
13
- from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
11
+ from sglang.utils import download_and_cache_file, read_jsonl
14
12
 
15
13
  INVALID = -9999999
16
14
 
@@ -0,0 +1,157 @@
1
+ import json
2
+
3
+ import requests
4
+
5
+ MANY_NEW_TOKENS_PROMPT = """
6
+ Please write an extremely detailed and vivid fantasy story, set in a world full of intricate magic systems, political intrigue, and complex characters.
7
+ Ensure that you thoroughly describe every scene, character's motivations, and the environment. Include long, engaging dialogues and elaborate on the inner thoughts of the characters.
8
+ Each section should be as comprehensive as possible to create a rich and immersive experience for the reader.
9
+ The story should span multiple events, challenges, and character developments over time. Aim to make the story at least 3,000 words long.
10
+ """
11
+
12
+
13
+ class MatchedStopMixin:
14
+ def _run_completions_generation(
15
+ self,
16
+ prompt=MANY_NEW_TOKENS_PROMPT,
17
+ max_tokens=1,
18
+ stop=None,
19
+ stop_regex=None,
20
+ finish_reason=None,
21
+ matched_stop=None,
22
+ ):
23
+ payload = {
24
+ "prompt": prompt,
25
+ "model": self.model,
26
+ "temperature": 0,
27
+ "top_p": 1,
28
+ "max_tokens": max_tokens,
29
+ }
30
+
31
+ if stop is not None:
32
+ payload["stop"] = stop
33
+
34
+ if stop_regex is not None:
35
+ payload["stop_regex"] = stop_regex
36
+
37
+ response_completions = requests.post(
38
+ self.base_url + "/v1/completions",
39
+ json=payload,
40
+ )
41
+ res = response_completions.json()
42
+ print(json.dumps(res))
43
+ print("=" * 100)
44
+
45
+ if not isinstance(matched_stop, list):
46
+ matched_stop = [matched_stop]
47
+
48
+ assert (
49
+ res["choices"][0]["finish_reason"] == finish_reason
50
+ ), f"Expected finish_reason: {finish_reason}, but got: {res['choices'][0]['finish_reason']}"
51
+ assert (
52
+ res["choices"][0]["matched_stop"] in matched_stop
53
+ ), f"Expected matched_stop: {matched_stop}, but got: {res['choices'][0]['matched_stop']}"
54
+
55
+ def _run_chat_completions_generation(
56
+ self,
57
+ prompt=MANY_NEW_TOKENS_PROMPT,
58
+ max_tokens=1,
59
+ stop=None,
60
+ stop_regex=None,
61
+ finish_reason=None,
62
+ matched_stop=None,
63
+ ):
64
+ chat_payload = {
65
+ "model": self.model,
66
+ "messages": [
67
+ {"role": "system", "content": "You are a helpful AI assistant"},
68
+ {"role": "user", "content": prompt},
69
+ ],
70
+ "temperature": 0,
71
+ "top_p": 1,
72
+ "max_tokens": max_tokens,
73
+ }
74
+
75
+ if stop is not None:
76
+ chat_payload["stop"] = stop
77
+
78
+ if stop_regex is not None:
79
+ chat_payload["stop_regex"] = stop_regex
80
+
81
+ response_chat = requests.post(
82
+ self.base_url + "/v1/chat/completions",
83
+ json=chat_payload,
84
+ )
85
+ res = response_chat.json()
86
+ print(json.dumps(res))
87
+ print("=" * 100)
88
+
89
+ if not isinstance(matched_stop, list):
90
+ matched_stop = [matched_stop]
91
+
92
+ assert (
93
+ res["choices"][0]["finish_reason"] == finish_reason
94
+ ), f"Expected finish_reason: {finish_reason}, but got: {res['choices'][0]['finish_reason']}"
95
+ assert (
96
+ res["choices"][0]["matched_stop"] in matched_stop
97
+ ), f"Expected matched_stop: {matched_stop}, but got: {res['choices'][0]['matched_stop']}"
98
+
99
+ def test_finish_stop_str(self):
100
+ self._run_completions_generation(
101
+ max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
102
+ )
103
+ self._run_chat_completions_generation(
104
+ max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
105
+ )
106
+
107
+ def test_finish_stop_regex_str(self):
108
+ STOP_REGEX_STR = r"and|or"
109
+ self._run_completions_generation(
110
+ max_tokens=1000,
111
+ stop_regex=STOP_REGEX_STR,
112
+ finish_reason="stop",
113
+ matched_stop=STOP_REGEX_STR,
114
+ )
115
+ self._run_chat_completions_generation(
116
+ max_tokens=1000,
117
+ stop_regex=STOP_REGEX_STR,
118
+ finish_reason="stop",
119
+ matched_stop=STOP_REGEX_STR,
120
+ )
121
+
122
+ # Match a complete sentence
123
+ STOP_REGEX_STR_SENTENCE = r"[.!?]\s*$"
124
+ self._run_chat_completions_generation(
125
+ max_tokens=1000,
126
+ stop_regex=STOP_REGEX_STR_SENTENCE,
127
+ finish_reason="stop",
128
+ matched_stop=STOP_REGEX_STR_SENTENCE,
129
+ )
130
+
131
+ def test_finish_stop_eos(self):
132
+ llama_format_prompt = """\
133
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|>
134
+ You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>
135
+ What is 2 + 2?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
136
+ """
137
+ eos_token_ids = [128000, 128009, 2]
138
+ self._run_completions_generation(
139
+ prompt=llama_format_prompt,
140
+ max_tokens=1000,
141
+ finish_reason="stop",
142
+ matched_stop=eos_token_ids,
143
+ )
144
+ self._run_chat_completions_generation(
145
+ prompt="What is 2 + 2?",
146
+ max_tokens=1000,
147
+ finish_reason="stop",
148
+ matched_stop=eos_token_ids,
149
+ )
150
+
151
+ def test_finish_length(self):
152
+ self._run_completions_generation(
153
+ max_tokens=5, finish_reason="length", matched_stop=None
154
+ )
155
+ self._run_chat_completions_generation(
156
+ max_tokens=5, finish_reason="length", matched_stop=None
157
+ )
@@ -0,0 +1 @@
1
+ """LongBench-v2 auxiliary utilities and validation scripts."""