sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/mllama.py
CHANGED
@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|
22
22
|
from sglang.srt.layers.linear import (
|
23
23
|
ColumnParallelLinear,
|
24
24
|
QKVParallelLinear,
|
25
|
+
ReplicatedLinear,
|
25
26
|
RowParallelLinear,
|
26
27
|
)
|
27
28
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
184
185
|
def __init__(
|
185
186
|
self,
|
186
187
|
config: config_mllama.MllamaVisionConfig,
|
188
|
+
quant_config: Optional[QuantizationConfig] = None,
|
187
189
|
is_gated: bool = False,
|
188
190
|
prefix: str = "",
|
189
191
|
):
|
@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
199
201
|
self.num_attention_heads,
|
200
202
|
self.hidden_size,
|
201
203
|
use_qkv_parallel=True,
|
202
|
-
quant_config=
|
204
|
+
quant_config=quant_config,
|
203
205
|
dropout=0.0,
|
204
206
|
use_context_forward=False,
|
205
207
|
softmax_in_single_precision=False,
|
206
208
|
flatten_batch=False,
|
207
209
|
prefix=add_prefix("self_attn", prefix),
|
208
210
|
)
|
209
|
-
self.mlp = MllamaVisionMLP(
|
211
|
+
self.mlp = MllamaVisionMLP(
|
212
|
+
config, quant_config, prefix=add_prefix("mlp", prefix)
|
213
|
+
)
|
210
214
|
|
211
215
|
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
212
216
|
self.post_attention_layernorm = nn.LayerNorm(
|
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
|
|
244
248
|
def __init__(
|
245
249
|
self,
|
246
250
|
config: config_mllama.MllamaVisionConfig,
|
251
|
+
quant_config: Optional[QuantizationConfig] = None,
|
247
252
|
num_layers=32,
|
248
253
|
is_gated=False,
|
249
254
|
output_hidden_states=None,
|
@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module):
|
|
254
259
|
self.layers = nn.ModuleList(
|
255
260
|
[
|
256
261
|
MllamaVisionEncoderLayer(
|
257
|
-
config,
|
262
|
+
config,
|
263
|
+
quant_config,
|
264
|
+
is_gated,
|
265
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
258
266
|
)
|
259
267
|
for i in range(num_layers)
|
260
268
|
]
|
@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module):
|
|
283
291
|
|
284
292
|
|
285
293
|
class MllamaVisionModel(nn.Module):
|
286
|
-
def __init__(
|
294
|
+
def __init__(
|
295
|
+
self,
|
296
|
+
config: config_mllama.MllamaVisionConfig,
|
297
|
+
quant_config: Optional[QuantizationConfig] = None,
|
298
|
+
prefix: str = "",
|
299
|
+
):
|
287
300
|
super().__init__()
|
288
301
|
self.image_size = config.image_size
|
289
302
|
self.patch_size = config.patch_size
|
@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module):
|
|
320
333
|
# encoders
|
321
334
|
self.transformer = MllamaVisionEncoder(
|
322
335
|
config,
|
336
|
+
quant_config,
|
323
337
|
config.num_hidden_layers,
|
324
338
|
is_gated=False,
|
325
339
|
output_hidden_states=config.intermediate_layers_indices,
|
@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module):
|
|
327
341
|
)
|
328
342
|
self.global_transformer = MllamaVisionEncoder(
|
329
343
|
config,
|
344
|
+
quant_config,
|
330
345
|
config.num_global_layers,
|
331
346
|
is_gated=True,
|
332
347
|
prefix=add_prefix("global_transformer", prefix),
|
@@ -535,6 +550,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|
535
550
|
self.num_local_key_value_heads,
|
536
551
|
layer_id=layer_id,
|
537
552
|
is_cross_attention=True,
|
553
|
+
quant_config=quant_config,
|
538
554
|
prefix=add_prefix("attn", prefix),
|
539
555
|
)
|
540
556
|
|
@@ -764,6 +780,27 @@ class MllamaForCausalLM(nn.Module):
|
|
764
780
|
|
765
781
|
|
766
782
|
class MllamaForConditionalGeneration(nn.Module):
|
783
|
+
# BitandBytes specific attributes
|
784
|
+
default_bitsandbytes_target_modules = [
|
785
|
+
".gate_proj.",
|
786
|
+
".down_proj.",
|
787
|
+
".up_proj.",
|
788
|
+
".q_proj.",
|
789
|
+
".k_proj.",
|
790
|
+
".v_proj.",
|
791
|
+
".o_proj.",
|
792
|
+
]
|
793
|
+
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
794
|
+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
795
|
+
bitsandbytes_stacked_params_mapping = {
|
796
|
+
# shard_name, weight_name, index
|
797
|
+
"q_proj": ("qkv_proj", 0),
|
798
|
+
"k_proj": ("qkv_proj", 1),
|
799
|
+
"v_proj": ("qkv_proj", 2),
|
800
|
+
"gate_proj": ("gate_up_proj", 0),
|
801
|
+
"up_proj": ("gate_up_proj", 1),
|
802
|
+
}
|
803
|
+
|
767
804
|
def __init__(
|
768
805
|
self,
|
769
806
|
config: config_mllama.MllamaConfig,
|
@@ -771,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
771
808
|
prefix: str = "",
|
772
809
|
):
|
773
810
|
super().__init__()
|
811
|
+
self.quant_config = quant_config
|
774
812
|
self.vocab_size = config.text_config.vocab_size
|
775
813
|
self.hidden_size = config.text_config.hidden_size
|
776
814
|
self.max_num_tiles = config.vision_config.max_num_tiles
|
@@ -781,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
781
819
|
self.image_size = config.vision_config.image_size
|
782
820
|
|
783
821
|
self.vision_model = MllamaVisionModel(
|
784
|
-
config.vision_config,
|
822
|
+
config.vision_config,
|
823
|
+
quant_config=quant_config,
|
824
|
+
prefix=add_prefix("vision_model", prefix),
|
785
825
|
)
|
786
826
|
self.language_model = MllamaForCausalLM(
|
787
827
|
config.text_config,
|
788
828
|
quant_config=quant_config,
|
789
829
|
prefix=add_prefix("language_model", prefix),
|
790
830
|
)
|
791
|
-
self.multi_modal_projector =
|
831
|
+
self.multi_modal_projector = ReplicatedLinear(
|
792
832
|
config.vision_config.vision_output_dim,
|
793
833
|
config.text_config.hidden_size,
|
794
834
|
bias=True,
|
835
|
+
quant_config=quant_config,
|
836
|
+
prefix="multi_modal_projector",
|
795
837
|
)
|
796
838
|
self.logits_processor = LogitsProcessor(config.text_config)
|
797
839
|
self.capture_mode = False
|
@@ -958,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
958
1000
|
cross_attention_states = self.vision_model(
|
959
1001
|
batched_images, batched_ar_ids, batched_ar_mask
|
960
1002
|
)
|
961
|
-
cross_attention_states = self.multi_modal_projector(
|
1003
|
+
cross_attention_states, _ = self.multi_modal_projector(
|
1004
|
+
cross_attention_states
|
1005
|
+
)
|
962
1006
|
|
963
1007
|
bs, _, _, _, image_token_dim = cross_attention_states.shape
|
964
1008
|
cross_attention_states = cross_attention_states.view(
|
@@ -1012,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
1012
1056
|
if "vision_model" in name:
|
1013
1057
|
# adapt to VisionAttention
|
1014
1058
|
name = name.replace("self_attn.o_proj", "self_attn.proj")
|
1015
|
-
|
1016
1059
|
param = params_dict.pop(name)
|
1017
1060
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
1018
1061
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/mllama4.py
CHANGED
@@ -1,13 +1,19 @@
|
|
1
|
-
# TODO: add Aapted from vllm/mllama4.py
|
2
1
|
from collections.abc import Iterable
|
3
|
-
from typing import Optional, Set, Tuple
|
2
|
+
from typing import List, Optional, Set, Tuple
|
4
3
|
|
5
4
|
import torch
|
6
5
|
from torch import nn
|
7
|
-
from transformers import Llama4Config
|
6
|
+
from transformers import Llama4Config, Llama4VisionModel
|
7
|
+
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
|
8
8
|
|
9
9
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
10
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
10
11
|
from sglang.srt.layers.quantization import QuantizationConfig
|
12
|
+
from sglang.srt.managers.mm_utils import (
|
13
|
+
MultiModalityDataPaddingPatternImageTokens,
|
14
|
+
general_mm_embed_routine,
|
15
|
+
)
|
16
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
11
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
12
18
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
13
19
|
from sglang.srt.utils import add_prefix
|
@@ -16,6 +22,7 @@ from sglang.srt.utils import add_prefix
|
|
16
22
|
class Llama4ForConditionalGeneration(nn.Module):
|
17
23
|
packed_modules_mapping = {
|
18
24
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
25
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
19
26
|
}
|
20
27
|
|
21
28
|
def __init__(
|
@@ -28,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
28
35
|
self.config = config
|
29
36
|
self.quant_config = quant_config
|
30
37
|
|
38
|
+
self.vision_model = Llama4VisionModel(config.vision_config)
|
39
|
+
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
40
|
+
|
31
41
|
# Initialize the language model
|
32
42
|
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
33
43
|
|
@@ -39,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
39
49
|
|
40
50
|
self.logits_processor = LogitsProcessor(config.text_config)
|
41
51
|
|
52
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
53
|
+
# Get all special token IDs
|
54
|
+
im_token_id: int = mm_inputs.im_token_id
|
55
|
+
|
56
|
+
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
|
57
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
58
|
+
|
59
|
+
def get_image_feature(
|
60
|
+
self,
|
61
|
+
items: List[MultimodalDataItem],
|
62
|
+
) -> torch.Tensor:
|
63
|
+
pixel_values = (
|
64
|
+
torch.concat([item.pixel_values for item in items])
|
65
|
+
.to(next(self.vision_model.parameters()).device)
|
66
|
+
.type(next(self.vision_model.parameters()).dtype)
|
67
|
+
)
|
68
|
+
|
69
|
+
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
|
70
|
+
image_features = image_outputs.last_hidden_state
|
71
|
+
vision_flat = image_features.view(-1, image_features.size(-1))
|
72
|
+
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
73
|
+
return projected_vision_flat
|
74
|
+
|
42
75
|
def forward(
|
43
76
|
self,
|
44
77
|
input_ids: torch.Tensor,
|
@@ -47,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
47
80
|
**kwargs: object,
|
48
81
|
) -> torch.Tensor:
|
49
82
|
|
50
|
-
|
83
|
+
hs = general_mm_embed_routine(
|
84
|
+
input_ids=input_ids,
|
85
|
+
forward_batch=forward_batch,
|
86
|
+
language_model=self.language_model,
|
87
|
+
image_data_embedding_func=self.get_image_feature,
|
88
|
+
positions=positions,
|
89
|
+
)
|
90
|
+
|
91
|
+
return hs
|
51
92
|
|
52
93
|
def permute_qk_weight_for_rotary(
|
53
94
|
self,
|
@@ -96,18 +137,27 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
96
137
|
|
97
138
|
num_experts = self.config.text_config.num_local_experts
|
98
139
|
|
99
|
-
for
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
140
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
141
|
+
# (param_name, weight_name, expert_id, shard_id)
|
142
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
143
|
+
ckpt_gate_proj_name="gate_proj",
|
144
|
+
ckpt_down_proj_name="down_proj",
|
145
|
+
ckpt_up_proj_name="up_proj",
|
146
|
+
num_experts=num_experts,
|
147
|
+
)
|
105
148
|
|
106
|
-
|
149
|
+
for name, loaded_weight in weights:
|
150
|
+
if not "vision" in name:
|
151
|
+
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
152
|
+
name, loaded_weight
|
153
|
+
)
|
107
154
|
|
108
155
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
109
156
|
if weight_name not in name:
|
110
157
|
continue
|
158
|
+
|
159
|
+
if "vision" in name:
|
160
|
+
continue
|
111
161
|
name = name.replace(weight_name, param_name)
|
112
162
|
param = params_dict[name]
|
113
163
|
weight_loader = param.weight_loader
|
@@ -115,31 +165,54 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
115
165
|
break
|
116
166
|
else:
|
117
167
|
if ".experts" in name:
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
123
|
-
shard_id_list = ["w1", "w3"]
|
124
|
-
else:
|
125
|
-
name_list = [
|
126
|
-
name.replace(".experts.down_proj", ".experts.w2_weight")
|
127
|
-
]
|
128
|
-
shard_id_list = ["w2"]
|
129
|
-
loaded_weight_list = [loaded_weight]
|
130
|
-
for name, loaded_weight, shard_id in zip(
|
131
|
-
name_list, loaded_weight_list, shard_id_list
|
168
|
+
# NOTE: llama4 fp8 has different weight format for experts
|
169
|
+
if (
|
170
|
+
"experts.gate_up_proj" not in name
|
171
|
+
and "experts.down_proj" not in name
|
132
172
|
):
|
133
|
-
|
134
|
-
|
135
|
-
|
173
|
+
for mapping in expert_params_mapping:
|
174
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
175
|
+
if weight_name not in name:
|
176
|
+
continue
|
177
|
+
name = name.replace(weight_name, param_name)
|
178
|
+
param = params_dict[name]
|
179
|
+
weight_loader = param.weight_loader
|
136
180
|
weight_loader(
|
137
181
|
param,
|
138
|
-
loaded_weight
|
182
|
+
loaded_weight,
|
139
183
|
name,
|
140
184
|
shard_id=shard_id,
|
141
185
|
expert_id=expert_id,
|
142
186
|
)
|
187
|
+
break
|
188
|
+
else:
|
189
|
+
if ".gate_up_proj" in name:
|
190
|
+
name_list = [
|
191
|
+
name.replace(
|
192
|
+
".experts.gate_up_proj", ".experts.w13_weight"
|
193
|
+
)
|
194
|
+
] * 2
|
195
|
+
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
196
|
+
shard_id_list = ["w1", "w3"]
|
197
|
+
else:
|
198
|
+
name_list = [
|
199
|
+
name.replace(".experts.down_proj", ".experts.w2_weight")
|
200
|
+
]
|
201
|
+
shard_id_list = ["w2"]
|
202
|
+
loaded_weight_list = [loaded_weight]
|
203
|
+
for name, loaded_weight, shard_id in zip(
|
204
|
+
name_list, loaded_weight_list, shard_id_list
|
205
|
+
):
|
206
|
+
param = params_dict[name]
|
207
|
+
weight_loader = param.weight_loader
|
208
|
+
for expert_id in range(num_experts):
|
209
|
+
weight_loader(
|
210
|
+
param,
|
211
|
+
loaded_weight[expert_id].T,
|
212
|
+
name,
|
213
|
+
shard_id=shard_id,
|
214
|
+
expert_id=expert_id,
|
215
|
+
)
|
143
216
|
else:
|
144
217
|
# Skip loading extra bias for GPTQ models.
|
145
218
|
if name.endswith(".bias") and name not in params_dict:
|
sglang/srt/models/olmo.py
CHANGED
sglang/srt/models/olmo2.py
CHANGED
sglang/srt/models/olmoe.py
CHANGED
sglang/srt/models/phi3_small.py
CHANGED
sglang/srt/models/qwen.py
CHANGED
sglang/srt/models/qwen2.py
CHANGED
@@ -154,6 +154,7 @@ class Qwen2Attention(nn.Module):
|
|
154
154
|
self.scaling,
|
155
155
|
num_kv_heads=self.num_kv_heads,
|
156
156
|
layer_id=layer_id,
|
157
|
+
quant_config=quant_config,
|
157
158
|
prefix=add_prefix("attn", prefix),
|
158
159
|
)
|
159
160
|
|
@@ -238,6 +239,7 @@ class Qwen2Model(nn.Module):
|
|
238
239
|
config: Qwen2Config,
|
239
240
|
quant_config: Optional[QuantizationConfig] = None,
|
240
241
|
prefix: str = "",
|
242
|
+
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
|
241
243
|
) -> None:
|
242
244
|
super().__init__()
|
243
245
|
self.config = config
|
@@ -249,9 +251,11 @@ class Qwen2Model(nn.Module):
|
|
249
251
|
quant_config=quant_config,
|
250
252
|
prefix=add_prefix("embed_tokens", prefix),
|
251
253
|
)
|
254
|
+
# Use the provided decoder layer type or default to Qwen2DecoderLayer
|
255
|
+
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
|
252
256
|
self.layers = make_layers(
|
253
257
|
config.num_hidden_layers,
|
254
|
-
lambda idx, prefix:
|
258
|
+
lambda idx, prefix: decoder_layer_type(
|
255
259
|
layer_id=idx,
|
256
260
|
config=config,
|
257
261
|
quant_config=quant_config,
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -30,12 +30,16 @@ import torch
|
|
30
30
|
import torch.nn as nn
|
31
31
|
import torch.nn.functional as F
|
32
32
|
from einops import rearrange
|
33
|
-
from transformers import Qwen2VLConfig
|
34
33
|
from transformers.activations import ACT2FN
|
35
34
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
36
35
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
36
|
+
Qwen2_5_VLConfig,
|
37
37
|
Qwen2_5_VLVisionConfig,
|
38
38
|
)
|
39
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
40
|
+
Qwen2_5_VisionPatchEmbed,
|
41
|
+
Qwen2_5_VisionRotaryEmbedding,
|
42
|
+
)
|
39
43
|
|
40
44
|
from sglang.srt.hf_transformers_utils import get_processor
|
41
45
|
from sglang.srt.layers.attention.vision import VisionAttention
|
@@ -137,7 +141,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
137
141
|
embed_dim=dim,
|
138
142
|
num_heads=num_heads,
|
139
143
|
projection_size=dim,
|
140
|
-
use_qkv_parallel=
|
144
|
+
use_qkv_parallel=True,
|
141
145
|
use_context_forward=use_context_forward,
|
142
146
|
softmax_in_single_precision=softmax_in_single_precision,
|
143
147
|
flatten_batch=flatten_batch,
|
@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
173
177
|
return x
|
174
178
|
|
175
179
|
|
176
|
-
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
177
|
-
|
178
|
-
def __init__(
|
179
|
-
self,
|
180
|
-
patch_size: int = 14,
|
181
|
-
temporal_patch_size: int = 2,
|
182
|
-
in_chans: int = 3,
|
183
|
-
embed_dim: int = 1152,
|
184
|
-
) -> None:
|
185
|
-
super().__init__()
|
186
|
-
self.patch_size = patch_size
|
187
|
-
self.temporal_patch_size = temporal_patch_size
|
188
|
-
self.embed_dim = embed_dim
|
189
|
-
|
190
|
-
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
191
|
-
self.proj = nn.Conv3d(
|
192
|
-
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
|
193
|
-
)
|
194
|
-
|
195
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
196
|
-
target_dtype = self.proj.weight.dtype
|
197
|
-
L, C = x.shape
|
198
|
-
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
199
|
-
x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
|
200
|
-
return x
|
201
|
-
|
202
|
-
|
203
180
|
class Qwen2_5_VisionPatchMerger(nn.Module):
|
204
181
|
|
205
182
|
def __init__(
|
@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|
244
221
|
return out
|
245
222
|
|
246
223
|
|
247
|
-
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
248
|
-
|
249
|
-
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
250
|
-
super().__init__()
|
251
|
-
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
252
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
253
|
-
|
254
|
-
def forward(self, seqlen: int) -> torch.Tensor:
|
255
|
-
seq = torch.arange(
|
256
|
-
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
257
|
-
)
|
258
|
-
freqs = torch.outer(seq, self.inv_freq)
|
259
|
-
return freqs
|
260
|
-
|
261
|
-
|
262
224
|
class Qwen2_5_VisionTransformer(nn.Module):
|
263
225
|
|
264
226
|
def __init__(
|
@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
275
237
|
spatial_merge_size: int = vision_config.spatial_merge_size
|
276
238
|
self.spatial_merge_size = spatial_merge_size
|
277
239
|
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
|
278
|
-
|
240
|
+
in_channels: int = vision_config.in_channels
|
279
241
|
hidden_size: int = vision_config.hidden_size
|
280
242
|
depth: int = vision_config.depth
|
281
243
|
num_heads: int = vision_config.num_heads
|
@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
286
248
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
287
249
|
patch_size=patch_size,
|
288
250
|
temporal_patch_size=temporal_patch_size,
|
289
|
-
|
251
|
+
in_channels=in_channels,
|
290
252
|
embed_dim=hidden_size,
|
291
253
|
)
|
292
254
|
|
@@ -363,7 +325,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
363
325
|
|
364
326
|
@property
|
365
327
|
def dtype(self) -> torch.dtype:
|
366
|
-
return self.
|
328
|
+
return self.patch_embed.proj.weight.dtype
|
367
329
|
|
368
330
|
@property
|
369
331
|
def device(self) -> torch.device:
|
@@ -467,9 +429,28 @@ cached_get_processor = lru_cache(get_processor)
|
|
467
429
|
|
468
430
|
|
469
431
|
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
432
|
+
# BitandBytes specific attributes
|
433
|
+
default_bitsandbytes_target_modules = [
|
434
|
+
".gate_proj.",
|
435
|
+
".down_proj.",
|
436
|
+
".up_proj.",
|
437
|
+
".q_proj.",
|
438
|
+
".k_proj.",
|
439
|
+
".v_proj.",
|
440
|
+
".o_proj.",
|
441
|
+
]
|
442
|
+
bitsandbytes_stacked_params_mapping = {
|
443
|
+
# shard_name, weight_name, index
|
444
|
+
"q_proj": ("qkv_proj", 0),
|
445
|
+
"k_proj": ("qkv_proj", 1),
|
446
|
+
"v_proj": ("qkv_proj", 2),
|
447
|
+
"gate_proj": ("gate_up_proj", 0),
|
448
|
+
"up_proj": ("gate_up_proj", 1),
|
449
|
+
}
|
450
|
+
|
470
451
|
def __init__(
|
471
452
|
self,
|
472
|
-
config:
|
453
|
+
config: Qwen2_5_VLConfig,
|
473
454
|
quant_config: Optional[QuantizationConfig] = None,
|
474
455
|
prefix: str = "",
|
475
456
|
) -> None:
|
@@ -479,9 +460,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
479
460
|
self.visual = Qwen2_5_VisionTransformer(
|
480
461
|
config.vision_config,
|
481
462
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
482
|
-
# NOTE:
|
483
|
-
# quantization
|
484
|
-
quant_config=
|
463
|
+
# NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
464
|
+
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
465
|
+
quant_config=quant_config,
|
485
466
|
prefix=add_prefix("visual", prefix),
|
486
467
|
)
|
487
468
|
|
@@ -500,6 +481,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
500
481
|
quant_config=quant_config,
|
501
482
|
prefix=add_prefix("lm_head", prefix),
|
502
483
|
)
|
484
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
503
485
|
|
504
486
|
self.logits_processor = LogitsProcessor(config)
|
505
487
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
@@ -553,14 +535,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
553
535
|
otherwise it will be `(seq_len,).
|
554
536
|
(Use input_metadata.mrope_positions to replace it)
|
555
537
|
"""
|
556
|
-
if
|
538
|
+
if self.is_mrope_enabled:
|
557
539
|
positions = forward_batch.mrope_positions
|
558
540
|
|
559
541
|
if not (
|
560
542
|
forward_batch.forward_mode.is_decode()
|
561
543
|
or not forward_batch.contains_image_inputs()
|
562
544
|
):
|
563
|
-
if
|
545
|
+
if self.is_mrope_enabled:
|
564
546
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
565
547
|
"multimodal section rotary embedding requires "
|
566
548
|
f"(3, seq_len) positions, but got {positions.size()}"
|
@@ -610,23 +592,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
610
592
|
weight_loader(param, loaded_weight, shard_id)
|
611
593
|
break
|
612
594
|
else:
|
613
|
-
if "visual" in name and "qkv.weight" in name:
|
614
|
-
visual_num_heads = self.config.vision_config.num_heads
|
615
|
-
visual_embed_dim = self.config.vision_config.hidden_size
|
616
|
-
head_size = visual_embed_dim // visual_num_heads
|
617
|
-
loaded_weight = loaded_weight.view(
|
618
|
-
3, visual_num_heads, head_size, visual_embed_dim
|
619
|
-
)
|
620
|
-
loaded_weight = loaded_weight.transpose(0, 1)
|
621
|
-
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
622
|
-
elif "visual" in name and "qkv.bias" in name:
|
623
|
-
visual_num_heads = self.config.vision_config.num_heads
|
624
|
-
visual_embed_dim = self.config.vision_config.hidden_size
|
625
|
-
head_size = visual_embed_dim // visual_num_heads
|
626
|
-
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
627
|
-
loaded_weight = loaded_weight.transpose(0, 1)
|
628
|
-
loaded_weight = loaded_weight.reshape(-1)
|
629
|
-
|
630
595
|
if "visual" in name:
|
631
596
|
# adapt to VisionAttention
|
632
597
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|