sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,514 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
15
+
16
+ """Inference-only NemotronH model."""
17
+
18
+ from collections.abc import Iterable
19
+ from typing import Optional, Union
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ from sglang.srt.configs import NemotronHConfig
25
+ from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
26
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.activation import ReLU2
28
+ from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
29
+ HybridLinearAttnBackend,
30
+ Mamba2AttnBackend,
31
+ )
32
+ from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
33
+ from sglang.srt.layers.layernorm import RMSNorm
34
+ from sglang.srt.layers.linear import (
35
+ ColumnParallelLinear,
36
+ QKVParallelLinear,
37
+ RowParallelLinear,
38
+ )
39
+ from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.quantization import QuantizationConfig
41
+ from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.vocab_parallel_embedding import (
43
+ DEFAULT_VOCAB_PADDING_SIZE,
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
48
+ from sglang.srt.model_loader.weight_utils import (
49
+ default_weight_loader,
50
+ maybe_remap_kv_scale_name,
51
+ )
52
+ from sglang.srt.utils import add_prefix, make_layers_non_pp
53
+ from sglang.utils import logger
54
+
55
+
56
+ class NemotronHMLP(nn.Module):
57
+ def __init__(
58
+ self,
59
+ config: NemotronHConfig,
60
+ layer_idx: int,
61
+ quant_config: Optional[QuantizationConfig] = None,
62
+ bias: bool = False,
63
+ prefix: str = "",
64
+ ) -> None:
65
+ super().__init__()
66
+
67
+ hybrid_override_pattern = config.hybrid_override_pattern
68
+ mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
69
+ if isinstance(config.intermediate_size, list):
70
+ if len(config.intermediate_size) == 1:
71
+ intermediate_size = config.intermediate_size[0]
72
+ else:
73
+ intermediate_size = config.intermediate_size[mlp_index]
74
+ else:
75
+ intermediate_size = config.intermediate_size
76
+
77
+ self.up_proj = ColumnParallelLinear(
78
+ input_size=config.hidden_size,
79
+ output_size=intermediate_size,
80
+ bias=bias,
81
+ quant_config=quant_config,
82
+ prefix=f"{prefix}.up_proj",
83
+ )
84
+ self.down_proj = RowParallelLinear(
85
+ input_size=intermediate_size,
86
+ output_size=config.hidden_size,
87
+ bias=bias,
88
+ quant_config=quant_config,
89
+ prefix=f"{prefix}.down_proj",
90
+ )
91
+ self.act_fn = ReLU2()
92
+
93
+ def forward(self, x: torch.Tensor):
94
+ x, _ = self.up_proj(x)
95
+ x = self.act_fn(x)
96
+ x, _ = self.down_proj(x)
97
+ return x
98
+
99
+
100
+ class NemotronHMLPDecoderLayer(nn.Module):
101
+ def __init__(
102
+ self,
103
+ config: NemotronHConfig,
104
+ layer_idx: int,
105
+ quant_config: Optional[QuantizationConfig] = None,
106
+ prefix: str = "",
107
+ ) -> None:
108
+ super().__init__()
109
+ self.config = config
110
+
111
+ self.mixer = NemotronHMLP(
112
+ config,
113
+ quant_config=quant_config,
114
+ bias=config.mlp_bias,
115
+ prefix=f"{prefix}.mixer",
116
+ layer_idx=layer_idx,
117
+ )
118
+
119
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
120
+
121
+ def forward(
122
+ self,
123
+ *,
124
+ hidden_states: torch.Tensor,
125
+ residual: Optional[torch.Tensor],
126
+ forward_batch: ForwardBatch,
127
+ ) -> tuple[torch.Tensor, torch.Tensor]:
128
+ if residual is None:
129
+ residual = hidden_states
130
+ hidden_states = self.norm(hidden_states)
131
+ else:
132
+ hidden_states, residual = self.norm(hidden_states, residual)
133
+
134
+ hidden_states = self.mixer.forward(hidden_states)
135
+ return hidden_states, residual
136
+
137
+
138
+ class NemotronHMambaDecoderLayer(nn.Module):
139
+ def __init__(
140
+ self,
141
+ config: NemotronHConfig,
142
+ layer_idx: int,
143
+ quant_config: Optional[QuantizationConfig] = None,
144
+ prefix: str = "",
145
+ ) -> None:
146
+ super().__init__()
147
+ self.config = config
148
+ self.layer_id = layer_idx
149
+ self.mixer = MambaMixer2(
150
+ cache_params=config.mamba2_cache_params,
151
+ hidden_size=config.hidden_size,
152
+ use_conv_bias=config.use_conv_bias,
153
+ use_bias=config.use_bias,
154
+ n_groups=config.mamba_n_groups,
155
+ rms_norm_eps=config.rms_norm_eps,
156
+ activation=config.mamba_hidden_act,
157
+ quant_config=quant_config,
158
+ )
159
+
160
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
161
+
162
+ def forward(
163
+ self,
164
+ *,
165
+ hidden_states: torch.Tensor,
166
+ residual: Optional[torch.Tensor],
167
+ forward_batch: ForwardBatch,
168
+ ) -> tuple[torch.Tensor, torch.Tensor]:
169
+ if residual is None:
170
+ residual = hidden_states
171
+ hidden_states = self.norm(hidden_states)
172
+ else:
173
+ hidden_states, residual = self.norm(hidden_states, residual)
174
+
175
+ output = torch.empty_like(hidden_states)
176
+ attn_backend = forward_batch.attn_backend
177
+ assert isinstance(attn_backend, HybridLinearAttnBackend)
178
+ assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
179
+ attn_backend.linear_attn_backend.forward(
180
+ mixer=self.mixer,
181
+ layer_id=self.layer_id,
182
+ hidden_states=hidden_states,
183
+ output=output,
184
+ use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
185
+ )
186
+ return output, residual
187
+
188
+
189
+ class NemotronHAttention(nn.Module):
190
+ def __init__(
191
+ self,
192
+ config: NemotronHConfig,
193
+ layer_idx: int,
194
+ quant_config: Optional[QuantizationConfig] = None,
195
+ prefix: str = "",
196
+ ) -> None:
197
+ super().__init__()
198
+ self.hidden_size = config.hidden_size
199
+ tp_size = get_tensor_model_parallel_world_size()
200
+ self.total_num_heads = config.num_attention_heads
201
+ assert self.total_num_heads % tp_size == 0
202
+ self.num_heads = self.total_num_heads // tp_size
203
+ self.total_num_kv_heads = config.num_key_value_heads
204
+ if self.total_num_kv_heads >= tp_size:
205
+ # Number of KV heads is greater than TP size, so we partition
206
+ # the KV heads across multiple tensor parallel GPUs.
207
+ assert self.total_num_kv_heads % tp_size == 0
208
+ else:
209
+ # Number of KV heads is less than TP size, so we replicate
210
+ # the KV heads across multiple tensor parallel GPUs.
211
+ assert tp_size % self.total_num_kv_heads == 0
212
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
213
+ if hasattr(config, "head_dim") and config.head_dim is not None:
214
+ self.head_dim = config.head_dim
215
+ else:
216
+ self.head_dim = config.hidden_size // self.total_num_heads
217
+ self.q_size = self.num_heads * self.head_dim
218
+ self.kv_size = self.num_kv_heads * self.head_dim
219
+ self.scaling = self.head_dim**-0.5
220
+
221
+ self.qkv_proj = QKVParallelLinear(
222
+ config.hidden_size,
223
+ self.head_dim,
224
+ self.total_num_heads,
225
+ self.total_num_kv_heads,
226
+ bias=False,
227
+ quant_config=quant_config,
228
+ prefix=f"{prefix}.qkv_proj",
229
+ )
230
+ self.o_proj = RowParallelLinear(
231
+ self.total_num_heads * self.head_dim,
232
+ config.hidden_size,
233
+ bias=False,
234
+ quant_config=quant_config,
235
+ prefix=f"{prefix}.o_proj",
236
+ )
237
+
238
+ self.attn = RadixAttention(
239
+ self.num_heads,
240
+ self.head_dim,
241
+ self.scaling,
242
+ num_kv_heads=self.num_kv_heads,
243
+ layer_id=layer_idx,
244
+ quant_config=quant_config,
245
+ prefix=add_prefix("attn", prefix),
246
+ )
247
+
248
+ def forward(
249
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
250
+ ) -> torch.Tensor:
251
+ qkv, _ = self.qkv_proj(hidden_states)
252
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
253
+ attn_output = self.attn.forward(q, k, v, forward_batch)
254
+ output, _ = self.o_proj(attn_output)
255
+ return output
256
+
257
+
258
+ class NemotronHAttentionDecoderLayer(nn.Module):
259
+ def __init__(
260
+ self,
261
+ config: NemotronHConfig,
262
+ layer_idx: int,
263
+ quant_config: Optional[QuantizationConfig] = None,
264
+ prefix: str = "",
265
+ ) -> None:
266
+ super().__init__()
267
+
268
+ self.mixer = NemotronHAttention(
269
+ config,
270
+ layer_idx,
271
+ quant_config,
272
+ prefix=f"{prefix}.mixer",
273
+ )
274
+
275
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
276
+
277
+ def forward(
278
+ self,
279
+ *,
280
+ hidden_states: torch.Tensor,
281
+ residual: Optional[torch.Tensor],
282
+ forward_batch: ForwardBatch,
283
+ ) -> tuple[torch.Tensor, torch.Tensor]:
284
+ if residual is None:
285
+ residual = hidden_states
286
+ hidden_states = self.norm(hidden_states)
287
+ else:
288
+ hidden_states, residual = self.norm(hidden_states, residual)
289
+
290
+ hidden_states = self.mixer.forward(
291
+ hidden_states=hidden_states, forward_batch=forward_batch
292
+ )
293
+ return hidden_states, residual
294
+
295
+
296
+ Layers = (
297
+ NemotronHAttentionDecoderLayer
298
+ | NemotronHMLPDecoderLayer
299
+ | NemotronHMambaDecoderLayer
300
+ )
301
+ ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
302
+ ATTENTION: NemotronHAttentionDecoderLayer,
303
+ MLP: NemotronHMLPDecoderLayer,
304
+ MAMBA: NemotronHMambaDecoderLayer,
305
+ }
306
+
307
+
308
+ class NemotronHModel(nn.Module):
309
+ def __init__(
310
+ self,
311
+ *,
312
+ config: NemotronHConfig,
313
+ quant_config: Optional[QuantizationConfig] = None,
314
+ prefix: str = "",
315
+ ):
316
+ super().__init__()
317
+
318
+ lora_config = None
319
+ self.config = config
320
+ lora_vocab = (
321
+ (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
322
+ if lora_config
323
+ else 0
324
+ )
325
+ self.vocab_size = config.vocab_size + lora_vocab
326
+ self.org_vocab_size = config.vocab_size
327
+
328
+ self.embed_tokens = VocabParallelEmbedding(
329
+ self.vocab_size,
330
+ config.hidden_size,
331
+ org_num_embeddings=config.vocab_size,
332
+ )
333
+
334
+ def get_layer(idx: int, prefix: str):
335
+ layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
336
+ return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
337
+
338
+ self.layers = make_layers_non_pp(
339
+ len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
340
+ )
341
+ self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
342
+
343
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
344
+ return self.embed_tokens(input_ids)
345
+
346
+ def forward(
347
+ self,
348
+ input_ids: torch.Tensor,
349
+ positions: torch.Tensor,
350
+ forward_batch: ForwardBatch,
351
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
352
+ inputs_embeds: Optional[torch.Tensor] = None,
353
+ ) -> Union[torch.Tensor, PPProxyTensors]:
354
+ if get_pp_group().is_first_rank:
355
+ if inputs_embeds is not None:
356
+ hidden_states = inputs_embeds
357
+ else:
358
+ hidden_states = self.get_input_embeddings(input_ids)
359
+ residual = None
360
+ else:
361
+ assert pp_proxy_tensors is not None
362
+ hidden_states = pp_proxy_tensors["hidden_states"]
363
+ residual = pp_proxy_tensors["residual"]
364
+
365
+ residual = None
366
+ for layer in self.layers:
367
+ if not isinstance(layer, Layers):
368
+ raise ValueError(f"Unknown layer type: {type(layer)}")
369
+ hidden_states, residual = layer.forward(
370
+ hidden_states=hidden_states,
371
+ residual=residual,
372
+ forward_batch=forward_batch,
373
+ )
374
+
375
+ if not get_pp_group().is_last_rank:
376
+ return PPProxyTensors(
377
+ {"hidden_states": hidden_states, "residual": residual}
378
+ )
379
+ hidden_states, _ = self.norm_f(hidden_states, residual)
380
+ return hidden_states
381
+
382
+
383
+ class NemotronHForCausalLM(nn.Module):
384
+ remap_prefix = {"backbone": "model"}
385
+ remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
386
+
387
+ # LoRA specific attributes
388
+ embedding_modules = {
389
+ "embed_tokens": "input_embeddings",
390
+ "lm_head": "output_embeddings",
391
+ }
392
+ embedding_padding_modules = ["lm_head"]
393
+
394
+ def __init__(
395
+ self,
396
+ *,
397
+ config: NemotronHConfig,
398
+ quant_config: Optional[QuantizationConfig] = None,
399
+ prefix: str = "",
400
+ ):
401
+ super().__init__()
402
+ lora_config = None
403
+ self.config = config
404
+ self.model = self._init_model(
405
+ config=config, quant_config=quant_config, prefix=prefix
406
+ )
407
+ if self.config.tie_word_embeddings:
408
+ self.lm_head = self.model.embed_tokens
409
+ else:
410
+ self.unpadded_vocab_size = config.vocab_size
411
+ if lora_config:
412
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
413
+ self.lm_head = ParallelLMHead(
414
+ self.unpadded_vocab_size,
415
+ config.hidden_size,
416
+ org_num_embeddings=config.vocab_size,
417
+ padding_size=(
418
+ DEFAULT_VOCAB_PADDING_SIZE
419
+ # We need bigger padding if using lora for kernel
420
+ # compatibility
421
+ if not lora_config
422
+ else lora_config.lora_vocab_padding_size
423
+ ),
424
+ quant_config=quant_config,
425
+ prefix=add_prefix("lm_head", prefix),
426
+ )
427
+ self.logits_processor = LogitsProcessor(config)
428
+
429
+ def _init_model(
430
+ self,
431
+ config: NemotronHConfig,
432
+ quant_config: Optional[QuantizationConfig] = None,
433
+ prefix: str = "",
434
+ ):
435
+ return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
436
+
437
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
438
+ return self.model.get_input_embeddings(input_ids)
439
+
440
+ @torch.no_grad()
441
+ def forward(
442
+ self,
443
+ input_ids: torch.Tensor,
444
+ positions: torch.Tensor,
445
+ forward_batch: ForwardBatch,
446
+ input_embeds: Optional[torch.Tensor] = None,
447
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
448
+ ):
449
+ hidden_states = self.model.forward(
450
+ input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
451
+ )
452
+ return self.logits_processor(
453
+ input_ids, hidden_states, self.lm_head, forward_batch
454
+ )
455
+
456
+ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
457
+ return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
458
+
459
+ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
460
+ return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
461
+
462
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
463
+ stacked_params_mapping = [
464
+ # (param_name, shard_name, shard_id)
465
+ ("qkv_proj", "q_proj", "q"),
466
+ ("qkv_proj", "k_proj", "k"),
467
+ ("qkv_proj", "v_proj", "v"),
468
+ ]
469
+
470
+ updated_weights = []
471
+ for name, loaded_weight in weights:
472
+ for prefix, new_key in self.remap_prefix.items():
473
+ if name.startswith(prefix):
474
+ name = name.replace(prefix, new_key)
475
+ for substr, new_key in self.remap_substr.items():
476
+ if substr in name:
477
+ name = name.replace(substr, new_key)
478
+ updated_weights.append((name, loaded_weight))
479
+ params_dict = dict(self.named_parameters())
480
+
481
+ for name, loaded_weight in updated_weights:
482
+ if "scale" in name:
483
+ name = maybe_remap_kv_scale_name(name, params_dict)
484
+ if name is None:
485
+ continue
486
+
487
+ for param_name, weight_name, shard_id in stacked_params_mapping:
488
+ if weight_name not in name:
489
+ continue
490
+ name = name.replace(weight_name, param_name)
491
+ # Skip loading extra bias for GPTQ models.
492
+ if name.endswith(".bias") and name not in params_dict:
493
+ continue
494
+ if name not in params_dict:
495
+ continue
496
+ param = params_dict[name]
497
+ weight_loader = param.weight_loader
498
+ weight_loader(param, loaded_weight, shard_id)
499
+ break
500
+ else:
501
+ # Skip loading extra bias for GPTQ models.
502
+ if name.endswith(".bias") and name not in params_dict:
503
+ continue
504
+ if name in params_dict.keys():
505
+ param = params_dict[name]
506
+ weight_loader = getattr(
507
+ param, "weight_loader", default_weight_loader
508
+ )
509
+ weight_loader(param, loaded_weight)
510
+ else:
511
+ logger.warning(f"Parameter {name} not found in params_dict")
512
+
513
+
514
+ EntryClass = [NemotronHForCausalLM]