sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama.py
CHANGED
@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
44
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
45
45
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
47
|
|
48
48
|
|
49
49
|
class LlamaMLP(nn.Module):
|
@@ -162,12 +162,12 @@ class LlamaAttention(nn.Module):
|
|
162
162
|
self,
|
163
163
|
positions: torch.Tensor,
|
164
164
|
hidden_states: torch.Tensor,
|
165
|
-
|
165
|
+
forward_batch: ForwardBatch,
|
166
166
|
) -> torch.Tensor:
|
167
167
|
qkv, _ = self.qkv_proj(hidden_states)
|
168
168
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
169
169
|
q, k = self.rotary_emb(positions, q, k)
|
170
|
-
attn_output = self.attn(q, k, v,
|
170
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
171
171
|
output, _ = self.o_proj(attn_output)
|
172
172
|
return output
|
173
173
|
|
@@ -221,7 +221,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
221
221
|
self,
|
222
222
|
positions: torch.Tensor,
|
223
223
|
hidden_states: torch.Tensor,
|
224
|
-
|
224
|
+
forward_batch: ForwardBatch,
|
225
225
|
residual: Optional[torch.Tensor],
|
226
226
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
227
227
|
# Self Attention
|
@@ -233,7 +233,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
233
233
|
hidden_states = self.self_attn(
|
234
234
|
positions=positions,
|
235
235
|
hidden_states=hidden_states,
|
236
|
-
|
236
|
+
forward_batch=forward_batch,
|
237
237
|
)
|
238
238
|
|
239
239
|
# Fully Connected
|
@@ -270,7 +270,7 @@ class LlamaModel(nn.Module):
|
|
270
270
|
self,
|
271
271
|
input_ids: torch.Tensor,
|
272
272
|
positions: torch.Tensor,
|
273
|
-
|
273
|
+
forward_batch: ForwardBatch,
|
274
274
|
input_embeds: torch.Tensor = None,
|
275
275
|
) -> torch.Tensor:
|
276
276
|
if input_embeds is None:
|
@@ -283,7 +283,7 @@ class LlamaModel(nn.Module):
|
|
283
283
|
hidden_states, residual = layer(
|
284
284
|
positions,
|
285
285
|
hidden_states,
|
286
|
-
|
286
|
+
forward_batch,
|
287
287
|
residual,
|
288
288
|
)
|
289
289
|
hidden_states, _ = self.norm(hidden_states, residual)
|
@@ -310,15 +310,16 @@ class LlamaForCausalLM(nn.Module):
|
|
310
310
|
self,
|
311
311
|
input_ids: torch.Tensor,
|
312
312
|
positions: torch.Tensor,
|
313
|
-
|
313
|
+
forward_batch: ForwardBatch,
|
314
314
|
input_embeds: torch.Tensor = None,
|
315
315
|
) -> LogitsProcessorOutput:
|
316
|
-
hidden_states = self.model(input_ids, positions,
|
316
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
317
317
|
return self.logits_processor(
|
318
|
-
input_ids, hidden_states, self.lm_head.weight,
|
318
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
319
319
|
)
|
320
320
|
|
321
321
|
def get_hidden_dim(self, module_name):
|
322
|
+
# return input_dim, output_dim
|
322
323
|
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
323
324
|
return self.config.hidden_size, self.config.hidden_size
|
324
325
|
elif module_name in ["kv_proj"]:
|
@@ -399,10 +400,21 @@ class LlamaForCausalLM(nn.Module):
|
|
399
400
|
# Skip loading extra bias for GPTQ models.
|
400
401
|
if name.endswith(".bias") and name not in params_dict:
|
401
402
|
continue
|
403
|
+
# Skip loading kv_scale from ckpts towards new design.
|
404
|
+
if name.endswith(".kv_scale") and name not in params_dict:
|
405
|
+
continue
|
402
406
|
param = params_dict[name]
|
403
407
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
404
408
|
weight_loader(param, loaded_weight)
|
405
409
|
|
410
|
+
if (
|
411
|
+
hasattr(self.config, "tie_word_embeddings")
|
412
|
+
and self.config.tie_word_embeddings
|
413
|
+
):
|
414
|
+
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
415
|
+
param = self.lm_head.weight
|
416
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
417
|
+
weight_loader(param, self.model.embed_tokens.weight)
|
406
418
|
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
407
419
|
|
408
420
|
|
@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
23
23
|
|
24
24
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
25
25
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
|
-
from sglang.srt.model_executor.forward_batch_info import
|
26
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
27
27
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
28
28
|
|
29
29
|
|
@@ -50,18 +50,18 @@ class LlamaForClassification(nn.Module):
|
|
50
50
|
self,
|
51
51
|
input_ids: torch.Tensor,
|
52
52
|
positions: torch.Tensor,
|
53
|
-
|
53
|
+
forward_batch: ForwardBatch,
|
54
54
|
input_embeds: torch.Tensor = None,
|
55
55
|
) -> torch.Tensor:
|
56
|
-
hidden_states = self.model(input_ids, positions,
|
56
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
57
57
|
is_eos_token = input_ids == self.eos_token_id
|
58
58
|
hidden_states = hidden_states[is_eos_token]
|
59
59
|
scores = self.classification_head(hidden_states)
|
60
60
|
|
61
|
-
if scores.shape[0] !=
|
61
|
+
if scores.shape[0] != forward_batch.batch_size:
|
62
62
|
print("Warning: the EOS tokens are missing in some sentences.")
|
63
63
|
scores = torch.ones(
|
64
|
-
(
|
64
|
+
(forward_batch.batch_size, self.config.classification_out_size)
|
65
65
|
).to(input_ids.device)
|
66
66
|
|
67
67
|
logits_output = LogitsProcessorOutput(
|
@@ -6,7 +6,7 @@ from transformers import LlamaConfig
|
|
6
6
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
7
7
|
|
8
8
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
9
|
-
from sglang.srt.model_executor.model_runner import
|
9
|
+
from sglang.srt.model_executor.model_runner import ForwardBatch
|
10
10
|
from sglang.srt.models.llama import LlamaModel
|
11
11
|
|
12
12
|
|
@@ -26,15 +26,15 @@ class LlamaEmbeddingModel(nn.Module):
|
|
26
26
|
self,
|
27
27
|
input_ids: torch.Tensor,
|
28
28
|
positions: torch.Tensor,
|
29
|
-
|
29
|
+
forward_batch: ForwardBatch,
|
30
30
|
input_embeds: torch.Tensor = None,
|
31
31
|
get_embedding: bool = True,
|
32
32
|
) -> EmbeddingPoolerOutput:
|
33
33
|
assert (
|
34
34
|
get_embedding
|
35
35
|
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
|
36
|
-
hidden_states = self.model(input_ids, positions,
|
37
|
-
return self.pooler(hidden_states,
|
36
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
37
|
+
return self.pooler(hidden_states, forward_batch)
|
38
38
|
|
39
39
|
def load_weights(
|
40
40
|
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
@@ -0,0 +1,142 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
from typing import Iterable, Optional, Tuple
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from torch import nn
|
20
|
+
from transformers import LlamaConfig
|
21
|
+
from vllm.config import CacheConfig
|
22
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
23
|
+
|
24
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
25
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
26
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
27
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
28
|
+
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
29
|
+
|
30
|
+
|
31
|
+
class LlamaForSequenceClassification(nn.Module):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
config: LlamaConfig,
|
35
|
+
quant_config: Optional[QuantizationConfig] = None,
|
36
|
+
cache_config: Optional[CacheConfig] = None,
|
37
|
+
) -> None:
|
38
|
+
super().__init__()
|
39
|
+
self.config = config
|
40
|
+
self.torchao_config = None
|
41
|
+
self.quant_config = quant_config
|
42
|
+
self.num_labels = config.num_labels
|
43
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
44
|
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
45
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
|
46
|
+
|
47
|
+
self.eos_token_id = config.eos_token_id
|
48
|
+
|
49
|
+
@torch.no_grad()
|
50
|
+
def forward(
|
51
|
+
self,
|
52
|
+
input_ids: torch.Tensor,
|
53
|
+
positions: torch.Tensor,
|
54
|
+
forward_batch: ForwardBatch,
|
55
|
+
input_embeds: torch.Tensor = None,
|
56
|
+
) -> EmbeddingPoolerOutput:
|
57
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
58
|
+
scores = self.score(hidden_states)
|
59
|
+
|
60
|
+
return self.pooler(scores, forward_batch)
|
61
|
+
|
62
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
63
|
+
params_dict = dict(self.named_parameters())
|
64
|
+
|
65
|
+
for name, loaded_weight in weights:
|
66
|
+
if "classification_head" in name:
|
67
|
+
param = params_dict[name]
|
68
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
69
|
+
weight_loader(param, loaded_weight)
|
70
|
+
elif "lm_head" in name:
|
71
|
+
continue
|
72
|
+
else:
|
73
|
+
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
74
|
+
|
75
|
+
|
76
|
+
class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
|
77
|
+
class Weights(torch.nn.Module):
|
78
|
+
def __init__(self, hidden_size, num_label):
|
79
|
+
super().__init__()
|
80
|
+
self.fc = torch.nn.Sequential(
|
81
|
+
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
|
82
|
+
torch.nn.SELU(),
|
83
|
+
torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
|
84
|
+
torch.nn.SELU(),
|
85
|
+
torch.nn.Linear(hidden_size, num_label // 2, dtype=torch.float16),
|
86
|
+
)
|
87
|
+
|
88
|
+
def forward(self, x):
|
89
|
+
return self.fc(x.to(torch.float16))
|
90
|
+
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
config: LlamaConfig,
|
94
|
+
quant_config: Optional[QuantizationConfig] = None,
|
95
|
+
cache_config: Optional[CacheConfig] = None,
|
96
|
+
) -> None:
|
97
|
+
super().__init__(config, quant_config, cache_config)
|
98
|
+
self.weights = self.Weights(config.hidden_size, self.num_labels)
|
99
|
+
|
100
|
+
@torch.no_grad()
|
101
|
+
def forward(
|
102
|
+
self,
|
103
|
+
input_ids: torch.Tensor,
|
104
|
+
positions: torch.Tensor,
|
105
|
+
forward_batch: ForwardBatch,
|
106
|
+
input_embeds: torch.Tensor = None,
|
107
|
+
get_embedding: bool = True,
|
108
|
+
) -> EmbeddingPoolerOutput:
|
109
|
+
assert (
|
110
|
+
get_embedding
|
111
|
+
), "LlamaForSequenceClassification is only used for embedding"
|
112
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
113
|
+
logits = self.score(hidden_states)
|
114
|
+
weights = self.weights(hidden_states)
|
115
|
+
|
116
|
+
pooled_logits = self.pooler(logits, forward_batch).embeddings
|
117
|
+
pooled_weights = self.pooler(weights, forward_batch).embeddings
|
118
|
+
|
119
|
+
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
|
120
|
+
-1, self.num_labels // 2
|
121
|
+
)
|
122
|
+
scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1)
|
123
|
+
return EmbeddingPoolerOutput(scores)
|
124
|
+
|
125
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
126
|
+
params_dict = dict(self.named_parameters())
|
127
|
+
|
128
|
+
for name, loaded_weight in weights:
|
129
|
+
if "classification_head" in name:
|
130
|
+
param = params_dict[name]
|
131
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
132
|
+
weight_loader(param, loaded_weight)
|
133
|
+
elif "lm_head" in name:
|
134
|
+
continue
|
135
|
+
else:
|
136
|
+
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
137
|
+
|
138
|
+
|
139
|
+
EntryClass = [
|
140
|
+
LlamaForSequenceClassification,
|
141
|
+
LlamaForSequenceClassificationWithNormal_Weights,
|
142
|
+
]
|
sglang/srt/models/llava.py
CHANGED
@@ -35,25 +35,22 @@ from vllm.config import CacheConfig
|
|
35
35
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
36
36
|
|
37
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from sglang.srt.managers.schedule_batch import ImageInputs
|
38
39
|
from sglang.srt.mm_utils import (
|
39
40
|
get_anyres_image_grid_shape,
|
40
41
|
unpad_image,
|
41
42
|
unpad_image_shape,
|
42
43
|
)
|
43
|
-
from sglang.srt.model_executor.forward_batch_info import
|
44
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
44
45
|
from sglang.srt.models.llama import LlamaForCausalLM
|
45
46
|
from sglang.srt.models.mistral import MistralForCausalLM
|
46
47
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
47
48
|
|
48
49
|
|
49
50
|
class LlavaBaseForCausalLM(nn.Module):
|
50
|
-
def pad_input_ids(
|
51
|
-
|
52
|
-
|
53
|
-
pad_value: List[int],
|
54
|
-
pixel_values: List,
|
55
|
-
image_sizes: List[List[int]],
|
56
|
-
):
|
51
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
52
|
+
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
|
53
|
+
|
57
54
|
# hardcode for spatial_unpad + anyres
|
58
55
|
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
|
59
56
|
offset_list = []
|
@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
92
89
|
new_w = int(new_w // times)
|
93
90
|
new_image_feature_len += new_h * (new_w + 1)
|
94
91
|
|
95
|
-
pad_ids =
|
96
|
-
(new_image_feature_len + len(
|
92
|
+
pad_ids = pad_values * (
|
93
|
+
(new_image_feature_len + len(pad_values)) // len(pad_values)
|
97
94
|
)
|
98
95
|
# print("calculated new_image_feature_len: ", new_image_feature_len)
|
99
96
|
try:
|
@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
107
104
|
+ input_ids[offset + 1 :]
|
108
105
|
)
|
109
106
|
offset_list.append(offset)
|
110
|
-
|
107
|
+
|
108
|
+
image_inputs.image_offsets = offset_list
|
109
|
+
return input_ids
|
111
110
|
|
112
111
|
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
113
112
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
@@ -131,33 +130,40 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
131
130
|
self,
|
132
131
|
input_ids: torch.LongTensor,
|
133
132
|
positions: torch.Tensor,
|
134
|
-
|
135
|
-
pixel_values: Optional[List[Optional[np.array]]] = None,
|
136
|
-
image_sizes: Optional[List[List[int]]] = None,
|
137
|
-
image_offsets: Optional[List[int]] = None,
|
133
|
+
forward_batch: ForwardBatch,
|
138
134
|
) -> torch.Tensor:
|
139
|
-
|
140
|
-
|
135
|
+
image_inputs = forward_batch.image_inputs
|
136
|
+
|
137
|
+
if forward_batch.forward_mode.is_extend():
|
138
|
+
bs = forward_batch.batch_size
|
141
139
|
# Got List[List[str]] extend it to List[str]
|
142
140
|
# The length of the List should be equal to batch size
|
143
141
|
modalities_list = []
|
144
|
-
|
145
|
-
|
146
|
-
|
142
|
+
max_image_offset = []
|
143
|
+
for im in image_inputs:
|
144
|
+
if im and im.modalities is not None:
|
145
|
+
modalities_list.extend(im.modalities)
|
146
|
+
if im and im.image_offsets is not None:
|
147
|
+
max_image_offset.append(max(im.image_offsets))
|
148
|
+
else:
|
149
|
+
max_image_offset.append(-1)
|
147
150
|
|
148
151
|
# Embed text inputs
|
149
152
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
150
153
|
|
151
|
-
|
152
|
-
|
153
|
-
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
154
|
-
)
|
155
|
-
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
156
|
-
need_vision = start_positions <= max_image_offset
|
154
|
+
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
|
155
|
+
need_vision = start_positions <= np.array(max_image_offset)
|
157
156
|
|
158
157
|
if need_vision.any():
|
159
|
-
pixel_values = [
|
160
|
-
|
158
|
+
pixel_values = [
|
159
|
+
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
160
|
+
]
|
161
|
+
image_sizes = [
|
162
|
+
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
|
163
|
+
]
|
164
|
+
image_offsets = [
|
165
|
+
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
|
166
|
+
]
|
161
167
|
|
162
168
|
########## Encode Image ########
|
163
169
|
|
@@ -342,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
342
348
|
image_features = new_image_features
|
343
349
|
|
344
350
|
# Fill in the placeholder for the image
|
345
|
-
extend_start_loc_cpu =
|
346
|
-
prefix_lens_cpu =
|
351
|
+
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
352
|
+
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
347
353
|
pt = 0
|
348
354
|
for i in range(bs):
|
349
355
|
if not need_vision[i]:
|
@@ -373,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
373
379
|
pt += 1
|
374
380
|
|
375
381
|
return self.language_model(
|
376
|
-
input_ids, positions,
|
382
|
+
input_ids, positions, forward_batch, input_embeds=input_embeds
|
377
383
|
)
|
378
|
-
elif
|
379
|
-
return self.language_model(input_ids, positions,
|
384
|
+
elif forward_batch.forward_mode.is_decode():
|
385
|
+
return self.language_model(input_ids, positions, forward_batch)
|
380
386
|
|
381
387
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
382
388
|
# Load clip vision model by cfg['mm_vision_tower']:
|
sglang/srt/models/llavavid.py
CHANGED
@@ -26,7 +26,8 @@ from vllm.config import CacheConfig
|
|
26
26
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
27
|
|
28
28
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
29
|
-
from sglang.srt.
|
29
|
+
from sglang.srt.managers.schedule_batch import ImageInputs
|
30
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
31
|
from sglang.srt.models.llama import LlamaForCausalLM
|
31
32
|
|
32
33
|
|
@@ -54,17 +55,12 @@ class LlavaVidForCausalLM(nn.Module):
|
|
54
55
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
55
56
|
)
|
56
57
|
|
57
|
-
def pad_input_ids(
|
58
|
-
|
59
|
-
input_ids: List[int],
|
60
|
-
pad_value: List[int],
|
61
|
-
pixel_values: List,
|
62
|
-
image_sizes: List[List[int]],
|
63
|
-
):
|
58
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
59
|
+
pad_values = image_inputs.pad_values
|
64
60
|
new_image_feature_len = self.image_feature_len
|
65
61
|
|
66
|
-
pad_ids =
|
67
|
-
(new_image_feature_len + len(
|
62
|
+
pad_ids = pad_values * (
|
63
|
+
(new_image_feature_len + len(pad_values)) // len(pad_values)
|
68
64
|
)
|
69
65
|
offset = input_ids.index(self.config.image_token_index)
|
70
66
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
@@ -73,7 +69,8 @@ class LlavaVidForCausalLM(nn.Module):
|
|
73
69
|
+ pad_ids[:new_image_feature_len]
|
74
70
|
+ input_ids[offset + 1 :]
|
75
71
|
)
|
76
|
-
|
72
|
+
image_inputs.image_offsets = [offset]
|
73
|
+
return new_input_ids
|
77
74
|
|
78
75
|
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
79
76
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
@@ -111,26 +108,32 @@ class LlavaVidForCausalLM(nn.Module):
|
|
111
108
|
self,
|
112
109
|
input_ids: torch.LongTensor,
|
113
110
|
positions: torch.Tensor,
|
114
|
-
|
115
|
-
pixel_values: Optional[List[Optional[np.array]]] = None,
|
116
|
-
image_sizes: Optional[List[List[int]]] = None,
|
117
|
-
image_offsets: Optional[List[int]] = None,
|
111
|
+
forward_batch: ForwardBatch,
|
118
112
|
) -> torch.Tensor:
|
119
|
-
|
120
|
-
|
113
|
+
image_inputs = forward_batch.image_inputs
|
114
|
+
if forward_batch.forward_mode.is_extend():
|
115
|
+
bs = forward_batch.batch_size
|
121
116
|
|
122
117
|
# Embed text inputs
|
123
118
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
124
119
|
|
125
120
|
# Whether the requests need vision inputs
|
126
|
-
max_image_offset =
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
121
|
+
max_image_offset = []
|
122
|
+
for im in image_inputs:
|
123
|
+
if im and im.image_offsets:
|
124
|
+
max_image_offset.append(max(im.image_offsets))
|
125
|
+
else:
|
126
|
+
max_image_offset.append(-1)
|
127
|
+
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
|
128
|
+
need_vision = start_positions <= np.array(max_image_offset)
|
131
129
|
|
132
130
|
if need_vision.any():
|
133
|
-
pixel_values = [
|
131
|
+
pixel_values = [
|
132
|
+
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
133
|
+
]
|
134
|
+
image_offsets = [
|
135
|
+
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
|
136
|
+
]
|
134
137
|
|
135
138
|
########## Encode Image ########
|
136
139
|
|
@@ -166,8 +169,8 @@ class LlavaVidForCausalLM(nn.Module):
|
|
166
169
|
image_features = new_image_features
|
167
170
|
|
168
171
|
# Fill in the placeholder for the image
|
169
|
-
extend_start_loc_cpu =
|
170
|
-
prefix_lens_cpu =
|
172
|
+
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
173
|
+
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
171
174
|
pt = 0
|
172
175
|
for i in range(bs):
|
173
176
|
if not need_vision[i]:
|
@@ -197,10 +200,10 @@ class LlavaVidForCausalLM(nn.Module):
|
|
197
200
|
pt += 1
|
198
201
|
|
199
202
|
return self.language_model(
|
200
|
-
input_ids, positions,
|
203
|
+
input_ids, positions, forward_batch, input_embeds=input_embeds
|
201
204
|
)
|
202
|
-
elif
|
203
|
-
return self.language_model(input_ids, positions,
|
205
|
+
elif forward_batch.forward_mode.is_decode():
|
206
|
+
return self.language_model(input_ids, positions, forward_batch)
|
204
207
|
|
205
208
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
206
209
|
# Load clip vision model by cfg['mm_vision_tower']:
|
sglang/srt/models/minicpm.py
CHANGED
@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
|
|
39
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.model_executor.forward_batch_info import
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
43
|
|
44
44
|
|
45
45
|
class MiniCPMMLP(nn.Module):
|
@@ -148,7 +148,7 @@ class MiniCPMAttention(nn.Module):
|
|
148
148
|
self,
|
149
149
|
positions: torch.Tensor,
|
150
150
|
hidden_states: torch.Tensor,
|
151
|
-
|
151
|
+
forward_batch: ForwardBatch,
|
152
152
|
) -> torch.Tensor:
|
153
153
|
qkv, _ = self.qkv_proj(hidden_states)
|
154
154
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
@@ -156,7 +156,7 @@ class MiniCPMAttention(nn.Module):
|
|
156
156
|
q, k = q.float(), k.float()
|
157
157
|
q, k = self.rotary_emb(positions, q, k)
|
158
158
|
q, k = q.to(orig_dtype), k.to(orig_dtype)
|
159
|
-
attn_output = self.attn(q, k, v,
|
159
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
160
160
|
output, _ = self.o_proj(attn_output)
|
161
161
|
return output
|
162
162
|
|
@@ -199,7 +199,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|
199
199
|
self,
|
200
200
|
positions: torch.Tensor,
|
201
201
|
hidden_states: torch.Tensor,
|
202
|
-
|
202
|
+
forward_batch: ForwardBatch,
|
203
203
|
residual: Optional[torch.Tensor],
|
204
204
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
205
205
|
# Self Attention
|
@@ -208,7 +208,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|
208
208
|
hidden_states = self.self_attn(
|
209
209
|
positions=positions,
|
210
210
|
hidden_states=hidden_states,
|
211
|
-
|
211
|
+
forward_batch=forward_batch,
|
212
212
|
)
|
213
213
|
hidden_states = residual + hidden_states * (
|
214
214
|
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
|
@@ -252,7 +252,7 @@ class MiniCPMModel(nn.Module):
|
|
252
252
|
self,
|
253
253
|
input_ids: torch.Tensor,
|
254
254
|
positions: torch.Tensor,
|
255
|
-
|
255
|
+
forward_batch: ForwardBatch,
|
256
256
|
input_embeds: torch.Tensor = None,
|
257
257
|
) -> torch.Tensor:
|
258
258
|
if input_embeds is None:
|
@@ -266,7 +266,7 @@ class MiniCPMModel(nn.Module):
|
|
266
266
|
hidden_states, residual = layer(
|
267
267
|
positions,
|
268
268
|
hidden_states,
|
269
|
-
|
269
|
+
forward_batch,
|
270
270
|
residual,
|
271
271
|
)
|
272
272
|
hidden_states = self.norm(hidden_states)
|
@@ -303,19 +303,19 @@ class MiniCPMForCausalLM(nn.Module):
|
|
303
303
|
self,
|
304
304
|
input_ids: torch.Tensor,
|
305
305
|
positions: torch.Tensor,
|
306
|
-
|
306
|
+
forward_batch: ForwardBatch,
|
307
307
|
input_embeds: torch.Tensor = None,
|
308
308
|
) -> torch.Tensor:
|
309
309
|
if input_embeds is not None:
|
310
310
|
input_embeds = input_embeds * self.config.scale_emb
|
311
|
-
hidden_states = self.model(input_ids, positions,
|
311
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
312
312
|
hidden_states = hidden_states / self.scale_width
|
313
313
|
if self.config.tie_word_embeddings:
|
314
314
|
lm_head_weight = self.model.embed_tokens.weight
|
315
315
|
else:
|
316
316
|
lm_head_weight = self.lm_head.weight
|
317
317
|
return self.logits_processor(
|
318
|
-
input_ids, hidden_states, lm_head_weight,
|
318
|
+
input_ids, hidden_states, lm_head_weight, forward_batch
|
319
319
|
)
|
320
320
|
|
321
321
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|