sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,511 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
15
+
16
+ """Inference-only NemotronH model."""
17
+
18
+ from collections.abc import Iterable
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ from sglang.srt.configs import NemotronHConfig
25
+ from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
26
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.activation import ReLU2
28
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
29
+ HybridLinearAttnBackend,
30
+ Mamba2AttnBackend,
31
+ )
32
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
33
+ from sglang.srt.layers.layernorm import RMSNorm
34
+ from sglang.srt.layers.linear import (
35
+ ColumnParallelLinear,
36
+ QKVParallelLinear,
37
+ RowParallelLinear,
38
+ )
39
+ from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.quantization import QuantizationConfig
41
+ from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.vocab_parallel_embedding import (
43
+ DEFAULT_VOCAB_PADDING_SIZE,
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
48
+ from sglang.srt.model_loader.weight_utils import (
49
+ default_weight_loader,
50
+ maybe_remap_kv_scale_name,
51
+ replace_prefix,
52
+ replace_substrings,
53
+ )
54
+ from sglang.srt.utils import add_prefix, make_layers_non_pp
55
+ from sglang.utils import logger
56
+
57
+
58
+ class NemotronHMLP(nn.Module):
59
+ def __init__(
60
+ self,
61
+ config: NemotronHConfig,
62
+ layer_idx: int,
63
+ quant_config: Optional[QuantizationConfig] = None,
64
+ bias: bool = False,
65
+ prefix: str = "",
66
+ ) -> None:
67
+ super().__init__()
68
+
69
+ hybrid_override_pattern = config.hybrid_override_pattern
70
+ mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
71
+ if isinstance(config.intermediate_size, list):
72
+ if len(config.intermediate_size) == 1:
73
+ intermediate_size = config.intermediate_size[0]
74
+ else:
75
+ intermediate_size = config.intermediate_size[mlp_index]
76
+ else:
77
+ intermediate_size = config.intermediate_size
78
+
79
+ self.up_proj = ColumnParallelLinear(
80
+ input_size=config.hidden_size,
81
+ output_size=intermediate_size,
82
+ bias=bias,
83
+ quant_config=quant_config,
84
+ prefix=f"{prefix}.up_proj",
85
+ )
86
+ self.down_proj = RowParallelLinear(
87
+ input_size=intermediate_size,
88
+ output_size=config.hidden_size,
89
+ bias=bias,
90
+ quant_config=quant_config,
91
+ prefix=f"{prefix}.down_proj",
92
+ )
93
+ self.act_fn = ReLU2()
94
+
95
+ def forward(self, x: torch.Tensor):
96
+ x, _ = self.up_proj(x)
97
+ x = self.act_fn(x)
98
+ x, _ = self.down_proj(x)
99
+ return x
100
+
101
+
102
+ class NemotronHMLPDecoderLayer(nn.Module):
103
+ def __init__(
104
+ self,
105
+ config: NemotronHConfig,
106
+ layer_idx: int,
107
+ quant_config: Optional[QuantizationConfig] = None,
108
+ prefix: str = "",
109
+ ) -> None:
110
+ super().__init__()
111
+ self.config = config
112
+
113
+ self.mixer = NemotronHMLP(
114
+ config,
115
+ quant_config=quant_config,
116
+ bias=config.mlp_bias,
117
+ prefix=f"{prefix}.mixer",
118
+ layer_idx=layer_idx,
119
+ )
120
+
121
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
122
+
123
+ def forward(
124
+ self,
125
+ *,
126
+ hidden_states: torch.Tensor,
127
+ residual: Optional[torch.Tensor],
128
+ forward_batch: ForwardBatch,
129
+ ) -> tuple[torch.Tensor, torch.Tensor]:
130
+ if residual is None:
131
+ residual = hidden_states
132
+ hidden_states = self.norm(hidden_states)
133
+ else:
134
+ hidden_states, residual = self.norm(hidden_states, residual)
135
+
136
+ hidden_states = self.mixer.forward(hidden_states)
137
+ return hidden_states, residual
138
+
139
+
140
+ class NemotronHMambaDecoderLayer(nn.Module):
141
+ def __init__(
142
+ self,
143
+ config: NemotronHConfig,
144
+ layer_idx: int,
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ prefix: str = "",
147
+ ) -> None:
148
+ super().__init__()
149
+ self.config = config
150
+ self.layer_id = layer_idx
151
+ self.mixer = MambaMixer2(
152
+ cache_params=config.mamba2_cache_params,
153
+ hidden_size=config.hidden_size,
154
+ use_conv_bias=config.use_conv_bias,
155
+ use_bias=config.use_bias,
156
+ n_groups=config.mamba_n_groups,
157
+ rms_norm_eps=config.rms_norm_eps,
158
+ activation=config.mamba_hidden_act,
159
+ quant_config=quant_config,
160
+ prefix=f"{prefix}.mixer",
161
+ )
162
+
163
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
164
+
165
+ def forward(
166
+ self,
167
+ *,
168
+ hidden_states: torch.Tensor,
169
+ residual: Optional[torch.Tensor],
170
+ forward_batch: ForwardBatch,
171
+ ) -> tuple[torch.Tensor, torch.Tensor]:
172
+ if residual is None:
173
+ residual = hidden_states
174
+ hidden_states = self.norm(hidden_states)
175
+ else:
176
+ hidden_states, residual = self.norm(hidden_states, residual)
177
+
178
+ output = torch.empty_like(hidden_states)
179
+ attn_backend = forward_batch.attn_backend
180
+ assert isinstance(attn_backend, HybridLinearAttnBackend)
181
+ assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
182
+ attn_backend.linear_attn_backend.forward(
183
+ mixer=self.mixer,
184
+ layer_id=self.layer_id,
185
+ hidden_states=hidden_states,
186
+ output=output,
187
+ use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
188
+ )
189
+ return output, residual
190
+
191
+
192
+ class NemotronHAttention(nn.Module):
193
+ def __init__(
194
+ self,
195
+ config: NemotronHConfig,
196
+ layer_idx: int,
197
+ quant_config: Optional[QuantizationConfig] = None,
198
+ prefix: str = "",
199
+ ) -> None:
200
+ super().__init__()
201
+ self.hidden_size = config.hidden_size
202
+ tp_size = get_tensor_model_parallel_world_size()
203
+ self.total_num_heads = config.num_attention_heads
204
+ assert self.total_num_heads % tp_size == 0
205
+ self.num_heads = self.total_num_heads // tp_size
206
+ self.total_num_kv_heads = config.num_key_value_heads
207
+ if self.total_num_kv_heads >= tp_size:
208
+ # Number of KV heads is greater than TP size, so we partition
209
+ # the KV heads across multiple tensor parallel GPUs.
210
+ assert self.total_num_kv_heads % tp_size == 0
211
+ else:
212
+ # Number of KV heads is less than TP size, so we replicate
213
+ # the KV heads across multiple tensor parallel GPUs.
214
+ assert tp_size % self.total_num_kv_heads == 0
215
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
216
+ if hasattr(config, "head_dim") and config.head_dim is not None:
217
+ self.head_dim = config.head_dim
218
+ else:
219
+ self.head_dim = config.hidden_size // self.total_num_heads
220
+ self.q_size = self.num_heads * self.head_dim
221
+ self.kv_size = self.num_kv_heads * self.head_dim
222
+ self.scaling = self.head_dim**-0.5
223
+
224
+ self.qkv_proj = QKVParallelLinear(
225
+ config.hidden_size,
226
+ self.head_dim,
227
+ self.total_num_heads,
228
+ self.total_num_kv_heads,
229
+ bias=False,
230
+ quant_config=quant_config,
231
+ prefix=f"{prefix}.qkv_proj",
232
+ )
233
+ self.o_proj = RowParallelLinear(
234
+ self.total_num_heads * self.head_dim,
235
+ config.hidden_size,
236
+ bias=False,
237
+ quant_config=quant_config,
238
+ prefix=f"{prefix}.o_proj",
239
+ )
240
+
241
+ self.attn = RadixAttention(
242
+ self.num_heads,
243
+ self.head_dim,
244
+ self.scaling,
245
+ num_kv_heads=self.num_kv_heads,
246
+ layer_id=layer_idx,
247
+ quant_config=quant_config,
248
+ prefix=add_prefix("attn", prefix),
249
+ )
250
+
251
+ def forward(
252
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
253
+ ) -> torch.Tensor:
254
+ qkv, _ = self.qkv_proj(hidden_states)
255
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
256
+ attn_output = self.attn.forward(q, k, v, forward_batch)
257
+ output, _ = self.o_proj(attn_output)
258
+ return output
259
+
260
+
261
+ class NemotronHAttentionDecoderLayer(nn.Module):
262
+ def __init__(
263
+ self,
264
+ config: NemotronHConfig,
265
+ layer_idx: int,
266
+ quant_config: Optional[QuantizationConfig] = None,
267
+ prefix: str = "",
268
+ ) -> None:
269
+ super().__init__()
270
+
271
+ self.mixer = NemotronHAttention(
272
+ config,
273
+ layer_idx,
274
+ quant_config,
275
+ prefix=f"{prefix}.mixer",
276
+ )
277
+
278
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
279
+
280
+ def forward(
281
+ self,
282
+ *,
283
+ hidden_states: torch.Tensor,
284
+ residual: Optional[torch.Tensor],
285
+ forward_batch: ForwardBatch,
286
+ ) -> tuple[torch.Tensor, torch.Tensor]:
287
+ if residual is None:
288
+ residual = hidden_states
289
+ hidden_states = self.norm(hidden_states)
290
+ else:
291
+ hidden_states, residual = self.norm(hidden_states, residual)
292
+
293
+ hidden_states = self.mixer.forward(
294
+ hidden_states=hidden_states, forward_batch=forward_batch
295
+ )
296
+ return hidden_states, residual
297
+
298
+
299
+ Layers = (
300
+ NemotronHAttentionDecoderLayer
301
+ | NemotronHMLPDecoderLayer
302
+ | NemotronHMambaDecoderLayer
303
+ )
304
+ ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
305
+ ATTENTION: NemotronHAttentionDecoderLayer,
306
+ MLP: NemotronHMLPDecoderLayer,
307
+ MAMBA: NemotronHMambaDecoderLayer,
308
+ }
309
+
310
+
311
+ class NemotronHModel(nn.Module):
312
+ def __init__(
313
+ self,
314
+ *,
315
+ config: NemotronHConfig,
316
+ quant_config: Optional[QuantizationConfig] = None,
317
+ prefix: str = "",
318
+ ):
319
+ super().__init__()
320
+
321
+ lora_config = None
322
+ self.config = config
323
+ lora_vocab = (
324
+ (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
325
+ if lora_config
326
+ else 0
327
+ )
328
+ self.vocab_size = config.vocab_size + lora_vocab
329
+ self.org_vocab_size = config.vocab_size
330
+
331
+ self.embed_tokens = VocabParallelEmbedding(
332
+ self.vocab_size,
333
+ config.hidden_size,
334
+ org_num_embeddings=config.vocab_size,
335
+ )
336
+
337
+ def get_layer(idx: int, prefix: str):
338
+ layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
339
+ return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
340
+
341
+ self.layers = make_layers_non_pp(
342
+ len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
343
+ )
344
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
345
+
346
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
347
+ return self.embed_tokens(input_ids)
348
+
349
+ def forward(
350
+ self,
351
+ input_ids: torch.Tensor,
352
+ positions: torch.Tensor,
353
+ forward_batch: ForwardBatch,
354
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
355
+ inputs_embeds: Optional[torch.Tensor] = None,
356
+ ) -> Union[torch.Tensor, PPProxyTensors]:
357
+ if get_pp_group().is_first_rank:
358
+ if inputs_embeds is not None:
359
+ hidden_states = inputs_embeds
360
+ else:
361
+ hidden_states = self.get_input_embeddings(input_ids)
362
+ residual = None
363
+ else:
364
+ assert pp_proxy_tensors is not None
365
+ hidden_states = pp_proxy_tensors["hidden_states"]
366
+ residual = pp_proxy_tensors["residual"]
367
+
368
+ residual = None
369
+ for layer in self.layers:
370
+ if not isinstance(layer, Layers):
371
+ raise ValueError(f"Unknown layer type: {type(layer)}")
372
+ hidden_states, residual = layer.forward(
373
+ hidden_states=hidden_states,
374
+ residual=residual,
375
+ forward_batch=forward_batch,
376
+ )
377
+
378
+ if not get_pp_group().is_last_rank:
379
+ return PPProxyTensors(
380
+ {"hidden_states": hidden_states, "residual": residual}
381
+ )
382
+ hidden_states, _ = self.norm_f(hidden_states, residual)
383
+ return hidden_states
384
+
385
+
386
+ class NemotronHForCausalLM(nn.Module):
387
+ stacked_params_mapping = [
388
+ # (param_name, shard_name, shard_id)
389
+ ("qkv_proj", "q_proj", "q"),
390
+ ("qkv_proj", "k_proj", "k"),
391
+ ("qkv_proj", "v_proj", "v"),
392
+ ]
393
+ packed_modules_mapping = {
394
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
395
+ }
396
+
397
+ remap_prefix = {"backbone": "model"}
398
+ remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
399
+
400
+ def __init__(
401
+ self,
402
+ *,
403
+ config: NemotronHConfig,
404
+ quant_config: Optional[QuantizationConfig] = None,
405
+ prefix: str = "",
406
+ ):
407
+ super().__init__()
408
+ lora_config = None
409
+ self.config = config
410
+ self.model = self._init_model(
411
+ config=config, quant_config=quant_config, prefix=prefix
412
+ )
413
+ if self.config.tie_word_embeddings:
414
+ self.lm_head = self.model.embed_tokens
415
+ else:
416
+ self.unpadded_vocab_size = config.vocab_size
417
+ if lora_config:
418
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
419
+ self.lm_head = ParallelLMHead(
420
+ self.unpadded_vocab_size,
421
+ config.hidden_size,
422
+ org_num_embeddings=config.vocab_size,
423
+ padding_size=(
424
+ DEFAULT_VOCAB_PADDING_SIZE
425
+ # We need bigger padding if using lora for kernel
426
+ # compatibility
427
+ if not lora_config
428
+ else lora_config.lora_vocab_padding_size
429
+ ),
430
+ quant_config=quant_config,
431
+ prefix=add_prefix("lm_head", prefix),
432
+ )
433
+ self.logits_processor = LogitsProcessor(config)
434
+
435
+ def _init_model(
436
+ self,
437
+ config: NemotronHConfig,
438
+ quant_config: Optional[QuantizationConfig] = None,
439
+ prefix: str = "",
440
+ ):
441
+ return NemotronHModel(
442
+ config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
443
+ )
444
+
445
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
446
+ return self.model.get_input_embeddings(input_ids)
447
+
448
+ @torch.no_grad()
449
+ def forward(
450
+ self,
451
+ input_ids: torch.Tensor,
452
+ positions: torch.Tensor,
453
+ forward_batch: ForwardBatch,
454
+ input_embeds: Optional[torch.Tensor] = None,
455
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
456
+ ):
457
+ hidden_states = self.model.forward(
458
+ input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
459
+ )
460
+ return self.logits_processor(
461
+ input_ids, hidden_states, self.lm_head, forward_batch
462
+ )
463
+
464
+ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
465
+ return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
466
+
467
+ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
468
+ return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
469
+
470
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
471
+ updated_weights = []
472
+ for name, loaded_weight in weights:
473
+ name = replace_prefix(name, self.remap_prefix)
474
+ name = replace_substrings(name, self.remap_substr)
475
+ updated_weights.append((name, loaded_weight))
476
+ params_dict = dict(self.named_parameters())
477
+
478
+ for name, loaded_weight in updated_weights:
479
+ if "scale" in name:
480
+ name = maybe_remap_kv_scale_name(name, params_dict)
481
+ if name is None:
482
+ continue
483
+
484
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
485
+ if weight_name not in name:
486
+ continue
487
+ name = name.replace(weight_name, param_name)
488
+ # Skip loading extra bias for GPTQ models.
489
+ if name.endswith(".bias") and name not in params_dict:
490
+ continue
491
+ if name not in params_dict:
492
+ continue
493
+ param = params_dict[name]
494
+ weight_loader = param.weight_loader
495
+ weight_loader(param, loaded_weight, shard_id)
496
+ break
497
+ else:
498
+ # Skip loading extra bias for GPTQ models.
499
+ if name.endswith(".bias") and name not in params_dict:
500
+ continue
501
+ if name in params_dict.keys():
502
+ param = params_dict[name]
503
+ weight_loader = getattr(
504
+ param, "weight_loader", default_weight_loader
505
+ )
506
+ weight_loader(param, loaded_weight)
507
+ else:
508
+ logger.warning(f"Parameter {name} not found in params_dict")
509
+
510
+
511
+ EntryClass = [NemotronHForCausalLM]
@@ -48,6 +48,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
48
48
  from sglang.srt.utils import add_prefix, make_layers
49
49
 
50
50
 
51
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
52
+ # SGLang assumes exclusive
53
+ def get_attention_sliding_window_size(config):
54
+ return config.sliding_window - 1 if hasattr(config, "sliding_window") else None
55
+
56
+
51
57
  class Olmo2Attention(nn.Module):
52
58
  """
53
59
  This is the attention block where the output is computed as
@@ -85,6 +91,8 @@ class Olmo2Attention(nn.Module):
85
91
  self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
86
92
 
87
93
  self.head_dim = self.hidden_size // self.total_num_heads
94
+ self.q_size = self.num_heads * self.head_dim
95
+ self.kv_size = self.num_kv_heads * self.head_dim
88
96
  self.max_position_embeddings = config.max_position_embeddings
89
97
  self.rope_theta = config.rope_theta
90
98
 
@@ -104,12 +112,26 @@ class Olmo2Attention(nn.Module):
104
112
  eps=self.config.rms_norm_eps,
105
113
  )
106
114
  self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
107
- # Rotary embeddings.
115
+
116
+ sliding_window = None
117
+ if (
118
+ layer_types := getattr(self.config, "layer_types", None)
119
+ ) is not None and layer_types[layer_id] == "sliding_attention":
120
+ sliding_window = get_attention_sliding_window_size(self.config)
121
+
122
+ # Rotary embeddings. Rope scaling is only applied on full attention
123
+ # layers.
124
+ self.rope_scaling = (
125
+ self.config.rope_scaling
126
+ if sliding_window is None
127
+ else {"rope_type": "default"}
128
+ )
108
129
  self.rotary_emb = get_rope(
109
130
  self.head_dim,
110
131
  rotary_dim=self.head_dim,
111
132
  max_position=self.max_position_embeddings,
112
133
  base=self.rope_theta,
134
+ rope_scaling=self.rope_scaling,
113
135
  )
114
136
  self.scaling = self.head_dim**-0.5
115
137
  self.attn = RadixAttention(
@@ -118,6 +140,7 @@ class Olmo2Attention(nn.Module):
118
140
  self.scaling,
119
141
  num_kv_heads=self.num_kv_heads,
120
142
  layer_id=layer_id,
143
+ sliding_window_size=sliding_window,
121
144
  quant_config=quant_config,
122
145
  prefix=add_prefix("attn", prefix),
123
146
  )
@@ -152,7 +175,7 @@ class Olmo2Attention(nn.Module):
152
175
  forward_batch: ForwardBatch,
153
176
  ) -> torch.Tensor:
154
177
  qkv, _ = self.qkv_proj(hidden_states)
155
- q, k, v = qkv.chunk(chunks=3, dim=-1)
178
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
156
179
  q, k = self._apply_qk_norm(q, k)
157
180
  q, k = self.rotary_emb(positions, q, k)
158
181
  attn_output = self.attn(q, k, v, forward_batch)
@@ -224,6 +247,7 @@ class Olmo2DecoderLayer(nn.Module):
224
247
  prefix: str = "",
225
248
  ):
226
249
  super().__init__()
250
+ self.layer_id = layer_id
227
251
  # Attention block.
228
252
  self.self_attn = Olmo2Attention(
229
253
  config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
@@ -280,8 +304,8 @@ class Olmo2Model(nn.Module):
280
304
  self.layers = make_layers(
281
305
  config.num_hidden_layers,
282
306
  lambda idx, prefix: Olmo2DecoderLayer(
283
- layer_id=idx,
284
307
  config=config,
308
+ layer_id=idx,
285
309
  quant_config=quant_config,
286
310
  prefix=prefix,
287
311
  ),
@@ -294,7 +318,7 @@ class Olmo2Model(nn.Module):
294
318
  input_ids: torch.Tensor,
295
319
  positions: torch.Tensor,
296
320
  forward_batch: ForwardBatch,
297
- input_embeds: torch.Tensor = None,
321
+ input_embeds: Optional[torch.Tensor] = None,
298
322
  ) -> torch.Tensor:
299
323
  """
300
324
  :param input_ids: A tensor of shape `(batch_size, seq_len)`.
@@ -351,6 +375,9 @@ class Olmo2ForCausalLM(nn.Module):
351
375
  )
