sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Adapted from Qwen2.5-VL SGLang implementation
|
3
|
+
|
4
|
+
import logging
|
5
|
+
from typing import Iterable, List, Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
from transformers.activations import ACT2FN
|
10
|
+
|
11
|
+
from sglang.srt.configs import DotsOCRConfig
|
12
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
13
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
14
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
15
|
+
from sglang.srt.managers.mm_utils import (
|
16
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
17
|
+
general_mm_embed_routine,
|
18
|
+
)
|
19
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
20
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
21
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
22
|
+
from sglang.srt.models.dots_vlm_vit import DotsVisionTransformer
|
23
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
24
|
+
from sglang.srt.utils import add_prefix
|
25
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class DotsOCRForCausalLM(nn.Module):
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
config: DotsOCRConfig,
|
34
|
+
quant_config: Optional[QuantizationConfig] = None,
|
35
|
+
prefix: str = "",
|
36
|
+
) -> None:
|
37
|
+
super().__init__()
|
38
|
+
self.config = config
|
39
|
+
|
40
|
+
# Initialize vision transformer
|
41
|
+
self.visual = DotsVisionTransformer(
|
42
|
+
config.vision_config,
|
43
|
+
)
|
44
|
+
|
45
|
+
# Initialize language model
|
46
|
+
self.model = Qwen2ForCausalLM(config, quant_config)
|
47
|
+
|
48
|
+
# Initialize LM head
|
49
|
+
if config.tie_word_embeddings:
|
50
|
+
self.lm_head = self.model.embed_tokens
|
51
|
+
else:
|
52
|
+
self.lm_head = ParallelLMHead(
|
53
|
+
config.vocab_size,
|
54
|
+
config.hidden_size,
|
55
|
+
quant_config=quant_config,
|
56
|
+
prefix=add_prefix("lm_head", prefix),
|
57
|
+
)
|
58
|
+
|
59
|
+
self.logits_processor = LogitsProcessor(config)
|
60
|
+
|
61
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
62
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
63
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
64
|
+
|
65
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
66
|
+
# Extract pixel values and grid information (following reference pattern)
|
67
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
68
|
+
self.visual.dtype
|
69
|
+
)
|
70
|
+
image_grid_thw = torch.concat(
|
71
|
+
[item.image_grid_thw for item in items], dim=0
|
72
|
+
).to(self.visual.device)
|
73
|
+
|
74
|
+
# Add dimension checks like in reference code
|
75
|
+
assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
|
76
|
+
assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
|
77
|
+
|
78
|
+
# Process through vision tower
|
79
|
+
image_embeds = self.visual(pixel_values, image_grid_thw)
|
80
|
+
|
81
|
+
# Ensure consistent dtype for FlashInfer compatibility
|
82
|
+
# Force bfloat16 to match model's expected dtype
|
83
|
+
if hasattr(self.model, "embed_tokens"):
|
84
|
+
target_dtype = self.model.embed_tokens.weight.dtype
|
85
|
+
if image_embeds.dtype != target_dtype:
|
86
|
+
image_embeds = image_embeds.to(target_dtype)
|
87
|
+
|
88
|
+
return image_embeds
|
89
|
+
|
90
|
+
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
91
|
+
"""pad attn qkv weights for dummy heads"""
|
92
|
+
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
93
|
+
if num_dummy_heads == 0:
|
94
|
+
return loaded_weight
|
95
|
+
head_dim = self.config.vision_config.head_dim
|
96
|
+
|
97
|
+
if "attn.qkv_proj" in name:
|
98
|
+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
99
|
+
if name.endswith(".weight"):
|
100
|
+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
101
|
+
elif name.endswith(".bias"):
|
102
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
103
|
+
else:
|
104
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
105
|
+
pad_func = lambda x: torch.cat(
|
106
|
+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
107
|
+
).flatten(0, 1)
|
108
|
+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
109
|
+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
110
|
+
if "attn.proj.weight" in name:
|
111
|
+
padded_weight = loaded_weight.new_zeros(
|
112
|
+
loaded_weight.shape[0], head_dim * num_dummy_heads
|
113
|
+
)
|
114
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
115
|
+
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
116
|
+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
117
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
118
|
+
return loaded_weight
|
119
|
+
|
120
|
+
def forward(
|
121
|
+
self,
|
122
|
+
input_ids: torch.Tensor,
|
123
|
+
positions: torch.Tensor,
|
124
|
+
forward_batch: ForwardBatch,
|
125
|
+
**kwargs: object,
|
126
|
+
) -> torch.Tensor:
|
127
|
+
hidden_states = general_mm_embed_routine(
|
128
|
+
input_ids=input_ids,
|
129
|
+
positions=positions,
|
130
|
+
forward_batch=forward_batch,
|
131
|
+
multimodal_model=self,
|
132
|
+
language_model=self.model,
|
133
|
+
)
|
134
|
+
return hidden_states
|
135
|
+
|
136
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
137
|
+
"""Load weights for the model, separating vision and language weights"""
|
138
|
+
weights = list(weights)
|
139
|
+
|
140
|
+
# Separate vision tower weights and language model weights
|
141
|
+
vision_weights = []
|
142
|
+
language_weights = []
|
143
|
+
|
144
|
+
for name, loaded_weight in weights:
|
145
|
+
if name.startswith("vision_tower."):
|
146
|
+
vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
147
|
+
|
148
|
+
vision_weights.append((vision_name, loaded_weight))
|
149
|
+
else:
|
150
|
+
# All other weights go to language model
|
151
|
+
language_weights.append((name, loaded_weight))
|
152
|
+
|
153
|
+
# Load vision tower weights
|
154
|
+
vision_state_dict = dict(vision_weights)
|
155
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
156
|
+
|
157
|
+
for name, loaded_weight in vision_state_dict.items():
|
158
|
+
name = name.replace("vision_tower", "visual")
|
159
|
+
if name not in params_dict:
|
160
|
+
raise ValueError(f"Weight {name} not found in params_dict")
|
161
|
+
param = params_dict[name]
|
162
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
163
|
+
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
|
164
|
+
weight_loader(param, loaded_weight)
|
165
|
+
|
166
|
+
if language_weights:
|
167
|
+
self.model.load_weights(language_weights)
|
168
|
+
|
169
|
+
def get_embed_and_head(self):
|
170
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
171
|
+
|
172
|
+
|
173
|
+
EntryClass = [DotsOCRForCausalLM]
|
@@ -0,0 +1,174 @@
|
|
1
|
+
# Copyright 2025 The RedNote HiLab team.
|
2
|
+
# Copyright 2025 The SGLang team.
|
3
|
+
#
|
4
|
+
# This code is based on the DeepseekVL2ForCausalLM and DotsVisionTransformer
|
5
|
+
# implementation in this library.
|
6
|
+
#
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
|
+
# you may not use this file except in compliance with the License.
|
9
|
+
# You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16
|
+
# See the License for the specific language governing permissions and
|
17
|
+
# limitations under the License.
|
18
|
+
"""Inference-only Dots-VL model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
from typing import Iterable, List, Optional, Tuple
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
|
25
|
+
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
26
|
+
from sglang.srt.distributed import parallel_state
|
27
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
|
+
from sglang.srt.managers.mm_utils import (
|
29
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
30
|
+
general_mm_embed_routine,
|
31
|
+
)
|
32
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
33
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
34
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
35
|
+
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
36
|
+
|
37
|
+
from .dots_vlm_vit import DotsVisionTransformer
|
38
|
+
|
39
|
+
|
40
|
+
class DotsVLMForCausalLM(nn.Module):
|
41
|
+
"""DotsVLM model for sglang inference"""
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self, config: DotsVLMConfig, quant_config: Optional[QuantizationConfig] = None
|
45
|
+
) -> None:
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
self.config = config
|
49
|
+
self.image_token_id = config.im_span_id
|
50
|
+
self.video_token_id = config.video_span_id
|
51
|
+
|
52
|
+
self.language_model = DeepseekV2ForCausalLM(
|
53
|
+
config.language_config, quant_config
|
54
|
+
)
|
55
|
+
|
56
|
+
# Initialize vision tower (matching transformers naming for weight compatibility)
|
57
|
+
self.vision_tower = DotsVisionTransformer(config.vision_config)
|
58
|
+
|
59
|
+
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
60
|
+
"""pad attn qkv weights for dummy heads"""
|
61
|
+
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
62
|
+
if num_dummy_heads == 0:
|
63
|
+
return loaded_weight
|
64
|
+
head_dim = self.config.vision_config.head_dim
|
65
|
+
|
66
|
+
if "attn.qkv_proj" in name:
|
67
|
+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
68
|
+
if name.endswith(".weight"):
|
69
|
+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
70
|
+
elif name.endswith(".bias"):
|
71
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
72
|
+
else:
|
73
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
74
|
+
pad_func = lambda x: torch.cat(
|
75
|
+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
76
|
+
).flatten(0, 1)
|
77
|
+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
78
|
+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
79
|
+
if "attn.proj.weight" in name:
|
80
|
+
padded_weight = loaded_weight.new_zeros(
|
81
|
+
loaded_weight.shape[0], head_dim * num_dummy_heads
|
82
|
+
)
|
83
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
84
|
+
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
85
|
+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
86
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
87
|
+
return loaded_weight
|
88
|
+
|
89
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
90
|
+
"""Load weights for the model, separating vision and language weights"""
|
91
|
+
weights = list(weights)
|
92
|
+
|
93
|
+
# Separate vision tower weights and language model weights
|
94
|
+
vision_weights = []
|
95
|
+
language_weights = []
|
96
|
+
|
97
|
+
for name, loaded_weight in weights:
|
98
|
+
if name.startswith("vision_tower."):
|
99
|
+
vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
100
|
+
vision_weights.append((vision_name, loaded_weight))
|
101
|
+
else:
|
102
|
+
# All other weights go to language model
|
103
|
+
language_weights.append((name, loaded_weight))
|
104
|
+
|
105
|
+
# Load vision tower weights
|
106
|
+
vision_state_dict = dict(vision_weights)
|
107
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
108
|
+
for name, loaded_weight in vision_state_dict.items():
|
109
|
+
if name not in params_dict:
|
110
|
+
raise ValueError(f"Weight {name} not found in params_dict")
|
111
|
+
param = params_dict[name]
|
112
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
113
|
+
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
|
114
|
+
weight_loader(param, loaded_weight)
|
115
|
+
|
116
|
+
# Load language model weights
|
117
|
+
if language_weights:
|
118
|
+
self.language_model.load_weights(language_weights)
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def get_model_config_for_expert_location(cls, config):
|
122
|
+
return DeepseekV2ForCausalLM.get_model_config_for_expert_location(config)
|
123
|
+
|
124
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
125
|
+
"""Pad input_ids with multimodal tokens"""
|
126
|
+
# Get image token ID for padding pattern
|
127
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
128
|
+
padded_input_ids = pattern.pad_input_tokens(input_ids, mm_inputs)
|
129
|
+
return padded_input_ids
|
130
|
+
|
131
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
132
|
+
# Extract pixel values and grid information (following reference pattern)
|
133
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
134
|
+
self.vision_tower.dtype
|
135
|
+
)
|
136
|
+
image_grid_thw = torch.concat(
|
137
|
+
[item.image_grid_thw for item in items], dim=0
|
138
|
+
).to(self.vision_tower.device)
|
139
|
+
|
140
|
+
# Add dimension checks like in reference code
|
141
|
+
assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
|
142
|
+
assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
|
143
|
+
|
144
|
+
# Process through vision tower
|
145
|
+
image_embeds = self.vision_tower(pixel_values, image_grid_thw)
|
146
|
+
|
147
|
+
# Ensure consistent dtype for FlashInfer compatibility
|
148
|
+
# Force bfloat16 to match model's expected dtype
|
149
|
+
if image_embeds.dtype != torch.bfloat16 and hasattr(
|
150
|
+
self.language_model.model, "embed_tokens"
|
151
|
+
):
|
152
|
+
target_dtype = self.language_model.model.embed_tokens.weight.dtype
|
153
|
+
image_embeds = image_embeds.to(target_dtype)
|
154
|
+
|
155
|
+
return image_embeds
|
156
|
+
|
157
|
+
def forward(
|
158
|
+
self,
|
159
|
+
input_ids: torch.Tensor,
|
160
|
+
positions: torch.Tensor,
|
161
|
+
forward_batch: ForwardBatch,
|
162
|
+
**kwargs: object,
|
163
|
+
) -> torch.Tensor:
|
164
|
+
hidden_states = general_mm_embed_routine(
|
165
|
+
input_ids=input_ids,
|
166
|
+
positions=positions,
|
167
|
+
forward_batch=forward_batch,
|
168
|
+
multimodal_model=self,
|
169
|
+
language_model=self.language_model,
|
170
|
+
)
|
171
|
+
return hidden_states
|
172
|
+
|
173
|
+
|
174
|
+
EntryClass = [DotsVLMForCausalLM]
|
@@ -0,0 +1,337 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
import torch.nn.functional as F
|
7
|
+
import torch.utils.checkpoint
|
8
|
+
from torch.nn import LayerNorm
|
9
|
+
from transformers.modeling_utils import PreTrainedModel
|
10
|
+
|
11
|
+
from sglang.srt.configs.dots_vlm import DotsVisionConfig
|
12
|
+
from sglang.srt.distributed import parallel_state
|
13
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
14
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
15
|
+
from sglang.srt.utils import add_prefix
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
class VisionRotaryEmbedding(nn.Module):
|
21
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
22
|
+
super().__init__()
|
23
|
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
24
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
25
|
+
|
26
|
+
def forward(self, seqlen: int) -> torch.Tensor:
|
27
|
+
seq = torch.arange(
|
28
|
+
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
29
|
+
)
|
30
|
+
freqs = torch.outer(seq, self.inv_freq)
|
31
|
+
return freqs
|
32
|
+
|
33
|
+
|
34
|
+
class PatchMerger(nn.Module):
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
dim: int,
|
38
|
+
context_dim: int,
|
39
|
+
spatial_merge_size: int = 2,
|
40
|
+
pre_norm="layernorm",
|
41
|
+
init_merger_std=None,
|
42
|
+
quant_config: Optional[QuantizationConfig] = None,
|
43
|
+
) -> None:
|
44
|
+
super().__init__()
|
45
|
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
46
|
+
self.pre_norm = pre_norm
|
47
|
+
if self.pre_norm == "layernorm":
|
48
|
+
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
49
|
+
elif self.pre_norm == "rmsnorm":
|
50
|
+
self.ln_q = RMSNorm(context_dim, eps=1e-6)
|
51
|
+
else:
|
52
|
+
logger.warning(f"no norm in patch merger: {self.pre_norm}")
|
53
|
+
|
54
|
+
self.mlp = nn.Sequential(
|
55
|
+
nn.Linear(self.hidden_size, self.hidden_size),
|
56
|
+
nn.GELU(),
|
57
|
+
nn.Linear(self.hidden_size, dim),
|
58
|
+
)
|
59
|
+
|
60
|
+
if init_merger_std is not None:
|
61
|
+
nn.init.normal_(self.mlp[0].weight, mean=0.0, std=init_merger_std)
|
62
|
+
nn.init.zeros_(self.mlp[0].bias)
|
63
|
+
nn.init.normal_(self.mlp[2].weight, mean=0.0, std=init_merger_std)
|
64
|
+
nn.init.zeros_(self.mlp[2].bias)
|
65
|
+
|
66
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
67
|
+
if self.pre_norm:
|
68
|
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
69
|
+
else:
|
70
|
+
x = self.mlp(x.view(-1, self.hidden_size))
|
71
|
+
return x
|
72
|
+
|
73
|
+
|
74
|
+
class RMSNorm(nn.Module):
|
75
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
76
|
+
super().__init__()
|
77
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
78
|
+
self.eps = eps
|
79
|
+
|
80
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
81
|
+
output = self._norm(x.float()).type_as(x)
|
82
|
+
return output * self.weight
|
83
|
+
|
84
|
+
def extra_repr(self) -> str:
|
85
|
+
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
86
|
+
|
87
|
+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
88
|
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
89
|
+
|
90
|
+
|
91
|
+
class DotsSwiGLUFFN(nn.Module):
|
92
|
+
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
93
|
+
super().__init__()
|
94
|
+
hidden_features = config.intermediate_size
|
95
|
+
in_features = config.embed_dim
|
96
|
+
bias = config.use_bias
|
97
|
+
|
98
|
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
99
|
+
self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
|
100
|
+
self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
|
101
|
+
|
102
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
103
|
+
x = F.silu(self.fc1(x)) * self.fc3(x)
|
104
|
+
x = self.fc2(x)
|
105
|
+
return x
|
106
|
+
|
107
|
+
|
108
|
+
class DotsPatchEmbed(nn.Module):
|
109
|
+
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
110
|
+
super().__init__()
|
111
|
+
self.num_channels = config.num_channels
|
112
|
+
self.patch_size = config.patch_size
|
113
|
+
self.temporal_patch_size = config.temporal_patch_size
|
114
|
+
self.embed_dim = config.embed_dim
|
115
|
+
self.config = config
|
116
|
+
self.proj = nn.Conv2d(
|
117
|
+
config.num_channels,
|
118
|
+
config.embed_dim,
|
119
|
+
kernel_size=(config.patch_size, config.patch_size),
|
120
|
+
stride=(config.patch_size, config.patch_size),
|
121
|
+
)
|
122
|
+
self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
123
|
+
|
124
|
+
def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
|
125
|
+
x = x.view(
|
126
|
+
-1,
|
127
|
+
self.num_channels,
|
128
|
+
self.temporal_patch_size,
|
129
|
+
self.patch_size,
|
130
|
+
self.patch_size,
|
131
|
+
)[:, :, 0]
|
132
|
+
x = self.proj(x).view(-1, self.embed_dim)
|
133
|
+
x = self.norm(x)
|
134
|
+
return x
|
135
|
+
|
136
|
+
|
137
|
+
class DotsViTPreprocessor(nn.Module):
|
138
|
+
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
139
|
+
super().__init__()
|
140
|
+
self.patch_h = config.patch_size
|
141
|
+
self.patch_w = config.patch_size
|
142
|
+
self.embed_dim = config.embed_dim
|
143
|
+
self.config = config
|
144
|
+
self.patchifier = DotsPatchEmbed(config, quant_config)
|
145
|
+
|
146
|
+
def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
|
147
|
+
tokens = self.patchifier(x, grid_thw)
|
148
|
+
return tokens
|
149
|
+
|
150
|
+
|
151
|
+
class DotsVisionBlock(nn.Module):
|
152
|
+
def __init__(
|
153
|
+
self,
|
154
|
+
config: DotsVisionConfig,
|
155
|
+
quant_config: Optional[QuantizationConfig] = None,
|
156
|
+
prefix: str = "",
|
157
|
+
attn_implementation: str = "flash_attention_2",
|
158
|
+
):
|
159
|
+
super().__init__()
|
160
|
+
if attn_implementation == "flash_attention_2":
|
161
|
+
qkv_backend = "fa3"
|
162
|
+
softmax_in_single_precision = False
|
163
|
+
else:
|
164
|
+
raise RuntimeError("Unimplemented")
|
165
|
+
self.attn = VisionAttention(
|
166
|
+
embed_dim=config.embed_dim,
|
167
|
+
num_heads=config.num_attention_heads,
|
168
|
+
projection_size=config.embed_dim,
|
169
|
+
use_qkv_parallel=True,
|
170
|
+
qkv_backend=qkv_backend,
|
171
|
+
softmax_in_single_precision=softmax_in_single_precision,
|
172
|
+
flatten_batch=True,
|
173
|
+
quant_config=quant_config,
|
174
|
+
prefix=add_prefix("attn", prefix),
|
175
|
+
num_dummy_heads=config.num_dummy_heads,
|
176
|
+
qkv_bias=config.use_bias,
|
177
|
+
proj_bias=config.use_bias,
|
178
|
+
)
|
179
|
+
self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
180
|
+
self.mlp = DotsSwiGLUFFN(config, quant_config)
|
181
|
+
self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
182
|
+
|
183
|
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
184
|
+
hidden_states = hidden_states + self.attn(
|
185
|
+
self.norm1(hidden_states),
|
186
|
+
cu_seqlens=cu_seqlens,
|
187
|
+
position_embeddings=rotary_pos_emb,
|
188
|
+
)
|
189
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
190
|
+
return hidden_states
|
191
|
+
|
192
|
+
|
193
|
+
class DotsVisionTransformer(PreTrainedModel):
|
194
|
+
def __init__(
|
195
|
+
self,
|
196
|
+
config: DotsVisionConfig,
|
197
|
+
quant_config: Optional[QuantizationConfig] = None,
|
198
|
+
) -> None:
|
199
|
+
super().__init__(config)
|
200
|
+
self.config = config
|
201
|
+
self._update_vision_config()
|
202
|
+
self.spatial_merge_size = config.spatial_merge_size
|
203
|
+
|
204
|
+
self.patch_embed = DotsViTPreprocessor(config, quant_config)
|
205
|
+
self._init_weights(self.patch_embed.patchifier.proj)
|
206
|
+
|
207
|
+
head_dim = config.embed_dim // config.num_attention_heads
|
208
|
+
|
209
|
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
210
|
+
|
211
|
+
_num_hidden_layers = config.num_hidden_layers
|
212
|
+
self.blocks = nn.ModuleList(
|
213
|
+
[
|
214
|
+
DotsVisionBlock(
|
215
|
+
config, quant_config, f"blocks.{i}", config.attn_implementation
|
216
|
+
)
|
217
|
+
for i in range(_num_hidden_layers)
|
218
|
+
]
|
219
|
+
)
|
220
|
+
|
221
|
+
if self.config.post_norm:
|
222
|
+
self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
|
223
|
+
|
224
|
+
self.merger = PatchMerger(
|
225
|
+
dim=config.hidden_size,
|
226
|
+
context_dim=config.embed_dim,
|
227
|
+
spatial_merge_size=config.spatial_merge_size,
|
228
|
+
init_merger_std=self.config.init_merger_std,
|
229
|
+
quant_config=quant_config,
|
230
|
+
)
|
231
|
+
|
232
|
+
self.gradient_checkpointing = False
|
233
|
+
|
234
|
+
def _update_vision_config(self):
|
235
|
+
"""update vision config to support tp"""
|
236
|
+
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
237
|
+
num_heads = self.config.num_attention_heads
|
238
|
+
head_dim = self.config.embed_dim // num_heads
|
239
|
+
num_dummy_heads = 0
|
240
|
+
|
241
|
+
if num_heads % world_size != 0:
|
242
|
+
num_dummy_heads = (
|
243
|
+
(num_heads + world_size) // world_size
|
244
|
+
) * world_size - num_heads
|
245
|
+
|
246
|
+
setattr(self.config, "head_dim", head_dim)
|
247
|
+
setattr(self.config, "num_dummy_heads", num_dummy_heads)
|
248
|
+
|
249
|
+
def _init_weights(self, module):
|
250
|
+
std = self.config.initializer_range
|
251
|
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
252
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
253
|
+
if module.bias is not None:
|
254
|
+
module.bias.data.zero_()
|
255
|
+
elif isinstance(module, nn.Embedding):
|
256
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
257
|
+
if module.padding_idx is not None:
|
258
|
+
module.weight.data[module.padding_idx].zero_()
|
259
|
+
|
260
|
+
@property
|
261
|
+
def dtype(self) -> torch.dtype:
|
262
|
+
return self.blocks[0].mlp.fc2.weight.dtype
|
263
|
+
|
264
|
+
@property
|
265
|
+
def device(self) -> torch.device:
|
266
|
+
return self.blocks[0].mlp.fc2.weight.device
|
267
|
+
|
268
|
+
def get_pos_ids_by_grid(self, grid_thw):
|
269
|
+
pos_ids = []
|
270
|
+
for t, h, w in grid_thw:
|
271
|
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
272
|
+
hpos_ids = hpos_ids.reshape(
|
273
|
+
h // self.spatial_merge_size,
|
274
|
+
self.spatial_merge_size,
|
275
|
+
w // self.spatial_merge_size,
|
276
|
+
self.spatial_merge_size,
|
277
|
+
)
|
278
|
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
279
|
+
hpos_ids = hpos_ids.flatten()
|
280
|
+
|
281
|
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
282
|
+
wpos_ids = wpos_ids.reshape(
|
283
|
+
h // self.spatial_merge_size,
|
284
|
+
self.spatial_merge_size,
|
285
|
+
w // self.spatial_merge_size,
|
286
|
+
self.spatial_merge_size,
|
287
|
+
)
|
288
|
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
289
|
+
wpos_ids = wpos_ids.flatten()
|
290
|
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
291
|
+
|
292
|
+
return pos_ids
|
293
|
+
|
294
|
+
def rot_pos_emb(self, grid_thw):
|
295
|
+
pos_ids = self.get_pos_ids_by_grid(grid_thw)
|
296
|
+
pos_ids = torch.cat(pos_ids, dim=0)
|
297
|
+
max_grid_size = grid_thw[:, 1:].max()
|
298
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
299
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
300
|
+
return rotary_pos_emb
|
301
|
+
|
302
|
+
def calc_cos_sin(self, rotary_pos_emb):
|
303
|
+
cos = rotary_pos_emb.cos()
|
304
|
+
sin = rotary_pos_emb.sin()
|
305
|
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
306
|
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
307
|
+
rotary_pos_emb = (cos, sin)
|
308
|
+
return rotary_pos_emb
|
309
|
+
|
310
|
+
def forward(
|
311
|
+
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=True
|
312
|
+
) -> torch.Tensor:
|
313
|
+
if bf16:
|
314
|
+
hidden_states = hidden_states.bfloat16()
|
315
|
+
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
316
|
+
|
317
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
318
|
+
rotary_pos_emb = self.calc_cos_sin(rotary_pos_emb)
|
319
|
+
|
320
|
+
cu_seqlens = torch.repeat_interleave(
|
321
|
+
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
322
|
+
).cumsum(
|
323
|
+
dim=0,
|
324
|
+
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
325
|
+
)
|
326
|
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
327
|
+
|
328
|
+
for blk in self.blocks:
|
329
|
+
hidden_states = blk(
|
330
|
+
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
331
|
+
)
|
332
|
+
|
333
|
+
if self.config.post_norm:
|
334
|
+
hidden_states = self.post_trunk_norm(hidden_states)
|
335
|
+
|
336
|
+
hidden_states = self.merger(hidden_states)
|
337
|
+
return hidden_states
|
sglang/srt/models/ernie4.py
CHANGED
@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
|
|
92
92
|
correction_bias=self.gate.e_score_correction_bias,
|
93
93
|
)
|
94
94
|
|
95
|
-
self.experts = get_moe_impl_class()(
|
95
|
+
self.experts = get_moe_impl_class(quant_config)(
|
96
96
|
num_experts=config.moe_num_experts,
|
97
97
|
top_k=config.moe_k,
|
98
98
|
hidden_size=config.hidden_size,
|