sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,6 @@ from torch import nn
24
24
  from transformers import MixtralConfig
25
25
 
26
26
  from sglang.srt.distributed import (
27
- get_moe_expert_parallel_world_size,
28
27
  get_pp_group,
29
28
  get_tensor_model_parallel_world_size,
30
29
  tensor_model_parallel_all_reduce,
@@ -36,7 +35,6 @@ from sglang.srt.layers.linear import (
36
35
  RowParallelLinear,
37
36
  )
38
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
40
38
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
39
  from sglang.srt.layers.moe.topk import TopK
42
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -94,8 +92,7 @@ class MixtralMoE(nn.Module):
94
92
  renormalize=True,
95
93
  )
96
94
 
97
- MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
98
- self.experts = MoEImpl(
95
+ self.experts = FusedMoE(
99
96
  num_experts=num_experts,
100
97
  top_k=top_k,
101
98
  layer_id=layer_id,
@@ -901,7 +901,7 @@ class MllamaForConditionalGeneration(nn.Module):
901
901
  img = pixel_values[0, j]
902
902
  num_tiles = img.shape[0]
903
903
  batched_images[i, j, :num_tiles] = img
904
- batched_ar_ids[i, j] = mm_input.mm_items[0].aspect_ratio_id[0, j]
904
+ batched_ar_ids[i, j] = mm_input.mm_items[0].aspect_ratio_ids[0, j]
905
905
 
906
906
  batched_ar_mask[i, j, :num_tiles] = mm_input.mm_items[
907
907
  0
@@ -2,6 +2,7 @@ import json as json_lib
2
2
  import logging
3
3
  import math
4
4
  import os
5
+ import re
5
6
  from collections.abc import Iterable
6
7
  from typing import List, Optional, Set, Tuple
7
8
 
@@ -30,9 +31,9 @@ from sglang.srt.managers.schedule_batch import (
30
31
  Modality,
31
32
  MultimodalDataItem,
32
33
  MultimodalInputs,
33
- global_server_args_dict,
34
34
  )
35
35
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
+ from sglang.srt.server_args import get_global_server_args
36
37
  from sglang.srt.utils import is_cpu
37
38
 
38
39
  _is_cpu = is_cpu()
@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
422
423
  "gate_up_proj": ["gate_proj", "up_proj"],
423
424
  }
424
425
 
426
+ # Pattern to match language model layers only (skip vision_model and multi_modal_projector)
427
+ lora_pattern = re.compile(
428
+ r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
429
+ )
430
+
425
431
  def __init__(
426
432
  self,
427
433
  config: Llama4Config,
@@ -442,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module):
442
448
  )
443
449
 
444
450
  self.has_vision = (
445
- self.has_vision_weights and global_server_args_dict["enable_multimodal"]
451
+ self.has_vision_weights and get_global_server_args().enable_multimodal
446
452
  )
447
453
 
448
454
  if self.has_vision:
@@ -555,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module):
555
561
 
556
562
  return projected_vision_flat
557
563
 
