sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
|
11
11
|
ENABLE_JIT_DEEPGEMM,
|
12
12
|
)
|
13
13
|
from sglang.srt.server_args import ServerArgs
|
14
|
+
from sglang.srt.utils import get_bool_env_var
|
14
15
|
|
15
16
|
logger = logging.getLogger(__name__)
|
16
17
|
|
@@ -18,6 +19,8 @@ if ENABLE_JIT_DEEPGEMM:
|
|
18
19
|
import deep_gemm
|
19
20
|
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
20
21
|
|
22
|
+
_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
|
23
|
+
|
21
24
|
|
22
25
|
# TODO maybe rename these functions
|
23
26
|
def grouped_gemm_nt_f8f8bf16_masked(
|
@@ -31,6 +34,9 @@ def grouped_gemm_nt_f8f8bf16_masked(
|
|
31
34
|
_, n, _ = rhs[0].shape
|
32
35
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
33
36
|
|
37
|
+
_sanity_check_input(lhs)
|
38
|
+
_sanity_check_input(rhs)
|
39
|
+
|
34
40
|
with compile_utils.deep_gemm_execution_hook(
|
35
41
|
expected_m, n, k, num_groups, kernel_type
|
36
42
|
):
|
@@ -53,6 +59,9 @@ def grouped_gemm_nt_f8f8bf16_contig(
|
|
53
59
|
num_groups, n, _ = rhs[0].shape
|
54
60
|
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
55
61
|
|
62
|
+
_sanity_check_input(lhs)
|
63
|
+
_sanity_check_input(rhs)
|
64
|
+
|
56
65
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
57
66
|
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
58
67
|
|
@@ -67,6 +76,9 @@ def gemm_nt_f8f8bf16(
|
|
67
76
|
num_groups = 1
|
68
77
|
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
69
78
|
|
79
|
+
_sanity_check_input(lhs)
|
80
|
+
_sanity_check_input(rhs)
|
81
|
+
|
70
82
|
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
71
83
|
deep_gemm.fp8_gemm_nt(
|
72
84
|
lhs,
|
@@ -90,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms):
|
|
90
102
|
yield
|
91
103
|
finally:
|
92
104
|
deep_gemm.set_num_sms(original_num_sms)
|
105
|
+
|
106
|
+
|
107
|
+
def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]):
|
108
|
+
if not _SANITY_CHECK:
|
109
|
+
return
|
110
|
+
|
111
|
+
x, x_scale = x_fp8
|
112
|
+
|
113
|
+
if x_scale.dtype == torch.int:
|
114
|
+
return
|
115
|
+
|
116
|
+
from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0
|
117
|
+
|
118
|
+
x_scale_ceil = ceil_to_ue8m0(x_scale)
|
119
|
+
assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}"
|
@@ -30,6 +30,9 @@ except ImportError:
|
|
30
30
|
|
31
31
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
32
32
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
33
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
34
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
35
|
+
from sglang.srt.layers.moe.token_dispatcher.base import DispatchOutputChecker
|
33
36
|
from sglang.srt.layers.parameter import (
|
34
37
|
BlockQuantScaleParameter,
|
35
38
|
ModelWeightParameter,
|
@@ -81,7 +84,11 @@ from sglang.srt.utils import (
|
|
81
84
|
)
|
82
85
|
|
83
86
|
if TYPE_CHECKING:
|
84
|
-
from sglang.srt.layers.moe.
|
87
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
88
|
+
CombineInput,
|
89
|
+
DispatchOutput,
|
90
|
+
StandardDispatchOutput,
|
91
|
+
)
|
85
92
|
from sglang.srt.layers.moe.topk import TopKOutput
|
86
93
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
87
94
|
|
@@ -345,11 +352,14 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
345
352
|
_is_cpu_amx_available
|
346
353
|
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
347
354
|
_amx_process_weight_after_loading(layer, ["weight"])
|
355
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
356
|
+
layer.weight_scale_inv.data, requires_grad=False
|
357
|
+
)
|
348
358
|
return
|
349
359
|
else:
|
350
360
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
351
|
-
layer.weight =
|
352
|
-
layer.weight_scale_inv =
|
361
|
+
layer.weight.data = weight.data
|
362
|
+
layer.weight_scale_inv.data = weight_scale.data
|
353
363
|
else:
|
354
364
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
355
365
|
|
@@ -527,7 +537,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
527
537
|
layer: Module,
|
528
538
|
num_experts: int,
|
529
539
|
hidden_size: int,
|
530
|
-
|
540
|
+
intermediate_size_per_partition: int,
|
531
541
|
params_dtype: torch.dtype,
|
532
542
|
**extra_weight_attrs,
|
533
543
|
):
|
@@ -543,18 +553,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
543
553
|
)
|
544
554
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
545
555
|
# Required by column parallel or enabling merged weights
|
546
|
-
if
|
556
|
+
if intermediate_size_per_partition % block_n != 0:
|
547
557
|
raise ValueError(
|
548
558
|
f"The output_size of gate's and up's weight = "
|
549
|
-
f"{
|
559
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
550
560
|
f"weight quantization block_n = {block_n}."
|
551
561
|
)
|
552
562
|
if tp_size > 1:
|
553
563
|
# Required by row parallel
|
554
|
-
if
|
564
|
+
if intermediate_size_per_partition % block_k != 0:
|
555
565
|
raise ValueError(
|
556
566
|
f"The input_size of down's weight = "
|
557
|
-
f"{
|
567
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
558
568
|
f"weight quantization block_k = {block_k}."
|
559
569
|
)
|
560
570
|
|
@@ -564,7 +574,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
564
574
|
w13_weight = torch.nn.Parameter(
|
565
575
|
torch.empty(
|
566
576
|
num_experts,
|
567
|
-
2 *
|
577
|
+
2 * intermediate_size_per_partition,
|
568
578
|
hidden_size // 8,
|
569
579
|
dtype=params_dtype,
|
570
580
|
),
|
@@ -572,20 +582,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
572
582
|
)
|
573
583
|
w2_weight = torch.nn.Parameter(
|
574
584
|
torch.empty(
|
575
|
-
num_experts,
|
585
|
+
num_experts,
|
586
|
+
hidden_size,
|
587
|
+
intermediate_size_per_partition // 8,
|
588
|
+
dtype=params_dtype,
|
576
589
|
),
|
577
590
|
requires_grad=False,
|
578
591
|
)
|
579
592
|
else:
|
580
593
|
w13_weight = torch.nn.Parameter(
|
581
594
|
torch.empty(
|
582
|
-
num_experts,
|
595
|
+
num_experts,
|
596
|
+
2 * intermediate_size_per_partition,
|
597
|
+
hidden_size,
|
598
|
+
dtype=params_dtype,
|
583
599
|
),
|
584
600
|
requires_grad=False,
|
585
601
|
)
|
586
602
|
w2_weight = torch.nn.Parameter(
|
587
603
|
torch.empty(
|
588
|
-
num_experts,
|
604
|
+
num_experts,
|
605
|
+
hidden_size,
|
606
|
+
intermediate_size_per_partition,
|
607
|
+
dtype=params_dtype,
|
589
608
|
),
|
590
609
|
requires_grad=False,
|
591
610
|
)
|
@@ -601,7 +620,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
601
620
|
w13_weight_scale = torch.nn.Parameter(
|
602
621
|
torch.ones(
|
603
622
|
num_experts,
|
604
|
-
2 * ((
|
623
|
+
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
605
624
|
(hidden_size + block_k - 1) // block_k,
|
606
625
|
dtype=torch.float32,
|
607
626
|
),
|
@@ -611,7 +630,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
611
630
|
torch.ones(
|
612
631
|
num_experts,
|
613
632
|
(hidden_size + block_n - 1) // block_n,
|
614
|
-
(
|
633
|
+
(intermediate_size_per_partition + block_k - 1) // block_k,
|
615
634
|
dtype=torch.float32,
|
616
635
|
),
|
617
636
|
requires_grad=False,
|
@@ -619,11 +638,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
619
638
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
620
639
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
621
640
|
assert self.quant_config.activation_scheme == "dynamic"
|
622
|
-
if
|
623
|
-
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
624
|
-
and self.cutlass_fp8_supported
|
625
|
-
and (is_sm100_supported() or is_sm90_supported())
|
626
|
-
):
|
641
|
+
if self.use_cutlass_fused_experts_fp8:
|
627
642
|
self.ab_strides1 = torch.full(
|
628
643
|
(num_experts,),
|
629
644
|
hidden_size,
|
@@ -632,13 +647,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
632
647
|
)
|
633
648
|
self.c_strides1 = torch.full(
|
634
649
|
(num_experts,),
|
635
|
-
2 *
|
650
|
+
2 * intermediate_size_per_partition,
|
636
651
|
device=w13_weight.device,
|
637
652
|
dtype=torch.int64,
|
638
653
|
)
|
639
654
|
self.ab_strides2 = torch.full(
|
640
655
|
(num_experts,),
|
641
|
-
|
656
|
+
intermediate_size_per_partition,
|
642
657
|
device=w2_weight.device,
|
643
658
|
dtype=torch.int64,
|
644
659
|
)
|
@@ -691,7 +706,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
691
706
|
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
|
692
707
|
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
693
708
|
w13_weight_scale1 = torch.nn.Parameter(
|
694
|
-
torch.ones(
|
709
|
+
torch.ones(
|
710
|
+
num_experts,
|
711
|
+
2 * intermediate_size_per_partition,
|
712
|
+
dtype=torch.float32,
|
713
|
+
),
|
695
714
|
requires_grad=False,
|
696
715
|
)
|
697
716
|
w2_weight_scale1 = torch.nn.Parameter(
|
@@ -984,14 +1003,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
984
1003
|
)
|
985
1004
|
torch.cuda.empty_cache()
|
986
1005
|
|
1006
|
+
def create_moe_runner(
|
1007
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1008
|
+
):
|
1009
|
+
self.moe_runner_config = moe_runner_config
|
1010
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
1011
|
+
|
987
1012
|
def apply(
|
988
1013
|
self,
|
989
1014
|
layer: torch.nn.Module,
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
1015
|
+
dispatch_output: DispatchOutput,
|
1016
|
+
) -> CombineInput:
|
1017
|
+
|
1018
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1019
|
+
|
1020
|
+
x = dispatch_output.hidden_states
|
1021
|
+
topk_output = dispatch_output.topk_output
|
1022
|
+
moe_runner_config = self.moe_runner_config
|
995
1023
|
|
996
1024
|
if use_intel_amx_backend(layer):
|
997
1025
|
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
@@ -1001,7 +1029,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1001
1029
|
moe_runner_config.apply_router_weight_on_input, topk_weights, x
|
1002
1030
|
)
|
1003
1031
|
|
1004
|
-
|
1032
|
+
output = torch.ops.sgl_kernel.fused_experts_cpu(
|
1005
1033
|
x,
|
1006
1034
|
layer.w13_weight,
|
1007
1035
|
layer.w2_weight,
|
@@ -1017,6 +1045,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1017
1045
|
None, # a2_scale
|
1018
1046
|
True, # is_vnni
|
1019
1047
|
)
|
1048
|
+
return StandardCombineInput(hidden_states=output)
|
1020
1049
|
|
1021
1050
|
if _is_hip:
|
1022
1051
|
ret = self.maybe_apply_hip_fused_experts(
|
@@ -1027,7 +1056,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1027
1056
|
moe_runner_config.no_combine,
|
1028
1057
|
)
|
1029
1058
|
if ret is not None:
|
1030
|
-
return ret
|
1059
|
+
return StandardCombineInput(hidden_states=ret)
|
1031
1060
|
|
1032
1061
|
if self.use_cutlass_fused_experts_fp8:
|
1033
1062
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
@@ -1056,17 +1085,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1056
1085
|
self.problem_sizes2,
|
1057
1086
|
use_fp8_blockscale=True,
|
1058
1087
|
)
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
layer.w13_weight,
|
1065
|
-
layer.w2_weight,
|
1066
|
-
topk_output=topk_output,
|
1067
|
-
moe_runner_config=moe_runner_config,
|
1088
|
+
return StandardCombineInput(hidden_states=output)
|
1089
|
+
|
1090
|
+
quant_info = TritonMoeQuantInfo(
|
1091
|
+
w13_weight=layer.w13_weight,
|
1092
|
+
w2_weight=layer.w2_weight,
|
1068
1093
|
use_fp8_w8a8=True,
|
1069
|
-
|
1094
|
+
w13_scale=(
|
1070
1095
|
layer.w13_weight_scale_inv
|
1071
1096
|
if self.block_quant
|
1072
1097
|
else layer.w13_weight_scale
|
@@ -1074,20 +1099,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1074
1099
|
w2_scale=(
|
1075
1100
|
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
1076
1101
|
),
|
1077
|
-
|
1102
|
+
a13_scale=layer.w13_input_scale,
|
1078
1103
|
a2_scale=layer.w2_input_scale,
|
1079
1104
|
block_shape=self.quant_config.weight_block_size,
|
1080
1105
|
)
|
1106
|
+
return self.runner.run(dispatch_output, quant_info)
|
1081
1107
|
|
1082
1108
|
def apply_with_router_logits(
|
1083
1109
|
self,
|
1084
1110
|
layer: torch.nn.Module,
|
1085
|
-
|
1086
|
-
topk_output: TopKOutput,
|
1087
|
-
moe_runner_config: MoeRunnerConfig,
|
1111
|
+
dispatch_output: StandardDispatchOutput,
|
1088
1112
|
) -> torch.Tensor:
|
1089
|
-
|
1090
|
-
|
1113
|
+
x = dispatch_output.hidden_states
|
1114
|
+
topk_output = dispatch_output.topk_output
|
1115
|
+
|
1116
|
+
activation = self.moe_runner_config.activation
|
1117
|
+
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
|
1091
1118
|
|
1092
1119
|
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
1093
1120
|
|
@@ -1108,10 +1135,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1108
1135
|
and topk_config.topk_group is not None
|
1109
1136
|
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"
|
1110
1137
|
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
correction_bias
|
1138
|
+
correction_bias = (
|
1139
|
+
None
|
1140
|
+
if topk_config.correction_bias is None
|
1141
|
+
else topk_config.correction_bias.to(x.dtype)
|
1142
|
+
)
|
1143
|
+
|
1115
1144
|
return trtllm_fp8_block_scale_moe(
|
1116
1145
|
routing_logits=router_logits.to(torch.float32),
|
1117
1146
|
routing_bias=correction_bias,
|
@@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
+
from sglang.srt import offloader
|
5
6
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
6
7
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
7
8
|
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
@@ -45,7 +46,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
45
46
|
|
46
47
|
if _use_aiter:
|
47
48
|
import aiter
|
48
|
-
from aiter import gemm_a8w8_blockscale, get_hip_quant
|
49
|
+
from aiter import gemm_a8w8_blockscale, gemm_a8w8_bpreshuffle, get_hip_quant
|
49
50
|
|
50
51
|
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
|
51
52
|
|
@@ -248,11 +249,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
248
249
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
249
250
|
)
|
250
251
|
|
251
|
-
# NOTE(alcanderian): Useless when scale is packed to int32
|
252
|
-
# if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
253
|
-
# _check_ue8m0("x_scale", x_scale)
|
254
|
-
# _check_ue8m0("weight_scale", ws)
|
255
|
-
|
256
252
|
output = w8a8_block_fp8_matmul_deepgemm(
|
257
253
|
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
258
254
|
)
|
@@ -261,11 +257,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
261
257
|
return output.to(dtype=output_dtype).view(*output_shape)
|
262
258
|
|
263
259
|
|
264
|
-
def _check_ue8m0(name, x):
|
265
|
-
x_ceil = ceil_to_ue8m0(x)
|
266
|
-
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
|
267
|
-
|
268
|
-
|
269
260
|
def aiter_w8a8_block_fp8_linear(
|
270
261
|
input: torch.Tensor,
|
271
262
|
weight: torch.Tensor,
|
@@ -427,10 +418,14 @@ def block_quant_dequant(
|
|
427
418
|
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
|
428
419
|
assert isinstance(weight, torch.nn.Parameter)
|
429
420
|
assert isinstance(weight_scale_inv, torch.nn.Parameter)
|
430
|
-
|
431
|
-
|
421
|
+
|
422
|
+
new_weight, new_weight_scale_inv = _requant_weight_ue8m0(
|
423
|
+
weight.to(weight_scale_inv.device), weight_scale_inv, weight_block_size
|
432
424
|
)
|
433
425
|
|
426
|
+
offloader.update_param(weight, new_weight)
|
427
|
+
weight_scale_inv.data = new_weight_scale_inv
|
428
|
+
|
434
429
|
|
435
430
|
def _requant_weight_ue8m0(
|
436
431
|
weight: torch.Tensor,
|
@@ -652,25 +647,49 @@ def apply_fp8_linear(
|
|
652
647
|
use_per_token_if_dynamic
|
653
648
|
and not per_tensor_weights
|
654
649
|
and not per_tensor_activations
|
655
|
-
and USE_ROWWISE_TORCH_SCALED_MM
|
650
|
+
and (USE_ROWWISE_TORCH_SCALED_MM or _use_aiter)
|
656
651
|
):
|
657
|
-
#
|
658
|
-
#
|
659
|
-
#
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
652
|
+
# into this sector means use dynamic per-token-per-channel quant
|
653
|
+
# per-token scale quant for input matrix, every row(one token) have one scale factor
|
654
|
+
# per-channel scale quant for weight matrix, every col(one channel) have one scale factor
|
655
|
+
if _use_aiter:
|
656
|
+
# gemm_a8w8_bpreshuffle(XQ, WQ, x_scale, w_scale, dtype)
|
657
|
+
# XQ -> input tensor, shape = (m, k)
|
658
|
+
# WQ -> weight tensor, shape = (n, k), with preshuffe get better perf
|
659
|
+
# x_scale -> input scale tensor, shape = (m, 1)
|
660
|
+
# w_scale -> weight scale tensor, shape = (n ,1)
|
661
|
+
# dtype -> output dtype
|
662
|
+
output = gemm_a8w8_bpreshuffle(
|
663
|
+
XQ=qinput,
|
664
|
+
WQ=weight,
|
665
|
+
x_scale=x_scale,
|
666
|
+
w_scale=weight_scale,
|
667
|
+
dtype=input.dtype,
|
668
|
+
)
|
669
|
+
if bias is not None:
|
670
|
+
output += bias
|
671
|
+
return _process_scaled_mm_output(
|
672
|
+
output, input_2d.shape, [*input.shape[:-1], weight.shape[0]]
|
673
|
+
)
|
674
|
+
else:
|
675
|
+
# For now validated on ROCm platform
|
676
|
+
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
677
|
+
# https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
|
678
|
+
# and ROCm 6.3, which only exists in torch 2.7 and above.
|
679
|
+
# For CUDA platform please validate if the
|
680
|
+
# torch._scaled_mm support rowwise scaled GEMM
|
681
|
+
# Fused GEMM_DQ Rowwise GEMM
|
682
|
+
output = torch._scaled_mm(
|
683
|
+
qinput,
|
684
|
+
weight,
|
685
|
+
out_dtype=input.dtype,
|
686
|
+
scale_a=x_scale,
|
687
|
+
scale_b=weight_scale.t(),
|
688
|
+
bias=bias,
|
689
|
+
)
|
690
|
+
return _process_scaled_mm_output(
|
691
|
+
output, input_2d.shape, output_shape
|
692
|
+
)
|
674
693
|
else:
|
675
694
|
# Fallback for channelwise case, where we use unfused DQ
|
676
695
|
# due to limitations with scaled_mm
|
@@ -713,7 +732,7 @@ def apply_fp8_linear(
|
|
713
732
|
# final solution should be: 1. add support to per-tensor activation scaling.
|
714
733
|
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
715
734
|
if _is_hip and weight_scale.numel() == 1:
|
716
|
-
qinput, x_scale =
|
735
|
+
qinput, x_scale = scaled_fp8_quant(
|
717
736
|
input_2d,
|
718
737
|
input_scale,
|
719
738
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
|
|
45
45
|
|
46
46
|
if TYPE_CHECKING:
|
47
47
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
48
|
-
from sglang.srt.layers.moe.
|
48
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
49
|
+
StandardDispatchOutput,
|
50
|
+
CombineInput,
|
51
|
+
)
|
49
52
|
|
50
53
|
from sglang.srt.utils import is_cuda
|
51
54
|
|
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
838
841
|
from sglang.srt.layers.linear import set_weight_attrs
|
839
842
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
840
843
|
|
841
|
-
|
842
|
-
|
843
|
-
self.is_k_full = (not self.quant_config.desc_act) or (
|
844
|
-
intermediate_size_per_partition == intermediate_size
|
845
|
-
)
|
844
|
+
self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
|
846
845
|
|
847
846
|
if self.quant_config.group_size != -1:
|
848
847
|
scales_size13 = hidden_size // self.quant_config.group_size
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
)
|
848
|
+
if self.quant_config.desc_act:
|
849
|
+
w2_scales_size = intermediate_size_per_partition
|
850
|
+
else:
|
851
|
+
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
|
854
852
|
scales_size2 = w2_scales_size // self.quant_config.group_size
|
855
853
|
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
856
854
|
else:
|
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1052
1050
|
)
|
1053
1051
|
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
1054
1052
|
|
1053
|
+
def create_moe_runner(
|
1054
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1055
|
+
):
|
1056
|
+
self.moe_runner_config = moe_runner_config
|
1057
|
+
|
1055
1058
|
def apply(
|
1056
1059
|
self,
|
1057
1060
|
layer: torch.nn.Module,
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1061
|
+
dispatch_output: StandardDispatchOutput,
|
1062
|
+
) -> CombineInput:
|
1063
|
+
|
1064
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1065
|
+
|
1066
|
+
x = dispatch_output.hidden_states
|
1067
|
+
topk_output = dispatch_output.topk_output
|
1068
|
+
|
1062
1069
|
# Delay the import to avoid circular dependency
|
1063
1070
|
|
1064
1071
|
assert (
|
1065
|
-
moe_runner_config.activation == "silu"
|
1072
|
+
self.moe_runner_config.activation == "silu"
|
1066
1073
|
), "Only SiLU activation is supported."
|
1067
1074
|
|
1068
1075
|
# The input must currently be float16
|
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1071
1078
|
|
1072
1079
|
topk_weights, topk_ids, router_logits = topk_output
|
1073
1080
|
|
1074
|
-
|
1081
|
+
output = fused_marlin_moe(
|
1075
1082
|
x,
|
1076
1083
|
layer.w13_qweight,
|
1077
1084
|
layer.w2_qweight,
|
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1087
1094
|
num_bits=self.quant_config.weight_bits,
|
1088
1095
|
is_k_full=self.is_k_full,
|
1089
1096
|
).to(orig_dtype)
|
1097
|
+
return StandardCombineInput(hidden_states=output)
|