sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
-
"""Inference-only GLM-4.5 NextN Speculative Decoding."""
|
15
|
+
"""Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
|
16
16
|
import logging
|
17
17
|
from typing import Iterable, Optional, Tuple
|
18
18
|
|
@@ -48,7 +48,7 @@ class Glm4MoeModelNextN(nn.Module):
|
|
48
48
|
super().__init__()
|
49
49
|
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
50
50
|
logger.warning(
|
51
|
-
"Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
|
51
|
+
"Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 / GLM-4.6 model."
|
52
52
|
)
|
53
53
|
quant_config = None
|
54
54
|
|
sglang/srt/models/glm4v.py
CHANGED
@@ -7,7 +7,6 @@ import torch.nn as nn
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
|
9
9
|
|
10
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
11
10
|
from sglang.srt.layers.activation import SiluAndMul
|
12
11
|
from sglang.srt.layers.attention import vision_utils
|
13
12
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -28,6 +27,7 @@ from sglang.srt.models.qwen2_5_vl import (
|
|
28
27
|
Qwen2_5_VLForConditionalGeneration,
|
29
28
|
)
|
30
29
|
from sglang.srt.utils import add_prefix
|
30
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
31
31
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
@@ -93,9 +93,8 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
|
|
93
93
|
quant_config=quant_config,
|
94
94
|
prefix=prefix,
|
95
95
|
num_dummy_heads=config.num_dummy_heads,
|
96
|
+
rms_norm_eps=config.rms_norm_eps,
|
96
97
|
)
|
97
|
-
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
98
|
-
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
99
98
|
|
100
99
|
self.mlp = Glm4vVisionMLP(
|
101
100
|
config.hidden_size,
|
@@ -498,6 +497,9 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
498
497
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
499
498
|
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
500
499
|
|
500
|
+
# For EAGLE3 support
|
501
|
+
self.capture_aux_hidden_states = False
|
502
|
+
|
501
503
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
502
504
|
pixel_values = torch.cat(
|
503
505
|
[item.feature.squeeze(0) for item in items], dim=0
|
sglang/srt/models/glm4v_moe.py
CHANGED
@@ -10,7 +10,6 @@ from sglang.srt.distributed import (
|
|
10
10
|
get_moe_expert_parallel_world_size,
|
11
11
|
get_tensor_model_parallel_world_size,
|
12
12
|
)
|
13
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
14
13
|
from sglang.srt.layers.attention import vision_utils
|
15
14
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
16
15
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
@@ -22,6 +21,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
22
21
|
from sglang.srt.models.glm4_moe import Glm4MoeModel
|
23
22
|
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
|
24
23
|
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
|
24
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
25
25
|
|
26
26
|
_is_cuda = is_cuda()
|
27
27
|
|
@@ -74,6 +74,9 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
|
|
74
74
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
75
75
|
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
76
76
|
|
77
|
+
# For EAGLE3 support
|
78
|
+
self.capture_aux_hidden_states = False
|
79
|
+
|
77
80
|
def determine_num_fused_shared_experts(
|
78
81
|
self, architecture: str = "Glm4MoeForCausalLM"
|
79
82
|
):
|
sglang/srt/models/gpt_oss.py
CHANGED
@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
66
66
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
67
67
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
68
68
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
69
|
+
from sglang.srt.models.utils import (
|
70
|
+
create_fused_set_kv_buffer_arg,
|
71
|
+
enable_fused_set_kv_buffer,
|
72
|
+
)
|
69
73
|
from sglang.srt.utils import (
|
70
74
|
LazyValue,
|
71
75
|
add_prefix,
|
@@ -121,7 +125,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
121
125
|
)
|
122
126
|
|
123
127
|
self.top_k = config.num_experts_per_tok
|
124
|
-
experts_type = get_moe_impl_class()
|
128
|
+
experts_type = get_moe_impl_class(quant_config)
|
125
129
|
extra_kwargs = {}
|
126
130
|
if experts_type.__name__ == "FusedMoE":
|
127
131
|
quant_config_name = (
|
@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
193
197
|
return ans
|
194
198
|
|
195
199
|
|
196
|
-
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
197
|
-
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
198
|
-
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
199
|
-
|
200
|
-
|
201
|
-
# TODO maybe move to a model-common utils
|
202
|
-
def _create_fused_set_kv_buffer_arg(
|
203
|
-
value: torch.Tensor,
|
204
|
-
layer: RadixAttention,
|
205
|
-
forward_batch: ForwardBatch,
|
206
|
-
):
|
207
|
-
layer_id = layer.layer_id
|
208
|
-
token_to_kv_pool = forward_batch.token_to_kv_pool
|
209
|
-
|
210
|
-
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
211
|
-
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
212
|
-
|
213
|
-
return FusedSetKVBufferArg(
|
214
|
-
value=value,
|
215
|
-
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
216
|
-
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
217
|
-
k_scale=layer.k_scale,
|
218
|
-
v_scale=layer.v_scale,
|
219
|
-
cache_loc=forward_batch.out_cache_loc,
|
220
|
-
)
|
221
|
-
|
222
|
-
|
223
200
|
class GptOssAttention(nn.Module):
|
224
201
|
def __init__(
|
225
202
|
self,
|
@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
|
|
337
314
|
q,
|
338
315
|
k,
|
339
316
|
fused_set_kv_buffer_arg=(
|
340
|
-
|
317
|
+
create_fused_set_kv_buffer_arg(
|
341
318
|
value=v,
|
342
319
|
layer=self.attn,
|
343
320
|
forward_batch=forward_batch,
|
344
321
|
)
|
345
|
-
if
|
322
|
+
if enable_fused_set_kv_buffer(forward_batch)
|
346
323
|
else None
|
347
324
|
),
|
348
325
|
)
|
@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
|
|
356
333
|
attn_output = self.attn(
|
357
334
|
*inner_state,
|
358
335
|
sinks=self.sinks,
|
359
|
-
save_kv_cache=not
|
336
|
+
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
|
360
337
|
)
|
361
338
|
output, _ = self.o_proj(attn_output)
|
362
339
|
return output
|
sglang/srt/models/grok.py
CHANGED
@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import (
|
|
49
49
|
RowParallelLinear,
|
50
50
|
)
|
51
51
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
53
52
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
54
53
|
from sglang.srt.layers.moe.router import fused_moe_router_shim
|
55
54
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module):
|
|
176
175
|
custom_routing_function=custom_routing_function,
|
177
176
|
)
|
178
177
|
|
179
|
-
|
180
|
-
if get_moe_expert_parallel_world_size() > 1:
|
181
|
-
MoEImpl = EPMoE
|
182
|
-
else:
|
183
|
-
MoEImpl = FusedMoE
|
184
|
-
kwargs["reduce_results"] = reduce_results
|
185
|
-
kwargs["use_presharded_weights"] = use_presharded_weights
|
186
|
-
kwargs["inplace"] = inplace
|
187
|
-
kwargs["no_combine"] = no_combine
|
188
|
-
|
189
|
-
self.experts = MoEImpl(
|
178
|
+
self.experts = FusedMoE(
|
190
179
|
num_experts=num_experts,
|
191
180
|
top_k=top_k,
|
192
181
|
layer_id=layer_id,
|
@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module):
|
|
195
184
|
params_dtype=params_dtype,
|
196
185
|
quant_config=quant_config,
|
197
186
|
activation="gelu",
|
198
|
-
|
187
|
+
reduce_results=reduce_results,
|
188
|
+
use_presharded_weights=use_presharded_weights,
|
189
|
+
inplace=inplace,
|
190
|
+
no_combine=no_combine,
|
199
191
|
)
|
200
192
|
|
201
193
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union
|
|
49
49
|
import torch
|
50
50
|
import torch.nn as nn
|
51
51
|
import torch.nn.functional as F
|
52
|
-
from transformers.activations import ACT2FN,
|
52
|
+
from transformers.activations import ACT2FN, GELUTanh
|
53
53
|
from transformers.modeling_utils import PreTrainedModel
|
54
54
|
|
55
55
|
try:
|
@@ -614,7 +614,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|
614
614
|
"num_heads": config.num_attention_heads,
|
615
615
|
"hidden_dim": config.hidden_size,
|
616
616
|
"mlp_dim": config.intermediate_size,
|
617
|
-
"activation":
|
617
|
+
"activation": GELUTanh(),
|
618
618
|
"attn_bias": True,
|
619
619
|
"attn_implementation": config._attn_implementation,
|
620
620
|
},
|
sglang/srt/models/llama.py
CHANGED
@@ -385,6 +385,10 @@ class LlamaModel(nn.Module):
|
|
385
385
|
"Self attention has no KV cache scaling " "factor attribute!"
|
386
386
|
)
|
387
387
|
|
388
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
389
|
+
"""Get input embeddings from the model."""
|
390
|
+
return self.embed_tokens
|
391
|
+
|
388
392
|
|
389
393
|
class LlamaForCausalLM(nn.Module):
|
390
394
|
# BitandBytes specific attributes
|
sglang/srt/models/llama4.py
CHANGED
@@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module):
|
|
423
423
|
return self.config.num_local_experts > 0
|
424
424
|
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
|
425
425
|
|
426
|
+
def get_intermediate_size(self) -> int:
|
427
|
+
if isinstance(self.feed_forward, Llama4MoE):
|
428
|
+
return self.config.intermediate_size
|
429
|
+
else:
|
430
|
+
return self.config.intermediate_size_mlp
|
431
|
+
|
426
432
|
def forward(
|
427
433
|
self,
|
428
434
|
positions: torch.Tensor,
|
@@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
|
540
546
|
def get_input_embeddings(self):
|
541
547
|
return self.model.embed_tokens
|
542
548
|
|
549
|
+
def get_layers(self):
|
550
|
+
return self.model.layers
|
551
|
+
|
543
552
|
def _init_model(
|
544
553
|
self,
|
545
554
|
config: Llama4TextConfig,
|
@@ -109,6 +109,16 @@ class LlamaModel(nn.Module):
|
|
109
109
|
) -> None:
|
110
110
|
super().__init__()
|
111
111
|
self.config = config
|
112
|
+
|
113
|
+
self.is_mrope_enabled = (
|
114
|
+
hasattr(config, "rope_scaling")
|
115
|
+
and config.rope_scaling is not None
|
116
|
+
and "mrope_section" in config.rope_scaling
|
117
|
+
)
|
118
|
+
# fix rope_scaling for qwen2.5-vl
|
119
|
+
if self.is_mrope_enabled:
|
120
|
+
config.rope_scaling["rope_type"] = "default"
|
121
|
+
|
112
122
|
self.vocab_size = config.vocab_size
|
113
123
|
self.embed_tokens = VocabParallelEmbedding(
|
114
124
|
config.vocab_size,
|
@@ -144,6 +154,9 @@ class LlamaModel(nn.Module):
|
|
144
154
|
else:
|
145
155
|
embeds = input_embeds
|
146
156
|
|
157
|
+
if self.is_mrope_enabled:
|
158
|
+
positions = forward_batch.mrope_positions
|
159
|
+
|
147
160
|
hidden_states = forward_batch.spec_info.hidden_states
|
148
161
|
if hidden_states.shape[-1] != embeds.shape[-1]:
|
149
162
|
hidden_states = self.fc(hidden_states)
|
@@ -131,7 +131,7 @@ elif _is_hip:
|
|
131
131
|
awq_dequantize_triton as awq_dequantize,
|
132
132
|
)
|
133
133
|
else:
|
134
|
-
|
134
|
+
pass
|
135
135
|
|
136
136
|
logger = logging.getLogger(__name__)
|
137
137
|
|
@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
|
|
260
260
|
)
|
261
261
|
self.topk.forward = self.topk.forward_native
|
262
262
|
|
263
|
-
self.experts = get_moe_impl_class()(
|
263
|
+
self.experts = get_moe_impl_class(quant_config)(
|
264
264
|
num_experts=self.num_experts,
|
265
265
|
top_k=self.top_k,
|
266
266
|
layer_id=self.layer_id,
|
@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
|
|
853
853
|
|
854
854
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
855
855
|
# (param_name, weight_name, expert_id, shard_id)
|
856
|
-
expert_params_mapping =
|
856
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
857
857
|
ckpt_gate_proj_name="gate_proj",
|
858
858
|
ckpt_down_proj_name="down_proj",
|
859
859
|
ckpt_up_proj_name="up_proj",
|
sglang/srt/models/mixtral.py
CHANGED
@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import (
|
|
36
36
|
RowParallelLinear,
|
37
37
|
)
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
|
-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
40
39
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
41
40
|
from sglang.srt.layers.moe.topk import TopK
|
42
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module):
|
|
94
93
|
renormalize=True,
|
95
94
|
)
|
96
95
|
|
97
|
-
|
98
|
-
self.experts = MoEImpl(
|
96
|
+
self.experts = FusedMoE(
|
99
97
|
num_experts=num_experts,
|
100
98
|
top_k=top_k,
|
101
99
|
layer_id=layer_id,
|
sglang/srt/models/mllama4.py
CHANGED
@@ -2,6 +2,7 @@ import json as json_lib
|
|
2
2
|
import logging
|
3
3
|
import math
|
4
4
|
import os
|
5
|
+
import re
|
5
6
|
from collections.abc import Iterable
|
6
7
|
from typing import List, Optional, Set, Tuple
|
7
8
|
|
@@ -291,7 +292,7 @@ class Llama4UnfoldConvolution(nn.Module):
|
|
291
292
|
|
292
293
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
293
294
|
hidden_states = self.unfold(hidden_states)
|
294
|
-
hidden_states = hidden_states.permute(0, 2, 1)
|
295
|
+
hidden_states = hidden_states.permute(0, 2, 1).contiguous()
|
295
296
|
hidden_states, _ = self.linear(hidden_states)
|
296
297
|
return hidden_states
|
297
298
|
|
@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
422
423
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
423
424
|
}
|
424
425
|
|
426
|
+
# Pattern to match language model layers only (skip vision_model and multi_modal_projector)
|
427
|
+
lora_pattern = re.compile(
|
428
|
+
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
429
|
+
)
|
430
|
+
|
425
431
|
def __init__(
|
426
432
|
self,
|
427
433
|
config: Llama4Config,
|
@@ -446,9 +452,20 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
446
452
|
)
|
447
453
|
|
448
454
|
if self.has_vision:
|
455
|
+
# TODO: make this more general
|
456
|
+
ignore_quant_layers = getattr(config, "quantization_config", {}).get(
|
457
|
+
"ignore", {}
|
458
|
+
)
|
459
|
+
if (
|
460
|
+
"model.layers.vision_model*" in ignore_quant_layers
|
461
|
+
and "model.layers.multi_modal_projector*" in ignore_quant_layers
|
462
|
+
):
|
463
|
+
vision_quant_config = None
|
464
|
+
else:
|
465
|
+
vision_quant_config = quant_config
|
449
466
|
self.vision_model = Llama4VisionModel(
|
450
467
|
config.vision_config,
|
451
|
-
quant_config=
|
468
|
+
quant_config=vision_quant_config,
|
452
469
|
prefix=add_prefix("vision_model", prefix),
|
453
470
|
)
|
454
471
|
|
@@ -544,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
544
561
|
|
545
562
|
return projected_vision_flat
|
546
563
|
|
564
|
+
def should_apply_lora(self, module_name: str) -> bool:
|
565
|
+
"""Skip vision model and multi_modal_projector for LoRA."""
|
566
|
+
return bool(self.lora_pattern.match(module_name))
|
567
|
+
|
547
568
|
def forward(
|
548
569
|
self,
|
549
570
|
input_ids: torch.Tensor,
|
@@ -560,7 +581,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
560
581
|
forward_batch=forward_batch,
|
561
582
|
language_model=self.language_model,
|
562
583
|
data_embedding_funcs={
|
563
|
-
Modality.IMAGE:
|
584
|
+
Modality.IMAGE: image_embedding_func,
|
564
585
|
},
|
565
586
|
positions=positions,
|
566
587
|
)
|
@@ -689,7 +710,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
689
710
|
"""Handle scale parameter remapping. Returns True if handled."""
|
690
711
|
if "scale" in name and "expert" not in name:
|
691
712
|
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
692
|
-
return remapped_name
|
713
|
+
return remapped_name != name
|
693
714
|
return False
|
694
715
|
|
695
716
|
def _handle_stacked_params(
|
@@ -961,5 +982,30 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
961
982
|
def set_embed(self, embed):
|
962
983
|
return self.language_model.set_embed(embed)
|
963
984
|
|
985
|
+
def get_hidden_dim(self, module_name, layer_idx):
|
986
|
+
# return input_dim, output_dim
|
987
|
+
if module_name == "qkv_proj":
|
988
|
+
return (
|
989
|
+
self.config.hidden_size,
|
990
|
+
self.config.head_dim
|
991
|
+
* (
|
992
|
+
self.config.num_attention_heads
|
993
|
+
+ self.config.num_key_value_heads * 2
|
994
|
+
),
|
995
|
+
)
|
996
|
+
elif module_name == "o_proj":
|
997
|
+
return (
|
998
|
+
self.config.head_dim * self.config.num_attention_heads,
|
999
|
+
self.config.hidden_size,
|
1000
|
+
)
|
1001
|
+
elif module_name == "gate_up_proj":
|
1002
|
+
return self.config.hidden_size, self.config.intermediate_size * 2
|
1003
|
+
elif module_name == "down_proj":
|
1004
|
+
decoder_layer = self.language_model.get_layers()[layer_idx]
|
1005
|
+
intermediate_size = decoder_layer.get_intermediate_size()
|
1006
|
+
return intermediate_size, self.config.hidden_size
|
1007
|
+
else:
|
1008
|
+
raise NotImplementedError()
|
1009
|
+
|
964
1010
|
|
965
1011
|
EntryClass = Llama4ForConditionalGeneration
|