sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -19,8 +19,10 @@ import logging
19
19
  import threading
20
20
  from typing import TYPE_CHECKING, Optional, Union
21
21
 
22
+ import numpy as np
22
23
  import torch
23
24
 
25
+ from sglang.srt.configs.model_config import AttentionArch
24
26
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
25
27
 
26
28
  logger = logging.getLogger(__name__)
@@ -73,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner):
73
75
  self.positions[: self.raw_num_token].copy_(forward_batch.positions)
74
76
 
75
77
  # Replay
76
- seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
77
- thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
78
- thread.start()
79
- self.graphs[self.bs].replay()
80
- thread.join()
78
+ if self.model_runner.model_config.index_head_dim is None:
79
+ seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
80
+ self.bs - self.raw_bs
81
+ )
82
+ thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
83
+ thread.start()
84
+ self.graphs[self.bs].replay()
85
+ thread.join()
86
+ else:
87
+ self.graphs[self.bs].replay()
81
88
 
82
89
  output = self.output_buffers[self.bs]
83
90
  if isinstance(output, LogitsProcessorOutput):
@@ -1,16 +1,22 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
2
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
3
7
  from torch import nn
4
8
 
5
- from sglang.srt.configs.device_config import DeviceConfig
6
- from sglang.srt.configs.load_config import LoadConfig
7
- from sglang.srt.configs.model_config import ModelConfig
8
9
  from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
9
10
  from sglang.srt.model_loader.utils import (
10
11
  get_architecture_class_name,
11
12
  get_model_architecture,
12
13
  )
13
14
 
15
+ if TYPE_CHECKING:
16
+ from sglang.srt.configs.device_config import DeviceConfig
17
+ from sglang.srt.configs.load_config import LoadConfig
18
+ from sglang.srt.configs.model_config import ModelConfig
19
+
14
20
 
