sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +57 -2
- sglang/api.py +8 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +83 -2
- sglang/lang/interpreter.py +92 -35
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +6 -4
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +10 -2
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +27 -3
- sglang/srt/managers/router/infer_batch.py +97 -48
- sglang/srt/managers/router/manager.py +11 -8
- sglang/srt/managers/router/model_rpc.py +169 -90
- sglang/srt/managers/router/model_runner.py +110 -166
- sglang/srt/managers/router/radix_cache.py +89 -51
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +110 -33
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +11 -0
- sglang/srt/models/commandr.py +372 -0
- sglang/srt/models/dbrx.py +412 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +24 -25
- sglang/srt/models/llama2.py +25 -26
- sglang/srt/models/llava.py +8 -10
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +29 -33
- sglang/srt/models/qwen.py +34 -25
- sglang/srt/models/qwen2.py +25 -26
- sglang/srt/models/stablelm.py +26 -26
- sglang/srt/models/yivl.py +3 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +91 -456
- sglang/srt/server_args.py +79 -49
- sglang/srt/utils.py +212 -47
- sglang/srt/weight_utils.py +417 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- sglang/utils.py +77 -26
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/srt/models/gemma.py
CHANGED
@@ -4,29 +4,25 @@
|
|
4
4
|
from typing import Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
8
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
9
7
|
from torch import nn
|
10
8
|
from transformers import PretrainedConfig
|
11
9
|
from vllm.config import LoRAConfig
|
12
|
-
from vllm.
|
10
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
13
11
|
from vllm.model_executor.layers.activation import GeluAndMul
|
14
12
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
15
13
|
from vllm.model_executor.layers.linear import (
|
16
|
-
LinearMethodBase,
|
17
14
|
MergedColumnParallelLinear,
|
18
15
|
QKVParallelLinear,
|
19
16
|
RowParallelLinear,
|
20
17
|
)
|
18
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
21
19
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
22
20
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
from
|
27
|
-
|
28
|
-
hf_model_weights_iterator,
|
29
|
-
)
|
21
|
+
|
22
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
23
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
24
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
25
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
30
26
|
|
31
27
|
|
32
28
|
class GemmaMLP(nn.Module):
|
@@ -34,17 +30,20 @@ class GemmaMLP(nn.Module):
|
|
34
30
|
self,
|
35
31
|
hidden_size: int,
|
36
32
|
intermediate_size: int,
|
37
|
-
|
33
|
+
quant_config: Optional[QuantizationConfig] = None,
|
38
34
|
) -> None:
|
39
35
|
super().__init__()
|
40
36
|
self.gate_up_proj = MergedColumnParallelLinear(
|
41
37
|
hidden_size,
|
42
38
|
[intermediate_size] * 2,
|
43
39
|
bias=False,
|
44
|
-
|
40
|
+
quant_config=quant_config,
|
45
41
|
)
|
46
42
|
self.down_proj = RowParallelLinear(
|
47
|
-
intermediate_size,
|
43
|
+
intermediate_size,
|
44
|
+
hidden_size,
|
45
|
+
bias=False,
|
46
|
+
quant_config=quant_config,
|
48
47
|
)
|
49
48
|
self.act_fn = GeluAndMul()
|
50
49
|
|
@@ -65,7 +64,7 @@ class GemmaAttention(nn.Module):
|
|
65
64
|
layer_id: int = 0,
|
66
65
|
max_position_embeddings: int = 8192,
|
67
66
|
rope_theta: float = 10000,
|
68
|
-
|
67
|
+
quant_config: Optional[QuantizationConfig] = None,
|
69
68
|
) -> None:
|
70
69
|
super().__init__()
|
71
70
|
self.hidden_size = hidden_size
|
@@ -95,13 +94,13 @@ class GemmaAttention(nn.Module):
|
|
95
94
|
self.total_num_heads,
|
96
95
|
self.total_num_kv_heads,
|
97
96
|
bias=False,
|
98
|
-
|
97
|
+
quant_config=quant_config,
|
99
98
|
)
|
100
99
|
self.o_proj = RowParallelLinear(
|
101
100
|
self.total_num_heads * self.head_dim,
|
102
101
|
hidden_size,
|
103
102
|
bias=False,
|
104
|
-
|
103
|
+
quant_config=quant_config,
|
105
104
|
)
|
106
105
|
|
107
106
|
self.rotary_emb = get_rope(
|
@@ -138,7 +137,7 @@ class GemmaDecoderLayer(nn.Module):
|
|
138
137
|
self,
|
139
138
|
config: PretrainedConfig,
|
140
139
|
layer_id: int = 0,
|
141
|
-
|
140
|
+
quant_config: Optional[QuantizationConfig] = None,
|
142
141
|
) -> None:
|
143
142
|
super().__init__()
|
144
143
|
self.hidden_size = config.hidden_size
|
@@ -150,12 +149,12 @@ class GemmaDecoderLayer(nn.Module):
|
|
150
149
|
layer_id=layer_id,
|
151
150
|
max_position_embeddings=config.max_position_embeddings,
|
152
151
|
rope_theta=config.rope_theta,
|
153
|
-
|
152
|
+
quant_config=quant_config,
|
154
153
|
)
|
155
154
|
self.mlp = GemmaMLP(
|
156
155
|
hidden_size=self.hidden_size,
|
157
156
|
intermediate_size=config.intermediate_size,
|
158
|
-
|
157
|
+
quant_config=quant_config,
|
159
158
|
)
|
160
159
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
161
160
|
self.post_attention_layernorm = RMSNorm(
|
@@ -191,7 +190,7 @@ class GemmaModel(nn.Module):
|
|
191
190
|
def __init__(
|
192
191
|
self,
|
193
192
|
config: PretrainedConfig,
|
194
|
-
|
193
|
+
quant_config: Optional[QuantizationConfig] = None,
|
195
194
|
) -> None:
|
196
195
|
super().__init__()
|
197
196
|
self.config = config
|
@@ -202,7 +201,7 @@ class GemmaModel(nn.Module):
|
|
202
201
|
)
|
203
202
|
self.layers = nn.ModuleList(
|
204
203
|
[
|
205
|
-
GemmaDecoderLayer(config, i,
|
204
|
+
GemmaDecoderLayer(config, i, quant_config=quant_config)
|
206
205
|
for i in range(config.num_hidden_layers)
|
207
206
|
]
|
208
207
|
)
|
@@ -263,14 +262,14 @@ class GemmaForCausalLM(nn.Module):
|
|
263
262
|
def __init__(
|
264
263
|
self,
|
265
264
|
config: PretrainedConfig,
|
266
|
-
|
265
|
+
quant_config: Optional[QuantizationConfig] = None,
|
267
266
|
lora_config: Optional[LoRAConfig] = None,
|
268
267
|
) -> None:
|
269
268
|
del lora_config # Unused.
|
270
269
|
super().__init__()
|
271
270
|
self.config = config
|
272
|
-
self.
|
273
|
-
self.model = GemmaModel(config,
|
271
|
+
self.quant_config = quant_config
|
272
|
+
self.model = GemmaModel(config, quant_config=quant_config)
|
274
273
|
self.logits_processor = LogitsProcessor(config)
|
275
274
|
|
276
275
|
@torch.no_grad()
|
sglang/srt/models/llama2.py
CHANGED
@@ -1,34 +1,30 @@
|
|
1
1
|
# Adapted from
|
2
2
|
# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
|
3
3
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
4
|
-
from typing import Any, Dict,
|
4
|
+
from typing import Any, Dict, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
8
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
9
|
-
from sglang.srt.managers.router.model_runner import InputMetadata
|
10
7
|
from torch import nn
|
11
8
|
from transformers import LlamaConfig
|
9
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
12
10
|
from vllm.model_executor.layers.activation import SiluAndMul
|
13
11
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
14
12
|
from vllm.model_executor.layers.linear import (
|
15
|
-
LinearMethodBase,
|
16
13
|
MergedColumnParallelLinear,
|
17
14
|
QKVParallelLinear,
|
18
15
|
RowParallelLinear,
|
19
16
|
)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
20
18
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
21
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
22
20
|
ParallelLMHead,
|
23
21
|
VocabParallelEmbedding,
|
24
22
|
)
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
from
|
29
|
-
|
30
|
-
hf_model_weights_iterator,
|
31
|
-
)
|
23
|
+
|
24
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
25
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
26
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
27
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
32
28
|
|
33
29
|
|
34
30
|
class LlamaMLP(nn.Module):
|
@@ -37,17 +33,20 @@ class LlamaMLP(nn.Module):
|
|
37
33
|
hidden_size: int,
|
38
34
|
intermediate_size: int,
|
39
35
|
hidden_act: str,
|
40
|
-
|
36
|
+
quant_config: Optional[QuantizationConfig] = None,
|
41
37
|
) -> None:
|
42
38
|
super().__init__()
|
43
39
|
self.gate_up_proj = MergedColumnParallelLinear(
|
44
40
|
hidden_size,
|
45
41
|
[intermediate_size] * 2,
|
46
42
|
bias=False,
|
47
|
-
|
43
|
+
quant_config=quant_config,
|
48
44
|
)
|
49
45
|
self.down_proj = RowParallelLinear(
|
50
|
-
intermediate_size,
|
46
|
+
intermediate_size,
|
47
|
+
hidden_size,
|
48
|
+
bias=False,
|
49
|
+
quant_config=quant_config,
|
51
50
|
)
|
52
51
|
if hidden_act != "silu":
|
53
52
|
raise ValueError(
|
@@ -73,7 +72,7 @@ class LlamaAttention(nn.Module):
|
|
73
72
|
rope_theta: float = 10000,
|
74
73
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
75
74
|
max_position_embeddings: int = 8192,
|
76
|
-
|
75
|
+
quant_config: Optional[QuantizationConfig] = None,
|
77
76
|
) -> None:
|
78
77
|
super().__init__()
|
79
78
|
self.hidden_size = hidden_size
|
@@ -104,13 +103,13 @@ class LlamaAttention(nn.Module):
|
|
104
103
|
self.total_num_heads,
|
105
104
|
self.total_num_kv_heads,
|
106
105
|
bias=False,
|
107
|
-
|
106
|
+
quant_config=quant_config,
|
108
107
|
)
|
109
108
|
self.o_proj = RowParallelLinear(
|
110
109
|
self.total_num_heads * self.head_dim,
|
111
110
|
hidden_size,
|
112
111
|
bias=False,
|
113
|
-
|
112
|
+
quant_config=quant_config,
|
114
113
|
)
|
115
114
|
|
116
115
|
self.rotary_emb = get_rope(
|
@@ -147,7 +146,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
147
146
|
self,
|
148
147
|
config: LlamaConfig,
|
149
148
|
layer_id: int = 0,
|
150
|
-
|
149
|
+
quant_config: Optional[QuantizationConfig] = None,
|
151
150
|
) -> None:
|
152
151
|
super().__init__()
|
153
152
|
self.hidden_size = config.hidden_size
|
@@ -162,13 +161,13 @@ class LlamaDecoderLayer(nn.Module):
|
|
162
161
|
rope_theta=rope_theta,
|
163
162
|
rope_scaling=rope_scaling,
|
164
163
|
max_position_embeddings=max_position_embeddings,
|
165
|
-
|
164
|
+
quant_config=quant_config,
|
166
165
|
)
|
167
166
|
self.mlp = LlamaMLP(
|
168
167
|
hidden_size=self.hidden_size,
|
169
168
|
intermediate_size=config.intermediate_size,
|
170
169
|
hidden_act=config.hidden_act,
|
171
|
-
|
170
|
+
quant_config=quant_config,
|
172
171
|
)
|
173
172
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
174
173
|
self.post_attention_layernorm = RMSNorm(
|
@@ -204,7 +203,7 @@ class LlamaModel(nn.Module):
|
|
204
203
|
def __init__(
|
205
204
|
self,
|
206
205
|
config: LlamaConfig,
|
207
|
-
|
206
|
+
quant_config: Optional[QuantizationConfig] = None,
|
208
207
|
) -> None:
|
209
208
|
super().__init__()
|
210
209
|
self.config = config
|
@@ -216,7 +215,7 @@ class LlamaModel(nn.Module):
|
|
216
215
|
)
|
217
216
|
self.layers = nn.ModuleList(
|
218
217
|
[
|
219
|
-
LlamaDecoderLayer(config, i,
|
218
|
+
LlamaDecoderLayer(config, i, quant_config=quant_config)
|
220
219
|
for i in range(config.num_hidden_layers)
|
221
220
|
]
|
222
221
|
)
|
@@ -250,12 +249,12 @@ class LlamaForCausalLM(nn.Module):
|
|
250
249
|
def __init__(
|
251
250
|
self,
|
252
251
|
config: LlamaConfig,
|
253
|
-
|
252
|
+
quant_config: Optional[QuantizationConfig] = None,
|
254
253
|
) -> None:
|
255
254
|
super().__init__()
|
256
255
|
self.config = config
|
257
|
-
self.
|
258
|
-
self.model = LlamaModel(config,
|
256
|
+
self.quant_config = quant_config
|
257
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
259
258
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
260
259
|
self.logits_processor = LogitsProcessor(config)
|
261
260
|
|
sglang/srt/models/llava.py
CHANGED
@@ -4,6 +4,11 @@ from typing import List, Optional
|
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import CLIPVisionModel, LlavaConfig
|
9
|
+
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
10
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
11
|
+
|
7
12
|
from sglang.srt.managers.router.infer_batch import ForwardMode
|
8
13
|
from sglang.srt.managers.router.model_runner import InputMetadata
|
9
14
|
from sglang.srt.mm_utils import (
|
@@ -12,21 +17,14 @@ from sglang.srt.mm_utils import (
|
|
12
17
|
unpad_image_shape,
|
13
18
|
)
|
14
19
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
15
|
-
from
|
16
|
-
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
|
17
|
-
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
18
|
-
from vllm.model_executor.layers.linear import LinearMethodBase
|
19
|
-
from vllm.model_executor.weight_utils import (
|
20
|
-
default_weight_loader,
|
21
|
-
hf_model_weights_iterator,
|
22
|
-
)
|
20
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
23
21
|
|
24
22
|
|
25
23
|
class LlavaLlamaForCausalLM(nn.Module):
|
26
24
|
def __init__(
|
27
25
|
self,
|
28
26
|
config: LlavaConfig,
|
29
|
-
|
27
|
+
quant_config: Optional[QuantizationConfig] = None,
|
30
28
|
) -> None:
|
31
29
|
super().__init__()
|
32
30
|
self.config = config
|
@@ -34,7 +32,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
34
32
|
self.config.vision_config.hidden_size = config.mm_hidden_size
|
35
33
|
self.config.text_config.hidden_size = config.hidden_size
|
36
34
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
37
|
-
self.language_model = LlamaForCausalLM(config,
|
35
|
+
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
38
36
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
39
37
|
self.language_model.model.image_newline = nn.Parameter(
|
40
38
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
@@ -0,0 +1,307 @@
|
|
1
|
+
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from torch import nn
|
9
|
+
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
|
10
|
+
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
11
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
12
|
+
|
13
|
+
from sglang.srt.managers.router.infer_batch import ForwardMode
|
14
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
15
|
+
from sglang.srt.mm_utils import (
|
16
|
+
get_anyres_image_grid_shape,
|
17
|
+
unpad_image,
|
18
|
+
unpad_image_shape,
|
19
|
+
)
|
20
|
+
from sglang.srt.models.llama2 import LlamaForCausalLM
|
21
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
22
|
+
|
23
|
+
|
24
|
+
class LlavaVidForCausalLM(nn.Module):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
config: LlavaConfig,
|
28
|
+
quant_config: Optional[QuantizationConfig] = None,
|
29
|
+
) -> None:
|
30
|
+
super().__init__()
|
31
|
+
self.config = config
|
32
|
+
self.vision_tower = None
|
33
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
34
|
+
self.config.text_config.hidden_size = config.hidden_size
|
35
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
36
|
+
self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
|
37
|
+
self.resampler = nn.AvgPool2d(
|
38
|
+
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
|
39
|
+
)
|
40
|
+
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
41
|
+
self.num_frames = getattr(self.config, "num_frames", 16)
|
42
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
43
|
+
self.language_model.model.image_newline = nn.Parameter(
|
44
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
45
|
+
)
|
46
|
+
|
47
|
+
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
48
|
+
new_image_feature_len = self.image_feature_len
|
49
|
+
# now only support spatial_unpad + anyres
|
50
|
+
# if self.mm_patch_merge_type.startswith("spatial"):
|
51
|
+
# height = width = self.num_patches_per_side
|
52
|
+
# if pt_shape[0] > 1:
|
53
|
+
# if self.image_aspect_ratio == "anyres":
|
54
|
+
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
55
|
+
# image_size,
|
56
|
+
# self.image_grid_pinpoints,
|
57
|
+
# self.vision_tower.config.image_size,
|
58
|
+
# )
|
59
|
+
# if "unpad" in self.mm_patch_merge_type:
|
60
|
+
# h = num_patch_height * height
|
61
|
+
# w = num_patch_width * width
|
62
|
+
# new_h, new_w = unpad_image_shape(h, w, image_size)
|
63
|
+
# new_image_feature_len += new_h * (new_w + 1)
|
64
|
+
|
65
|
+
pad_ids = pad_value * (
|
66
|
+
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
67
|
+
)
|
68
|
+
# print(input_ids)
|
69
|
+
offset = input_ids.index(self.config.image_token_index)
|
70
|
+
# old_len + pad_len - 1, because we need to remove image_token_id
|
71
|
+
new_input_ids = (
|
72
|
+
input_ids[:offset]
|
73
|
+
+ pad_ids[:new_image_feature_len]
|
74
|
+
+ input_ids[offset + 1 :]
|
75
|
+
)
|
76
|
+
return new_input_ids, offset
|
77
|
+
|
78
|
+
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
79
|
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
80
|
+
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
81
|
+
|
82
|
+
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
83
|
+
if self.vision_feature_select_strategy in ["default", "patch"]:
|
84
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
85
|
+
elif self.vision_feature_select_strategy == "full":
|
86
|
+
selected_image_feature = selected_image_feature
|
87
|
+
else:
|
88
|
+
raise ValueError(
|
89
|
+
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
90
|
+
)
|
91
|
+
|
92
|
+
height = width = self.num_patches_per_side
|
93
|
+
num_of_frames = selected_image_feature.shape[0]
|
94
|
+
selected_image_feature = selected_image_feature.view(
|
95
|
+
num_of_frames, height, width, -1
|
96
|
+
)
|
97
|
+
selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
|
98
|
+
selected_image_feature = (
|
99
|
+
self.resampler(selected_image_feature)
|
100
|
+
.flatten(2)
|
101
|
+
.transpose(1, 2)
|
102
|
+
.contiguous()
|
103
|
+
)
|
104
|
+
|
105
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
106
|
+
|
107
|
+
return image_features
|
108
|
+
|
109
|
+
def forward(
|
110
|
+
self,
|
111
|
+
input_ids: torch.LongTensor,
|
112
|
+
positions: torch.Tensor,
|
113
|
+
input_metadata: InputMetadata,
|
114
|
+
pixel_values: Optional[List[Optional[np.array]]] = None,
|
115
|
+
image_sizes: Optional[List[List[int]]] = None,
|
116
|
+
image_offsets: Optional[List[int]] = None,
|
117
|
+
) -> torch.Tensor:
|
118
|
+
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
119
|
+
bs = input_metadata.batch_size
|
120
|
+
|
121
|
+
# Embed text input
|
122
|
+
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
123
|
+
|
124
|
+
# Embed vision input
|
125
|
+
need_vision = (
|
126
|
+
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
127
|
+
.cpu()
|
128
|
+
.numpy()
|
129
|
+
)
|
130
|
+
# FIXME: We need to substract the length of the system prompt
|
131
|
+
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
132
|
+
need_vision = need_vision & has_pixel
|
133
|
+
|
134
|
+
if need_vision.any():
|
135
|
+
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
136
|
+
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
137
|
+
|
138
|
+
########## Encode Image ########
|
139
|
+
|
140
|
+
if pixel_values[0].ndim == 4:
|
141
|
+
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
142
|
+
np.concatenate(pixel_values, axis=0)
|
143
|
+
# ndim=4
|
144
|
+
concat_images = torch.tensor(
|
145
|
+
np.concatenate(pixel_values, axis=0),
|
146
|
+
device=self.vision_tower.device,
|
147
|
+
)
|
148
|
+
# image_features = self.encode_images(concat_images)
|
149
|
+
# split_sizes = [image.shape[0] for image in pixel_values]
|
150
|
+
# image_features = torch.split(image_features, split_sizes, dim=0)
|
151
|
+
image_features = self.encode_images(
|
152
|
+
concat_images
|
153
|
+
) # , prompts)#, image_counts, long_video=long_video)
|
154
|
+
split_sizes = [image.shape[0] for image in pixel_values]
|
155
|
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
156
|
+
|
157
|
+
# hd image_features: BS, num_patch, 576, 4096
|
158
|
+
else:
|
159
|
+
# normal pixel: BS, C=3, H=336, W=336
|
160
|
+
pixel_values = torch.tensor(
|
161
|
+
np.array(pixel_values), device=self.vision_tower.device
|
162
|
+
)
|
163
|
+
image_features = self.encode_images(pixel_values)
|
164
|
+
# image_features: BS, 576, 4096
|
165
|
+
|
166
|
+
new_image_features = []
|
167
|
+
for image_idx, image_feature in enumerate(image_features):
|
168
|
+
new_image_features.append(image_feature.flatten(0, 1))
|
169
|
+
image_features = new_image_features
|
170
|
+
|
171
|
+
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
172
|
+
pt = 0
|
173
|
+
for i in range(bs):
|
174
|
+
if not need_vision[i]:
|
175
|
+
continue
|
176
|
+
|
177
|
+
start_idx = extend_start_loc_cpu[i]
|
178
|
+
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
179
|
+
dim = input_embeds.shape[1]
|
180
|
+
assert (
|
181
|
+
pad_dim == dim
|
182
|
+
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
183
|
+
# Fill in the placeholder for the image
|
184
|
+
try:
|
185
|
+
input_embeds[
|
186
|
+
start_idx
|
187
|
+
+ image_offsets[i] : start_idx
|
188
|
+
+ image_offsets[i]
|
189
|
+
+ pad_len
|
190
|
+
] = image_features[pt]
|
191
|
+
except RuntimeError as e:
|
192
|
+
print(f"RuntimeError in llava image encoding: {e}")
|
193
|
+
print(input_embeds.shape)
|
194
|
+
print(start_idx, image_offsets[i])
|
195
|
+
pt += 1
|
196
|
+
|
197
|
+
return self.language_model(
|
198
|
+
input_ids, positions, input_metadata, input_embeds=input_embeds
|
199
|
+
)
|
200
|
+
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
201
|
+
return self.language_model(input_ids, positions, input_metadata)
|
202
|
+
|
203
|
+
def load_weights(
|
204
|
+
self,
|
205
|
+
model_name_or_path: str,
|
206
|
+
cache_dir: Optional[str] = None,
|
207
|
+
load_format: str = "auto",
|
208
|
+
revision: Optional[str] = None,
|
209
|
+
):
|
210
|
+
# load clip vision model by cfg['mm_vision_tower']:
|
211
|
+
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
212
|
+
vision_path = self.config.mm_vision_tower
|
213
|
+
self.vision_tower = CLIPVisionModel.from_pretrained(
|
214
|
+
vision_path, torch_dtype=torch.float16
|
215
|
+
).cuda()
|
216
|
+
self.vision_tower.eval()
|
217
|
+
|
218
|
+
self.vision_feature_layer = self.config.mm_vision_select_layer
|
219
|
+
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
220
|
+
self.image_size = self.vision_tower.config.image_size
|
221
|
+
self.patch_size = self.vision_tower.config.patch_size
|
222
|
+
|
223
|
+
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
224
|
+
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
225
|
+
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
226
|
+
|
227
|
+
print(f"target_frames: {self.num_frames}")
|
228
|
+
self.image_feature_len = self.num_frames * int(
|
229
|
+
(self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
|
230
|
+
)
|
231
|
+
if self.vision_feature_select_strategy == "patch":
|
232
|
+
pass
|
233
|
+
elif self.vision_feature_select_strategy == "cls_patch":
|
234
|
+
self.image_feature_len += 1
|
235
|
+
else:
|
236
|
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
237
|
+
|
238
|
+
# load mm_projector
|
239
|
+
projector_weights = {
|
240
|
+
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
241
|
+
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
242
|
+
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
243
|
+
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
244
|
+
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
245
|
+
}
|
246
|
+
params_dict = dict(self.named_parameters())
|
247
|
+
for name, loaded_weight in hf_model_weights_iterator(
|
248
|
+
model_name_or_path, cache_dir, load_format, revision
|
249
|
+
):
|
250
|
+
# FIXME: why projector weights read two times?
|
251
|
+
if "projector" in name or "vision_tower" in name:
|
252
|
+
for weight_name, param_name in projector_weights.items():
|
253
|
+
if weight_name in name:
|
254
|
+
name = name.replace(weight_name, param_name)
|
255
|
+
if name in params_dict:
|
256
|
+
param = params_dict[name]
|
257
|
+
else:
|
258
|
+
print(f"Warning: {name} not found in the model")
|
259
|
+
continue
|
260
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
261
|
+
weight_loader(param, loaded_weight)
|
262
|
+
|
263
|
+
# load language model
|
264
|
+
self.language_model.load_weights(
|
265
|
+
model_name_or_path, cache_dir, load_format, revision
|
266
|
+
)
|
267
|
+
|
268
|
+
monkey_path_clip_vision_embed_forward()
|
269
|
+
|
270
|
+
@property
|
271
|
+
def num_patches_per_side(self):
|
272
|
+
return self.image_size // self.patch_size
|
273
|
+
|
274
|
+
|
275
|
+
first_call = True
|
276
|
+
|
277
|
+
|
278
|
+
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
279
|
+
batch_size = pixel_values.shape[0]
|
280
|
+
|
281
|
+
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
282
|
+
global first_call
|
283
|
+
if first_call:
|
284
|
+
self.patch_embedding.cpu().float()
|
285
|
+
first_call = False
|
286
|
+
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
287
|
+
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
288
|
+
|
289
|
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
290
|
+
|
291
|
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
292
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
293
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
294
|
+
return embeddings
|
295
|
+
|
296
|
+
|
297
|
+
def monkey_path_clip_vision_embed_forward():
|
298
|
+
import transformers
|
299
|
+
|
300
|
+
setattr(
|
301
|
+
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
302
|
+
"forward",
|
303
|
+
clip_vision_embed_forward,
|
304
|
+
)
|
305
|
+
|
306
|
+
|
307
|
+
EntryClass = LlavaVidForCausalLM
|