sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +378 -160
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +10 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +136 -25
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +63 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +83 -80
- sglang/srt/entrypoints/grpc_server.py +430 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +195 -102
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +58 -6
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +33 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +20 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +10 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +24 -10
- sglang/srt/layers/attention/flashinfer_backend.py +258 -22
- sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
- sglang/srt/layers/attention/utils.py +89 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +12 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +64 -19
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +152 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
- sglang/srt/layers/moe/ep_moe/layer.py +154 -625
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
- sglang/srt/layers/moe/moe_runner/runner.py +6 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
- sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +7 -6
- sglang/srt/layers/moe/utils.py +20 -5
- sglang/srt/layers/quantization/__init__.py +5 -58
- sglang/srt/layers/quantization/awq.py +183 -9
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +27 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +152 -81
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +35 -68
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +23 -48
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +87 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +62 -9
- sglang/srt/layers/rotary_embedding.py +686 -17
- sglang/srt/layers/sampler.py +47 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +69 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +420 -514
- sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +375 -95
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +11 -2
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +517 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +71 -25
- sglang/srt/model_executor/model_runner.py +362 -270
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +418 -140
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +327 -382
- sglang/srt/models/glm4_moe_nextn.py +6 -16
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +32 -199
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +22 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3.py +34 -4
- sglang/srt/models/qwen3_moe.py +19 -37
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +7 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +2 -6
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +28 -2
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +846 -163
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +36 -31
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +272 -82
- sglang/srt/utils/hf_transformers_utils.py +44 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +463 -107
- sglang/test/test_deterministic_utils.py +74 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/models/vila.py +0 -306
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/backend.py
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
import ast
|
|
5
|
+
import dataclasses
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import pprint
|
|
9
|
+
import time
|
|
10
|
+
from collections.abc import Sequence
|
|
11
|
+
from contextlib import contextmanager
|
|
12
|
+
from typing import Any, Callable, Optional
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.fx as fx
|
|
16
|
+
from torch._dispatch.python import enable_python_dispatcher
|
|
17
|
+
|
|
18
|
+
from sglang.srt.compilation.compilation_config import CompilationConfig
|
|
19
|
+
from sglang.srt.compilation.compilation_counter import compilation_counter
|
|
20
|
+
from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
|
|
21
|
+
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
|
|
22
|
+
from sglang.srt.compilation.pass_manager import PostGradPassManager
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def make_compiler(config: CompilationConfig):
|
|
28
|
+
if config.compiler == "eager":
|
|
29
|
+
return EagerAdapter()
|
|
30
|
+
elif config.compiler == "inductor":
|
|
31
|
+
return InductorAdaptor()
|
|
32
|
+
else:
|
|
33
|
+
raise ValueError(f"Unknown compiler: {config.compiler}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CompilerManager:
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
config: CompilationConfig,
|
|
40
|
+
):
|
|
41
|
+
self.cache = dict()
|
|
42
|
+
self.is_cache_updated = False
|
|
43
|
+
self.compiler = make_compiler(config)
|
|
44
|
+
|
|
45
|
+
def compute_hash(self):
|
|
46
|
+
return self.compiler.compute_hash()
|
|
47
|
+
|
|
48
|
+
def initialize_cache(
|
|
49
|
+
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
|
50
|
+
):
|
|
51
|
+
self.disable_cache = disable_cache
|
|
52
|
+
self.cache_dir = cache_dir
|
|
53
|
+
self.cache_file_path = os.path.join(cache_dir, "sglang_compile_cache.py")
|
|
54
|
+
|
|
55
|
+
if not disable_cache and os.path.exists(self.cache_file_path):
|
|
56
|
+
with open(self.cache_file_path) as f:
|
|
57
|
+
self.cache = ast.literal_eval(f.read())
|
|
58
|
+
|
|
59
|
+
self.compiler.initialize_cache(
|
|
60
|
+
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def save_to_file(self):
|
|
64
|
+
if self.disable_cache or not self.is_cache_updated:
|
|
65
|
+
return
|
|
66
|
+
printer = pprint.PrettyPrinter(indent=4)
|
|
67
|
+
data = printer.pformat(self.cache)
|
|
68
|
+
with open(self.cache_file_path, "w") as f:
|
|
69
|
+
f.write(data)
|
|
70
|
+
|
|
71
|
+
def load(
|
|
72
|
+
self,
|
|
73
|
+
graph: fx.GraphModule,
|
|
74
|
+
example_inputs: list[Any],
|
|
75
|
+
graph_index: int,
|
|
76
|
+
runtime_shape: Optional[int] = None,
|
|
77
|
+
) -> Optional[Callable]:
|
|
78
|
+
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
|
|
79
|
+
compiled_graph = self.compiler.load(
|
|
80
|
+
handle, graph, example_inputs, graph_index, runtime_shape
|
|
81
|
+
)
|
|
82
|
+
if runtime_shape is None:
|
|
83
|
+
logger.debug(
|
|
84
|
+
"Directly load the %s-th graph for dynamic shape from %s via "
|
|
85
|
+
"handle %s",
|
|
86
|
+
graph_index,
|
|
87
|
+
self.compiler.name,
|
|
88
|
+
handle,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
logger.debug(
|
|
92
|
+
"Directly load the %s-th graph for shape %s from %s via " "handle %s",
|
|
93
|
+
graph_index,
|
|
94
|
+
str(runtime_shape),
|
|
95
|
+
self.compiler.name,
|
|
96
|
+
handle,
|
|
97
|
+
)
|
|
98
|
+
return compiled_graph
|
|
99
|
+
|
|
100
|
+
def compile(
|
|
101
|
+
self,
|
|
102
|
+
graph: fx.GraphModule,
|
|
103
|
+
example_inputs,
|
|
104
|
+
inductor_config: dict[str, Any],
|
|
105
|
+
graph_index: int = 0,
|
|
106
|
+
num_graphs: int = 1,
|
|
107
|
+
runtime_shape: Optional[int] = None,
|
|
108
|
+
) -> Any:
|
|
109
|
+
if graph_index == 0:
|
|
110
|
+
# before compiling the first graph, record the start time
|
|
111
|
+
global compilation_start_time
|
|
112
|
+
compilation_start_time = time.time()
|
|
113
|
+
|
|
114
|
+
compilation_counter.num_backend_compilations += 1
|
|
115
|
+
|
|
116
|
+
compiled_graph = None
|
|
117
|
+
|
|
118
|
+
# TODO(Yuwei): support cache loading
|
|
119
|
+
|
|
120
|
+
# no compiler cached the graph, or the cache is disabled,
|
|
121
|
+
# we need to compile it
|
|
122
|
+
if isinstance(self.compiler, InductorAdaptor):
|
|
123
|
+
maybe_key = None
|
|
124
|
+
else:
|
|
125
|
+
maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
|
126
|
+
compiled_graph, handle = self.compiler.compile(
|
|
127
|
+
graph, example_inputs, inductor_config, runtime_shape, maybe_key
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
assert compiled_graph is not None, "Failed to compile the graph"
|
|
131
|
+
|
|
132
|
+
# store the artifact in the cache
|
|
133
|
+
if handle is not None:
|
|
134
|
+
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
|
|
135
|
+
compilation_counter.num_cache_entries_updated += 1
|
|
136
|
+
self.is_cache_updated = True
|
|
137
|
+
if graph_index == 0:
|
|
138
|
+
# adds some info logging for the first graph
|
|
139
|
+
if runtime_shape is None:
|
|
140
|
+
logger.info("Cache the graph for dynamic shape for later use")
|
|
141
|
+
else:
|
|
142
|
+
logger.info(
|
|
143
|
+
"Cache the graph of shape %s for later use", str(runtime_shape)
|
|
144
|
+
)
|
|
145
|
+
if runtime_shape is None:
|
|
146
|
+
logger.debug(
|
|
147
|
+
"Store the %s-th graph for dynamic shape from %s via " "handle %s",
|
|
148
|
+
graph_index,
|
|
149
|
+
self.compiler.name,
|
|
150
|
+
handle,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
logger.debug(
|
|
154
|
+
"Store the %s-th graph for shape %s from %s via handle %s",
|
|
155
|
+
graph_index,
|
|
156
|
+
str(runtime_shape),
|
|
157
|
+
self.compiler.name,
|
|
158
|
+
handle,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# after compiling the last graph, record the end time
|
|
162
|
+
if graph_index == num_graphs - 1:
|
|
163
|
+
now = time.time()
|
|
164
|
+
elapsed = now - compilation_start_time
|
|
165
|
+
if runtime_shape is None:
|
|
166
|
+
logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
|
|
167
|
+
else:
|
|
168
|
+
logger.info(
|
|
169
|
+
"Compiling a graph for shape %s takes %.2f s",
|
|
170
|
+
runtime_shape,
|
|
171
|
+
elapsed,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return compiled_graph
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@dataclasses.dataclass
|
|
178
|
+
class SplitItem:
|
|
179
|
+
submod_name: str
|
|
180
|
+
graph_id: int
|
|
181
|
+
is_splitting_graph: bool
|
|
182
|
+
graph: fx.GraphModule
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def split_graph(
|
|
186
|
+
graph: fx.GraphModule, ops: list[str]
|
|
187
|
+
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
|
188
|
+
# split graph by ops
|
|
189
|
+
subgraph_id = 0
|
|
190
|
+
node_to_subgraph_id = {}
|
|
191
|
+
split_op_graphs = []
|
|
192
|
+
for node in graph.graph.nodes:
|
|
193
|
+
if node.op in ("output", "placeholder"):
|
|
194
|
+
continue
|
|
195
|
+
if node.op == "call_function" and str(node.target) in ops:
|
|
196
|
+
subgraph_id += 1
|
|
197
|
+
node_to_subgraph_id[node] = subgraph_id
|
|
198
|
+
split_op_graphs.append(subgraph_id)
|
|
199
|
+
subgraph_id += 1
|
|
200
|
+
else:
|
|
201
|
+
node_to_subgraph_id[node] = subgraph_id
|
|
202
|
+
|
|
203
|
+
# `keep_original_order` is important!
|
|
204
|
+
# otherwise pytorch might reorder the nodes and
|
|
205
|
+
# the semantics of the graph will change when we
|
|
206
|
+
# have mutations in the graph
|
|
207
|
+
split_gm = torch.fx.passes.split_module.split_module(
|
|
208
|
+
graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
outputs = []
|
|
212
|
+
|
|
213
|
+
names = [name for (name, module) in split_gm.named_modules()]
|
|
214
|
+
|
|
215
|
+
for name in names:
|
|
216
|
+
if "." in name or name == "":
|
|
217
|
+
# recursive child module or the root module
|
|
218
|
+
continue
|
|
219
|
+
|
|
220
|
+
module = getattr(split_gm, name)
|
|
221
|
+
|
|
222
|
+
graph_id = int(name.replace("submod_", ""))
|
|
223
|
+
outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
|
224
|
+
|
|
225
|
+
# sort by intetger graph_id, rather than string name
|
|
226
|
+
outputs.sort(key=lambda x: x.graph_id)
|
|
227
|
+
|
|
228
|
+
return split_gm, outputs
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
# we share the global graph pool among all the backends
|
|
232
|
+
global_graph_pool = None
|
|
233
|
+
|
|
234
|
+
compilation_start_time = 0.0
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|
238
|
+
def __init__(
|
|
239
|
+
self,
|
|
240
|
+
module: torch.fx.GraphModule,
|
|
241
|
+
compile_submod_names: list[str],
|
|
242
|
+
inductor_config: dict[str, Any],
|
|
243
|
+
graph_pool,
|
|
244
|
+
compile_config: CompilationConfig,
|
|
245
|
+
sglang_backend: "SGLangBackend",
|
|
246
|
+
):
|
|
247
|
+
super().__init__(module)
|
|
248
|
+
from torch._guards import detect_fake_mode
|
|
249
|
+
|
|
250
|
+
self.fake_mode = detect_fake_mode()
|
|
251
|
+
self.compile_submod_names = compile_submod_names
|
|
252
|
+
self.graph_pool = graph_pool
|
|
253
|
+
self.sglang_backend = sglang_backend
|
|
254
|
+
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
|
255
|
+
self.extra_traceback = False
|
|
256
|
+
self.inductor_config = inductor_config
|
|
257
|
+
self.compile_config = compile_config
|
|
258
|
+
|
|
259
|
+
def run(self, *args):
|
|
260
|
+
fake_args = [
|
|
261
|
+
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
|
262
|
+
for t in args
|
|
263
|
+
]
|
|
264
|
+
with self.fake_mode, enable_python_dispatcher():
|
|
265
|
+
return super().run(*fake_args)
|
|
266
|
+
|
|
267
|
+
def call_module(
|
|
268
|
+
self,
|
|
269
|
+
target: torch.fx.node.Target,
|
|
270
|
+
args: tuple[torch.fx.node.Argument, ...],
|
|
271
|
+
kwargs: dict[str, Any],
|
|
272
|
+
) -> Any:
|
|
273
|
+
assert isinstance(target, str)
|
|
274
|
+
output = super().call_module(target, args, kwargs)
|
|
275
|
+
|
|
276
|
+
if target in self.compile_submod_names:
|
|
277
|
+
index = self.compile_submod_names.index(target)
|
|
278
|
+
submod = self.fetch_attr(target)
|
|
279
|
+
sym_shape_indices = [
|
|
280
|
+
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
|
281
|
+
]
|
|
282
|
+
global compilation_start_time
|
|
283
|
+
compiled_graph_for_dynamic_shape = (
|
|
284
|
+
self.sglang_backend.compiler_manager.compile(
|
|
285
|
+
submod,
|
|
286
|
+
args,
|
|
287
|
+
self.inductor_config,
|
|
288
|
+
graph_index=index,
|
|
289
|
+
num_graphs=len(self.compile_submod_names),
|
|
290
|
+
runtime_shape=None,
|
|
291
|
+
)
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
self.module.__dict__[target] = CUDAPiecewiseBackend(
|
|
295
|
+
submod,
|
|
296
|
+
self.compile_config,
|
|
297
|
+
self.inductor_config,
|
|
298
|
+
self.graph_pool,
|
|
299
|
+
index,
|
|
300
|
+
len(self.compile_submod_names),
|
|
301
|
+
sym_shape_indices,
|
|
302
|
+
compiled_graph_for_dynamic_shape,
|
|
303
|
+
self.sglang_backend,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
|
307
|
+
|
|
308
|
+
return output
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
model_tag: str = "backbone"
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@contextmanager
|
|
315
|
+
def set_model_tag(tag: str):
|
|
316
|
+
"""Context manager to set the model tag."""
|
|
317
|
+
global model_tag
|
|
318
|
+
assert (
|
|
319
|
+
tag != model_tag
|
|
320
|
+
), f"Model tag {tag} is the same as the current tag {model_tag}."
|
|
321
|
+
old_tag = model_tag
|
|
322
|
+
model_tag = tag
|
|
323
|
+
try:
|
|
324
|
+
yield
|
|
325
|
+
finally:
|
|
326
|
+
model_tag = old_tag
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class SGLangBackend:
|
|
330
|
+
|
|
331
|
+
graph_pool: Any
|
|
332
|
+
_called: bool = False
|
|
333
|
+
# the graph we compiled
|
|
334
|
+
graph: fx.GraphModule
|
|
335
|
+
# the stiching graph module for all the piecewise graphs
|
|
336
|
+
split_gm: fx.GraphModule
|
|
337
|
+
piecewise_graphs: list[SplitItem]
|
|
338
|
+
returned_callable: Callable
|
|
339
|
+
# Inductor passes to run on the graph pre-defunctionalization
|
|
340
|
+
post_grad_passes: Sequence[Callable]
|
|
341
|
+
sym_tensor_indices: list[int]
|
|
342
|
+
input_buffers: list[torch.Tensor]
|
|
343
|
+
compiler_manager: CompilerManager
|
|
344
|
+
|
|
345
|
+
def __init__(
|
|
346
|
+
self,
|
|
347
|
+
config: CompilationConfig,
|
|
348
|
+
graph_pool: Any,
|
|
349
|
+
):
|
|
350
|
+
assert graph_pool is not None
|
|
351
|
+
self.graph_pool = graph_pool
|
|
352
|
+
|
|
353
|
+
self.post_grad_pass_manager = PostGradPassManager()
|
|
354
|
+
self.sym_tensor_indices = []
|
|
355
|
+
self.input_buffers = []
|
|
356
|
+
|
|
357
|
+
self.compiler_manager = CompilerManager(config)
|
|
358
|
+
self.inductor_config = {
|
|
359
|
+
"enable_auto_functionalized_v2": False,
|
|
360
|
+
}
|
|
361
|
+
self.compile_config = config
|
|
362
|
+
|
|
363
|
+
def configure_post_pass(self):
|
|
364
|
+
self.post_grad_pass_manager.configure()
|
|
365
|
+
self.inductor_config["post_grad_custom_post_pass"] = self.post_grad_pass_manager
|
|
366
|
+
|
|
367
|
+
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
|
368
|
+
base_cache_dir = os.path.expanduser(
|
|
369
|
+
os.getenv("SGLANG_CACHE_DIR", "~/.cache/sglang/")
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
cache_hash = self.compiler_manager.compute_hash()
|
|
373
|
+
cache_dir = os.path.join(
|
|
374
|
+
base_cache_dir,
|
|
375
|
+
"torch_compile_cache",
|
|
376
|
+
cache_hash,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
380
|
+
rank = 0
|
|
381
|
+
dp_rank = 0
|
|
382
|
+
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", model_tag)
|
|
383
|
+
os.makedirs(local_cache_dir, exist_ok=True)
|
|
384
|
+
self.compiler_manager.initialize_cache(
|
|
385
|
+
local_cache_dir, disable_cache=False, prefix=""
|
|
386
|
+
)
|
|
387
|
+
compilation_counter.num_graphs_seen += 1
|
|
388
|
+
|
|
389
|
+
assert not self._called, "SGLangBackend can only be called once"
|
|
390
|
+
|
|
391
|
+
self.graph = graph
|
|
392
|
+
self.configure_post_pass()
|
|
393
|
+
|
|
394
|
+
self.split_gm, self.piecewise_graphs = split_graph(
|
|
395
|
+
graph, ["sglang.unified_attention_with_output", "sglang.inplace_all_reduce"]
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
from torch._dynamo.utils import lazy_format_graph_code
|
|
399
|
+
|
|
400
|
+
# depyf will hook lazy_format_graph_code and dump the graph
|
|
401
|
+
# for debugging, no need to print the graph here
|
|
402
|
+
lazy_format_graph_code("before split", self.graph)
|
|
403
|
+
lazy_format_graph_code("after split", self.split_gm)
|
|
404
|
+
|
|
405
|
+
compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
|
|
406
|
+
|
|
407
|
+
submod_names_to_compile = [
|
|
408
|
+
item.submod_name
|
|
409
|
+
for item in self.piecewise_graphs
|
|
410
|
+
if not item.is_splitting_graph
|
|
411
|
+
]
|
|
412
|
+
|
|
413
|
+
PiecewiseCompileInterpreter(
|
|
414
|
+
self.split_gm,
|
|
415
|
+
submod_names_to_compile,
|
|
416
|
+
self.inductor_config,
|
|
417
|
+
self.graph_pool,
|
|
418
|
+
self.compile_config,
|
|
419
|
+
self,
|
|
420
|
+
).run(*example_inputs)
|
|
421
|
+
|
|
422
|
+
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
|
423
|
+
if not os.path.exists(graph_path):
|
|
424
|
+
# code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
|
|
425
|
+
# use `print_readable` because it can include submodules
|
|
426
|
+
src = (
|
|
427
|
+
"from __future__ import annotations\nimport torch\n"
|
|
428
|
+
+ self.split_gm.print_readable(print_output=False)
|
|
429
|
+
)
|
|
430
|
+
src = src.replace("<lambda>", "GraphModule")
|
|
431
|
+
with open(graph_path, "w") as f:
|
|
432
|
+
f.write(src)
|
|
433
|
+
|
|
434
|
+
logger.debug("Computation graph saved to %s", graph_path)
|
|
435
|
+
|
|
436
|
+
self._called = True
|
|
437
|
+
return self.split_gm
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# TODO(Yuwei): support better compile config support
|
|
7
|
+
class CompilationConfig:
|
|
8
|
+
def __init__(self, capture_sizes: List[int], compiler: str = "eager"):
|
|
9
|
+
self.traced_files = set()
|
|
10
|
+
self.capture_sizes = capture_sizes
|
|
11
|
+
self.compiler = compiler
|
|
12
|
+
|
|
13
|
+
def add_traced_file(self, file_path: str):
|
|
14
|
+
self.traced_files.add(file_path)
|
|
15
|
+
|
|
16
|
+
def get_traced_files(self):
|
|
17
|
+
return self.traced_files
|
|
18
|
+
|
|
19
|
+
def get_capture_sizes(self):
|
|
20
|
+
return self.capture_sizes
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_counter.py
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import dataclasses
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclasses.dataclass
|
|
9
|
+
class CompilationCounter:
|
|
10
|
+
num_models_seen: int = 0
|
|
11
|
+
num_graphs_seen: int = 0
|
|
12
|
+
# including the splitting ops
|
|
13
|
+
num_piecewise_graphs_seen: int = 0
|
|
14
|
+
# not including the splitting ops
|
|
15
|
+
num_piecewise_capturable_graphs_seen: int = 0
|
|
16
|
+
num_backend_compilations: int = 0
|
|
17
|
+
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
|
18
|
+
num_gpu_runner_capture_triggers: int = 0
|
|
19
|
+
# Number of CUDAGraphs captured
|
|
20
|
+
num_cudagraph_captured: int = 0
|
|
21
|
+
# InductorAdapter.compile calls
|
|
22
|
+
num_inductor_compiles: int = 0
|
|
23
|
+
# EagerAdapter.compile calls
|
|
24
|
+
num_eager_compiles: int = 0
|
|
25
|
+
# The number of time vLLM's compiler cache entry was updated
|
|
26
|
+
num_cache_entries_updated: int = 0
|
|
27
|
+
# The number of standalone_compile compiled artifacts saved
|
|
28
|
+
num_compiled_artifacts_saved: int = 0
|
|
29
|
+
# Number of times a model was loaded with CompilationLevel.DYNAMO_AS_IS
|
|
30
|
+
dynamo_as_is_count: int = 0
|
|
31
|
+
|
|
32
|
+
def clone(self) -> "CompilationCounter":
|
|
33
|
+
return copy.deepcopy(self)
|
|
34
|
+
|
|
35
|
+
@contextmanager
|
|
36
|
+
def expect(self, **kwargs):
|
|
37
|
+
old = self.clone()
|
|
38
|
+
yield
|
|
39
|
+
for k, v in kwargs.items():
|
|
40
|
+
assert getattr(self, k) - getattr(old, k) == v, (
|
|
41
|
+
f"{k} not as expected, before it is {getattr(old, k)}"
|
|
42
|
+
f", after it is {getattr(self, k)}, "
|
|
43
|
+
f"expected diff is {v}"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
compilation_counter = CompilationCounter()
|