sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -19,8 +19,10 @@ import logging
|
|
19
19
|
import threading
|
20
20
|
from typing import TYPE_CHECKING, Optional, Union
|
21
21
|
|
22
|
+
import numpy as np
|
22
23
|
import torch
|
23
24
|
|
25
|
+
from sglang.srt.configs.model_config import AttentionArch
|
24
26
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
25
27
|
|
26
28
|
logger = logging.getLogger(__name__)
|
@@ -73,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner):
|
|
73
75
|
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
74
76
|
|
75
77
|
# Replay
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
78
|
+
if self.model_runner.model_config.index_head_dim is None:
|
79
|
+
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
|
80
|
+
self.bs - self.raw_bs
|
81
|
+
)
|
82
|
+
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
83
|
+
thread.start()
|
84
|
+
self.graphs[self.bs].replay()
|
85
|
+
thread.join()
|
86
|
+
else:
|
87
|
+
self.graphs[self.bs].replay()
|
81
88
|
|
82
89
|
output = self.output_buffers[self.bs]
|
83
90
|
if isinstance(output, LogitsProcessorOutput):
|
@@ -1,16 +1,22 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
3
7
|
from torch import nn
|
4
8
|
|
5
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
6
|
-
from sglang.srt.configs.load_config import LoadConfig
|
7
|
-
from sglang.srt.configs.model_config import ModelConfig
|
8
9
|
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
|
9
10
|
from sglang.srt.model_loader.utils import (
|
10
11
|
get_architecture_class_name,
|
11
12
|
get_model_architecture,
|
12
13
|
)
|
13
14
|
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
17
|
+
from sglang.srt.configs.load_config import LoadConfig
|
18
|
+
from sglang.srt.configs.model_config import ModelConfig
|
19
|
+
|
14
20
|
|
15
21
|
def get_model(
|
16
22
|
*,
|
@@ -18,7 +24,7 @@ def get_model(
|
|
18
24
|
load_config: LoadConfig,
|
19
25
|
device_config: DeviceConfig,
|
20
26
|
) -> nn.Module:
|
21
|
-
loader = get_model_loader(load_config)
|
27
|
+
loader = get_model_loader(load_config, model_config)
|
22
28
|
return loader.load_model(
|
23
29
|
model_config=model_config,
|
24
30
|
device_config=device_config,
|
@@ -1,5 +1,7 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
# ruff: noqa: SIM117
|
4
6
|
import collections
|
5
7
|
import concurrent
|
@@ -10,25 +12,50 @@ import json
|
|
10
12
|
import logging
|
11
13
|
import math
|
12
14
|
import os
|
15
|
+
import re
|
16
|
+
import socket
|
17
|
+
import threading
|
13
18
|
import time
|
14
19
|
from abc import ABC, abstractmethod
|
15
20
|
from concurrent.futures import ThreadPoolExecutor
|
16
21
|
from contextlib import contextmanager
|
17
|
-
from typing import
|
22
|
+
from typing import (
|
23
|
+
TYPE_CHECKING,
|
24
|
+
Any,
|
25
|
+
Dict,
|
26
|
+
Generator,
|
27
|
+
Iterable,
|
28
|
+
List,
|
29
|
+
Optional,
|
30
|
+
Tuple,
|
31
|
+
cast,
|
32
|
+
)
|
33
|
+
from urllib.parse import urlparse
|
18
34
|
|
19
35
|
import huggingface_hub
|
20
36
|
import numpy as np
|
37
|
+
import requests
|
21
38
|
import safetensors.torch
|
22
39
|
import torch
|
40
|
+
|
41
|
+
# Try to import accelerate (optional dependency)
|
42
|
+
try:
|
43
|
+
from accelerate import infer_auto_device_map, init_empty_weights
|
44
|
+
from accelerate.utils import get_max_memory
|
45
|
+
|
46
|
+
HAS_ACCELERATE = True
|
47
|
+
except ImportError:
|
48
|
+
HAS_ACCELERATE = False
|
49
|
+
infer_auto_device_map = None
|
50
|
+
init_empty_weights = None
|
51
|
+
get_max_memory = None
|
52
|
+
|
23
53
|
from huggingface_hub import HfApi, hf_hub_download
|
24
54
|
from torch import nn
|
25
|
-
from
|
26
|
-
from transformers import AutoModelForCausalLM
|
55
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
27
56
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
28
57
|
|
29
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
30
58
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
31
|
-
from sglang.srt.configs.model_config import ModelConfig
|
32
59
|
from sglang.srt.connector import (
|
33
60
|
ConnectorType,
|
34
61
|
create_remote_connector,
|
@@ -39,14 +66,24 @@ from sglang.srt.distributed import (
|
|
39
66
|
get_tensor_model_parallel_rank,
|
40
67
|
get_tensor_model_parallel_world_size,
|
41
68
|
)
|
69
|
+
from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES
|
42
70
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
71
|
+
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
72
|
+
trigger_transferring_weights_request,
|
73
|
+
)
|
43
74
|
from sglang.srt.model_loader.utils import (
|
44
75
|
get_model_architecture,
|
45
76
|
post_load_weights,
|
46
77
|
set_default_torch_dtype,
|
47
78
|
)
|
79
|
+
|
80
|
+
# Constants for memory management
|
81
|
+
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
|
82
|
+
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
|
83
|
+
)
|
48
84
|
from sglang.srt.model_loader.weight_utils import (
|
49
85
|
_BAR_FORMAT,
|
86
|
+
default_weight_loader,
|
50
87
|
download_safetensors_index_file_from_hf,
|
51
88
|
download_weights_from_hf,
|
52
89
|
filter_duplicate_safetensors_files,
|
@@ -70,7 +107,14 @@ from sglang.srt.utils import (
|
|
70
107
|
set_weight_attrs,
|
71
108
|
)
|
72
109
|
|
110
|
+
if TYPE_CHECKING:
|
111
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
112
|
+
from sglang.srt.configs.model_config import ModelConfig
|
113
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
114
|
+
|
73
115
|
_is_npu = is_npu()
|
116
|
+
# ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
|
117
|
+
# which contains the complete mapping of quantization config choices
|
74
118
|
|
75
119
|
|
76
120
|
@contextmanager
|
@@ -183,7 +227,10 @@ def _initialize_model(
|
|
183
227
|
if _is_npu:
|
184
228
|
packed_modules_mapping.update(
|
185
229
|
{
|
186
|
-
"visual": {
|
230
|
+
"visual": {
|
231
|
+
"qkv_proj": ["qkv"],
|
232
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
233
|
+
},
|
187
234
|
"vision_model": {
|
188
235
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
189
236
|
"proj": ["out_proj"],
|
@@ -451,12 +498,78 @@ class DefaultModelLoader(BaseModelLoader):
|
|
451
498
|
model_config.model_path, model_config.revision, fall_back_to_pt=True
|
452
499
|
)
|
453
500
|
|
501
|
+
def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
|
502
|
+
"""Load and prepare the base model for ModelOpt quantization.
|
503
|
+
|
504
|
+
This method handles the common model loading logic shared between
|
505
|
+
DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
|
506
|
+
"""
|
507
|
+
if not HAS_ACCELERATE:
|
508
|
+
raise ImportError(
|
509
|
+
"accelerate is required for ModelOpt quantization. "
|
510
|
+
"Please install it with: pip install accelerate"
|
511
|
+
)
|
512
|
+
|
513
|
+
hf_config = AutoConfig.from_pretrained(
|
514
|
+
model_config.model_path, trust_remote_code=True
|
515
|
+
)
|
516
|
+
with init_empty_weights():
|
517
|
+
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
|
518
|
+
model = AutoModelForCausalLM.from_config(
|
519
|
+
hf_config, torch_dtype=torch_dtype, trust_remote_code=True
|
520
|
+
)
|
521
|
+
max_memory = get_max_memory()
|
522
|
+
inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
|
523
|
+
|
524
|
+
on_cpu = "cpu" in inferred_device_map.values()
|
525
|
+
model_kwargs = {"torch_dtype": "auto"}
|
526
|
+
device_map = "auto"
|
527
|
+
|
528
|
+
if on_cpu:
|
529
|
+
for device in max_memory.keys():
|
530
|
+
if isinstance(device, int):
|
531
|
+
max_memory[device] *= DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
|
532
|
+
|
533
|
+
logger.warning(
|
534
|
+
"Model does not fit to the GPU mem. "
|
535
|
+
f"We apply the following memory limit for calibration: \n{max_memory}\n"
|
536
|
+
f"If you hit GPU OOM issue, please adjust the memory fraction "
|
537
|
+
f"(currently {DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION}) or "
|
538
|
+
"reduce the calibration `batch_size` manually."
|
539
|
+
)
|
540
|
+
model_kwargs["max_memory"] = max_memory
|
541
|
+
|
542
|
+
model = AutoModelForCausalLM.from_pretrained(
|
543
|
+
model_config.model_path,
|
544
|
+
device_map=device_map,
|
545
|
+
**model_kwargs,
|
546
|
+
trust_remote_code=True,
|
547
|
+
)
|
548
|
+
logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
|
549
|
+
|
550
|
+
quant_choice_str = model_config.modelopt_quant
|
551
|
+
if not isinstance(quant_choice_str, str):
|
552
|
+
raise TypeError(
|
553
|
+
f"modelopt_quant must be a string preset key (e.g., 'fp8'), "
|
554
|
+
f"got {type(quant_choice_str)}"
|
555
|
+
)
|
556
|
+
|
557
|
+
return model
|
558
|
+
|
454
559
|
def load_model(
|
455
560
|
self,
|
456
561
|
*,
|
457
562
|
model_config: ModelConfig,
|
458
563
|
device_config: DeviceConfig,
|
459
564
|
) -> nn.Module:
|
565
|
+
|
566
|
+
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
567
|
+
# Load base model using shared method
|
568
|
+
model = self._load_modelopt_base_model(model_config)
|
569
|
+
# Note: DefaultModelLoader doesn't do additional quantization processing
|
570
|
+
# For full ModelOpt quantization, use ModelOptModelLoader
|
571
|
+
return model.eval()
|
572
|
+
|
460
573
|
target_device = torch.device(device_config.device)
|
461
574
|
with set_default_torch_dtype(model_config.dtype):
|
462
575
|
with target_device:
|
@@ -465,9 +578,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|
465
578
|
self.load_config,
|
466
579
|
)
|
467
580
|
|
468
|
-
|
469
|
-
|
470
|
-
|
581
|
+
self.load_weights_and_postprocess(
|
582
|
+
model, self._get_all_weights(model_config, model), target_device
|
583
|
+
)
|
471
584
|
|
472
585
|
return model.eval()
|
473
586
|
|
@@ -1366,6 +1479,105 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1366
1479
|
return model
|
1367
1480
|
|
1368
1481
|
|
1482
|
+
class RemoteInstanceModelLoader(BaseModelLoader):
|
1483
|
+
"""Model loader that can load Tensors from remote sglang instance."""
|
1484
|
+
|
1485
|
+
def __init__(self, load_config: LoadConfig):
|
1486
|
+
super().__init__(load_config)
|
1487
|
+
if load_config.model_loader_extra_config:
|
1488
|
+
raise ValueError(
|
1489
|
+
f"Model loader extra config is not supported for "
|
1490
|
+
f"load format {load_config.load_format}"
|
1491
|
+
)
|
1492
|
+
|
1493
|
+
def download_model(self, model_config: ModelConfig) -> None:
|
1494
|
+
raise NotImplementedError
|
1495
|
+
|
1496
|
+
def load_model(
|
1497
|
+
self,
|
1498
|
+
*,
|
1499
|
+
model_config: ModelConfig,
|
1500
|
+
device_config: DeviceConfig,
|
1501
|
+
) -> nn.Module:
|
1502
|
+
logger.info("Loading weights from remote instance ...")
|
1503
|
+
load_config = self.load_config
|
1504
|
+
|
1505
|
+
assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, (
|
1506
|
+
f"Model loader {self.load_config.load_format} is not supported for "
|
1507
|
+
f"load format {load_config.load_format}"
|
1508
|
+
)
|
1509
|
+
|
1510
|
+
model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
|
1511
|
+
|
1512
|
+
with set_default_torch_dtype(model_config.dtype):
|
1513
|
+
with torch.device(device_config.device):
|
1514
|
+
model = _initialize_model(model_config, self.load_config)
|
1515
|
+
|
1516
|
+
with create_remote_connector(model_weights, device_config.device) as client:
|
1517
|
+
connector_type = get_connector_type(client)
|
1518
|
+
if connector_type == ConnectorType.INSTANCE:
|
1519
|
+
self.load_model_from_remote_instance(
|
1520
|
+
model, client, model_config, device_config
|
1521
|
+
)
|
1522
|
+
else:
|
1523
|
+
raise ValueError(
|
1524
|
+
f"Unsupported connector type {connector_type} for "
|
1525
|
+
f"remote tensor model loading."
|
1526
|
+
)
|
1527
|
+
return model.eval()
|
1528
|
+
|
1529
|
+
def load_model_from_remote_instance(
|
1530
|
+
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1531
|
+
) -> nn.Module:
|
1532
|
+
load_config = self.load_config
|
1533
|
+
instance_ip = socket.gethostbyname(socket.gethostname())
|
1534
|
+
start_build_group_tic = time.time()
|
1535
|
+
client.build_group(
|
1536
|
+
gpu_id=device_config.gpu_id,
|
1537
|
+
tp_rank=load_config.tp_rank,
|
1538
|
+
instance_ip=instance_ip,
|
1539
|
+
)
|
1540
|
+
torch.cuda.synchronize()
|
1541
|
+
end_build_group_tic = time.time()
|
1542
|
+
logger.debug(
|
1543
|
+
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
|
1544
|
+
)
|
1545
|
+
|
1546
|
+
if load_config.tp_rank == 0:
|
1547
|
+
t = threading.Thread(
|
1548
|
+
target=trigger_transferring_weights_request,
|
1549
|
+
args=(
|
1550
|
+
load_config.remote_instance_weight_loader_seed_instance_ip,
|
1551
|
+
load_config.remote_instance_weight_loader_seed_instance_service_port,
|
1552
|
+
load_config.remote_instance_weight_loader_send_weights_group_ports,
|
1553
|
+
instance_ip,
|
1554
|
+
),
|
1555
|
+
)
|
1556
|
+
t.start()
|
1557
|
+
|
1558
|
+
start_get_weights_tic = time.time()
|
1559
|
+
with set_default_torch_dtype(model_config.dtype):
|
1560
|
+
for _, tensor in model.named_parameters():
|
1561
|
+
torch.distributed.broadcast(
|
1562
|
+
tensor.data,
|
1563
|
+
src=0,
|
1564
|
+
group=client._model_update_group,
|
1565
|
+
)
|
1566
|
+
torch.cuda.synchronize()
|
1567
|
+
|
1568
|
+
if hasattr(model, "post_load_weights"):
|
1569
|
+
model.post_load_weights()
|
1570
|
+
end_get_weights_tic = time.time()
|
1571
|
+
logger.debug(
|
1572
|
+
f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s"
|
1573
|
+
)
|
1574
|
+
# destroy the process group after loading weights
|
1575
|
+
torch.distributed.distributed_c10d.destroy_process_group(
|
1576
|
+
client._model_update_group
|
1577
|
+
)
|
1578
|
+
torch.cuda.empty_cache()
|
1579
|
+
|
1580
|
+
|
1369
1581
|
class RemoteModelLoader(BaseModelLoader):
|
1370
1582
|
"""Model loader that can load Tensors from remote database."""
|
1371
1583
|
|
@@ -1543,9 +1755,103 @@ def load_model_with_cpu_quantization(
|
|
1543
1755
|
return model.eval()
|
1544
1756
|
|
1545
1757
|
|
1546
|
-
|
1758
|
+
class ModelOptModelLoader(DefaultModelLoader):
|
1759
|
+
"""
|
1760
|
+
Model loader that applies NVIDIA Model Optimizer quantization
|
1761
|
+
"""
|
1762
|
+
|
1763
|
+
def __init__(self, load_config: LoadConfig):
|
1764
|
+
super().__init__(load_config)
|
1765
|
+
# Any ModelOpt specific initialization if needed
|
1766
|
+
|
1767
|
+
def load_model(
|
1768
|
+
self,
|
1769
|
+
*,
|
1770
|
+
model_config: ModelConfig,
|
1771
|
+
device_config: DeviceConfig,
|
1772
|
+
) -> nn.Module:
|
1773
|
+
|
1774
|
+
logger.info("ModelOptModelLoader: Loading base model...")
|
1775
|
+
|
1776
|
+
# Use shared method from parent class to load base model
|
1777
|
+
model = self._load_modelopt_base_model(model_config)
|
1778
|
+
|
1779
|
+
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
|
1780
|
+
try:
|
1781
|
+
import modelopt.torch.quantization as mtq
|
1782
|
+
from modelopt.torch.utils.dataset_utils import create_forward_loop
|
1783
|
+
except ImportError:
|
1784
|
+
logger.error(
|
1785
|
+
"NVIDIA Model Optimizer (modelopt) library not found. "
|
1786
|
+
"Please install it to use 'modelopt_quant' feature."
|
1787
|
+
)
|
1788
|
+
raise
|
1789
|
+
|
1790
|
+
quant_choice_str = model_config.modelopt_quant
|
1791
|
+
|
1792
|
+
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
1793
|
+
if not quant_cfg_name:
|
1794
|
+
raise ValueError(
|
1795
|
+
f"Invalid modelopt_quant choice: '{quant_choice_str}'. "
|
1796
|
+
f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. "
|
1797
|
+
"Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
|
1798
|
+
"attribute names of config objects in modelopt.torch.quantization."
|
1799
|
+
)
|
1800
|
+
|
1801
|
+
try:
|
1802
|
+
# getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
|
1803
|
+
quant_cfg = getattr(mtq, quant_cfg_name)
|
1804
|
+
except AttributeError:
|
1805
|
+
raise AttributeError(
|
1806
|
+
f"ModelOpt quantization config attribute '{quant_cfg_name}' "
|
1807
|
+
f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. "
|
1808
|
+
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
|
1809
|
+
)
|
1810
|
+
|
1811
|
+
# For now, assume no calibration. Calibration setup is a separate, more complex step.
|
1812
|
+
use_calibration = False # This would ideally be a configurable parameter
|
1813
|
+
calib_dataloader = None # This would need to be provided/configured
|
1814
|
+
|
1815
|
+
calibrate_loop = (
|
1816
|
+
create_forward_loop(dataloader=calib_dataloader)
|
1817
|
+
if use_calibration
|
1818
|
+
else None
|
1819
|
+
)
|
1820
|
+
|
1821
|
+
if use_calibration and calib_dataloader is None:
|
1822
|
+
logger.warning(
|
1823
|
+
"ModelOpt calibration requested but no calib_dataloader provided. "
|
1824
|
+
"Proceeding without calibration. Quantization accuracy may be affected."
|
1825
|
+
)
|
1826
|
+
|
1827
|
+
logger.info(
|
1828
|
+
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
|
1829
|
+
)
|
1830
|
+
|
1831
|
+
try:
|
1832
|
+
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
1833
|
+
logger.info("Model successfully quantized with ModelOpt.")
|
1834
|
+
except Exception as e:
|
1835
|
+
logger.error(f"Error during ModelOpt mtq.quantize call: {e}")
|
1836
|
+
raise
|
1837
|
+
mtq.print_quant_summary(model)
|
1838
|
+
|
1839
|
+
return model.eval()
|
1840
|
+
|
1841
|
+
|
1842
|
+
def get_model_loader(
|
1843
|
+
load_config: LoadConfig, model_config: Optional[ModelConfig] = None
|
1844
|
+
) -> BaseModelLoader:
|
1547
1845
|
"""Get a model loader based on the load format."""
|
1548
1846
|
|
1847
|
+
if (
|
1848
|
+
model_config
|
1849
|
+
and hasattr(model_config, "modelopt_quant")
|
1850
|
+
and model_config.modelopt_quant
|
1851
|
+
):
|
1852
|
+
logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.")
|
1853
|
+
return ModelOptModelLoader(load_config)
|
1854
|
+
|
1549
1855
|
if isinstance(load_config.load_format, type):
|
1550
1856
|
return load_config.load_format(load_config)
|
1551
1857
|
|
@@ -1567,4 +1873,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
1567
1873
|
if load_config.load_format == LoadFormat.REMOTE:
|
1568
1874
|
return RemoteModelLoader(load_config)
|
1569
1875
|
|
1876
|
+
if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
|
1877
|
+
return RemoteInstanceModelLoader(load_config)
|
1878
|
+
|
1570
1879
|
return DefaultModelLoader(load_config)
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import List
|
5
|
+
|
6
|
+
import requests
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
def trigger_init_weights_send_group_for_remote_instance_request(
|
12
|
+
remote_instance_weight_loader_seed_instance_ip: str,
|
13
|
+
remote_instance_weight_loader_seed_instance_service_port: int,
|
14
|
+
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
15
|
+
remote_instance_weight_loader_client_id: str,
|
16
|
+
):
|
17
|
+
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
18
|
+
# Only support loading weights from instance with same parallelism strategy.
|
19
|
+
# Per TP rank pair between seed and dst instances will build a communication group for sending weights.
|
20
|
+
# i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc.
|
21
|
+
# Each communication group will have a world size 2.
|
22
|
+
try:
|
23
|
+
requests.post(
|
24
|
+
f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance",
|
25
|
+
json={
|
26
|
+
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
27
|
+
"ports": (
|
28
|
+
",".join(
|
29
|
+
str(p)
|
30
|
+
for p in remote_instance_weight_loader_send_weights_group_ports
|
31
|
+
)
|
32
|
+
),
|
33
|
+
"group_rank": 0,
|
34
|
+
"world_size": 2,
|
35
|
+
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
36
|
+
"backend": "nccl",
|
37
|
+
},
|
38
|
+
)
|
39
|
+
except Exception as e:
|
40
|
+
logger.error(
|
41
|
+
f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}."
|
42
|
+
)
|
43
|
+
raise
|
44
|
+
|
45
|
+
|
46
|
+
def trigger_transferring_weights_request(
|
47
|
+
remote_instance_weight_loader_seed_instance_ip: str,
|
48
|
+
remote_instance_weight_loader_seed_instance_service_port: int,
|
49
|
+
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
50
|
+
remote_instance_weight_loader_client_id: str,
|
51
|
+
):
|
52
|
+
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
53
|
+
try:
|
54
|
+
requests.post(
|
55
|
+
f"{seed_instance_service_url}/send_weights_to_remote_instance",
|
56
|
+
json={
|
57
|
+
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
58
|
+
"ports": (
|
59
|
+
",".join(
|
60
|
+
str(p)
|
61
|
+
for p in remote_instance_weight_loader_send_weights_group_ports
|
62
|
+
)
|
63
|
+
),
|
64
|
+
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
65
|
+
},
|
66
|
+
)
|
67
|
+
except Exception as e:
|
68
|
+
logger.error(f"Failed to trigger send weights to remote instance request: {e}")
|
69
|
+
raise
|