sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,220 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/pull/17433/files and deepseek_nextn.py
|
2
|
+
|
3
|
+
from functools import partial
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import PretrainedConfig
|
9
|
+
|
10
|
+
from sglang.srt.distributed import (
|
11
|
+
get_tensor_model_parallel_rank,
|
12
|
+
get_tensor_model_parallel_world_size,
|
13
|
+
split_tensor_along_last_dim,
|
14
|
+
tensor_model_parallel_all_gather,
|
15
|
+
)
|
16
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
17
|
+
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
18
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
19
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
20
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
21
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
22
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
23
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
24
|
+
ParallelLMHead,
|
25
|
+
VocabParallelEmbedding,
|
26
|
+
)
|
27
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
28
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
29
|
+
from sglang.srt.models.mimo import MiMoForCausalLM
|
30
|
+
from sglang.srt.models.qwen2 import (
|
31
|
+
Qwen2Attention,
|
32
|
+
Qwen2DecoderLayer,
|
33
|
+
Qwen2MLP,
|
34
|
+
Qwen2Model,
|
35
|
+
)
|
36
|
+
from sglang.srt.utils import add_prefix
|
37
|
+
|
38
|
+
|
39
|
+
class MiMoMultiTokenPredictorLayer(nn.Module):
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
config: PretrainedConfig,
|
44
|
+
prefix: str,
|
45
|
+
quant_config: Optional[QuantizationConfig] = None,
|
46
|
+
) -> None:
|
47
|
+
super().__init__()
|
48
|
+
|
49
|
+
self.embed_tokens = VocabParallelEmbedding(
|
50
|
+
config.vocab_size,
|
51
|
+
config.hidden_size,
|
52
|
+
)
|
53
|
+
self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
54
|
+
self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
55
|
+
self.input_proj = nn.Linear(
|
56
|
+
config.hidden_size * 2, config.hidden_size, bias=False
|
57
|
+
)
|
58
|
+
self.mtp_block = Qwen2DecoderLayer(
|
59
|
+
config=config, quant_config=quant_config, prefix=prefix
|
60
|
+
)
|
61
|
+
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
62
|
+
|
63
|
+
def forward(
|
64
|
+
self,
|
65
|
+
input_ids: torch.Tensor,
|
66
|
+
positions: torch.Tensor,
|
67
|
+
forward_batch: ForwardBatch,
|
68
|
+
input_embeds: torch.Tensor = None,
|
69
|
+
) -> torch.Tensor:
|
70
|
+
|
71
|
+
if input_embeds is None:
|
72
|
+
hidden_states = self.embed_tokens(input_ids)
|
73
|
+
else:
|
74
|
+
hidden_states = input_embeds
|
75
|
+
# masking inputs at position 0, as not needed by MTP
|
76
|
+
hidden_states[positions == 0] = 0
|
77
|
+
|
78
|
+
hidden_states = self.input_proj(
|
79
|
+
torch.cat(
|
80
|
+
(
|
81
|
+
self.hidden_layernorm(forward_batch.spec_info.hidden_states),
|
82
|
+
self.token_layernorm(hidden_states),
|
83
|
+
),
|
84
|
+
dim=-1,
|
85
|
+
)
|
86
|
+
)
|
87
|
+
|
88
|
+
hidden_states, residual = self.mtp_block(
|
89
|
+
positions=positions,
|
90
|
+
hidden_states=hidden_states,
|
91
|
+
forward_batch=forward_batch,
|
92
|
+
residual=None,
|
93
|
+
)
|
94
|
+
hidden_states = residual + hidden_states
|
95
|
+
hidden_states = self.final_layernorm(hidden_states)
|
96
|
+
return hidden_states
|
97
|
+
|
98
|
+
|
99
|
+
class MiMoMTP(nn.Module):
|
100
|
+
def __init__(
|
101
|
+
self,
|
102
|
+
config: PretrainedConfig,
|
103
|
+
quant_config: Optional[QuantizationConfig] = None,
|
104
|
+
prefix: str = "",
|
105
|
+
) -> None:
|
106
|
+
nn.Module.__init__(self)
|
107
|
+
self.config = config
|
108
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
109
|
+
self.quant_config = quant_config
|
110
|
+
|
111
|
+
self.model = MiMoMultiTokenPredictorLayer(
|
112
|
+
config,
|
113
|
+
prefix,
|
114
|
+
quant_config,
|
115
|
+
)
|
116
|
+
self.lm_head = ParallelLMHead(
|
117
|
+
config.vocab_size,
|
118
|
+
config.hidden_size,
|
119
|
+
quant_config=quant_config,
|
120
|
+
)
|
121
|
+
self.logits_processor = LogitsProcessor(config)
|
122
|
+
|
123
|
+
@torch.no_grad()
|
124
|
+
def forward(
|
125
|
+
self,
|
126
|
+
input_ids: torch.Tensor,
|
127
|
+
positions: torch.Tensor,
|
128
|
+
forward_batch: ForwardBatch,
|
129
|
+
) -> torch.Tensor:
|
130
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
131
|
+
return self.logits_processor(
|
132
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
133
|
+
)
|
134
|
+
|
135
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
136
|
+
stacked_params_mapping = [
|
137
|
+
# (param_name, shard_name, shard_id)
|
138
|
+
("qkv_proj", "q_proj", "q"),
|
139
|
+
("qkv_proj", "k_proj", "k"),
|
140
|
+
("qkv_proj", "v_proj", "v"),
|
141
|
+
("gate_up_proj", "gate_proj", 0),
|
142
|
+
("gate_up_proj", "up_proj", 1),
|
143
|
+
]
|
144
|
+
|
145
|
+
params_dict = dict(self.named_parameters())
|
146
|
+
for name, loaded_weight in weights:
|
147
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
148
|
+
continue
|
149
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
150
|
+
# Models trained using ColossalAI may include these tensors in
|
151
|
+
# the checkpoint. Skip them.
|
152
|
+
continue
|
153
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
154
|
+
continue
|
155
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
156
|
+
continue
|
157
|
+
name = self.map_model_name_to_mtp_param_name(name)
|
158
|
+
|
159
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
160
|
+
if weight_name not in name:
|
161
|
+
continue
|
162
|
+
if "mtp_block" not in name:
|
163
|
+
break
|
164
|
+
name = name.replace(weight_name, param_name)
|
165
|
+
# Skip loading extra bias for GPTQ models.
|
166
|
+
if name.endswith(".bias") and name not in params_dict:
|
167
|
+
continue
|
168
|
+
param = params_dict[name]
|
169
|
+
weight_loader = param.weight_loader
|
170
|
+
weight_loader(param, loaded_weight, shard_id)
|
171
|
+
break
|
172
|
+
else:
|
173
|
+
# Skip loading extra bias for GPTQ models.
|
174
|
+
if name.endswith(".bias") and name not in params_dict:
|
175
|
+
continue
|
176
|
+
if "mtp_block" not in name and (
|
177
|
+
"embed_tokens" not in name
|
178
|
+
and "lm_head" not in name
|
179
|
+
and "token_layernorm" not in name
|
180
|
+
and "hidden_layernorm" not in name
|
181
|
+
and "input_proj" not in name
|
182
|
+
and "final_layernorm" not in name
|
183
|
+
):
|
184
|
+
continue
|
185
|
+
param = params_dict[name]
|
186
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
187
|
+
weight_loader(param, loaded_weight)
|
188
|
+
|
189
|
+
def map_model_name_to_mtp_param_name(self, name: str) -> str:
|
190
|
+
import re
|
191
|
+
|
192
|
+
name_without_prefix = [
|
193
|
+
"token_layernorm",
|
194
|
+
"hidden_layernorm",
|
195
|
+
"input_proj",
|
196
|
+
"final_layernorm",
|
197
|
+
]
|
198
|
+
pattern = r"model.mtp_layers.(\d+)."
|
199
|
+
group = re.match(pattern, name)
|
200
|
+
if group is not None:
|
201
|
+
for sub_name in name_without_prefix:
|
202
|
+
if sub_name in name:
|
203
|
+
name = name.replace(group.group(), "model.")
|
204
|
+
return name
|
205
|
+
name = name.replace(group.group(), "model.mtp_block.")
|
206
|
+
return name
|
207
|
+
|
208
|
+
def get_embed_and_head(self):
|
209
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
210
|
+
|
211
|
+
def set_embed_and_head(self, embed, head):
|
212
|
+
del self.model.embed_tokens.weight
|
213
|
+
del self.lm_head.weight
|
214
|
+
self.model.embed_tokens.weight = embed
|
215
|
+
self.lm_head.weight = head
|
216
|
+
torch.cuda.empty_cache()
|
217
|
+
torch.cuda.synchronize()
|
218
|
+
|
219
|
+
|
220
|
+
EntryClass = MiMoMTP
|
sglang/srt/models/minicpmo.py
CHANGED
@@ -1520,12 +1520,15 @@ class MiniCPMO(MiniCPMBaseModel):
|
|
1520
1520
|
slice_start_id: int = mm_input.slice_start_id
|
1521
1521
|
slice_end_id: int = mm_input.slice_end_id
|
1522
1522
|
|
1523
|
-
|
1523
|
+
data_token_pairs = [
|
1524
1524
|
(im_start_id, im_end_id),
|
1525
1525
|
(slice_start_id, slice_end_id),
|
1526
1526
|
(mm_input.audio_start_id, mm_input.audio_end_id),
|
1527
1527
|
]
|
1528
|
-
|
1528
|
+
data_start_token_ids = [im_start_id, mm_input.audio_start_id]
|
1529
|
+
pattern = MultiModalityDataPaddingPatternTokenPairs(
|
1530
|
+
data_token_pairs=data_token_pairs, data_start_token_ids=data_start_token_ids
|
1531
|
+
)
|
1529
1532
|
|
1530
1533
|
return pattern.pad_input_tokens(input_ids, mm_input)
|
1531
1534
|
|
@@ -1823,22 +1826,12 @@ class MiniCPMO(MiniCPMBaseModel):
|
|
1823
1826
|
**kwargs: Any,
|
1824
1827
|
) -> torch.Tensor:
|
1825
1828
|
|
1826
|
-
mm_input = forward_batch.merge_mm_inputs()
|
1827
|
-
placeholder_token_ids = (
|
1828
|
-
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
|
1829
|
-
if forward_batch.contains_mm_inputs()
|
1830
|
-
else []
|
1831
|
-
)
|
1832
1829
|
hidden_states = general_mm_embed_routine(
|
1833
1830
|
input_ids=input_ids,
|
1834
1831
|
forward_batch=forward_batch,
|
1835
1832
|
language_model=self.llm,
|
1836
1833
|
image_data_embedding_func=self.get_image_feature,
|
1837
1834
|
audio_data_embedding_func=self.get_audio_feature,
|
1838
|
-
placeholder_tokens={
|
1839
|
-
Modality.IMAGE: placeholder_token_ids,
|
1840
|
-
Modality.AUDIO: placeholder_token_ids,
|
1841
|
-
},
|
1842
1835
|
positions=positions,
|
1843
1836
|
)
|
1844
1837
|
return hidden_states
|
sglang/srt/models/mistral.py
CHANGED
@@ -13,6 +13,12 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Inference-only Mistral model."""
|
15
15
|
|
16
|
+
from typing import List, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from transformers.models.mistral3.modeling_mistral3 import Mistral3MultiModalProjector
|
20
|
+
|
21
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem
|
16
22
|
from sglang.srt.models.llama import LlamaForCausalLM
|
17
23
|
|
18
24
|
|
@@ -20,4 +26,68 @@ class MistralForCausalLM(LlamaForCausalLM):
|
|
20
26
|
pass
|
21
27
|
|
22
28
|
|
23
|
-
|
29
|
+
class Mistral3ForConditionalGeneration:
|
30
|
+
MULTIMODAL_PROJECTOR_TYPE = Mistral3MultiModalProjector
|
31
|
+
|
32
|
+
def __init__(self, **kwargs):
|
33
|
+
# lazy load inner class
|
34
|
+
# to bypass circular import
|
35
|
+
from sglang.srt.models.llava import LlavaForConditionalGeneration
|
36
|
+
|
37
|
+
# override config: mistral's projector adds patchmerger that doesn't require padding
|
38
|
+
kwargs["config"].vision_config.pad_image_border = False
|
39
|
+
|
40
|
+
self.inner = LlavaForConditionalGeneration(**kwargs)
|
41
|
+
self.inner.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(
|
42
|
+
kwargs["config"]
|
43
|
+
)
|
44
|
+
self.inner.get_image_feature = self.get_image_feature
|
45
|
+
|
46
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
47
|
+
"""Extract features from image inputs.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
items: List of MultimodalDataItem objects containing image data
|
51
|
+
Note that an item can be either "image" or "multi-images"
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
torch.Tensor: features from image inputs, concatenated
|
55
|
+
"""
|
56
|
+
features = []
|
57
|
+
for item in items:
|
58
|
+
# in each item, we assume pixel_values is always batched
|
59
|
+
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
60
|
+
image_outputs = self.vision_tower(
|
61
|
+
pixel_values, image_sizes, output_hidden_states=True
|
62
|
+
)
|
63
|
+
selected_image_feature = image_outputs.hidden_states[
|
64
|
+
self.vision_feature_layer
|
65
|
+
]
|
66
|
+
|
67
|
+
if self.vision_feature_select_strategy in ["default", "patch"]:
|
68
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
69
|
+
elif self.vision_feature_select_strategy == "full":
|
70
|
+
selected_image_feature = selected_image_feature
|
71
|
+
else:
|
72
|
+
raise ValueError(
|
73
|
+
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
74
|
+
)
|
75
|
+
features.append(
|
76
|
+
self.multi_modal_projector(
|
77
|
+
selected_image_feature.squeeze(0), image_sizes
|
78
|
+
)
|
79
|
+
)
|
80
|
+
ret = torch.cat(features, dim=0)
|
81
|
+
return ret
|
82
|
+
|
83
|
+
def __getattr__(self, name):
|
84
|
+
return getattr(self.inner, name)
|
85
|
+
|
86
|
+
def __hasattr__(self, name):
|
87
|
+
return hasattr(self.inner, name)
|
88
|
+
|
89
|
+
def __call__(self, *args, **kwargs):
|
90
|
+
return self.inner(*args, **kwargs)
|
91
|
+
|
92
|
+
|
93
|
+
EntryClass = [MistralForCausalLM, Mistral3ForConditionalGeneration]
|
sglang/srt/models/mixtral.py
CHANGED
@@ -16,13 +16,15 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
17
17
|
"""Inference-only Mixtral model."""
|
18
18
|
|
19
|
-
|
19
|
+
import logging
|
20
|
+
from typing import Iterable, Optional, Tuple, Union
|
20
21
|
|
21
22
|
import torch
|
22
23
|
from torch import nn
|
23
24
|
from transformers import MixtralConfig
|
24
25
|
|
25
26
|
from sglang.srt.distributed import (
|
27
|
+
get_pp_group,
|
26
28
|
get_tensor_model_parallel_world_size,
|
27
29
|
tensor_model_parallel_all_reduce,
|
28
30
|
)
|
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
38
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
42
|
from sglang.srt.layers.rotary_embedding import get_rope
|
43
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
41
44
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
42
45
|
ParallelLMHead,
|
43
46
|
VocabParallelEmbedding,
|
44
47
|
)
|
45
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
47
50
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
48
|
-
from sglang.srt.utils import add_prefix
|
51
|
+
from sglang.srt.utils import add_prefix, make_layers
|
52
|
+
|
53
|
+
logger = logging.getLogger(__name__)
|
49
54
|
|
50
55
|
|
51
56
|
class MixtralMoE(nn.Module):
|
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
|
|
257
262
|
super().__init__()
|
258
263
|
self.padding_idx = config.pad_token_id
|
259
264
|
self.vocab_size = config.vocab_size
|
265
|
+
self.pp_group = get_pp_group()
|
260
266
|
|
261
|
-
self.
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
267
|
+
if self.pp_group.is_first_rank:
|
268
|
+
self.embed_tokens = VocabParallelEmbedding(
|
269
|
+
config.vocab_size,
|
270
|
+
config.hidden_size,
|
271
|
+
prefix=add_prefix("embed_tokens", prefix),
|
272
|
+
)
|
273
|
+
else:
|
274
|
+
self.embed_tokens = PPMissingLayer()
|
275
|
+
|
276
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
277
|
+
config.num_hidden_layers,
|
278
|
+
lambda idx, prefix: MixtralDecoderLayer(
|
279
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
280
|
+
),
|
281
|
+
pp_rank=self.pp_group.rank_in_group,
|
282
|
+
pp_size=self.pp_group.world_size,
|
283
|
+
prefix="layers",
|
284
|
+
return_tuple=True,
|
276
285
|
)
|
277
|
-
|
286
|
+
|
287
|
+
if self.pp_group.is_last_rank:
|
288
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
289
|
+
else:
|
290
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
278
291
|
|
279
292
|
def forward(
|
280
293
|
self,
|
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
|
|
282
295
|
positions: torch.Tensor,
|
283
296
|
forward_batch: ForwardBatch,
|
284
297
|
input_embeds: torch.Tensor = None,
|
285
|
-
|
286
|
-
|
287
|
-
|
298
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
299
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
300
|
+
if self.pp_group.is_first_rank:
|
301
|
+
if input_embeds is None:
|
302
|
+
hidden_states = self.embed_tokens(input_ids)
|
303
|
+
else:
|
304
|
+
hidden_states = input_embeds
|
305
|
+
residual = None
|
288
306
|
else:
|
289
|
-
|
290
|
-
|
291
|
-
|
307
|
+
assert pp_proxy_tensors is not None
|
308
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
309
|
+
residual = pp_proxy_tensors["residual"]
|
310
|
+
|
311
|
+
for i in range(self.start_layer, self.end_layer):
|
292
312
|
layer = self.layers[i]
|
293
313
|
hidden_states, residual = layer(
|
294
314
|
positions, hidden_states, forward_batch, residual
|
295
315
|
)
|
296
|
-
|
316
|
+
|
317
|
+
if not self.pp_group.is_last_rank:
|
318
|
+
return PPProxyTensors(
|
319
|
+
{
|
320
|
+
"hidden_states": hidden_states,
|
321
|
+
"residual": residual,
|
322
|
+
}
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
326
|
+
|
297
327
|
return hidden_states
|
298
328
|
|
299
329
|
|
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
|
|
306
336
|
prefix: str = "",
|
307
337
|
) -> None:
|
308
338
|
super().__init__()
|
339
|
+
self.pp_group = get_pp_group()
|
309
340
|
self.config = config
|
310
341
|
self.quant_config = quant_config
|
311
342
|
self.model = MixtralModel(
|
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
|
|
322
353
|
positions: torch.Tensor,
|
323
354
|
forward_batch: ForwardBatch,
|
324
355
|
input_embeds: torch.Tensor = None,
|
356
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
325
357
|
) -> torch.Tensor:
|
326
|
-
hidden_states = self.model(
|
327
|
-
|
328
|
-
|
358
|
+
hidden_states = self.model(
|
359
|
+
input_ids,
|
360
|
+
positions,
|
361
|
+
forward_batch,
|
362
|
+
input_embeds,
|
363
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
329
364
|
)
|
330
365
|
|
366
|
+
if self.pp_group.is_last_rank:
|
367
|
+
return self.logits_processor(
|
368
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
369
|
+
)
|
370
|
+
else:
|
371
|
+
return hidden_states
|
372
|
+
|
373
|
+
@property
|
374
|
+
def start_layer(self):
|
375
|
+
return self.model.start_layer
|
376
|
+
|
377
|
+
@property
|
378
|
+
def end_layer(self):
|
379
|
+
return self.model.end_layer
|
380
|
+
|
331
381
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
332
382
|
stacked_params_mapping = [
|
333
383
|
# (param_name, shard_name, shard_id)
|
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
|
|
348
398
|
|
349
399
|
params_dict = dict(self.named_parameters())
|
350
400
|
for name, loaded_weight in weights:
|
401
|
+
layer_id = get_layer_id(name)
|
402
|
+
if (
|
403
|
+
layer_id is not None
|
404
|
+
and hasattr(self.model, "start_layer")
|
405
|
+
and (
|
406
|
+
layer_id < self.model.start_layer
|
407
|
+
or layer_id >= self.model.end_layer
|
408
|
+
)
|
409
|
+
):
|
410
|
+
continue
|
411
|
+
|
351
412
|
if "rotary_emb.inv_freq" in name:
|
352
413
|
continue
|
353
414
|
|
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
|
|
398
459
|
if name is None:
|
399
460
|
continue
|
400
461
|
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
462
|
+
if name in params_dict.keys():
|
463
|
+
param = params_dict[name]
|
464
|
+
weight_loader = getattr(
|
465
|
+
param, "weight_loader", default_weight_loader
|
466
|
+
)
|
467
|
+
weight_loader(param, loaded_weight)
|
468
|
+
else:
|
469
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
406
470
|
|
407
471
|
|
408
472
|
EntryClass = MixtralForCausalLM
|
sglang/srt/models/mllama.py
CHANGED
@@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
836
836
|
prefix="multi_modal_projector",
|
837
837
|
)
|
838
838
|
self.logits_processor = LogitsProcessor(config.text_config)
|
839
|
-
self.capture_mode = False
|
840
839
|
|
841
840
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
842
841
|
pixel_values = torch.cat(
|
@@ -865,7 +864,6 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
865
864
|
pixel_values = torch.cat(
|
866
865
|
[item.pixel_values for item in mm_input.mm_items], dim=0
|
867
866
|
)
|
868
|
-
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
|
869
867
|
max_num_images = max(max_num_images, pixel_values.shape[1])
|
870
868
|
|
871
869
|
max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
|
@@ -970,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
970
968
|
positions: torch.Tensor,
|
971
969
|
forward_batch: ForwardBatch,
|
972
970
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
971
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
972
|
+
|
973
973
|
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
|
974
974
|
self._batch_image_inputs(forward_batch)
|
975
975
|
)
|
@@ -978,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|
978
978
|
cross_attention_mask = None
|
979
979
|
cross_attention_states = None
|
980
980
|
|
981
|
-
if
|
981
|
+
if get_is_capture_mode():
|
982
982
|
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
|
983
983
|
# Make is a constant value to avoid cuda graph capture issue
|
984
984
|
skip_cross_attention = False
|