564
+ def should_apply_lora(self, module_name: str) -> bool:
565
+ """Skip vision model and multi_modal_projector for LoRA."""
566
+ return bool(self.lora_pattern.match(module_name))
567
+
558
568
  def forward(
559
569
  self,
560
570
  input_ids: torch.Tensor,
@@ -700,7 +710,7 @@ class Llama4ForConditionalGeneration(nn.Module):
700
710
  """Handle scale parameter remapping. Returns True if handled."""
701
711
  if "scale" in name and "expert" not in name:
702
712
  remapped_name = maybe_remap_kv_scale_name(name, params_dict)
703
- return remapped_name is not None and remapped_name != name
713
+ return remapped_name != name
704
714
  return False
705
715
 
706
716
  def _handle_stacked_params(
@@ -0,0 +1,511 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
15
+
16
+ """Inference-only NemotronH model."""
17
+
18
+ from collections.abc import Iterable
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ from sglang.srt.configs import NemotronHConfig
25
+ from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
26
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.activation import ReLU2
28
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
29
+ HybridLinearAttnBackend,
30
+ Mamba2AttnBackend,
31
+ )
32
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
33
+ from sglang.srt.layers.layernorm import RMSNorm
34
+ from sglang.srt.layers.linear import (
35
+ ColumnParallelLinear,
36
+ QKVParallelLinear,
37
+ RowParallelLinear,
38
+ )
39
+ from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.quantization import QuantizationConfig
41
+ from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.vocab_parallel_embedding import (
43
+ DEFAULT_VOCAB_PADDING_SIZE,
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
48
+ from sglang.srt.model_loader.weight_utils import (
49
+ default_weight_loader,
50
+ maybe_remap_kv_scale_name,
51
+ replace_prefix,
52
+ replace_substrings,
53
+ )
54
+ from sglang.srt.utils import add_prefix, make_layers_non_pp
55
+ from sglang.utils import logger
56
+
57
+
58
+ class NemotronHMLP(nn.Module):
59
+ def __init__(
60
+ self,
61
+ config: NemotronHConfig,
62
+ layer_idx: int,
63
+ quant_config: Optional[QuantizationConfig] = None,
64
+ bias: bool = False,
65
+ prefix: str = "",
66
+ ) -> None:
67
+ super().__init__()
68
+
69
+ hybrid_override_pattern = config.hybrid_override_pattern
70
+ mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
71
+ if isinstance(config.intermediate_size, list):
72
+ if len(config.intermediate_size) == 1:
73
+ intermediate_size = config.intermediate_size[0]
74
+ else:
75
+ intermediate_size = config.intermediate_size[mlp_index]
76
+ else:
77
+ intermediate_size = config.intermediate_size
78
+
79
+ self.up_proj = ColumnParallelLinear(
80
+ input_size=config.hidden_size,
81
+ output_size=intermediate_size,
82
+ bias=bias,
83
+ quant_config=quant_config,
84
+ prefix=f"{prefix}.up_proj",
85
+ )
86
+ self.down_proj = RowParallelLinear(
87
+ input_size=intermediate_size,
88
+ output_size=config.hidden_size,
89
+ bias=bias,
90
+ quant_config=quant_config,
91
+ prefix=f"{prefix}.down_proj",
92
+ )
93
+ self.act_fn = ReLU2()
94
+
95
+ def forward(self, x: torch.Tensor):
96
+ x, _ = self.up_proj(x)
97
+ x = self.act_fn(x)
98
+ x, _ = self.down_proj(x)
99
+ return x
100
+
101
+
102
+ class NemotronHMLPDecoderLayer(nn.Module):
103
+ def __init__(
104
+ self,
105
+ config: NemotronHConfig,
106
+ layer_idx: int,
107
+ quant_config: Optional[QuantizationConfig] = None,
108
+ prefix: str = "",
109
+ ) -> None:
110
+ super().__init__()
111
+ self.config = config
112
+
113
+ self.mixer = NemotronHMLP(
114
+ config,
115
+ quant_config=quant_config,
116
+ bias=config.mlp_bias,
117
+ prefix=f"{prefix}.mixer",
118
+ layer_idx=layer_idx,
119
+ )
120
+
121
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
122
+
123
+ def forward(
124
+ self,
125
+ *,
126
+ hidden_states: torch.Tensor,
127
+ residual: Optional[torch.Tensor],
128
+ forward_batch: ForwardBatch,
129
+ ) -> tuple[torch.Tensor, torch.Tensor]:
130
+ if residual is None:
131
+ residual = hidden_states
132
+ hidden_states = self.norm(hidden_states)
133
+ else:
134
+ hidden_states, residual = self.norm(hidden_states, residual)
135
+
136
+ hidden_states = self.mixer.forward(hidden_states)
137
+ return hidden_states, residual
138
+
139
+
140
+ class NemotronHMambaDecoderLayer(nn.Module):
141
+ def __init__(
142
+ self,
143
+ config: NemotronHConfig,
144
+ layer_idx: int,
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ prefix: str = "",
147
+ ) -> None:
148
+ super().__init__()
149
+ self.config = config
150
+ self.layer_id = layer_idx
151
+ self.mixer = MambaMixer2(
152
+ cache_params=config.mamba2_cache_params,
153
+ hidden_size=config.hidden_size,
154
+ use_conv_bias=config.use_conv_bias,
155
+ use_bias=config.use_bias,
156
+ n_groups=config.mamba_n_groups,
157
+ rms_norm_eps=config.rms_norm_eps,
158
+ activation=config.mamba_hidden_act,
159
+ quant_config=quant_config,
160
+ prefix=f"{prefix}.mixer",
161
+ )
162
+
163
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
164
+
165
+ def forward(
166
+ self,
167
+ *,
168
+ hidden_states: torch.Tensor,
169
+ residual: Optional[torch.Tensor],
170
+ forward_batch: ForwardBatch,
171
+ ) -> tuple[torch.Tensor, torch.Tensor]:
172
+ if residual is None:
173
+ residual = hidden_states
174
+ hidden_states = self.norm(hidden_states)
175
+ else:
176
+ hidden_states, residual = self.norm(hidden_states, residual)
177
+
178
+ output = torch.empty_like(hidden_states)
179
+ attn_backend = forward_batch.attn_backend
180
+ assert isinstance(attn_backend, HybridLinearAttnBackend)
181
+ assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
182
+ attn_backend.linear_attn_backend.forward(
183
+ mixer=self.mixer,
184
+ layer_id=self.layer_id,
185
+ hidden_states=hidden_states,
186
+ output=output,
187
+ use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
188
+ )
189
+ return output, residual
190
+
191
+
192
+ class NemotronHAttention(nn.Module):
193
+ def __init__(
194
+ self,
195
+ config: NemotronHConfig,
196
+ layer_idx: int,
197
+ quant_config: Optional[QuantizationConfig] = None,
198
+ prefix: str = "",
199
+ ) -> None:
200
+ super().__init__()
201
+ self.hidden_size = config.hidden_size
202
+ tp_size = get_tensor_model_parallel_world_size()
203
+ self.total_num_heads = config.num_attention_heads
204
+ assert self.total_num_heads % tp_size == 0
205
+ self.num_heads = self.total_num_heads // tp_size
206
+ self.total_num_kv_heads = config.num_key_value_heads
207
+ if self.total_num_kv_heads >= tp_size:
208
+ # Number of KV heads is greater than TP size, so we partition
209
+ # the KV heads across multiple tensor parallel GPUs.
210
+ assert self.total_num_kv_heads % tp_size == 0
211
+ else:
212
+ # Number of KV heads is less than TP size, so we replicate
213
+ # the KV heads across multiple tensor parallel GPUs.
214
+ assert tp_size % self.total_num_kv_heads == 0
215
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
216
+ if hasattr(config, "head_dim") and config.head_dim is not None:
217
+ self.head_dim = config.head_dim
218
+ else:
219
+ self.head_dim = config.hidden_size // self.total_num_heads
220
+ self.q_size = self.num_heads * self.head_dim
221
+ self.kv_size = self.num_kv_heads * self.head_dim
222
+ self.scaling = self.head_dim**-0.5
223
+
224
+ self.qkv_proj = QKVParallelLinear(
225
+ config.hidden_size,
226
+ self.head_dim,
227
+ self.total_num_heads,
228
+ self.total_num_kv_heads,
229
+ bias=False,
230
+ quant_config=quant_config,
231
+ prefix=f"{prefix}.qkv_proj",
232
+ )
233
+ self.o_proj = RowParallelLinear(
234
+ self.total_num_heads * self.head_dim,
235
+ config.hidden_size,
236
+ bias=False,
237
+ quant_config=quant_config,
238
+ prefix=f"{prefix}.o_proj",
239
+ )
240
+
241
+ self.attn = RadixAttention(
242
+ self.num_heads,
243
+ self.head_dim,
244
+ self.scaling,
245
+ num_kv_heads=self.num_kv_heads,
246
+ layer_id=layer_idx,
247
+ quant_config=quant_config,
248
+ prefix=add_prefix("attn", prefix),
249
+ )
250
+
251
+ def forward(
252
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
253
+ ) -> torch.Tensor:
254
+ qkv, _ = self.qkv_proj(hidden_states)
255
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
256
+ attn_output = self.attn.forward(q, k, v, forward_batch)
257
+ output, _ = self.o_proj(attn_output)
258
+ return output
259
+
260
+
261
+ class NemotronHAttentionDecoderLayer(nn.Module):
262
+ def __init__(
263
+ self,
264
+ config: NemotronHConfig,
265
+ layer_idx: int,
266
+ quant_config: Optional[QuantizationConfig] = None,
267
+ prefix: str = "",
268
+ ) -> None:
269
+ super().__init__()
270
+
271
+ self.mixer = NemotronHAttention(
272
+ config,
273
+ layer_idx,
274
+ quant_config,
275
+ prefix=f"{prefix}.mixer",
276
+ )
277
+
278
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
279
+
280
+ def forward(
281
+ self,
282
+ *,
283
+ hidden_states: torch.Tensor,
284
+ residual: Optional[torch.Tensor],
285
+ forward_batch: ForwardBatch,
286
+ ) -> tuple[torch.Tensor, torch.Tensor]:
287
+ if residual is None:
288
+ residual = hidden_states
289
+ hidden_states = self.norm(hidden_states)
290
+ else:
291
+ hidden_states, residual = self.norm(hidden_states, residual)
292
+
293
+ hidden_states = self.mixer.forward(
294
+ hidden_states=hidden_states, forward_batch=forward_batch
295
+ )
296
+ return hidden_states, residual
297
+
298
+
299
+ Layers = (
300
+ NemotronHAttentionDecoderLayer
301
+ | NemotronHMLPDecoderLayer
302
+ | NemotronHMambaDecoderLayer
303
+ )
304
+ ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
305
+ ATTENTION: NemotronHAttentionDecoderLayer,
306
+ MLP: NemotronHMLPDecoderLayer,
307
+ MAMBA: NemotronHMambaDecoderLayer,
308
+ }
309
+
310
+
311
+ class NemotronHModel(nn.Module):
312
+ def __init__(
313
+ self,
314
+ *,
315
+ config: NemotronHConfig,
316
+ quant_config: Optional[QuantizationConfig] = None,
317
+ prefix: str = "",
318
+ ):
319
+ super().__init__()
320
+
321
+ lora_config = None
322
+ self.config = config
323
+ lora_vocab = (
324
+ (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
325
+ if lora_config
326
+ else 0
327
+ )
328
+ self.vocab_size = config.vocab_size + lora_vocab
329
+ self.org_vocab_size = config.vocab_size
330
+
331
+ self.embed_tokens = VocabParallelEmbedding(
332
+ self.vocab_size,
333
+ config.hidden_size,
334
+ org_num_embeddings=config.vocab_size,
335
+ )
336
+
337
+ def get_layer(idx: int, prefix: str):
338
+ layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
339
+ return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
340
+
341
+ self.layers = make_layers_non_pp(
342
+ len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
343
+ )
344
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
345
+
346
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
347
+ return self.embed_tokens(input_ids)
348
+
349
+ def forward(
350
+ self,
351
+ input_ids: torch.Tensor,
352
+ positions: torch.Tensor,
353
+ forward_batch: ForwardBatch,
354
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
355
+ inputs_embeds: Optional[torch.Tensor] = None,
356
+ ) -> Union[torch.Tensor, PPProxyTensors]:
357
+ if get_pp_group().is_first_rank:
358
+ if inputs_embeds is not None:
359
+ hidden_states = inputs_embeds
360
+ else:
361
+ hidden_states = self.get_input_embeddings(input_ids)
362
+ residual = None
363
+ else:
364
+ assert pp_proxy_tensors is not None
365
+ hidden_states = pp_proxy_tensors["hidden_states"]
366
+ residual = pp_proxy_tensors["residual"]
367
+
368
+ residual = None
369
+ for layer in self.layers:
370
+ if not isinstance(layer, Layers):
371
+ raise ValueError(f"Unknown layer type: {type(layer)}")
372
+ hidden_states, residual = layer.forward(
373
+ hidden_states=hidden_states,
374
+ residual=residual,
375
+ forward_batch=forward_batch,
376
+ )
377
+
378
+ if not get_pp_group().is_last_rank:
379
+ return PPProxyTensors(
380
+ {"hidden_states": hidden_states, "residual": residual}
381
+ )
382
+ hidden_states, _ = self.norm_f(hidden_states, residual)
383
+ return hidden_states
384
+
385
+
386
+ class NemotronHForCausalLM(nn.Module):
387
+ stacked_params_mapping = [
388
+ # (param_name, shard_name, shard_id)
389
+ ("qkv_proj", "q_proj", "q"),
390
+ ("qkv_proj", "k_proj", "k"),
391
+ ("qkv_proj", "v_proj", "v"),
392
+ ]
393
+ packed_modules_mapping = {
394
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
395
+ }
396
+
397
+ remap_prefix = {"backbone": "model"}
398
+ remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
399
+
400
+ def __init__(
401
+ self,
402
+ *,
403
+ config: NemotronHConfig,
404
+ quant_config: Optional[QuantizationConfig] = None,
405
+ prefix: str = "",
406
+ ):
407
+ super().__init__()
408
+ lora_config = None
409
+ self.config = config
410
+ self.model = self._init_model(
411
+ config=config, quant_config=quant_config, prefix=prefix
412
+ )
413
+ if self.config.tie_word_embeddings:
414
+ self.lm_head = self.model.embed_tokens
415
+ else:
416
+ self.unpadded_vocab_size = config.vocab_size
417
+ if lora_config:
418
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
419
+ self.lm_head = ParallelLMHead(
420
+ self.unpadded_vocab_size,
421
+ config.hidden_size,
422
+ org_num_embeddings=config.vocab_size,
423
+ padding_size=(
424
+ DEFAULT_VOCAB_PADDING_SIZE
425
+ # We need bigger padding if using lora for kernel
426
+ # compatibility
427
+ if not lora_config
428
+ else lora_config.lora_vocab_padding_size
429
+ ),
430
+ quant_config=quant_config,
431
+ prefix=add_prefix("lm_head", prefix),
432
+ )
433
+ self.logits_processor = LogitsProcessor(config)
434
+
435
+ def _init_model(
436
+ self,
437
+ config: NemotronHConfig,
438
+ quant_config: Optional[QuantizationConfig] = None,
439
+ prefix: str = "",
440
+ ):
441
+ return NemotronHModel(
442
+ config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
443
+ )
444
+
445
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
446
+ return self.model.get_input_embeddings(input_ids)
447
+
448
+ @torch.no_grad()
449
+ def forward(
450
+ self,
451
+ input_ids: torch.Tensor,
452
+ positions: torch.Tensor,
453
+ forward_batch: ForwardBatch,
454
+ input_embeds: Optional[torch.Tensor] = None,
455
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
456
+ ):
457
+ hidden_states = self.model.forward(
458
+ input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
459
+ )
460
+ return self.logits_processor(
461
+ input_ids, hidden_states, self.lm_head, forward_batch
462
+ )
463
+
464
+ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
465
+ return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
466
+
467
+ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
468
+ return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
469
+
470
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
471
+ updated_weights = []
472
+ for name, loaded_weight in weights:
473
+ name = replace_prefix(name, self.remap_prefix)
474
+ name = replace_substrings(name, self.remap_substr)
475
+ updated_weights.append((name, loaded_weight))
476
+ params_dict = dict(self.named_parameters())
477
+
478
+ for name, loaded_weight in updated_weights:
479
+ if "scale" in name:
480
+ name = maybe_remap_kv_scale_name(name, params_dict)
481
+ if name is None:
482
+ continue
483
+
484
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
485
+ if weight_name not in name:
486
+ continue
487
+ name = name.replace(weight_name, param_name)
488
+ # Skip loading extra bias for GPTQ models.
489
+ if name.endswith(".bias") and name not in params_dict:
490
+ continue
491
+ if name not in params_dict:
492
+ continue
493
+ param = params_dict[name]
494
+ weight_loader = param.weight_loader
495
+ weight_loader(param, loaded_weight, shard_id)
496
+ break
497
+ else:
498
+ # Skip loading extra bias for GPTQ models.
499
+ if name.endswith(".bias") and name not in params_dict:
500
+ continue
501
+ if name in params_dict.keys():
502
+ param = params_dict[name]
503
+ weight_loader = getattr(
504
+ param, "weight_loader", default_weight_loader
505
+ )
506
+ weight_loader(param, loaded_weight)
507
+ else:
508
+ logger.warning(f"Parameter {name} not found in params_dict")
509
+
510
+
511
+ EntryClass = [NemotronHForCausalLM]