sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,6 @@ from __future__ import annotations
4
4
 
5
5
  # ruff: noqa: SIM117
6
6
  import collections
7
- import concurrent
8
7
  import dataclasses
9
8
  import fnmatch
10
9
  import glob
@@ -12,13 +11,11 @@ import json
12
11
  import logging
13
12
  import math
14
13
  import os
15
- import re
16
14
  import socket
17
15
  import threading
18
16
  import time
19
17
  from abc import ABC, abstractmethod
20
- from concurrent.futures import ThreadPoolExecutor
21
- from contextlib import contextmanager
18
+ from contextlib import contextmanager, suppress
22
19
  from typing import (
23
20
  TYPE_CHECKING,
24
21
  Any,
@@ -30,17 +27,28 @@ from typing import (
30
27
  Tuple,
31
28
  cast,
32
29
  )
33
- from urllib.parse import urlparse
34
30
 
35
31
  import huggingface_hub
36
32
  import numpy as np
37
- import requests
38
- import safetensors.torch
39
33
  import torch
34
+
35
+ from sglang.srt.server_args import get_global_server_args
36
+
37
+ # Try to import accelerate (optional dependency)
38
+ try:
39
+ from accelerate import infer_auto_device_map, init_empty_weights
40
+ from accelerate.utils import get_max_memory
41
+
42
+ HAS_ACCELERATE = True
43
+ except ImportError:
44
+ HAS_ACCELERATE = False
45
+ infer_auto_device_map = None
46
+ init_empty_weights = None
47
+ get_max_memory = None
48
+
40
49
  from huggingface_hub import HfApi, hf_hub_download
41
50
  from torch import nn
42
- from tqdm.auto import tqdm
43
- from transformers import AutoModelForCausalLM
51
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
44
52
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
45
53
 
46
54
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
@@ -54,6 +62,8 @@ from sglang.srt.distributed import (
54
62
  get_tensor_model_parallel_rank,
55
63
  get_tensor_model_parallel_world_size,
56
64
  )
65
+ from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES
66
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
57
67
  from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
58
68
  trigger_transferring_weights_request,
59
69
  )
@@ -62,9 +72,13 @@ from sglang.srt.model_loader.utils import (
62
72
  post_load_weights,
63
73
  set_default_torch_dtype,
64
74
  )
