sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
sglang/test/runners.py CHANGED
@@ -519,6 +519,7 @@ class SRTRunner:
519
519
  lora_target_modules: Optional[List[str]] = None,
520
520
  enable_lora: Optional[bool] = None,
521
521
  max_loaded_loras: Optional[int] = None,
522
+ lora_eviction_policy: str = "lru",
522
523
  ):
523
524
  self.model_type = model_type
524
525
  self.is_generation = model_type == "generation"
@@ -565,6 +566,7 @@ class SRTRunner:
565
566
  lora_target_modules=lora_target_modules,
566
567
  enable_lora=enable_lora,
567
568
  max_loaded_loras=max_loaded_loras,
569
+ lora_eviction_policy=lora_eviction_policy,
568
570
  **spec_kwargs,
569
571
  )
570
572
 
sglang/test/send_one.py CHANGED
@@ -3,6 +3,8 @@ Run one test prompt.
3
3
 
4
4
  Usage:
5
5
  python3 -m sglang.test.send_one
6
+ python3 -m sglang.test.send_one --profile --profile-steps 5
7
+ python3 -m sglang.test.send_one --profile --profile-by-stage
6
8
  """
7
9
 
8
10
  import argparse
@@ -10,6 +12,9 @@ import dataclasses
10
12
  import json
11
13
 
12
14
  import requests
15
+ import tabulate
16
+
17
+ from sglang.profiler import run_profile
13
18
 
14
19
 
15
20
  @dataclasses.dataclass
@@ -29,6 +34,9 @@ class BenchArgs:
29
34
  image: bool = False
30
35
  many_images: bool = False
31
36
  stream: bool = False
37
+ profile: bool = False
38
+ profile_steps: int = 3
39
+ profile_by_stage: bool = False
32
40
 
33
41
  @staticmethod
34
42
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -51,6 +59,11 @@ class BenchArgs:
51
59
  parser.add_argument("--image", action="store_true")
52
60
  parser.add_argument("--many-images", action="store_true")
53
61
  parser.add_argument("--stream", action="store_true")
62
+ parser.add_argument("--profile", action="store_true")
63
+ parser.add_argument(
64
+ "--profile-steps", type=int, default=BenchArgs.profile_steps
65
+ )
66
+ parser.add_argument("--profile-by-stage", action="store_true")
54
67
 
55
68
  @classmethod
56
69
  def from_cli_args(cls, args: argparse.Namespace):
@@ -59,6 +72,8 @@ class BenchArgs:
59
72
 
60
73
 
61
74
  def send_one_prompt(args):
75
+ base_url = f"http://{args.host}:{args.port}"
76
+
62
77
  if args.image:
63
78
  args.prompt = (
64
79
  "Human: Describe this image in a very short sentence.\n\nAssistant:"
@@ -108,19 +123,35 @@ def send_one_prompt(args):
108
123
  "stream": args.stream,
109
124
  }
110
125
 
126
+ # Run profiler if requested
127
+ if args.profile:
128
+ print(f"Running profiler with {args.profile_steps} steps...")
129
+ run_profile(
130
+ base_url,
131
+ args.profile_steps,
132
+ ["CPU", "GPU"],
133
+ None,
134
+ None,
135
+ args.profile_by_stage,
136
+ )
137
+
111
138
  response = requests.post(
112
- f"http://{args.host}:{args.port}/generate",
139
+ f"{base_url}/generate",
113
140
  json=json_data,
114
141
  stream=args.stream,
115
142
  )
116
143
 
117
144
  if args.stream:
145
+ last_len = 0
118
146
  for chunk in response.iter_lines(decode_unicode=False):
119
147
  chunk = chunk.decode("utf-8")
120
148
  if chunk and chunk.startswith("data:"):
121
149
  if chunk == "data: [DONE]":
122
150
  break
123
151
  ret = json.loads(chunk[5:].strip("\n"))
152
+ chunk_str = ret["text"][last_len:]
153
+ last_len = len(ret["text"])
154
+ print(chunk_str, end="", flush=True)
124
155
  else:
125
156
  ret = response.json()
126
157
 
@@ -131,21 +162,25 @@ def send_one_prompt(args):
131
162
  print(ret)
132
163
  return 0, 0
133
164
 
134
- latency = ret["meta_info"]["e2e_latency"]
135
-
136
- if "spec_verify_ct" in ret["meta_info"]:
165
+ if "spec_verify_ct" in ret["meta_info"] and ret["meta_info"]["spec_verify_ct"] > 0:
137
166
  acc_length = (
138
167
  ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
139
168
  )
140
169
  else:
141
170
  acc_length = 1.0
142
171
 
172
+ latency = ret["meta_info"]["e2e_latency"]
143
173
  speed = ret["meta_info"]["completion_tokens"] / latency
174
+ tokens = ret["meta_info"]["completion_tokens"]
175
+
176
+ if not args.stream:
177
+ print(ret["text"])
144
178
 
145
- print(ret["text"])
146
179
  print()
147
- print(f"{acc_length=:.2f}")
148
- print(f"{speed=:.2f} token/s")
180
+ headers = ["Latency (s)", "Tokens", "Acc Length", "Speed (token/s)"]
181
+ rows = [[f"{latency:.3f}", f"{tokens}", f"{acc_length:.3f}", f"{speed:.2f}"]]
182
+ msg = tabulate.tabulate(rows, headers=headers, tablefmt="pretty")
183
+ print(msg)
149
184
 
150
185
  return acc_length, speed
151
186
 
@@ -290,6 +290,9 @@ def aggregate_results(
290
290
  htmls = []
291
291
  convos = []
292
292
  for single_eval_result in single_eval_results:
293
+ # Skip None results
294
+ if single_eval_result is None:
295
+ continue
293
296
  for name, value in single_eval_result.metrics.items():
294
297
  name2values[name].append(value)
295
298
  if single_eval_result.score is not None:
@@ -18,7 +18,6 @@ from sglang.test.simple_eval_common import (
18
18
  HTML_JINJA,
19
19
  Eval,
20
20
  EvalResult,
21
- MessageList,
22
21
  SamplerBase,
23
22
  SingleEvalResult,
24
23
  format_multichoice_question,
@@ -11,8 +11,6 @@ import re
11
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
12
  from typing import Dict, List, Optional
13
13
 
14
- import tqdm
15
-
16
14
  try:
17
15
  from human_eval.data import read_problems
18
16
  from human_eval.evaluation import estimate_pass_at_k
@@ -41,7 +39,6 @@ def evaluate_functional_correctness(
41
39
  Evaluates the functional correctness of generated samples, and writes
42
40
  results to f"{sample_file}_results.jsonl.gz"
43
41
  """