352
376
  self.logits_processor = LogitsProcessor(config)
353
377
 
378
+ def get_attention_sliding_window_size(self):
379
+ return get_attention_sliding_window_size(self.config)
380
+
354
381
  def forward(
355
382
  self,
356
383
  input_ids: torch.Tensor,
sglang/srt/models/opt.py CHANGED
@@ -13,11 +13,11 @@
13
13
  # ==============================================================================
14
14
 
15
15
  """Inference-only OPT model compatible with HuggingFace weights."""
16
+ import logging
16
17
  from collections.abc import Iterable
17
18
  from typing import Optional, Union
18
19
 
19
20
  import torch
20
- import torch.nn.functional as F
21
21
  from torch import nn
22
22
  from transformers import OPTConfig
23
23
 
@@ -26,10 +26,8 @@ from sglang.srt.distributed import (
26
26
  get_tensor_model_parallel_rank,
27
27
  get_tensor_model_parallel_world_size,
28
28
  )
29
- from sglang.srt.layers.activation import get_act_fn
30
29
  from sglang.srt.layers.linear import (
31
30
  ColumnParallelLinear,
32
- MergedColumnParallelLinear,
33
31
  QKVParallelLinear,
34
32
  ReplicatedLinear,
35
33
  RowParallelLinear,
@@ -38,7 +36,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
38
36
  from sglang.srt.layers.pooler import Pooler, PoolingType
39
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
38
  from sglang.srt.layers.radix_attention import RadixAttention
41
- from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
39
+ from sglang.srt.layers.utils import get_layer_id
42
40
  from sglang.srt.layers.vocab_parallel_embedding import (
43
41
  ParallelLMHead,
44
42
  VocabParallelEmbedding,
@@ -47,9 +45,11 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
47
45
  from sglang.srt.model_loader.weight_utils import (
48
46
  default_weight_loader,
49
47
  kv_cache_scales_loader,
50
- maybe_remap_kv_scale_name,
51
48
  )
52
49
  from sglang.srt.utils import add_prefix, make_layers
50
+ from sglang.utils import get_exception_traceback
51
+
52
+ logger = logging.getLogger(__name__)
53
53
 
54
54
 
55
55
  def get_activation(name="relu"):
sglang/srt/models/phi.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi.py
2
- from typing import Iterable, Optional, Union
2
+ from typing import Iterable, Optional
3
3
 
4
4
  import torch
5
5
  from torch import nn
@@ -24,7 +24,7 @@ from typing import List, Optional, Tuple
24
24
  import numpy as np
25
25
  import torch
26
26
  from torch import nn
27
- from transformers import PretrainedConfig, SiglipVisionConfig
27
+ from transformers import PretrainedConfig
28
28
 
29
29
  from sglang.srt.layers.quantization import QuantizationConfig
30
30
  from sglang.srt.managers.mm_utils import (
@@ -18,7 +18,6 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
18
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
19
  from sglang.srt.layers.radix_attention import RadixAttention
20
20
  from sglang.srt.layers.rotary_embedding import get_rope
21
- from sglang.srt.layers.utils import PPMissingLayer
22
21
  from sglang.srt.layers.vocab_parallel_embedding import (
23
22
  DEFAULT_VOCAB_PADDING_SIZE,
24
23
  ParallelLMHead,
@@ -16,13 +16,10 @@
16
16
  Using mistral-community/pixtral-12b as reference.
17
17
  """
18
18
 
19
- import logging
20
- import math
21
19
  from typing import Iterable, List, Optional, Set, Tuple, Union
22
20
 
23
21
  import torch
24
22
  import torch.nn as nn
25
- import torch.nn.functional as F
26
23
  from transformers import PixtralVisionConfig, PretrainedConfig
27
24
  from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
28
25
  from transformers.models.pixtral.modeling_pixtral import (