sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,6 @@ from __future__ import annotations
4
4
 
5
5
  # ruff: noqa: SIM117
6
6
  import collections
7
- import concurrent
8
7
  import dataclasses
9
8
  import fnmatch
10
9
  import glob
@@ -12,13 +11,11 @@ import json
12
11
  import logging
13
12
  import math
14
13
  import os
15
- import re
16
14
  import socket
17
15
  import threading
18
16
  import time
19
17
  from abc import ABC, abstractmethod
20
- from concurrent.futures import ThreadPoolExecutor
21
- from contextlib import contextmanager
18
+ from contextlib import contextmanager, suppress
22
19
  from typing import (
23
20
  TYPE_CHECKING,
24
21
  Any,
@@ -30,17 +27,28 @@ from typing import (
30
27
  Tuple,
31
28
  cast,
32
29
  )
33
- from urllib.parse import urlparse
34
30
 
35
31
  import huggingface_hub
36
32
  import numpy as np
37
- import requests
38
- import safetensors.torch
39
33
  import torch
34
+
35
+ from sglang.srt.server_args import get_global_server_args
36
+
37
+ # Try to import accelerate (optional dependency)
38
+ try:
39
+ from accelerate import infer_auto_device_map, init_empty_weights
40
+ from accelerate.utils import get_max_memory
41
+
42
+ HAS_ACCELERATE = True
43
+ except ImportError:
44
+ HAS_ACCELERATE = False
45
+ infer_auto_device_map = None
46
+ init_empty_weights = None
47
+ get_max_memory = None
48
+
40
49
  from huggingface_hub import HfApi, hf_hub_download
41
50
  from torch import nn
42
- from tqdm.auto import tqdm
43
- from transformers import AutoModelForCausalLM
51
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
44
52
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
45
53
 
46
54
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
@@ -54,14 +62,23 @@ from sglang.srt.distributed import (
54
62
  get_tensor_model_parallel_rank,
55
63
  get_tensor_model_parallel_world_size,
56
64
  )
65
+ from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES
66
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
67
+ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
68
+ trigger_transferring_weights_request,
69
+ )
57
70
  from sglang.srt.model_loader.utils import (
58
71
  get_model_architecture,
59
72
  post_load_weights,
60
73
  set_default_torch_dtype,
61
74
  )