44
- import copy
45
42
 
46
43
  # Check the generated samples against test suites.
47
44
  with ThreadPoolExecutor(max_workers=n_workers) as executor:
@@ -0,0 +1,344 @@
1
+ # Adapted from https://github.com/openai/simple-evals/
2
+
3
+ """
4
+ LongBench v2: Towards Deeper Understanding and Reasoning on Realistic Long-Context Multitasks
5
+ Yushi Bai, Shangqing Tu, Jiajie Zhang, Hao Peng, Xiaozhi Wang, Xin Lv, Shulin Cao, Jiazheng Xu, Lei Hou, Yuxiao Dong, Jie Tang, Juanzi Li
6
+ https://arxiv.org/abs/2412.15204
7
+ """
8
+
9
+ import csv
10
+ import json
11
+ import os
12
+ import re
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ from transformers import AutoTokenizer
16
+
17
+ from sglang.test import simple_eval_common as common
18
+ from sglang.test.simple_eval_common import (
19
+ ANSWER_PATTERN_MULTICHOICE,
20
+ HTML_JINJA,
21
+ Eval,
22
+ EvalResult,
23
+ SamplerBase,
24
+ SingleEvalResult,
25
+ )
26
+
27
+ # LongBench-v2 task categories
28
+ TASK_CATEGORIES = {
29
+ "single_document_qa",
30
+ "multi_document_qa",
31
+ "long_in_context_learning",
32
+ "long_dialogue_history",
33
+ "code_repo_understanding",
34
+ "long_structured_data",
35
+ }
36
+
37
+ DEFAULT_DATASET = "THUDM/LongBench-v2"
38
+ DEFAULT_DATASET_SPLIT = "train"
39
+
40
+
41
+ def format_longbench_v2_question(row: dict) -> str:
42
+ """Format a LongBench-v2 question using the official template."""
43
+ context = row.get("context", "")
44
+ question = row.get("question", "")
45
+
46
+ # Handle both standard format (A, B, C, D) and alternative format (choices list)
47
+ if "choices" in row:
48
+ choices = row["choices"]
49
+ choice_A = choices[0] if len(choices) > 0 else ""
50
+ choice_B = choices[1] if len(choices) > 1 else ""
51
+ choice_C = choices[2] if len(choices) > 2 else ""
52
+ choice_D = choices[3] if len(choices) > 3 else ""
53
+ else:
54
+ choice_A = row.get("A", row.get("choice_A", ""))
55
+ choice_B = row.get("B", row.get("choice_B", ""))
56
+ choice_C = row.get("C", row.get("choice_C", ""))
57
+ choice_D = row.get("D", row.get("choice_D", ""))
58
+
59
+ # Official LongBench-v2 template
60
+ prompt = f"""
61
+ Please read the following text and answer the question below.
62
+ <text>
63
+ {context.strip()}
64
+ </text>
65
+
66
+ What is the correct answer to this question: {question.strip()}
67
+ Choices:
68
+ (A) {choice_A.strip()}
69
+ (B) {choice_B.strip()}
70
+ (C) {choice_C.strip()}
71
+ (D) {choice_D.strip()}
72
+
73
+ Format your response as follows: "The correct answer is (insert answer here)"."""
74
+
75
+ return prompt
76
+
77
+
78
+ def extract_longbench_v2_answer(response: str) -> Optional[str]:
79
+ """Extract answer from model response using official LongBench-v2 method."""
80
+ response = response.replace("*", "")
81
+
82
+ # First try: "The correct answer is (A)"
83
+ match = re.search(r"The correct answer is \(([A-D])\)", response, re.IGNORECASE)
84
+ if match:
85
+ return match.group(1).upper()
86
+
87
+ # Second try: "The correct answer is A"
88
+ match = re.search(r"The correct answer is ([A-D])", response, re.IGNORECASE)
89
+ if match:
90
+ return match.group(1).upper()
91
+
92
+ # Fallback: Standard SGLang multichoice pattern
93
+ match = re.search(ANSWER_PATTERN_MULTICHOICE, response)
94
+ if match:
95
+ return match.group(1).upper()
96
+
97
+ # Generic fallback when model says "answer is A"
98
+ match = re.search(r"answer\s+is\s*\(?([A-D])\)?", response, re.IGNORECASE)
99
+ if match:
100
+ return match.group(1).upper()
101
+
102
+ return None
103
+
104
+
105
+ class LongBenchV2Eval(Eval):
106
+ """
107
+ Evaluation utility for LongBench-v2 dataset.
108
+
109
+ LongBench-v2 is designed to assess the ability of LLMs to handle long-context problems
110
+ requiring deep understanding and reasoning across real-world multitasks.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ model: str = None,
116
+ data_source: str = DEFAULT_DATASET,
117
+ num_examples: Optional[int] = None,
118
+ num_threads: int = 1,
119
+ n_repeats: int = 1,
120
+ categories: Optional[List[str]] = None,
121
+ max_context_length: Optional[int] = None,
122
+ min_context_length: Optional[int] = None,
123
+ ):
124
+ """
125
+ Initialize LongBench-v2 evaluation.
126
+
127
+ Args:
128
+ data_source: HuggingFace dataset name, local file path (CSV/JSON)
129
+ num_examples: Number of examples to evaluate (None for all)
130
+ num_threads: Number of threads for parallel processing
131
+ n_repeats: Number of times to repeat evaluation for error bars
132
+ categories: List of task categories to include (None for all)
133
+ max_context_length: Maximum context length in characters
134
+ min_context_length: Minimum context length in characters
135
+ """
136
+ self.tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
137
+ self.min_context_length = min_context_length
138
+ self.max_context_length = max_context_length
139
+ # Load dataset based on data source type
140
+ examples = self._load_dataset(data_source)
141
+
142
+ # Apply filtering
143
+ if categories:
144
+ examples = [ex for ex in examples if ex.get("category") in categories]
145
+
146
+ # Sample examples if specified
147
+ if num_examples:
148
+ assert n_repeats == 1, "n_repeats only supported when not sampling examples"
149
+ examples = examples[: min(num_examples, len(examples))]
150
+
151
+ # Repeat examples for multiple runs
152
+ examples = examples * n_repeats
153
+
154
+ if not examples:
155
+ raise ValueError(
156
+ "No examples available for LongBench-v2 evaluation after filtering"
157
+ )
158
+
159
+ self.examples = examples
160
+ self.n_repeats = n_repeats
161
+ self.num_threads = num_threads
162
+
163
+ print(f"Loaded {len(self.examples)} examples from LongBench-v2")
164
+ if categories:
165
+ print(f"Filtered to categories: {categories}")
166
+ if min_context_length or max_context_length:
167
+ print(
168
+ f"Context length filter: {min_context_length}-{max_context_length} characters"
169
+ )
170
+
171
+ def _load_dataset(self, data_source: str) -> List[Dict[str, Any]]:
172
+ """Load dataset from HuggingFace hub or local files."""
173
+
174
+ if not data_source:
175
+ data_source = DEFAULT_DATASET
176
+
177
+ if os.path.exists(data_source):
178
+ raw_examples = self._load_local_file(data_source)
179
+ else:
180
+ raw_examples = self._load_hf_dataset(data_source)
181
+
182
+ return [self._normalize_example(example) for example in raw_examples]
183
+
184
+ def _load_local_file(self, path: str) -> List[Dict[str, Any]]:
185
+ """Load examples from a local CSV/JSON/JSONL file."""
186
+
187
+ suffix = os.path.splitext(path)[1].lower()
188
+ if suffix in {".json", ".jsonl"}:
189
+ with open(path, "r", encoding="utf-8") as fh:
190
+ if suffix == ".jsonl":
191
+ data = [json.loads(line) for line in fh if line.strip()]
192
+ else:
193
+ data = json.load(fh)
194
+ elif suffix == ".csv":
195
+ with open(path, "r", encoding="utf-8") as fh:
196
+ reader = csv.DictReader(fh)
197
+ data = list(reader)
198
+ else:
199
+ # Try JSON, then CSV as fallback
200
+ try:
201
+ with open(path, "r", encoding="utf-8") as fh:
202
+ data = json.load(fh)
203
+ except json.JSONDecodeError:
204
+ with open(path, "r", encoding="utf-8") as fh:
205
+ reader = csv.DictReader(fh)
206
+ data = list(reader)
207
+
208
+ if isinstance(data, dict):
209
+ data = data.get("data", [])
210
+
211
+ if not isinstance(data, list):
212
+ raise ValueError("Expected list of examples from local file")
213
+
214
+ return data
215
+
216
+ def _load_hf_dataset(self, identifier: str) -> List[Dict[str, Any]]:
217
+ """Load the dataset from HuggingFace Hub."""
218
+
219
+ parts = identifier.split(":", maxsplit=1)
220
+ dataset_name = parts[0]
221
+ split = parts[1] if len(parts) == 2 else DEFAULT_DATASET_SPLIT
222
+
223
+ try:
224
+ from datasets import load_dataset # type: ignore
225
+ except ImportError as exc:
226
+ raise ImportError(
227
+ "Please install the 'datasets' package to load LongBench-v2 from HuggingFace: pip install datasets"
228
+ ) from exc
229
+
230
+ dataset = load_dataset(dataset_name, split=split)
231
+ return [dict(row) for row in dataset]
232
+
233
+ def _normalize_example(self, example: Dict[str, Any]) -> Dict[str, Any]:
234
+ """Ensure each example exposes the expected keys."""
235
+
236
+ normalized = dict(example)
237
+
238
+ for letter in ["A", "B", "C", "D"]:
239
+ choice_key = f"choice_{letter}"
240
+ if letter not in normalized and choice_key in normalized:
241
+ normalized[letter] = normalized[choice_key]
242
+
243
+ if "category" not in normalized and "domain" in normalized:
244
+ normalized["category"] = normalized["domain"]
245
+
246
+ answer = normalized.get("answer")
247
+ if isinstance(answer, str):
248
+ normalized["answer"] = answer.strip().upper()
249
+ elif isinstance(answer, int) and 0 <= answer < 4:
250
+ normalized["answer"] = ["A", "B", "C", "D"][answer]
251
+
252
+ return normalized
253
+
254
+ def _check_context_length(
255
+ self,
256
+ formatted_question: str,
257
+ tokenizer: AutoTokenizer,
258
+ min_length: Optional[int],
259
+ max_length: Optional[int],
260
+ ) -> bool:
261
+ """Filter examples by context length measured in characters."""
262
+ input_ids = tokenizer.encode(formatted_question)
263
+ context_length = len(input_ids)
264
+
265
+ if min_length is not None and context_length < min_length:
266
+ return False
267
+ if max_length is not None and context_length > max_length:
268
+ return False
269
+
270
+ return True
271
+
272
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
273
+ """Run the evaluation."""
274
+
275
+ def fn(row: dict):
276
+ # Format the question using official template
277
+ formatted_question = format_longbench_v2_question(row)
278
+
279
+ if self.min_context_length or self.max_context_length:
280
+ if not self._check_context_length(
281
+ formatted_question,
282
+ self.tokenizer,
283
+ self.min_context_length,
284
+ self.max_context_length,
285
+ ):
286
+ # Skip this example
287
+ return None
288
+
289
+ prompt_messages = [
290
+ sampler._pack_message(content=formatted_question, role="user")
291
+ ]
292
+
293
+ # Get model response
294
+ response_text = sampler(prompt_messages)
295
+ if response_text is None:
296
+ response_text = ""
297
+
298
+ # Extract answer using official method
299
+ extracted_answer = extract_longbench_v2_answer(response_text)
300
+
301
+ # Get correct answer
302
+ correct_answer = row.get("answer", "")
303
+ if isinstance(correct_answer, str):
304
+ correct_answer = correct_answer.strip().upper()
305
+ elif isinstance(correct_answer, int) and 0 <= correct_answer < 4:
306
+ correct_answer = ["A", "B", "C", "D"][correct_answer]
307
+
308
+ # Calculate score
309
+ score = 1.0 if extracted_answer == correct_answer else 0.0
310
+
311
+ # Generate HTML report
312
+ html = common.jinja_env.from_string(HTML_JINJA).render(
313
+ prompt_messages=prompt_messages,
314
+ next_message=dict(content=response_text, role="assistant"),
315
+ score=score,
316
+ correct_answer=correct_answer,
317
+ extracted_answer=extracted_answer,
318
+ )
319
+
320
+ # Build conversation
321
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
322
+
323
+ # Prepare metrics
324
+ metrics = {"chars": len(response_text)}
325
+
326
+ # Add category-specific metrics
327
+ category = row.get("category", row.get("domain", "unknown"))
328
+ if category in TASK_CATEGORIES:
329
+ metrics[category] = score
330
+
331
+ difficulty = row.get("difficulty")
332
+ if isinstance(difficulty, str) and difficulty:
333
+ metrics[f"difficulty_{difficulty.lower()}"] = score
334
+
335
+ return SingleEvalResult(
336
+ html=html,
337
+ score=score,
338
+ convo=convo,
339
+ metrics=metrics,
340
+ )
341
+
342
+ # Run evaluation with progress tracking
343
+ results = common.map_with_progress(fn, self.examples, self.num_threads)
344
+ return common.aggregate_results(results)
@@ -1,5 +1,4 @@
1
1
  import itertools
2
- import os
3
2
  import unittest
4
3
 
5
4
  import torch
@@ -577,7 +576,7 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
577
576
  if not torch.cuda.is_available():
578
577
  raise unittest.SkipTest("CUDA is not available")
579
578
  try:
580
- import deep_gemm
579
+ import deep_gemm # noqa: F401
581
580
  except ImportError:
582
581
  raise unittest.SkipTest("DeepGEMM is not available")
583
582
  torch.set_default_device("cuda")
@@ -1,5 +1,4 @@
1
1
  import itertools
2
- import os
3
2
  import unittest
4
3
  from typing import List, Tuple
5
4
 
@@ -1,5 +1,4 @@
1
1
  import argparse
2
- import time
3
2
 
4
3
  import torch
5
4
  import triton # Added import
@@ -34,7 +33,7 @@ def get_model_config(tp_size: int):
34
33
  "topk": topk,
35
34
  "hidden_size": config.hidden_size,
36
35
  "shard_intermediate_size": shard_intermediate_size,
37
- "dtype": config.torch_dtype,
36
+ "dtype": config.dtype,
38
37
  "block_shape": config.quantization_config["weight_block_size"],
39
38
  }
40
39
 
@@ -1,6 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- from typing import Literal, Optional
3
+ from typing import Optional
4
4
 
5
5
  import pytest
6
6
  import torch
@@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
120
120
  )
121
121
  topk_weights, topk_ids, _ = topk_output
122
122
  expert_map = torch.arange(E, dtype=torch.int32, device=device)
123
- expert_map[local_e:] = E
123
+ expert_map[local_e:] = -1
124
124
 
125
125
  output = cutlass_moe(
126
126
  a,
@@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
138
138
  c_strides2,
139
139
  s_strides13,
140
140
  s_strides2,
141
- 0,
142
- local_e - 1,
143
- E,
141
+ local_e,
144
142
  a1_scale,
145
143
  a2_scale,
146
144
  expert_map,
@@ -178,7 +176,7 @@ def cutlass_moe(
178
176
  w1_scale: torch.Tensor,
179
177
  w2_scale: torch.Tensor,
180
178
  topk_weights: torch.Tensor,
181
- topk_ids_: torch.Tensor,
179
+ topk_ids: torch.Tensor,
182
180
  a_strides1: torch.Tensor,
183
181
  b_strides1: torch.Tensor,
184
182
  c_strides1: torch.Tensor,
@@ -187,40 +185,32 @@ def cutlass_moe(
187
185
  c_strides2: torch.Tensor,
188
186
  s_strides13: torch.Tensor,
189
187
  s_strides2: torch.Tensor,
190
- start_expert_id: int,
191
- end_expert_id: int,
192
- E: int,
188
+ num_local_experts: int,
193
189
  a1_scale: Optional[torch.Tensor] = None,
194
190
  a2_scale: Optional[torch.Tensor] = None,
195
191
  expert_map: Optional[torch.Tensor] = None,
196
192
  apply_router_weight_on_input: bool = False,
197
193
  ):
198
- local_topk_ids = topk_ids_
199
- local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
194
+ topk_ids = expert_map[topk_ids]
200
195
  device = a.device
201
196
 
202
- local_num_experts = end_expert_id - start_expert_id + 1
203
197
  expert_offsets = torch.empty(
204
- (local_num_experts + 1), dtype=torch.int32, device=device
198
+ (num_local_experts + 1), dtype=torch.int32, device=device
205
199
  )
206
200
  problem_sizes1 = torch.empty(
207
- (local_num_experts, 3), dtype=torch.int32, device=device
201
+ (num_local_experts, 3), dtype=torch.int32, device=device
208
202
  )
209
203
  problem_sizes2 = torch.empty(
210
- (local_num_experts, 3), dtype=torch.int32, device=device
204
+ (num_local_experts, 3), dtype=torch.int32, device=device
211
205
  )
212
206
  return cutlass_w4a8_moe(
213
- start_expert_id,
214
- end_expert_id,
215
- E,
216
207
  a,
217
208
  w1_q,
218
209
  w2_q,
219
210
  w1_scale,
220
211
  w2_scale,
221
212
  topk_weights,
222
- topk_ids_,
223
- local_topk_ids,
213
+ topk_ids,
224
214
  a_strides1,
225
215
  b_strides1,
226
216
  c_strides1,