sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,17 @@
1
1
  # Adapted from
2
- # https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
3
3
  """Inference-only LLaMA model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, Optional, Tuple
4
+ from typing import Any, Dict, Optional, Tuple, Iterable
5
5
 
6
6
  import torch
7
+ import tqdm
7
8
  from torch import nn
8
9
  from transformers import LlamaConfig
10
+ from vllm.config import CacheConfig
11
+ from vllm.distributed import (
12
+ get_tensor_model_parallel_rank,
13
+ get_tensor_model_parallel_world_size
14
+ )
9
15
  from vllm.model_executor.layers.activation import SiluAndMul
10
16
  from vllm.model_executor.layers.layernorm import RMSNorm
11
17
  from vllm.model_executor.layers.linear import (
@@ -13,24 +19,17 @@ from vllm.model_executor.layers.linear import (
13
19
  QKVParallelLinear,
14
20
  RowParallelLinear,
15
21
  )
16
- from vllm.model_executor.layers.quantization.base_config import (
17
- QuantizationConfig)
22
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
18
23
  from vllm.model_executor.layers.rotary_embedding import get_rope
19
24
  from vllm.model_executor.layers.vocab_parallel_embedding import (
20
25
  ParallelLMHead,
21
26
  VocabParallelEmbedding,
22
27
  )
23
- from vllm.distributed import (
24
- get_tensor_model_parallel_world_size,
25
- )
26
- from sglang.srt.weight_utils import (
27
- default_weight_loader,
28
- hf_model_weights_iterator,
29
- )
28
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
29
 
31
30
  from sglang.srt.layers.logits_processor import LogitsProcessor
32
31
  from sglang.srt.layers.radix_attention import RadixAttention
33
- from sglang.srt.managers.router.model_runner import InputMetadata
32
+ from sglang.srt.managers.controller.model_runner import InputMetadata
34
33
 
35
34
 
36
35
  class LlamaMLP(nn.Module):
@@ -49,7 +48,10 @@ class LlamaMLP(nn.Module):
49
48
  quant_config=quant_config,
50
49
  )
51
50
  self.down_proj = RowParallelLinear(
52
- intermediate_size, hidden_size, bias=False, quant_config=quant_config,
51
+ intermediate_size,
52
+ hidden_size,
53
+ bias=False,
54
+ quant_config=quant_config,
53
55
  )
54
56
  if hidden_act != "silu":
55
57
  raise ValueError(
@@ -155,6 +157,10 @@ class LlamaDecoderLayer(nn.Module):
155
157
  self.hidden_size = config.hidden_size
156
158
  rope_theta = getattr(config, "rope_theta", 10000)
157
159
  rope_scaling = getattr(config, "rope_scaling", None)
160
+ if rope_scaling is not None and getattr(
161
+ config, "original_max_position_embeddings", None):
162
+ rope_scaling["original_max_position_embeddings"] = (
163
+ config.original_max_position_embeddings)
158
164
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
159
165
  self.self_attn = LlamaAttention(
160
166
  hidden_size=self.hidden_size,
@@ -253,6 +259,7 @@ class LlamaForCausalLM(nn.Module):
253
259
  self,
254
260
  config: LlamaConfig,
255
261
  quant_config: Optional[QuantizationConfig] = None,
262
+ cache_config: Optional[CacheConfig] = None,
256
263
  ) -> None:
257
264
  super().__init__()
258
265
  self.config = config
@@ -273,13 +280,7 @@ class LlamaForCausalLM(nn.Module):
273
280
  input_ids, hidden_states, self.lm_head.weight, input_metadata
274
281
  )
275
282
 
276
- def load_weights(
277
- self,
278
- model_name_or_path: str,
279
- cache_dir: Optional[str] = None,
280
- load_format: str = "auto",
281
- revision: Optional[str] = None,
282
- ):
283
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
283
284
  stacked_params_mapping = [
284
285
  # (param_name, shard_name, shard_id)
285
286
  ("qkv_proj", "q_proj", "q"),
@@ -289,9 +290,9 @@ class LlamaForCausalLM(nn.Module):
289
290
  ("gate_up_proj", "up_proj", 1),
290
291
  ]
291
292
  params_dict = dict(self.named_parameters())
292
- for name, loaded_weight in hf_model_weights_iterator(
293
- model_name_or_path, cache_dir, load_format, revision
294
- ):
293
+ if get_tensor_model_parallel_rank() == 0:
294
+ weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
295
+ for name, loaded_weight in weights:
295
296
  if "rotary_emb.inv_freq" in name or "projector" in name:
296
297
  continue
297
298
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -1,27 +1,26 @@
1
1
  """Inference-only LLaVa model compatible with HuggingFace weights."""
2
2
 
3
- from typing import List, Optional
3
+ from typing import List, Iterable, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
7
  from torch import nn
8
- from transformers import CLIPVisionModel, LlavaConfig
8
+ from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig
9
9
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
10
- from vllm.model_executor.layers.quantization.base_config import (
11
- QuantizationConfig)
12
- from sglang.srt.weight_utils import (
13
- default_weight_loader,
14
- hf_model_weights_iterator,
15
- )
10
+ from vllm.config import CacheConfig
11
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
12
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
16
13
 
17
- from sglang.srt.managers.router.infer_batch import ForwardMode
18
- from sglang.srt.managers.router.model_runner import InputMetadata
14
+ from sglang.srt.managers.controller.infer_batch import ForwardMode
15
+ from sglang.srt.managers.controller.model_runner import InputMetadata
19
16
  from sglang.srt.mm_utils import (
20
17
  get_anyres_image_grid_shape,
21
18
  unpad_image,
22
19
  unpad_image_shape,
23
20
  )
24
21
  from sglang.srt.models.llama2 import LlamaForCausalLM
22
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
23
+ from sglang.srt.models.mistral import MistralForCausalLM
25
24
 
26
25
 
27
26
  class LlavaLlamaForCausalLM(nn.Module):
@@ -29,6 +28,7 @@ class LlavaLlamaForCausalLM(nn.Module):
29
28
  self,
30
29
  config: LlavaConfig,
31
30
  quant_config: Optional[QuantizationConfig] = None,
31
+ cache_config: Optional[CacheConfig] = None,
32
32
  ) -> None:
33
33
  super().__init__()
34
34
  self.config = config
@@ -237,13 +237,7 @@ class LlavaLlamaForCausalLM(nn.Module):
237
237
  elif input_metadata.forward_mode == ForwardMode.DECODE:
238
238
  return self.language_model(input_ids, positions, input_metadata)
239
239
 
240
- def load_weights(
241
- self,
242
- model_name_or_path: str,
243
- cache_dir: Optional[str] = None,
244
- load_format: str = "auto",
245
- revision: Optional[str] = None,
246
- ):
240
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
247
241
  # load clip vision model by cfg['mm_vision_tower']:
248
242
  # huggingface_name or path_of_clip_relative_to_llava_model_dir
249
243
  vision_path = self.config.mm_vision_tower
@@ -276,9 +270,8 @@ class LlavaLlamaForCausalLM(nn.Module):
276
270
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
277
271
  }
278
272
  params_dict = dict(self.named_parameters())
279
- for name, loaded_weight in hf_model_weights_iterator(
280
- model_name_or_path, cache_dir, load_format, revision
281
- ):
273
+ weights = list(weights)
274
+ for name, loaded_weight in weights:
282
275
  # FIXME: why projector weights read two times?
283
276
  if "projector" in name or "vision_tower" in name:
284
277
  for weight_name, param_name in projector_weights.items():
@@ -289,9 +282,7 @@ class LlavaLlamaForCausalLM(nn.Module):
289
282
  weight_loader(param, loaded_weight)
290
283
 
291
284
  # load language model
292
- self.language_model.load_weights(
293
- model_name_or_path, cache_dir, load_format, revision
294
- )
285
+ self.language_model.load_weights(weights)
295
286
 
296
287
  monkey_path_clip_vision_embed_forward()
297
288
 
@@ -300,9 +291,74 @@ class LlavaLlamaForCausalLM(nn.Module):
300
291
  return self.image_size // self.patch_size
301
292
 
302
293
 
303
- first_call = True
294
+ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
295
+ def __init__(
296
+ self,
297
+ config: LlavaConfig,
298
+ quant_config: Optional[QuantizationConfig] = None,
299
+ cache_config: Optional[CacheConfig] = None,
300
+ ) -> None:
301
+ super().__init__(config, quant_config=quant_config, cache_config=cache_config)
302
+ self.config = config
303
+ self.vision_tower = None
304
+ if getattr(self.config, "vision_config", None) is None:
305
+ self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
306
+
307
+ if getattr(self.config, "text_config", None) is None:
308
+ self.config.text_config = Qwen2Config(self.config._name_or_path)
309
+
310
+ self.config.vision_config.hidden_size = config.mm_hidden_size
311
+ self.config.text_config.hidden_size = config.hidden_size
312
+
313
+ if getattr(self.config, "projector_hidden_act", None) is None:
314
+ self.config.projector_hidden_act = "gelu"
315
+
316
+ if getattr(self.config, "image_token_index", None) is None:
317
+ self.config.image_token_index = 151646
318
+
319
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
320
+ self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
321
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
322
+ self.language_model.model.image_newline = nn.Parameter(
323
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
324
+ )
304
325
 
305
326
 
327
+ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
328
+ def __init__(
329
+ self,
330
+ config: LlavaConfig,
331
+ quant_config: Optional[QuantizationConfig] = None,
332
+ cache_config: Optional[CacheConfig] = None,
333
+ ) -> None:
334
+ super().__init__(config, quant_config=quant_config, cache_config=cache_config)
335
+ self.config = config
336
+ self.vision_tower = None
337
+ if getattr(self.config, "vision_config", None) is None:
338
+ self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
339
+
340
+ if getattr(self.config, "text_config", None) is None:
341
+ self.config.text_config = MistralConfig(self.config._name_or_path)
342
+
343
+ self.config.vision_config.hidden_size = config.mm_hidden_size
344
+ self.config.text_config.hidden_size = config.hidden_size
345
+
346
+ if getattr(self.config, "projector_hidden_act", None) is None:
347
+ self.config.projector_hidden_act = "gelu"
348
+
349
+ if getattr(self.config, "image_token_index", None) is None:
350
+ self.config.image_token_index = 32000
351
+
352
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
353
+ self.language_model = MistralForCausalLM(config, quant_config=quant_config)
354
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
355
+ self.language_model.model.image_newline = nn.Parameter(
356
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
357
+ )
358
+
359
+
360
+ first_call = True
361
+
306
362
  def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
307
363
  batch_size = pixel_values.shape[0]
308
364
 
@@ -332,4 +388,8 @@ def monkey_path_clip_vision_embed_forward():
332
388
  )
333
389
 
334
390
 
335
- EntryClass = LlavaLlamaForCausalLM
391
+ EntryClass = [
392
+ LlavaLlamaForCausalLM,
393
+ LlavaQwenForCausalLM,
394
+ LlavaMistralForCausalLM
395
+ ]
@@ -0,0 +1,298 @@
1
+ """Inference-only LLaVa video model compatible with HuggingFace weights."""
2
+
3
+ from typing import List, Iterable, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ from transformers import CLIPVisionModel, LlavaConfig
9
+ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
10
+ from vllm.config import CacheConfig
11
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
12
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
13
+
14
+ from sglang.srt.managers.controller.infer_batch import ForwardMode
15
+ from sglang.srt.managers.controller.model_runner import InputMetadata
16
+ from sglang.srt.mm_utils import (
17
+ get_anyres_image_grid_shape,
18
+ unpad_image,
19
+ unpad_image_shape,
20
+ )
21
+ from sglang.srt.models.llama2 import LlamaForCausalLM
22
+
23
+
24
+ class LlavaVidForCausalLM(nn.Module):
25
+ def __init__(
26
+ self,
27
+ config: LlavaConfig,
28
+ quant_config: Optional[QuantizationConfig] = None,
29
+ cache_config: Optional[CacheConfig] = None,
30
+ ) -> None:
31
+ super().__init__()
32
+ self.config = config
33
+ self.vision_tower = None
34
+ self.config.vision_config.hidden_size = config.mm_hidden_size
35
+ self.config.text_config.hidden_size = config.hidden_size
36
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
37
+ self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
38
+ self.resampler = nn.AvgPool2d(
39
+ kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
40
+ )
41
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
42
+ self.num_frames = getattr(self.config, "num_frames", 16)
43
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
44
+ self.language_model.model.image_newline = nn.Parameter(
45
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
46
+ )
47
+
48
+ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
49
+ new_image_feature_len = self.image_feature_len
50
+ # now only support spatial_unpad + anyres
51
+ # if self.mm_patch_merge_type.startswith("spatial"):
52
+ # height = width = self.num_patches_per_side
53
+ # if pt_shape[0] > 1:
54
+ # if self.image_aspect_ratio == "anyres":
55
+ # num_patch_width, num_patch_height = get_anyres_image_grid_shape(
56
+ # image_size,
57
+ # self.image_grid_pinpoints,
58
+ # self.vision_tower.config.image_size,
59
+ # )
60
+ # if "unpad" in self.mm_patch_merge_type:
61
+ # h = num_patch_height * height
62
+ # w = num_patch_width * width
63
+ # new_h, new_w = unpad_image_shape(h, w, image_size)
64
+ # new_image_feature_len += new_h * (new_w + 1)
65
+
66
+ pad_ids = pad_value * (
67
+ (new_image_feature_len + len(pad_value)) // len(pad_value)
68
+ )
69
+ offset = input_ids.index(self.config.image_token_index)
70
+ # old_len + pad_len - 1, because we need to remove image_token_id
71
+ new_input_ids = (
72
+ input_ids[:offset]
73
+ + pad_ids[:new_image_feature_len]
74
+ + input_ids[offset + 1 :]
75
+ )
76
+ return new_input_ids, offset
77
+
78
+ def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
79
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
80
+ # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
81
+
82
+ selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
83
+ if self.vision_feature_select_strategy in ["default", "patch"]:
84
+ selected_image_feature = selected_image_feature[:, 1:]
85
+ elif self.vision_feature_select_strategy == "full":
86
+ selected_image_feature = selected_image_feature
87
+ else:
88
+ raise ValueError(
89
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
90
+ )
91
+
92
+ height = width = self.num_patches_per_side
93
+ num_of_frames = selected_image_feature.shape[0]
94
+ selected_image_feature = selected_image_feature.view(
95
+ num_of_frames, height, width, -1
96
+ )
97
+ selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
98
+ selected_image_feature = (
99
+ self.resampler(selected_image_feature)
100
+ .flatten(2)
101
+ .transpose(1, 2)
102
+ .contiguous()
103
+ )
104
+
105
+ image_features = self.multi_modal_projector(selected_image_feature)
106
+
107
+ return image_features
108
+
109
+ def forward(
110
+ self,
111
+ input_ids: torch.LongTensor,
112
+ positions: torch.Tensor,
113
+ input_metadata: InputMetadata,
114
+ pixel_values: Optional[List[Optional[np.array]]] = None,
115
+ image_sizes: Optional[List[List[int]]] = None,
116
+ image_offsets: Optional[List[int]] = None,
117
+ ) -> torch.Tensor:
118
+ if input_metadata.forward_mode == ForwardMode.EXTEND:
119
+ bs = input_metadata.batch_size
120
+
121
+ # Embed text input
122
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
123
+
124
+ # Embed vision input
125
+ need_vision = (
126
+ (positions[input_metadata.extend_start_loc] < self.image_feature_len)
127
+ .cpu()
128
+ .numpy()
129
+ )
130
+ # FIXME: We need to substract the length of the system prompt
131
+ has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
132
+ need_vision = need_vision & has_pixel
133
+
134
+ if need_vision.any():
135
+ pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
136
+ image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
137
+
138
+ ########## Encode Image ########
139
+
140
+ if pixel_values[0].ndim == 4:
141
+ # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
142
+ np.concatenate(pixel_values, axis=0)
143
+ # ndim=4
144
+ concat_images = torch.tensor(
145
+ np.concatenate(pixel_values, axis=0),
146
+ device=self.vision_tower.device,
147
+ )
148
+ # image_features = self.encode_images(concat_images)
149
+ # split_sizes = [image.shape[0] for image in pixel_values]
150
+ # image_features = torch.split(image_features, split_sizes, dim=0)
151
+ image_features = self.encode_images(
152
+ concat_images
153
+ ) # , prompts)#, image_counts, long_video=long_video)
154
+ split_sizes = [image.shape[0] for image in pixel_values]
155
+ image_features = torch.split(image_features, split_sizes, dim=0)
156
+
157
+ # hd image_features: BS, num_patch, 576, 4096
158
+ else:
159
+ # normal pixel: BS, C=3, H=336, W=336
160
+ pixel_values = torch.tensor(
161
+ np.array(pixel_values), device=self.vision_tower.device
162
+ )
163
+ image_features = self.encode_images(pixel_values)
164
+ # image_features: BS, 576, 4096
165
+
166
+ new_image_features = []
167
+ for image_idx, image_feature in enumerate(image_features):
168
+ new_image_features.append(image_feature.flatten(0, 1))
169
+ image_features = new_image_features
170
+
171
+ extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
172
+ pt = 0
173
+ for i in range(bs):
174
+ if not need_vision[i]:
175
+ continue
176
+
177
+ start_idx = extend_start_loc_cpu[i]
178
+ pad_len, pad_dim = image_features[pt].shape # 576, 4096
179
+ dim = input_embeds.shape[1]
180
+ assert (
181
+ pad_dim == dim
182
+ ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
183
+ # Fill in the placeholder for the image
184
+ try:
185
+ input_embeds[
186
+ start_idx
187
+ + image_offsets[i] : start_idx
188
+ + image_offsets[i]
189
+ + pad_len
190
+ ] = image_features[pt]
191
+ except RuntimeError as e:
192
+ print(f"RuntimeError in llava image encoding: {e}")
193
+ print(input_embeds.shape)
194
+ print(start_idx, image_offsets[i])
195
+ pt += 1
196
+
197
+ return self.language_model(
198
+ input_ids, positions, input_metadata, input_embeds=input_embeds
199
+ )
200
+ elif input_metadata.forward_mode == ForwardMode.DECODE:
201
+ return self.language_model(input_ids, positions, input_metadata)
202
+
203
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
204
+ # load clip vision model by cfg['mm_vision_tower']:
205
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
206
+ vision_path = self.config.mm_vision_tower
207
+ self.vision_tower = CLIPVisionModel.from_pretrained(
208
+ vision_path, torch_dtype=torch.float16
209
+ ).cuda()
210
+ self.vision_tower.eval()
211
+
212
+ self.vision_feature_layer = self.config.mm_vision_select_layer
213
+ self.vision_feature_select_strategy = self.config.mm_vision_select_feature
214
+ self.image_size = self.vision_tower.config.image_size
215
+ self.patch_size = self.vision_tower.config.patch_size
216
+
217
+ self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
218
+ self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
219
+ self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
220
+
221
+ print(f"target_frames: {self.num_frames}")
222
+ self.image_feature_len = self.num_frames * int(
223
+ (self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
224
+ )
225
+ if self.vision_feature_select_strategy == "patch":
226
+ pass
227
+ elif self.vision_feature_select_strategy == "cls_patch":
228
+ self.image_feature_len += 1
229
+ else:
230
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
231
+
232
+ # load mm_projector
233
+ projector_weights = {
234
+ "model.mm_projector.0": "multi_modal_projector.linear_1",
235
+ "model.mm_projector.2": "multi_modal_projector.linear_2",
236
+ "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
237
+ "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
238
+ "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
239
+ }
240
+ params_dict = dict(self.named_parameters())
241
+ weights = list(weights)
242
+ for name, loaded_weight in weights:
243
+ # FIXME: why projector weights read two times?
244
+ if "projector" in name or "vision_tower" in name:
245
+ for weight_name, param_name in projector_weights.items():
246
+ if weight_name in name:
247
+ name = name.replace(weight_name, param_name)
248
+ if name in params_dict:
249
+ param = params_dict[name]
250
+ else:
251
+ print(f"Warning: {name} not found in the model")
252
+ continue
253
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
254
+ weight_loader(param, loaded_weight)
255
+
256
+ # load language model
257
+ self.language_model.load_weights(weights)
258
+
259
+ monkey_path_clip_vision_embed_forward()
260
+
261
+ @property
262
+ def num_patches_per_side(self):
263
+ return self.image_size // self.patch_size
264
+
265
+
266
+ first_call = True
267
+
268
+
269
+ def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
270
+ batch_size = pixel_values.shape[0]
271
+
272
+ # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
273
+ global first_call
274
+ if first_call:
275
+ self.patch_embedding.cpu().float()
276
+ first_call = False
277
+ pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
278
+ patch_embeds = self.patch_embedding(pixel_values).cuda().half()
279
+
280
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
281
+
282
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
283
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
284
+ embeddings = embeddings + self.position_embedding(self.position_ids)
285
+ return embeddings
286
+
287
+
288
+ def monkey_path_clip_vision_embed_forward():
289
+ import transformers
290
+
291
+ setattr(
292
+ transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
293
+ "forward",
294
+ clip_vision_embed_forward,
295
+ )
296
+
297
+
298
+ EntryClass = LlavaVidForCausalLM