sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/models/glm4_moe.py
CHANGED
@@ -153,7 +153,13 @@ class Glm4MoeMLP(nn.Module):
|
|
153
153
|
)
|
154
154
|
self.act_fn = SiluAndMul()
|
155
155
|
|
156
|
-
def forward(
|
156
|
+
def forward(
|
157
|
+
self,
|
158
|
+
x,
|
159
|
+
forward_batch=None,
|
160
|
+
should_allreduce_fusion=False,
|
161
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
162
|
+
):
|
157
163
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
158
164
|
return x
|
159
165
|
|
@@ -501,6 +507,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
501
507
|
hidden_states: torch.Tensor,
|
502
508
|
should_allreduce_fusion: bool = False,
|
503
509
|
use_reduce_scatter: bool = False,
|
510
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
504
511
|
) -> torch.Tensor:
|
505
512
|
|
506
513
|
current_stream = torch.cuda.current_stream()
|
@@ -543,6 +550,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|
543
550
|
hidden_states: torch.Tensor,
|
544
551
|
should_allreduce_fusion: bool = False,
|
545
552
|
use_reduce_scatter: bool = False,
|
553
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
546
554
|
) -> torch.Tensor:
|
547
555
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
548
556
|
self.shared_experts.gate_up_proj
|
@@ -666,6 +674,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|
666
674
|
forward_batch: ForwardBatch,
|
667
675
|
residual: Optional[torch.Tensor],
|
668
676
|
zero_allocator: BumpAllocator,
|
677
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
669
678
|
) -> torch.Tensor:
|
670
679
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
671
680
|
hidden_states, residual, forward_batch
|
sglang/srt/models/glm4v.py
CHANGED
@@ -93,9 +93,8 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
|
|
93
93
|
quant_config=quant_config,
|
94
94
|
prefix=prefix,
|
95
95
|
num_dummy_heads=config.num_dummy_heads,
|
96
|
+
rms_norm_eps=config.rms_norm_eps,
|
96
97
|
)
|
97
|
-
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
98
|
-
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
99
98
|
|
100
99
|
self.mlp = Glm4vVisionMLP(
|
101
100
|
config.hidden_size,
|
@@ -498,6 +497,9 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
498
497
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
499
498
|
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
500
499
|
|
500
|
+
# For EAGLE3 support
|
501
|
+
self.capture_aux_hidden_states = False
|
502
|
+
|
501
503
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
502
504
|
pixel_values = torch.cat(
|
503
505
|
[item.feature.squeeze(0) for item in items], dim=0
|
sglang/srt/models/gpt_oss.py
CHANGED
@@ -58,7 +58,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
58
58
|
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
59
59
|
from sglang.srt.layers.radix_attention import RadixAttention
|
60
60
|
from sglang.srt.layers.rotary_embedding import get_rope
|
61
|
-
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
61
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
62
62
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
63
63
|
ParallelLMHead,
|
64
64
|
VocabParallelEmbedding,
|
@@ -71,6 +71,7 @@ from sglang.srt.utils import (
|
|
71
71
|
add_prefix,
|
72
72
|
is_cuda,
|
73
73
|
is_flashinfer_available,
|
74
|
+
is_sm100_supported,
|
74
75
|
make_layers,
|
75
76
|
)
|
76
77
|
|
@@ -192,8 +193,9 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
192
193
|
return ans
|
193
194
|
|
194
195
|
|
195
|
-
def _enable_fused_set_kv_buffer():
|
196
|
-
|
196
|
+
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
197
|
+
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
198
|
+
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
197
199
|
|
198
200
|
|
199
201
|
# TODO maybe move to a model-common utils
|
@@ -340,7 +342,7 @@ class GptOssAttention(nn.Module):
|
|
340
342
|
layer=self.attn,
|
341
343
|
forward_batch=forward_batch,
|
342
344
|
)
|
343
|
-
if _enable_fused_set_kv_buffer()
|
345
|
+
if _enable_fused_set_kv_buffer(forward_batch)
|
344
346
|
else None
|
345
347
|
),
|
346
348
|
)
|
@@ -354,7 +356,7 @@ class GptOssAttention(nn.Module):
|
|
354
356
|
attn_output = self.attn(
|
355
357
|
*inner_state,
|
356
358
|
sinks=self.sinks,
|
357
|
-
save_kv_cache=not _enable_fused_set_kv_buffer(),
|
359
|
+
save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
|
358
360
|
)
|
359
361
|
output, _ = self.o_proj(attn_output)
|
360
362
|
return output
|
@@ -1028,10 +1030,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1028
1030
|
)
|
1029
1031
|
|
1030
1032
|
params_dict = dict(self.named_parameters())
|
1031
|
-
params_checker = {k: False for k, v in params_dict.items()}
|
1032
|
-
|
1033
|
-
for other_loaded_param_name in other_loaded_param_names:
|
1034
|
-
params_checker[other_loaded_param_name] = True
|
1035
1033
|
|
1036
1034
|
for name, loaded_weight in weights:
|
1037
1035
|
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
|
@@ -1068,7 +1066,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1068
1066
|
param = params_dict[name]
|
1069
1067
|
weight_loader = param.weight_loader
|
1070
1068
|
weight_loader(param, loaded_weight, shard_id)
|
1071
|
-
params_checker[name] = True
|
1072
1069
|
break
|
1073
1070
|
else:
|
1074
1071
|
for mapping in expert_params_mapping:
|
@@ -1091,7 +1088,6 @@ class GptOssForCausalLM(nn.Module):
|
|
1091
1088
|
name,
|
1092
1089
|
shard_id=shard_id,
|
1093
1090
|
)
|
1094
|
-
params_checker[name] = True
|
1095
1091
|
break
|
1096
1092
|
else:
|
1097
1093
|
if name.endswith(".bias") and name not in params_dict:
|
@@ -1110,17 +1106,9 @@ class GptOssForCausalLM(nn.Module):
|
|
1110
1106
|
param, "weight_loader", default_weight_loader
|
1111
1107
|
)
|
1112
1108
|
weight_loader(param, loaded_weight)
|
1113
|
-
params_checker[name] = True
|
1114
1109
|
else:
|
1115
1110
|
logger.warning(f"Parameter {name} not found in params_dict")
|
1116
1111
|
|
1117
|
-
not_loaded_params = [k for k, v in params_checker.items() if not v]
|
1118
|
-
if tp_rank == 0:
|
1119
|
-
if len(not_loaded_params) > 0:
|
1120
|
-
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
|
1121
|
-
else:
|
1122
|
-
logging.info("All parameters loaded successfully.")
|
1123
|
-
|
1124
1112
|
def get_embed_and_head(self):
|
1125
1113
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
1126
1114
|
|
sglang/srt/models/internvl.py
CHANGED
@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import (
|
|
26
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
27
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
28
28
|
from sglang.srt.models.deepseek_janus_pro import DropPath
|
29
|
+
from sglang.srt.models.gpt_oss import GptOssForCausalLM
|
29
30
|
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
30
31
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
32
|
+
from sglang.srt.models.qwen3 import Qwen3ForCausalLM
|
31
33
|
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
32
34
|
from sglang.utils import logger
|
33
35
|
|
@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module):
|
|
445
447
|
self.language_model = Qwen3MoeForCausalLM(
|
446
448
|
config=config.llm_config, quant_config=quant_config
|
447
449
|
)
|
450
|
+
elif config.llm_config.architectures[0] == "GptOssForCausalLM":
|
451
|
+
self.language_model = GptOssForCausalLM(
|
452
|
+
config=config.llm_config, quant_config=quant_config
|
453
|
+
)
|
454
|
+
elif config.llm_config.architectures[0] == "Qwen3ForCausalLM":
|
455
|
+
self.language_model = Qwen3ForCausalLM(
|
456
|
+
config=config.llm_config, quant_config=quant_config
|
457
|
+
)
|
448
458
|
else:
|
449
459
|
raise NotImplementedError(
|
450
460
|
f"{config.llm_config.architectures[0]} is not implemented."
|
@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module):
|
|
577
587
|
ckpt_up_proj_name="up_proj",
|
578
588
|
num_experts=self.config.num_experts,
|
579
589
|
)
|
590
|
+
elif "Qwen3ForCausalLM" in self.config.llm_config.architectures:
|
591
|
+
stacked_params_mapping = [
|
592
|
+
# (param_name, shard_name, shard_id)
|
593
|
+
("qkv_proj", "q_proj", "q"),
|
594
|
+
("qkv_proj", "k_proj", "k"),
|
595
|
+
("qkv_proj", "v_proj", "v"),
|
596
|
+
("gate_up_proj", "gate_proj", 0),
|
597
|
+
("gate_up_proj", "up_proj", 1),
|
598
|
+
]
|
580
599
|
|
581
600
|
params_dict = dict(self.named_parameters())
|
582
601
|
loaded_params: Set[str] = set()
|
@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module):
|
|
661
680
|
|
662
681
|
loaded_params.add(name)
|
663
682
|
unloaded_params = params_dict.keys() - loaded_params
|
683
|
+
# Skip params that are created by quantization wrappers and are not expected in the ckpt
|
684
|
+
_quant_only_fragments = (
|
685
|
+
"weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
|
686
|
+
)
|
687
|
+
unloaded_params = {
|
688
|
+
n
|
689
|
+
for n in unloaded_params
|
690
|
+
if not any(frag in n for frag in _quant_only_fragments)
|
691
|
+
}
|
664
692
|
if unloaded_params:
|
665
693
|
raise RuntimeError(
|
666
694
|
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
sglang/srt/models/llama4.py
CHANGED
@@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module):
|
|
423
423
|
return self.config.num_local_experts > 0
|
424
424
|
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
|
425
425
|
|
426
|
+
def get_intermediate_size(self) -> int:
|
427
|
+
if isinstance(self.feed_forward, Llama4MoE):
|
428
|
+
return self.config.intermediate_size
|
429
|
+
else:
|
430
|
+
return self.config.intermediate_size_mlp
|
431
|
+
|
426
432
|
def forward(
|
427
433
|
self,
|
428
434
|
positions: torch.Tensor,
|
@@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
|
540
546
|
def get_input_embeddings(self):
|
541
547
|
return self.model.embed_tokens
|
542
548
|
|
549
|
+
def get_layers(self):
|
550
|
+
return self.model.layers
|
551
|
+
|
543
552
|
def _init_model(
|
544
553
|
self,
|
545
554
|
config: Llama4TextConfig,
|
@@ -109,6 +109,16 @@ class LlamaModel(nn.Module):
|
|
109
109
|
) -> None:
|
110
110
|
super().__init__()
|
111
111
|
self.config = config
|
112
|
+
|
113
|
+
self.is_mrope_enabled = (
|
114
|
+
hasattr(config, "rope_scaling")
|
115
|
+
and config.rope_scaling is not None
|
116
|
+
and "mrope_section" in config.rope_scaling
|
117
|
+
)
|
118
|
+
# fix rope_scaling for qwen2.5-vl
|
119
|
+
if self.is_mrope_enabled:
|
120
|
+
config.rope_scaling["rope_type"] = "default"
|
121
|
+
|
112
122
|
self.vocab_size = config.vocab_size
|
113
123
|
self.embed_tokens = VocabParallelEmbedding(
|
114
124
|
config.vocab_size,
|
@@ -144,6 +154,9 @@ class LlamaModel(nn.Module):
|
|
144
154
|
else:
|
145
155
|
embeds = input_embeds
|
146
156
|
|
157
|
+
if self.is_mrope_enabled:
|
158
|
+
positions = forward_batch.mrope_positions
|
159
|
+
|
147
160
|
hidden_states = forward_batch.spec_info.hidden_states
|
148
161
|
if hidden_states.shape[-1] != embeds.shape[-1]:
|
149
162
|
hidden_states = self.fc(hidden_states)
|
@@ -185,9 +198,13 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
|
185
198
|
)
|
186
199
|
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
187
200
|
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
201
|
+
self.load_lm_head_from_target = False
|
188
202
|
if self.config.tie_word_embeddings:
|
189
203
|
self.lm_head = self.model.embed_tokens
|
190
204
|
else:
|
205
|
+
if config.draft_vocab_size is None:
|
206
|
+
self.load_lm_head_from_target = True
|
207
|
+
config.draft_vocab_size = config.vocab_size
|
191
208
|
self.lm_head = ParallelLMHead(
|
192
209
|
config.draft_vocab_size,
|
193
210
|
config.hidden_size,
|