75
+
76
+ # Constants for memory management
77
+ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
78
+ 0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
79
+ )
80
+ from sglang.srt.environ import envs
65
81
  from sglang.srt.model_loader.weight_utils import (
66
- _BAR_FORMAT,
67
- default_weight_loader,
68
82
  download_safetensors_index_file_from_hf,
69
83
  download_weights_from_hf,
70
84
  filter_duplicate_safetensors_files,
@@ -85,6 +99,7 @@ from sglang.srt.utils import (
85
99
  get_device_capability,
86
100
  is_npu,
87
101
  is_pin_memory_available,
102
+ rank0_log,
88
103
  set_weight_attrs,
89
104
  )
90
105
 
@@ -94,6 +109,8 @@ if TYPE_CHECKING:
94
109
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
95
110
 
96
111
  _is_npu = is_npu()
112
+ # ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
113
+ # which contains the complete mapping of quantization config choices
97
114
 
98
115
 
99
116
  @contextmanager
@@ -163,11 +180,12 @@ def _get_quantization_config(
163
180
  model_config: ModelConfig,
164
181
  load_config: LoadConfig,
165
182
  packed_modules_mapping: Dict[str, List[str]],
183
+ remap_prefix: Dict[str, str] | None = None,
166
184
  ) -> Optional[QuantizationConfig]:
167
185
  """Get the quantization config."""
168
186
  if model_config.quantization is not None:
169
187
  quant_config = get_quant_config(
170
- model_config, load_config, packed_modules_mapping
188
+ model_config, load_config, packed_modules_mapping, remap_prefix
171
189
  )
172
190
  # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
173
191
  if quant_config is None:
@@ -203,6 +221,7 @@ def _initialize_model(
203
221
  """Initialize a model with the given configurations."""
204
222
  model_class, _ = get_model_architecture(model_config)
205
223
  packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
224
+ remap_prefix = getattr(model_class, "remap_prefix", None)
206
225
  if _is_npu:
207
226
  packed_modules_mapping.update(
208
227
  {
@@ -226,13 +245,22 @@ def _initialize_model(
226
245
  )
227
246
 
228
247
  quant_config = _get_quantization_config(
229
- model_config, load_config, packed_modules_mapping
230
- )
231
- return model_class(
232
- config=model_config.hf_config,
233
- quant_config=quant_config,
248
+ model_config, load_config, packed_modules_mapping, remap_prefix
234
249
  )
235
250
 
251
+ # Build kwargs conditionally
252
+ kwargs = {
253
+ "config": model_config.hf_config,
254
+ "quant_config": quant_config,
255
+ }
256
+
257
+ # Only add sparse head kwargs if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set()
258
+ if envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.is_set():
259
+ kwargs["sparse_head"] = envs.SGLANG_EMBEDDINGS_SPARSE_HEAD.value
260
+ kwargs["model_path"] = model_config.model_path
261
+
262
+ return model_class(**kwargs)
263
+
236
264
 
237
265
  class BaseModelLoader(ABC):
238
266
  """Base class for model loaders."""
@@ -424,10 +452,8 @@ class DefaultModelLoader(BaseModelLoader):
424
452
  hf_weights_files,
425
453
  )
426
454
  elif use_safetensors:
427
- from sglang.srt.managers.schedule_batch import global_server_args_dict
428
-
429
- weight_loader_disable_mmap = global_server_args_dict.get(
430
- "weight_loader_disable_mmap"
455
+ weight_loader_disable_mmap = (
456
+ get_global_server_args().weight_loader_disable_mmap
431
457
  )
432
458
 
433
459
  if extra_config.get("enable_multithread_load"):
@@ -477,12 +503,87 @@ class DefaultModelLoader(BaseModelLoader):
477
503
  model_config.model_path, model_config.revision, fall_back_to_pt=True
478
504
  )
479
505
 
506
+ def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
507
+ """Load and prepare the base model for ModelOpt quantization.
508
+
509
+ This method handles the common model loading logic shared between
510
+ DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
511
+ """
512
+ if not HAS_ACCELERATE:
513
+ raise ImportError(
514
+ "accelerate is required for ModelOpt quantization. "
515
+ "Please install it with: pip install accelerate"
516
+ )
517
+
518
+ hf_config = AutoConfig.from_pretrained(
519
+ model_config.model_path, trust_remote_code=True
520
+ )
521
+ with init_empty_weights():
522
+ torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
523
+ model = AutoModelForCausalLM.from_config(
524
+ hf_config, torch_dtype=torch_dtype, trust_remote_code=True
525
+ )
526
+ max_memory = get_max_memory()
527
+ inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
528
+
529
+ on_cpu = "cpu" in inferred_device_map.values()
530
+ model_kwargs = {"torch_dtype": "auto"}
531
+ device_map = "auto"
532
+
533
+ if on_cpu:
534
+ for device in max_memory.keys():
535
+ if isinstance(device, int):
536
+ max_memory[device] *= DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
537
+
538
+ logger.warning(
539
+ "Model does not fit to the GPU mem. "
540
+ f"We apply the following memory limit for calibration: \n{max_memory}\n"
541
+ f"If you hit GPU OOM issue, please adjust the memory fraction "
542
+ f"(currently {DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION}) or "
543
+ "reduce the calibration `batch_size` manually."
544
+ )
545
+ model_kwargs["max_memory"] = max_memory
546
+
547
+ model = AutoModelForCausalLM.from_pretrained(
548
+ model_config.model_path,
549
+ device_map=device_map,
550
+ **model_kwargs,
551
+ trust_remote_code=True,
552
+ )
553
+ # Handle both legacy modelopt_quant and unified quantization flags
554
+ if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
555
+ # Legacy approach
556
+ quant_choice_str = model_config.modelopt_quant
557
+ rank0_log(f"ModelOpt quantization requested (legacy): {quant_choice_str}")
558
+ else:
559
+ # Unified approach - extract quantization type
560
+ quant_choice_str = model_config._get_modelopt_quant_type()
561
+ rank0_log(
562
+ f"ModelOpt quantization requested (unified): {model_config.quantization} -> {quant_choice_str}"
563
+ )
564
+
565
+ if not isinstance(quant_choice_str, str):
566
+ raise TypeError(
567
+ f"Quantization type must be a string (e.g., 'fp8'), "
568
+ f"got {type(quant_choice_str)}"
569
+ )
570
+
571
+ return model
572
+
480
573
  def load_model(
481
574
  self,
482
575
  *,
483
576
  model_config: ModelConfig,
484
577
  device_config: DeviceConfig,
485
578
  ) -> nn.Module:
579
+
580
+ if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
581
+ # Load base model using shared method
582
+ model = self._load_modelopt_base_model(model_config)
583
+ # Note: DefaultModelLoader doesn't do additional quantization processing
584
+ # For full ModelOpt quantization, use ModelOptModelLoader
585
+ return model.eval()
586
+
486
587
  target_device = torch.device(device_config.device)
487
588
  with set_default_torch_dtype(model_config.dtype):
488
589
  with target_device:
@@ -491,9 +592,9 @@ class DefaultModelLoader(BaseModelLoader):
491
592
  self.load_config,
492
593
  )
493
594
 
494
- self.load_weights_and_postprocess(
495
- model, self._get_all_weights(model_config, model), target_device
496
- )
595
+ self.load_weights_and_postprocess(
596
+ model, self._get_all_weights(model_config, model), target_device
597
+ )
497
598
 
498
599
  return model.eval()
499
600
 
@@ -511,6 +612,8 @@ class DefaultModelLoader(BaseModelLoader):
511
612
  # parameters onto device for processing and back off after.
512
613
  with device_loading_context(module, target_device):
513
614
  quant_method.process_weights_after_loading(module)
615
+ if _is_npu:
616
+ torch.npu.empty_cache()
514
617
 
515
618
 
516
619
  class LayeredModelLoader(DefaultModelLoader):
@@ -529,9 +632,9 @@ class LayeredModelLoader(DefaultModelLoader):
529
632
  device_config: DeviceConfig,
530
633
  ) -> nn.Module:
531
634
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
532
- from sglang.srt.managers.schedule_batch import global_server_args_dict
635
+ from sglang.srt.server_args import get_global_server_args
533
636
 
534
- torchao_config = global_server_args_dict.get("torchao_config")
637
+ torchao_config = get_global_server_args().torchao_config
535
638
  target_device = torch.device(device_config.device)
536
639
 
537
640
  with set_default_torch_dtype(model_config.dtype):
@@ -1668,9 +1771,303 @@ def load_model_with_cpu_quantization(
1668
1771
  return model.eval()
1669
1772
 
1670
1773
 
1671
- def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1774
+ class ModelOptModelLoader(DefaultModelLoader):
1775
+ """
1776
+ Model loader that applies NVIDIA Model Optimizer quantization
1777
+ """
1778
+
1779
+ def __init__(self, load_config: LoadConfig):
1780
+ super().__init__(load_config)
1781
+ # Any ModelOpt specific initialization if needed
1782
+
1783
+ def _setup_modelopt_quantization(
1784
+ self,
1785
+ model,
1786
+ tokenizer,
1787
+ quant_cfg,
1788
+ quantized_ckpt_restore_path: str | None = None,
1789
+ quantized_ckpt_save_path: str | None = None,
1790
+ export_path: str | None = None,
1791
+ ) -> None:
1792
+ """
1793
+ Set up ModelOpt quantization for the given model.
1794
+
1795
+ Args:
1796
+ model: The model to quantize
1797
+ tokenizer: The tokenizer associated with the model
1798
+ quant_cfg: The quantization configuration
1799
+ quantized_ckpt_restore_path: Path to restore quantized checkpoint from
1800
+ quantized_ckpt_save_path: Path to save quantized checkpoint to
1801
+ export_path: Path to export the quantized model in HuggingFace format
1802
+
1803
+ Raises:
1804
+ ImportError: If ModelOpt is not available
1805
+ Exception: If quantization setup fails
1806
+ """
1807
+ try:
1808
+ import modelopt.torch.opt as mto
1809
+ import modelopt.torch.quantization as mtq
1810
+ from modelopt.torch.quantization.utils import is_quantized
1811
+ except ImportError as e:
1812
+ raise ImportError(
1813
+ "ModelOpt is not available. Please install modelopt."
1814
+ ) from e
1815
+
1816
+ if is_quantized(model):
1817
+ rank0_log("Model is already quantized, skipping quantization setup.")
1818
+ return
1819
+ # Restore from checkpoint if provided
1820
+ if quantized_ckpt_restore_path:
1821
+ try:
1822
+ mto.restore(model, quantized_ckpt_restore_path)
1823
+ rank0_log(
1824
+ f"Restored quantized model from {quantized_ckpt_restore_path}"
1825
+ )
1826
+
1827
+ # Export model if path provided (even when restoring from checkpoint)
1828
+ self._maybe_export_modelopt(model, export_path)
1829
+ return
1830
+ except Exception as e:
1831
+ logger.warning(
1832
+ f"Failed to restore from {quantized_ckpt_restore_path}: {e}"
1833
+ )
1834
+ rank0_log("Proceeding with calibration-based quantization...")
1835
+
1836
+ # Set up calibration-based quantization
1837
+ try:
1838
+ # Left padding tends to work better for batched generation with decoder-only LMs
1839
+ with suppress(Exception):
1840
+ tokenizer.padding_side = "left"
1841
+
1842
+ from modelopt.torch.utils.dataset_utils import (
1843
+ create_forward_loop,
1844
+ get_dataset_dataloader,
1845
+ )
1846
+
1847
+ # Create calibration dataloader
1848
+ calib_dataloader = get_dataset_dataloader(
1849
+ dataset_name="cnn_dailymail", # TODO: Consider making this configurable
1850
+ tokenizer=tokenizer,
1851
+ batch_size=36, # TODO: Consider making this configurable
1852
+ num_samples=512, # TODO: Consider making this configurable
1853
+ device=model.device,
1854
+ include_labels=False,
1855
+ )
1856
+
1857
+ calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
1858
+
1859
+ # Apply quantization
1860
+ mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
1861
+
1862
+ if get_tensor_model_parallel_rank() == 0:
1863
+ mtq.print_quant_summary(model)
1864
+
1865
+ # Save checkpoint if path provided
1866
+ if quantized_ckpt_save_path:
1867
+ try:
1868
+ mto.save(model, quantized_ckpt_save_path)
1869
+ rank0_log(f"Quantized model saved to {quantized_ckpt_save_path}")
1870
+ except Exception as e:
1871
+ logger.warning(
1872
+ f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
1873
+ )
1874
+
1875
+ # Export model if path provided
1876
+ self._maybe_export_modelopt(model, export_path)
1877
+
1878
+ except Exception as e:
1879
+ raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
1880
+
1881
+ def _maybe_export_modelopt(self, model, export_path: str | None) -> None:
1882
+ """Export model to HuggingFace format if export_path is provided."""
1883
+ if export_path:
1884
+ try:
1885
+ # Get the original model path from the model config
1886
+ original_model_path = getattr(self, "_original_model_path", None)
1887
+ self._export_modelopt_checkpoint(
1888
+ model, export_path, original_model_path
1889
+ )
1890
+ rank0_log(
1891
+ f"Quantized model exported to HuggingFace format at {export_path}"
1892
+ )
1893
+ except Exception as e:
1894
+ rank0_log(
1895
+ f"Warning: Failed to export quantized model to {export_path}: {e}"
1896
+ )
1897
+
1898
+ def _export_modelopt_checkpoint(
1899
+ self,
1900
+ model,
1901
+ export_path: str,
1902
+ model_path: str = None,
1903
+ trust_remote_code: bool = True,
1904
+ ) -> None:
1905
+ """
1906
+ Export the quantized model to HuggingFace format using ModelOpt export API.
1907
+
1908
+ Args:
1909
+ model: The quantized model to export
1910
+ export_path: Directory path to export the model to
1911
+ model_path: Path to the original model (for tokenizer export)
1912
+ trust_remote_code: Whether to trust remote code for tokenizer loading
1913
+
1914
+ Raises:
1915
+ ImportError: If ModelOpt export functionality is not available
1916
+ Exception: If export fails
1917
+ """
1918
+ try:
1919
+ from modelopt.torch.export import export_hf_checkpoint
1920
+ from transformers import AutoTokenizer
1921
+ except ImportError as e:
1922
+ raise ImportError(
1923
+ "ModelOpt export functionality is not available. "
1924
+ "Please ensure you have the latest version of modelopt installed."
1925
+ ) from e
1926
+
1927
+ # Create export directory if it doesn't exist
1928
+ os.makedirs(export_path, exist_ok=True)
1929
+
1930
+ # Export the quantized model
1931
+ export_hf_checkpoint(model, export_dir=export_path)
1932
+
1933
+ # Export the tokenizer if model_path is provided
1934
+ if model_path:
1935
+ try:
1936
+ tokenizer = AutoTokenizer.from_pretrained(
1937
+ model_path, trust_remote_code=trust_remote_code
1938
+ )
1939
+ tokenizer.save_pretrained(export_path)
1940
+ rank0_log(f"Tokenizer exported to {export_path}")
1941
+ except Exception as e:
1942
+ rank0_log(f"Warning: Failed to export tokenizer: {e}")
1943
+
1944
+ def load_model(
1945
+ self,
1946
+ *,
1947
+ model_config: ModelConfig,
1948
+ device_config: DeviceConfig,
1949
+ ) -> nn.Module:
1950
+
1951
+ logger.info("ModelOptModelLoader: Loading base model...")
1952
+
1953
+ # Store the original model path for tokenizer export
1954
+ self._original_model_path = model_config.model_path
1955
+
1956
+ # Check if model is already quantized
1957
+ if model_config._is_already_quantized():
1958
+ logger.info("Model is already quantized, loading directly...")
1959
+ # Use default loading for pre-quantized models
1960
+ return super().load_model(
1961
+ model_config=model_config, device_config=device_config
1962
+ )
1963
+
1964
+ # TODO: Quantize-and-serve mode has been disabled at the ModelConfig level
1965
+ # All quantization now uses the standard workflow (quantize + export/save)
1966
+ logger.info("Standard quantization mode: Will quantize and export/save")
1967
+ return self._standard_quantization_workflow(model_config, device_config)
1968
+
1969
+ def _standard_quantization_workflow(
1970
+ self, model_config: ModelConfig, device_config: DeviceConfig
1971
+ ) -> nn.Module:
1972
+ """Standard quantization workflow: quantize, save checkpoint, export, then return model."""
1973
+ # Use shared method from parent class to load base model for quantization
1974
+ model = self._load_modelopt_base_model(model_config)
1975
+
1976
+ # Import ModelOpt modules
1977
+ try:
1978
+ import modelopt.torch.quantization as mtq
1979
+ except ImportError:
1980
+ logger.error(
1981
+ "NVIDIA Model Optimizer (modelopt) library not found. "
1982
+ "Please install it to use ModelOpt quantization."
1983
+ )
1984
+ raise
1985
+
1986
+ # Handle both old modelopt_quant and new unified quantization flags
1987
+ if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
1988
+ # Legacy modelopt_quant flag
1989
+ quant_choice_str = model_config.modelopt_quant
1990
+ else:
1991
+ # Unified quantization flag - extract the type (fp8/fp4)
1992
+ quant_choice_str = model_config._get_modelopt_quant_type()
1993
+
1994
+ quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
1995
+ if not quant_cfg_name:
1996
+ raise ValueError(
1997
+ f"Invalid quantization choice: '{quant_choice_str}'. "
1998
+ f"Available choices: {list(QUANT_CFG_CHOICES.keys())}"
1999
+ )
2000
+
2001
+ try:
2002
+ # getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
2003
+ quant_cfg = getattr(mtq, quant_cfg_name)
2004
+ except AttributeError:
2005
+ raise AttributeError(
2006
+ f"ModelOpt quantization config '{quant_cfg_name}' not found. "
2007
+ "Please verify the ModelOpt library installation."
2008
+ )
2009
+
2010
+ logger.info(
2011
+ f"Quantizing model with ModelOpt using config: mtq.{quant_cfg_name}"
2012
+ )
2013
+
2014
+ # Get ModelOpt configuration from LoadConfig
2015
+ modelopt_config = self.load_config.modelopt_config
2016
+ quantized_ckpt_restore_path = (
2017
+ modelopt_config.checkpoint_restore_path if modelopt_config else None
2018
+ )
2019
+ quantized_ckpt_save_path = (
2020
+ modelopt_config.checkpoint_save_path if modelopt_config else None
2021
+ )
2022
+ export_path = modelopt_config.export_path if modelopt_config else None
2023
+ tokenizer = AutoTokenizer.from_pretrained(
2024
+ model_config.model_path, use_fast=True
2025
+ )
2026
+
2027
+ try:
2028
+ self._setup_modelopt_quantization(
2029
+ model,
2030
+ tokenizer,
2031
+ quant_cfg,
2032
+ quantized_ckpt_restore_path=quantized_ckpt_restore_path,
2033
+ quantized_ckpt_save_path=quantized_ckpt_save_path,
2034
+ export_path=export_path,
2035
+ )
2036
+ except Exception as e:
2037
+ logger.warning(f"ModelOpt quantization failed: {e}")
2038
+ rank0_log("Proceeding without quantization...")
2039
+
2040
+ return model.eval()
2041
+
2042
+
2043
+ def get_model_loader(
2044
+ load_config: LoadConfig, model_config: Optional[ModelConfig] = None
2045
+ ) -> BaseModelLoader:
1672
2046
  """Get a model loader based on the load format."""
1673
2047
 
2048
+ if model_config and (
2049
+ (hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant)
2050
+ or model_config.quantization in ["modelopt_fp8", "modelopt_fp4", "modelopt"]
2051
+ ):
2052
+ logger.info("Using ModelOptModelLoader due to ModelOpt quantization config.")
2053
+ return ModelOptModelLoader(load_config)
2054
+
2055
+ # Use ModelOptModelLoader for unified quantization flags
2056
+ if (
2057
+ model_config
2058
+ and hasattr(model_config, "quantization")
2059
+ and model_config.quantization in ["modelopt_fp8", "modelopt_fp4"]
2060
+ ):
2061
+ if model_config._is_already_quantized():
2062
+ logger.info(
2063
+ f"Using ModelOptModelLoader for pre-quantized model: {model_config.quantization}"
2064
+ )
2065
+ else:
2066
+ logger.info(
2067
+ f"Using ModelOptModelLoader for quantization: {model_config.quantization}"
2068
+ )
2069
+ return ModelOptModelLoader(load_config)
2070
+
1674
2071
  if isinstance(load_config.load_format, type):
1675
2072
  return load_config.load_format(load_config)
1676
2073
 
@@ -99,7 +99,6 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
99
99
 
100
100
  if not is_native_supported or model_config.model_impl == ModelImpl.TRANSFORMERS:
101
101
  architectures = resolve_transformers_arch(model_config, architectures)
102
-
103
102
  return ModelRegistry.resolve_model_cls(architectures)
104
103
 
105
104