sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3
|
+
|
4
|
+
# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
|
5
|
+
#
|
6
|
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
7
|
+
# and OPT implementations in this library. It has been modified from its
|
8
|
+
# original forms to accommodate minor architectural differences compared
|
9
|
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
10
|
+
#
|
11
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
12
|
+
# you may not use this file except in compliance with the License.
|
13
|
+
# You may obtain a copy of the License at
|
14
|
+
#
|
15
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
16
|
+
#
|
17
|
+
# Unless required by applicable law or agreed to in writing, software
|
18
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
19
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
20
|
+
# See the License for the specific language governing permissions and
|
21
|
+
# limitations under the License.
|
22
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/starcoder2.py
|
23
|
+
""" PyTorch Starcoder2 model."""
|
24
|
+
from collections.abc import Iterable
|
25
|
+
from typing import Optional, Tuple
|
26
|
+
|
27
|
+
import torch
|
28
|
+
from torch import nn
|
29
|
+
from transformers import Starcoder2Config
|
30
|
+
|
31
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
32
|
+
from sglang.srt.layers.activation import get_act_fn
|
33
|
+
from sglang.srt.layers.linear import (
|
34
|
+
ColumnParallelLinear,
|
35
|
+
QKVParallelLinear,
|
36
|
+
RowParallelLinear,
|
37
|
+
)
|
38
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
40
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
42
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
43
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
44
|
+
ParallelLMHead,
|
45
|
+
VocabParallelEmbedding,
|
46
|
+
)
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
49
|
+
from sglang.srt.utils import add_prefix, make_layers
|
50
|
+
|
51
|
+
|
52
|
+
class Starcoder2Attention(nn.Module):
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
config: Starcoder2Config,
|
57
|
+
quant_config: Optional[QuantizationConfig] = None,
|
58
|
+
prefix: str = "",
|
59
|
+
layer_id: int = 0,
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
self.config = config
|
63
|
+
|
64
|
+
self.hidden_size = config.hidden_size
|
65
|
+
tp_size = get_tensor_model_parallel_world_size()
|
66
|
+
self.total_num_heads = config.num_attention_heads
|
67
|
+
assert self.total_num_heads % tp_size == 0
|
68
|
+
self.num_heads = self.total_num_heads // tp_size
|
69
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
70
|
+
if self.total_num_kv_heads >= tp_size:
|
71
|
+
# Number of KV heads is greater than TP size, so we partition
|
72
|
+
# the KV heads across multiple tensor parallel GPUs.
|
73
|
+
assert self.total_num_kv_heads % tp_size == 0
|
74
|
+
else:
|
75
|
+
# Number of KV heads is less than TP size, so we replicate
|
76
|
+
# the KV heads across multiple tensor parallel GPUs.
|
77
|
+
assert tp_size % self.total_num_kv_heads == 0
|
78
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
79
|
+
self.head_dim = self.hidden_size // self.total_num_heads
|
80
|
+
self.q_size = self.num_heads * self.head_dim
|
81
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
82
|
+
self.scaling = self.head_dim**-0.5
|
83
|
+
self.rope_theta = config.rope_theta
|
84
|
+
self.max_position_embeddings = config.max_position_embeddings
|
85
|
+
self.use_bias = config.use_bias
|
86
|
+
|
87
|
+
self.qkv_proj = QKVParallelLinear(
|
88
|
+
self.hidden_size,
|
89
|
+
self.head_dim,
|
90
|
+
self.total_num_heads,
|
91
|
+
self.total_num_kv_heads,
|
92
|
+
bias=self.use_bias,
|
93
|
+
quant_config=quant_config,
|
94
|
+
prefix=f"{prefix}.qkv_proj",
|
95
|
+
)
|
96
|
+
self.o_proj = RowParallelLinear(
|
97
|
+
self.total_num_heads * self.head_dim,
|
98
|
+
self.hidden_size,
|
99
|
+
bias=self.use_bias,
|
100
|
+
quant_config=quant_config,
|
101
|
+
prefix=f"{prefix}.o_proj",
|
102
|
+
)
|
103
|
+
self.rotary_emb = get_rope(
|
104
|
+
self.head_dim,
|
105
|
+
rotary_dim=self.head_dim,
|
106
|
+
max_position=self.max_position_embeddings,
|
107
|
+
base=int(self.rope_theta),
|
108
|
+
is_neox_style=True,
|
109
|
+
)
|
110
|
+
self.attn = RadixAttention(
|
111
|
+
self.num_heads,
|
112
|
+
self.head_dim,
|
113
|
+
self.scaling,
|
114
|
+
num_kv_heads=self.num_kv_heads,
|
115
|
+
layer_id=layer_id,
|
116
|
+
quant_config=quant_config,
|
117
|
+
prefix=f"{prefix}.attn",
|
118
|
+
)
|
119
|
+
|
120
|
+
def forward(
|
121
|
+
self,
|
122
|
+
positions: torch.Tensor,
|
123
|
+
hidden_states: torch.Tensor,
|
124
|
+
forward_batch: ForwardBatch,
|
125
|
+
) -> torch.Tensor:
|
126
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
127
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
128
|
+
q, k = self.rotary_emb(positions, q, k)
|
129
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
130
|
+
output, _ = self.o_proj(attn_output)
|
131
|
+
return output
|
132
|
+
|
133
|
+
|
134
|
+
class Starcoder2MLP(nn.Module):
|
135
|
+
|
136
|
+
def __init__(
|
137
|
+
self,
|
138
|
+
config: Starcoder2Config,
|
139
|
+
quant_config: Optional[QuantizationConfig] = None,
|
140
|
+
prefix: str = "",
|
141
|
+
):
|
142
|
+
super().__init__()
|
143
|
+
self.c_fc = ColumnParallelLinear(
|
144
|
+
config.hidden_size,
|
145
|
+
config.intermediate_size,
|
146
|
+
bias=config.use_bias,
|
147
|
+
quant_config=quant_config,
|
148
|
+
prefix=f"{prefix}.c_fc",
|
149
|
+
)
|
150
|
+
self.c_proj = RowParallelLinear(
|
151
|
+
config.intermediate_size,
|
152
|
+
config.hidden_size,
|
153
|
+
bias=config.use_bias,
|
154
|
+
quant_config=quant_config,
|
155
|
+
prefix=f"{prefix}.c_proj",
|
156
|
+
)
|
157
|
+
self.act = get_act_fn(config.hidden_act)
|
158
|
+
|
159
|
+
def forward(
|
160
|
+
self,
|
161
|
+
hidden_states: torch.Tensor,
|
162
|
+
) -> torch.Tensor:
|
163
|
+
hidden_states, _ = self.c_fc(hidden_states)
|
164
|
+
hidden_states = self.act(hidden_states)
|
165
|
+
hidden_states, _ = self.c_proj(hidden_states)
|
166
|
+
return hidden_states
|
167
|
+
|
168
|
+
|
169
|
+
class Starcoder2DecoderLayer(nn.Module):
|
170
|
+
|
171
|
+
def __init__(
|
172
|
+
self,
|
173
|
+
config: Starcoder2Config,
|
174
|
+
layer_id: int,
|
175
|
+
quant_config: Optional[QuantizationConfig] = None,
|
176
|
+
prefix: str = "",
|
177
|
+
):
|
178
|
+
super().__init__()
|
179
|
+
self.hidden_size = config.hidden_size
|
180
|
+
self.self_attn = Starcoder2Attention(
|
181
|
+
config=config,
|
182
|
+
layer_id=layer_id,
|
183
|
+
quant_config=quant_config,
|
184
|
+
prefix=f"{prefix}.self_attn",
|
185
|
+
)
|
186
|
+
self.mlp = Starcoder2MLP(
|
187
|
+
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
|
188
|
+
)
|
189
|
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
190
|
+
self.post_attention_layernorm = nn.LayerNorm(
|
191
|
+
config.hidden_size, eps=config.norm_epsilon
|
192
|
+
)
|
193
|
+
|
194
|
+
def forward(
|
195
|
+
self,
|
196
|
+
positions: torch.Tensor,
|
197
|
+
hidden_states: torch.Tensor,
|
198
|
+
forward_batch: ForwardBatch,
|
199
|
+
) -> torch.Tensor:
|
200
|
+
# Self Attention
|
201
|
+
residual = hidden_states
|
202
|
+
hidden_states = self.input_layernorm(hidden_states)
|
203
|
+
hidden_states = self.self_attn(
|
204
|
+
positions=positions,
|
205
|
+
hidden_states=hidden_states,
|
206
|
+
forward_batch=forward_batch,
|
207
|
+
)
|
208
|
+
hidden_states = residual + hidden_states
|
209
|
+
|
210
|
+
# Fully Connected
|
211
|
+
residual = hidden_states
|
212
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
213
|
+
hidden_states = self.mlp(hidden_states)
|
214
|
+
hidden_states = residual + hidden_states
|
215
|
+
|
216
|
+
return hidden_states
|
217
|
+
|
218
|
+
|
219
|
+
class Starcoder2Model(nn.Module):
|
220
|
+
|
221
|
+
def __init__(
|
222
|
+
self,
|
223
|
+
config: Starcoder2Config,
|
224
|
+
quant_config: Optional[QuantizationConfig] = None,
|
225
|
+
prefix: str = "",
|
226
|
+
):
|
227
|
+
super().__init__()
|
228
|
+
|
229
|
+
self.config = config
|
230
|
+
self.vocab_size = config.vocab_size
|
231
|
+
|
232
|
+
self.embed_tokens = VocabParallelEmbedding(
|
233
|
+
config.vocab_size,
|
234
|
+
config.hidden_size,
|
235
|
+
quant_config=quant_config,
|
236
|
+
prefix=f"{prefix}.embed_tokens",
|
237
|
+
)
|
238
|
+
|
239
|
+
pp_group = get_pp_group()
|
240
|
+
pp_size = pp_group.world_size
|
241
|
+
pp_rank = pp_group.rank
|
242
|
+
self.start_layer = pp_rank * config.num_hidden_layers // pp_size
|
243
|
+
self.end_layer = (pp_rank + 1) * config.num_hidden_layers // pp_size
|
244
|
+
|
245
|
+
self.layers = make_layers(
|
246
|
+
config.num_hidden_layers,
|
247
|
+
lambda idx, prefix: Starcoder2DecoderLayer(
|
248
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
249
|
+
),
|
250
|
+
prefix=f"{prefix}.layers",
|
251
|
+
)
|
252
|
+
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
253
|
+
|
254
|
+
def forward(
|
255
|
+
self,
|
256
|
+
input_ids: torch.Tensor,
|
257
|
+
positions: torch.Tensor,
|
258
|
+
forward_batch: ForwardBatch,
|
259
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
260
|
+
) -> torch.Tensor:
|
261
|
+
if inputs_embeds is None:
|
262
|
+
hidden_states = self.embed_tokens(input_ids)
|
263
|
+
else:
|
264
|
+
hidden_states = inputs_embeds
|
265
|
+
for i in range(self.start_layer, self.end_layer):
|
266
|
+
layer = self.layers[i]
|
267
|
+
hidden_states = layer(
|
268
|
+
positions,
|
269
|
+
hidden_states,
|
270
|
+
forward_batch,
|
271
|
+
)
|
272
|
+
hidden_states = self.norm(hidden_states)
|
273
|
+
return hidden_states
|
274
|
+
|
275
|
+
|
276
|
+
class Starcoder2ForCausalLM(nn.Module):
|
277
|
+
|
278
|
+
def __init__(
|
279
|
+
self,
|
280
|
+
config: Starcoder2Config,
|
281
|
+
quant_config: Optional[QuantizationConfig] = None,
|
282
|
+
prefix: str = "",
|
283
|
+
):
|
284
|
+
super().__init__()
|
285
|
+
self.config = config
|
286
|
+
self.model = Starcoder2Model(
|
287
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
288
|
+
)
|
289
|
+
self.vocab_size = config.vocab_size
|
290
|
+
self.unpadded_vocab_size = config.vocab_size
|
291
|
+
if config.tie_word_embeddings:
|
292
|
+
self.lm_head = self.model.embed_tokens
|
293
|
+
else:
|
294
|
+
self.unpadded_vocab_size = config.vocab_size
|
295
|
+
self.lm_head = ParallelLMHead(
|
296
|
+
self.unpadded_vocab_size,
|
297
|
+
config.hidden_size,
|
298
|
+
org_num_embeddings=config.vocab_size,
|
299
|
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
300
|
+
quant_config=quant_config,
|
301
|
+
prefix=f"{prefix}.lm_head",
|
302
|
+
)
|
303
|
+
self.logits_processor = LogitsProcessor(config=config)
|
304
|
+
|
305
|
+
def forward(
|
306
|
+
self,
|
307
|
+
input_ids: torch.Tensor,
|
308
|
+
positions: torch.Tensor,
|
309
|
+
forward_batch: ForwardBatch,
|
310
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
311
|
+
) -> torch.Tensor:
|
312
|
+
hidden_states = self.model(
|
313
|
+
input_ids=input_ids,
|
314
|
+
positions=positions,
|
315
|
+
forward_batch=forward_batch,
|
316
|
+
inputs_embeds=inputs_embeds,
|
317
|
+
)
|
318
|
+
return self.logits_processor(
|
319
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
320
|
+
)
|
321
|
+
|
322
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
323
|
+
stacked_params_mapping = [
|
324
|
+
# (param_name, shard_name, shard_id)
|
325
|
+
("qkv_proj", "q_proj", "q"),
|
326
|
+
("qkv_proj", "k_proj", "k"),
|
327
|
+
("qkv_proj", "v_proj", "v"),
|
328
|
+
]
|
329
|
+
params_dict = dict(self.named_parameters())
|
330
|
+
|
331
|
+
for name, loaded_weight in weights:
|
332
|
+
if "rotary_emb.inv_freqs" in name:
|
333
|
+
continue
|
334
|
+
|
335
|
+
is_stacked = False
|
336
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
337
|
+
if weight_name in name:
|
338
|
+
name = name.replace(weight_name, param_name)
|
339
|
+
param = params_dict[name]
|
340
|
+
weight_loader = getattr(
|
341
|
+
param, "weight_loader", default_weight_loader
|
342
|
+
)
|
343
|
+
weight_loader(param, loaded_weight, shard_id)
|
344
|
+
is_stacked = True
|
345
|
+
break
|
346
|
+
if is_stacked:
|
347
|
+
continue
|
348
|
+
|
349
|
+
param = params_dict.get(name)
|
350
|
+
if param is None:
|
351
|
+
continue
|
352
|
+
|
353
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
354
|
+
weight_loader(param, loaded_weight)
|
355
|
+
|
356
|
+
|
357
|
+
EntryClass = Starcoder2ForCausalLM
|
sglang/srt/models/step3_vl.py
CHANGED
@@ -133,7 +133,7 @@ class Step3TextMoEMLP(nn.Module):
|
|
133
133
|
use_grouped_topk=False,
|
134
134
|
)
|
135
135
|
|
136
|
-
self.experts = get_moe_impl_class()(
|
136
|
+
self.experts = get_moe_impl_class(quant_config)(
|
137
137
|
num_experts=config.moe_num_experts,
|
138
138
|
top_k=config.moe_top_k,
|
139
139
|
hidden_size=config.hidden_size,
|
@@ -66,8 +66,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
66
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
67
|
from sglang.srt.utils import add_prefix
|
68
68
|
|
69
|
-
tp_size =
|
70
|
-
tp_rank =
|
69
|
+
tp_size: Optional[int] = None
|
70
|
+
tp_rank: Optional[int] = None
|
71
71
|
|
72
72
|
|
73
73
|
def gate_up_proj_weight_loader(
|
@@ -341,6 +341,13 @@ class LlamaModel(nn.Module):
|
|
341
341
|
quant_config: Optional[QuantizationConfig] = None,
|
342
342
|
) -> None:
|
343
343
|
super().__init__()
|
344
|
+
|
345
|
+
global tp_size, tp_rank
|
346
|
+
if tp_size is None:
|
347
|
+
tp_size = get_tensor_model_parallel_world_size()
|
348
|
+
if tp_rank is None:
|
349
|
+
tp_rank = get_tensor_model_parallel_rank()
|
350
|
+
|
344
351
|
self.config = config
|
345
352
|
self.padding_idx = config.pad_token_id
|
346
353
|
self.vocab_size = config.vocab_size
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright 2023-2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
18
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
|
+
from sglang.srt.utils import is_cuda
|
20
|
+
|
21
|
+
_is_cuda = is_cuda()
|
22
|
+
|
23
|
+
|
24
|
+
if _is_cuda:
|
25
|
+
from sgl_kernel import FusedSetKVBufferArg
|
26
|
+
|
27
|
+
|
28
|
+
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
29
|
+
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
30
|
+
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
31
|
+
|
32
|
+
|
33
|
+
def create_fused_set_kv_buffer_arg(
|
34
|
+
value: torch.Tensor,
|
35
|
+
layer: RadixAttention,
|
36
|
+
forward_batch: ForwardBatch,
|
37
|
+
):
|
38
|
+
layer_id = layer.layer_id
|
39
|
+
token_to_kv_pool = forward_batch.token_to_kv_pool
|
40
|
+
|
41
|
+
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
42
|
+
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
43
|
+
|
44
|
+
return FusedSetKVBufferArg(
|
45
|
+
value=value,
|
46
|
+
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
47
|
+
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
48
|
+
k_scale=layer.k_scale,
|
49
|
+
v_scale=layer.v_scale,
|
50
|
+
cache_loc=forward_batch.out_cache_loc,
|
51
|
+
)
|
@@ -234,19 +234,27 @@ class BaseMultimodalProcessor(ABC):
|
|
234
234
|
and isinstance(processor.image_processor, BaseImageProcessorFast)
|
235
235
|
and not self.server_args.disable_fast_image_processor
|
236
236
|
):
|
237
|
-
|
237
|
+
if not _is_npu:
|
238
|
+
kwargs["device"] = "cuda"
|
239
|
+
elif processor.__class__.__name__ not in {
|
240
|
+
"Qwen2_5_VLProcessor",
|
241
|
+
"Qwen3VLProcessor",
|
242
|
+
}:
|
243
|
+
# Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend.
|
244
|
+
kwargs["device"] = "npu"
|
238
245
|
result = processor.__call__(
|
239
246
|
text=[input_text],
|
240
247
|
padding=True,
|
241
248
|
return_tensors="pt",
|
242
249
|
**kwargs,
|
243
250
|
)
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
251
|
+
if not self.server_args.keep_mm_feature_on_device:
|
252
|
+
# move feature tensors to cpu
|
253
|
+
for feature_name in self.FEATURE_NAMES:
|
254
|
+
if feature_name in result and isinstance(
|
255
|
+
result[feature_name], torch.Tensor
|
256
|
+
):
|
257
|
+
result[feature_name] = result[feature_name].to("cpu")
|
250
258
|
|
251
259
|
return result
|
252
260
|
|
@@ -0,0 +1,98 @@
|
|
1
|
+
import asyncio
|
2
|
+
import math
|
3
|
+
import re
|
4
|
+
from typing import Dict, List, Union
|
5
|
+
|
6
|
+
from PIL import Image
|
7
|
+
|
8
|
+
from sglang.srt.models.dots_ocr import DotsOCRForCausalLM
|
9
|
+
from sglang.srt.models.dots_vlm import DotsVLMForCausalLM
|
10
|
+
from sglang.srt.multimodal.processors.base_processor import (
|
11
|
+
BaseMultimodalProcessor,
|
12
|
+
MultimodalSpecialTokens,
|
13
|
+
)
|
14
|
+
from sglang.srt.multimodal.processors.qwen_vl import resize_image_async
|
15
|
+
|
16
|
+
|
17
|
+
class DotsVLMImageProcessor(BaseMultimodalProcessor):
|
18
|
+
models = [DotsVLMForCausalLM, DotsOCRForCausalLM]
|
19
|
+
|
20
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
21
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
22
|
+
# The single, pre-expanded image token.
|
23
|
+
self.IMAGE_TOKEN = "<|img|><|imgpad|><|endofimg|>"
|
24
|
+
# The regex that matches expanded image tokens.
|
25
|
+
self.IMAGE_TOKEN_REGEX = re.compile(r"<\|img\|>(?:<\|imgpad\|>)+<\|endofimg\|>")
|
26
|
+
|
27
|
+
assert len(_processor.tokenizer.encode("<|img|>")) == 1
|
28
|
+
self.im_start_id = _processor.tokenizer.encode("<|img|>")[0]
|
29
|
+
self.im_end_id = _processor.tokenizer.encode("<|endofimg|>")[0]
|
30
|
+
self.image_token_id = _processor.tokenizer.encode("<|imgpad|>")[0]
|
31
|
+
self.IM_TOKEN_ID = self.image_token_id
|
32
|
+
self.IM_START_ID = self.im_start_id
|
33
|
+
self.IM_END_ID = self.im_end_id
|
34
|
+
|
35
|
+
vision_config = hf_config.vision_config
|
36
|
+
patch_size = vision_config.patch_size
|
37
|
+
merge_size = vision_config.spatial_merge_size
|
38
|
+
|
39
|
+
self.IMAGE_FACTOR = patch_size * merge_size
|
40
|
+
self.MIN_PIXELS = _processor.image_processor.min_pixels
|
41
|
+
self.MAX_PIXELS = _processor.image_processor.max_pixels
|
42
|
+
self.MAX_RATIO = 200
|
43
|
+
self.mm_tokens = MultimodalSpecialTokens(
|
44
|
+
image_token=self.IMAGE_TOKEN,
|
45
|
+
image_token_id=self.image_token_id,
|
46
|
+
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
47
|
+
).build(_processor)
|
48
|
+
|
49
|
+
async def process_mm_data_async(
|
50
|
+
self,
|
51
|
+
image_data: List[Union[str, bytes, Dict]],
|
52
|
+
input_text,
|
53
|
+
request_obj,
|
54
|
+
max_req_input_len,
|
55
|
+
*args,
|
56
|
+
**kwargs,
|
57
|
+
):
|
58
|
+
if isinstance(image_data, str):
|
59
|
+
image_data = [image_data]
|
60
|
+
|
61
|
+
if (
|
62
|
+
isinstance(image_data, list)
|
63
|
+
and image_data
|
64
|
+
and isinstance(image_data[0], list)
|
65
|
+
):
|
66
|
+
image_data = sum(image_data, [])
|
67
|
+
|
68
|
+
base_output = self.load_mm_data(
|
69
|
+
prompt=input_text,
|
70
|
+
image_data=image_data,
|
71
|
+
multimodal_tokens=self.mm_tokens,
|
72
|
+
)
|
73
|
+
|
74
|
+
# Qwen-specific: resize images if they are raw Image objects
|
75
|
+
if base_output.images and isinstance(base_output.images[0], Image.Image):
|
76
|
+
resize_tasks = [
|
77
|
+
resize_image_async(
|
78
|
+
image,
|
79
|
+
min_pixels=self.MIN_PIXELS,
|
80
|
+
max_pixels=self.MAX_PIXELS,
|
81
|
+
size_factor=self.IMAGE_FACTOR,
|
82
|
+
)
|
83
|
+
for image in base_output.images
|
84
|
+
]
|
85
|
+
base_output.images = await asyncio.gather(*resize_tasks)
|
86
|
+
combined_mm_item, input_ids, _ = self.process_and_combine_mm_data(
|
87
|
+
base_output, self.mm_tokens
|
88
|
+
)
|
89
|
+
if combined_mm_item is None:
|
90
|
+
return None
|
91
|
+
|
92
|
+
return {
|
93
|
+
"input_ids": input_ids.tolist(),
|
94
|
+
"mm_items": combined_mm_item,
|
95
|
+
"im_start_id": self.im_start_id,
|
96
|
+
"im_end_id": self.im_end_id,
|
97
|
+
"im_token_id": self.image_token_id,
|
98
|
+
}
|
@@ -2,7 +2,6 @@ import re
|
|
2
2
|
from typing import List, Union
|
3
3
|
|
4
4
|
from decord import VideoReader
|
5
|
-
from transformers.video_utils import VideoMetadata
|
6
5
|
|
7
6
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
8
7
|
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
|
@@ -66,17 +65,18 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
|
|
66
65
|
total_num_frames = len(vr)
|
67
66
|
duration = total_num_frames / video_fps if video_fps else 0
|
68
67
|
|
69
|
-
metadata = VideoMetadata(
|
70
|
-
total_num_frames=int(total_num_frames),
|
71
|
-
fps=float(video_fps),
|
72
|
-
duration=float(duration),
|
73
|
-
video_backend="decord",
|
74
|
-
)
|
75
|
-
|
76
68
|
# Extract all frames
|
77
69
|
indices = list(range(total_num_frames))
|
78
70
|
frames = vr.get_batch(indices).asnumpy()
|
79
|
-
|
71
|
+
|
72
|
+
# Return metadata as dict so transformers can properly create VideoMetadata objects
|
73
|
+
metadata = {
|
74
|
+
"total_num_frames": int(total_num_frames),
|
75
|
+
"fps": float(video_fps),
|
76
|
+
"duration": float(duration),
|
77
|
+
"video_backend": "decord",
|
78
|
+
"frames_indices": indices,
|
79
|
+
}
|
80
80
|
|
81
81
|
return frames, metadata
|
82
82
|
|