75
+
76
+ # Constants for memory management
77
+ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
78
+ 0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
79
+ )
80
+ from sglang.srt.environ import envs
62
81
  from sglang.srt.model_loader.weight_utils import (
63
- _BAR_FORMAT,
64
- default_weight_loader,
65
82
  download_safetensors_index_file_from_hf,
66
83
  download_weights_from_hf,
67
84
  filter_duplicate_safetensors_files,
@@ -77,14 +94,12 @@ from sglang.srt.model_loader.weight_utils import (
77
94
  safetensors_weights_iterator,
78
95
  set_runai_streamer_env,
79
96
  )
80
- from sglang.srt.remote_instance_weight_loader_utils import (
81
- trigger_transferring_weights_request,
82
- )
83
97
  from sglang.srt.utils import (
84
98
  get_bool_env_var,
85
99
  get_device_capability,
86
100
  is_npu,
87
101
  is_pin_memory_available,
102
+ rank0_log,
88
103
  set_weight_attrs,
89
104
  )
90
105
 
@@ -94,6 +109,8 @@ if TYPE_CHECKING:
94
109
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
95
110
 
96
111
  _is_npu = is_npu()
112
+ # ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
113
+ # which contains the complete mapping of quantization config choices
97
114
 
98
115
 
99
116
  @contextmanager
@@ -163,11 +180,12 @@ def _get_quantization_config(
163
180
  model_config: ModelConfig,
164
181
  load_config: LoadConfig,
165
182
  packed_modules_mapping: Dict[str, List[str]],
183
+ remap_prefix: Dict[str, str] | None = None,
166
184
  ) -> Optional[QuantizationConfig]:
167
185
  """Get the quantization config."""
168
186
  if model_config.quantization is not None:
169
187
  quant_config = get_quant_config(
170
- model_config, load_config, packed_modules_mapping
188
+ model_config, load_config, packed_modules_mapping, remap_prefix
171
189
  )
172
190
  # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
173
191
  if quant_config is None:
@@ -203,10 +221,14 @@ def _initialize_model(
203
221
  """Initialize a model with the given configurations."""
204
222
  model_class, _ = get_model_architecture(model_config)
205
223
  packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
224
+ remap_prefix = getattr(model_class, "remap_prefix", None)
206
225
  if _is_npu:
207
226
  packed_modules_mapping.update(
208
227
  {
209
- "visual": {"qkv_proj": ["qkv"]},
228
+ "visual": {
229
+ "qkv_proj": ["qkv"],
230
+ "gate_up_proj": ["gate_proj", "up_proj"],
231
+ },
210
232
  "vision_model": {
211
233
  "qkv_proj": ["q_proj", "k_proj", "v_proj"],
212
234
  "proj": ["out_proj"],
@@ -223,13 +245,22 @@ def _initialize_model(
223
245
  )
224
246
 
225
247
  quant_config = _get_quantization_config(
226
- model_config, load_config, packed_modules_mapping
227
- )
228
- return model_class(
229
- config=model_config.hf_config,
230
- quant_config=quant_config,
248
+ model_config, load_config, packed_modules_mapping, remap_prefix
231
249
  )
232
250
 
251
+ # Build kwargs conditionally
252
+ kwargs = {
253
+ "config": model_config.hf_config,
254
+ "quant_config": quant_config,
255
+ }
256
+
257
+ # Only add sparse head kwargs if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
258
+ if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set():
259
+ kwargs["sparse_head"] = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.value
260
+ kwargs["model_path"] = model_config.model_path
261
+
262
+ return model_class(**kwargs)
263
+
233
264
 
234
265
  class BaseModelLoader(ABC):
235
266
  """Base class for model loaders."""
@@ -421,10 +452,8 @@ class DefaultModelLoader(BaseModelLoader):
421
452
  hf_weights_files,
422
453
  )
423
454
  elif use_safetensors:
424
- from sglang.srt.managers.schedule_batch import global_server_args_dict
425
-
426
- weight_loader_disable_mmap = global_server_args_dict.get(
427
- "weight_loader_disable_mmap"
455
+ weight_loader_disable_mmap = (
456
+ get_global_server_args().weight_loader_disable_mmap
428
457
  )
429
458
 
430
459
  if extra_config.get("enable_multithread_load"):
@@ -474,12 +503,87 @@ class DefaultModelLoader(BaseModelLoader):
474
503
  model_config.model_path, model_config.revision, fall_back_to_pt=True
475
504
  )
476
505
 
506
+ def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
507
+ """Load and prepare the base model for ModelOpt quantization.
508
+
509
+ This method handles the common model loading logic shared between
510
+ DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
511
+ """
512
+ if not HAS_ACCELERATE:
513
+ raise ImportError(
514
+ "accelerate is required for ModelOpt quantization. "
515
+ "Please install it with: pip install accelerate"
516
+ )
517
+
518
+ hf_config = AutoConfig.from_pretrained(
519
+ model_config.model_path, trust_remote_code=True
520
+ )
521
+ with init_empty_weights():
522
+ torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
523
+ model = AutoModelForCausalLM.from_config(
524
+ hf_config, torch_dtype=torch_dtype, trust_remote_code=True
525
+ )
526
+ max_memory = get_max_memory()
527
+ inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
528
+
529
+ on_cpu = "cpu" in inferred_device_map.values()
530
+ model_kwargs = {"torch_dtype": "auto"}
531
+ device_map = "auto"
532
+
533
+ if on_cpu:
534
+ for device in max_memory.keys():
535
+ if isinstance(device, int):
536
+ max_memory[device] *= DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
537
+
538
+ logger.warning(
539
+ "Model does not fit to the GPU mem. "
540
+ f"We apply the following memory limit for calibration: \n{max_memory}\n"
541
+ f"If you hit GPU OOM issue, please adjust the memory fraction "
542
+ f"(currently {DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION}) or "
543
+ "reduce the calibration `batch_size` manually."
544
+ )
545
+ model_kwargs["max_memory"] = max_memory
546
+
547
+ model = AutoModelForCausalLM.from_pretrained(
548
+ model_config.model_path,
549
+ device_map=device_map,
550
+ **model_kwargs,
551
+ trust_remote_code=True,
552
+ )
553
+ # Handle both legacy modelopt_quant and unified quantization flags
554
+ if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
555
+ # Legacy approach
556
+ quant_choice_str = model_config.modelopt_quant
557
+ rank0_log(f"ModelOpt quantization requested (legacy): {quant_choice_str}")
558
+ else:
559
+ # Unified approach - extract quantization type
560
+ quant_choice_str = model_config._get_modelopt_quant_type()
561
+ rank0_log(
562
+ f"ModelOpt quantization requested (unified): {model_config.quantization} -> {quant_choice_str}"
563
+ )
564
+
565
+ if not isinstance(quant_choice_str, str):
566
+ raise TypeError(
567
+ f"Quantization type must be a string (e.g., 'fp8'), "
568
+ f"got {type(quant_choice_str)}"
569
+ )
570
+
571
+ return model
572
+
477
573
  def load_model(
478
574
  self,
479
575
  *,
480
576
  model_config: ModelConfig,
481
577
  device_config: DeviceConfig,
482
578
  ) -> nn.Module:
579
+
580
+ if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
581
+ # Load base model using shared method
582
+ model = self._load_modelopt_base_model(model_config)
583
+ # Note: DefaultModelLoader doesn't do additional quantization processing
584
+ # For full ModelOpt quantization, use ModelOptModelLoader
585
+ return model.eval()
586
+
483
587
  target_device = torch.device(device_config.device)
484
588
  with set_default_torch_dtype(model_config.dtype):
485
589
  with target_device:
@@ -488,9 +592,9 @@ class DefaultModelLoader(BaseModelLoader):
488
592
  self.load_config,
489
593
  )
490
594
 
491
- self.load_weights_and_postprocess(
492
- model, self._get_all_weights(model_config, model), target_device
493
- )
595
+ self.load_weights_and_postprocess(
596
+ model, self._get_all_weights(model_config, model), target_device
597
+ )
494
598
 
495
599
  return model.eval()
496
600
 
@@ -508,6 +612,8 @@ class DefaultModelLoader(BaseModelLoader):
508
612
  # parameters onto device for processing and back off after.
509
613
  with device_loading_context(module, target_device):
510
614
  quant_method.process_weights_after_loading(module)
615
+ if _is_npu:
616
+ torch.npu.empty_cache()
511
617
 
512
618
 
513
619
  class LayeredModelLoader(DefaultModelLoader):
@@ -526,9 +632,9 @@ class LayeredModelLoader(DefaultModelLoader):
526
632
  device_config: DeviceConfig,
527
633
  ) -> nn.Module:
528
634
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
529
- from sglang.srt.managers.schedule_batch import global_server_args_dict
635
+ from sglang.srt.server_args import get_global_server_args
530
636
 
531
- torchao_config = global_server_args_dict.get("torchao_config")
637
+ torchao_config = get_global_server_args().torchao_config
532
638
  target_device = torch.device(device_config.device)
533
639
 
534
640
  with set_default_torch_dtype(model_config.dtype):
@@ -1417,7 +1523,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
1417
1523
  f"load format {load_config.load_format}"
1418
1524
  )
1419
1525
 
1420
- model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
1526
+ model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
1421
1527
 
1422
1528
  with set_default_torch_dtype(model_config.dtype):
1423
1529
  with torch.device(device_config.device):
@@ -1439,11 +1545,12 @@ class RemoteInstanceModelLoader(BaseModelLoader):
1439
1545
  def load_model_from_remote_instance(
1440
1546
  self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1441
1547
  ) -> nn.Module:
1548
+ load_config = self.load_config
1442
1549
  instance_ip = socket.gethostbyname(socket.gethostname())
1443
1550
  start_build_group_tic = time.time()
1444
1551
  client.build_group(
1445
1552
  gpu_id=device_config.gpu_id,
1446
- tp_rank=model_config.tp_rank,
1553
+ tp_rank=load_config.tp_rank,
1447
1554
  instance_ip=instance_ip,
1448
1555
  )
1449
1556
  torch.cuda.synchronize()
@@ -1452,13 +1559,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
1452
1559
  f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
1453
1560
  )
