sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,578 @@
|
|
1
|
+
import enum
|
2
|
+
import logging
|
3
|
+
from typing import Any, Iterable, List, Optional, Set, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn
|
7
|
+
|
8
|
+
from sglang.srt.configs.falcon_h1 import FalconH1Config
|
9
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
10
|
+
from sglang.srt.layers.activation import SiluAndMul
|
11
|
+
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
12
|
+
HybridLinearAttnBackend,
|
13
|
+
Mamba2AttnBackend,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
16
|
+
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
17
|
+
from sglang.srt.layers.dp_attention import (
|
18
|
+
get_attention_tp_rank,
|
19
|
+
get_attention_tp_size,
|
20
|
+
is_dp_attention_enabled,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
23
|
+
from sglang.srt.layers.linear import (
|
24
|
+
MergedColumnParallelLinear,
|
25
|
+
QKVParallelLinear,
|
26
|
+
RowParallelLinear,
|
27
|
+
)
|
28
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
29
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
30
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
31
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
32
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
33
|
+
ParallelLMHead,
|
34
|
+
VocabParallelEmbedding,
|
35
|
+
)
|
36
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
37
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
38
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
39
|
+
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
40
|
+
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
_is_cuda = is_cuda()
|
43
|
+
|
44
|
+
|
45
|
+
class FalconH1MLP(nn.Module):
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
hidden_size: int,
|
49
|
+
intermediate_size: int,
|
50
|
+
hidden_act: str,
|
51
|
+
layer_id: int,
|
52
|
+
mlp_multipliers: List[float],
|
53
|
+
quant_config: Optional[QuantizationConfig] = None,
|
54
|
+
prefix: str = "",
|
55
|
+
reduce_results: bool = True,
|
56
|
+
) -> None:
|
57
|
+
super().__init__()
|
58
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
59
|
+
hidden_size,
|
60
|
+
[intermediate_size] * 2,
|
61
|
+
bias=False,
|
62
|
+
quant_config=quant_config,
|
63
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
64
|
+
)
|
65
|
+
self.down_proj = RowParallelLinear(
|
66
|
+
intermediate_size,
|
67
|
+
hidden_size,
|
68
|
+
bias=False,
|
69
|
+
quant_config=quant_config,
|
70
|
+
prefix=add_prefix("down_proj", prefix),
|
71
|
+
reduce_results=reduce_results,
|
72
|
+
)
|
73
|
+
if hidden_act != "silu":
|
74
|
+
raise ValueError(
|
75
|
+
f"Unsupported activation: {hidden_act}. "
|
76
|
+
"Only silu is supported for now."
|
77
|
+
)
|
78
|
+
self.act_fn = SiluAndMul()
|
79
|
+
self.layer_id = layer_id
|
80
|
+
|
81
|
+
self.intermediate_size = intermediate_size
|
82
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
83
|
+
|
84
|
+
self.gate_multiplier, self.down_multiplier = mlp_multipliers
|
85
|
+
|
86
|
+
def forward(
|
87
|
+
self,
|
88
|
+
x,
|
89
|
+
forward_batch=None,
|
90
|
+
use_reduce_scatter: bool = False,
|
91
|
+
):
|
92
|
+
gate_up, _ = self.gate_up_proj(x)
|
93
|
+
gate_up[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier
|
94
|
+
|
95
|
+
x = self.act_fn(gate_up)
|
96
|
+
x, _ = self.down_proj(
|
97
|
+
x,
|
98
|
+
skip_all_reduce=use_reduce_scatter,
|
99
|
+
)
|
100
|
+
x = x * self.down_multiplier
|
101
|
+
return x
|
102
|
+
|
103
|
+
|
104
|
+
class FalconH1HybridAttentionDecoderLayer(nn.Module):
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
config: FalconH1Config,
|
109
|
+
layer_id: int,
|
110
|
+
quant_config: Optional[QuantizationConfig] = None,
|
111
|
+
prefix: str = "",
|
112
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
113
|
+
) -> None:
|
114
|
+
super().__init__()
|
115
|
+
self.config = config
|
116
|
+
self.hidden_size = config.hidden_size
|
117
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
118
|
+
self.attn_tp_size = get_attention_tp_size()
|
119
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
120
|
+
self.total_num_heads = config.num_attention_heads
|
121
|
+
assert self.total_num_heads % self.attn_tp_size == 0
|
122
|
+
self.num_heads = self.total_num_heads // self.attn_tp_size
|
123
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
124
|
+
if self.total_num_kv_heads >= self.attn_tp_size:
|
125
|
+
# Number of KV heads is greater than TP size, so we partition
|
126
|
+
# the KV heads across multiple tensor parallel GPUs.
|
127
|
+
assert self.total_num_kv_heads % self.attn_tp_size == 0
|
128
|
+
else:
|
129
|
+
# Number of KV heads is less than TP size, so we replicate
|
130
|
+
# the KV heads across multiple tensor parallel GPUs.
|
131
|
+
assert self.attn_tp_size % self.total_num_kv_heads == 0
|
132
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
|
133
|
+
self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
|
134
|
+
self.q_size = self.num_heads * self.head_dim
|
135
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
136
|
+
self.scaling = self.head_dim**-0.5
|
137
|
+
self.rope_theta = getattr(config, "rope_theta", 10000)
|
138
|
+
self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
139
|
+
self.rope_scaling = getattr(config, "rope_scaling", None)
|
140
|
+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
141
|
+
self.layer_id = layer_id
|
142
|
+
|
143
|
+
self.rotary_emb = get_rope(
|
144
|
+
head_size=self.head_dim,
|
145
|
+
rotary_dim=self.head_dim,
|
146
|
+
max_position=self.max_position_embeddings,
|
147
|
+
rope_scaling=self.rope_scaling,
|
148
|
+
base=self.rope_theta,
|
149
|
+
partial_rotary_factor=self.partial_rotary_factor,
|
150
|
+
is_neox_style=True,
|
151
|
+
dtype=torch.get_default_dtype(), # see impl of get_rope
|
152
|
+
)
|
153
|
+
|
154
|
+
self.qkv_proj = QKVParallelLinear(
|
155
|
+
config.hidden_size,
|
156
|
+
self.head_dim,
|
157
|
+
self.total_num_heads,
|
158
|
+
self.total_num_kv_heads,
|
159
|
+
bias=False,
|
160
|
+
quant_config=quant_config,
|
161
|
+
tp_rank=self.attn_tp_rank,
|
162
|
+
tp_size=self.attn_tp_size,
|
163
|
+
)
|
164
|
+
|
165
|
+
self.o_proj = RowParallelLinear(
|
166
|
+
self.total_num_heads * self.head_dim,
|
167
|
+
config.hidden_size,
|
168
|
+
bias=False,
|
169
|
+
quant_config=quant_config,
|
170
|
+
reduce_results=False,
|
171
|
+
tp_rank=self.attn_tp_rank,
|
172
|
+
tp_size=self.attn_tp_size,
|
173
|
+
)
|
174
|
+
|
175
|
+
self.attn = RadixAttention(
|
176
|
+
self.num_heads,
|
177
|
+
self.head_dim,
|
178
|
+
self.scaling,
|
179
|
+
num_kv_heads=self.num_kv_heads,
|
180
|
+
layer_id=layer_id,
|
181
|
+
prefix=f"{prefix}.attn",
|
182
|
+
)
|
183
|
+
|
184
|
+
self.d_ssm = (
|
185
|
+
int(config.mamba_expand * config.hidden_size)
|
186
|
+
if config.mamba_d_ssm is None
|
187
|
+
else config.mamba_d_ssm
|
188
|
+
)
|
189
|
+
|
190
|
+
self.mamba = MambaMixer2(
|
191
|
+
cache_params=config.mamba2_cache_params,
|
192
|
+
hidden_size=config.hidden_size,
|
193
|
+
use_conv_bias=config.mamba_conv_bias,
|
194
|
+
use_bias=config.mamba_proj_bias,
|
195
|
+
n_groups=config.mamba_n_groups,
|
196
|
+
rms_norm_eps=config.rms_norm_eps,
|
197
|
+
activation=config.hidden_act,
|
198
|
+
use_rms_norm=config.mamba_rms_norm,
|
199
|
+
prefix=f"{prefix}.mixer",
|
200
|
+
)
|
201
|
+
|
202
|
+
# FalconH1 all layers are sparse and have no nextn now
|
203
|
+
self.is_layer_sparse = False
|
204
|
+
is_previous_layer_sparse = False
|
205
|
+
|
206
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
207
|
+
layer_id=layer_id,
|
208
|
+
num_layers=config.num_hidden_layers,
|
209
|
+
is_layer_sparse=self.is_layer_sparse,
|
210
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
211
|
+
)
|
212
|
+
|
213
|
+
self.feed_forward = FalconH1MLP(
|
214
|
+
hidden_size=self.hidden_size,
|
215
|
+
intermediate_size=config.intermediate_size,
|
216
|
+
hidden_act=config.hidden_act,
|
217
|
+
layer_id=layer_id,
|
218
|
+
mlp_multipliers=config.mlp_multipliers,
|
219
|
+
quant_config=quant_config,
|
220
|
+
prefix=add_prefix("mlp", prefix),
|
221
|
+
)
|
222
|
+
|
223
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
224
|
+
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
225
|
+
|
226
|
+
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
227
|
+
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
228
|
+
|
229
|
+
self.layer_communicator = LayerCommunicator(
|
230
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
231
|
+
input_layernorm=self.input_layernorm,
|
232
|
+
post_attention_layernorm=self.pre_ff_layernorm,
|
233
|
+
allow_reduce_scatter=True,
|
234
|
+
)
|
235
|
+
|
236
|
+
self.alt_stream = alt_stream
|
237
|
+
self.key_multiplier = config.key_multiplier
|
238
|
+
|
239
|
+
self.ssm_out_multiplier = config.ssm_out_multiplier
|
240
|
+
self.ssm_in_multiplier = config.ssm_in_multiplier
|
241
|
+
|
242
|
+
self.attention_in_multiplier = config.attention_in_multiplier
|
243
|
+
self.attn_out_multiplier = config.attention_out_multiplier
|
244
|
+
|
245
|
+
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
|
246
|
+
self.zxbcdt_multipliers = config.ssm_multipliers
|
247
|
+
self._init_mup_vector()
|
248
|
+
|
249
|
+
def _init_mup_vector(self):
|
250
|
+
"""
|
251
|
+
Non learnable per-block scaling vector composed of element-wise
|
252
|
+
multipliersapplied to each separate contiguous block of the output
|
253
|
+
of the linear projection (in_proj) before further processing
|
254
|
+
(gating, convolution, SSM):
|
255
|
+
|
256
|
+
- Z block: [0 : d_ssm] → zxbcdt_multipliers[0]
|
257
|
+
- X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1]
|
258
|
+
- B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2]
|
259
|
+
- C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S]
|
260
|
+
→ zxbcdt_multipliers[3]
|
261
|
+
- dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4]
|
262
|
+
|
263
|
+
where:
|
264
|
+
- d_ssm: Dimension of state-space model latent
|
265
|
+
- G: Number of groups (n_groups)
|
266
|
+
- S: SSM state size per group
|
267
|
+
- All indices are divided by tp_size to support tensor parallelism
|
268
|
+
"""
|
269
|
+
vector_shape = (
|
270
|
+
2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads
|
271
|
+
) // self.tp_size
|
272
|
+
mup_vector = torch.ones(1, vector_shape)
|
273
|
+
# Z vector 0 -> d_ssm
|
274
|
+
mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0]
|
275
|
+
# X vector d_ssm -> 2 * d_ssm
|
276
|
+
mup_vector[
|
277
|
+
:, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size)
|
278
|
+
] *= self.zxbcdt_multipliers[1]
|
279
|
+
# B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state)
|
280
|
+
mup_vector[
|
281
|
+
:,
|
282
|
+
(2 * self.d_ssm)
|
283
|
+
// self.tp_size : (2 * self.d_ssm + self.groups_time_state_size)
|
284
|
+
// self.tp_size,
|
285
|
+
] *= self.zxbcdt_multipliers[2]
|
286
|
+
# C vector 2 * d_ssm + (n_group * d_state)
|
287
|
+
# -> 2 * d_ssm + 2 * (n_group * d_state)
|
288
|
+
mup_vector[
|
289
|
+
:,
|
290
|
+
(2 * self.d_ssm + self.groups_time_state_size)
|
291
|
+
// self.tp_size : (2 * self.d_ssm + 2 * self.groups_time_state_size)
|
292
|
+
// self.tp_size,
|
293
|
+
] *= self.zxbcdt_multipliers[3]
|
294
|
+
# dt vector 2 * d_ssm + 2 * (n_group * d_state)
|
295
|
+
# -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads
|
296
|
+
mup_vector[
|
297
|
+
:,
|
298
|
+
(2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :,
|
299
|
+
] *= self.zxbcdt_multipliers[4]
|
300
|
+
|
301
|
+
self.register_buffer("mup_vector", mup_vector, persistent=False)
|
302
|
+
|
303
|
+
def self_attention(
|
304
|
+
self,
|
305
|
+
positions: torch.Tensor,
|
306
|
+
hidden_states: torch.Tensor,
|
307
|
+
forward_batch: ForwardBatch,
|
308
|
+
) -> torch.Tensor:
|
309
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
310
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
311
|
+
k = k * self.key_multiplier
|
312
|
+
q, k = self.rotary_emb(positions, q, k)
|
313
|
+
|
314
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
315
|
+
|
316
|
+
output, _ = self.o_proj(attn_output)
|
317
|
+
return output
|
318
|
+
|
319
|
+
def forward(
|
320
|
+
self,
|
321
|
+
positions: torch.Tensor,
|
322
|
+
hidden_states: torch.Tensor,
|
323
|
+
residual: Optional[torch.Tensor],
|
324
|
+
forward_batch: ForwardBatch,
|
325
|
+
**kwargs: Any,
|
326
|
+
):
|
327
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
328
|
+
hidden_states, residual, forward_batch
|
329
|
+
)
|
330
|
+
|
331
|
+
if not forward_batch.forward_mode.is_idle():
|
332
|
+
# Attention block
|
333
|
+
attention_hidden_states = self.self_attention(
|
334
|
+
positions=positions,
|
335
|
+
hidden_states=hidden_states * self.attention_in_multiplier,
|
336
|
+
forward_batch=forward_batch,
|
337
|
+
)
|
338
|
+
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
|
339
|
+
|
340
|
+
attn_backend = forward_batch.attn_backend
|
341
|
+
assert isinstance(attn_backend, HybridLinearAttnBackend)
|
342
|
+
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
|
343
|
+
# Mamba block
|
344
|
+
mamba_hidden_states = torch.empty_like(hidden_states)
|
345
|
+
attn_backend.linear_attn_backend.forward(
|
346
|
+
self.mamba,
|
347
|
+
hidden_states * self.ssm_in_multiplier,
|
348
|
+
mamba_hidden_states,
|
349
|
+
layer_id=self.layer_id,
|
350
|
+
mup_vector=self.mup_vector,
|
351
|
+
)
|
352
|
+
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
|
353
|
+
|
354
|
+
hidden_states = attention_hidden_states + mamba_hidden_states
|
355
|
+
|
356
|
+
# Fully Connected
|
357
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
358
|
+
hidden_states, residual, forward_batch
|
359
|
+
)
|
360
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
361
|
+
forward_batch
|
362
|
+
)
|
363
|
+
hidden_states = self.feed_forward(
|
364
|
+
hidden_states, forward_batch, use_reduce_scatter
|
365
|
+
)
|
366
|
+
|
367
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
368
|
+
hidden_states, residual, forward_batch
|
369
|
+
)
|
370
|
+
|
371
|
+
return hidden_states, residual
|
372
|
+
|
373
|
+
|
374
|
+
ALL_DECODER_LAYER_TYPES = {
|
375
|
+
"falcon_h1": FalconH1HybridAttentionDecoderLayer,
|
376
|
+
}
|
377
|
+
|
378
|
+
|
379
|
+
class FalconH1Model(nn.Module):
|
380
|
+
def __init__(
|
381
|
+
self,
|
382
|
+
config: FalconH1Config,
|
383
|
+
quant_config: Optional[QuantizationConfig] = None,
|
384
|
+
prefix: str = "",
|
385
|
+
) -> None:
|
386
|
+
super().__init__()
|
387
|
+
self.config = config
|
388
|
+
|
389
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
390
|
+
self.embedding_multiplier = config.embedding_multiplier
|
391
|
+
|
392
|
+
self.embed_tokens = VocabParallelEmbedding(
|
393
|
+
config.vocab_size,
|
394
|
+
config.hidden_size,
|
395
|
+
org_num_embeddings=config.vocab_size,
|
396
|
+
enable_tp=not is_dp_attention_enabled(),
|
397
|
+
)
|
398
|
+
|
399
|
+
def get_layer(idx: int, prefix: str):
|
400
|
+
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
|
401
|
+
return layer_class(
|
402
|
+
config,
|
403
|
+
idx,
|
404
|
+
quant_config=quant_config,
|
405
|
+
prefix=prefix,
|
406
|
+
alt_stream=alt_stream,
|
407
|
+
)
|
408
|
+
|
409
|
+
self.layers = make_layers(
|
410
|
+
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
|
411
|
+
)
|
412
|
+
|
413
|
+
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
414
|
+
self.infer_count = 0
|
415
|
+
|
416
|
+
def forward(
|
417
|
+
self,
|
418
|
+
input_ids: torch.Tensor,
|
419
|
+
positions: torch.Tensor,
|
420
|
+
forward_batch: ForwardBatch,
|
421
|
+
# mamba_cache_params: MambaCacheParams,
|
422
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
423
|
+
) -> torch.Tensor:
|
424
|
+
|
425
|
+
# pass a sequence index tensor, that is required for
|
426
|
+
# proper continuous batching computation including
|
427
|
+
# chunked prefill
|
428
|
+
if inputs_embeds is not None:
|
429
|
+
hidden_states = inputs_embeds * self.embedding_multiplier
|
430
|
+
else:
|
431
|
+
hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier
|
432
|
+
|
433
|
+
residual = None
|
434
|
+
for i in range(len(self.layers)):
|
435
|
+
layer = self.layers[i]
|
436
|
+
hidden_states, residual = layer(
|
437
|
+
layer_id=i,
|
438
|
+
positions=positions,
|
439
|
+
hidden_states=hidden_states,
|
440
|
+
residual=residual,
|
441
|
+
forward_batch=forward_batch,
|
442
|
+
)
|
443
|
+
|
444
|
+
if not forward_batch.forward_mode.is_idle():
|
445
|
+
if residual is None:
|
446
|
+
hidden_states = self.final_layernorm(hidden_states)
|
447
|
+
else:
|
448
|
+
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
449
|
+
|
450
|
+
return hidden_states
|
451
|
+
|
452
|
+
|
453
|
+
class HybridLayerType(enum.Enum):
|
454
|
+
full_attention = "attention"
|
455
|
+
swa_attention = "swa_attention"
|
456
|
+
linear_attention = "linear_attention"
|
457
|
+
mamba2 = "mamba"
|
458
|
+
|
459
|
+
|
460
|
+
class FalconH1ForCausalLM(nn.Module):
|
461
|
+
fall_back_to_pt_during_load = False
|
462
|
+
|
463
|
+
def __init__(
|
464
|
+
self,
|
465
|
+
config: FalconH1Config,
|
466
|
+
quant_config: Optional[QuantizationConfig] = None,
|
467
|
+
prefix: str = "",
|
468
|
+
) -> None:
|
469
|
+
super().__init__()
|
470
|
+
self.config = config
|
471
|
+
self.pp_group = get_pp_group()
|
472
|
+
assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
|
473
|
+
self.quant_config = quant_config
|
474
|
+
self.model = FalconH1Model(
|
475
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
476
|
+
)
|
477
|
+
if config.tie_word_embeddings:
|
478
|
+
self.lm_head = self.model.embed_tokens
|
479
|
+
else:
|
480
|
+
self.lm_head = ParallelLMHead(
|
481
|
+
config.vocab_size,
|
482
|
+
config.hidden_size,
|
483
|
+
quant_config=quant_config,
|
484
|
+
org_num_embeddings=config.vocab_size,
|
485
|
+
prefix=add_prefix("lm_head", prefix),
|
486
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
487
|
+
)
|
488
|
+
self.lm_head = self.lm_head.float()
|
489
|
+
self.lm_head_multiplier = config.lm_head_multiplier
|
490
|
+
self.logits_processor = LogitsProcessor(
|
491
|
+
config, logit_scale=self.lm_head_multiplier
|
492
|
+
)
|
493
|
+
|
494
|
+
@torch.no_grad()
|
495
|
+
def forward(
|
496
|
+
self,
|
497
|
+
input_ids: torch.Tensor,
|
498
|
+
positions: torch.Tensor,
|
499
|
+
forward_batch: ForwardBatch,
|
500
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
501
|
+
**kwargs,
|
502
|
+
):
|
503
|
+
hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
|
504
|
+
|
505
|
+
return self.logits_processor(
|
506
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
507
|
+
)
|
508
|
+
|
509
|
+
def get_embed_and_head(self):
|
510
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
511
|
+
|
512
|
+
def set_embed_and_head(self, embed, head):
|
513
|
+
del self.model.embed_tokens.weight
|
514
|
+
del self.lm_head.weight
|
515
|
+
self.model.embed_tokens.weight = embed
|
516
|
+
self.lm_head.weight = head
|
517
|
+
torch.cuda.empty_cache()
|
518
|
+
torch.cuda.synchronize()
|
519
|
+
|
520
|
+
def load_weights(
|
521
|
+
self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
|
522
|
+
) -> Set[str]:
|
523
|
+
stacked_params_mapping = [
|
524
|
+
# (param_name, shard_name, shard_id)
|
525
|
+
("qkv_proj", "q_proj", "q"),
|
526
|
+
("qkv_proj", "k_proj", "k"),
|
527
|
+
("qkv_proj", "v_proj", "v"),
|
528
|
+
("gate_up_proj", "gate_proj", 0),
|
529
|
+
("gate_up_proj", "up_proj", 1),
|
530
|
+
]
|
531
|
+
|
532
|
+
params_dict = dict(self.named_parameters())
|
533
|
+
loaded_params: Set[str] = set()
|
534
|
+
for name, loaded_weight in weights:
|
535
|
+
|
536
|
+
if "rotary_emb.inv_freq" in name:
|
537
|
+
continue
|
538
|
+
|
539
|
+
if ".self_attn." in name:
|
540
|
+
name = name.replace(".self_attn", "")
|
541
|
+
|
542
|
+
if "A_log" in name:
|
543
|
+
name = name.replace("A_log", "A")
|
544
|
+
|
545
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
546
|
+
if weight_name not in name:
|
547
|
+
continue
|
548
|
+
|
549
|
+
name = name.replace(weight_name, param_name)
|
550
|
+
# Skip loading extra bias for GPTQ models.
|
551
|
+
if name.endswith(".bias") and name not in params_dict:
|
552
|
+
continue
|
553
|
+
# Skip layers on other devices.
|
554
|
+
# if is_pp_missing_parameter(name, self):
|
555
|
+
# continue
|
556
|
+
if name not in params_dict:
|
557
|
+
continue
|
558
|
+
param = params_dict[name]
|
559
|
+
weight_loader = getattr(param, "weight_loader")
|
560
|
+
weight_loader(param, loaded_weight, shard_id)
|
561
|
+
break
|
562
|
+
else:
|
563
|
+
# Skip loading extra bias for GPTQ models.
|
564
|
+
if name.endswith(".bias") and name not in params_dict:
|
565
|
+
continue
|
566
|
+
# if is_pp_missing_parameter(name, self):
|
567
|
+
# continue
|
568
|
+
|
569
|
+
param = params_dict[name]
|
570
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
571
|
+
|
572
|
+
weight_loader(param, loaded_weight)
|
573
|
+
|
574
|
+
loaded_params.add(name)
|
575
|
+
return loaded_params
|
576
|
+
|
577
|
+
|
578
|
+
EntryClass = FalconH1ForCausalLM
|
@@ -20,7 +20,6 @@ import torch.nn.functional as F
|
|
20
20
|
from torch import nn
|
21
21
|
from transformers import (
|
22
22
|
ROPE_INIT_FUNCTIONS,
|
23
|
-
AutoModel,
|
24
23
|
Gemma3TextConfig,
|
25
24
|
PretrainedConfig,
|
26
25
|
PreTrainedModel,
|
@@ -761,4 +760,3 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
|
761
760
|
|
762
761
|
|
763
762
|
EntryClass = Gemma3ForCausalLM
|
764
|
-
AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
|
17
17
|
|
18
18
|
import logging
|
19
|
+
import re
|
19
20
|
from functools import lru_cache
|
20
21
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
21
22
|
|
@@ -23,7 +24,6 @@ import torch
|
|
23
24
|
from torch import nn
|
24
25
|
from transformers import Gemma3Config, PreTrainedModel
|
25
26
|
|
26
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
27
27
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
28
28
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
29
29
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -44,6 +44,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
44
44
|
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
|
45
45
|
from sglang.srt.models.siglip import SiglipVisionModel
|
46
46
|
from sglang.srt.utils import add_prefix
|
47
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
47
48
|
|
48
49
|
logger = logging.getLogger(__name__)
|
49
50
|
|
@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
154
155
|
embedding_modules = {}
|
155
156
|
embedding_padding_modules = []
|
156
157
|
supports_lora = True
|
158
|
+
# Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
|
159
|
+
lora_pattern = re.compile(
|
160
|
+
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
161
|
+
)
|
157
162
|
|
158
163
|
def __init__(
|
159
164
|
self,
|
@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
165
170
|
self.config = config
|
166
171
|
self.quant_config = quant_config
|
167
172
|
|
173
|
+
# For LoRA compatibility: expose text_config attributes at top level
|
174
|
+
# This allows LoRA code to work without special multimodal handling
|
175
|
+
if not hasattr(config, "num_hidden_layers"):
|
176
|
+
config.num_hidden_layers = config.text_config.num_hidden_layers
|
177
|
+
if not hasattr(config, "hidden_size"):
|
178
|
+
config.hidden_size = config.text_config.hidden_size
|
179
|
+
|
168
180
|
self.vision_tower = SiglipVisionModel(
|
169
181
|
config=config.vision_config,
|
170
182
|
quant_config=quant_config,
|
@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
380
392
|
|
381
393
|
return hs
|
382
394
|
|
395
|
+
def should_apply_lora(self, module_name: str) -> bool:
|
396
|
+
"""Skip vision tower and multi_modal_projector for LoRA."""
|
397
|
+
return bool(self.lora_pattern.match(module_name))
|
398
|
+
|
383
399
|
def tie_weights(self):
|
384
400
|
return self.language_model.tie_weights()
|
385
401
|
|
sglang/srt/models/gemma3n_mm.py
CHANGED
@@ -14,7 +14,6 @@ from transformers import (
|
|
14
14
|
)
|
15
15
|
from transformers.models.auto.modeling_auto import AutoModel
|
16
16
|
|
17
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
18
17
|
from sglang.srt.layers.layernorm import RMSNorm
|
19
18
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
20
19
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -38,6 +37,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
38
37
|
from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
|
39
38
|
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
|
40
39
|
from sglang.srt.utils import add_prefix
|
40
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
41
41
|
|
42
42
|
logger = logging.getLogger(__name__)
|
43
43
|
|
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
499
499
|
def should_apply_lora(self, module_name: str) -> bool:
|
500
500
|
return bool(self.lora_pattern.match(module_name))
|
501
501
|
|
502
|
-
def get_hidden_dim(self, module_name):
|
502
|
+
def get_hidden_dim(self, module_name, layer_idx):
|
503
503
|
# return input_dim, output_dim
|
504
504
|
if module_name == "qkv_proj":
|
505
505
|
return (
|
sglang/srt/models/glm4_moe.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
-
"""Inference-only GLM-4.5 model compatible with HuggingFace weights"""
|
15
|
+
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
|
16
16
|
|
17
17
|
import logging
|
18
18
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
429
429
|
routed_scaling_factor=self.routed_scaling_factor,
|
430
430
|
)
|
431
431
|
|
432
|
-
self.experts = get_moe_impl_class()(
|
432
|
+
self.experts = get_moe_impl_class(quant_config)(
|
433
433
|
num_experts=config.n_routed_experts
|
434
434
|
+ self.num_fused_shared_experts
|
435
435
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
@@ -785,9 +785,9 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
785
785
|
or self.config.architectures[0] != architecture
|
786
786
|
or self.config.n_shared_experts != 1
|
787
787
|
):
|
788
|
-
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
788
|
+
disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
789
789
|
elif get_moe_expert_parallel_world_size() > 1:
|
790
|
-
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
|
790
|
+
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
|
791
791
|
|
792
792
|
if disable_reason is not None:
|
793
793
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|