sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,355 @@
1
+ import itertools
2
+ import math
3
+ from collections.abc import Iterable
4
+ from typing import Any
5
+
6
+ import einops
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch import Tensor
11
+ from transformers.configuration_utils import PretrainedConfig
12
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
13
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
14
+ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
15
+
16
+ import sglang.srt.managers.mm_utils as mm_utils
17
+ import sglang.srt.model_loader.weight_utils as weight_utils
18
+ import sglang.srt.utils as utils
19
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
20
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
22
+ from sglang.srt.managers.schedule_batch import (
23
+ Modality,
24
+ MultimodalDataItem,
25
+ MultimodalInputs,
26
+ )
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
29
+
30
+ MM_HIDDEN_SIZE = 3456
31
+
32
+
33
+ class NVILAConfig(PretrainedConfig):
34
+ model_type = "nvila"
35
+ sub_configs = {
36
+ "text_config": Qwen2Config,
37
+ "vision_config": SiglipVisionConfig,
38
+ }
39
+ _auto_class = "AutoConfig"
40
+
41
+ def __init__(
42
+ self,
43
+ *,
44
+ text_config: dict[str, Any] | None = None,
45
+ vision_config: dict[str, Any] | None = None,
46
+ image_token_id: int | None = None,
47
+ video_token_id: int | None = None,
48
+ **kwargs,
49
+ ):
50
+ self.text_config = (
51
+ Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
52
+ )
53
+ self.vision_config = (
54
+ SiglipVisionConfig(**vision_config)
55
+ if vision_config is not None
56
+ else SiglipVisionConfig()
57
+ )
58
+
59
+ self.image_token_id = image_token_id if image_token_id is not None else -1
60
+ self.video_token_id = video_token_id if video_token_id is not None else -1
61
+
62
+ super().__init__(**kwargs)
63
+
64
+
65
+ class NVILAMultiModalProjectorDownsampleBlock(nn.Module):
66
+ def forward(self, x: Tensor) -> Tensor:
67
+ batch_size, sequence_length, hidden_size = x.shape
68
+
69
+ feat_size = math.isqrt(sequence_length)
70
+
71
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
72
+
73
+ pad_after = feat_size % 2
74
+ if pad_after > 0:
75
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
76
+ feat_size = feat_size + pad_after
77
+
78
+ features = features.reshape(
79
+ batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size
80
+ )
81
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
82
+ features = features.reshape(batch_size, -1, 4 * hidden_size)
83
+
84
+ return features
85
+
86
+
87
+ class NVILAMultiModalProjector(nn.Module):
88
+ def __init__(self, config: NVILAConfig):
89
+ super().__init__()
90
+
91
+ self.layers = nn.Sequential(
92
+ NVILAMultiModalProjectorDownsampleBlock(),
93
+ nn.LayerNorm(MM_HIDDEN_SIZE * 4),
94
+ nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size),
95
+ nn.GELU(),
96
+ nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
97
+ )
98
+
99
+ def forward(self, x: Tensor) -> Tensor:
100
+ return self.layers(x)
101
+
102
+
103
+ class NVILAForConditionalGeneration(nn.Module):
104
+ def __init__(
105
+ self,
106
+ config: NVILAConfig,
107
+ quant_config: QuantizationConfig | None = None,
108
+ prefix: str = "",
109
+ ) -> None:
110
+ super().__init__()
111
+
112
+ self.config = config
113
+
114
+ self.vision_tower = SiglipVisionModel(config.vision_config)
115
+ self.mm_projector = NVILAMultiModalProjector(config)
116
+ self.llm = Qwen2ForCausalLM(
117
+ config=config.text_config,
118
+ quant_config=quant_config,
119
+ prefix=utils.add_prefix("llm", prefix),
120
+ )
121
+
122
+ def forward(
123
+ self,
124
+ input_ids: Tensor,
125
+ positions: Tensor,
126
+ forward_batch: ForwardBatch,
127
+ get_embedding: bool = False,
128
+ ) -> LogitsProcessorOutput:
129
+ output = mm_utils.general_mm_embed_routine(
130
+ input_ids=input_ids,
131
+ forward_batch=forward_batch,
132
+ language_model=self.llm,
133
+ data_embedding_funcs={
134
+ Modality.IMAGE: self.get_image_feature,
135
+ Modality.VIDEO: self.get_image_feature,
136
+ },
137
+ get_embedding=get_embedding,
138
+ positions=positions,
139
+ )
140
+
141
+ assert isinstance(output, LogitsProcessorOutput)
142
+
143
+ return output
144
+
145
+ def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
146
+ block_sizes = (
147
+ list(
148
+ itertools.chain.from_iterable(
149
+ x.block_sizes for x in mm_input if hasattr(x, "block_sizes")
150
+ )
151
+ )
152
+ or None
153
+ )
154
+ pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
155
+
156
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
157
+ pixel_values.to(
158
+ device=self.vision_tower.device, dtype=self.vision_tower.dtype
159
+ ),
160
+ output_hidden_states=True,
161
+ )
162
+ assert vision_tower_output.hidden_states is not None
163
+
164
+ vision_features: Tensor = vision_tower_output.hidden_states[-2]
165
+
166
+ vision_features_list, block_sizes = merge_features_for_dynamic_s2(
167
+ vision_features,
168
+ block_sizes=(
169
+ block_sizes
170
+ if block_sizes is not None
171
+ else [None] * vision_features.shape[0]
172
+ ),
173
+ resize_output_to_scale_idx=-1,
174
+ scales=[448, 896, 1344],
175
+ )
176
+
177
+ vision_features_list = [
178
+ split_chessboard(x, block_size[0], block_size[1])
179
+ for x, block_size in zip(vision_features_list, block_sizes)
180
+ ]
181
+
182
+ vision_features = torch.cat(
183
+ [einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list]
184
+ )
185
+
186
+ vision_features = self.mm_projector(vision_features)
187
+
188
+ vision_features_list = list(
189
+ vision_features.split(
190
+ [block_size[0] * block_size[1] for block_size in block_sizes], dim=0
191
+ )
192
+ )
193
+ vision_features_list = [
194
+ merge_chessboard(x, block_size[0], block_size[1])
195
+ for x, block_size in zip(vision_features_list, block_sizes)
196
+ ]
197
+
198
+ vision_features = torch.stack(
199
+ [einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list]
200
+ )
201
+
202
+ vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
203
+
204
+ return vision_features
205
+
206
+ def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
207
+ params_dict = dict(self.named_parameters())
208
+
209
+ for name, loaded_weight in weights:
210
+ if name.startswith("llm."):
211
+ self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
212
+ else:
213
+ param = params_dict[name]
214
+ weight_loader = getattr(
215
+ param, "weight_loader", weight_utils.default_weight_loader
216
+ )
217
+ weight_loader(param, loaded_weight)
218
+
219
+ def pad_input_ids(
220
+ self, input_ids: list[int], mm_inputs: MultimodalInputs
221
+ ) -> list[int]:
222
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
223
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
224
+
225
+
226
+ def merge_chessboard(x, num_split_h, num_split_w):
227
+ """
228
+ x: b * n * c or b * h * w * c
229
+ out: b * c * h * w
230
+ Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
231
+ """
232
+ B = x.shape[0]
233
+ if x.dim() == 3:
234
+ N = x.shape[1]
235
+ x = einops.rearrange(
236
+ x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N)
237
+ )
238
+
239
+ assert B % (num_split_h * num_split_w) == 0
240
+ b = B // (num_split_h * num_split_w)
241
+
242
+ x_merge = torch.cat(
243
+ [
244
+ torch.cat(
245
+ [
246
+ x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b]
247
+ for j in range(num_split_w)
248
+ ],
249
+ dim=-1,
250
+ )
251
+ for i in range(num_split_h)
252
+ ],
253
+ dim=-2,
254
+ )
255
+
256
+ return x_merge
257
+
258
+
259
+ def merge_features_for_dynamic_s2(
260
+ image_features, block_sizes, *, scales, resize_output_to_scale_idx
261
+ ):
262
+ image_features_each_image = []
263
+ new_block_sizes = []
264
+ block_cnt = 0
265
+ for block_size_each_image in block_sizes:
266
+ if block_size_each_image is None:
267
+ cur_features = image_features[block_cnt : block_cnt + 1]
268
+ cur_features = einops.rearrange(
269
+ cur_features,
270
+ "1 (h w) c -> 1 c h w",
271
+ h=math.isqrt(cur_features.shape[1]),
272
+ )
273
+ cur_features = cur_features.repeat(1, len(scales), 1, 1)
274
+ image_features_each_image.append(cur_features)
275
+ new_block_sizes.append((1, 1))
276
+ block_cnt += 1
277
+ else:
278
+ cur_features_each_scale = []
279
+ for scale in scales[:-1]:
280
+ num_blocks_this_scale = (scale // scales[0]) ** 2
281
+ cur_features_each_scale.append(
282
+ merge_chessboard(
283
+ image_features[block_cnt : block_cnt + num_blocks_this_scale],
284
+ num_split_h=scale // scales[0],
285
+ num_split_w=scale // scales[0],
286
+ )
287
+ ) # 1 * C * H * W
288
+ block_cnt += num_blocks_this_scale
289
+ num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
290
+ cur_features_each_scale.append(
291
+ merge_chessboard(
292
+ image_features[block_cnt : block_cnt + num_blocks_last_scale],
293
+ num_split_h=block_size_each_image[0],
294
+ num_split_w=block_size_each_image[1],
295
+ )
296
+ ) # 1 * C * H * W
297
+ block_cnt += num_blocks_last_scale
298
+
299
+ # resize and concat features from different scales
300
+ output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
301
+ cur_features = torch.cat(
302
+ [
303
+ F.interpolate(
304
+ cur_features_each_scale[i].to(torch.float32),
305
+ size=output_size,
306
+ mode="area",
307
+ ).to(cur_features_each_scale[i].dtype)
308
+ for i in range(len(cur_features_each_scale))
309
+ ],
310
+ dim=1,
311
+ )
312
+
313
+ image_features_each_image.append(cur_features)
314
+
315
+ if (
316
+ resize_output_to_scale_idx == len(scales) - 1
317
+ or resize_output_to_scale_idx == -1
318
+ ):
319
+ new_block_sizes.append(block_size_each_image)
320
+ else:
321
+ new_block_sizes.append(
322
+ (
323
+ scales[resize_output_to_scale_idx] // scales[0],
324
+ scales[resize_output_to_scale_idx] // scales[0],
325
+ )
326
+ )
327
+
328
+ assert block_cnt == len(
329
+ image_features
330
+ ), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!"
331
+
332
+ return image_features_each_image, new_block_sizes
333
+
334
+
335
+ def split_chessboard(x, num_split_h, num_split_w):
336
+ """
337
+ x: b * c * h * w
338
+ out: b * c * h * w
339
+ Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
340
+ """
341
+ B, C, H, W = x.shape
342
+ assert H % num_split_h == 0 and W % num_split_w == 0
343
+ h, w = H // num_split_h, W // num_split_w
344
+ x_split = torch.cat(
345
+ [
346
+ x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w]
347
+ for i in range(num_split_h)
348
+ for j in range(num_split_w)
349
+ ],
350
+ dim=0,
351
+ )
352
+ return x_split
353
+
354
+
355
+ EntryClass = [NVILAForConditionalGeneration]
@@ -0,0 +1,184 @@
1
+ import math
2
+ from collections.abc import Iterable
3
+ from typing import Any
4
+
5
+ import einops
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
12
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
13
+ from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
14
+
15
+ import sglang.srt.managers.mm_utils as mm_utils
16
+ import sglang.srt.model_loader.weight_utils as weight_utils
17
+ import sglang.srt.utils as utils
18
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
19
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
+ from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
21
+ from sglang.srt.managers.schedule_batch import (
22
+ Modality,
23
+ MultimodalDataItem,
24
+ MultimodalInputs,
25
+ )
26
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
28
+
29
+ MM_HIDDEN_SIZE = 1152
30
+
31
+
32
+ class NVILALiteConfig(PretrainedConfig):
33
+ model_type = "nvila_lite"
34
+ sub_configs = {
35
+ "text_config": Qwen2Config,
36
+ "vision_config": SiglipVisionConfig,
37
+ }
38
+ _auto_class = "AutoConfig"
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ text_config: dict[str, Any] | None = None,
44
+ vision_config: dict[str, Any] | None = None,
45
+ image_token_id: int | None = None,
46
+ video_token_id: int | None = None,
47
+ **kwargs,
48
+ ):
49
+ self.text_config = (
50
+ Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
51
+ )
52
+ self.vision_config = (
53
+ SiglipVisionConfig(**vision_config)
54
+ if vision_config is not None
55
+ else SiglipVisionConfig()
56
+ )
57
+
58
+ self.image_token_id = image_token_id if image_token_id is not None else -1
59
+ self.video_token_id = video_token_id if video_token_id is not None else -1
60
+
61
+ super().__init__(**kwargs)
62
+
63
+
64
+ class NVILALiteMultiModalProjectorDownsampleBlock(nn.Module):
65
+ def forward(self, x: Tensor) -> Tensor:
66
+ batch_size, sequence_length, hidden_size = x.shape
67
+
68
+ feat_size = math.isqrt(sequence_length)
69
+
70
+ features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
71
+
72
+ pad_after = (3 - feat_size % 3) % 3
73
+ if pad_after > 0:
74
+ features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
75
+ feat_size = feat_size + pad_after
76
+
77
+ features = features.reshape(
78
+ batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
79
+ )
80
+ features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
81
+ features = features.reshape(batch_size, -1, 9 * hidden_size)
82
+
83
+ return features
84
+
85
+
86
+ class NVILALiteMultiModalProjector(nn.Module):
87
+ def __init__(self, config: NVILALiteConfig):
88
+ super().__init__()
89
+
90
+ self.layers = nn.Sequential(
91
+ NVILALiteMultiModalProjectorDownsampleBlock(),
92
+ nn.LayerNorm(MM_HIDDEN_SIZE * 9),
93
+ nn.Linear(MM_HIDDEN_SIZE * 9, MM_HIDDEN_SIZE * 3),
94
+ nn.GELU(),
95
+ nn.LayerNorm(MM_HIDDEN_SIZE * 3),
96
+ nn.Linear(MM_HIDDEN_SIZE * 3, config.text_config.hidden_size),
97
+ nn.GELU(),
98
+ nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
99
+ )
100
+
101
+ def forward(self, x: Tensor) -> Tensor:
102
+ return self.layers(x)
103
+
104
+
105
+ class NVILALiteForConditionalGeneration(nn.Module):
106
+ def __init__(
107
+ self,
108
+ config: NVILALiteConfig,
109
+ quant_config: QuantizationConfig | None = None,
110
+ prefix: str = "",
111
+ ) -> None:
112
+ super().__init__()
113
+
114
+ self.config = config
115
+
116
+ self.vision_tower = SiglipVisionModel(config.vision_config)
117
+ self.mm_projector = NVILALiteMultiModalProjector(config)
118
+ self.llm = Qwen2ForCausalLM(
119
+ config=config.text_config,
120
+ quant_config=quant_config,
121
+ prefix=utils.add_prefix("llm", prefix),
122
+ )
123
+
124
+ def forward(
125
+ self,
126
+ input_ids: Tensor,
127
+ positions: Tensor,
128
+ forward_batch: ForwardBatch,
129
+ get_embedding: bool = False,
130
+ ) -> LogitsProcessorOutput:
131
+ output = mm_utils.general_mm_embed_routine(
132
+ input_ids=input_ids,
133
+ forward_batch=forward_batch,
134
+ language_model=self.llm,
135
+ data_embedding_funcs={
136
+ Modality.IMAGE: self.get_image_feature,
137
+ Modality.VIDEO: self.get_image_feature,
138
+ },
139
+ get_embedding=get_embedding,
140
+ positions=positions,
141
+ )
142
+
143
+ assert isinstance(output, LogitsProcessorOutput)
144
+
145
+ return output
146
+
147
+ def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
148
+ pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
149
+
150
+ vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
151
+ pixel_values,
152
+ output_hidden_states=True,
153
+ )
154
+ assert vision_tower_output.hidden_states is not None
155
+
156
+ vision_features = vision_tower_output.hidden_states[-2]
157
+
158
+ vision_features = self.mm_projector(vision_features)
159
+
160
+ vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
161
+
162
+ return vision_features
163
+
164
+ def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
165
+ params_dict = dict(self.named_parameters())
166
+
167
+ for name, loaded_weight in weights:
168
+ if name.startswith("llm."):
169
+ self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
170
+ else:
171
+ param = params_dict[name]
172
+ weight_loader = getattr(
173
+ param, "weight_loader", weight_utils.default_weight_loader
174
+ )
175
+ weight_loader(param, loaded_weight)
176
+
177
+ def pad_input_ids(
178
+ self, input_ids: list[int], mm_inputs: MultimodalInputs
179
+ ) -> list[int]:
180
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
181
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
182
+
183
+
184
+ EntryClass = [NVILALiteForConditionalGeneration]
@@ -48,6 +48,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
48
48
  from sglang.srt.utils import add_prefix, make_layers
