sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
|
|
34
34
|
|
35
35
|
if TYPE_CHECKING:
|
36
36
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
37
|
-
from sglang.srt.layers.moe.
|
37
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
38
|
+
StandardDispatchOutput,
|
39
|
+
CombineInput,
|
40
|
+
)
|
38
41
|
|
39
42
|
from sglang.srt.utils import is_cuda, is_hip
|
40
43
|
|
@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
736
739
|
)
|
737
740
|
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
738
741
|
|
742
|
+
def create_moe_runner(
|
743
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
744
|
+
):
|
745
|
+
self.moe_runner_config = moe_runner_config
|
746
|
+
|
739
747
|
def apply(
|
740
748
|
self,
|
741
749
|
layer: torch.nn.Module,
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
750
|
+
dispatch_output: StandardDispatchOutput,
|
751
|
+
) -> CombineInput:
|
752
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
753
|
+
|
746
754
|
assert (
|
747
|
-
moe_runner_config.activation == "silu"
|
755
|
+
self.moe_runner_config.activation == "silu"
|
748
756
|
), "Only SiLU activation is supported."
|
749
757
|
|
750
758
|
# The input must currently be float16
|
759
|
+
x = dispatch_output.hidden_states
|
760
|
+
topk_output = dispatch_output.topk_output
|
761
|
+
|
751
762
|
orig_dtype = x.dtype
|
752
763
|
x = x.half()
|
753
764
|
|
754
765
|
topk_weights, topk_ids, router_logits = topk_output
|
755
766
|
|
756
|
-
|
767
|
+
output = fused_marlin_moe(
|
757
768
|
x,
|
758
769
|
layer.w13_qweight,
|
759
770
|
layer.w2_qweight,
|
@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
768
779
|
w2_zeros=layer.w2_qzeros,
|
769
780
|
num_bits=self.quant_config.weight_bits,
|
770
781
|
).to(orig_dtype)
|
782
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
|
4
4
|
import inspect
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
+
from dataclasses import dataclass
|
6
7
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
7
8
|
|
8
9
|
import torch
|
@@ -10,7 +11,7 @@ from torch import nn
|
|
10
11
|
|
11
12
|
if TYPE_CHECKING:
|
12
13
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
13
|
-
from sglang.srt.layers.moe.
|
14
|
+
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
|
14
15
|
|
15
16
|
|
16
17
|
class QuantizeMethodBase(ABC):
|
@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|
89
90
|
layer: torch.nn.Module,
|
90
91
|
num_experts: int,
|
91
92
|
hidden_size: int,
|
92
|
-
|
93
|
+
intermediate_size_per_partition: int,
|
93
94
|
params_dtype: torch.dtype,
|
94
95
|
**extra_weight_attrs,
|
95
96
|
):
|
96
97
|
raise NotImplementedError
|
97
98
|
|
99
|
+
@abstractmethod
|
100
|
+
def create_moe_runner(
|
101
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
102
|
+
):
|
103
|
+
raise NotImplementedError
|
104
|
+
|
98
105
|
@abstractmethod
|
99
106
|
def apply(
|
100
107
|
self,
|
101
108
|
layer: torch.nn.Module,
|
102
|
-
|
103
|
-
|
104
|
-
moe_runner_config: MoeRunnerConfig,
|
105
|
-
) -> torch.Tensor:
|
109
|
+
dispatch_output: DispatchOutput,
|
110
|
+
) -> CombineInput:
|
106
111
|
raise NotImplementedError
|
107
112
|
|
108
113
|
|
@@ -9,6 +9,8 @@ import torch
|
|
9
9
|
from torch.nn import Module
|
10
10
|
|
11
11
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
13
15
|
from sglang.srt.layers.quantization.base_config import (
|
14
16
|
FusedMoEMethodBase,
|
@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|
22
24
|
from sglang.srt.utils import set_weight_attrs
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
|
-
from sglang.srt.layers.moe.
|
26
|
-
|
27
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
28
|
+
CombineInput,
|
29
|
+
StandardDispatchOutput,
|
30
|
+
)
|
27
31
|
|
28
32
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
29
33
|
|
@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
257
261
|
layer: Module,
|
258
262
|
num_experts: int,
|
259
263
|
hidden_size: int,
|
260
|
-
|
264
|
+
intermediate_size_per_partition: int,
|
261
265
|
params_dtype: torch.dtype,
|
262
266
|
**extra_weight_attrs,
|
263
267
|
):
|
@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
273
277
|
)
|
274
278
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
275
279
|
# Required by column parallel or enabling merged weights
|
276
|
-
if
|
280
|
+
if intermediate_size_per_partition % block_n != 0:
|
277
281
|
raise ValueError(
|
278
282
|
f"The output_size of gate's and up's weight = "
|
279
|
-
f"{
|
283
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
280
284
|
f"weight quantization block_n = {block_n}."
|
281
285
|
)
|
282
286
|
if tp_size > 1:
|
283
287
|
# Required by row parallel
|
284
|
-
if
|
288
|
+
if intermediate_size_per_partition % block_k != 0:
|
285
289
|
raise ValueError(
|
286
290
|
f"The input_size of down's weight = "
|
287
|
-
f"{
|
291
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
288
292
|
f"weight quantization block_k = {block_k}."
|
289
293
|
)
|
290
294
|
|
291
295
|
# WEIGHTS
|
292
296
|
w13_weight = torch.nn.Parameter(
|
293
297
|
torch.empty(
|
294
|
-
num_experts,
|
298
|
+
num_experts,
|
299
|
+
2 * intermediate_size_per_partition,
|
300
|
+
hidden_size,
|
301
|
+
dtype=params_dtype,
|
295
302
|
),
|
296
303
|
requires_grad=False,
|
297
304
|
)
|
@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
300
307
|
|
301
308
|
w2_weight = torch.nn.Parameter(
|
302
309
|
torch.empty(
|
303
|
-
num_experts,
|
310
|
+
num_experts,
|
311
|
+
hidden_size,
|
312
|
+
intermediate_size_per_partition,
|
313
|
+
dtype=params_dtype,
|
304
314
|
),
|
305
315
|
requires_grad=False,
|
306
316
|
)
|
@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
311
321
|
w13_weight_scale = torch.nn.Parameter(
|
312
322
|
torch.ones(
|
313
323
|
num_experts,
|
314
|
-
2 * ((
|
324
|
+
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
315
325
|
(hidden_size + block_k - 1) // block_k,
|
316
326
|
dtype=torch.float32,
|
317
327
|
),
|
@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
321
331
|
torch.ones(
|
322
332
|
num_experts,
|
323
333
|
(hidden_size + block_n - 1) // block_n,
|
324
|
-
(
|
334
|
+
(intermediate_size_per_partition + block_k - 1) // block_k,
|
325
335
|
dtype=torch.float32,
|
326
336
|
),
|
327
337
|
requires_grad=False,
|
@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
344
354
|
# Block quant doesn't need to process weights after loading
|
345
355
|
return
|
346
356
|
|
357
|
+
def create_moe_runner(
|
358
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
359
|
+
):
|
360
|
+
self.moe_runner_config = moe_runner_config
|
361
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
362
|
+
|
347
363
|
def apply(
|
348
364
|
self,
|
349
365
|
layer: torch.nn.Module,
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
# Expert fusion with INT8 quantization
|
357
|
-
return fused_experts(
|
358
|
-
x,
|
359
|
-
layer.w13_weight,
|
360
|
-
layer.w2_weight,
|
361
|
-
topk_output=topk_output,
|
362
|
-
moe_runner_config=moe_runner_config,
|
366
|
+
dispatch_output: StandardDispatchOutput,
|
367
|
+
) -> CombineInput:
|
368
|
+
|
369
|
+
quant_info = TritonMoeQuantInfo(
|
370
|
+
w13_weight=layer.w13_weight,
|
371
|
+
w2_weight=layer.w2_weight,
|
363
372
|
use_int8_w8a8=True,
|
364
|
-
|
365
|
-
w2_scale=
|
366
|
-
|
373
|
+
w13_scale=layer.w13_weight_scale_inv,
|
374
|
+
w2_scale=layer.w2_weight_scale_inv,
|
375
|
+
a13_scale=layer.w13_input_scale,
|
367
376
|
a2_scale=layer.w2_input_scale,
|
368
377
|
block_shape=self.quant_config.weight_block_size,
|
369
378
|
)
|
379
|
+
|
380
|
+
return self.runner.run(dispatch_output, quant_info)
|
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
|
30
30
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
31
31
|
CompressedTensorsScheme,
|
32
32
|
CompressedTensorsW8A8Fp8,
|
33
|
+
CompressedTensorsW8A8Int8,
|
33
34
|
CompressedTensorsW8A16Fp8,
|
34
35
|
)
|
35
36
|
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
@@ -11,6 +11,8 @@ import torch
|
|
11
11
|
from compressed_tensors import CompressionFormat
|
12
12
|
from compressed_tensors.quantization import QuantizationStrategy
|
13
13
|
|
14
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
15
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
14
16
|
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
15
17
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
16
18
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
@@ -30,8 +32,10 @@ from sglang.srt.utils import (
|
|
30
32
|
|
31
33
|
if TYPE_CHECKING:
|
32
34
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
33
|
-
from sglang.srt.layers.moe.
|
34
|
-
|
35
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
36
|
+
CombineInput,
|
37
|
+
StandardDispatchOutput,
|
38
|
+
)
|
35
39
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
36
40
|
CompressedTensorsConfig,
|
37
41
|
)
|
@@ -293,14 +297,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
293
297
|
)
|
294
298
|
torch.cuda.empty_cache()
|
295
299
|
|
300
|
+
def create_moe_runner(
|
301
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
302
|
+
):
|
303
|
+
self.moe_runner_config = moe_runner_config
|
304
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
305
|
+
|
296
306
|
def apply(
|
297
307
|
self,
|
298
308
|
layer: torch.nn.Module,
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
309
|
+
dispatch_output: StandardDispatchOutput,
|
310
|
+
) -> CombineInput:
|
311
|
+
|
312
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
313
|
+
|
314
|
+
x = dispatch_output.hidden_states
|
315
|
+
topk_output = dispatch_output.topk_output
|
316
|
+
|
317
|
+
moe_runner_config = self.moe_runner_config
|
304
318
|
|
305
319
|
if (
|
306
320
|
_use_aiter
|
@@ -308,7 +322,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
308
322
|
and moe_runner_config.apply_router_weight_on_input
|
309
323
|
):
|
310
324
|
topk_weights, topk_ids, _ = topk_output
|
311
|
-
|
325
|
+
output = rocm_fused_experts_tkw1(
|
312
326
|
hidden_states=x,
|
313
327
|
w1=layer.w13_weight,
|
314
328
|
w2=layer.w2_weight,
|
@@ -324,21 +338,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
324
338
|
a1_scale=layer.w13_input_scale,
|
325
339
|
a2_scale=layer.w2_input_scale,
|
326
340
|
)
|
341
|
+
return StandardCombineInput(hidden_states=output)
|
327
342
|
else:
|
328
|
-
|
329
|
-
|
330
|
-
layer.
|
331
|
-
layer.w2_weight,
|
332
|
-
topk_output=topk_output,
|
333
|
-
moe_runner_config=moe_runner_config,
|
343
|
+
quant_info = TritonMoeQuantInfo(
|
344
|
+
w13_weight=layer.w13_weight,
|
345
|
+
w2_weight=layer.w2_weight,
|
334
346
|
use_fp8_w8a8=True,
|
335
347
|
per_channel_quant=self.weight_quant.strategy
|
336
348
|
== QuantizationStrategy.CHANNEL,
|
337
|
-
|
349
|
+
w13_scale=layer.w13_weight_scale,
|
338
350
|
w2_scale=layer.w2_weight_scale,
|
339
|
-
|
351
|
+
a13_scale=layer.w13_input_scale,
|
340
352
|
a2_scale=layer.w2_input_scale,
|
341
353
|
)
|
354
|
+
return self.runner.run(dispatch_output, quant_info)
|
342
355
|
|
343
356
|
|
344
357
|
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
@@ -380,8 +393,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
380
393
|
params_dtype == torch.float16
|
381
394
|
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
382
395
|
|
383
|
-
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
384
|
-
|
385
396
|
# Will transpose the loaded weight along the
|
386
397
|
# intermediate and hidden dim sizes. Will
|
387
398
|
# shard for TP along the transposed dims
|
@@ -415,13 +426,13 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
415
426
|
# In the case where we have actorder/g_idx,
|
416
427
|
# we do not partition the w2 scales
|
417
428
|
load_full_w2 = self.actorder and self.group_size != -1
|
418
|
-
w2_scales_size = (
|
419
|
-
intermediate_size_full if load_full_w2 else intermediate_size_per_partition
|
420
|
-
)
|
421
429
|
|
422
|
-
|
423
|
-
intermediate_size_per_partition
|
424
|
-
|
430
|
+
if load_full_w2:
|
431
|
+
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
|
432
|
+
else:
|
433
|
+
w2_scales_size = intermediate_size_per_partition
|
434
|
+
|
435
|
+
self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1
|
425
436
|
|
426
437
|
if self.strategy == "channel":
|
427
438
|
num_groups_w2 = num_groups_w13 = 1
|
@@ -640,21 +651,29 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
640
651
|
)
|
641
652
|
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
642
653
|
|
654
|
+
def create_moe_runner(
|
655
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
656
|
+
):
|
657
|
+
self.moe_runner_config = moe_runner_config
|
658
|
+
|
643
659
|
def apply(
|
644
660
|
self,
|
645
661
|
layer: torch.nn.Module,
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
662
|
+
dispatch_output: StandardDispatchOutput,
|
663
|
+
) -> CombineInput:
|
664
|
+
|
665
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
650
666
|
|
651
667
|
assert (
|
652
|
-
moe_runner_config.activation == "silu"
|
668
|
+
self.moe_runner_config.activation == "silu"
|
653
669
|
), "Only SiLU activation is supported."
|
654
670
|
|
671
|
+
x = dispatch_output.hidden_states
|
672
|
+
topk_output = dispatch_output.topk_output
|
673
|
+
|
655
674
|
topk_weights, topk_ids, router_logits = topk_output
|
656
675
|
|
657
|
-
|
676
|
+
output = torch.ops.vllm.fused_marlin_moe(
|
658
677
|
x,
|
659
678
|
layer.w13_weight_packed,
|
660
679
|
layer.w2_weight_packed,
|
@@ -670,3 +689,4 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
670
689
|
num_bits=self.num_bits,
|
671
690
|
is_k_full=self.is_k_full,
|
672
691
|
)
|
692
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -2,10 +2,12 @@
|
|
2
2
|
|
3
3
|
from .compressed_tensors_scheme import CompressedTensorsScheme
|
4
4
|
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
5
|
+
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
5
6
|
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
6
7
|
|
7
8
|
__all__ = [
|
8
9
|
"CompressedTensorsScheme",
|
9
10
|
"CompressedTensorsW8A8Fp8",
|
10
11
|
"CompressedTensorsW8A16Fp8",
|
12
|
+
"CompressedTensorsW8A8Int8",
|
11
13
|
]
|
@@ -21,9 +21,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
21
21
|
normalize_e4m3fn_to_e4m3fnuz,
|
22
22
|
)
|
23
23
|
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
24
|
+
from sglang.srt.utils import get_bool_env_var, is_hip
|
24
25
|
|
25
26
|
__all__ = ["CompressedTensorsW8A8Fp8"]
|
26
27
|
|
28
|
+
_is_hip = is_hip()
|
29
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
30
|
+
if _use_aiter:
|
31
|
+
from aiter.ops.shuffle import shuffle_weight
|
32
|
+
|
27
33
|
|
28
34
|
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
29
35
|
|
@@ -76,7 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|
76
82
|
else:
|
77
83
|
weight_scale = layer.weight_scale.data
|
78
84
|
|
79
|
-
|
85
|
+
if _use_aiter:
|
86
|
+
layer.weight = Parameter(
|
87
|
+
shuffle_weight(weight, (16, 16)), requires_grad=False
|
88
|
+
)
|
89
|
+
else:
|
90
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
91
|
+
|
80
92
|
# required by torch.compile to be torch.nn.Parameter
|
81
93
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
82
94
|
|
@@ -0,0 +1,173 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
from typing import Callable, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from compressed_tensors.quantization import QuantizationStrategy
|
8
|
+
from torch.nn import Parameter
|
9
|
+
|
10
|
+
from sglang.srt.layers.parameter import (
|
11
|
+
ChannelQuantScaleParameter,
|
12
|
+
ModelWeightParameter,
|
13
|
+
PerTensorScaleParameter,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
16
|
+
CompressedTensorsScheme,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
19
|
+
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
20
|
+
from sglang.srt.utils import is_cuda
|
21
|
+
|
22
|
+
_is_cuda = is_cuda()
|
23
|
+
if _is_cuda:
|
24
|
+
from sgl_kernel import int8_scaled_mm
|
25
|
+
|
26
|
+
|
27
|
+
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
31
|
+
):
|
32
|
+
self.strategy = strategy
|
33
|
+
self.is_static_input_scheme = is_static_input_scheme
|
34
|
+
self.input_symmetric = input_symmetric
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def get_min_capability(cls) -> int:
|
38
|
+
# lovelace and up
|
39
|
+
return 89
|
40
|
+
|
41
|
+
def process_weights_after_loading(self, layer) -> None:
|
42
|
+
# If per tensor, when we have a fused module (e.g. QKV) with per
|
43
|
+
# tensor scales (thus N scales being passed to the kernel),
|
44
|
+
# requantize so we can always run per channel
|
45
|
+
if self.strategy == QuantizationStrategy.TENSOR:
|
46
|
+
max_w_scale, weight = requantize_with_max_scale(
|
47
|
+
weight=layer.weight,
|
48
|
+
weight_scale=layer.weight_scale,
|
49
|
+
logical_widths=layer.logical_widths,
|
50
|
+
)
|
51
|
+
|
52
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
53
|
+
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
54
|
+
|
55
|
+
# If channelwise, scales are already lined up, so just transpose.
|
56
|
+
elif self.strategy == QuantizationStrategy.CHANNEL:
|
57
|
+
weight = layer.weight
|
58
|
+
weight_scale = layer.weight_scale.data
|
59
|
+
|
60
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
61
|
+
# required by torch.compile to be torch.nn.Parameter
|
62
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
63
|
+
|
64
|
+
else:
|
65
|
+
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
66
|
+
|
67
|
+
# INPUT SCALE
|
68
|
+
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
|
69
|
+
if self.input_symmetric:
|
70
|
+
layer.input_scale = Parameter(
|
71
|
+
layer.input_scale.max(), requires_grad=False
|
72
|
+
)
|
73
|
+
else:
|
74
|
+
input_scale = layer.input_scale
|
75
|
+
input_zero_point = layer.input_zero_point
|
76
|
+
|
77
|
+
# reconstruct the ranges
|
78
|
+
int8_traits = torch.iinfo(torch.int8)
|
79
|
+
azps = input_zero_point.to(dtype=torch.int32)
|
80
|
+
range_max = (input_scale * (int8_traits.max - azps)).max()
|
81
|
+
range_min = (input_scale * (int8_traits.min - azps)).min()
|
82
|
+
|
83
|
+
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
84
|
+
|
85
|
+
# AZP loaded as int8 but used as int32
|
86
|
+
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
87
|
+
|
88
|
+
layer.input_scale = Parameter(scale, requires_grad=False)
|
89
|
+
layer.input_zero_point = Parameter(azp, requires_grad=False)
|
90
|
+
else:
|
91
|
+
layer.input_scale = None
|
92
|
+
layer.input_zero_point = None
|
93
|
+
|
94
|
+
# azp_adj is the AZP adjustment term, used to account for weights.
|
95
|
+
# It does not depend on scales or azp, so it is the same for
|
96
|
+
# static and dynamic quantization.
|
97
|
+
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
98
|
+
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
99
|
+
if not self.input_symmetric:
|
100
|
+
weight = layer.weight
|
101
|
+
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
102
|
+
if self.is_static_input_scheme:
|
103
|
+
# cutlass_w8a8 requires azp to be folded into azp_adj
|
104
|
+
# in the per-tensor case
|
105
|
+
azp_adj = layer.input_zero_point * azp_adj
|
106
|
+
layer.azp_adj = Parameter(azp_adj, requires_grad=False)
|
107
|
+
else:
|
108
|
+
layer.azp_adj = None
|
109
|
+
|
110
|
+
def create_weights(
|
111
|
+
self,
|
112
|
+
layer: torch.nn.Module,
|
113
|
+
output_partition_sizes: list[int],
|
114
|
+
input_size_per_partition: int,
|
115
|
+
params_dtype: torch.dtype,
|
116
|
+
weight_loader: Callable,
|
117
|
+
**kwargs,
|
118
|
+
):
|
119
|
+
output_size_per_partition = sum(output_partition_sizes)
|
120
|
+
layer.logical_widths = output_partition_sizes
|
121
|
+
|
122
|
+
# WEIGHT
|
123
|
+
weight = ModelWeightParameter(
|
124
|
+
data=torch.empty(
|
125
|
+
output_size_per_partition, input_size_per_partition, dtype=torch.int8
|
126
|
+
),
|
127
|
+
input_dim=1,
|
128
|
+
output_dim=0,
|
129
|
+
weight_loader=weight_loader,
|
130
|
+
)
|
131
|
+
|
132
|
+
layer.register_parameter("weight", weight)
|
133
|
+
|
134
|
+
# WEIGHT SCALE
|
135
|
+
if self.strategy == QuantizationStrategy.CHANNEL:
|
136
|
+
weight_scale = ChannelQuantScaleParameter(
|
137
|
+
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
138
|
+
output_dim=0,
|
139
|
+
weight_loader=weight_loader,
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
assert self.strategy == QuantizationStrategy.TENSOR
|
143
|
+
weight_scale = PerTensorScaleParameter(
|
144
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
145
|
+
weight_loader=weight_loader,
|
146
|
+
)
|
147
|
+
layer.register_parameter("weight_scale", weight_scale)
|
148
|
+
|
149
|
+
# INPUT SCALE
|
150
|
+
if self.is_static_input_scheme:
|
151
|
+
input_scale = PerTensorScaleParameter(
|
152
|
+
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
153
|
+
)
|
154
|
+
layer.register_parameter("input_scale", input_scale)
|
155
|
+
|
156
|
+
if not self.input_symmetric:
|
157
|
+
# Note: compressed-tensors stores the zp using the same dtype
|
158
|
+
# as the weights
|
159
|
+
# AZP loaded as int8 but used as int32
|
160
|
+
input_zero_point = PerTensorScaleParameter(
|
161
|
+
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
162
|
+
)
|
163
|
+
layer.register_parameter("input_zero_point", input_zero_point)
|
164
|
+
|
165
|
+
def apply_weights(
|
166
|
+
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
167
|
+
) -> torch.Tensor:
|
168
|
+
# TODO: add cutlass_scaled_mm_azp support
|
169
|
+
x_q, x_scale = per_token_quant_int8(x)
|
170
|
+
|
171
|
+
return int8_scaled_mm(
|
172
|
+
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
173
|
+
)
|
@@ -1,8 +1,6 @@
|
|
1
1
|
import logging
|
2
2
|
|
3
|
-
import
|
4
|
-
|
5
|
-
from sglang.srt.utils import get_bool_env_var, get_device_sm
|
3
|
+
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
|
6
4
|
|
7
5
|
logger = logging.getLogger(__name__)
|
8
6
|
|
@@ -15,18 +13,12 @@ def _compute_enable_deep_gemm():
|
|
15
13
|
try:
|
16
14
|
import deep_gemm
|
17
15
|
except ImportError:
|
18
|
-
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
|
19
16
|
return False
|
20
17
|
|
21
18
|
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
22
19
|
|
23
20
|
|
24
|
-
def _is_blackwell_arch() -> bool:
|
25
|
-
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
26
|
-
return major == 10
|
27
|
-
|
28
|
-
|
29
21
|
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
30
22
|
|
31
|
-
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and
|
23
|
+
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell()
|
32
24
|
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|