sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/mixtral.py
CHANGED
@@ -16,13 +16,15 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
17
17
|
"""Inference-only Mixtral model."""
|
18
18
|
|
19
|
-
|
19
|
+
import logging
|
20
|
+
from typing import Iterable, Optional, Tuple, Union
|
20
21
|
|
21
22
|
import torch
|
22
23
|
from torch import nn
|
23
24
|
from transformers import MixtralConfig
|
24
25
|
|
25
26
|
from sglang.srt.distributed import (
|
27
|
+
get_pp_group,
|
26
28
|
get_tensor_model_parallel_world_size,
|
27
29
|
tensor_model_parallel_all_reduce,
|
28
30
|
)
|
@@ -38,14 +40,17 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
38
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
42
|
from sglang.srt.layers.rotary_embedding import get_rope
|
43
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
41
44
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
42
45
|
ParallelLMHead,
|
43
46
|
VocabParallelEmbedding,
|
44
47
|
)
|
45
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
47
50
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
48
|
-
from sglang.srt.utils import add_prefix
|
51
|
+
from sglang.srt.utils import add_prefix, make_layers
|
52
|
+
|
53
|
+
logger = logging.getLogger(__name__)
|
49
54
|
|
50
55
|
|
51
56
|
class MixtralMoE(nn.Module):
|
@@ -257,24 +262,32 @@ class MixtralModel(nn.Module):
|
|
257
262
|
super().__init__()
|
258
263
|
self.padding_idx = config.pad_token_id
|
259
264
|
self.vocab_size = config.vocab_size
|
265
|
+
self.pp_group = get_pp_group()
|
260
266
|
|
261
|
-
self.
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
267
|
+
if self.pp_group.is_first_rank:
|
268
|
+
self.embed_tokens = VocabParallelEmbedding(
|
269
|
+
config.vocab_size,
|
270
|
+
config.hidden_size,
|
271
|
+
prefix=add_prefix("embed_tokens", prefix),
|
272
|
+
)
|
273
|
+
else:
|
274
|
+
self.embed_tokens = PPMissingLayer()
|
275
|
+
|
276
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
277
|
+
config.num_hidden_layers,
|
278
|
+
lambda idx, prefix: MixtralDecoderLayer(
|
279
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
280
|
+
),
|
281
|
+
pp_rank=self.pp_group.rank_in_group,
|
282
|
+
pp_size=self.pp_group.world_size,
|
283
|
+
prefix="layers",
|
284
|
+
return_tuple=True,
|
276
285
|
)
|
277
|
-
|
286
|
+
|
287
|
+
if self.pp_group.is_last_rank:
|
288
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
289
|
+
else:
|
290
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
278
291
|
|
279
292
|
def forward(
|
280
293
|
self,
|
@@ -282,18 +295,35 @@ class MixtralModel(nn.Module):
|
|
282
295
|
positions: torch.Tensor,
|
283
296
|
forward_batch: ForwardBatch,
|
284
297
|
input_embeds: torch.Tensor = None,
|
285
|
-
|
286
|
-
|
287
|
-
|
298
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
299
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
300
|
+
if self.pp_group.is_first_rank:
|
301
|
+
if input_embeds is None:
|
302
|
+
hidden_states = self.embed_tokens(input_ids)
|
303
|
+
else:
|
304
|
+
hidden_states = input_embeds
|
305
|
+
residual = None
|
288
306
|
else:
|
289
|
-
|
290
|
-
|
291
|
-
|
307
|
+
assert pp_proxy_tensors is not None
|
308
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
309
|
+
residual = pp_proxy_tensors["residual"]
|
310
|
+
|
311
|
+
for i in range(self.start_layer, self.end_layer):
|
292
312
|
layer = self.layers[i]
|
293
313
|
hidden_states, residual = layer(
|
294
314
|
positions, hidden_states, forward_batch, residual
|
295
315
|
)
|
296
|
-
|
316
|
+
|
317
|
+
if not self.pp_group.is_last_rank:
|
318
|
+
return PPProxyTensors(
|
319
|
+
{
|
320
|
+
"hidden_states": hidden_states,
|
321
|
+
"residual": residual,
|
322
|
+
}
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
326
|
+
|
297
327
|
return hidden_states
|
298
328
|
|
299
329
|
|
@@ -306,6 +336,7 @@ class MixtralForCausalLM(nn.Module):
|
|
306
336
|
prefix: str = "",
|
307
337
|
) -> None:
|
308
338
|
super().__init__()
|
339
|
+
self.pp_group = get_pp_group()
|
309
340
|
self.config = config
|
310
341
|
self.quant_config = quant_config
|
311
342
|
self.model = MixtralModel(
|
@@ -322,12 +353,31 @@ class MixtralForCausalLM(nn.Module):
|
|
322
353
|
positions: torch.Tensor,
|
323
354
|
forward_batch: ForwardBatch,
|
324
355
|
input_embeds: torch.Tensor = None,
|
356
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
325
357
|
) -> torch.Tensor:
|
326
|
-
hidden_states = self.model(
|
327
|
-
|
328
|
-
|
358
|
+
hidden_states = self.model(
|
359
|
+
input_ids,
|
360
|
+
positions,
|
361
|
+
forward_batch,
|
362
|
+
input_embeds,
|
363
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
329
364
|
)
|
330
365
|
|
366
|
+
if self.pp_group.is_last_rank:
|
367
|
+
return self.logits_processor(
|
368
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
369
|
+
)
|
370
|
+
else:
|
371
|
+
return hidden_states
|
372
|
+
|
373
|
+
@property
|
374
|
+
def start_layer(self):
|
375
|
+
return self.model.start_layer
|
376
|
+
|
377
|
+
@property
|
378
|
+
def end_layer(self):
|
379
|
+
return self.model.end_layer
|
380
|
+
|
331
381
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
332
382
|
stacked_params_mapping = [
|
333
383
|
# (param_name, shard_name, shard_id)
|
@@ -348,6 +398,17 @@ class MixtralForCausalLM(nn.Module):
|
|
348
398
|
|
349
399
|
params_dict = dict(self.named_parameters())
|
350
400
|
for name, loaded_weight in weights:
|
401
|
+
layer_id = get_layer_id(name)
|
402
|
+
if (
|
403
|
+
layer_id is not None
|
404
|
+
and hasattr(self.model, "start_layer")
|
405
|
+
and (
|
406
|
+
layer_id < self.model.start_layer
|
407
|
+
or layer_id >= self.model.end_layer
|
408
|
+
)
|
409
|
+
):
|
410
|
+
continue
|
411
|
+
|
351
412
|
if "rotary_emb.inv_freq" in name:
|
352
413
|
continue
|
353
414
|
|
@@ -398,11 +459,14 @@ class MixtralForCausalLM(nn.Module):
|
|
398
459
|
if name is None:
|
399
460
|
continue
|
400
461
|
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
462
|
+
if name in params_dict.keys():
|
463
|
+
param = params_dict[name]
|
464
|
+
weight_loader = getattr(
|
465
|
+
param, "weight_loader", default_weight_loader
|
466
|
+
)
|
467
|
+
weight_loader(param, loaded_weight)
|
468
|
+
else:
|
469
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
406
470
|
|
407
471
|
|
408
472
|
EntryClass = MixtralForCausalLM
|
@@ -0,0 +1,467 @@
|
|
1
|
+
# Copyright 2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
"""
|
16
|
+
Using mistral-community/pixtral-12b as reference.
|
17
|
+
"""
|
18
|
+
|
19
|
+
import logging
|
20
|
+
import math
|
21
|
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import torch.nn as nn
|
25
|
+
import torch.nn.functional as F
|
26
|
+
from transformers import PixtralVisionConfig, PretrainedConfig
|
27
|
+
from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
|
28
|
+
from transformers.models.pixtral.modeling_pixtral import (
|
29
|
+
generate_block_attention_mask as _get_pixtral_attention_mask,
|
30
|
+
)
|
31
|
+
from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
|
32
|
+
|
33
|
+
from sglang.srt.layers.activation import SiluAndMul
|
34
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
35
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
36
|
+
from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
|
37
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
39
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
40
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
|
+
|
42
|
+
|
43
|
+
class PixtralHFMLP(nn.Module):
|
44
|
+
"""MLP for PixtralHFVisionModel using SGLang components."""
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
config: PretrainedConfig,
|
49
|
+
quant_config: Optional[QuantizationConfig] = None,
|
50
|
+
*,
|
51
|
+
prefix: str = "",
|
52
|
+
) -> None:
|
53
|
+
super().__init__()
|
54
|
+
|
55
|
+
assert config.intermediate_size is not None
|
56
|
+
|
57
|
+
# Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
|
58
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
59
|
+
input_size=config.hidden_size,
|
60
|
+
output_sizes=[config.intermediate_size, config.intermediate_size],
|
61
|
+
bias=False,
|
62
|
+
quant_config=quant_config,
|
63
|
+
prefix=f"{prefix}.gate_up_proj",
|
64
|
+
)
|
65
|
+
|
66
|
+
self.down_proj = RowParallelLinear(
|
67
|
+
input_size=config.intermediate_size,
|
68
|
+
output_size=config.hidden_size,
|
69
|
+
bias=False,
|
70
|
+
quant_config=quant_config,
|
71
|
+
prefix=f"{prefix}.down_proj",
|
72
|
+
)
|
73
|
+
|
74
|
+
self.act_fn = SiluAndMul()
|
75
|
+
|
76
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77
|
+
gate_up_output, _ = self.gate_up_proj(x)
|
78
|
+
|
79
|
+
# Apply SiLU activation and multiply
|
80
|
+
gate_up = self.act_fn(gate_up_output)
|
81
|
+
|
82
|
+
# Project back to hidden size
|
83
|
+
out, _ = self.down_proj(gate_up)
|
84
|
+
return out
|
85
|
+
|
86
|
+
|
87
|
+
class PixtralHFTransformerBlock(nn.Module):
|
88
|
+
"""Transformer block for PixtralHFVisionModel using SGLang components."""
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
config: PretrainedConfig,
|
93
|
+
layer_id: int,
|
94
|
+
quant_config: Optional[QuantizationConfig] = None,
|
95
|
+
*,
|
96
|
+
prefix: str = "",
|
97
|
+
) -> None:
|
98
|
+
super().__init__()
|
99
|
+
|
100
|
+
self.layer_id = layer_id
|
101
|
+
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
102
|
+
|
103
|
+
# Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
|
104
|
+
self.attention = VisionAttention(
|
105
|
+
embed_dim=config.hidden_size,
|
106
|
+
num_heads=config.num_attention_heads,
|
107
|
+
projection_size=config.hidden_size,
|
108
|
+
use_qkv_parallel=True,
|
109
|
+
quant_config=quant_config,
|
110
|
+
dropout=0.0,
|
111
|
+
use_context_forward=False,
|
112
|
+
softmax_in_single_precision=False,
|
113
|
+
flatten_batch=False,
|
114
|
+
prefix=f"{prefix}.attention",
|
115
|
+
)
|
116
|
+
|
117
|
+
self.feed_forward = PixtralHFMLP(
|
118
|
+
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
119
|
+
)
|
120
|
+
|
121
|
+
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
hidden_states: torch.Tensor,
|
126
|
+
attention_mask: Optional[torch.Tensor],
|
127
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
128
|
+
) -> torch.Tensor:
|
129
|
+
# Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
|
130
|
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
131
|
+
|
132
|
+
# Apply attention norm - normalize along the last dimension
|
133
|
+
attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
|
134
|
+
batch_size, seq_len, hidden_dim
|
135
|
+
)
|
136
|
+
|
137
|
+
# Pass through attention layer
|
138
|
+
attention_output = self.attention(
|
139
|
+
attn_normalized,
|
140
|
+
attention_mask=attention_mask,
|
141
|
+
cu_seqlens=None,
|
142
|
+
position_embeddings=position_embeddings,
|
143
|
+
)
|
144
|
+
|
145
|
+
# Apply first residual connection
|
146
|
+
hidden_states = hidden_states + attention_output
|
147
|
+
|
148
|
+
# Apply feed-forward norm - normalize along the last dimension
|
149
|
+
ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
|
150
|
+
batch_size, seq_len, hidden_dim
|
151
|
+
)
|
152
|
+
|
153
|
+
# Pass through feed-forward layer
|
154
|
+
# First reshape to 2D for the feed-forward network, then reshape back
|
155
|
+
ffn_output = self.feed_forward(ffn_normalized)
|
156
|
+
|
157
|
+
# Apply second residual connection
|
158
|
+
output = hidden_states + ffn_output
|
159
|
+
|
160
|
+
return output
|
161
|
+
|
162
|
+
|
163
|
+
class PixtralHFTransformer(nn.Module):
|
164
|
+
"""Transformer for PixtralHFVisionModel using SGLang components."""
|
165
|
+
|
166
|
+
def __init__(
|
167
|
+
self,
|
168
|
+
config: PixtralVisionConfig,
|
169
|
+
quant_config: Optional[QuantizationConfig] = None,
|
170
|
+
*,
|
171
|
+
num_hidden_layers_override: Optional[int] = None,
|
172
|
+
prefix: str = "",
|
173
|
+
) -> None:
|
174
|
+
super().__init__()
|
175
|
+
|
176
|
+
num_hidden_layers = config.num_hidden_layers
|
177
|
+
if num_hidden_layers_override is not None:
|
178
|
+
num_hidden_layers = num_hidden_layers_override
|
179
|
+
|
180
|
+
self.layers = nn.ModuleList(
|
181
|
+
[
|
182
|
+
PixtralHFTransformerBlock(
|
183
|
+
config=config,
|
184
|
+
layer_id=layer_idx,
|
185
|
+
quant_config=quant_config,
|
186
|
+
prefix=f"{prefix}.layers.{layer_idx}",
|
187
|
+
)
|
188
|
+
for layer_idx in range(num_hidden_layers)
|
189
|
+
]
|
190
|
+
)
|
191
|
+
|
192
|
+
def forward(
|
193
|
+
self,
|
194
|
+
x: torch.Tensor,
|
195
|
+
attention_mask: Optional[torch.Tensor],
|
196
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
197
|
+
return_all_hidden_states: bool = False,
|
198
|
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
199
|
+
"""Forward pass through transformer layers.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
x: Input tensor
|
203
|
+
attention_mask: Optional attention mask
|
204
|
+
position_embeddings: Optional position embeddings for rotary attention
|
205
|
+
return_all_hidden_states: Whether to return all hidden states
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
Either the final hidden state, or a list of all hidden states if
|
209
|
+
return_all_hidden_states is True
|
210
|
+
"""
|
211
|
+
# For HF model compatibility, always start with the input
|
212
|
+
hidden_states = x
|
213
|
+
all_hidden_states = [hidden_states] if return_all_hidden_states else None
|
214
|
+
|
215
|
+
for i, layer in enumerate(self.layers):
|
216
|
+
hidden_states = layer(hidden_states, attention_mask, position_embeddings)
|
217
|
+
if return_all_hidden_states:
|
218
|
+
all_hidden_states.append(hidden_states)
|
219
|
+
|
220
|
+
if return_all_hidden_states:
|
221
|
+
return all_hidden_states
|
222
|
+
return hidden_states
|
223
|
+
|
224
|
+
|
225
|
+
def resolve_visual_encoder_outputs(
|
226
|
+
outputs: Union[torch.Tensor, List[torch.Tensor]],
|
227
|
+
feature_sample_layers: Optional[List[int]],
|
228
|
+
post_norm: Optional[nn.Module],
|
229
|
+
num_hidden_layers: int,
|
230
|
+
) -> torch.Tensor:
|
231
|
+
"""Resolve outputs from visual encoder based on feature_sample_layers."""
|
232
|
+
if feature_sample_layers is None:
|
233
|
+
# Just use the last layer's output
|
234
|
+
if isinstance(outputs, list):
|
235
|
+
outputs = outputs[-1]
|
236
|
+
if post_norm is not None:
|
237
|
+
outputs = post_norm(outputs)
|
238
|
+
return outputs
|
239
|
+
|
240
|
+
# Handle the case where we want to use specific layers
|
241
|
+
if not isinstance(outputs, list):
|
242
|
+
raise ValueError(
|
243
|
+
"Expected outputs to be a list when feature_sample_layers is provided"
|
244
|
+
)
|
245
|
+
|
246
|
+
# Validate layer indices
|
247
|
+
for layer_idx in feature_sample_layers:
|
248
|
+
if layer_idx < 0 or layer_idx > num_hidden_layers:
|
249
|
+
raise ValueError(
|
250
|
+
f"Feature sample layer index {layer_idx} is out of range "
|
251
|
+
f"[0, {num_hidden_layers}]"
|
252
|
+
)
|
253
|
+
|
254
|
+
# Collect outputs from specified layers
|
255
|
+
selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
|
256
|
+
|
257
|
+
# Combine the outputs
|
258
|
+
combined_outputs = torch.cat(selected_outputs, dim=-1)
|
259
|
+
|
260
|
+
if post_norm is not None:
|
261
|
+
combined_outputs = post_norm(combined_outputs)
|
262
|
+
|
263
|
+
return combined_outputs
|
264
|
+
|
265
|
+
|
266
|
+
class PixtralHFVisionModel(nn.Module):
|
267
|
+
"""Hugging Face Pixtral Vision Model implemented using SGLang components."""
|
268
|
+
|
269
|
+
DEFAULT_IMAGE_TOKEN_ID = 10
|
270
|
+
|
271
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
272
|
+
return self.input_padder.pad_input_tokens(input_ids, image_inputs)
|
273
|
+
|
274
|
+
def __init__(
|
275
|
+
self,
|
276
|
+
config: PixtralVisionConfig,
|
277
|
+
quant_config: Optional[QuantizationConfig] = None,
|
278
|
+
*,
|
279
|
+
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
|
280
|
+
num_hidden_layers_override: Optional[int] = None,
|
281
|
+
prefix: str = "",
|
282
|
+
) -> None:
|
283
|
+
super().__init__()
|
284
|
+
|
285
|
+
self.config = config
|
286
|
+
|
287
|
+
self.image_size = config.image_size
|
288
|
+
self.patch_size = config.patch_size
|
289
|
+
|
290
|
+
self.patch_conv = nn.Conv2d(
|
291
|
+
in_channels=config.num_channels,
|
292
|
+
out_channels=config.hidden_size,
|
293
|
+
kernel_size=config.patch_size,
|
294
|
+
stride=config.patch_size,
|
295
|
+
bias=False,
|
296
|
+
)
|
297
|
+
|
298
|
+
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
299
|
+
|
300
|
+
self.transformer = PixtralHFTransformer(
|
301
|
+
config,
|
302
|
+
quant_config,
|
303
|
+
num_hidden_layers_override=num_hidden_layers_override,
|
304
|
+
prefix=f"{prefix}.transformer",
|
305
|
+
)
|
306
|
+
|
307
|
+
# Check that num_hidden_layers is valid
|
308
|
+
num_hidden_layers = config.num_hidden_layers
|
309
|
+
if len(self.transformer.layers) > config.num_hidden_layers:
|
310
|
+
raise ValueError(
|
311
|
+
f"The original encoder only has {num_hidden_layers} "
|
312
|
+
f"layers, but you requested {len(self.transformer.layers)} "
|
313
|
+
"layers."
|
314
|
+
)
|
315
|
+
|
316
|
+
# Initialize patch position embedding
|
317
|
+
self.image_token_id = image_token_id
|
318
|
+
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
319
|
+
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
|
320
|
+
[self.image_token_id]
|
321
|
+
)
|
322
|
+
|
323
|
+
@property
|
324
|
+
def dtype(self):
|
325
|
+
return next(self.parameters()).dtype
|
326
|
+
|
327
|
+
@property
|
328
|
+
def device(self):
|
329
|
+
return next(self.parameters()).device
|
330
|
+
|
331
|
+
def forward(
|
332
|
+
self,
|
333
|
+
pixel_values: torch.Tensor,
|
334
|
+
image_sizes: list[tuple[int, int]],
|
335
|
+
output_hidden_states: bool = False,
|
336
|
+
feature_sample_layers: Optional[list[int]] = None,
|
337
|
+
) -> Union[torch.Tensor, tuple]:
|
338
|
+
"""
|
339
|
+
Args:
|
340
|
+
pixel_values: [batch_size, C, H, W], padded if multiple images
|
341
|
+
image_sizes: list of (H, W) for each image in the batch
|
342
|
+
output_hidden_states: Whether to return all hidden states.
|
343
|
+
feature_sample_layers: Layer indices whose features should be
|
344
|
+
concatenated and used as the visual encoder output. If none
|
345
|
+
are provided, the last layer is used.
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
A tuple containing:
|
349
|
+
- hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
|
350
|
+
- hidden_states tuple (optional): All hidden states if output_hidden_states=True
|
351
|
+
"""
|
352
|
+
# batch patch images
|
353
|
+
embeds_orig = self.patch_conv(
|
354
|
+
pixel_values.to(device=self.device, dtype=self.dtype)
|
355
|
+
)
|
356
|
+
# crop the embeddings
|
357
|
+
embeds_2d = [
|
358
|
+
embed[..., : h // self.patch_size, : w // self.patch_size]
|
359
|
+
for embed, (h, w) in zip(embeds_orig, image_sizes)
|
360
|
+
]
|
361
|
+
|
362
|
+
# flatten to sequence
|
363
|
+
embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
|
364
|
+
embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
|
365
|
+
|
366
|
+
# positional embeddings
|
367
|
+
position_ids = position_ids_in_meshgrid(
|
368
|
+
embeds_2d,
|
369
|
+
max_width=self.image_size // self.patch_size,
|
370
|
+
).to(self.device)
|
371
|
+
|
372
|
+
# The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
|
373
|
+
# These tensors are used by apply_rotary_pos_emb in the transformer blocks
|
374
|
+
position_embedding = self.patch_positional_embedding(
|
375
|
+
embeds_featurized, position_ids
|
376
|
+
)
|
377
|
+
attention_mask = _get_pixtral_attention_mask(
|
378
|
+
[p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
|
379
|
+
)
|
380
|
+
|
381
|
+
return_all_hidden_states = (
|
382
|
+
output_hidden_states or feature_sample_layers is not None
|
383
|
+
)
|
384
|
+
|
385
|
+
transformer_outputs = self.transformer(
|
386
|
+
embeds_featurized, # add batch dimension
|
387
|
+
attention_mask,
|
388
|
+
position_embedding,
|
389
|
+
return_all_hidden_states=return_all_hidden_states,
|
390
|
+
)
|
391
|
+
|
392
|
+
# Store all hidden states if requested
|
393
|
+
all_hidden_states = None
|
394
|
+
if isinstance(transformer_outputs, list):
|
395
|
+
all_hidden_states = transformer_outputs
|
396
|
+
# Use the last layer by default if feature_sample_layers is not specified
|
397
|
+
if feature_sample_layers is None:
|
398
|
+
out = transformer_outputs[-1]
|
399
|
+
else:
|
400
|
+
# Resolve outputs based on feature sample layers
|
401
|
+
out = resolve_visual_encoder_outputs(
|
402
|
+
transformer_outputs,
|
403
|
+
feature_sample_layers,
|
404
|
+
None,
|
405
|
+
self.config.num_hidden_layers,
|
406
|
+
)
|
407
|
+
else:
|
408
|
+
out = transformer_outputs
|
409
|
+
|
410
|
+
# Format return to be compatible with HuggingFace vision models
|
411
|
+
if output_hidden_states:
|
412
|
+
return type(
|
413
|
+
"VisualOutput",
|
414
|
+
(),
|
415
|
+
{
|
416
|
+
"last_hidden_state": out,
|
417
|
+
"hidden_states": all_hidden_states,
|
418
|
+
},
|
419
|
+
)
|
420
|
+
else:
|
421
|
+
return out
|
422
|
+
|
423
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
424
|
+
"""Load weights from a HuggingFace checkpoint with proper parameter mapping."""
|
425
|
+
params_dict = dict(self.named_parameters())
|
426
|
+
|
427
|
+
# for (param, weight, shard_id): load weight into param as param's shard_id part
|
428
|
+
stacked_params_mapping = [
|
429
|
+
(".attention.qkv_proj", ".attention.q_proj", "q"),
|
430
|
+
(".attention.qkv_proj", ".attention.k_proj", "k"),
|
431
|
+
(".attention.qkv_proj", ".attention.v_proj", "v"),
|
432
|
+
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
433
|
+
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
434
|
+
]
|
435
|
+
|
436
|
+
# Process each weight
|
437
|
+
for name, loaded_weight in weights:
|
438
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
439
|
+
if weight_name in name:
|
440
|
+
# Replace the weight name part with the combined parameter name
|
441
|
+
transformed_name = name.replace(weight_name, param_name)
|
442
|
+
if transformed_name in params_dict:
|
443
|
+
param = params_dict[transformed_name]
|
444
|
+
weight_loader = getattr(
|
445
|
+
param, "weight_loader", default_weight_loader
|
446
|
+
)
|
447
|
+
weight_loader(param, loaded_weight, shard_id)
|
448
|
+
break
|
449
|
+
else:
|
450
|
+
if ".attention.o_proj" in name:
|
451
|
+
alt_name = name.replace(".attention.o_proj", ".attention.proj")
|
452
|
+
if alt_name in params_dict:
|
453
|
+
name = alt_name
|
454
|
+
if name in params_dict:
|
455
|
+
param = params_dict[name]
|
456
|
+
weight_loader = getattr(
|
457
|
+
param, "weight_loader", default_weight_loader
|
458
|
+
)
|
459
|
+
weight_loader(param, loaded_weight)
|
460
|
+
|
461
|
+
|
462
|
+
class PixtralVisionModel(PixtralHFVisionModel):
|
463
|
+
pass
|
464
|
+
|
465
|
+
|
466
|
+
# Register the model classes for external access
|
467
|
+
EntryClass = [PixtralVisionModel]
|
sglang/srt/models/roberta.py
CHANGED
@@ -57,7 +57,7 @@ class RobertaEmbedding(nn.Module):
|
|
57
57
|
input_shape = input_ids.size()
|
58
58
|
inputs_embeds = self.word_embeddings(input_ids)
|
59
59
|
|
60
|
-
#
|
60
|
+
# Adapted from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
|
61
61
|
|
62
62
|
pos_list = []
|
63
63
|
token_list = []
|