sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama4.py
CHANGED
@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
|
|
30
30
|
from sglang.srt.layers.dp_attention import (
|
31
31
|
dp_gather_partial,
|
32
32
|
dp_scatter,
|
33
|
-
get_attention_dp_size,
|
34
33
|
get_attention_tp_rank,
|
35
34
|
get_attention_tp_size,
|
35
|
+
get_local_attention_dp_size,
|
36
36
|
)
|
37
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
38
38
|
from sglang.srt.layers.linear import (
|
@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
46
46
|
from sglang.srt.layers.rotary_embedding import get_rope
|
47
47
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
48
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
|
-
from sglang.srt.model_executor.forward_batch_info import
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
50
|
+
ForwardBatch,
|
51
|
+
ForwardMode,
|
52
|
+
PPProxyTensors,
|
53
|
+
)
|
50
54
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
51
55
|
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
52
56
|
|
@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module):
|
|
81
85
|
super().__init__()
|
82
86
|
self.tp_size = get_tensor_model_parallel_world_size()
|
83
87
|
self.top_k = config.num_experts_per_tok
|
88
|
+
self.device_module = torch.get_device_module()
|
84
89
|
|
85
90
|
intermediate_size_moe = config.intermediate_size
|
86
91
|
self.router = ReplicatedLinear(
|
@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module):
|
|
113
118
|
reduce_results=False, # We need to do scatter before reduce
|
114
119
|
)
|
115
120
|
|
116
|
-
def forward(self, hidden_states):
|
121
|
+
def forward(self, hidden_states, forward_batch: ForwardBatch):
|
122
|
+
shared_out, routed_out = self._forward_core(
|
123
|
+
hidden_states, forward_batch.forward_mode
|
124
|
+
)
|
125
|
+
|
126
|
+
out_aD = routed_out + shared_out
|
127
|
+
|
128
|
+
if self.tp_size > 1:
|
129
|
+
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
130
|
+
|
131
|
+
return out_aD
|
132
|
+
|
133
|
+
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
|
134
|
+
if hidden_states.shape[0] < 4:
|
135
|
+
return self._forward_core_shared_routed_overlap(hidden_states)
|
136
|
+
else:
|
137
|
+
return self._forward_core_normal(hidden_states)
|
138
|
+
|
139
|
+
def _forward_core_normal(self, hidden_states):
|
117
140
|
# router_scores: [num_tokens, num_experts]
|
118
141
|
router_logits, _ = self.router(hidden_states)
|
119
142
|
shared_out = self.shared_expert(hidden_states)
|
@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module):
|
|
121
144
|
hidden_states=hidden_states,
|
122
145
|
router_logits=router_logits,
|
123
146
|
)
|
124
|
-
|
147
|
+
return shared_out, routed_out
|
125
148
|
|
126
|
-
|
127
|
-
|
149
|
+
def _forward_core_shared_routed_overlap(self, hidden_states):
|
150
|
+
alt_stream = _get_or_create_alt_stream(self.device_module)
|
128
151
|
|
129
|
-
|
152
|
+
alt_stream.wait_stream(self.device_module.current_stream())
|
153
|
+
|
154
|
+
shared_out = self.shared_expert(hidden_states)
|
155
|
+
|
156
|
+
with self.device_module.stream(alt_stream):
|
157
|
+
# router_scores: [num_tokens, num_experts]
|
158
|
+
router_logits, _ = self.router(hidden_states)
|
159
|
+
routed_out = self.experts(
|
160
|
+
hidden_states=hidden_states,
|
161
|
+
router_logits=router_logits,
|
162
|
+
)
|
163
|
+
self.device_module.current_stream().wait_stream(alt_stream)
|
164
|
+
|
165
|
+
return shared_out, routed_out
|
166
|
+
|
167
|
+
|
168
|
+
_alt_stream = None
|
169
|
+
|
170
|
+
|
171
|
+
def _get_or_create_alt_stream(device_module):
|
172
|
+
global _alt_stream
|
173
|
+
if _alt_stream is None:
|
174
|
+
_alt_stream = device_module.Stream()
|
175
|
+
return _alt_stream
|
130
176
|
|
131
177
|
|
132
178
|
class Llama4Attention(nn.Module):
|
@@ -152,7 +198,6 @@ class Llama4Attention(nn.Module):
|
|
152
198
|
self.use_rope = int((layer_id + 1) % 4 != 0)
|
153
199
|
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
154
200
|
|
155
|
-
self.dp_size = get_attention_dp_size()
|
156
201
|
attn_tp_rank = get_attention_tp_rank()
|
157
202
|
attn_tp_size = get_attention_tp_size()
|
158
203
|
|
@@ -296,7 +341,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
296
341
|
rope_theta = config.rope_theta
|
297
342
|
rope_scaling = config.rope_scaling
|
298
343
|
max_position_embeddings = config.max_position_embeddings
|
299
|
-
self.
|
344
|
+
self.local_dp_size = get_local_attention_dp_size()
|
300
345
|
self.attn_tp_size = get_attention_tp_size()
|
301
346
|
self.attn_tp_rank = get_attention_tp_rank()
|
302
347
|
|
@@ -359,7 +404,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
359
404
|
# Gather
|
360
405
|
if get_tensor_model_parallel_world_size() > 1:
|
361
406
|
# all gather and all reduce
|
362
|
-
if self.
|
407
|
+
if self.local_dp_size != 1:
|
363
408
|
if self.attn_tp_rank == 0:
|
364
409
|
hidden_states += residual
|
365
410
|
hidden_states, local_hidden_states = (
|
@@ -380,11 +425,11 @@ class Llama4DecoderLayer(nn.Module):
|
|
380
425
|
)
|
381
426
|
|
382
427
|
# Fully Connected
|
383
|
-
hidden_states = self.feed_forward(hidden_states)
|
428
|
+
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
384
429
|
|
385
|
-
# TODO(ch-wan):
|
430
|
+
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
|
386
431
|
# Scatter
|
387
|
-
if self.
|
432
|
+
if self.local_dp_size != 1:
|
388
433
|
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
389
434
|
# be careful about this!
|
390
435
|
hidden_states, global_hidden_states = (
|
sglang/srt/models/llava.py
CHANGED
@@ -15,7 +15,8 @@
|
|
15
15
|
|
16
16
|
import math
|
17
17
|
import re
|
18
|
-
from
|
18
|
+
from functools import lru_cache
|
19
|
+
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
19
20
|
|
20
21
|
import numpy as np
|
21
22
|
import torch
|
@@ -28,10 +29,18 @@ from transformers import (
|
|
28
29
|
Qwen2Config,
|
29
30
|
SiglipVisionModel,
|
30
31
|
)
|
32
|
+
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
|
31
33
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
32
34
|
|
35
|
+
# leave till last and symbol only in case circular import
|
36
|
+
import sglang.srt.models as sgl_models
|
33
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
|
-
from sglang.srt.managers.
|
38
|
+
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
39
|
+
from sglang.srt.managers.schedule_batch import (
|
40
|
+
Modality,
|
41
|
+
MultimodalDataItem,
|
42
|
+
MultimodalInputs,
|
43
|
+
)
|
35
44
|
from sglang.srt.mm_utils import (
|
36
45
|
get_anyres_image_grid_shape,
|
37
46
|
unpad_image,
|
@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
42
51
|
from sglang.srt.models.llama import LlamaForCausalLM
|
43
52
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
53
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
45
|
-
from sglang.srt.utils import add_prefix, flatten_nested_list
|
54
|
+
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
|
46
55
|
|
47
56
|
|
48
57
|
class LlavaBaseForCausalLM(nn.Module):
|
@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
114
123
|
image_inputs.image_offsets = offset_list
|
115
124
|
return input_ids
|
116
125
|
|
117
|
-
def encode_images(
|
126
|
+
def encode_images(
|
127
|
+
self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
|
128
|
+
) -> torch.Tensor:
|
129
|
+
"""
|
130
|
+
encode images by vision tower and multimodal projector
|
131
|
+
Args:
|
132
|
+
pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
|
133
|
+
Returns:
|
134
|
+
torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
|
135
|
+
"""
|
118
136
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
119
137
|
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
120
138
|
|
@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
583
601
|
)
|
584
602
|
|
585
603
|
|
586
|
-
|
604
|
+
class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
605
|
+
"""
|
606
|
+
An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
|
607
|
+
It follows the structure of (vision_tower, multi_modal_projector, language_model)
|
608
|
+
|
609
|
+
Once a model config is loaded, text_config and vision_config will be extracted, and
|
610
|
+
LlavaForConditionalGeneration will load the language_model and vision_tower models
|
611
|
+
according to config.
|
612
|
+
"""
|
613
|
+
|
614
|
+
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
|
615
|
+
|
616
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
617
|
+
if hasattr(self.vision_tower, "pad_input_ids"):
|
618
|
+
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
|
619
|
+
else:
|
620
|
+
return super().pad_input_ids(input_ids, image_inputs)
|
621
|
+
|
622
|
+
def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
|
623
|
+
"""
|
624
|
+
Get the SGLang model implementation class according to config.
|
625
|
+
|
626
|
+
Args:
|
627
|
+
config: The config object of the model.
|
628
|
+
auto_model_type: The type of the auto model.
|
629
|
+
|
630
|
+
Returns:
|
631
|
+
The SGLang model implementation class.
|
632
|
+
"""
|
633
|
+
config_cls_name = config.__class__.__name__
|
634
|
+
arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
|
635
|
+
if arch := arch_name_mapping.get(config_cls_name):
|
636
|
+
if isinstance(arch, tuple):
|
637
|
+
arch = arch[0]
|
638
|
+
logger.warning(
|
639
|
+
f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
|
640
|
+
)
|
641
|
+
try:
|
642
|
+
return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
|
643
|
+
except Exception as e:
|
644
|
+
raise ValueError(
|
645
|
+
f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
|
646
|
+
)
|
647
|
+
else:
|
648
|
+
raise ValueError(
|
649
|
+
f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
|
650
|
+
)
|
651
|
+
|
652
|
+
@lru_cache
|
653
|
+
def _config_cls_name_to_arch_name_mapping(
|
654
|
+
self, auto_model_type: Type[AutoModel]
|
655
|
+
) -> Dict[str, str]:
|
656
|
+
mapping = {}
|
657
|
+
for config_cls, archs in auto_model_type._model_mapping.items():
|
658
|
+
if isinstance(archs, tuple):
|
659
|
+
mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
|
660
|
+
else:
|
661
|
+
mapping[config_cls.__name__] = archs.__name__
|
662
|
+
return mapping
|
663
|
+
|
664
|
+
def __init__(
|
665
|
+
self,
|
666
|
+
config: LlavaConfig,
|
667
|
+
quant_config: Optional[QuantizationConfig] = None,
|
668
|
+
prefix: str = "",
|
669
|
+
) -> None:
|
670
|
+
super().__init__()
|
671
|
+
|
672
|
+
assert hasattr(config, "text_config")
|
673
|
+
assert hasattr(config, "vision_config")
|
674
|
+
self.config = config
|
675
|
+
self.text_config = config.text_config
|
676
|
+
self.vision_config = config.vision_config
|
677
|
+
|
678
|
+
if not hasattr(self.config, "vocab_size"):
|
679
|
+
self.config.vocab_size = self.config.text_config.vocab_size
|
680
|
+
if not hasattr(self.config, "image_aspect_ratio"):
|
681
|
+
self.config.image_aspect_ratio = "anyres"
|
682
|
+
if not hasattr(self.config, "image_grid_pinpoints"):
|
683
|
+
# from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
|
684
|
+
# self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
|
685
|
+
self.config.image_grid_pinpoints = [
|
686
|
+
[96, 96],
|
687
|
+
[224, 224],
|
688
|
+
[384, 384],
|
689
|
+
[512, 512],
|
690
|
+
[768, 768],
|
691
|
+
[1024, 1024],
|
692
|
+
]
|
693
|
+
if not hasattr(self.config, "mm_patch_merge_type"):
|
694
|
+
self.config.mm_patch_merge_type = "flat"
|
695
|
+
if not hasattr(self.config, "image_token_index"):
|
696
|
+
self.config.image_token_index = 10
|
697
|
+
if not hasattr(self.config, "projector_hidden_act"):
|
698
|
+
self.config.projector_hidden_act = "gelu"
|
699
|
+
|
700
|
+
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
701
|
+
self.vision_feature_select_strategy = getattr(
|
702
|
+
config, "vision_feature_select_strategy", "full"
|
703
|
+
)
|
704
|
+
self.image_size = self.config.vision_config.image_size
|
705
|
+
self.patch_size = self.config.vision_config.patch_size
|
706
|
+
|
707
|
+
self.mm_patch_merge_type = config.mm_patch_merge_type
|
708
|
+
self.image_aspect_ratio = config.image_aspect_ratio
|
709
|
+
self.image_grid_pinpoints = config.image_grid_pinpoints
|
710
|
+
|
711
|
+
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
712
|
+
|
713
|
+
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
|
714
|
+
|
715
|
+
language_model_cls = self._get_sgl_model_cls(
|
716
|
+
config.text_config, AutoModelForCausalLM
|
717
|
+
)
|
718
|
+
vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
|
719
|
+
self.language_model = language_model_cls(
|
720
|
+
config.text_config,
|
721
|
+
quant_config=quant_config,
|
722
|
+
prefix=add_prefix("language_model", prefix),
|
723
|
+
)
|
724
|
+
self.vision_tower = vision_model_cls(
|
725
|
+
config.vision_config,
|
726
|
+
quant_config=quant_config,
|
727
|
+
prefix=add_prefix("vision_tower", prefix),
|
728
|
+
)
|
729
|
+
|
730
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
731
|
+
self.language_model.model.image_newline = nn.Parameter(
|
732
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
733
|
+
)
|
734
|
+
|
735
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
736
|
+
"""Extract features from image inputs.
|
737
|
+
|
738
|
+
Args:
|
739
|
+
items: List of MultimodalDataItem objects containing image data
|
740
|
+
Note that an item can be either "image" or "multi-images"
|
741
|
+
|
742
|
+
Returns:
|
743
|
+
torch.Tensor: features from image inputs, concatenated
|
744
|
+
"""
|
745
|
+
features = []
|
746
|
+
for item in items:
|
747
|
+
# in each item, we assume pixel_values is always batched
|
748
|
+
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
749
|
+
image_outputs = self.vision_tower(
|
750
|
+
pixel_values, image_sizes, output_hidden_states=True
|
751
|
+
)
|
752
|
+
selected_image_feature = image_outputs.hidden_states[
|
753
|
+
self.vision_feature_layer
|
754
|
+
]
|
755
|
+
|
756
|
+
if self.vision_feature_select_strategy in ["default", "patch"]:
|
757
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
758
|
+
elif self.vision_feature_select_strategy == "full":
|
759
|
+
selected_image_feature = selected_image_feature
|
760
|
+
else:
|
761
|
+
raise ValueError(
|
762
|
+
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
763
|
+
)
|
764
|
+
features.append(
|
765
|
+
self.multi_modal_projector(selected_image_feature.squeeze(0))
|
766
|
+
)
|
767
|
+
ret = torch.cat(features, dim=0)
|
768
|
+
return ret
|
769
|
+
|
770
|
+
def forward(
|
771
|
+
self,
|
772
|
+
input_ids: torch.Tensor,
|
773
|
+
positions: torch.Tensor,
|
774
|
+
forward_batch: ForwardBatch,
|
775
|
+
get_embedding: bool = False,
|
776
|
+
):
|
777
|
+
hidden_states = general_mm_embed_routine(
|
778
|
+
input_ids=input_ids,
|
779
|
+
forward_batch=forward_batch,
|
780
|
+
get_embedding=get_embedding,
|
781
|
+
language_model=self.language_model,
|
782
|
+
image_data_embedding_func=self.get_image_feature,
|
783
|
+
placeholder_tokens=None, # using mm_item.pad_value
|
784
|
+
positions=positions,
|
785
|
+
)
|
786
|
+
|
787
|
+
return hidden_states
|
788
|
+
|
789
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
790
|
+
"""Load weights for LlavaForConditionalGeneration.
|
791
|
+
|
792
|
+
Unlike the base class implementation, this one doesn't need to handle
|
793
|
+
weight name remapping as the weights are already properly structured with
|
794
|
+
'language_model' and 'vision_tower' prefixes in the safetensors files.
|
795
|
+
"""
|
796
|
+
if (
|
797
|
+
self.vision_feature_select_strategy == "patch"
|
798
|
+
or self.vision_feature_select_strategy == "full"
|
799
|
+
):
|
800
|
+
pass
|
801
|
+
elif self.vision_feature_select_strategy == "cls_patch":
|
802
|
+
self.image_feature_len += 1
|
803
|
+
else:
|
804
|
+
raise ValueError(
|
805
|
+
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
806
|
+
)
|
807
|
+
|
808
|
+
# Create dictionaries for direct parameter loading
|
809
|
+
params_dict = dict(self.named_parameters())
|
810
|
+
|
811
|
+
# Load weights directly without remapping
|
812
|
+
for name, loaded_weight in weights:
|
813
|
+
for part in ("language_model", "vision_tower"):
|
814
|
+
if name.startswith(part):
|
815
|
+
name = name[len(part + ".") :]
|
816
|
+
getattr(self, part).load_weights([(name, loaded_weight)])
|
817
|
+
break
|
818
|
+
else:
|
819
|
+
param = params_dict[name]
|
820
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
821
|
+
weight_loader(param, loaded_weight)
|
822
|
+
|
823
|
+
|
824
|
+
EntryClass = [
|
825
|
+
LlavaLlamaForCausalLM,
|
826
|
+
LlavaQwenForCausalLM,
|
827
|
+
LlavaMistralForCausalLM,
|
828
|
+
LlavaForConditionalGeneration,
|
829
|
+
]
|
sglang/srt/models/minicpmv.py
CHANGED
@@ -197,7 +197,7 @@ class Idefics2EncoderLayer(nn.Module):
|
|
197
197
|
use_qkv_parallel=True,
|
198
198
|
quant_config=quant_config,
|
199
199
|
dropout=config.attention_dropout,
|
200
|
-
|
200
|
+
qkv_backend="sdpa",
|
201
201
|
softmax_in_single_precision=True,
|
202
202
|
flatten_batch=False,
|
203
203
|
prefix=add_prefix("self_attn", prefix),
|
sglang/srt/models/mixtral.py
CHANGED
@@ -16,13 +16,15 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
17
17
|
"""Inference-only Mixtral model."""
|
18
18
|
|
19
|
-
|
19
|
+
import logging
|
20
|
+
from typing import Iterable, Optional, Tuple, Union
|
20
21
|
|
21
22
|
import torch
|
22
23
|
from torch import nn
|
23
24
|
from transformers import MixtralConfig
|
24
25
|
|
25
26
|
from sglang.srt.distributed import (
|
27
|
+
get_pp_group,
|
26
28
|
get_tensor_model_parallel_world_size,
|
27
29
|
tensor_model_parallel_all_reduce,
|
28
30
|
)
|
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
38
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
42
|
from sglang.srt.layers.rotary_embedding import get_rope
|
43
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
41
44
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
42
45
|
ParallelLMHead,
|
43
46
|
VocabParallelEmbedding,
|
44
47
|
)
|
45
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
47
50
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
48
|
-
from sglang.srt.utils import add_prefix
|
51
|
+
from sglang.srt.utils import add_prefix, make_layers
|
52
|
+
|
53
|
+
logger = logging.getLogger(__name__)
|
49
54
|
|
50
55
|
|
51
56
|
class MixtralMoE(nn.Module):
|
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
|
|
257
262
|
super().__init__()
|
258
263
|
self.padding_idx = config.pad_token_id
|
259
264
|
self.vocab_size = config.vocab_size
|
265
|
+
self.pp_group = get_pp_group()
|
260
266
|
|
261
|
-
self.
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
267
|
+
if self.pp_group.is_first_rank:
|
268
|
+
self.embed_tokens = VocabParallelEmbedding(
|
269
|
+
config.vocab_size,
|
270
|
+
config.hidden_size,
|
271
|
+
prefix=add_prefix("embed_tokens", prefix),
|
272
|
+
)
|
273
|
+
else:
|
274
|
+
self.embed_tokens = PPMissingLayer()
|
275
|
+
|
276
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
277
|
+
config.num_hidden_layers,
|
278
|
+
lambda idx, prefix: MixtralDecoderLayer(
|
279
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
280
|
+
),
|
281
|
+
pp_rank=self.pp_group.rank_in_group,
|
282
|
+
pp_size=self.pp_group.world_size,
|
283
|
+
prefix="layers",
|
284
|
+
return_tuple=True,
|
276
285
|
)
|
277
|
-
|
286
|
+
|
287
|
+
if self.pp_group.is_last_rank:
|
288
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
289
|
+
else:
|
290
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
278
291
|
|
279
292
|
def forward(
|
280
293
|
self,
|
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
|
|
282
295
|
positions: torch.Tensor,
|
283
296
|
forward_batch: ForwardBatch,
|
284
297
|
input_embeds: torch.Tensor = None,
|
285
|
-
|
286
|
-
|
287
|
-
|
298
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
299
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
300
|
+
if self.pp_group.is_first_rank:
|
301
|
+
if input_embeds is None:
|
302
|
+
hidden_states = self.embed_tokens(input_ids)
|
303
|
+
else:
|
304
|
+
hidden_states = input_embeds
|
305
|
+
residual = None
|
288
306
|
else:
|
289
|
-
|
290
|
-
|
291
|
-
|
307
|
+
assert pp_proxy_tensors is not None
|
308
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
309
|
+
residual = pp_proxy_tensors["residual"]
|
310
|
+
|
311
|
+
for i in range(self.start_layer, self.end_layer):
|
292
312
|
layer = self.layers[i]
|
293
313
|
hidden_states, residual = layer(
|
294
314
|
positions, hidden_states, forward_batch, residual
|
295
315
|
)
|
296
|
-
|
316
|
+
|
317
|
+
if not self.pp_group.is_last_rank:
|
318
|
+
return PPProxyTensors(
|
319
|
+
{
|
320
|
+
"hidden_states": hidden_states,
|
321
|
+
"residual": residual,
|
322
|
+
}
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
326
|
+
|
297
327
|
return hidden_states
|
298
328
|
|
299
329
|
|
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
|
|
306
336
|
prefix: str = "",
|
307
337
|
) -> None:
|
308
338
|
super().__init__()
|
339
|
+
self.pp_group = get_pp_group()
|
309
340
|
self.config = config
|
310
341
|
self.quant_config = quant_config
|
311
342
|
self.model = MixtralModel(
|
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
|
|
322
353
|
positions: torch.Tensor,
|
323
354
|
forward_batch: ForwardBatch,
|
324
355
|
input_embeds: torch.Tensor = None,
|
356
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
325
357
|
) -> torch.Tensor:
|
326
|
-
hidden_states = self.model(
|
327
|
-
|
328
|
-
|
358
|
+
hidden_states = self.model(
|
359
|
+
input_ids,
|
360
|
+
positions,
|
361
|
+
forward_batch,
|
362
|
+
input_embeds,
|
363
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
329
364
|
)
|
330
365
|
|
366
|
+
if self.pp_group.is_last_rank:
|
367
|
+
return self.logits_processor(
|
368
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
369
|
+
)
|
370
|
+
else:
|
371
|
+
return hidden_states
|
372
|
+
|
373
|
+
@property
|
374
|
+
def start_layer(self):
|
375
|
+
return self.model.start_layer
|
376
|
+
|
377
|
+
@property
|
378
|
+
def end_layer(self):
|
379
|
+
return self.model.end_layer
|
380
|
+
|
331
381
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
332
382
|
stacked_params_mapping = [
|
333
383
|
# (param_name, shard_name, shard_id)
|
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
|
|
348
398
|
|
349
399
|
params_dict = dict(self.named_parameters())
|
350
400
|
for name, loaded_weight in weights:
|
401
|
+
layer_id = get_layer_id(name)
|
402
|
+
if (
|
403
|
+
layer_id is not None
|
404
|
+
and hasattr(self.model, "start_layer")
|
405
|
+
and (
|
406
|
+
layer_id < self.model.start_layer
|
407
|
+
or layer_id >= self.model.end_layer
|
408
|
+
)
|
409
|
+
):
|
410
|
+
continue
|
411
|
+
|
351
412
|
if "rotary_emb.inv_freq" in name:
|
352
413
|
continue
|
353
414
|
|
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
|
|
398
459
|
if name is None:
|
399
460
|
continue
|
400
461
|
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
462
|
+
if name in params_dict.keys():
|
463
|
+
param = params_dict[name]
|
464
|
+
weight_loader = getattr(
|
465
|
+
param, "weight_loader", default_weight_loader
|
466
|
+
)
|
467
|
+
weight_loader(param, loaded_weight)
|
468
|
+
else:
|
469
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
406
470
|
|
407
471
|
|
408
472
|
EntryClass = MixtralForCausalLM
|
sglang/srt/models/mllama.py
CHANGED
@@ -203,7 +203,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
203
203
|
use_qkv_parallel=True,
|
204
204
|
quant_config=quant_config,
|
205
205
|
dropout=0.0,
|
206
|
-
|
206
|
+
qkv_backend="sdpa",
|
207
207
|
softmax_in_single_precision=False,
|
208
208
|
flatten_batch=False,
|
209
209
|
prefix=add_prefix("self_attn", prefix),
|