sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +7 -7
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +158 -11
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +12 -2
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +28 -3
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +8 -2
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +3 -1
- sglang/srt/hf_transformers_utils.py +130 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +582 -0
- sglang/srt/layers/logits_processor.py +65 -32
- sglang/srt/layers/radix_attention.py +41 -7
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
- sglang/srt/managers/{router → controller}/model_runner.py +262 -158
- sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
- sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
- sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
- sglang/srt/managers/detokenizer_manager.py +42 -46
- sglang/srt/managers/io_struct.py +22 -12
- sglang/srt/managers/tokenizer_manager.py +151 -87
- sglang/srt/model_config.py +83 -5
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +12 -15
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +26 -15
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +86 -19
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +282 -103
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +150 -95
- sglang/srt/openai_protocol.py +11 -2
- sglang/srt/server.py +124 -48
- sglang/srt/server_args.py +128 -48
- sglang/srt/utils.py +234 -67
- sglang/test/test_programs.py +65 -3
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +23 -4
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama2.py
CHANGED
@@ -1,12 +1,18 @@
|
|
1
1
|
# Adapted from
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
3
3
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
4
|
-
|
4
|
+
|
5
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
6
|
|
6
7
|
import torch
|
8
|
+
import tqdm
|
7
9
|
from torch import nn
|
8
10
|
from transformers import LlamaConfig
|
9
|
-
from vllm.
|
11
|
+
from vllm.config import CacheConfig
|
12
|
+
from vllm.distributed import (
|
13
|
+
get_tensor_model_parallel_rank,
|
14
|
+
get_tensor_model_parallel_world_size,
|
15
|
+
)
|
10
16
|
from vllm.model_executor.layers.activation import SiluAndMul
|
11
17
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
12
18
|
from vllm.model_executor.layers.linear import (
|
@@ -20,11 +26,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
20
26
|
ParallelLMHead,
|
21
27
|
VocabParallelEmbedding,
|
22
28
|
)
|
29
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
23
30
|
|
24
31
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
25
32
|
from sglang.srt.layers.radix_attention import RadixAttention
|
26
|
-
from sglang.srt.managers.
|
27
|
-
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
33
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
28
34
|
|
29
35
|
|
30
36
|
class LlamaMLP(nn.Module):
|
@@ -71,6 +77,7 @@ class LlamaAttention(nn.Module):
|
|
71
77
|
layer_id: int = 0,
|
72
78
|
rope_theta: float = 10000,
|
73
79
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
80
|
+
rope_is_neox_style: bool = True,
|
74
81
|
max_position_embeddings: int = 8192,
|
75
82
|
quant_config: Optional[QuantizationConfig] = None,
|
76
83
|
) -> None:
|
@@ -118,6 +125,7 @@ class LlamaAttention(nn.Module):
|
|
118
125
|
max_position=max_position_embeddings,
|
119
126
|
base=rope_theta,
|
120
127
|
rope_scaling=rope_scaling,
|
128
|
+
is_neox_style=rope_is_neox_style,
|
121
129
|
)
|
122
130
|
self.attn = RadixAttention(
|
123
131
|
self.num_heads,
|
@@ -152,6 +160,13 @@ class LlamaDecoderLayer(nn.Module):
|
|
152
160
|
self.hidden_size = config.hidden_size
|
153
161
|
rope_theta = getattr(config, "rope_theta", 10000)
|
154
162
|
rope_scaling = getattr(config, "rope_scaling", None)
|
163
|
+
if rope_scaling is not None and getattr(
|
164
|
+
config, "original_max_position_embeddings", None
|
165
|
+
):
|
166
|
+
rope_scaling["original_max_position_embeddings"] = (
|
167
|
+
config.original_max_position_embeddings
|
168
|
+
)
|
169
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
155
170
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
156
171
|
self.self_attn = LlamaAttention(
|
157
172
|
hidden_size=self.hidden_size,
|
@@ -160,6 +175,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
160
175
|
layer_id=layer_id,
|
161
176
|
rope_theta=rope_theta,
|
162
177
|
rope_scaling=rope_scaling,
|
178
|
+
rope_is_neox_style=rope_is_neox_style,
|
163
179
|
max_position_embeddings=max_position_embeddings,
|
164
180
|
quant_config=quant_config,
|
165
181
|
)
|
@@ -250,6 +266,7 @@ class LlamaForCausalLM(nn.Module):
|
|
250
266
|
self,
|
251
267
|
config: LlamaConfig,
|
252
268
|
quant_config: Optional[QuantizationConfig] = None,
|
269
|
+
cache_config: Optional[CacheConfig] = None,
|
253
270
|
) -> None:
|
254
271
|
super().__init__()
|
255
272
|
self.config = config
|
@@ -270,13 +287,7 @@ class LlamaForCausalLM(nn.Module):
|
|
270
287
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
271
288
|
)
|
272
289
|
|
273
|
-
def load_weights(
|
274
|
-
self,
|
275
|
-
model_name_or_path: str,
|
276
|
-
cache_dir: Optional[str] = None,
|
277
|
-
load_format: str = "auto",
|
278
|
-
revision: Optional[str] = None,
|
279
|
-
):
|
290
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
280
291
|
stacked_params_mapping = [
|
281
292
|
# (param_name, shard_name, shard_id)
|
282
293
|
("qkv_proj", "q_proj", "q"),
|
@@ -286,9 +297,9 @@ class LlamaForCausalLM(nn.Module):
|
|
286
297
|
("gate_up_proj", "up_proj", 1),
|
287
298
|
]
|
288
299
|
params_dict = dict(self.named_parameters())
|
289
|
-
|
290
|
-
|
291
|
-
|
300
|
+
if get_tensor_model_parallel_rank() == 0:
|
301
|
+
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
302
|
+
for name, loaded_weight in weights:
|
292
303
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
293
304
|
continue
|
294
305
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
@@ -0,0 +1,104 @@
|
|
1
|
+
from typing import Iterable, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import tqdm
|
5
|
+
from torch import nn
|
6
|
+
from transformers import LlamaConfig
|
7
|
+
from vllm.config import CacheConfig
|
8
|
+
from vllm.distributed import (
|
9
|
+
get_tensor_model_parallel_rank,
|
10
|
+
)
|
11
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
12
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
13
|
+
|
14
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
15
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
16
|
+
from sglang.srt.models.llama2 import LlamaModel
|
17
|
+
|
18
|
+
|
19
|
+
class LlamaForClassification(nn.Module):
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
config: LlamaConfig,
|
23
|
+
quant_config: Optional[QuantizationConfig] = None,
|
24
|
+
cache_config: Optional[CacheConfig] = None,
|
25
|
+
) -> None:
|
26
|
+
super().__init__()
|
27
|
+
self.config = config
|
28
|
+
self.quant_config = quant_config
|
29
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
30
|
+
|
31
|
+
self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size)
|
32
|
+
self.eos_token_id = config.eos_token_id
|
33
|
+
|
34
|
+
def forward(
|
35
|
+
self,
|
36
|
+
input_ids: torch.Tensor,
|
37
|
+
positions: torch.Tensor,
|
38
|
+
input_metadata: InputMetadata,
|
39
|
+
input_embeds: torch.Tensor = None,
|
40
|
+
) -> torch.Tensor:
|
41
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
42
|
+
is_eos_token = input_ids == self.eos_token_id
|
43
|
+
hidden_states = hidden_states[is_eos_token]
|
44
|
+
scores = self.classification_head(hidden_states)
|
45
|
+
|
46
|
+
if scores.shape[0] != input_metadata.batch_size:
|
47
|
+
print("Warning: the EOS tokens are missing in some sentences.")
|
48
|
+
scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device)
|
49
|
+
|
50
|
+
return LogitProcessorOutput(
|
51
|
+
next_token_logits=scores,
|
52
|
+
next_token_logprobs=scores,
|
53
|
+
normalized_prompt_logprobs=scores,
|
54
|
+
prefill_token_logprobs=torch.ones_like(input_ids),
|
55
|
+
prefill_top_logprobs=None,
|
56
|
+
decode_top_logprobs=None,
|
57
|
+
)
|
58
|
+
|
59
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
60
|
+
stacked_params_mapping = [
|
61
|
+
# (param_name, shard_name, shard_id)
|
62
|
+
("qkv_proj", "q_proj", "q"),
|
63
|
+
("qkv_proj", "k_proj", "k"),
|
64
|
+
("qkv_proj", "v_proj", "v"),
|
65
|
+
("gate_up_proj", "gate_proj", 0),
|
66
|
+
("gate_up_proj", "up_proj", 1),
|
67
|
+
]
|
68
|
+
params_dict = dict(self.named_parameters())
|
69
|
+
if get_tensor_model_parallel_rank() == 0:
|
70
|
+
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
71
|
+
for name, loaded_weight in weights:
|
72
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
73
|
+
continue
|
74
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
75
|
+
# Models trained using ColossalAI may include these tensors in
|
76
|
+
# the checkpoint. Skip them.
|
77
|
+
continue
|
78
|
+
if "lm_head" in name:
|
79
|
+
continue
|
80
|
+
|
81
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
82
|
+
if weight_name not in name:
|
83
|
+
continue
|
84
|
+
name = name.replace(weight_name, param_name)
|
85
|
+
# Skip loading extra bias for GPTQ models.
|
86
|
+
if name.endswith(".bias") and name not in params_dict:
|
87
|
+
continue
|
88
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
89
|
+
continue
|
90
|
+
param = params_dict[name]
|
91
|
+
weight_loader = param.weight_loader
|
92
|
+
weight_loader(param, loaded_weight, shard_id)
|
93
|
+
break
|
94
|
+
else:
|
95
|
+
# Skip loading extra bias for GPTQ models.
|
96
|
+
if name.endswith(".bias") and name not in params_dict:
|
97
|
+
continue
|
98
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
99
|
+
continue
|
100
|
+
param = params_dict[name]
|
101
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
102
|
+
weight_loader(param, loaded_weight)
|
103
|
+
|
104
|
+
EntryClass = LlamaForClassification
|
sglang/srt/models/llava.py
CHANGED
@@ -1,23 +1,32 @@
|
|
1
1
|
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
2
2
|
|
3
|
-
from typing import List, Optional
|
3
|
+
from typing import Iterable, List, Optional, Tuple
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
|
-
from transformers import
|
8
|
+
from transformers import (
|
9
|
+
CLIPVisionConfig,
|
10
|
+
CLIPVisionModel,
|
11
|
+
LlavaConfig,
|
12
|
+
MistralConfig,
|
13
|
+
Qwen2Config,
|
14
|
+
)
|
9
15
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
16
|
+
from vllm.config import CacheConfig
|
10
17
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
18
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
11
19
|
|
12
|
-
from sglang.srt.managers.
|
13
|
-
from sglang.srt.managers.
|
20
|
+
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
21
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
14
22
|
from sglang.srt.mm_utils import (
|
15
23
|
get_anyres_image_grid_shape,
|
16
24
|
unpad_image,
|
17
25
|
unpad_image_shape,
|
18
26
|
)
|
19
27
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
20
|
-
from sglang.srt.
|
28
|
+
from sglang.srt.models.mistral import MistralForCausalLM
|
29
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
21
30
|
|
22
31
|
|
23
32
|
class LlavaLlamaForCausalLM(nn.Module):
|
@@ -25,6 +34,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
25
34
|
self,
|
26
35
|
config: LlavaConfig,
|
27
36
|
quant_config: Optional[QuantizationConfig] = None,
|
37
|
+
cache_config: Optional[CacheConfig] = None,
|
28
38
|
) -> None:
|
29
39
|
super().__init__()
|
30
40
|
self.config = config
|
@@ -233,13 +243,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
233
243
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
234
244
|
return self.language_model(input_ids, positions, input_metadata)
|
235
245
|
|
236
|
-
def load_weights(
|
237
|
-
self,
|
238
|
-
model_name_or_path: str,
|
239
|
-
cache_dir: Optional[str] = None,
|
240
|
-
load_format: str = "auto",
|
241
|
-
revision: Optional[str] = None,
|
242
|
-
):
|
246
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
243
247
|
# load clip vision model by cfg['mm_vision_tower']:
|
244
248
|
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
245
249
|
vision_path = self.config.mm_vision_tower
|
@@ -272,9 +276,8 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
272
276
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
273
277
|
}
|
274
278
|
params_dict = dict(self.named_parameters())
|
275
|
-
|
276
|
-
|
277
|
-
):
|
279
|
+
weights = list(weights)
|
280
|
+
for name, loaded_weight in weights:
|
278
281
|
# FIXME: why projector weights read two times?
|
279
282
|
if "projector" in name or "vision_tower" in name:
|
280
283
|
for weight_name, param_name in projector_weights.items():
|
@@ -285,9 +288,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
285
288
|
weight_loader(param, loaded_weight)
|
286
289
|
|
287
290
|
# load language model
|
288
|
-
self.language_model.load_weights(
|
289
|
-
model_name_or_path, cache_dir, load_format, revision
|
290
|
-
)
|
291
|
+
self.language_model.load_weights(weights)
|
291
292
|
|
292
293
|
monkey_path_clip_vision_embed_forward()
|
293
294
|
|
@@ -296,6 +297,72 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
296
297
|
return self.image_size // self.patch_size
|
297
298
|
|
298
299
|
|
300
|
+
class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
301
|
+
def __init__(
|
302
|
+
self,
|
303
|
+
config: LlavaConfig,
|
304
|
+
quant_config: Optional[QuantizationConfig] = None,
|
305
|
+
cache_config: Optional[CacheConfig] = None,
|
306
|
+
) -> None:
|
307
|
+
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
308
|
+
self.config = config
|
309
|
+
self.vision_tower = None
|
310
|
+
if getattr(self.config, "vision_config", None) is None:
|
311
|
+
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
312
|
+
|
313
|
+
if getattr(self.config, "text_config", None) is None:
|
314
|
+
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
315
|
+
|
316
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
317
|
+
self.config.text_config.hidden_size = config.hidden_size
|
318
|
+
|
319
|
+
if getattr(self.config, "projector_hidden_act", None) is None:
|
320
|
+
self.config.projector_hidden_act = "gelu"
|
321
|
+
|
322
|
+
if getattr(self.config, "image_token_index", None) is None:
|
323
|
+
self.config.image_token_index = 151646
|
324
|
+
|
325
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
326
|
+
self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
|
327
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
328
|
+
self.language_model.model.image_newline = nn.Parameter(
|
329
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
330
|
+
)
|
331
|
+
|
332
|
+
|
333
|
+
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
334
|
+
def __init__(
|
335
|
+
self,
|
336
|
+
config: LlavaConfig,
|
337
|
+
quant_config: Optional[QuantizationConfig] = None,
|
338
|
+
cache_config: Optional[CacheConfig] = None,
|
339
|
+
) -> None:
|
340
|
+
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
341
|
+
self.config = config
|
342
|
+
self.vision_tower = None
|
343
|
+
if getattr(self.config, "vision_config", None) is None:
|
344
|
+
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
345
|
+
|
346
|
+
if getattr(self.config, "text_config", None) is None:
|
347
|
+
self.config.text_config = MistralConfig(self.config._name_or_path)
|
348
|
+
|
349
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
350
|
+
self.config.text_config.hidden_size = config.hidden_size
|
351
|
+
|
352
|
+
if getattr(self.config, "projector_hidden_act", None) is None:
|
353
|
+
self.config.projector_hidden_act = "gelu"
|
354
|
+
|
355
|
+
if getattr(self.config, "image_token_index", None) is None:
|
356
|
+
self.config.image_token_index = 32000
|
357
|
+
|
358
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
359
|
+
self.language_model = MistralForCausalLM(config, quant_config=quant_config)
|
360
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
361
|
+
self.language_model.model.image_newline = nn.Parameter(
|
362
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
363
|
+
)
|
364
|
+
|
365
|
+
|
299
366
|
first_call = True
|
300
367
|
|
301
368
|
|
@@ -328,4 +395,4 @@ def monkey_path_clip_vision_embed_forward():
|
|
328
395
|
)
|
329
396
|
|
330
397
|
|
331
|
-
EntryClass = LlavaLlamaForCausalLM
|
398
|
+
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
sglang/srt/models/llavavid.py
CHANGED
@@ -1,24 +1,24 @@
|
|
1
1
|
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
2
2
|
|
3
|
-
import
|
4
|
-
from typing import List, Optional
|
3
|
+
from typing import Iterable, List, Optional, Tuple
|
5
4
|
|
6
5
|
import numpy as np
|
7
6
|
import torch
|
8
7
|
from torch import nn
|
9
|
-
from transformers import CLIPVisionModel,
|
8
|
+
from transformers import CLIPVisionModel, LlavaConfig
|
10
9
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
10
|
+
from vllm.config import CacheConfig
|
11
11
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
12
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
12
13
|
|
13
|
-
from sglang.srt.managers.
|
14
|
-
from sglang.srt.managers.
|
14
|
+
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
15
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
15
16
|
from sglang.srt.mm_utils import (
|
16
17
|
get_anyres_image_grid_shape,
|
17
18
|
unpad_image,
|
18
19
|
unpad_image_shape,
|
19
20
|
)
|
20
21
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
21
|
-
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
22
22
|
|
23
23
|
|
24
24
|
class LlavaVidForCausalLM(nn.Module):
|
@@ -26,6 +26,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
26
26
|
self,
|
27
27
|
config: LlavaConfig,
|
28
28
|
quant_config: Optional[QuantizationConfig] = None,
|
29
|
+
cache_config: Optional[CacheConfig] = None,
|
29
30
|
) -> None:
|
30
31
|
super().__init__()
|
31
32
|
self.config = config
|
@@ -65,7 +66,6 @@ class LlavaVidForCausalLM(nn.Module):
|
|
65
66
|
pad_ids = pad_value * (
|
66
67
|
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
67
68
|
)
|
68
|
-
# print(input_ids)
|
69
69
|
offset = input_ids.index(self.config.image_token_index)
|
70
70
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
71
71
|
new_input_ids = (
|
@@ -200,13 +200,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
200
200
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
201
201
|
return self.language_model(input_ids, positions, input_metadata)
|
202
202
|
|
203
|
-
def load_weights(
|
204
|
-
self,
|
205
|
-
model_name_or_path: str,
|
206
|
-
cache_dir: Optional[str] = None,
|
207
|
-
load_format: str = "auto",
|
208
|
-
revision: Optional[str] = None,
|
209
|
-
):
|
203
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
210
204
|
# load clip vision model by cfg['mm_vision_tower']:
|
211
205
|
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
212
206
|
vision_path = self.config.mm_vision_tower
|
@@ -244,9 +238,8 @@ class LlavaVidForCausalLM(nn.Module):
|
|
244
238
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
245
239
|
}
|
246
240
|
params_dict = dict(self.named_parameters())
|
247
|
-
|
248
|
-
|
249
|
-
):
|
241
|
+
weights = list(weights)
|
242
|
+
for name, loaded_weight in weights:
|
250
243
|
# FIXME: why projector weights read two times?
|
251
244
|
if "projector" in name or "vision_tower" in name:
|
252
245
|
for weight_name, param_name in projector_weights.items():
|
@@ -261,9 +254,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
261
254
|
weight_loader(param, loaded_weight)
|
262
255
|
|
263
256
|
# load language model
|
264
|
-
self.language_model.load_weights(
|
265
|
-
model_name_or_path, cache_dir, load_format, revision
|
266
|
-
)
|
257
|
+
self.language_model.load_weights(weights)
|
267
258
|
|
268
259
|
monkey_path_clip_vision_embed_forward()
|
269
260
|
|