15
21
  def get_model(
16
22
  *,
@@ -18,7 +24,7 @@ def get_model(
18
24
  load_config: LoadConfig,
19
25
  device_config: DeviceConfig,
20
26
  ) -> nn.Module:
21
- loader = get_model_loader(load_config)
27
+ loader = get_model_loader(load_config, model_config)
22
28
  return loader.load_model(
23
29
  model_config=model_config,
24
30
  device_config=device_config,
@@ -1,5 +1,7 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  # ruff: noqa: SIM117
4
6
  import collections
5
7
  import concurrent
@@ -10,25 +12,50 @@ import json
10
12
  import logging
11
13
  import math
12
14
  import os
15
+ import re
16
+ import socket
17
+ import threading
13
18
  import time
14
19
  from abc import ABC, abstractmethod
15
20
  from concurrent.futures import ThreadPoolExecutor
16
21
  from contextlib import contextmanager
17
- from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
22
+ from typing import (
23
+ TYPE_CHECKING,
24
+ Any,
25
+ Dict,
26
+ Generator,
27
+ Iterable,
28
+ List,
29
+ Optional,
30
+ Tuple,
31
+ cast,
32
+ )
33
+ from urllib.parse import urlparse
18
34
 
19
35
  import huggingface_hub
20
36
  import numpy as np
37
+ import requests
21
38
  import safetensors.torch
22
39
  import torch
40
+
41
+ # Try to import accelerate (optional dependency)
42
+ try:
43
+ from accelerate import infer_auto_device_map, init_empty_weights
44
+ from accelerate.utils import get_max_memory
45
+
46
+ HAS_ACCELERATE = True
47
+ except ImportError:
48
+ HAS_ACCELERATE = False
49
+ infer_auto_device_map = None
50
+ init_empty_weights = None
51
+ get_max_memory = None
52
+
23
53
  from huggingface_hub import HfApi, hf_hub_download
24
54
  from torch import nn
25
- from tqdm.auto import tqdm
26
- from transformers import AutoModelForCausalLM
55
+ from transformers import AutoConfig, AutoModelForCausalLM
27
56
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
28
57
 
29
- from sglang.srt.configs.device_config import DeviceConfig
30
58
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
31
- from sglang.srt.configs.model_config import ModelConfig
32
59
  from sglang.srt.connector import (
33
60
  ConnectorType,
34
61
  create_remote_connector,
@@ -39,14 +66,24 @@ from sglang.srt.distributed import (
39
66
  get_tensor_model_parallel_rank,
40
67
  get_tensor_model_parallel_world_size,
41
68
  )
69
+ from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES
42
70
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
71
+ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
72
+ trigger_transferring_weights_request,
73
+ )
43
74
  from sglang.srt.model_loader.utils import (
44
75
  get_model_architecture,
45
76
  post_load_weights,
46
77
  set_default_torch_dtype,
47
78
  )
79
+
80
+ # Constants for memory management
81
+ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
82
+ 0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
83
+ )
48
84
  from sglang.srt.model_loader.weight_utils import (
49
85
  _BAR_FORMAT,
86
+ default_weight_loader,
50
87
  download_safetensors_index_file_from_hf,
51
88
  download_weights_from_hf,
52
89
  filter_duplicate_safetensors_files,
@@ -70,7 +107,14 @@ from sglang.srt.utils import (
70
107
  set_weight_attrs,
71
108
  )
72
109
 
110
+ if TYPE_CHECKING:
111
+ from sglang.srt.configs.device_config import DeviceConfig
112
+ from sglang.srt.configs.model_config import ModelConfig
113
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
114
+
73
115
  _is_npu = is_npu()
116
+ # ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
117
+ # which contains the complete mapping of quantization config choices
74
118
 
75
119
 
76
120
  @contextmanager
@@ -183,7 +227,10 @@ def _initialize_model(
183
227
  if _is_npu:
184
228
  packed_modules_mapping.update(
185
229
  {
186
- "visual": {"qkv_proj": ["qkv"]},
230
+ "visual": {
231
+ "qkv_proj": ["qkv"],
232
+ "gate_up_proj": ["gate_proj", "up_proj"],
233
+ },
187
234
  "vision_model": {
188
235
  "qkv_proj": ["q_proj", "k_proj", "v_proj"],
189
236
  "proj": ["out_proj"],
@@ -451,12 +498,78 @@ class DefaultModelLoader(BaseModelLoader):
451
498
  model_config.model_path, model_config.revision, fall_back_to_pt=True
452
499
  )
453
500
 
501
+ def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
502
+ """Load and prepare the base model for ModelOpt quantization.
503
+
504
+ This method handles the common model loading logic shared between
505
+ DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
506
+ """
507
+ if not HAS_ACCELERATE:
508
+ raise ImportError(
509
+ "accelerate is required for ModelOpt quantization. "
510
+ "Please install it with: pip install accelerate"
511
+ )
512
+
513
+ hf_config = AutoConfig.from_pretrained(
514
+ model_config.model_path, trust_remote_code=True
515
+ )
516
+ with init_empty_weights():
517
+ torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
518
+ model = AutoModelForCausalLM.from_config(
519
+ hf_config, torch_dtype=torch_dtype, trust_remote_code=True
520
+ )
521
+ max_memory = get_max_memory()
522
+ inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
523
+
524
+ on_cpu = "cpu" in inferred_device_map.values()
525
+ model_kwargs = {"torch_dtype": "auto"}
526
+ device_map = "auto"
527
+
528
+ if on_cpu:
529
+ for device in max_memory.keys():
530
+ if isinstance(device, int):
531
+ max_memory[device] *= DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
532
+
533
+ logger.warning(
534
+ "Model does not fit to the GPU mem. "
535
+ f"We apply the following memory limit for calibration: \n{max_memory}\n"
536
+ f"If you hit GPU OOM issue, please adjust the memory fraction "
537
+ f"(currently {DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION}) or "
538
+ "reduce the calibration `batch_size` manually."
539
+ )
540
+ model_kwargs["max_memory"] = max_memory
541
+
542
+ model = AutoModelForCausalLM.from_pretrained(
543
+ model_config.model_path,
544
+ device_map=device_map,
545
+ **model_kwargs,
546
+ trust_remote_code=True,
547
+ )
548
+ logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
549
+
550
+ quant_choice_str = model_config.modelopt_quant
551
+ if not isinstance(quant_choice_str, str):
552
+ raise TypeError(
553
+ f"modelopt_quant must be a string preset key (e.g., 'fp8'), "
554
+ f"got {type(quant_choice_str)}"
555
+ )
556
+
557
+ return model
558
+
454
559
  def load_model(
455
560
  self,
456
561
  *,
457
562
  model_config: ModelConfig,
458
563
  device_config: DeviceConfig,
459
564
  ) -> nn.Module:
565
+
566
+ if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
567
+ # Load base model using shared method
568
+ model = self._load_modelopt_base_model(model_config)
569
+ # Note: DefaultModelLoader doesn't do additional quantization processing
570
+ # For full ModelOpt quantization, use ModelOptModelLoader
571
+ return model.eval()
572
+
460
573
  target_device = torch.device(device_config.device)
461
574
  with set_default_torch_dtype(model_config.dtype):
462
575
  with target_device:
@@ -465,9 +578,9 @@ class DefaultModelLoader(BaseModelLoader):
465
578
  self.load_config,
466
579
  )
467
580
 
468
- self.load_weights_and_postprocess(
469
- model, self._get_all_weights(model_config, model), target_device
470
- )
581
+ self.load_weights_and_postprocess(
582
+ model, self._get_all_weights(model_config, model), target_device
583
+ )
471
584
 
472
585
  return model.eval()
473
586
 
@@ -1366,6 +1479,105 @@ class GGUFModelLoader(BaseModelLoader):
1366
1479
  return model
1367
1480
 
1368
1481
 
1482
+ class RemoteInstanceModelLoader(BaseModelLoader):
1483
+ """Model loader that can load Tensors from remote sglang instance."""
1484
+
1485
+ def __init__(self, load_config: LoadConfig):
1486
+ super().__init__(load_config)
1487
+ if load_config.model_loader_extra_config:
1488
+ raise ValueError(
1489
+ f"Model loader extra config is not supported for "
1490
+ f"load format {load_config.load_format}"
1491
+ )
1492
+
1493
+ def download_model(self, model_config: ModelConfig) -> None:
1494
+ raise NotImplementedError
1495
+
1496
+ def load_model(
1497
+ self,
1498
+ *,
1499
+ model_config: ModelConfig,
1500
+ device_config: DeviceConfig,
1501
+ ) -> nn.Module:
1502
+ logger.info("Loading weights from remote instance ...")
1503
+ load_config = self.load_config
1504
+
1505
+ assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, (
1506
+ f"Model loader {self.load_config.load_format} is not supported for "
1507
+ f"load format {load_config.load_format}"
1508
+ )
1509
+
1510
+ model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
1511
+
1512
+ with set_default_torch_dtype(model_config.dtype):
1513
+ with torch.device(device_config.device):
1514
+ model = _initialize_model(model_config, self.load_config)
1515
+
1516
+ with create_remote_connector(model_weights, device_config.device) as client:
1517
+ connector_type = get_connector_type(client)
1518
+ if connector_type == ConnectorType.INSTANCE:
1519
+ self.load_model_from_remote_instance(
1520
+ model, client, model_config, device_config
1521
+ )
1522
+ else:
1523
+ raise ValueError(
1524
+ f"Unsupported connector type {connector_type} for "
1525
+ f"remote tensor model loading."
1526
+ )
1527
+ return model.eval()
1528
+
1529
+ def load_model_from_remote_instance(
1530
+ self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1531
+ ) -> nn.Module:
1532
+ load_config = self.load_config
1533
+ instance_ip = socket.gethostbyname(socket.gethostname())
1534
+ start_build_group_tic = time.time()
1535
+ client.build_group(
1536
+ gpu_id=device_config.gpu_id,
1537
+ tp_rank=load_config.tp_rank,
1538
+ instance_ip=instance_ip,
1539
+ )
1540
+ torch.cuda.synchronize()
1541
+ end_build_group_tic = time.time()
1542
+ logger.debug(
1543
+ f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
1544
+ )
1545
+
1546
+ if load_config.tp_rank == 0:
1547
+ t = threading.Thread(
1548
+ target=trigger_transferring_weights_request,
1549
+ args=(
1550
+ load_config.remote_instance_weight_loader_seed_instance_ip,
1551
+ load_config.remote_instance_weight_loader_seed_instance_service_port,
1552
+ load_config.remote_instance_weight_loader_send_weights_group_ports,
1553
+ instance_ip,
1554
+ ),
1555
+ )
1556
+ t.start()
1557
+
1558
+ start_get_weights_tic = time.time()
1559
+ with set_default_torch_dtype(model_config.dtype):
1560
+ for _, tensor in model.named_parameters():
1561
+ torch.distributed.broadcast(
1562
+ tensor.data,
1563
+ src=0,
1564
+ group=client._model_update_group,
1565
+ )
1566
+ torch.cuda.synchronize()
1567
+
1568
+ if hasattr(model, "post_load_weights"):
1569
+ model.post_load_weights()
1570
+ end_get_weights_tic = time.time()
1571
+ logger.debug(
1572
+ f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s"
1573
+ )
1574
+ # destroy the process group after loading weights
1575
+ torch.distributed.distributed_c10d.destroy_process_group(
1576
+ client._model_update_group
1577
+ )
1578
+ torch.cuda.empty_cache()
1579
+
1580
+
1369
1581
  class RemoteModelLoader(BaseModelLoader):
1370
1582
  """Model loader that can load Tensors from remote database."""
1371
1583
 
@@ -1543,9 +1755,103 @@ def load_model_with_cpu_quantization(
1543
1755
  return model.eval()
1544
1756
 
1545
1757
 
1546
- def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1758
+ class ModelOptModelLoader(DefaultModelLoader):
1759
+ """
1760
+ Model loader that applies NVIDIA Model Optimizer quantization
1761
+ """
1762
+
1763
+ def __init__(self, load_config: LoadConfig):
1764
+ super().__init__(load_config)
1765
+ # Any ModelOpt specific initialization if needed
1766
+
1767
+ def load_model(
1768
+ self,
1769
+ *,
1770
+ model_config: ModelConfig,
1771
+ device_config: DeviceConfig,
1772
+ ) -> nn.Module:
1773
+
1774
+ logger.info("ModelOptModelLoader: Loading base model...")
1775
+
1776
+ # Use shared method from parent class to load base model
1777
+ model = self._load_modelopt_base_model(model_config)
1778
+
1779
+ # Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
1780
+ try:
1781
+ import modelopt.torch.quantization as mtq
1782
+ from modelopt.torch.utils.dataset_utils import create_forward_loop
1783
+ except ImportError:
1784
+ logger.error(
1785
+ "NVIDIA Model Optimizer (modelopt) library not found. "
1786
+ "Please install it to use 'modelopt_quant' feature."
1787
+ )
1788
+ raise
1789
+
1790
+ quant_choice_str = model_config.modelopt_quant
1791
+
1792
+ quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
1793
+ if not quant_cfg_name:
1794
+ raise ValueError(
1795
+ f"Invalid modelopt_quant choice: '{quant_choice_str}'. "
1796
+ f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. "
1797
+ "Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
1798
+ "attribute names of config objects in modelopt.torch.quantization."
1799
+ )
1800
+
1801
+ try:
1802
+ # getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
1803
+ quant_cfg = getattr(mtq, quant_cfg_name)
1804
+ except AttributeError:
1805
+ raise AttributeError(
1806
+ f"ModelOpt quantization config attribute '{quant_cfg_name}' "
1807
+ f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. "
1808
+ "Please verify QUANT_CFG_CHOICES and the ModelOpt library."
1809
+ )
1810
+
1811
+ # For now, assume no calibration. Calibration setup is a separate, more complex step.
1812
+ use_calibration = False # This would ideally be a configurable parameter
1813
+ calib_dataloader = None # This would need to be provided/configured
1814
+
1815
+ calibrate_loop = (
1816
+ create_forward_loop(dataloader=calib_dataloader)
1817
+ if use_calibration
1818
+ else None
1819
+ )
1820
+
1821
+ if use_calibration and calib_dataloader is None:
1822
+ logger.warning(
1823
+ "ModelOpt calibration requested but no calib_dataloader provided. "
1824
+ "Proceeding without calibration. Quantization accuracy may be affected."
1825
+ )
1826
+
1827
+ logger.info(
1828
+ f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
1829
+ )
1830
+
1831
+ try:
1832
+ model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
1833
+ logger.info("Model successfully quantized with ModelOpt.")
1834
+ except Exception as e:
1835
+ logger.error(f"Error during ModelOpt mtq.quantize call: {e}")
1836
+ raise
1837
+ mtq.print_quant_summary(model)
1838
+
1839
+ return model.eval()
1840
+
1841
+
1842
+ def get_model_loader(
1843
+ load_config: LoadConfig, model_config: Optional[ModelConfig] = None
1844
+ ) -> BaseModelLoader:
1547
1845
  """Get a model loader based on the load format."""
1548
1846
 
1847
+ if (
1848
+ model_config
1849
+ and hasattr(model_config, "modelopt_quant")
1850
+ and model_config.modelopt_quant
1851
+ ):
1852
+ logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.")
1853
+ return ModelOptModelLoader(load_config)
1854
+
1549
1855
  if isinstance(load_config.load_format, type):
1550
1856
  return load_config.load_format(load_config)
1551
1857
 
@@ -1567,4 +1873,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1567
1873
  if load_config.load_format == LoadFormat.REMOTE:
1568
1874
  return RemoteModelLoader(load_config)
1569
1875
 
1876
+ if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
1877
+ return RemoteInstanceModelLoader(load_config)
1878
+
1570
1879
  return DefaultModelLoader(load_config)
@@ -0,0 +1,69 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ from typing import List
5
+
6
+ import requests
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def trigger_init_weights_send_group_for_remote_instance_request(
12
+ remote_instance_weight_loader_seed_instance_ip: str,
13
+ remote_instance_weight_loader_seed_instance_service_port: int,
14
+ remote_instance_weight_loader_send_weights_group_ports: List[int],
15
+ remote_instance_weight_loader_client_id: str,
16
+ ):
17
+ seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
18
+ # Only support loading weights from instance with same parallelism strategy.
19
+ # Per TP rank pair between seed and dst instances will build a communication group for sending weights.
20
+ # i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc.
21
+ # Each communication group will have a world size 2.
22
+ try:
23
+ requests.post(
24
+ f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance",
25
+ json={
26
+ "master_address": remote_instance_weight_loader_seed_instance_ip,
27
+ "ports": (
28
+ ",".join(
29
+ str(p)
30
+ for p in remote_instance_weight_loader_send_weights_group_ports
31
+ )
32
+ ),
33
+ "group_rank": 0,
34
+ "world_size": 2,
35
+ "group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
36
+ "backend": "nccl",
37
+ },
38
+ )
39
+ except Exception as e:
40
+ logger.error(
41
+ f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}."
42
+ )
43
+ raise
44
+
45
+
46
+ def trigger_transferring_weights_request(
47
+ remote_instance_weight_loader_seed_instance_ip: str,
48
+ remote_instance_weight_loader_seed_instance_service_port: int,
49
+ remote_instance_weight_loader_send_weights_group_ports: List[int],
50
+ remote_instance_weight_loader_client_id: str,
51
+ ):
52
+ seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
53
+ try:
54
+ requests.post(
55
+ f"{seed_instance_service_url}/send_weights_to_remote_instance",
56
+ json={
57
+ "master_address": remote_instance_weight_loader_seed_instance_ip,
58
+ "ports": (
59
+ ",".join(
60
+ str(p)
61
+ for p in remote_instance_weight_loader_send_weights_group_ports
62
+ )
63
+ ),
64
+ "group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
65
+ },
66
+ )
67
+ except Exception as e:
68
+ logger.error(f"Failed to trigger send weights to remote instance request: {e}")
69
+ raise