1454
1561
 
1455
- if model_config.tp_rank == 0:
1562
+ if load_config.tp_rank == 0:
1456
1563
  t = threading.Thread(
1457
1564
  target=trigger_transferring_weights_request,
1458
1565
  args=(
1459
- model_config.remote_instance_weight_loader_seed_instance_ip,
1460
- model_config.remote_instance_weight_loader_seed_instance_service_port,
1461
- model_config.remote_instance_weight_loader_send_weights_group_ports,
1566
+ load_config.remote_instance_weight_loader_seed_instance_ip,
1567
+ load_config.remote_instance_weight_loader_seed_instance_service_port,
1568
+ load_config.remote_instance_weight_loader_send_weights_group_ports,
1462
1569
  instance_ip,
1463
1570
  ),
1464
1571
  )
@@ -1664,9 +1771,303 @@ def load_model_with_cpu_quantization(
1664
1771
  return model.eval()
1665
1772
 
1666
1773
 
1667
- def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1774
+ class ModelOptModelLoader(DefaultModelLoader):
1775
+ """
1776
+ Model loader that applies NVIDIA Model Optimizer quantization
1777
+ """
1778
+
1779
+ def __init__(self, load_config: LoadConfig):
1780
+ super().__init__(load_config)
1781
+ # Any ModelOpt specific initialization if needed
1782
+
1783
+ def _setup_modelopt_quantization(
1784
+ self,
1785
+ model,
1786
+ tokenizer,
1787
+ quant_cfg,
1788
+ quantized_ckpt_restore_path: str | None = None,
1789
+ quantized_ckpt_save_path: str | None = None,
1790
+ export_path: str | None = None,
1791
+ ) -> None:
1792
+ """
1793
+ Set up ModelOpt quantization for the given model.
1794
+
1795
+ Args:
1796
+ model: The model to quantize
1797
+ tokenizer: The tokenizer associated with the model
1798
+ quant_cfg: The quantization configuration
1799
+ quantized_ckpt_restore_path: Path to restore quantized checkpoint from
1800
+ quantized_ckpt_save_path: Path to save quantized checkpoint to
1801
+ export_path: Path to export the quantized model in HuggingFace format
1802
+
1803
+ Raises:
1804
+ ImportError: If ModelOpt is not available
1805
+ Exception: If quantization setup fails
1806
+ """
1807
+ try:
1808
+ import modelopt.torch.opt as mto
1809
+ import modelopt.torch.quantization as mtq
1810
+ from modelopt.torch.quantization.utils import is_quantized
1811
+ except ImportError as e:
1812
+ raise ImportError(
1813
+ "ModelOpt is not available. Please install modelopt."
1814
+ ) from e
1815
+
1816
+ if is_quantized(model):
1817
+ rank0_log("Model is already quantized, skipping quantization setup.")
1818
+ return
1819
+ # Restore from checkpoint if provided
1820
+ if quantized_ckpt_restore_path:
1821
+ try:
1822
+ mto.restore(model, quantized_ckpt_restore_path)
1823
+ rank0_log(
1824
+ f"Restored quantized model from {quantized_ckpt_restore_path}"
1825
+ )
1826
+
1827
+ # Export model if path provided (even when restoring from checkpoint)
1828
+ self._maybe_export_modelopt(model, export_path)
1829
+ return
1830
+ except Exception as e:
1831
+ logger.warning(
1832
+ f"Failed to restore from {quantized_ckpt_restore_path}: {e}"
1833
+ )
1834
+ rank0_log("Proceeding with calibration-based quantization...")
1835
+
1836
+ # Set up calibration-based quantization
1837
+ try:
1838
+ # Left padding tends to work better for batched generation with decoder-only LMs
1839
+ with suppress(Exception):
1840
+ tokenizer.padding_side = "left"
1841
+
1842
+ from modelopt.torch.utils.dataset_utils import (
1843
+ create_forward_loop,
1844
+ get_dataset_dataloader,
1845
+ )
1846
+
1847
+ # Create calibration dataloader
1848
+ calib_dataloader = get_dataset_dataloader(
1849
+ dataset_name="cnn_dailymail", # TODO: Consider making this configurable
1850
+ tokenizer=tokenizer,
1851
+ batch_size=36, # TODO: Consider making this configurable
1852
+ num_samples=512, # TODO: Consider making this configurable
1853
+ device=model.device,
1854
+ include_labels=False,
1855
+ )
1856
+
1857
+ calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
1858
+
1859
+ # Apply quantization
1860
+ mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
1861
+
1862
+ if get_tensor_model_parallel_rank() == 0:
1863
+ mtq.print_quant_summary(model)
1864
+
1865
+ # Save checkpoint if path provided
1866
+ if quantized_ckpt_save_path:
1867
+ try:
1868
+ mto.save(model, quantized_ckpt_save_path)
1869
+ rank0_log(f"Quantized model saved to {quantized_ckpt_save_path}")
1870
+ except Exception as e:
1871
+ logger.warning(
1872
+ f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
1873
+ )
1874
+
1875
+ # Export model if path provided
1876
+ self._maybe_export_modelopt(model, export_path)
1877
+
1878
+ except Exception as e:
1879
+ raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
1880
+
1881
+ def _maybe_export_modelopt(self, model, export_path: str | None) -> None:
1882
+ """Export model to HuggingFace format if export_path is provided."""
1883
+ if export_path:
1884
+ try:
1885
+ # Get the original model path from the model config
1886
+ original_model_path = getattr(self, "_original_model_path", None)
1887
+ self._export_modelopt_checkpoint(
1888
+ model, export_path, original_model_path
1889
+ )
1890
+ rank0_log(
1891
+ f"Quantized model exported to HuggingFace format at {export_path}"
1892
+ )
1893
+ except Exception as e:
1894
+ rank0_log(
1895
+ f"Warning: Failed to export quantized model to {export_path}: {e}"
1896
+ )
1897
+
1898
+ def _export_modelopt_checkpoint(
1899
+ self,
1900
+ model,
1901
+ export_path: str,
1902
+ model_path: str = None,
1903
+ trust_remote_code: bool = True,
1904
+ ) -> None:
1905
+ """
1906
+ Export the quantized model to HuggingFace format using ModelOpt export API.
1907
+
1908
+ Args:
1909
+ model: The quantized model to export
1910
+ export_path: Directory path to export the model to
1911
+ model_path: Path to the original model (for tokenizer export)
1912
+ trust_remote_code: Whether to trust remote code for tokenizer loading
1913
+
1914
+ Raises:
1915
+ ImportError: If ModelOpt export functionality is not available
1916
+ Exception: If export fails
1917
+ """
1918
+ try:
1919
+ from modelopt.torch.export import export_hf_checkpoint
1920
+ from transformers import AutoTokenizer
1921
+ except ImportError as e:
1922
+ raise ImportError(
1923
+ "ModelOpt export functionality is not available. "
1924
+ "Please ensure you have the latest version of modelopt installed."
1925
+ ) from e
1926
+
1927
+ # Create export directory if it doesn't exist
1928
+ os.makedirs(export_path, exist_ok=True)
1929
+
1930
+ # Export the quantized model
1931
+ export_hf_checkpoint(model, export_dir=export_path)
1932
+
1933
+ # Export the tokenizer if model_path is provided
1934
+ if model_path:
1935
+ try:
1936
+ tokenizer = AutoTokenizer.from_pretrained(
1937
+ model_path, trust_remote_code=trust_remote_code
1938
+ )
1939
+ tokenizer.save_pretrained(export_path)
1940
+ rank0_log(f"Tokenizer exported to {export_path}")
1941
+ except Exception as e:
1942
+ rank0_log(f"Warning: Failed to export tokenizer: {e}")
1943
+
1944
+ def load_model(
1945
+ self,
1946
+ *,
1947
+ model_config: ModelConfig,
1948
+ device_config: DeviceConfig,
1949
+ ) -> nn.Module:
1950
+
1951
+ logger.info("ModelOptModelLoader: Loading base model...")
1952
+
1953
+ # Store the original model path for tokenizer export
1954
+ self._original_model_path = model_config.model_path
1955
+
1956
+ # Check if model is already quantized
1957
+ if model_config._is_already_quantized():
1958
+ logger.info("Model is already quantized, loading directly...")
1959
+ # Use default loading for pre-quantized models
1960
+ return super().load_model(
1961
+ model_config=model_config, device_config=device_config
1962
+ )
1963
+
1964
+ # TODO: Quantize-and-serve mode has been disabled at the ModelConfig level
1965
+ # All quantization now uses the standard workflow (quantize + export/save)
1966
+ logger.info("Standard quantization mode: Will quantize and export/save")
1967
+ return self._standard_quantization_workflow(model_config, device_config)
1968
+
1969
+ def _standard_quantization_workflow(
1970
+ self, model_config: ModelConfig, device_config: DeviceConfig
1971
+ ) -> nn.Module:
1972
+ """Standard quantization workflow: quantize, save checkpoint, export, then return model."""
1973
+ # Use shared method from parent class to load base model for quantization
1974
+ model = self._load_modelopt_base_model(model_config)
1975
+
1976
+ # Import ModelOpt modules
1977
+ try:
1978
+ import modelopt.torch.quantization as mtq
1979
+ except ImportError:
1980
+ logger.error(
1981
+ "NVIDIA Model Optimizer (modelopt) library not found. "
1982
+ "Please install it to use ModelOpt quantization."
1983
+ )
1984
+ raise
1985
+
1986
+ # Handle both old modelopt_quant and new unified quantization flags
1987
+ if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
1988
+ # Legacy modelopt_quant flag
1989
+ quant_choice_str = model_config.modelopt_quant
1990
+ else:
1991
+ # Unified quantization flag - extract the type (fp8/fp4)
1992
+ quant_choice_str = model_config._get_modelopt_quant_type()
1993
+
1994
+ quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
1995
+ if not quant_cfg_name:
1996
+ raise ValueError(
1997
+ f"Invalid quantization choice: '{quant_choice_str}'. "
1998
+ f"Available choices: {list(QUANT_CFG_CHOICES.keys())}"
1999
+ )
2000
+
2001
+ try:
2002
+ # getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
2003
+ quant_cfg = getattr(mtq, quant_cfg_name)
2004
+ except AttributeError:
2005
+ raise AttributeError(
2006
+ f"ModelOpt quantization config '{quant_cfg_name}' not found. "
2007
+ "Please verify the ModelOpt library installation."
2008
+ )
2009
+
2010
+ logger.info(
2011
+ f"Quantizing model with ModelOpt using config: mtq.{quant_cfg_name}"
2012
+ )
2013
+
2014
+ # Get ModelOpt configuration from LoadConfig
2015
+ modelopt_config = self.load_config.modelopt_config
2016
+ quantized_ckpt_restore_path = (
2017
+ modelopt_config.checkpoint_restore_path if modelopt_config else None
2018
+ )
2019
+ quantized_ckpt_save_path = (
2020
+ modelopt_config.checkpoint_save_path if modelopt_config else None
2021
+ )
2022
+ export_path = modelopt_config.export_path if modelopt_config else None
2023
+ tokenizer = AutoTokenizer.from_pretrained(
2024
+ model_config.model_path, use_fast=True
2025
+ )
2026
+
2027
+ try:
2028
+ self._setup_modelopt_quantization(
2029
+ model,
2030
+ tokenizer,
2031
+ quant_cfg,
2032
+ quantized_ckpt_restore_path=quantized_ckpt_restore_path,
2033
+ quantized_ckpt_save_path=quantized_ckpt_save_path,
2034
+ export_path=export_path,
2035
+ )
2036
+ except Exception as e:
2037
+ logger.warning(f"ModelOpt quantization failed: {e}")
2038
+ rank0_log("Proceeding without quantization...")
2039
+
2040
+ return model.eval()
2041
+
2042
+
2043
+ def get_model_loader(
2044
+ load_config: LoadConfig, model_config: Optional[ModelConfig] = None
2045
+ ) -> BaseModelLoader:
1668
2046
  """Get a model loader based on the load format."""
1669
2047
 
2048
+ if model_config and (
2049
+ (hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant)
2050
+ or model_config.quantization in ["modelopt_fp8", "modelopt_fp4", "modelopt"]
2051
+ ):
2052
+ logger.info("Using ModelOptModelLoader due to ModelOpt quantization config.")
2053
+ return ModelOptModelLoader(load_config)
2054
+
2055
+ # Use ModelOptModelLoader for unified quantization flags
2056
+ if (
2057
+ model_config
2058
+ and hasattr(model_config, "quantization")
2059
+ and model_config.quantization in ["modelopt_fp8", "modelopt_fp4"]
2060
+ ):
2061
+ if model_config._is_already_quantized():
2062
+ logger.info(
2063
+ f"Using ModelOptModelLoader for pre-quantized model: {model_config.quantization}"
2064
+ )
2065
+ else:
2066
+ logger.info(
2067
+ f"Using ModelOptModelLoader for quantization: {model_config.quantization}"
2068
+ )
2069
+ return ModelOptModelLoader(load_config)
2070
+
1670
2071
  if isinstance(load_config.load_format, type):
1671
2072
  return load_config.load_format(load_config)
1672
2073
 
@@ -99,7 +99,6 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
99
99
 
100
100
  if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS:
101
101
  architectures = resolve_transformers_arch(model_config, architectures)
102
-
103
102
  return ModelRegistry.resolve_model_cls(architectures)
104
103
 
105
104