49
49
 
50
50
 
51
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
52
+ # SGLang assumes exclusive
53
+ def get_attention_sliding_window_size(config):
54
+ return config.sliding_window - 1 if hasattr(config, "sliding_window") else None
55
+
56
+
51
57
  class Olmo2Attention(nn.Module):
52
58
  """
53
59
  This is the attention block where the output is computed as
@@ -85,6 +91,8 @@ class Olmo2Attention(nn.Module):
85
91
  self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
86
92
 
87
93
  self.head_dim = self.hidden_size // self.total_num_heads
94
+ self.q_size = self.num_heads * self.head_dim
95
+ self.kv_size = self.num_kv_heads * self.head_dim
88
96
  self.max_position_embeddings = config.max_position_embeddings
89
97
  self.rope_theta = config.rope_theta
90
98
 
@@ -104,12 +112,26 @@ class Olmo2Attention(nn.Module):
104
112
  eps=self.config.rms_norm_eps,
105
113
  )
106
114
  self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
107
- # Rotary embeddings.
115
+
116
+ sliding_window = None
117
+ if (
118
+ layer_types := getattr(self.config, "layer_types", None)
119
+ ) is not None and layer_types[layer_id] == "sliding_attention":
120
+ sliding_window = get_attention_sliding_window_size(self.config)
121
+
122
+ # Rotary embeddings. Rope scaling is only applied on full attention
123
+ # layers.
124
+ self.rope_scaling = (
125
+ self.config.rope_scaling
126
+ if sliding_window is None
127
+ else {"rope_type": "default"}
128
+ )
108
129
  self.rotary_emb = get_rope(
109
130
  self.head_dim,
110
131
  rotary_dim=self.head_dim,
111
132
  max_position=self.max_position_embeddings,
112
133
  base=self.rope_theta,
134
+ rope_scaling=self.rope_scaling,
113
135
  )
114
136
  self.scaling = self.head_dim**-0.5
115
137
  self.attn = RadixAttention(
@@ -118,6 +140,7 @@ class Olmo2Attention(nn.Module):
118
140
  self.scaling,
119
141
  num_kv_heads=self.num_kv_heads,
120
142
  layer_id=layer_id,
143
+ sliding_window_size=sliding_window,
121
144
  quant_config=quant_config,
122
145
  prefix=add_prefix("attn", prefix),
123
146
  )
@@ -152,7 +175,7 @@ class Olmo2Attention(nn.Module):
152
175
  forward_batch: ForwardBatch,
153
176
  ) -> torch.Tensor:
154
177
  qkv, _ = self.qkv_proj(hidden_states)
155
- q, k, v = qkv.chunk(chunks=3, dim=-1)
178
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
156
179
  q, k = self._apply_qk_norm(q, k)
157
180
  q, k = self.rotary_emb(positions, q, k)
158
181
  attn_output = self.attn(q, k, v, forward_batch)
@@ -224,6 +247,7 @@ class Olmo2DecoderLayer(nn.Module):
224
247
  prefix: str = "",
225
248
  ):
226
249
  super().__init__()
250
+ self.layer_id = layer_id
227
251
  # Attention block.
228
252
  self.self_attn = Olmo2Attention(
229
253
  config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
@@ -280,8 +304,8 @@ class Olmo2Model(nn.Module):
280
304
  self.layers = make_layers(
281
305
  config.num_hidden_layers,
282
306
  lambda idx, prefix: Olmo2DecoderLayer(
283
- layer_id=idx,
284
307
  config=config,
308
+ layer_id=idx,
285
309
  quant_config=quant_config,
286
310
  prefix=prefix,
287
311
  ),
@@ -294,7 +318,7 @@ class Olmo2Model(nn.Module):
294
318
  input_ids: torch.Tensor,
295
319
  positions: torch.Tensor,
296
320
  forward_batch: ForwardBatch,
297
- input_embeds: torch.Tensor = None,
321
+ input_embeds: Optional[torch.Tensor] = None,
298
322
  ) -> torch.Tensor:
299
323
  """
