sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +378 -160
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +10 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +136 -25
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +63 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +83 -80
- sglang/srt/entrypoints/grpc_server.py +430 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +195 -102
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +58 -6
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +33 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +20 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +10 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +24 -10
- sglang/srt/layers/attention/flashinfer_backend.py +258 -22
- sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
- sglang/srt/layers/attention/utils.py +89 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +12 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +64 -19
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +152 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
- sglang/srt/layers/moe/ep_moe/layer.py +154 -625
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
- sglang/srt/layers/moe/moe_runner/runner.py +6 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
- sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +7 -6
- sglang/srt/layers/moe/utils.py +20 -5
- sglang/srt/layers/quantization/__init__.py +5 -58
- sglang/srt/layers/quantization/awq.py +183 -9
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +27 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +152 -81
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +35 -68
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +23 -48
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +87 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +62 -9
- sglang/srt/layers/rotary_embedding.py +686 -17
- sglang/srt/layers/sampler.py +47 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +69 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +420 -514
- sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +375 -95
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +11 -2
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +517 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +71 -25
- sglang/srt/model_executor/model_runner.py +362 -270
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +418 -140
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +327 -382
- sglang/srt/models/glm4_moe_nextn.py +6 -16
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +32 -199
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +22 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3.py +34 -4
- sglang/srt/models/qwen3_moe.py +19 -37
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +7 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +2 -6
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +28 -2
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +846 -163
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +36 -31
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +272 -82
- sglang/srt/utils/hf_transformers_utils.py +44 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +463 -107
- sglang/test/test_deterministic_utils.py +74 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/models/vila.py +0 -306
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
sglang/check_env.py
CHANGED
sglang/compile_deep_gemm.py
CHANGED
|
@@ -19,6 +19,7 @@ import requests
|
|
|
19
19
|
|
|
20
20
|
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
|
21
21
|
from sglang.srt.entrypoints.http_server import launch_server
|
|
22
|
+
from sglang.srt.environ import envs
|
|
22
23
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
|
23
24
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|
24
25
|
from sglang.srt.server_args import ServerArgs
|
|
@@ -28,9 +29,9 @@ from sglang.srt.warmup import warmup
|
|
|
28
29
|
multiprocessing.set_start_method("spawn", force=True)
|
|
29
30
|
|
|
30
31
|
# Reduce warning
|
|
31
|
-
|
|
32
|
+
envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)
|
|
32
33
|
# Force enable deep gemm
|
|
33
|
-
|
|
34
|
+
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)
|
|
34
35
|
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
|
|
35
36
|
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
|
|
36
37
|
|
|
@@ -141,6 +142,9 @@ def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
|
|
141
142
|
server_args.enable_torch_compile = False
|
|
142
143
|
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
|
143
144
|
|
|
145
|
+
server_args.load_format = "dummy"
|
|
146
|
+
print(f"Set load format to dummy to save time...")
|
|
147
|
+
|
|
144
148
|
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
|
|
145
149
|
server_args.watchdog_timeout = compile_args.timeout
|
|
146
150
|
server_args.warmups = "compile-deep-gemm"
|
sglang/global_config.py
CHANGED
|
@@ -1,14 +1,11 @@
|
|
|
1
1
|
"""Global configurations"""
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
# FIXME: deprecate this file and move all usage to sglang.srt.environ or sglang.__init__.py
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class GlobalConfig:
|
|
7
7
|
"""
|
|
8
8
|
Store some global constants.
|
|
9
|
-
|
|
10
|
-
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
|
|
11
|
-
many global runtime arguments as well.
|
|
12
9
|
"""
|
|
13
10
|
|
|
14
11
|
def __init__(self):
|
|
@@ -20,27 +17,6 @@ class GlobalConfig:
|
|
|
20
17
|
# Default backend of the language
|
|
21
18
|
self.default_backend = None
|
|
22
19
|
|
|
23
|
-
# Runtime constants: New generation token ratio estimation
|
|
24
|
-
self.default_init_new_token_ratio = float(
|
|
25
|
-
os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
|
|
26
|
-
)
|
|
27
|
-
self.default_min_new_token_ratio_factor = float(
|
|
28
|
-
os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
|
|
29
|
-
)
|
|
30
|
-
self.default_new_token_ratio_decay_steps = float(
|
|
31
|
-
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
|
|
32
|
-
)
|
|
33
|
-
self.torch_empty_cache_interval = float(
|
|
34
|
-
os.environ.get(
|
|
35
|
-
"SGLANG_EMPTY_CACHE_INTERVAL", -1
|
|
36
|
-
) # in seconds. Set if you observe high memory accumulation over a long serving period.
|
|
37
|
-
)
|
|
38
|
-
# Runtime constants: others
|
|
39
|
-
self.retract_decode_steps = 20
|
|
40
|
-
self.flashinfer_workspace_size = int(
|
|
41
|
-
os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
|
|
42
|
-
)
|
|
43
|
-
|
|
44
20
|
# Output tokenization configs
|
|
45
21
|
self.skip_special_tokens_in_output = True
|
|
46
22
|
self.spaces_between_special_tokens_in_out = True
|
sglang/lang/api.py
CHANGED
|
@@ -79,6 +79,7 @@ def gen(
|
|
|
79
79
|
n: Optional[int] = None,
|
|
80
80
|
stop: Optional[Union[str, List[str]]] = None,
|
|
81
81
|
stop_token_ids: Optional[List[int]] = None,
|
|
82
|
+
stop_regex: Optional[Union[str, List[str]]] = None,
|
|
82
83
|
temperature: Optional[float] = None,
|
|
83
84
|
top_p: Optional[float] = None,
|
|
84
85
|
top_k: Optional[int] = None,
|
|
@@ -120,6 +121,7 @@ def gen(
|
|
|
120
121
|
n,
|
|
121
122
|
stop,
|
|
122
123
|
stop_token_ids,
|
|
124
|
+
stop_regex,
|
|
123
125
|
temperature,
|
|
124
126
|
top_p,
|
|
125
127
|
top_k,
|
|
@@ -143,6 +145,7 @@ def gen_int(
|
|
|
143
145
|
n: Optional[int] = None,
|
|
144
146
|
stop: Optional[Union[str, List[str]]] = None,
|
|
145
147
|
stop_token_ids: Optional[List[int]] = None,
|
|
148
|
+
stop_regex: Optional[Union[str, List[str]]] = None,
|
|
146
149
|
temperature: Optional[float] = None,
|
|
147
150
|
top_p: Optional[float] = None,
|
|
148
151
|
top_k: Optional[int] = None,
|
|
@@ -162,6 +165,7 @@ def gen_int(
|
|
|
162
165
|
n,
|
|
163
166
|
stop,
|
|
164
167
|
stop_token_ids,
|
|
168
|
+
stop_regex,
|
|
165
169
|
temperature,
|
|
166
170
|
top_p,
|
|
167
171
|
top_k,
|
|
@@ -184,6 +188,7 @@ def gen_string(
|
|
|
184
188
|
n: Optional[int] = None,
|
|
185
189
|
stop: Optional[Union[str, List[str]]] = None,
|
|
186
190
|
stop_token_ids: Optional[List[int]] = None,
|
|
191
|
+
stop_regex: Optional[Union[str, List[str]]] = None,
|
|
187
192
|
temperature: Optional[float] = None,
|
|
188
193
|
top_p: Optional[float] = None,
|
|
189
194
|
top_k: Optional[int] = None,
|
|
@@ -203,6 +208,7 @@ def gen_string(
|
|
|
203
208
|
n,
|
|
204
209
|
stop,
|
|
205
210
|
stop_token_ids,
|
|
211
|
+
stop_regex,
|
|
206
212
|
temperature,
|
|
207
213
|
top_p,
|
|
208
214
|
top_k,
|
sglang/lang/interpreter.py
CHANGED
sglang/lang/ir.py
CHANGED
|
@@ -21,6 +21,7 @@ class SglSamplingParams:
|
|
|
21
21
|
n: int = 1
|
|
22
22
|
stop: Union[str, List[str]] = ()
|
|
23
23
|
stop_token_ids: Optional[List[int]] = ()
|
|
24
|
+
stop_regex: Optional[Union[str, List[str]]] = ()
|
|
24
25
|
temperature: float = 1.0
|
|
25
26
|
top_p: float = 1.0
|
|
26
27
|
top_k: int = -1 # -1 means disable
|
|
@@ -45,6 +46,7 @@ class SglSamplingParams:
|
|
|
45
46
|
self.n,
|
|
46
47
|
self.stop,
|
|
47
48
|
self.stop_token_ids,
|
|
49
|
+
self.stop_regex,
|
|
48
50
|
self.temperature,
|
|
49
51
|
self.top_p,
|
|
50
52
|
self.top_k,
|
|
@@ -123,6 +125,7 @@ class SglSamplingParams:
|
|
|
123
125
|
"n": self.n,
|
|
124
126
|
"stop": self.stop,
|
|
125
127
|
"stop_token_ids": self.stop_token_ids,
|
|
128
|
+
"stop_regex": self.stop_regex,
|
|
126
129
|
"temperature": self.temperature,
|
|
127
130
|
"top_p": self.top_p,
|
|
128
131
|
"top_k": self.top_k,
|
|
@@ -161,6 +164,7 @@ class SglFunction:
|
|
|
161
164
|
n: int = 1,
|
|
162
165
|
stop: Optional[Union[str, List[str]]] = None,
|
|
163
166
|
stop_token_ids: Optional[List[int]] = None,
|
|
167
|
+
stop_regex: Optional[Union[str, List[str]]] = None,
|
|
164
168
|
temperature: float = 1.0,
|
|
165
169
|
top_p: float = 1.0,
|
|
166
170
|
top_k: int = -1,
|
|
@@ -184,12 +188,15 @@ class SglFunction:
|
|
|
184
188
|
stop = []
|
|
185
189
|
if stop_token_ids is None:
|
|
186
190
|
stop_token_ids = []
|
|
191
|
+
if stop_regex is None:
|
|
192
|
+
stop_regex = []
|
|
187
193
|
|
|
188
194
|
default_sampling_para = SglSamplingParams(
|
|
189
195
|
max_new_tokens=max_new_tokens,
|
|
190
196
|
n=n,
|
|
191
197
|
stop=stop,
|
|
192
198
|
stop_token_ids=stop_token_ids,
|
|
199
|
+
stop_regex=stop_regex,
|
|
193
200
|
temperature=temperature,
|
|
194
201
|
top_p=top_p,
|
|
195
202
|
top_k=top_k,
|
|
@@ -221,6 +228,7 @@ class SglFunction:
|
|
|
221
228
|
n: int = 1,
|
|
222
229
|
stop: Optional[Union[str, List[str]]] = None,
|
|
223
230
|
stop_token_ids: Optional[List[int]] = None,
|
|
231
|
+
stop_regex: Optional[Union[str, List[str]]] = None,
|
|
224
232
|
temperature: float = 1.0,
|
|
225
233
|
top_p: float = 1.0,
|
|
226
234
|
top_k: int = -1,
|
|
@@ -243,6 +251,8 @@ class SglFunction:
|
|
|
243
251
|
stop = []
|
|
244
252
|
if stop_token_ids is None:
|
|
245
253
|
stop_token_ids = []
|
|
254
|
+
if stop_regex is None:
|
|
255
|
+
stop_regex = []
|
|
246
256
|
|
|
247
257
|
assert isinstance(batch_kwargs, (list, tuple))
|
|
248
258
|
if len(batch_kwargs) == 0:
|
|
@@ -267,6 +277,7 @@ class SglFunction:
|
|
|
267
277
|
n=n,
|
|
268
278
|
stop=stop,
|
|
269
279
|
stop_token_ids=stop_token_ids,
|
|
280
|
+
stop_regex=stop_regex,
|
|
270
281
|
temperature=temperature,
|
|
271
282
|
top_p=top_p,
|
|
272
283
|
top_k=top_k,
|
|
@@ -451,6 +462,7 @@ class SglGen(SglExpr):
|
|
|
451
462
|
n: Optional[int] = None,
|
|
452
463
|
stop: Optional[Union[str, List[str]]] = None,
|
|
453
464
|
stop_token_ids: Optional[List[int]] = None,
|
|
465
|
+
stop_regex: Optional[Union[str, List[str]]] = None,
|
|
454
466
|
temperature: Optional[float] = None,
|
|
455
467
|
top_p: Optional[float] = None,
|
|
456
468
|
top_k: Optional[int] = None,
|
|
@@ -474,6 +486,7 @@ class SglGen(SglExpr):
|
|
|
474
486
|
min_new_tokens=min_new_tokens,
|
|
475
487
|
n=n,
|
|
476
488
|
stop=stop,
|
|
489
|
+
stop_regex=stop_regex,
|
|
477
490
|
stop_token_ids=stop_token_ids,
|
|
478
491
|
temperature=temperature,
|
|
479
492
|
top_p=top_p,
|
sglang/launch_server.py
CHANGED
|
@@ -1,30 +1,25 @@
|
|
|
1
1
|
"""Launch the inference server."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import os
|
|
4
5
|
import sys
|
|
5
6
|
|
|
6
|
-
from sglang.srt.entrypoints.http_server import launch_server
|
|
7
7
|
from sglang.srt.server_args import prepare_server_args
|
|
8
8
|
from sglang.srt.utils import kill_process_tree
|
|
9
9
|
|
|
10
|
-
MOVE_ENVS_WARN = """
|
|
11
|
-
########################################################################
|
|
12
|
-
# For contributors and developers: #
|
|
13
|
-
# Please move environment variable definitions to sglang.srt.environ #
|
|
14
|
-
# using the following pattern: #
|
|
15
|
-
# SGLANG_XXX = EnvBool(False) #
|
|
16
|
-
# #
|
|
17
|
-
########################################################################
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
10
|
if __name__ == "__main__":
|
|
21
11
|
server_args = prepare_server_args(sys.argv[1:])
|
|
22
12
|
|
|
23
|
-
|
|
13
|
+
try:
|
|
14
|
+
if server_args.grpc_mode:
|
|
15
|
+
# Handle gRPC server
|
|
16
|
+
from sglang.srt.entrypoints.grpc_server import serve_grpc
|
|
24
17
|
|
|
25
|
-
|
|
18
|
+
asyncio.run(serve_grpc(server_args))
|
|
19
|
+
else:
|
|
20
|
+
# Handle HTTP server
|
|
21
|
+
from sglang.srt.entrypoints.http_server import launch_server
|
|
26
22
|
|
|
27
|
-
|
|
28
|
-
launch_server(server_args)
|
|
23
|
+
launch_server(server_args)
|
|
29
24
|
finally:
|
|
30
25
|
kill_process_tree(os.getpid(), include_parent=False)
|
sglang/profiler.py
CHANGED
|
@@ -25,6 +25,7 @@ def _run_profile(
|
|
|
25
25
|
output_dir: Optional[str] = None,
|
|
26
26
|
profile_name: Optional[str] = None,
|
|
27
27
|
profile_by_stage: bool = False,
|
|
28
|
+
merge_profiles: bool = False,
|
|
28
29
|
) -> str:
|
|
29
30
|
if output_dir is None:
|
|
30
31
|
output_dir = PROFILER_DIR
|
|
@@ -60,6 +61,7 @@ def _run_profile(
|
|
|
60
61
|
"num_steps": str(num_steps),
|
|
61
62
|
"activities": activities,
|
|
62
63
|
"profile_by_stage": profile_by_stage,
|
|
64
|
+
"merge_profiles": merge_profiles,
|
|
63
65
|
}
|
|
64
66
|
|
|
65
67
|
response = requests.post(url=url + "/start_profile", json=json_data)
|
|
@@ -76,10 +78,17 @@ def run_profile(
|
|
|
76
78
|
output_dir: Optional[str] = None,
|
|
77
79
|
profile_name: Optional[str] = None,
|
|
78
80
|
profile_by_stage: bool = False,
|
|
81
|
+
merge_profiles: bool = False,
|
|
79
82
|
):
|
|
80
83
|
# step based profile will self terminate on num_steps constraints
|
|
81
84
|
link = _run_profile(
|
|
82
|
-
url,
|
|
85
|
+
url,
|
|
86
|
+
num_steps,
|
|
87
|
+
activities,
|
|
88
|
+
output_dir,
|
|
89
|
+
profile_name,
|
|
90
|
+
profile_by_stage,
|
|
91
|
+
merge_profiles,
|
|
83
92
|
)
|
|
84
93
|
return link
|
|
85
94
|
|
|
@@ -145,6 +154,13 @@ if __name__ == "__main__":
|
|
|
145
154
|
default=False,
|
|
146
155
|
help="Whether to use rpd profiler (https://github.com/ROCm/rocmProfileData)",
|
|
147
156
|
)
|
|
157
|
+
parser.add_argument(
|
|
158
|
+
"--merge-profiles",
|
|
159
|
+
action=argparse.BooleanOptionalAction,
|
|
160
|
+
type=bool,
|
|
161
|
+
default=False,
|
|
162
|
+
help="Whether to merge profiles from all ranks into a single trace file",
|
|
163
|
+
)
|
|
148
164
|
|
|
149
165
|
args = parser.parse_args()
|
|
150
166
|
activities = []
|
|
@@ -163,4 +179,5 @@ if __name__ == "__main__":
|
|
|
163
179
|
args.output_dir,
|
|
164
180
|
args.profile_name,
|
|
165
181
|
args.profile_by_stage,
|
|
182
|
+
args.merge_profiles,
|
|
166
183
|
)
|
sglang/srt/_custom_ops.py
CHANGED
|
@@ -15,7 +15,7 @@ if not is_hpu():
|
|
|
15
15
|
# ROCm does not use vllm custom allreduce
|
|
16
16
|
if use_vllm_custom_allreduce and not is_hip():
|
|
17
17
|
try:
|
|
18
|
-
import vllm._C
|
|
18
|
+
import vllm._C # noqa: F401
|
|
19
19
|
except ImportError as e:
|
|
20
20
|
logger.warning("Failed to import from vllm._C with %r", e)
|
|
21
21
|
else:
|
|
@@ -9,6 +9,22 @@ import torch
|
|
|
9
9
|
import triton
|
|
10
10
|
import triton.language as tl
|
|
11
11
|
|
|
12
|
+
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
|
|
13
|
+
from sglang.srt.utils.common import calc_diff, get_bool_env_var
|
|
14
|
+
|
|
15
|
+
if ENABLE_JIT_DEEPGEMM:
|
|
16
|
+
import deep_gemm
|
|
17
|
+
|
|
18
|
+
_ENABLE_MM_DEEPGEMM = get_bool_env_var(
|
|
19
|
+
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM", "1"
|
|
20
|
+
)
|
|
21
|
+
_ENABLE_MM_COMPARISON_TEST = get_bool_env_var(
|
|
22
|
+
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if not _ENABLE_MM_DEEPGEMM:
|
|
26
|
+
print("Disable DeepGEMM in batch invariant ops. Performance may be suboptimal.")
|
|
27
|
+
|
|
12
28
|
__all__ = [
|
|
13
29
|
"set_batch_invariant_mode",
|
|
14
30
|
"is_batch_invariant_mode_enabled",
|
|
@@ -77,8 +93,6 @@ def matmul_kernel_persistent(
|
|
|
77
93
|
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
|
78
94
|
num_tiles = num_pid_m * num_pid_n
|
|
79
95
|
|
|
80
|
-
tile_id_c = start_pid - NUM_SMS
|
|
81
|
-
|
|
82
96
|
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
|
83
97
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
84
98
|
|
|
@@ -120,10 +134,6 @@ def matmul_kernel_persistent(
|
|
|
120
134
|
)
|
|
121
135
|
accumulator = tl.dot(a, b, accumulator)
|
|
122
136
|
|
|
123
|
-
tile_id_c += NUM_SMS
|
|
124
|
-
pid_m, pid_n = _compute_pid(
|
|
125
|
-
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
|
126
|
-
)
|
|
127
137
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
128
138
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
129
139
|
if C_LARGE:
|
|
@@ -137,12 +147,16 @@ def matmul_kernel_persistent(
|
|
|
137
147
|
accumulator += bias
|
|
138
148
|
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
|
139
149
|
c = accumulator.to(tl.float8e4nv)
|
|
150
|
+
elif c_ptr.dtype.element_ty == tl.bfloat16:
|
|
151
|
+
c = accumulator.to(tl.bfloat16)
|
|
152
|
+
elif c_ptr.dtype.element_ty == tl.float32:
|
|
153
|
+
c = accumulator.to(tl.float32)
|
|
140
154
|
else:
|
|
141
155
|
c = accumulator.to(tl.float16)
|
|
142
156
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
143
157
|
|
|
144
158
|
|
|
145
|
-
def
|
|
159
|
+
def _matmul_persistent_triton(
|
|
146
160
|
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
|
147
161
|
):
|
|
148
162
|
# Check constraints.
|
|
@@ -219,6 +233,54 @@ def matmul_persistent(
|
|
|
219
233
|
return c
|
|
220
234
|
|
|
221
235
|
|
|
236
|
+
def _matmul_persistent_deepgemm(
|
|
237
|
+
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
|
238
|
+
):
|
|
239
|
+
M, K = a.shape
|
|
240
|
+
K, N = b.shape
|
|
241
|
+
dtype = a.dtype
|
|
242
|
+
out = torch.empty((M, N), device=a.device, dtype=dtype)
|
|
243
|
+
|
|
244
|
+
deep_gemm.bf16_gemm_nn(a, b, out)
|
|
245
|
+
|
|
246
|
+
# TODO can this be put in DeepGEMM's `c`?
|
|
247
|
+
if bias is not None:
|
|
248
|
+
out += bias
|
|
249
|
+
|
|
250
|
+
return out
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def matmul_persistent(
|
|
254
|
+
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
|
255
|
+
):
|
|
256
|
+
if (
|
|
257
|
+
_ENABLE_MM_DEEPGEMM
|
|
258
|
+
and ENABLE_JIT_DEEPGEMM
|
|
259
|
+
and (a.dtype == torch.bfloat16)
|
|
260
|
+
and (b.dtype == torch.bfloat16)
|
|
261
|
+
and a.is_contiguous()
|
|
262
|
+
and b.transpose(0, 1).is_contiguous()
|
|
263
|
+
):
|
|
264
|
+
if _ENABLE_MM_COMPARISON_TEST:
|
|
265
|
+
out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias)
|
|
266
|
+
out_deepgemm = _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
|
|
267
|
+
diff = calc_diff(out_triton, out_deepgemm)
|
|
268
|
+
assert diff < 0.0001, f"{diff=} {out_triton=} {out_deepgemm=}"
|
|
269
|
+
# can be enabled for debugging
|
|
270
|
+
# print(
|
|
271
|
+
# f"{diff=} "
|
|
272
|
+
# f"{(out_triton - out_deepgemm).abs().mean()=} "
|
|
273
|
+
# f"{(out_triton - out_deepgemm).abs().sum()=} "
|
|
274
|
+
# f"{torch.sum(out_triton != out_deepgemm)=} "
|
|
275
|
+
# )
|
|
276
|
+
# print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}")
|
|
277
|
+
return out_deepgemm
|
|
278
|
+
|
|
279
|
+
return _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
|
|
280
|
+
|
|
281
|
+
return _matmul_persistent_triton(a=a, b=b, bias=bias)
|
|
282
|
+
|
|
283
|
+
|
|
222
284
|
@triton.jit
|
|
223
285
|
def _log_softmax_kernel(
|
|
224
286
|
input_ptr,
|
|
@@ -497,16 +559,39 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
|
|
|
497
559
|
return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
|
|
498
560
|
|
|
499
561
|
|
|
562
|
+
def bmm_batch_invariant(a, b, *, out=None):
|
|
563
|
+
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
|
|
564
|
+
# Process each batch separately with our persistent kernel
|
|
565
|
+
if a.ndim == 3 and b.ndim == 3:
|
|
566
|
+
results = []
|
|
567
|
+
for i in range(a.shape[0]):
|
|
568
|
+
results.append(matmul_persistent(a[i], b[i]))
|
|
569
|
+
result = torch.stack(results, dim=0)
|
|
570
|
+
|
|
571
|
+
if out is not None:
|
|
572
|
+
out.copy_(result)
|
|
573
|
+
return out
|
|
574
|
+
return result
|
|
575
|
+
else:
|
|
576
|
+
raise ValueError(
|
|
577
|
+
f"bmm_batch_invariant expects 3D tensors, "
|
|
578
|
+
f"got shapes {a.shape} and {b.shape}"
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
|
|
500
582
|
_batch_invariant_MODE = False
|
|
501
583
|
_batch_invariant_LIB = None
|
|
584
|
+
_original_torch_bmm = None
|
|
502
585
|
|
|
503
586
|
|
|
504
587
|
def is_batch_invariant_mode_enabled():
|
|
505
588
|
return _batch_invariant_MODE
|
|
506
589
|
|
|
507
590
|
|
|
508
|
-
def enable_batch_invariant_mode(
|
|
509
|
-
|
|
591
|
+
def enable_batch_invariant_mode(
|
|
592
|
+
enable_bmm: bool = True,
|
|
593
|
+
):
|
|
594
|
+
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
|
510
595
|
if _batch_invariant_MODE:
|
|
511
596
|
return
|
|
512
597
|
|
|
@@ -519,11 +604,21 @@ def enable_batch_invariant_mode():
|
|
|
519
604
|
)
|
|
520
605
|
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
|
521
606
|
|
|
607
|
+
if enable_bmm:
|
|
608
|
+
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
|
|
609
|
+
|
|
610
|
+
# Also monkeypatch torch.bmm directly as a fallback
|
|
611
|
+
_original_torch_bmm = torch.bmm
|
|
612
|
+
torch.bmm = bmm_batch_invariant
|
|
613
|
+
|
|
522
614
|
|
|
523
615
|
def disable_batch_invariant_mode():
|
|
524
|
-
global _batch_invariant_MODE, _batch_invariant_LIB
|
|
616
|
+
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
|
525
617
|
if _batch_invariant_LIB is not None:
|
|
526
618
|
_batch_invariant_LIB._destroy()
|
|
619
|
+
if _original_torch_bmm is not None:
|
|
620
|
+
torch.bmm = _original_torch_bmm
|
|
621
|
+
_original_torch_bmm = None
|
|
527
622
|
_batch_invariant_MODE = False
|
|
528
623
|
_batch_invariant_LIB = None
|
|
529
624
|
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ==============================================================================
|
|
14
|
+
"""
|
|
15
|
+
Checkpoint-engine integration for SGLang.
|
|
16
|
+
This module provides weight update functionality via IPC for checkpoint-engine compatibility.
|
|
17
|
+
"""
|
|
18
|
+
import logging
|
|
19
|
+
from typing import Callable, Dict, Optional
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import zmq
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from checkpoint_engine.worker import update_weights_from_ipc
|
|
26
|
+
except ImportError:
|
|
27
|
+
raise ImportError(
|
|
28
|
+
"checkpoint-engine is not installed. "
|
|
29
|
+
"Please install it with: pip install sglang[checkpoint-engine]"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SGLangCheckpointEngineWorkerExtension:
|
|
36
|
+
"""
|
|
37
|
+
Worker extension for SGLang to support checkpoint-engine IPC weight updates.
|
|
38
|
+
This class provides the interface needed for checkpoint-engine integration.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
self._zmq_ctx: Optional[zmq.Context] = None
|
|
43
|
+
|
|
44
|
+
def get_device_uuid(self) -> str:
|
|
45
|
+
"""Get the UUID of current device."""
|
|
46
|
+
# We need to implement this to get the device UUID
|
|
47
|
+
# This will be overridden when integrated into SGLang's worker
|
|
48
|
+
raise NotImplementedError(
|
|
49
|
+
"This method should be overridden by SGLang integration"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def get_device_id(self) -> int:
|
|
53
|
+
"""Get the device ID."""
|
|
54
|
+
raise NotImplementedError(
|
|
55
|
+
"This method should be overridden by SGLang integration"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def get_model_loader(self) -> Callable:
|
|
59
|
+
"""Get the model weight loader function."""
|
|
60
|
+
raise NotImplementedError(
|
|
61
|
+
"This method should be overridden by SGLang integration"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def get_post_hook(self) -> Optional[Callable]:
|
|
65
|
+
"""Get the post-processing hook after weight loading."""
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
|
|
69
|
+
"""
|
|
70
|
+
Update weights from IPC communication.
|
|
71
|
+
Args:
|
|
72
|
+
zmq_handles: Dict mapping device UUID to ZMQ socket path
|
|
73
|
+
"""
|
|
74
|
+
if self._zmq_ctx is None:
|
|
75
|
+
self._zmq_ctx = zmq.Context()
|
|
76
|
+
device_uuid = self.get_device_uuid()
|
|
77
|
+
device_id = self.get_device_id()
|
|
78
|
+
if device_uuid not in zmq_handles:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
|
|
81
|
+
)
|
|
82
|
+
update_weights_from_ipc(
|
|
83
|
+
self._zmq_ctx,
|
|
84
|
+
zmq_handles[device_uuid],
|
|
85
|
+
device_id=device_id,
|
|
86
|
+
run=self.get_model_loader(),
|
|
87
|
+
post_hook=self.get_post_hook(),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
|
|
92
|
+
"""
|
|
93
|
+
Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
|
|
94
|
+
This class provides the concrete implementation for checkpoint-engine IPC weight updates.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(self, model_runner):
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.model_runner = model_runner
|
|
100
|
+
|
|
101
|
+
def get_device_uuid(self) -> str:
|
|
102
|
+
"""Get the UUID of current device."""
|
|
103
|
+
# Get device UUID for current device
|
|
104
|
+
device_id = torch.cuda.current_device()
|
|
105
|
+
try:
|
|
106
|
+
return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
|
|
107
|
+
except AssertionError as e:
|
|
108
|
+
raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
|
|
109
|
+
|
|
110
|
+
def get_device_id(self) -> int:
|
|
111
|
+
"""Get the device ID."""
|
|
112
|
+
return torch.cuda.current_device()
|
|
113
|
+
|
|
114
|
+
def get_model_loader(self) -> Callable:
|
|
115
|
+
"""Get the model weight loader function."""
|
|
116
|
+
return self.model_runner.model.load_weights
|
|
117
|
+
|
|
118
|
+
def get_post_hook(self) -> Optional[Callable]:
|
|
119
|
+
"""Get the post-processing hook after weight loading."""
|
|
120
|
+
|
|
121
|
+
def post_hook():
|
|
122
|
+
# Perform post-processing after weight loading similar to DefaultModelLoader
|
|
123
|
+
try:
|
|
124
|
+
from sglang.srt.model_loader.loader import device_loading_context
|
|
125
|
+
|
|
126
|
+
# Process quantization methods after loading weights
|
|
127
|
+
for _, module in self.model_runner.model.named_modules():
|
|
128
|
+
quant_method = getattr(module, "quant_method", None)
|
|
129
|
+
if quant_method is not None:
|
|
130
|
+
# Move parameters to device if needed for quantization processing
|
|
131
|
+
target_device = torch.device(
|
|
132
|
+
"cuda", torch.cuda.current_device()
|
|
133
|
+
)
|
|
134
|
+
with device_loading_context(module, target_device):
|
|
135
|
+
quant_method.process_weights_after_loading(module)
|
|
136
|
+
# Call model-specific post-loading hook if available
|
|
137
|
+
if hasattr(self.model_runner.model, "post_load_weights"):
|
|
138
|
+
self.model_runner.model.post_load_weights()
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.warning(f"Post-hook processing failed: {e}")
|
|
141
|
+
|
|
142
|
+
return post_hook
|