sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  import math
2
- import os
3
2
  from dataclasses import dataclass
4
3
  from typing import Dict, List, Optional, Tuple
5
4
 
@@ -12,6 +11,8 @@ from transformers import (
12
11
  ProcessorMixin,
13
12
  )
14
13
 
14
+ from sglang.srt.configs.deepseek_ocr import BASE_SIZE, IMAGE_SIZE, MAX_CROPS, MIN_CROPS
15
+
15
16
 
16
17
  def select_best_resolution(image_size, candidate_resolutions):
17
18
  # used for cropping
@@ -62,6 +63,7 @@ class DictOutput(object):
62
63
  class VLChatProcessorOutput(DictOutput):
63
64
  input_ids: torch.LongTensor
64
65
  target_ids: torch.LongTensor
66
+ images_crop: torch.LongTensor
65
67
  pixel_values: (
66
68
  torch.Tensor
67
69
  ) # rename from "images" to "pixel_values" for compatibility
@@ -105,6 +107,68 @@ class ImageTransform(object):
105
107
  return x
106
108
 
107
109
 
110
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
111
+ best_ratio_diff = float("inf")
112
+ best_ratio = (1, 1)
113
+ area = width * height
114
+ for ratio in target_ratios:
115
+ target_aspect_ratio = ratio[0] / ratio[1]
116
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
117
+ if ratio_diff < best_ratio_diff:
118
+ best_ratio_diff = ratio_diff
119
+ best_ratio = ratio
120
+ elif ratio_diff == best_ratio_diff:
121
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
122
+ best_ratio = ratio
123
+ return best_ratio
124
+
125
+
126
+ def dynamic_preprocess(
127
+ image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
128
+ ):
129
+ orig_width, orig_height = image.size
130
+ aspect_ratio = orig_width / orig_height
131
+
132
+ # calculate the existing image aspect ratio
133
+ target_ratios = set(
134
+ (i, j)
135
+ for n in range(min_num, max_num + 1)
136
+ for i in range(1, n + 1)
137
+ for j in range(1, n + 1)
138
+ if i * j <= max_num and i * j >= min_num
139
+ )
140
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
141
+
142
+ # find the closest aspect ratio to the target
143
+ target_aspect_ratio = find_closest_aspect_ratio(
144
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
145
+ )
146
+
147
+ # calculate the target width and height
148
+ target_width = image_size * target_aspect_ratio[0]
149
+ target_height = image_size * target_aspect_ratio[1]
150
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
151
+
152
+ # resize the image
153
+ resized_img = image.resize((target_width, target_height))
154
+ processed_images = []
155
+ for i in range(blocks):
156
+ box = (
157
+ (i % (target_width // image_size)) * image_size,
158
+ (i // (target_width // image_size)) * image_size,
159
+ ((i % (target_width // image_size)) + 1) * image_size,
160
+ ((i // (target_width // image_size)) + 1) * image_size,
161
+ )
162
+ # split the image
163
+ split_img = resized_img.crop(box)
164
+ processed_images.append(split_img)
165
+ assert len(processed_images) == blocks
166
+ if use_thumbnail and len(processed_images) != 1:
167
+ thumbnail_img = image.resize((image_size, image_size))
168
+ processed_images.append(thumbnail_img)
169
+ return processed_images, target_aspect_ratio
170
+
171
+
108
172
  class DeepseekVLV2Processor(ProcessorMixin):
109
173
  tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
110
174
  attributes = ["tokenizer"]
@@ -134,7 +198,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
134
198
  self.image_std = image_std
135
199
  self.normalize = normalize
136
200
  self.downsample_ratio = downsample_ratio
137
-
201
+ self.base_size = BASE_SIZE
138
202
  self.image_transform = ImageTransform(
139
203
  mean=image_mean, std=image_std, normalize=normalize
140
204
  )
@@ -177,7 +241,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
177
241
  **kwargs,
178
242
  )
179
243
 
180
- def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
244
+ def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
181
245
  """play the role of format_messages_v2 and get_images_info in the last version"""
182
246
  tokenized_data = []
183
247
  masked_tokenized_data = [] # labels
@@ -187,35 +251,34 @@ class DeepseekVLV2Processor(ProcessorMixin):
187
251
 
188
252
  image_index = 0
189
253
  image_token_cnt = messages.count(self.image_token)
190
- tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
254
+ (
255
+ input_ids,
256
+ images,
257
+ images_crop,
258
+ seq_mask,
259
+ spatial_crop,
260
+ num_image_tokens,
261
+ image_shapes,
262
+ ) = self.tokenize_with_images(
191
263
  messages,
192
264
  pil_images[image_index : image_index + image_token_cnt],
193
265
  bos=True,
194
266
  eos=True,
195
267
  cropping=len(pil_images) <= 2,
196
- max_req_input_len=max_req_input_len,
197
268
  )
198
269
 
199
270
  image_index = image_token_cnt
200
- tokenized_data += tokenized_str
201
- if self.mask_prompt:
202
- masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
203
- else:
204
- masked_tokenized_data += tokenized_str
205
271
  images_list += images
206
272
  images_seq_mask += seq_mask
207
- images_spatial_crop += spatial_crop
208
-
209
- assert len(tokenized_data) == len(
210
- images_seq_mask
211
- ), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
273
+ images_spatial_crop = spatial_crop
212
274
 
213
275
  return (
214
- tokenized_data,
276
+ input_ids,
215
277
  masked_tokenized_data,
216
278
  images_list,
217
279
  images_seq_mask,
218
280
  images_spatial_crop,
281
+ images_crop,
219
282
  )
220
283
 
221
284
  @property
@@ -252,6 +315,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
252
315
  inference_mode: bool = True,
253
316
  system_prompt: str = "",
254
317
  max_req_input_len: int = -1,
318
+ cropping: bool = True,
255
319
  **kwargs,
256
320
  ):
257
321
  """
@@ -275,47 +339,22 @@ class DeepseekVLV2Processor(ProcessorMixin):
275
339
  - num_image_tokens (List[int]): the number of image tokens
276
340
  """
277
341
 
278
- assert (
279
- prompt is None or conversations is None
280
- ), "prompt and conversations cannot be used at the same time."
281
-
342
+ prompt = conversations or prompt
282
343
  (
283
- tokenized_str,
344
+ input_ids,
284
345
  masked_tokenized_str,
285
346
  images_list,
286
347
  images_seq_mask,
287
348
  images_spatial_crop,
288
- ) = self.format_messages_v2(conversations, images, max_req_input_len)
349
+ images_crop,
350
+ ) = self.format_messages_v2(prompt, images, max_req_input_len)
289
351
 
290
- assert (
291
- len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
292
- ), (
293
- f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
294
- f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
295
- )
296
-
297
- input_ids = torch.LongTensor(tokenized_str)
298
352
  target_ids = torch.LongTensor(masked_tokenized_str)
299
- images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
300
-
301
- # set input_ids < 0 | input_ids == self.image_token_id as ignore_id
302
- target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
303
- self.ignore_id
304
- )
305
- input_ids[input_ids < 0] = self.pad_id
306
-
307
- if inference_mode:
308
- assert input_ids[-1] == self.eos_id
309
- input_ids = input_ids[:-1]
310
- target_ids = target_ids[:-1]
311
- images_seq_mask = images_seq_mask[:-1]
312
353
 
313
354
  if len(images_list) == 0:
314
355
  images = torch.zeros((1, 3, self.image_size, self.image_size))
315
- images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
316
356
  else:
317
357
  images = torch.stack(images_list, dim=0)
318
- images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
319
358
 
320
359
  images_spatial_crop = torch.stack(
321
360
  [images_spatial_crop], dim=0
@@ -324,6 +363,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
324
363
  prepare = VLChatProcessorOutput(
325
364
  input_ids=input_ids,
326
365
  target_ids=target_ids,
366
+ images_crop=images_crop,
327
367
  pixel_values=images,
328
368
  images_seq_mask=images_seq_mask,
329
369
  images_spatial_crop=images_spatial_crop,
@@ -341,10 +381,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
341
381
  inference_mode: bool = True,
342
382
  system_prompt: str = "",
343
383
  max_req_input_len: int = -1,
384
+ text: list[str] = None,
344
385
  **kwargs,
345
386
  ):
387
+ assert text is None or isinstance(text, list)
388
+ if text is not None:
389
+ text = text[0]
346
390
  prepare = self.process_one(
347
- prompt=prompt,
391
+ prompt=prompt or text,
348
392
  conversations=conversations,
349
393
  images=images,
350
394
  apply_sft_format=apply_sft_format,
@@ -369,85 +413,83 @@ class DeepseekVLV2Processor(ProcessorMixin):
369
413
  bos: bool = True,
370
414
  eos: bool = True,
371
415
  cropping: bool = True,
372
- max_req_input_len: int = -1,
373
416
  ):
374
417
  """Tokenize text with <image> tags."""
375
- images_list, images_seq_mask, images_spatial_crop = [], [], []
418
+
419
+ conversation = conversation
420
+ assert conversation.count(self.image_token) == len(images)
376
421
  text_splits = conversation.split(self.image_token)
422
+ images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
423
+ [],
424
+ [],
425
+ [],
426
+ [],
427
+ )
428
+ image_shapes = []
429
+ num_image_tokens = []
377
430
  tokenized_str = []
378
431
  for text_sep, image in zip(text_splits, images):
379
432
  """encode text_sep"""
380
433
  tokenized_sep = self.encode(text_sep, bos=False, eos=False)
434
+
381
435
  tokenized_str += tokenized_sep
382
436
  images_seq_mask += [False] * len(tokenized_sep)
383
437
 
384
- """select best resolution for anyres"""
385
- if cropping:
386
- best_width, best_height = select_best_resolution(
387
- image.size, self.candidate_resolutions
388
- )
438
+ image_shapes.append(image.size)
439
+
440
+ if image.size[0] <= 640 and image.size[1] <= 640:
441
+ crop_ratio = [1, 1]
389
442
  else:
390
- best_width, best_height = self.image_size, self.image_size
391
- # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
443
+ if cropping:
444
+ images_crop_raw, crop_ratio = dynamic_preprocess(
445
+ image, image_size=IMAGE_SIZE
446
+ )
447
+ else:
448
+ crop_ratio = [1, 1]
392
449
 
393
450
  """process the global view"""
451
+ if self.image_size <= 640 and not cropping:
452
+ image = image.resize((self.image_size, self.image_size))
453
+
394
454
  global_view = ImageOps.pad(
395
455
  image,
396
- (self.image_size, self.image_size),
456
+ (self.base_size, self.base_size),
397
457
  color=tuple(int(x * 255) for x in self.image_transform.mean),
398
458
  )
399
459
  images_list.append(self.image_transform(global_view))
400
460
 
401
- """process the local views"""
402
- local_view = ImageOps.pad(
403
- image,
404
- (best_width, best_height),
405
- color=tuple(int(x * 255) for x in self.image_transform.mean),
406
- )
407
- for i in range(0, best_height, self.image_size):
408
- for j in range(0, best_width, self.image_size):
409
- images_list.append(
410
- self.image_transform(
411
- local_view.crop(
412
- (j, i, j + self.image_size, i + self.image_size)
413
- )
414
- )
415
- )
416
-
417
- """record height / width crop num"""
418
- num_width_tiles, num_height_tiles = (
419
- best_width // self.image_size,
420
- best_height // self.image_size,
421
- )
461
+ num_width_tiles, num_height_tiles = crop_ratio
422
462
  images_spatial_crop.append([num_width_tiles, num_height_tiles])
423
463
 
464
+ if num_width_tiles > 1 or num_height_tiles > 1:
465
+ for i in range(len(images_crop_raw)):
466
+ images_crop_list.append(self.image_transform(images_crop_raw[i]))
467
+
424
468
  """add image tokens"""
425
- h = w = math.ceil(
469
+ num_queries = math.ceil(
426
470
  (self.image_size // self.patch_size) / self.downsample_ratio
427
471
  )
428
- # global views tokens h * (w + 1), 1 is for line separator
429
- tokenized_image = [self.image_token_id] * h * (w + 1)
430
- # add a separator between global and local views
431
- tokenized_image += [self.image_token_id]
432
- # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
433
- tokenized_image += (
434
- [self.image_token_id]
435
- * (num_height_tiles * h)
436
- * (num_width_tiles * w + 1)
472
+ num_queries_base = math.ceil(
473
+ (self.base_size // self.patch_size) / self.downsample_ratio
437
474
  )
438
475
 
476
+ tokenized_image = (
477
+ [self.image_token_id] * num_queries_base + [self.image_token_id]
478
+ ) * num_queries_base
479
+ tokenized_image += [self.image_token_id]
480
+ if num_width_tiles > 1 or num_height_tiles > 1:
481
+ tokenized_image += (
482
+ [self.image_token_id] * (num_queries * num_width_tiles)
483
+ + [self.image_token_id]
484
+ ) * (num_queries * num_height_tiles)
439
485
  tokenized_str += tokenized_image
486
+
440
487
  images_seq_mask += [True] * len(tokenized_image)
441
- # print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
488
+ num_image_tokens.append(len(tokenized_image))
442
489
 
443
490
  """process the last text split"""
444
491
  tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
445
- # deal with video, limit with request len
446
- if max_req_input_len > -1:
447
- if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
448
- rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
449
- tokenized_str = tokenized_str[:rest]
450
- images_seq_mask = images_seq_mask[:rest]
492
+
451
493
  tokenized_str += tokenized_sep
452
494
  images_seq_mask += [False] * len(tokenized_sep)
453
495
 
@@ -463,7 +505,64 @@ class DeepseekVLV2Processor(ProcessorMixin):
463
505
  images_seq_mask
464
506
  ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
465
507
 
466
- return tokenized_str, images_list, images_seq_mask, images_spatial_crop
508
+ masked_tokenized_str = []
509
+ for token_index in tokenized_str:
510
+ if token_index != self.image_token_id:
511
+ masked_tokenized_str.append(token_index)
512
+ else:
513
+ masked_tokenized_str.append(self.ignore_id)
514
+
515
+ assert (
516
+ len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
517
+ ), (
518
+ f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
519
+ f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
520
+ )
521
+ input_ids = torch.LongTensor(tokenized_str)
522
+ target_ids = torch.LongTensor(masked_tokenized_str)
523
+ images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
524
+
525
+ # set input_ids < 0 | input_ids == self.image_token_id as ignore_id
526
+ target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
527
+ self.ignore_id
528
+ )
529
+ input_ids[input_ids < 0] = self.pad_id
530
+
531
+ inference_mode = True
532
+
533
+ if inference_mode:
534
+ # Remove the ending eos token
535
+ assert input_ids[-1] == self.eos_id
536
+ input_ids = input_ids[:-1]
537
+ target_ids = target_ids[:-1]
538
+ images_seq_mask = images_seq_mask[:-1]
539
+
540
+ if len(images_list) == 0:
541
+ pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
542
+ images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
543
+ images_crop = torch.zeros(
544
+ (1, 3, self.image_size, self.image_size)
545
+ ).unsqueeze(0)
546
+ else:
547
+ pixel_values = torch.stack(images_list, dim=0)
548
+ images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
549
+ if images_crop_list:
550
+ images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
551
+ else:
552
+ images_crop = torch.zeros(
553
+ (1, 3, self.image_size, self.image_size)
554
+ ).unsqueeze(0)
555
+
556
+ input_ids = input_ids.unsqueeze(0)
557
+ return (
558
+ input_ids,
559
+ pixel_values,
560
+ images_crop,
561
+ images_seq_mask,
562
+ images_spatial_crop,
563
+ num_image_tokens,
564
+ image_shapes,
565
+ )
467
566
 
468
567
 
469
568
  class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
@@ -548,7 +647,6 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
548
647
 
549
648
 
550
649
  class DeepseekV2Config(PretrainedConfig):
551
-
552
650
  model_type = "deepseek_v2"
553
651
  keys_to_ignore_at_inference = ["past_key_values"]
554
652
 
@@ -1,10 +1,5 @@
1
- from typing import Any, List, Optional, Union
2
-
3
- from transformers import AutoProcessor, LlamaTokenizerFast, PretrainedConfig
4
- from transformers.feature_extraction_utils import BatchFeature
5
- from transformers.image_utils import ImageInput
6
- from transformers.processing_utils import ProcessingKwargs, Unpack
7
- from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
1
+ from transformers import AutoProcessor, PretrainedConfig
2
+ from transformers.processing_utils import ProcessingKwargs
8
3
 
9
4
  try:
10
5
  from transformers import Qwen2_5_VLProcessor
@@ -14,21 +14,12 @@
14
14
  # limitations under the License.
15
15
  """Falcon-H1 model configuration"""
16
16
 
17
- import enum
18
- import os
19
17
 
20
- import numpy as np
21
- import torch
22
18
  from transformers.configuration_utils import PretrainedConfig
23
- from transformers.modeling_rope_utils import rope_config_validation
24
19
  from transformers.utils import logging
25
20
 
26
- from sglang.srt.distributed.utils import divide
27
- from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
28
- from sglang.srt.layers.dp_attention import (
29
- get_attention_tp_size,
30
- get_tensor_model_parallel_world_size,
31
- )
21
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
22
+ from sglang.srt.layers.dp_attention import get_tensor_model_parallel_world_size
32
23
 
33
24
  logger = logging.get_logger(__name__)
34
25
 
@@ -214,7 +205,7 @@ class FalconH1Config(PretrainedConfig):
214
205
  self.rope_scaling = None
215
206
  self.rope_scaling = rope_scaling
216
207
  self.projectors_bias = projectors_bias
217
- mamba_intermediate = (
208
+ self.mamba_intermediate = mamba_intermediate = (
218
209
  mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
219
210
  )
220
211
 
@@ -294,18 +285,6 @@ class FalconH1Config(PretrainedConfig):
294
285
  def layers_block_type(self):
295
286
  return ["falcon_h1" for i in range(self.num_hidden_layers)]
296
287
 
297
- @property
298
- def mamba_cache_per_req(self):
299
- conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
300
- self.hybrid_gdn_params
301
- )
302
- mamba_layers_len = len(mamba_layers)
303
-
304
- return (
305
- int(np.prod(conv_state_shape)) * conv_dtype.itemsize
306
- + int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
307
- ) * mamba_layers_len
308
-
309
288
  @property
310
289
  def full_attention_layer_ids(self):
311
290
  # For Falcon-H1, we do have attention on all layers
@@ -317,44 +296,14 @@ class FalconH1Config(PretrainedConfig):
317
296
  return range(self.num_hidden_layers)
318
297
 
319
298
  @property
320
- def hybrid_gdn_params(self):
321
- world_size = get_tensor_model_parallel_world_size()
322
-
323
- n_groups = self.mamba_n_groups
324
- if self.mamba_n_groups % world_size != 0:
325
- # - for TP we shard conv_dim by sharding on n_groups,
326
- # - but if n_groups cannot divide tp_size, we need to
327
- # extend some extra groups
328
- extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
329
- self.mamba_n_groups, world_size
330
- )
331
- n_groups += extra_groups
332
-
333
- conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
334
-
335
- conv_state_shape = (
336
- divide(conv_dim, world_size),
337
- self.mamba_d_conv - 1,
338
- )
339
-
340
- # we TP-ize on the heads dimension
341
- temporal_state_shape = (
342
- self.mamba_d_state,
343
- self.mamba_d_head,
344
- divide(self.mamba_n_heads, world_size),
345
- )
346
- conv_dtype = torch.bfloat16
347
- dtype_map = {
348
- "float32": torch.float32,
349
- "bfloat16": torch.bfloat16,
350
- }
351
- ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
352
- mamba_layers = self.linear_layer_ids
353
-
354
- return (
355
- conv_state_shape,
356
- temporal_state_shape,
357
- conv_dtype,
358
- ssm_dtype,
359
- mamba_layers,
299
+ def mamba2_cache_params(self):
300
+ shape = Mamba2StateShape.create(
301
+ tp_world_size=get_tensor_model_parallel_world_size(),
302
+ intermediate_size=self.mamba_intermediate,
303
+ n_groups=self.mamba_n_groups,
304
+ num_heads=self.mamba_n_heads,
305
+ head_dim=self.mamba_d_head,
306
+ state_size=self.mamba_d_state,
307
+ conv_kernel=self.mamba_d_conv,
360
308
  )
309
+ return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
@@ -1,10 +1,12 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
2
2
  import enum
3
- import json
4
3
  import logging
5
4
  from dataclasses import dataclass, field
6
5
  from typing import List, Optional, Union
7
6
 
7
+ import orjson
8
+
9
+ from sglang.srt.configs.modelopt_config import ModelOptConfig
8
10
  from sglang.srt.utils import is_hip
9
11
 
10
12
  logger = logging.getLogger(__name__)
@@ -50,6 +52,11 @@ class LoadConfig:
50
52
  decryption_key_file: If set, decrypts the output files with a password read
51
53
  from this file (after PBKDF2).
52
54
  decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
55
+
56
+ # ModelOpt-specific loading options
57
+ modelopt_checkpoint_restore_path: Optional[str] = None
58
+ modelopt_checkpoint_save_path: Optional[str] = None
59
+ modelopt_export_path: Optional[str] = None
53
60
  """
54
61
 
55
62
  load_format: Union[str, LoadFormat] = LoadFormat.AUTO
@@ -63,10 +70,18 @@ class LoadConfig:
63
70
  remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
64
71
  remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
65
72
 
73
+ # ModelOpt-specific loading options
74
+ modelopt_checkpoint_restore_path: Optional[str] = None
75
+ modelopt_checkpoint_save_path: Optional[str] = None
76
+ modelopt_export_path: Optional[str] = None
77
+
78
+ # ModelOpt configuration object
79
+ modelopt_config: Optional[ModelOptConfig] = None
80
+
66
81
  def __post_init__(self):
67
82
  model_loader_extra_config = self.model_loader_extra_config or {}
68
83
  if isinstance(model_loader_extra_config, str):
69
- self.model_loader_extra_config = json.loads(model_loader_extra_config)
84
+ self.model_loader_extra_config = orjson.loads(model_loader_extra_config)
70
85
  self._verify_load_format()
71
86
 
72
87
  if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
@@ -77,6 +92,14 @@ class LoadConfig:
77
92
  else:
78
93
  self.ignore_patterns = ["original/**/*"]
79
94
 
95
+ # Create ModelOptConfig if not provided
96
+ if self.modelopt_config is None:
97
+ self.modelopt_config = ModelOptConfig(
98
+ checkpoint_restore_path=self.modelopt_checkpoint_restore_path,
99
+ checkpoint_save_path=self.modelopt_checkpoint_save_path,
100
+ export_path=self.modelopt_export_path,
101
+ )
102
+
80
103
  def _verify_load_format(self) -> None:
81
104
  if not isinstance(self.load_format, str):
82
105
  return