300
324
  :param input_ids: A tensor of shape `(batch_size, seq_len)`.
@@ -351,6 +375,9 @@ class Olmo2ForCausalLM(nn.Module):
351
375
  )
352
376
  self.logits_processor = LogitsProcessor(config)
353
377
 
378
+ def get_attention_sliding_window_size(self):
379
+ return get_attention_sliding_window_size(self.config)
380
+
354
381
  def forward(
355
382
  self,
356
383
  input_ids: torch.Tensor,
sglang/srt/models/opt.py CHANGED
@@ -13,11 +13,11 @@
13
13
  # ==============================================================================
14
14
 
15
15
  """Inference-only OPT model compatible with HuggingFace weights."""
16
+ import logging
16
17
  from collections.abc import Iterable
17
18
  from typing import Optional, Union
18
19
 
19
20
  import torch
20
- import torch.nn.functional as F
21
21
  from torch import nn
22
22
  from transformers import OPTConfig
23
23
 
@@ -26,10 +26,8 @@ from sglang.srt.distributed import (
26
26
  get_tensor_model_parallel_rank,
27
27
  get_tensor_model_parallel_world_size,
28
28
  )
29
- from sglang.srt.layers.activation import get_act_fn
30
29
  from sglang.srt.layers.linear import (
31
30
  ColumnParallelLinear,
32
- MergedColumnParallelLinear,
33
31
  QKVParallelLinear,
34
32
  ReplicatedLinear,
35
33
  RowParallelLinear,
@@ -38,7 +36,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
38
36
  from sglang.srt.layers.pooler import Pooler, PoolingType
39
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
38
  from sglang.srt.layers.radix_attention import RadixAttention
41
- from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
39
+ from sglang.srt.layers.utils import get_layer_id
42
40
  from sglang.srt.layers.vocab_parallel_embedding import (
43
41
  ParallelLMHead,
44
42
  VocabParallelEmbedding,
@@ -47,9 +45,11 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
47
45
  from sglang.srt.model_loader.weight_utils import (
48
46
  default_weight_loader,
49
47
  kv_cache_scales_loader,
50
- maybe_remap_kv_scale_name,
51
48
  )
52
49
  from sglang.srt.utils import add_prefix, make_layers
50
+ from sglang.utils import get_exception_traceback
51
+
52
+ logger = logging.getLogger(__name__)
53
53
 
54
54
 
55
55
  def get_activation(name="relu"):
sglang/srt/models/phi.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi.py
2
- from typing import Iterable, Optional, Union
2
+ from typing import Iterable, Optional
3
3
 
4
4
  import torch
5
5
  from torch import nn
@@ -24,7 +24,7 @@ from typing import List, Optional, Tuple
24
24
  import numpy as np
25
25
  import torch
26
26
  from torch import nn
27
- from transformers import PretrainedConfig, SiglipVisionConfig
27
+ from transformers import PretrainedConfig
28
28
 
29
29
  from sglang.srt.layers.quantization import QuantizationConfig
30
30
  from sglang.srt.managers.mm_utils import (