sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama2.py
CHANGED
@@ -1,34 +1,36 @@
|
|
1
1
|
# Adapted from
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
3
3
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
4
|
-
|
4
|
+
|
5
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
6
|
|
6
7
|
import torch
|
7
|
-
|
8
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
9
|
-
from sglang.srt.managers.router.model_runner import InputMetadata
|
8
|
+
import tqdm
|
10
9
|
from torch import nn
|
11
10
|
from transformers import LlamaConfig
|
11
|
+
from vllm.config import CacheConfig
|
12
|
+
from vllm.distributed import (
|
13
|
+
get_tensor_model_parallel_rank,
|
14
|
+
get_tensor_model_parallel_world_size,
|
15
|
+
)
|
12
16
|
from vllm.model_executor.layers.activation import SiluAndMul
|
13
17
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
14
18
|
from vllm.model_executor.layers.linear import (
|
15
|
-
LinearMethodBase,
|
16
19
|
MergedColumnParallelLinear,
|
17
20
|
QKVParallelLinear,
|
18
21
|
RowParallelLinear,
|
19
22
|
)
|
23
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
20
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
21
25
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
22
26
|
ParallelLMHead,
|
23
27
|
VocabParallelEmbedding,
|
24
28
|
)
|
25
|
-
from vllm.model_executor.
|
26
|
-
|
27
|
-
|
28
|
-
from
|
29
|
-
|
30
|
-
hf_model_weights_iterator,
|
31
|
-
)
|
29
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
|
+
|
31
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
32
34
|
|
33
35
|
|
34
36
|
class LlamaMLP(nn.Module):
|
@@ -37,17 +39,20 @@ class LlamaMLP(nn.Module):
|
|
37
39
|
hidden_size: int,
|
38
40
|
intermediate_size: int,
|
39
41
|
hidden_act: str,
|
40
|
-
|
42
|
+
quant_config: Optional[QuantizationConfig] = None,
|
41
43
|
) -> None:
|
42
44
|
super().__init__()
|
43
45
|
self.gate_up_proj = MergedColumnParallelLinear(
|
44
46
|
hidden_size,
|
45
47
|
[intermediate_size] * 2,
|
46
48
|
bias=False,
|
47
|
-
|
49
|
+
quant_config=quant_config,
|
48
50
|
)
|
49
51
|
self.down_proj = RowParallelLinear(
|
50
|
-
intermediate_size,
|
52
|
+
intermediate_size,
|
53
|
+
hidden_size,
|
54
|
+
bias=False,
|
55
|
+
quant_config=quant_config,
|
51
56
|
)
|
52
57
|
if hidden_act != "silu":
|
53
58
|
raise ValueError(
|
@@ -72,8 +77,9 @@ class LlamaAttention(nn.Module):
|
|
72
77
|
layer_id: int = 0,
|
73
78
|
rope_theta: float = 10000,
|
74
79
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
80
|
+
rope_is_neox_style: bool = True,
|
75
81
|
max_position_embeddings: int = 8192,
|
76
|
-
|
82
|
+
quant_config: Optional[QuantizationConfig] = None,
|
77
83
|
) -> None:
|
78
84
|
super().__init__()
|
79
85
|
self.hidden_size = hidden_size
|
@@ -104,13 +110,13 @@ class LlamaAttention(nn.Module):
|
|
104
110
|
self.total_num_heads,
|
105
111
|
self.total_num_kv_heads,
|
106
112
|
bias=False,
|
107
|
-
|
113
|
+
quant_config=quant_config,
|
108
114
|
)
|
109
115
|
self.o_proj = RowParallelLinear(
|
110
116
|
self.total_num_heads * self.head_dim,
|
111
117
|
hidden_size,
|
112
118
|
bias=False,
|
113
|
-
|
119
|
+
quant_config=quant_config,
|
114
120
|
)
|
115
121
|
|
116
122
|
self.rotary_emb = get_rope(
|
@@ -119,6 +125,7 @@ class LlamaAttention(nn.Module):
|
|
119
125
|
max_position=max_position_embeddings,
|
120
126
|
base=rope_theta,
|
121
127
|
rope_scaling=rope_scaling,
|
128
|
+
is_neox_style=rope_is_neox_style,
|
122
129
|
)
|
123
130
|
self.attn = RadixAttention(
|
124
131
|
self.num_heads,
|
@@ -147,12 +154,19 @@ class LlamaDecoderLayer(nn.Module):
|
|
147
154
|
self,
|
148
155
|
config: LlamaConfig,
|
149
156
|
layer_id: int = 0,
|
150
|
-
|
157
|
+
quant_config: Optional[QuantizationConfig] = None,
|
151
158
|
) -> None:
|
152
159
|
super().__init__()
|
153
160
|
self.hidden_size = config.hidden_size
|
154
161
|
rope_theta = getattr(config, "rope_theta", 10000)
|
155
162
|
rope_scaling = getattr(config, "rope_scaling", None)
|
163
|
+
if rope_scaling is not None and getattr(
|
164
|
+
config, "original_max_position_embeddings", None
|
165
|
+
):
|
166
|
+
rope_scaling[
|
167
|
+
"original_max_position_embeddings"
|
168
|
+
] = config.original_max_position_embeddings
|
169
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
156
170
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
157
171
|
self.self_attn = LlamaAttention(
|
158
172
|
hidden_size=self.hidden_size,
|
@@ -161,14 +175,15 @@ class LlamaDecoderLayer(nn.Module):
|
|
161
175
|
layer_id=layer_id,
|
162
176
|
rope_theta=rope_theta,
|
163
177
|
rope_scaling=rope_scaling,
|
178
|
+
rope_is_neox_style=rope_is_neox_style,
|
164
179
|
max_position_embeddings=max_position_embeddings,
|
165
|
-
|
180
|
+
quant_config=quant_config,
|
166
181
|
)
|
167
182
|
self.mlp = LlamaMLP(
|
168
183
|
hidden_size=self.hidden_size,
|
169
184
|
intermediate_size=config.intermediate_size,
|
170
185
|
hidden_act=config.hidden_act,
|
171
|
-
|
186
|
+
quant_config=quant_config,
|
172
187
|
)
|
173
188
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
174
189
|
self.post_attention_layernorm = RMSNorm(
|
@@ -204,7 +219,7 @@ class LlamaModel(nn.Module):
|
|
204
219
|
def __init__(
|
205
220
|
self,
|
206
221
|
config: LlamaConfig,
|
207
|
-
|
222
|
+
quant_config: Optional[QuantizationConfig] = None,
|
208
223
|
) -> None:
|
209
224
|
super().__init__()
|
210
225
|
self.config = config
|
@@ -216,7 +231,7 @@ class LlamaModel(nn.Module):
|
|
216
231
|
)
|
217
232
|
self.layers = nn.ModuleList(
|
218
233
|
[
|
219
|
-
LlamaDecoderLayer(config, i,
|
234
|
+
LlamaDecoderLayer(config, i, quant_config=quant_config)
|
220
235
|
for i in range(config.num_hidden_layers)
|
221
236
|
]
|
222
237
|
)
|
@@ -250,12 +265,13 @@ class LlamaForCausalLM(nn.Module):
|
|
250
265
|
def __init__(
|
251
266
|
self,
|
252
267
|
config: LlamaConfig,
|
253
|
-
|
268
|
+
quant_config: Optional[QuantizationConfig] = None,
|
269
|
+
cache_config: Optional[CacheConfig] = None,
|
254
270
|
) -> None:
|
255
271
|
super().__init__()
|
256
272
|
self.config = config
|
257
|
-
self.
|
258
|
-
self.model = LlamaModel(config,
|
273
|
+
self.quant_config = quant_config
|
274
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
259
275
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
260
276
|
self.logits_processor = LogitsProcessor(config)
|
261
277
|
|
@@ -271,13 +287,7 @@ class LlamaForCausalLM(nn.Module):
|
|
271
287
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
272
288
|
)
|
273
289
|
|
274
|
-
def load_weights(
|
275
|
-
self,
|
276
|
-
model_name_or_path: str,
|
277
|
-
cache_dir: Optional[str] = None,
|
278
|
-
load_format: str = "auto",
|
279
|
-
revision: Optional[str] = None,
|
280
|
-
):
|
290
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
281
291
|
stacked_params_mapping = [
|
282
292
|
# (param_name, shard_name, shard_id)
|
283
293
|
("qkv_proj", "q_proj", "q"),
|
@@ -287,9 +297,9 @@ class LlamaForCausalLM(nn.Module):
|
|
287
297
|
("gate_up_proj", "up_proj", 1),
|
288
298
|
]
|
289
299
|
params_dict = dict(self.named_parameters())
|
290
|
-
|
291
|
-
|
292
|
-
|
300
|
+
if get_tensor_model_parallel_rank() == 0:
|
301
|
+
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
302
|
+
for name, loaded_weight in weights:
|
293
303
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
294
304
|
continue
|
295
305
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
@@ -0,0 +1,107 @@
|
|
1
|
+
from typing import Iterable, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import tqdm
|
5
|
+
from torch import nn
|
6
|
+
from transformers import LlamaConfig
|
7
|
+
from vllm.config import CacheConfig
|
8
|
+
from vllm.distributed import get_tensor_model_parallel_rank
|
9
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
10
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
11
|
+
|
12
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
13
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
14
|
+
from sglang.srt.models.llama2 import LlamaModel
|
15
|
+
|
16
|
+
|
17
|
+
class LlamaForClassification(nn.Module):
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
config: LlamaConfig,
|
21
|
+
quant_config: Optional[QuantizationConfig] = None,
|
22
|
+
cache_config: Optional[CacheConfig] = None,
|
23
|
+
) -> None:
|
24
|
+
super().__init__()
|
25
|
+
self.config = config
|
26
|
+
self.quant_config = quant_config
|
27
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
28
|
+
|
29
|
+
self.classification_head = nn.Linear(
|
30
|
+
config.hidden_size, config.classification_out_size
|
31
|
+
)
|
32
|
+
self.eos_token_id = config.eos_token_id
|
33
|
+
|
34
|
+
def forward(
|
35
|
+
self,
|
36
|
+
input_ids: torch.Tensor,
|
37
|
+
positions: torch.Tensor,
|
38
|
+
input_metadata: InputMetadata,
|
39
|
+
input_embeds: torch.Tensor = None,
|
40
|
+
) -> torch.Tensor:
|
41
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
42
|
+
is_eos_token = input_ids == self.eos_token_id
|
43
|
+
hidden_states = hidden_states[is_eos_token]
|
44
|
+
scores = self.classification_head(hidden_states)
|
45
|
+
|
46
|
+
if scores.shape[0] != input_metadata.batch_size:
|
47
|
+
print("Warning: the EOS tokens are missing in some sentences.")
|
48
|
+
scores = torch.ones(
|
49
|
+
(input_metadata.batch_size, self.config.classification_out_size)
|
50
|
+
).to(input_ids.device)
|
51
|
+
|
52
|
+
return LogitProcessorOutput(
|
53
|
+
next_token_logits=scores,
|
54
|
+
next_token_logprobs=scores,
|
55
|
+
normalized_prompt_logprobs=scores,
|
56
|
+
prefill_token_logprobs=torch.ones_like(input_ids),
|
57
|
+
prefill_top_logprobs=None,
|
58
|
+
decode_top_logprobs=None,
|
59
|
+
)
|
60
|
+
|
61
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
62
|
+
stacked_params_mapping = [
|
63
|
+
# (param_name, shard_name, shard_id)
|
64
|
+
("qkv_proj", "q_proj", "q"),
|
65
|
+
("qkv_proj", "k_proj", "k"),
|
66
|
+
("qkv_proj", "v_proj", "v"),
|
67
|
+
("gate_up_proj", "gate_proj", 0),
|
68
|
+
("gate_up_proj", "up_proj", 1),
|
69
|
+
]
|
70
|
+
params_dict = dict(self.named_parameters())
|
71
|
+
if get_tensor_model_parallel_rank() == 0:
|
72
|
+
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
73
|
+
for name, loaded_weight in weights:
|
74
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
75
|
+
continue
|
76
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
77
|
+
# Models trained using ColossalAI may include these tensors in
|
78
|
+
# the checkpoint. Skip them.
|
79
|
+
continue
|
80
|
+
if "lm_head" in name:
|
81
|
+
continue
|
82
|
+
|
83
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
84
|
+
if weight_name not in name:
|
85
|
+
continue
|
86
|
+
name = name.replace(weight_name, param_name)
|
87
|
+
# Skip loading extra bias for GPTQ models.
|
88
|
+
if name.endswith(".bias") and name not in params_dict:
|
89
|
+
continue
|
90
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
91
|
+
continue
|
92
|
+
param = params_dict[name]
|
93
|
+
weight_loader = param.weight_loader
|
94
|
+
weight_loader(param, loaded_weight, shard_id)
|
95
|
+
break
|
96
|
+
else:
|
97
|
+
# Skip loading extra bias for GPTQ models.
|
98
|
+
if name.endswith(".bias") and name not in params_dict:
|
99
|
+
continue
|
100
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
101
|
+
continue
|
102
|
+
param = params_dict[name]
|
103
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
104
|
+
weight_loader(param, loaded_weight)
|
105
|
+
|
106
|
+
|
107
|
+
EntryClass = LlamaForClassification
|
sglang/srt/models/llava.py
CHANGED
@@ -1,32 +1,40 @@
|
|
1
1
|
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
2
2
|
|
3
|
-
from typing import List, Optional
|
3
|
+
from typing import Iterable, List, Optional, Tuple
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import torch
|
7
|
-
from
|
8
|
-
from
|
7
|
+
from torch import nn
|
8
|
+
from transformers import (
|
9
|
+
CLIPVisionConfig,
|
10
|
+
CLIPVisionModel,
|
11
|
+
LlavaConfig,
|
12
|
+
MistralConfig,
|
13
|
+
Qwen2Config,
|
14
|
+
)
|
15
|
+
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
16
|
+
from vllm.config import CacheConfig
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
18
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
19
|
+
|
20
|
+
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
21
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
9
22
|
from sglang.srt.mm_utils import (
|
10
23
|
get_anyres_image_grid_shape,
|
11
24
|
unpad_image,
|
12
25
|
unpad_image_shape,
|
13
26
|
)
|
14
27
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
15
|
-
from
|
16
|
-
from
|
17
|
-
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
18
|
-
from vllm.model_executor.layers.linear import LinearMethodBase
|
19
|
-
from vllm.model_executor.weight_utils import (
|
20
|
-
default_weight_loader,
|
21
|
-
hf_model_weights_iterator,
|
22
|
-
)
|
28
|
+
from sglang.srt.models.mistral import MistralForCausalLM
|
29
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
23
30
|
|
24
31
|
|
25
32
|
class LlavaLlamaForCausalLM(nn.Module):
|
26
33
|
def __init__(
|
27
34
|
self,
|
28
35
|
config: LlavaConfig,
|
29
|
-
|
36
|
+
quant_config: Optional[QuantizationConfig] = None,
|
37
|
+
cache_config: Optional[CacheConfig] = None,
|
30
38
|
) -> None:
|
31
39
|
super().__init__()
|
32
40
|
self.config = config
|
@@ -34,7 +42,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
34
42
|
self.config.vision_config.hidden_size = config.mm_hidden_size
|
35
43
|
self.config.text_config.hidden_size = config.hidden_size
|
36
44
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
37
|
-
self.language_model = LlamaForCausalLM(config,
|
45
|
+
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
38
46
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
39
47
|
self.language_model.model.image_newline = nn.Parameter(
|
40
48
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
@@ -235,13 +243,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
235
243
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
236
244
|
return self.language_model(input_ids, positions, input_metadata)
|
237
245
|
|
238
|
-
def load_weights(
|
239
|
-
self,
|
240
|
-
model_name_or_path: str,
|
241
|
-
cache_dir: Optional[str] = None,
|
242
|
-
load_format: str = "auto",
|
243
|
-
revision: Optional[str] = None,
|
244
|
-
):
|
246
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
245
247
|
# load clip vision model by cfg['mm_vision_tower']:
|
246
248
|
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
247
249
|
vision_path = self.config.mm_vision_tower
|
@@ -274,9 +276,8 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
274
276
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
275
277
|
}
|
276
278
|
params_dict = dict(self.named_parameters())
|
277
|
-
|
278
|
-
|
279
|
-
):
|
279
|
+
weights = list(weights)
|
280
|
+
for name, loaded_weight in weights:
|
280
281
|
# FIXME: why projector weights read two times?
|
281
282
|
if "projector" in name or "vision_tower" in name:
|
282
283
|
for weight_name, param_name in projector_weights.items():
|
@@ -287,9 +288,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
287
288
|
weight_loader(param, loaded_weight)
|
288
289
|
|
289
290
|
# load language model
|
290
|
-
self.language_model.load_weights(
|
291
|
-
model_name_or_path, cache_dir, load_format, revision
|
292
|
-
)
|
291
|
+
self.language_model.load_weights(weights)
|
293
292
|
|
294
293
|
monkey_path_clip_vision_embed_forward()
|
295
294
|
|
@@ -298,6 +297,72 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
298
297
|
return self.image_size // self.patch_size
|
299
298
|
|
300
299
|
|
300
|
+
class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
301
|
+
def __init__(
|
302
|
+
self,
|
303
|
+
config: LlavaConfig,
|
304
|
+
quant_config: Optional[QuantizationConfig] = None,
|
305
|
+
cache_config: Optional[CacheConfig] = None,
|
306
|
+
) -> None:
|
307
|
+
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
308
|
+
self.config = config
|
309
|
+
self.vision_tower = None
|
310
|
+
if getattr(self.config, "vision_config", None) is None:
|
311
|
+
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
312
|
+
|
313
|
+
if getattr(self.config, "text_config", None) is None:
|
314
|
+
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
315
|
+
|
316
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
317
|
+
self.config.text_config.hidden_size = config.hidden_size
|
318
|
+
|
319
|
+
if getattr(self.config, "projector_hidden_act", None) is None:
|
320
|
+
self.config.projector_hidden_act = "gelu"
|
321
|
+
|
322
|
+
if getattr(self.config, "image_token_index", None) is None:
|
323
|
+
self.config.image_token_index = 151646
|
324
|
+
|
325
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
326
|
+
self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
|
327
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
328
|
+
self.language_model.model.image_newline = nn.Parameter(
|
329
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
330
|
+
)
|
331
|
+
|
332
|
+
|
333
|
+
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
334
|
+
def __init__(
|
335
|
+
self,
|
336
|
+
config: LlavaConfig,
|
337
|
+
quant_config: Optional[QuantizationConfig] = None,
|
338
|
+
cache_config: Optional[CacheConfig] = None,
|
339
|
+
) -> None:
|
340
|
+
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
341
|
+
self.config = config
|
342
|
+
self.vision_tower = None
|
343
|
+
if getattr(self.config, "vision_config", None) is None:
|
344
|
+
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
345
|
+
|
346
|
+
if getattr(self.config, "text_config", None) is None:
|
347
|
+
self.config.text_config = MistralConfig(self.config._name_or_path)
|
348
|
+
|
349
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
350
|
+
self.config.text_config.hidden_size = config.hidden_size
|
351
|
+
|
352
|
+
if getattr(self.config, "projector_hidden_act", None) is None:
|
353
|
+
self.config.projector_hidden_act = "gelu"
|
354
|
+
|
355
|
+
if getattr(self.config, "image_token_index", None) is None:
|
356
|
+
self.config.image_token_index = 32000
|
357
|
+
|
358
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
359
|
+
self.language_model = MistralForCausalLM(config, quant_config=quant_config)
|
360
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
361
|
+
self.language_model.model.image_newline = nn.Parameter(
|
362
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
363
|
+
)
|
364
|
+
|
365
|
+
|
301
366
|
first_call = True
|
302
367
|
|
303
368
|
|
@@ -330,4 +395,4 @@ def monkey_path_clip_vision_embed_forward():
|
|
330
395
|
)
|
331
396
|
|
332
397
|
|
333
|
-
EntryClass = LlavaLlamaForCausalLM
|
398
|
+
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|