sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +378 -160
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +10 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +136 -25
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +63 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +83 -80
- sglang/srt/entrypoints/grpc_server.py +430 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +195 -102
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +58 -6
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +33 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +20 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +10 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +24 -10
- sglang/srt/layers/attention/flashinfer_backend.py +258 -22
- sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
- sglang/srt/layers/attention/utils.py +89 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +12 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +64 -19
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +152 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
- sglang/srt/layers/moe/ep_moe/layer.py +154 -625
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
- sglang/srt/layers/moe/moe_runner/runner.py +6 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
- sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +7 -6
- sglang/srt/layers/moe/utils.py +20 -5
- sglang/srt/layers/quantization/__init__.py +5 -58
- sglang/srt/layers/quantization/awq.py +183 -9
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +27 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +152 -81
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +35 -68
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +23 -48
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +87 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +62 -9
- sglang/srt/layers/rotary_embedding.py +686 -17
- sglang/srt/layers/sampler.py +47 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +69 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +420 -514
- sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +375 -95
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +11 -2
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +517 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +71 -25
- sglang/srt/model_executor/model_runner.py +362 -270
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +418 -140
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +327 -382
- sglang/srt/models/glm4_moe_nextn.py +6 -16
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +32 -199
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +22 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3.py +34 -4
- sglang/srt/models/qwen3_moe.py +19 -37
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +7 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +2 -6
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +28 -2
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +846 -163
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +36 -31
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +272 -82
- sglang/srt/utils/hf_transformers_utils.py +44 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +463 -107
- sglang/test/test_deterministic_utils.py +74 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/models/vila.py +0 -306
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import math
|
|
2
|
-
import os
|
|
3
2
|
from dataclasses import dataclass
|
|
4
3
|
from typing import Dict, List, Optional, Tuple
|
|
5
4
|
|
|
@@ -12,6 +11,8 @@ from transformers import (
|
|
|
12
11
|
ProcessorMixin,
|
|
13
12
|
)
|
|
14
13
|
|
|
14
|
+
from sglang.srt.configs.deepseek_ocr import BASE_SIZE, IMAGE_SIZE, MAX_CROPS, MIN_CROPS
|
|
15
|
+
|
|
15
16
|
|
|
16
17
|
def select_best_resolution(image_size, candidate_resolutions):
|
|
17
18
|
# used for cropping
|
|
@@ -62,6 +63,7 @@ class DictOutput(object):
|
|
|
62
63
|
class VLChatProcessorOutput(DictOutput):
|
|
63
64
|
input_ids: torch.LongTensor
|
|
64
65
|
target_ids: torch.LongTensor
|
|
66
|
+
images_crop: torch.LongTensor
|
|
65
67
|
pixel_values: (
|
|
66
68
|
torch.Tensor
|
|
67
69
|
) # rename from "images" to "pixel_values" for compatibility
|
|
@@ -105,6 +107,68 @@ class ImageTransform(object):
|
|
|
105
107
|
return x
|
|
106
108
|
|
|
107
109
|
|
|
110
|
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
|
111
|
+
best_ratio_diff = float("inf")
|
|
112
|
+
best_ratio = (1, 1)
|
|
113
|
+
area = width * height
|
|
114
|
+
for ratio in target_ratios:
|
|
115
|
+
target_aspect_ratio = ratio[0] / ratio[1]
|
|
116
|
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
117
|
+
if ratio_diff < best_ratio_diff:
|
|
118
|
+
best_ratio_diff = ratio_diff
|
|
119
|
+
best_ratio = ratio
|
|
120
|
+
elif ratio_diff == best_ratio_diff:
|
|
121
|
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
122
|
+
best_ratio = ratio
|
|
123
|
+
return best_ratio
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def dynamic_preprocess(
|
|
127
|
+
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
|
|
128
|
+
):
|
|
129
|
+
orig_width, orig_height = image.size
|
|
130
|
+
aspect_ratio = orig_width / orig_height
|
|
131
|
+
|
|
132
|
+
# calculate the existing image aspect ratio
|
|
133
|
+
target_ratios = set(
|
|
134
|
+
(i, j)
|
|
135
|
+
for n in range(min_num, max_num + 1)
|
|
136
|
+
for i in range(1, n + 1)
|
|
137
|
+
for j in range(1, n + 1)
|
|
138
|
+
if i * j <= max_num and i * j >= min_num
|
|
139
|
+
)
|
|
140
|
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
141
|
+
|
|
142
|
+
# find the closest aspect ratio to the target
|
|
143
|
+
target_aspect_ratio = find_closest_aspect_ratio(
|
|
144
|
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# calculate the target width and height
|
|
148
|
+
target_width = image_size * target_aspect_ratio[0]
|
|
149
|
+
target_height = image_size * target_aspect_ratio[1]
|
|
150
|
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
151
|
+
|
|
152
|
+
# resize the image
|
|
153
|
+
resized_img = image.resize((target_width, target_height))
|
|
154
|
+
processed_images = []
|
|
155
|
+
for i in range(blocks):
|
|
156
|
+
box = (
|
|
157
|
+
(i % (target_width // image_size)) * image_size,
|
|
158
|
+
(i // (target_width // image_size)) * image_size,
|
|
159
|
+
((i % (target_width // image_size)) + 1) * image_size,
|
|
160
|
+
((i // (target_width // image_size)) + 1) * image_size,
|
|
161
|
+
)
|
|
162
|
+
# split the image
|
|
163
|
+
split_img = resized_img.crop(box)
|
|
164
|
+
processed_images.append(split_img)
|
|
165
|
+
assert len(processed_images) == blocks
|
|
166
|
+
if use_thumbnail and len(processed_images) != 1:
|
|
167
|
+
thumbnail_img = image.resize((image_size, image_size))
|
|
168
|
+
processed_images.append(thumbnail_img)
|
|
169
|
+
return processed_images, target_aspect_ratio
|
|
170
|
+
|
|
171
|
+
|
|
108
172
|
class DeepseekVLV2Processor(ProcessorMixin):
|
|
109
173
|
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
|
110
174
|
attributes = ["tokenizer"]
|
|
@@ -134,7 +198,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
|
134
198
|
self.image_std = image_std
|
|
135
199
|
self.normalize = normalize
|
|
136
200
|
self.downsample_ratio = downsample_ratio
|
|
137
|
-
|
|
201
|
+
self.base_size = BASE_SIZE
|
|
138
202
|
self.image_transform = ImageTransform(
|
|
139
203
|
mean=image_mean, std=image_std, normalize=normalize
|
|
140
204
|
)
|
|
@@ -177,7 +241,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
|
177
241
|
**kwargs,
|
|
178
242
|
)
|
|
179
243
|
|
|
180
|
-
def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
|
|
244
|
+
def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
|
|
181
245
|
"""play the role of format_messages_v2 and get_images_info in the last version"""
|
|
182
246
|
tokenized_data = []
|
|
183
247
|
masked_tokenized_data = [] # labels
|
|
@@ -187,35 +251,34 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
|
187
251
|
|
|
188
252
|
image_index = 0
|
|
189
253
|
image_token_cnt = messages.count(self.image_token)
|
|
190
|
-
|
|
254
|
+
(
|
|
255
|
+
input_ids,
|
|
256
|
+
images,
|
|
257
|
+
images_crop,
|
|
258
|
+
seq_mask,
|
|
259
|
+
spatial_crop,
|
|
260
|
+
num_image_tokens,
|
|
261
|
+
image_shapes,
|
|
262
|
+
) = self.tokenize_with_images(
|
|
191
263
|
messages,
|
|
192
264
|
pil_images[image_index : image_index + image_token_cnt],
|
|
193
265
|
bos=True,
|
|
194
266
|
eos=True,
|
|
195
267
|
cropping=len(pil_images) <= 2,
|
|
196
|
-
max_req_input_len=max_req_input_len,
|
|
197
268
|
)
|
|
198
269
|
|
|
199
270
|
image_index = image_token_cnt
|
|
200
|
-
tokenized_data += tokenized_str
|
|
201
|
-
if self.mask_prompt:
|
|
202
|
-
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
|
|
203
|
-
else:
|
|
204
|
-
masked_tokenized_data += tokenized_str
|
|
205
271
|
images_list += images
|
|
206
272
|
images_seq_mask += seq_mask
|
|
207
|
-
images_spatial_crop
|
|
208
|
-
|
|
209
|
-
assert len(tokenized_data) == len(
|
|
210
|
-
images_seq_mask
|
|
211
|
-
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
|
273
|
+
images_spatial_crop = spatial_crop
|
|
212
274
|
|
|
213
275
|
return (
|
|
214
|
-
|
|
276
|
+
input_ids,
|
|
215
277
|
masked_tokenized_data,
|
|
216
278
|
images_list,
|
|
217
279
|
images_seq_mask,
|
|
218
280
|
images_spatial_crop,
|
|
281
|
+
images_crop,
|
|
219
282
|
)
|
|
220
283
|
|
|
221
284
|
@property
|
|
@@ -252,6 +315,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
|
252
315
|
inference_mode: bool = True,
|
|
253
316
|
system_prompt: str = "",
|
|
254
317
|
max_req_input_len: int = -1,
|
|
318
|
+
cropping: bool = True,
|
|
255
319
|
**kwargs,
|
|
256
320
|
):
|
|
257
321
|
"""
|
|
@@ -275,47 +339,22 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
|
275
339
|
- num_image_tokens (List[int]): the number of image tokens
|
|
276
340
|
"""
|
|
277
341
|
|
|
278
|
-
|
|
279
|
-
prompt is None or conversations is None
|
|
280
|
-
), "prompt and conversations cannot be used at the same time."
|
|
281
|
-
|
|
342
|
+
prompt = conversations or prompt
|
|
282
343
|
(
|
|
283
|
-
|
|
344
|
+
input_ids,
|
|
284
345
|
masked_tokenized_str,
|
|
285
346
|
images_list,
|
|
286
347
|
images_seq_mask,
|
|
287
348
|
images_spatial_crop,
|
|
288
|
-
|
|
349
|
+
images_crop,
|
|
350
|
+
) = self.format_messages_v2(prompt, images, max_req_input_len)
|
|
289
351
|
|
|
290
|
-
assert (
|
|
291
|
-
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
|
|
292
|
-
), (
|
|
293
|
-
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
|
|
294
|
-
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
input_ids = torch.LongTensor(tokenized_str)
|
|
298
352
|
target_ids = torch.LongTensor(masked_tokenized_str)
|
|
299
|
-
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
|
300
|
-
|
|
301
|
-
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
|
|
302
|
-
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
|
|
303
|
-
self.ignore_id
|
|
304
|
-
)
|
|
305
|
-
input_ids[input_ids < 0] = self.pad_id
|
|
306
|
-
|
|
307
|
-
if inference_mode:
|
|
308
|
-
assert input_ids[-1] == self.eos_id
|
|
309
|
-
input_ids = input_ids[:-1]
|
|
310
|
-
target_ids = target_ids[:-1]
|
|
311
|
-
images_seq_mask = images_seq_mask[:-1]
|
|
312
353
|
|
|
313
354
|
if len(images_list) == 0:
|
|
314
355
|
images = torch.zeros((1, 3, self.image_size, self.image_size))
|
|
315
|
-
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
|
|
316
356
|
else:
|
|
317
357
|
images = torch.stack(images_list, dim=0)
|
|
318
|
-
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
|
319
358
|
|
|
320
359
|
images_spatial_crop = torch.stack(
|
|
321
360
|
[images_spatial_crop], dim=0
|
|
@@ -324,6 +363,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
|
324
363
|
prepare = VLChatProcessorOutput(
|
|
325
364
|
input_ids=input_ids,
|
|
326
365
|
target_ids=target_ids,
|
|
366
|
+
images_crop=images_crop,
|
|
327
367
|
pixel_values=images,
|
|
328
368
|
images_seq_mask=images_seq_mask,
|
|
329
369
|
images_spatial_crop=images_spatial_crop,
|
|
@@ -341,10 +381,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
|
341
381
|
inference_mode: bool = True,
|
|
342
382
|
system_prompt: str = "",
|
|
343
383
|
max_req_input_len: int = -1,
|
|
384
|
+
text: list[str] = None,
|
|
344
385
|
**kwargs,
|
|
345
386
|
):
|
|
387
|
+
assert text is None or isinstance(text, list)
|
|
388
|
+
if text is not None:
|
|
389
|
+
text = text[0]
|
|
346
390
|
prepare = self.process_one(
|
|
347
|
-
prompt=prompt,
|
|
391
|
+
prompt=prompt or text,
|
|
348
392
|
conversations=conversations,
|
|
349
393
|
images=images,
|
|
350
394
|
apply_sft_format=apply_sft_format,
|
|
@@ -369,85 +413,83 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
|
369
413
|
bos: bool = True,
|
|
370
414
|
eos: bool = True,
|
|
371
415
|
cropping: bool = True,
|
|
372
|
-
max_req_input_len: int = -1,
|
|
373
416
|
):
|
|
374
417
|
"""Tokenize text with <image> tags."""
|
|
375
|
-
|
|
418
|
+
|
|
419
|
+
conversation = conversation
|
|
420
|
+
assert conversation.count(self.image_token) == len(images)
|
|
376
421
|
text_splits = conversation.split(self.image_token)
|
|
422
|
+
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
|
|
423
|
+
[],
|
|
424
|
+
[],
|
|
425
|
+
[],
|
|
426
|
+
[],
|
|
427
|
+
)
|
|
428
|
+
image_shapes = []
|
|
429
|
+
num_image_tokens = []
|
|
377
430
|
tokenized_str = []
|
|
378
431
|
for text_sep, image in zip(text_splits, images):
|
|
379
432
|
"""encode text_sep"""
|
|
380
433
|
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
|
|
434
|
+
|
|
381
435
|
tokenized_str += tokenized_sep
|
|
382
436
|
images_seq_mask += [False] * len(tokenized_sep)
|
|
383
437
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
)
|
|
438
|
+
image_shapes.append(image.size)
|
|
439
|
+
|
|
440
|
+
if image.size[0] <= 640 and image.size[1] <= 640:
|
|
441
|
+
crop_ratio = [1, 1]
|
|
389
442
|
else:
|
|
390
|
-
|
|
391
|
-
|
|
443
|
+
if cropping:
|
|
444
|
+
images_crop_raw, crop_ratio = dynamic_preprocess(
|
|
445
|
+
image, image_size=IMAGE_SIZE
|
|
446
|
+
)
|
|
447
|
+
else:
|
|
448
|
+
crop_ratio = [1, 1]
|
|
392
449
|
|
|
393
450
|
"""process the global view"""
|
|
451
|
+
if self.image_size <= 640 and not cropping:
|
|
452
|
+
image = image.resize((self.image_size, self.image_size))
|
|
453
|
+
|
|
394
454
|
global_view = ImageOps.pad(
|
|
395
455
|
image,
|
|
396
|
-
(self.
|
|
456
|
+
(self.base_size, self.base_size),
|
|
397
457
|
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
|
398
458
|
)
|
|
399
459
|
images_list.append(self.image_transform(global_view))
|
|
400
460
|
|
|
401
|
-
|
|
402
|
-
local_view = ImageOps.pad(
|
|
403
|
-
image,
|
|
404
|
-
(best_width, best_height),
|
|
405
|
-
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
|
406
|
-
)
|
|
407
|
-
for i in range(0, best_height, self.image_size):
|
|
408
|
-
for j in range(0, best_width, self.image_size):
|
|
409
|
-
images_list.append(
|
|
410
|
-
self.image_transform(
|
|
411
|
-
local_view.crop(
|
|
412
|
-
(j, i, j + self.image_size, i + self.image_size)
|
|
413
|
-
)
|
|
414
|
-
)
|
|
415
|
-
)
|
|
416
|
-
|
|
417
|
-
"""record height / width crop num"""
|
|
418
|
-
num_width_tiles, num_height_tiles = (
|
|
419
|
-
best_width // self.image_size,
|
|
420
|
-
best_height // self.image_size,
|
|
421
|
-
)
|
|
461
|
+
num_width_tiles, num_height_tiles = crop_ratio
|
|
422
462
|
images_spatial_crop.append([num_width_tiles, num_height_tiles])
|
|
423
463
|
|
|
464
|
+
if num_width_tiles > 1 or num_height_tiles > 1:
|
|
465
|
+
for i in range(len(images_crop_raw)):
|
|
466
|
+
images_crop_list.append(self.image_transform(images_crop_raw[i]))
|
|
467
|
+
|
|
424
468
|
"""add image tokens"""
|
|
425
|
-
|
|
469
|
+
num_queries = math.ceil(
|
|
426
470
|
(self.image_size // self.patch_size) / self.downsample_ratio
|
|
427
471
|
)
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
# add a separator between global and local views
|
|
431
|
-
tokenized_image += [self.image_token_id]
|
|
432
|
-
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
|
|
433
|
-
tokenized_image += (
|
|
434
|
-
[self.image_token_id]
|
|
435
|
-
* (num_height_tiles * h)
|
|
436
|
-
* (num_width_tiles * w + 1)
|
|
472
|
+
num_queries_base = math.ceil(
|
|
473
|
+
(self.base_size // self.patch_size) / self.downsample_ratio
|
|
437
474
|
)
|
|
438
475
|
|
|
476
|
+
tokenized_image = (
|
|
477
|
+
[self.image_token_id] * num_queries_base + [self.image_token_id]
|
|
478
|
+
) * num_queries_base
|
|
479
|
+
tokenized_image += [self.image_token_id]
|
|
480
|
+
if num_width_tiles > 1 or num_height_tiles > 1:
|
|
481
|
+
tokenized_image += (
|
|
482
|
+
[self.image_token_id] * (num_queries * num_width_tiles)
|
|
483
|
+
+ [self.image_token_id]
|
|
484
|
+
) * (num_queries * num_height_tiles)
|
|
439
485
|
tokenized_str += tokenized_image
|
|
486
|
+
|
|
440
487
|
images_seq_mask += [True] * len(tokenized_image)
|
|
441
|
-
|
|
488
|
+
num_image_tokens.append(len(tokenized_image))
|
|
442
489
|
|
|
443
490
|
"""process the last text split"""
|
|
444
491
|
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
|
|
445
|
-
|
|
446
|
-
if max_req_input_len > -1:
|
|
447
|
-
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
|
|
448
|
-
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
|
|
449
|
-
tokenized_str = tokenized_str[:rest]
|
|
450
|
-
images_seq_mask = images_seq_mask[:rest]
|
|
492
|
+
|
|
451
493
|
tokenized_str += tokenized_sep
|
|
452
494
|
images_seq_mask += [False] * len(tokenized_sep)
|
|
453
495
|
|
|
@@ -463,7 +505,64 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
|
463
505
|
images_seq_mask
|
|
464
506
|
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
|
465
507
|
|
|
466
|
-
|
|
508
|
+
masked_tokenized_str = []
|
|
509
|
+
for token_index in tokenized_str:
|
|
510
|
+
if token_index != self.image_token_id:
|
|
511
|
+
masked_tokenized_str.append(token_index)
|
|
512
|
+
else:
|
|
513
|
+
masked_tokenized_str.append(self.ignore_id)
|
|
514
|
+
|
|
515
|
+
assert (
|
|
516
|
+
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
|
|
517
|
+
), (
|
|
518
|
+
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
|
|
519
|
+
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
|
|
520
|
+
)
|
|
521
|
+
input_ids = torch.LongTensor(tokenized_str)
|
|
522
|
+
target_ids = torch.LongTensor(masked_tokenized_str)
|
|
523
|
+
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
|
524
|
+
|
|
525
|
+
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
|
|
526
|
+
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
|
|
527
|
+
self.ignore_id
|
|
528
|
+
)
|
|
529
|
+
input_ids[input_ids < 0] = self.pad_id
|
|
530
|
+
|
|
531
|
+
inference_mode = True
|
|
532
|
+
|
|
533
|
+
if inference_mode:
|
|
534
|
+
# Remove the ending eos token
|
|
535
|
+
assert input_ids[-1] == self.eos_id
|
|
536
|
+
input_ids = input_ids[:-1]
|
|
537
|
+
target_ids = target_ids[:-1]
|
|
538
|
+
images_seq_mask = images_seq_mask[:-1]
|
|
539
|
+
|
|
540
|
+
if len(images_list) == 0:
|
|
541
|
+
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
|
|
542
|
+
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
|
|
543
|
+
images_crop = torch.zeros(
|
|
544
|
+
(1, 3, self.image_size, self.image_size)
|
|
545
|
+
).unsqueeze(0)
|
|
546
|
+
else:
|
|
547
|
+
pixel_values = torch.stack(images_list, dim=0)
|
|
548
|
+
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
|
549
|
+
if images_crop_list:
|
|
550
|
+
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
|
|
551
|
+
else:
|
|
552
|
+
images_crop = torch.zeros(
|
|
553
|
+
(1, 3, self.image_size, self.image_size)
|
|
554
|
+
).unsqueeze(0)
|
|
555
|
+
|
|
556
|
+
input_ids = input_ids.unsqueeze(0)
|
|
557
|
+
return (
|
|
558
|
+
input_ids,
|
|
559
|
+
pixel_values,
|
|
560
|
+
images_crop,
|
|
561
|
+
images_seq_mask,
|
|
562
|
+
images_spatial_crop,
|
|
563
|
+
num_image_tokens,
|
|
564
|
+
image_shapes,
|
|
565
|
+
)
|
|
467
566
|
|
|
468
567
|
|
|
469
568
|
class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
|
|
@@ -548,7 +647,6 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
|
|
|
548
647
|
|
|
549
648
|
|
|
550
649
|
class DeepseekV2Config(PretrainedConfig):
|
|
551
|
-
|
|
552
650
|
model_type = "deepseek_v2"
|
|
553
651
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
554
652
|
|
sglang/srt/configs/dots_vlm.py
CHANGED
|
@@ -1,10 +1,5 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from transformers import AutoProcessor, LlamaTokenizerFast, PretrainedConfig
|
|
4
|
-
from transformers.feature_extraction_utils import BatchFeature
|
|
5
|
-
from transformers.image_utils import ImageInput
|
|
6
|
-
from transformers.processing_utils import ProcessingKwargs, Unpack
|
|
7
|
-
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
1
|
+
from transformers import AutoProcessor, PretrainedConfig
|
|
2
|
+
from transformers.processing_utils import ProcessingKwargs
|
|
8
3
|
|
|
9
4
|
try:
|
|
10
5
|
from transformers import Qwen2_5_VLProcessor
|
sglang/srt/configs/falcon_h1.py
CHANGED
|
@@ -14,21 +14,12 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
"""Falcon-H1 model configuration"""
|
|
16
16
|
|
|
17
|
-
import enum
|
|
18
|
-
import os
|
|
19
17
|
|
|
20
|
-
import numpy as np
|
|
21
|
-
import torch
|
|
22
18
|
from transformers.configuration_utils import PretrainedConfig
|
|
23
|
-
from transformers.modeling_rope_utils import rope_config_validation
|
|
24
19
|
from transformers.utils import logging
|
|
25
20
|
|
|
26
|
-
from sglang.srt.
|
|
27
|
-
from sglang.srt.layers.
|
|
28
|
-
from sglang.srt.layers.dp_attention import (
|
|
29
|
-
get_attention_tp_size,
|
|
30
|
-
get_tensor_model_parallel_world_size,
|
|
31
|
-
)
|
|
21
|
+
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
|
22
|
+
from sglang.srt.layers.dp_attention import get_tensor_model_parallel_world_size
|
|
32
23
|
|
|
33
24
|
logger = logging.get_logger(__name__)
|
|
34
25
|
|
|
@@ -214,7 +205,7 @@ class FalconH1Config(PretrainedConfig):
|
|
|
214
205
|
self.rope_scaling = None
|
|
215
206
|
self.rope_scaling = rope_scaling
|
|
216
207
|
self.projectors_bias = projectors_bias
|
|
217
|
-
mamba_intermediate = (
|
|
208
|
+
self.mamba_intermediate = mamba_intermediate = (
|
|
218
209
|
mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
|
|
219
210
|
)
|
|
220
211
|
|
|
@@ -294,18 +285,6 @@ class FalconH1Config(PretrainedConfig):
|
|
|
294
285
|
def layers_block_type(self):
|
|
295
286
|
return ["falcon_h1" for i in range(self.num_hidden_layers)]
|
|
296
287
|
|
|
297
|
-
@property
|
|
298
|
-
def mamba_cache_per_req(self):
|
|
299
|
-
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
|
|
300
|
-
self.hybrid_gdn_params
|
|
301
|
-
)
|
|
302
|
-
mamba_layers_len = len(mamba_layers)
|
|
303
|
-
|
|
304
|
-
return (
|
|
305
|
-
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
|
|
306
|
-
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
|
|
307
|
-
) * mamba_layers_len
|
|
308
|
-
|
|
309
288
|
@property
|
|
310
289
|
def full_attention_layer_ids(self):
|
|
311
290
|
# For Falcon-H1, we do have attention on all layers
|
|
@@ -317,44 +296,14 @@ class FalconH1Config(PretrainedConfig):
|
|
|
317
296
|
return range(self.num_hidden_layers)
|
|
318
297
|
|
|
319
298
|
@property
|
|
320
|
-
def
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
self.mamba_n_groups, world_size
|
|
330
|
-
)
|
|
331
|
-
n_groups += extra_groups
|
|
332
|
-
|
|
333
|
-
conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
|
|
334
|
-
|
|
335
|
-
conv_state_shape = (
|
|
336
|
-
divide(conv_dim, world_size),
|
|
337
|
-
self.mamba_d_conv - 1,
|
|
338
|
-
)
|
|
339
|
-
|
|
340
|
-
# we TP-ize on the heads dimension
|
|
341
|
-
temporal_state_shape = (
|
|
342
|
-
self.mamba_d_state,
|
|
343
|
-
self.mamba_d_head,
|
|
344
|
-
divide(self.mamba_n_heads, world_size),
|
|
345
|
-
)
|
|
346
|
-
conv_dtype = torch.bfloat16
|
|
347
|
-
dtype_map = {
|
|
348
|
-
"float32": torch.float32,
|
|
349
|
-
"bfloat16": torch.bfloat16,
|
|
350
|
-
}
|
|
351
|
-
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
|
352
|
-
mamba_layers = self.linear_layer_ids
|
|
353
|
-
|
|
354
|
-
return (
|
|
355
|
-
conv_state_shape,
|
|
356
|
-
temporal_state_shape,
|
|
357
|
-
conv_dtype,
|
|
358
|
-
ssm_dtype,
|
|
359
|
-
mamba_layers,
|
|
299
|
+
def mamba2_cache_params(self):
|
|
300
|
+
shape = Mamba2StateShape.create(
|
|
301
|
+
tp_world_size=get_tensor_model_parallel_world_size(),
|
|
302
|
+
intermediate_size=self.mamba_intermediate,
|
|
303
|
+
n_groups=self.mamba_n_groups,
|
|
304
|
+
num_heads=self.mamba_n_heads,
|
|
305
|
+
head_dim=self.mamba_d_head,
|
|
306
|
+
state_size=self.mamba_d_state,
|
|
307
|
+
conv_kernel=self.mamba_d_conv,
|
|
360
308
|
)
|
|
309
|
+
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
|
2
2
|
import enum
|
|
3
|
-
import json
|
|
4
3
|
import logging
|
|
5
4
|
from dataclasses import dataclass, field
|
|
6
5
|
from typing import List, Optional, Union
|
|
7
6
|
|
|
7
|
+
import orjson
|
|
8
|
+
|
|
9
|
+
from sglang.srt.configs.modelopt_config import ModelOptConfig
|
|
8
10
|
from sglang.srt.utils import is_hip
|
|
9
11
|
|
|
10
12
|
logger = logging.getLogger(__name__)
|
|
@@ -50,6 +52,11 @@ class LoadConfig:
|
|
|
50
52
|
decryption_key_file: If set, decrypts the output files with a password read
|
|
51
53
|
from this file (after PBKDF2).
|
|
52
54
|
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
|
|
55
|
+
|
|
56
|
+
# ModelOpt-specific loading options
|
|
57
|
+
modelopt_checkpoint_restore_path: Optional[str] = None
|
|
58
|
+
modelopt_checkpoint_save_path: Optional[str] = None
|
|
59
|
+
modelopt_export_path: Optional[str] = None
|
|
53
60
|
"""
|
|
54
61
|
|
|
55
62
|
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
|
@@ -63,10 +70,18 @@ class LoadConfig:
|
|
|
63
70
|
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
|
64
71
|
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
|
65
72
|
|
|
73
|
+
# ModelOpt-specific loading options
|
|
74
|
+
modelopt_checkpoint_restore_path: Optional[str] = None
|
|
75
|
+
modelopt_checkpoint_save_path: Optional[str] = None
|
|
76
|
+
modelopt_export_path: Optional[str] = None
|
|
77
|
+
|
|
78
|
+
# ModelOpt configuration object
|
|
79
|
+
modelopt_config: Optional[ModelOptConfig] = None
|
|
80
|
+
|
|
66
81
|
def __post_init__(self):
|
|
67
82
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
|
68
83
|
if isinstance(model_loader_extra_config, str):
|
|
69
|
-
self.model_loader_extra_config =
|
|
84
|
+
self.model_loader_extra_config = orjson.loads(model_loader_extra_config)
|
|
70
85
|
self._verify_load_format()
|
|
71
86
|
|
|
72
87
|
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
|
@@ -77,6 +92,14 @@ class LoadConfig:
|
|
|
77
92
|
else:
|
|
78
93
|
self.ignore_patterns = ["original/**/*"]
|
|
79
94
|
|
|
95
|
+
# Create ModelOptConfig if not provided
|
|
96
|
+
if self.modelopt_config is None:
|
|
97
|
+
self.modelopt_config = ModelOptConfig(
|
|
98
|
+
checkpoint_restore_path=self.modelopt_checkpoint_restore_path,
|
|
99
|
+
checkpoint_save_path=self.modelopt_checkpoint_save_path,
|
|
100
|
+
export_path=self.modelopt_export_path,
|
|
101
|
+
)
|
|
102
|
+
|
|
80
103
|
def _verify_load_format(self) -> None:
|
|
81
104
|
if not isinstance(self.load_format, str):
|
|
82
105
|
return
|