sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,269 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Inference-only Sarashina2Vision model compatible with HuggingFace weights."""
|
15
|
+
|
16
|
+
import logging
|
17
|
+
from typing import Iterable, List, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn.functional as F
|
21
|
+
from torch import nn
|
22
|
+
from transformers import LlamaConfig
|
23
|
+
|
24
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
25
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
26
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
27
|
+
from sglang.srt.managers.mm_utils import (
|
28
|
+
MultimodalDataItem,
|
29
|
+
MultimodalInputs,
|
30
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
31
|
+
general_mm_embed_routine,
|
32
|
+
)
|
33
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
34
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
35
|
+
from sglang.srt.models.llama import LlamaForCausalLM
|
36
|
+
from sglang.srt.models.qwen2_vl import Qwen2VisionTransformer
|
37
|
+
from sglang.srt.utils import add_prefix
|
38
|
+
|
39
|
+
logger = logging.getLogger(__name__)
|
40
|
+
|
41
|
+
|
42
|
+
class Sarashina2VisionForCausalLM(nn.Module):
|
43
|
+
"""
|
44
|
+
Sarashina2Vision model that combines:
|
45
|
+
- Llama text backbone (sbintuitions/sarashina2-7b)
|
46
|
+
- Qwen2VL vision encoder
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
config,
|
52
|
+
quant_config: Optional[QuantizationConfig] = None,
|
53
|
+
prefix: str = "",
|
54
|
+
) -> None:
|
55
|
+
super().__init__()
|
56
|
+
|
57
|
+
self.config = config
|
58
|
+
|
59
|
+
# Extract text and vision configurations
|
60
|
+
text_config = getattr(config, "text_config", config)
|
61
|
+
vision_config = getattr(config, "vision_config", None)
|
62
|
+
|
63
|
+
# Create vision transformer first (like original model)
|
64
|
+
if vision_config is not None:
|
65
|
+
self.visual = Qwen2VisionTransformer(
|
66
|
+
vision_config,
|
67
|
+
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
68
|
+
quant_config=quant_config,
|
69
|
+
prefix=add_prefix("visual", prefix),
|
70
|
+
)
|
71
|
+
else:
|
72
|
+
self.visual = None
|
73
|
+
|
74
|
+
# Layer norm for vision outputs (matching original model)
|
75
|
+
self.norm = nn.LayerNorm(text_config.hidden_size)
|
76
|
+
|
77
|
+
# Create Llama text model (using 'llm' name to match original)
|
78
|
+
if hasattr(text_config, "model_type") and text_config.model_type == "llama":
|
79
|
+
llama_config = LlamaConfig(**text_config.__dict__)
|
80
|
+
# Set vocab_size from main config if available
|
81
|
+
if hasattr(config, "vocab_size"):
|
82
|
+
llama_config.vocab_size = config.vocab_size
|
83
|
+
self.llm = LlamaForCausalLM(
|
84
|
+
llama_config,
|
85
|
+
quant_config=quant_config,
|
86
|
+
prefix=add_prefix("llm", prefix),
|
87
|
+
)
|
88
|
+
else:
|
89
|
+
# Set vocab_size from main config if available
|
90
|
+
if hasattr(config, "vocab_size"):
|
91
|
+
config.vocab_size = config.vocab_size
|
92
|
+
self.llm = LlamaForCausalLM(
|
93
|
+
config,
|
94
|
+
quant_config=quant_config,
|
95
|
+
prefix=add_prefix("llm", prefix),
|
96
|
+
)
|
97
|
+
|
98
|
+
# Image token indices from config
|
99
|
+
self.image_token_index = getattr(config, "image_token_index", 14)
|
100
|
+
self.start_image_token_index = getattr(
|
101
|
+
config, "start_image_token_index", 102397
|
102
|
+
)
|
103
|
+
self.end_image_token_index = getattr(config, "end_image_token_index", 102398)
|
104
|
+
|
105
|
+
# Ensure vocabulary size matches
|
106
|
+
if hasattr(config, "vocab_size"):
|
107
|
+
self.llm.config.vocab_size = config.vocab_size
|
108
|
+
|
109
|
+
self.logits_processor = LogitsProcessor(config)
|
110
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
111
|
+
|
112
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
113
|
+
"""Pad input tokens with multimodal data hashes for RadixAttention."""
|
114
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
115
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
116
|
+
|
117
|
+
def get_input_embeddings(self):
|
118
|
+
"""Get input embeddings from the language model."""
|
119
|
+
return self.llm.get_input_embeddings()
|
120
|
+
|
121
|
+
def get_image_embeds(
|
122
|
+
self,
|
123
|
+
pixel_values: torch.Tensor,
|
124
|
+
image_grid_thw: torch.Tensor,
|
125
|
+
) -> torch.Tensor:
|
126
|
+
"""Extract image embeddings using the vision transformer."""
|
127
|
+
if self.visual is None:
|
128
|
+
raise ValueError("Visual encoder not initialized")
|
129
|
+
|
130
|
+
# Use the existing Qwen2VisionTransformer forward method
|
131
|
+
hidden_states = self.visual(pixel_values, image_grid_thw)
|
132
|
+
|
133
|
+
# Apply normalization layer
|
134
|
+
return self.norm(hidden_states)
|
135
|
+
|
136
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
137
|
+
"""Extract image features for SGLang compatibility."""
|
138
|
+
if self.visual is None:
|
139
|
+
raise ValueError("Visual encoder not initialized")
|
140
|
+
|
141
|
+
# Concatenate pixel values and grid_thw from all items
|
142
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
143
|
+
self.visual.dtype
|
144
|
+
)
|
145
|
+
image_grid_thw = torch.cat([item.image_grid_thw for item in items], dim=0)
|
146
|
+
|
147
|
+
assert pixel_values.dim() == 2, pixel_values.dim()
|
148
|
+
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
149
|
+
|
150
|
+
# Use the get_image_embeds method
|
151
|
+
return self.get_image_embeds(pixel_values, image_grid_thw)
|
152
|
+
|
153
|
+
def forward(
|
154
|
+
self,
|
155
|
+
input_ids: torch.Tensor,
|
156
|
+
positions: torch.Tensor,
|
157
|
+
forward_batch: ForwardBatch,
|
158
|
+
get_embedding: bool = False,
|
159
|
+
) -> torch.Tensor:
|
160
|
+
"""Forward pass through the model."""
|
161
|
+
# Handles token-to-feature mapping for expanded tokens
|
162
|
+
hidden_states = general_mm_embed_routine(
|
163
|
+
input_ids=input_ids,
|
164
|
+
forward_batch=forward_batch,
|
165
|
+
language_model=self.llm.model,
|
166
|
+
multimodal_model=self,
|
167
|
+
positions=positions,
|
168
|
+
)
|
169
|
+
|
170
|
+
if get_embedding:
|
171
|
+
return self.pooler(hidden_states, forward_batch)
|
172
|
+
else:
|
173
|
+
return self.logits_processor(
|
174
|
+
input_ids, hidden_states, self.llm.lm_head, forward_batch
|
175
|
+
)
|
176
|
+
|
177
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
178
|
+
"""Load model weights."""
|
179
|
+
params_dict = dict(self.named_parameters())
|
180
|
+
loaded_params = set()
|
181
|
+
|
182
|
+
# Collect weights that need to be fused
|
183
|
+
qkv_weights = {}
|
184
|
+
gate_up_weights = {}
|
185
|
+
|
186
|
+
for name, loaded_weight in weights:
|
187
|
+
# Handle weight name mappings
|
188
|
+
|
189
|
+
# Map visual attention weights: qkv -> qkv_proj
|
190
|
+
if ".attn.qkv." in name:
|
191
|
+
mapped_name = name.replace(".attn.qkv.", ".attn.qkv_proj.")
|
192
|
+
if mapped_name in params_dict:
|
193
|
+
param = params_dict[mapped_name]
|
194
|
+
weight_loader = getattr(
|
195
|
+
param, "weight_loader", default_weight_loader
|
196
|
+
)
|
197
|
+
weight_loader(param, loaded_weight)
|
198
|
+
loaded_params.add(mapped_name)
|
199
|
+
continue
|
200
|
+
|
201
|
+
# Handle Llama attention weights - need to fuse q, k, v into qkv
|
202
|
+
if ".self_attn.q_proj.weight" in name:
|
203
|
+
base = name.replace(".q_proj.weight", "")
|
204
|
+
qkv_weights[base] = qkv_weights.get(base, {})
|
205
|
+
qkv_weights[base]["q"] = loaded_weight
|
206
|
+
continue
|
207
|
+
elif ".self_attn.k_proj.weight" in name:
|
208
|
+
base = name.replace(".k_proj.weight", "")
|
209
|
+
qkv_weights[base] = qkv_weights.get(base, {})
|
210
|
+
qkv_weights[base]["k"] = loaded_weight
|
211
|
+
continue
|
212
|
+
elif ".self_attn.v_proj.weight" in name:
|
213
|
+
base = name.replace(".v_proj.weight", "")
|
214
|
+
qkv_weights[base] = qkv_weights.get(base, {})
|
215
|
+
qkv_weights[base]["v"] = loaded_weight
|
216
|
+
continue
|
217
|
+
|
218
|
+
# Handle Llama MLP weights - need to fuse gate and up projections
|
219
|
+
if ".mlp.gate_proj.weight" in name:
|
220
|
+
base = name.replace(".gate_proj.weight", "")
|
221
|
+
gate_up_weights[base] = gate_up_weights.get(base, {})
|
222
|
+
gate_up_weights[base]["gate"] = loaded_weight
|
223
|
+
continue
|
224
|
+
elif ".mlp.up_proj.weight" in name:
|
225
|
+
base = name.replace(".up_proj.weight", "")
|
226
|
+
gate_up_weights[base] = gate_up_weights.get(base, {})
|
227
|
+
gate_up_weights[base]["up"] = loaded_weight
|
228
|
+
continue
|
229
|
+
|
230
|
+
# Direct mapping for other weights
|
231
|
+
if name in params_dict:
|
232
|
+
param = params_dict[name]
|
233
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
234
|
+
weight_loader(param, loaded_weight)
|
235
|
+
loaded_params.add(name)
|
236
|
+
|
237
|
+
# Fuse QKV weights for Llama attention layers
|
238
|
+
for base, weights_dict in qkv_weights.items():
|
239
|
+
if "q" in weights_dict and "k" in weights_dict and "v" in weights_dict:
|
240
|
+
qkv_name = f"{base}.qkv_proj.weight"
|
241
|
+
if qkv_name in params_dict:
|
242
|
+
# Concatenate q, k, v weights
|
243
|
+
q, k, v = weights_dict["q"], weights_dict["k"], weights_dict["v"]
|
244
|
+
qkv = torch.cat([q, k, v], dim=0)
|
245
|
+
param = params_dict[qkv_name]
|
246
|
+
weight_loader = getattr(
|
247
|
+
param, "weight_loader", default_weight_loader
|
248
|
+
)
|
249
|
+
weight_loader(param, qkv)
|
250
|
+
loaded_params.add(qkv_name)
|
251
|
+
|
252
|
+
# Fuse gate and up weights for Llama MLP layers
|
253
|
+
for base, weights_dict in gate_up_weights.items():
|
254
|
+
if "gate" in weights_dict and "up" in weights_dict:
|
255
|
+
gate_up_name = f"{base}.gate_up_proj.weight"
|
256
|
+
if gate_up_name in params_dict:
|
257
|
+
# Concatenate gate and up weights
|
258
|
+
gate, up = weights_dict["gate"], weights_dict["up"]
|
259
|
+
gate_up = torch.cat([gate, up], dim=0)
|
260
|
+
param = params_dict[gate_up_name]
|
261
|
+
weight_loader = getattr(
|
262
|
+
param, "weight_loader", default_weight_loader
|
263
|
+
)
|
264
|
+
weight_loader(param, gate_up)
|
265
|
+
loaded_params.add(gate_up_name)
|
266
|
+
|
267
|
+
|
268
|
+
# Register the model
|
269
|
+
EntryClass = Sarashina2VisionForCausalLM
|