sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama2.py
CHANGED
@@ -1,11 +1,17 @@
|
|
1
1
|
# Adapted from
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
3
3
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
4
|
-
from typing import Any, Dict, Optional, Tuple
|
4
|
+
from typing import Any, Dict, Optional, Tuple, Iterable
|
5
5
|
|
6
6
|
import torch
|
7
|
+
import tqdm
|
7
8
|
from torch import nn
|
8
9
|
from transformers import LlamaConfig
|
10
|
+
from vllm.config import CacheConfig
|
11
|
+
from vllm.distributed import (
|
12
|
+
get_tensor_model_parallel_rank,
|
13
|
+
get_tensor_model_parallel_world_size
|
14
|
+
)
|
9
15
|
from vllm.model_executor.layers.activation import SiluAndMul
|
10
16
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
11
17
|
from vllm.model_executor.layers.linear import (
|
@@ -13,24 +19,17 @@ from vllm.model_executor.layers.linear import (
|
|
13
19
|
QKVParallelLinear,
|
14
20
|
RowParallelLinear,
|
15
21
|
)
|
16
|
-
from vllm.model_executor.layers.quantization.base_config import
|
17
|
-
QuantizationConfig)
|
22
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
18
23
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
19
24
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
20
25
|
ParallelLMHead,
|
21
26
|
VocabParallelEmbedding,
|
22
27
|
)
|
23
|
-
from vllm.
|
24
|
-
get_tensor_model_parallel_world_size,
|
25
|
-
)
|
26
|
-
from sglang.srt.weight_utils import (
|
27
|
-
default_weight_loader,
|
28
|
-
hf_model_weights_iterator,
|
29
|
-
)
|
28
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
29
|
|
31
30
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
31
|
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
-
from sglang.srt.managers.
|
32
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
34
33
|
|
35
34
|
|
36
35
|
class LlamaMLP(nn.Module):
|
@@ -49,7 +48,10 @@ class LlamaMLP(nn.Module):
|
|
49
48
|
quant_config=quant_config,
|
50
49
|
)
|
51
50
|
self.down_proj = RowParallelLinear(
|
52
|
-
intermediate_size,
|
51
|
+
intermediate_size,
|
52
|
+
hidden_size,
|
53
|
+
bias=False,
|
54
|
+
quant_config=quant_config,
|
53
55
|
)
|
54
56
|
if hidden_act != "silu":
|
55
57
|
raise ValueError(
|
@@ -155,6 +157,10 @@ class LlamaDecoderLayer(nn.Module):
|
|
155
157
|
self.hidden_size = config.hidden_size
|
156
158
|
rope_theta = getattr(config, "rope_theta", 10000)
|
157
159
|
rope_scaling = getattr(config, "rope_scaling", None)
|
160
|
+
if rope_scaling is not None and getattr(
|
161
|
+
config, "original_max_position_embeddings", None):
|
162
|
+
rope_scaling["original_max_position_embeddings"] = (
|
163
|
+
config.original_max_position_embeddings)
|
158
164
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
159
165
|
self.self_attn = LlamaAttention(
|
160
166
|
hidden_size=self.hidden_size,
|
@@ -253,6 +259,7 @@ class LlamaForCausalLM(nn.Module):
|
|
253
259
|
self,
|
254
260
|
config: LlamaConfig,
|
255
261
|
quant_config: Optional[QuantizationConfig] = None,
|
262
|
+
cache_config: Optional[CacheConfig] = None,
|
256
263
|
) -> None:
|
257
264
|
super().__init__()
|
258
265
|
self.config = config
|
@@ -273,13 +280,7 @@ class LlamaForCausalLM(nn.Module):
|
|
273
280
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
274
281
|
)
|
275
282
|
|
276
|
-
def load_weights(
|
277
|
-
self,
|
278
|
-
model_name_or_path: str,
|
279
|
-
cache_dir: Optional[str] = None,
|
280
|
-
load_format: str = "auto",
|
281
|
-
revision: Optional[str] = None,
|
282
|
-
):
|
283
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
283
284
|
stacked_params_mapping = [
|
284
285
|
# (param_name, shard_name, shard_id)
|
285
286
|
("qkv_proj", "q_proj", "q"),
|
@@ -289,9 +290,9 @@ class LlamaForCausalLM(nn.Module):
|
|
289
290
|
("gate_up_proj", "up_proj", 1),
|
290
291
|
]
|
291
292
|
params_dict = dict(self.named_parameters())
|
292
|
-
|
293
|
-
|
294
|
-
|
293
|
+
if get_tensor_model_parallel_rank() == 0:
|
294
|
+
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
295
|
+
for name, loaded_weight in weights:
|
295
296
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
296
297
|
continue
|
297
298
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
sglang/srt/models/llava.py
CHANGED
@@ -1,27 +1,26 @@
|
|
1
1
|
"""Inference-only LLaVa model compatible with HuggingFace weights."""
|
2
2
|
|
3
|
-
from typing import List, Optional
|
3
|
+
from typing import List, Iterable, Optional, Tuple
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
|
-
from transformers import CLIPVisionModel, LlavaConfig
|
8
|
+
from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig
|
9
9
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
10
|
-
from vllm.
|
11
|
-
|
12
|
-
from
|
13
|
-
default_weight_loader,
|
14
|
-
hf_model_weights_iterator,
|
15
|
-
)
|
10
|
+
from vllm.config import CacheConfig
|
11
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
12
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
16
13
|
|
17
|
-
from sglang.srt.managers.
|
18
|
-
from sglang.srt.managers.
|
14
|
+
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
15
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
19
16
|
from sglang.srt.mm_utils import (
|
20
17
|
get_anyres_image_grid_shape,
|
21
18
|
unpad_image,
|
22
19
|
unpad_image_shape,
|
23
20
|
)
|
24
21
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
22
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
23
|
+
from sglang.srt.models.mistral import MistralForCausalLM
|
25
24
|
|
26
25
|
|
27
26
|
class LlavaLlamaForCausalLM(nn.Module):
|
@@ -29,6 +28,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
29
28
|
self,
|
30
29
|
config: LlavaConfig,
|
31
30
|
quant_config: Optional[QuantizationConfig] = None,
|
31
|
+
cache_config: Optional[CacheConfig] = None,
|
32
32
|
) -> None:
|
33
33
|
super().__init__()
|
34
34
|
self.config = config
|
@@ -237,13 +237,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
237
237
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
238
238
|
return self.language_model(input_ids, positions, input_metadata)
|
239
239
|
|
240
|
-
def load_weights(
|
241
|
-
self,
|
242
|
-
model_name_or_path: str,
|
243
|
-
cache_dir: Optional[str] = None,
|
244
|
-
load_format: str = "auto",
|
245
|
-
revision: Optional[str] = None,
|
246
|
-
):
|
240
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
247
241
|
# load clip vision model by cfg['mm_vision_tower']:
|
248
242
|
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
249
243
|
vision_path = self.config.mm_vision_tower
|
@@ -276,9 +270,8 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
276
270
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
277
271
|
}
|
278
272
|
params_dict = dict(self.named_parameters())
|
279
|
-
|
280
|
-
|
281
|
-
):
|
273
|
+
weights = list(weights)
|
274
|
+
for name, loaded_weight in weights:
|
282
275
|
# FIXME: why projector weights read two times?
|
283
276
|
if "projector" in name or "vision_tower" in name:
|
284
277
|
for weight_name, param_name in projector_weights.items():
|
@@ -289,9 +282,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
289
282
|
weight_loader(param, loaded_weight)
|
290
283
|
|
291
284
|
# load language model
|
292
|
-
self.language_model.load_weights(
|
293
|
-
model_name_or_path, cache_dir, load_format, revision
|
294
|
-
)
|
285
|
+
self.language_model.load_weights(weights)
|
295
286
|
|
296
287
|
monkey_path_clip_vision_embed_forward()
|
297
288
|
|
@@ -300,9 +291,74 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
300
291
|
return self.image_size // self.patch_size
|
301
292
|
|
302
293
|
|
303
|
-
|
294
|
+
class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
295
|
+
def __init__(
|
296
|
+
self,
|
297
|
+
config: LlavaConfig,
|
298
|
+
quant_config: Optional[QuantizationConfig] = None,
|
299
|
+
cache_config: Optional[CacheConfig] = None,
|
300
|
+
) -> None:
|
301
|
+
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
302
|
+
self.config = config
|
303
|
+
self.vision_tower = None
|
304
|
+
if getattr(self.config, "vision_config", None) is None:
|
305
|
+
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
306
|
+
|
307
|
+
if getattr(self.config, "text_config", None) is None:
|
308
|
+
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
309
|
+
|
310
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
311
|
+
self.config.text_config.hidden_size = config.hidden_size
|
312
|
+
|
313
|
+
if getattr(self.config, "projector_hidden_act", None) is None:
|
314
|
+
self.config.projector_hidden_act = "gelu"
|
315
|
+
|
316
|
+
if getattr(self.config, "image_token_index", None) is None:
|
317
|
+
self.config.image_token_index = 151646
|
318
|
+
|
319
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
320
|
+
self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
|
321
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
322
|
+
self.language_model.model.image_newline = nn.Parameter(
|
323
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
324
|
+
)
|
304
325
|
|
305
326
|
|
327
|
+
class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
328
|
+
def __init__(
|
329
|
+
self,
|
330
|
+
config: LlavaConfig,
|
331
|
+
quant_config: Optional[QuantizationConfig] = None,
|
332
|
+
cache_config: Optional[CacheConfig] = None,
|
333
|
+
) -> None:
|
334
|
+
super().__init__(config, quant_config=quant_config, cache_config=cache_config)
|
335
|
+
self.config = config
|
336
|
+
self.vision_tower = None
|
337
|
+
if getattr(self.config, "vision_config", None) is None:
|
338
|
+
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
339
|
+
|
340
|
+
if getattr(self.config, "text_config", None) is None:
|
341
|
+
self.config.text_config = MistralConfig(self.config._name_or_path)
|
342
|
+
|
343
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
344
|
+
self.config.text_config.hidden_size = config.hidden_size
|
345
|
+
|
346
|
+
if getattr(self.config, "projector_hidden_act", None) is None:
|
347
|
+
self.config.projector_hidden_act = "gelu"
|
348
|
+
|
349
|
+
if getattr(self.config, "image_token_index", None) is None:
|
350
|
+
self.config.image_token_index = 32000
|
351
|
+
|
352
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
353
|
+
self.language_model = MistralForCausalLM(config, quant_config=quant_config)
|
354
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
355
|
+
self.language_model.model.image_newline = nn.Parameter(
|
356
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
357
|
+
)
|
358
|
+
|
359
|
+
|
360
|
+
first_call = True
|
361
|
+
|
306
362
|
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
307
363
|
batch_size = pixel_values.shape[0]
|
308
364
|
|
@@ -332,4 +388,8 @@ def monkey_path_clip_vision_embed_forward():
|
|
332
388
|
)
|
333
389
|
|
334
390
|
|
335
|
-
EntryClass =
|
391
|
+
EntryClass = [
|
392
|
+
LlavaLlamaForCausalLM,
|
393
|
+
LlavaQwenForCausalLM,
|
394
|
+
LlavaMistralForCausalLM
|
395
|
+
]
|
@@ -0,0 +1,298 @@
|
|
1
|
+
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
2
|
+
|
3
|
+
from typing import List, Iterable, Optional, Tuple
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import CLIPVisionModel, LlavaConfig
|
9
|
+
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
10
|
+
from vllm.config import CacheConfig
|
11
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
12
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
13
|
+
|
14
|
+
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
15
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
16
|
+
from sglang.srt.mm_utils import (
|
17
|
+
get_anyres_image_grid_shape,
|
18
|
+
unpad_image,
|
19
|
+
unpad_image_shape,
|
20
|
+
)
|
21
|
+
from sglang.srt.models.llama2 import LlamaForCausalLM
|
22
|
+
|
23
|
+
|
24
|
+
class LlavaVidForCausalLM(nn.Module):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
config: LlavaConfig,
|
28
|
+
quant_config: Optional[QuantizationConfig] = None,
|
29
|
+
cache_config: Optional[CacheConfig] = None,
|
30
|
+
) -> None:
|
31
|
+
super().__init__()
|
32
|
+
self.config = config
|
33
|
+
self.vision_tower = None
|
34
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
35
|
+
self.config.text_config.hidden_size = config.hidden_size
|
36
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
37
|
+
self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
|
38
|
+
self.resampler = nn.AvgPool2d(
|
39
|
+
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
|
40
|
+
)
|
41
|
+
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
42
|
+
self.num_frames = getattr(self.config, "num_frames", 16)
|
43
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
44
|
+
self.language_model.model.image_newline = nn.Parameter(
|
45
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
46
|
+
)
|
47
|
+
|
48
|
+
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
49
|
+
new_image_feature_len = self.image_feature_len
|
50
|
+
# now only support spatial_unpad + anyres
|
51
|
+
# if self.mm_patch_merge_type.startswith("spatial"):
|
52
|
+
# height = width = self.num_patches_per_side
|
53
|
+
# if pt_shape[0] > 1:
|
54
|
+
# if self.image_aspect_ratio == "anyres":
|
55
|
+
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
56
|
+
# image_size,
|
57
|
+
# self.image_grid_pinpoints,
|
58
|
+
# self.vision_tower.config.image_size,
|
59
|
+
# )
|
60
|
+
# if "unpad" in self.mm_patch_merge_type:
|
61
|
+
# h = num_patch_height * height
|
62
|
+
# w = num_patch_width * width
|
63
|
+
# new_h, new_w = unpad_image_shape(h, w, image_size)
|
64
|
+
# new_image_feature_len += new_h * (new_w + 1)
|
65
|
+
|
66
|
+
pad_ids = pad_value * (
|
67
|
+
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
68
|
+
)
|
69
|
+
offset = input_ids.index(self.config.image_token_index)
|
70
|
+
# old_len + pad_len - 1, because we need to remove image_token_id
|
71
|
+
new_input_ids = (
|
72
|
+
input_ids[:offset]
|
73
|
+
+ pad_ids[:new_image_feature_len]
|
74
|
+
+ input_ids[offset + 1 :]
|
75
|
+
)
|
76
|
+
return new_input_ids, offset
|
77
|
+
|
78
|
+
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
79
|
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
80
|
+
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
81
|
+
|
82
|
+
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
83
|
+
if self.vision_feature_select_strategy in ["default", "patch"]:
|
84
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
85
|
+
elif self.vision_feature_select_strategy == "full":
|
86
|
+
selected_image_feature = selected_image_feature
|
87
|
+
else:
|
88
|
+
raise ValueError(
|
89
|
+
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
90
|
+
)
|
91
|
+
|
92
|
+
height = width = self.num_patches_per_side
|
93
|
+
num_of_frames = selected_image_feature.shape[0]
|
94
|
+
selected_image_feature = selected_image_feature.view(
|
95
|
+
num_of_frames, height, width, -1
|
96
|
+
)
|
97
|
+
selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
|
98
|
+
selected_image_feature = (
|
99
|
+
self.resampler(selected_image_feature)
|
100
|
+
.flatten(2)
|
101
|
+
.transpose(1, 2)
|
102
|
+
.contiguous()
|
103
|
+
)
|
104
|
+
|
105
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
106
|
+
|
107
|
+
return image_features
|
108
|
+
|
109
|
+
def forward(
|
110
|
+
self,
|
111
|
+
input_ids: torch.LongTensor,
|
112
|
+
positions: torch.Tensor,
|
113
|
+
input_metadata: InputMetadata,
|
114
|
+
pixel_values: Optional[List[Optional[np.array]]] = None,
|
115
|
+
image_sizes: Optional[List[List[int]]] = None,
|
116
|
+
image_offsets: Optional[List[int]] = None,
|
117
|
+
) -> torch.Tensor:
|
118
|
+
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
119
|
+
bs = input_metadata.batch_size
|
120
|
+
|
121
|
+
# Embed text input
|
122
|
+
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
123
|
+
|
124
|
+
# Embed vision input
|
125
|
+
need_vision = (
|
126
|
+
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
127
|
+
.cpu()
|
128
|
+
.numpy()
|
129
|
+
)
|
130
|
+
# FIXME: We need to substract the length of the system prompt
|
131
|
+
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
132
|
+
need_vision = need_vision & has_pixel
|
133
|
+
|
134
|
+
if need_vision.any():
|
135
|
+
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
136
|
+
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
137
|
+
|
138
|
+
########## Encode Image ########
|
139
|
+
|
140
|
+
if pixel_values[0].ndim == 4:
|
141
|
+
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
142
|
+
np.concatenate(pixel_values, axis=0)
|
143
|
+
# ndim=4
|
144
|
+
concat_images = torch.tensor(
|
145
|
+
np.concatenate(pixel_values, axis=0),
|
146
|
+
device=self.vision_tower.device,
|
147
|
+
)
|
148
|
+
# image_features = self.encode_images(concat_images)
|
149
|
+
# split_sizes = [image.shape[0] for image in pixel_values]
|
150
|
+
# image_features = torch.split(image_features, split_sizes, dim=0)
|
151
|
+
image_features = self.encode_images(
|
152
|
+
concat_images
|
153
|
+
) # , prompts)#, image_counts, long_video=long_video)
|
154
|
+
split_sizes = [image.shape[0] for image in pixel_values]
|
155
|
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
156
|
+
|
157
|
+
# hd image_features: BS, num_patch, 576, 4096
|
158
|
+
else:
|
159
|
+
# normal pixel: BS, C=3, H=336, W=336
|
160
|
+
pixel_values = torch.tensor(
|
161
|
+
np.array(pixel_values), device=self.vision_tower.device
|
162
|
+
)
|
163
|
+
image_features = self.encode_images(pixel_values)
|
164
|
+
# image_features: BS, 576, 4096
|
165
|
+
|
166
|
+
new_image_features = []
|
167
|
+
for image_idx, image_feature in enumerate(image_features):
|
168
|
+
new_image_features.append(image_feature.flatten(0, 1))
|
169
|
+
image_features = new_image_features
|
170
|
+
|
171
|
+
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
172
|
+
pt = 0
|
173
|
+
for i in range(bs):
|
174
|
+
if not need_vision[i]:
|
175
|
+
continue
|
176
|
+
|
177
|
+
start_idx = extend_start_loc_cpu[i]
|
178
|
+
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
179
|
+
dim = input_embeds.shape[1]
|
180
|
+
assert (
|
181
|
+
pad_dim == dim
|
182
|
+
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
183
|
+
# Fill in the placeholder for the image
|
184
|
+
try:
|
185
|
+
input_embeds[
|
186
|
+
start_idx
|
187
|
+
+ image_offsets[i] : start_idx
|
188
|
+
+ image_offsets[i]
|
189
|
+
+ pad_len
|
190
|
+
] = image_features[pt]
|
191
|
+
except RuntimeError as e:
|
192
|
+
print(f"RuntimeError in llava image encoding: {e}")
|
193
|
+
print(input_embeds.shape)
|
194
|
+
print(start_idx, image_offsets[i])
|
195
|
+
pt += 1
|
196
|
+
|
197
|
+
return self.language_model(
|
198
|
+
input_ids, positions, input_metadata, input_embeds=input_embeds
|
199
|
+
)
|
200
|
+
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
201
|
+
return self.language_model(input_ids, positions, input_metadata)
|
202
|
+
|
203
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
204
|
+
# load clip vision model by cfg['mm_vision_tower']:
|
205
|
+
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
206
|
+
vision_path = self.config.mm_vision_tower
|
207
|
+
self.vision_tower = CLIPVisionModel.from_pretrained(
|
208
|
+
vision_path, torch_dtype=torch.float16
|
209
|
+
).cuda()
|
210
|
+
self.vision_tower.eval()
|
211
|
+
|
212
|
+
self.vision_feature_layer = self.config.mm_vision_select_layer
|
213
|
+
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
214
|
+
self.image_size = self.vision_tower.config.image_size
|
215
|
+
self.patch_size = self.vision_tower.config.patch_size
|
216
|
+
|
217
|
+
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
218
|
+
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
219
|
+
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
220
|
+
|
221
|
+
print(f"target_frames: {self.num_frames}")
|
222
|
+
self.image_feature_len = self.num_frames * int(
|
223
|
+
(self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
|
224
|
+
)
|
225
|
+
if self.vision_feature_select_strategy == "patch":
|
226
|
+
pass
|
227
|
+
elif self.vision_feature_select_strategy == "cls_patch":
|
228
|
+
self.image_feature_len += 1
|
229
|
+
else:
|
230
|
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
231
|
+
|
232
|
+
# load mm_projector
|
233
|
+
projector_weights = {
|
234
|
+
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
235
|
+
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
236
|
+
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
237
|
+
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
238
|
+
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
239
|
+
}
|
240
|
+
params_dict = dict(self.named_parameters())
|
241
|
+
weights = list(weights)
|
242
|
+
for name, loaded_weight in weights:
|
243
|
+
# FIXME: why projector weights read two times?
|
244
|
+
if "projector" in name or "vision_tower" in name:
|
245
|
+
for weight_name, param_name in projector_weights.items():
|
246
|
+
if weight_name in name:
|
247
|
+
name = name.replace(weight_name, param_name)
|
248
|
+
if name in params_dict:
|
249
|
+
param = params_dict[name]
|
250
|
+
else:
|
251
|
+
print(f"Warning: {name} not found in the model")
|
252
|
+
continue
|
253
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
254
|
+
weight_loader(param, loaded_weight)
|
255
|
+
|
256
|
+
# load language model
|
257
|
+
self.language_model.load_weights(weights)
|
258
|
+
|
259
|
+
monkey_path_clip_vision_embed_forward()
|
260
|
+
|
261
|
+
@property
|
262
|
+
def num_patches_per_side(self):
|
263
|
+
return self.image_size // self.patch_size
|
264
|
+
|
265
|
+
|
266
|
+
first_call = True
|
267
|
+
|
268
|
+
|
269
|
+
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
270
|
+
batch_size = pixel_values.shape[0]
|
271
|
+
|
272
|
+
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
273
|
+
global first_call
|
274
|
+
if first_call:
|
275
|
+
self.patch_embedding.cpu().float()
|
276
|
+
first_call = False
|
277
|
+
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
278
|
+
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
279
|
+
|
280
|
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
281
|
+
|
282
|
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
283
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
284
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
285
|
+
return embeddings
|
286
|
+
|
287
|
+
|
288
|
+
def monkey_path_clip_vision_embed_forward():
|
289
|
+
import transformers
|
290
|
+
|
291
|
+
setattr(
|
292
|
+
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
293
|
+
"forward",
|
294
|
+
clip_vision_embed_forward,
|
295
|
+
)
|
296
|
+
|
297
|
+
|
298
|
+
EntryClass = LlavaVidForCausalLM
|