sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -4,29 +4,25 @@
4
4
  from typing import Optional, Tuple
5
5
 
6
6
  import torch
7
- from sglang.srt.layers.logits_processor import LogitsProcessor
8
- from sglang.srt.layers.radix_attention import RadixAttention
9
7
  from torch import nn
10
8
  from transformers import PretrainedConfig
11
9
  from vllm.config import LoRAConfig
12
- from vllm.model_executor.input_metadata import InputMetadata
10
+ from vllm.distributed import get_tensor_model_parallel_world_size
13
11
  from vllm.model_executor.layers.activation import GeluAndMul
14
12
  from vllm.model_executor.layers.layernorm import RMSNorm
15
13
  from vllm.model_executor.layers.linear import (
16
- LinearMethodBase,
17
14
  MergedColumnParallelLinear,
18
15
  QKVParallelLinear,
19
16
  RowParallelLinear,
20
17
  )
18
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
21
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
22
20
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
23
- from vllm.model_executor.parallel_utils.parallel_state import (
24
- get_tensor_model_parallel_world_size,
25
- )
26
- from vllm.model_executor.weight_utils import (
27
- default_weight_loader,
28
- hf_model_weights_iterator,
29
- )
21
+
22
+ from sglang.srt.layers.logits_processor import LogitsProcessor
23
+ from sglang.srt.layers.radix_attention import RadixAttention
24
+ from sglang.srt.managers.router.model_runner import InputMetadata
25
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
30
26
 
31
27
 
32
28
  class GemmaMLP(nn.Module):
@@ -34,17 +30,20 @@ class GemmaMLP(nn.Module):
34
30
  self,
35
31
  hidden_size: int,
36
32
  intermediate_size: int,
37
- linear_method: Optional[LinearMethodBase] = None,
33
+ quant_config: Optional[QuantizationConfig] = None,
38
34
  ) -> None:
39
35
  super().__init__()
40
36
  self.gate_up_proj = MergedColumnParallelLinear(
41
37
  hidden_size,
42
38
  [intermediate_size] * 2,
43
39
  bias=False,
44
- linear_method=linear_method,
40
+ quant_config=quant_config,
45
41
  )
46
42
  self.down_proj = RowParallelLinear(
47
- intermediate_size, hidden_size, bias=False, linear_method=linear_method
43
+ intermediate_size,
44
+ hidden_size,
45
+ bias=False,
46
+ quant_config=quant_config,
48
47
  )
49
48
  self.act_fn = GeluAndMul()
50
49
 
@@ -65,7 +64,7 @@ class GemmaAttention(nn.Module):
65
64
  layer_id: int = 0,
66
65
  max_position_embeddings: int = 8192,
67
66
  rope_theta: float = 10000,
68
- linear_method: Optional[LinearMethodBase] = None,
67
+ quant_config: Optional[QuantizationConfig] = None,
69
68
  ) -> None:
70
69
  super().__init__()
71
70
  self.hidden_size = hidden_size
@@ -95,13 +94,13 @@ class GemmaAttention(nn.Module):
95
94
  self.total_num_heads,
96
95
  self.total_num_kv_heads,
97
96
  bias=False,
98
- linear_method=linear_method,
97
+ quant_config=quant_config,
99
98
  )
100
99
  self.o_proj = RowParallelLinear(
101
100
  self.total_num_heads * self.head_dim,
102
101
  hidden_size,
103
102
  bias=False,
104
- linear_method=linear_method,
103
+ quant_config=quant_config,
105
104
  )
106
105
 
107
106
  self.rotary_emb = get_rope(
@@ -138,7 +137,7 @@ class GemmaDecoderLayer(nn.Module):
138
137
  self,
139
138
  config: PretrainedConfig,
140
139
  layer_id: int = 0,
141
- linear_method: Optional[LinearMethodBase] = None,
140
+ quant_config: Optional[QuantizationConfig] = None,
142
141
  ) -> None:
143
142
  super().__init__()
144
143
  self.hidden_size = config.hidden_size
@@ -150,12 +149,12 @@ class GemmaDecoderLayer(nn.Module):
150
149
  layer_id=layer_id,
151
150
  max_position_embeddings=config.max_position_embeddings,
152
151
  rope_theta=config.rope_theta,
153
- linear_method=linear_method,
152
+ quant_config=quant_config,
154
153
  )
155
154
  self.mlp = GemmaMLP(
156
155
  hidden_size=self.hidden_size,
157
156
  intermediate_size=config.intermediate_size,
158
- linear_method=linear_method,
157
+ quant_config=quant_config,
159
158
  )
160
159
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
161
160
  self.post_attention_layernorm = RMSNorm(
@@ -191,7 +190,7 @@ class GemmaModel(nn.Module):
191
190
  def __init__(
192
191
  self,
193
192
  config: PretrainedConfig,
194
- linear_method: Optional[LinearMethodBase] = None,
193
+ quant_config: Optional[QuantizationConfig] = None,
195
194
  ) -> None:
196
195
  super().__init__()
197
196
  self.config = config
@@ -202,7 +201,7 @@ class GemmaModel(nn.Module):
202
201
  )
203
202
  self.layers = nn.ModuleList(
204
203
  [
205
- GemmaDecoderLayer(config, i, linear_method)
204
+ GemmaDecoderLayer(config, i, quant_config=quant_config)
206
205
  for i in range(config.num_hidden_layers)
207
206
  ]
208
207
  )
@@ -263,14 +262,14 @@ class GemmaForCausalLM(nn.Module):
263
262
  def __init__(
264
263
  self,
265
264
  config: PretrainedConfig,
266
- linear_method: Optional[LinearMethodBase] = None,
265
+ quant_config: Optional[QuantizationConfig] = None,
267
266
  lora_config: Optional[LoRAConfig] = None,
268
267
  ) -> None:
269
268
  del lora_config # Unused.
270
269
  super().__init__()
271
270
  self.config = config
272
- self.linear_method = linear_method
273
- self.model = GemmaModel(config, linear_method)
271
+ self.quant_config = quant_config
272
+ self.model = GemmaModel(config, quant_config=quant_config)
274
273
  self.logits_processor = LogitsProcessor(config)
275
274
 
276
275
  @torch.no_grad()
@@ -1,34 +1,30 @@
1
1
  # Adapted from
2
2
  # https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
3
3
  """Inference-only LLaMA model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Dict, Optional, Tuple
5
5
 
6
6
  import torch
7
- from sglang.srt.layers.logits_processor import LogitsProcessor
8
- from sglang.srt.layers.radix_attention import RadixAttention
9
- from sglang.srt.managers.router.model_runner import InputMetadata
10
7
  from torch import nn
11
8
  from transformers import LlamaConfig
9
+ from vllm.distributed import get_tensor_model_parallel_world_size
12
10
  from vllm.model_executor.layers.activation import SiluAndMul
13
11
  from vllm.model_executor.layers.layernorm import RMSNorm
14
12
  from vllm.model_executor.layers.linear import (
15
- LinearMethodBase,
16
13
  MergedColumnParallelLinear,
17
14
  QKVParallelLinear,
18
15
  RowParallelLinear,
19
16
  )
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
20
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
21
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
22
20
  ParallelLMHead,
23
21
  VocabParallelEmbedding,
24
22
  )
25
- from vllm.model_executor.parallel_utils.parallel_state import (
26
- get_tensor_model_parallel_world_size,
27
- )
28
- from vllm.model_executor.weight_utils import (
29
- default_weight_loader,
30
- hf_model_weights_iterator,
31
- )
23
+
24
+ from sglang.srt.layers.logits_processor import LogitsProcessor
25
+ from sglang.srt.layers.radix_attention import RadixAttention
26
+ from sglang.srt.managers.router.model_runner import InputMetadata
27
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
32
28
 
33
29
 
34
30
  class LlamaMLP(nn.Module):
@@ -37,17 +33,20 @@ class LlamaMLP(nn.Module):
37
33
  hidden_size: int,
38
34
  intermediate_size: int,
39
35
  hidden_act: str,
40
- linear_method: Optional[LinearMethodBase] = None,
36
+ quant_config: Optional[QuantizationConfig] = None,
41
37
  ) -> None:
42
38
  super().__init__()
43
39
  self.gate_up_proj = MergedColumnParallelLinear(
44
40
  hidden_size,
45
41
  [intermediate_size] * 2,
46
42
  bias=False,
47
- linear_method=linear_method,
43
+ quant_config=quant_config,
48
44
  )
49
45
  self.down_proj = RowParallelLinear(
50
- intermediate_size, hidden_size, bias=False, linear_method=linear_method
46
+ intermediate_size,
47
+ hidden_size,
48
+ bias=False,
49
+ quant_config=quant_config,
51
50
  )
52
51
  if hidden_act != "silu":
53
52
  raise ValueError(
@@ -73,7 +72,7 @@ class LlamaAttention(nn.Module):
73
72
  rope_theta: float = 10000,
74
73
  rope_scaling: Optional[Dict[str, Any]] = None,
75
74
  max_position_embeddings: int = 8192,
76
- linear_method: Optional[LinearMethodBase] = None,
75
+ quant_config: Optional[QuantizationConfig] = None,
77
76
  ) -> None:
78
77
  super().__init__()
79
78
  self.hidden_size = hidden_size
@@ -104,13 +103,13 @@ class LlamaAttention(nn.Module):
104
103
  self.total_num_heads,
105
104
  self.total_num_kv_heads,
106
105
  bias=False,
107
- linear_method=linear_method,
106
+ quant_config=quant_config,
108
107
  )
109
108
  self.o_proj = RowParallelLinear(
110
109
  self.total_num_heads * self.head_dim,
111
110
  hidden_size,
112
111
  bias=False,
113
- linear_method=linear_method,
112
+ quant_config=quant_config,
114
113
  )
115
114
 
116
115
  self.rotary_emb = get_rope(
@@ -147,7 +146,7 @@ class LlamaDecoderLayer(nn.Module):
147
146
  self,
148
147
  config: LlamaConfig,
149
148
  layer_id: int = 0,
150
- linear_method: Optional[LinearMethodBase] = None,
149
+ quant_config: Optional[QuantizationConfig] = None,
151
150
  ) -> None:
152
151
  super().__init__()
153
152
  self.hidden_size = config.hidden_size
@@ -162,13 +161,13 @@ class LlamaDecoderLayer(nn.Module):
162
161
  rope_theta=rope_theta,
163
162
  rope_scaling=rope_scaling,
164
163
  max_position_embeddings=max_position_embeddings,
165
- linear_method=linear_method,
164
+ quant_config=quant_config,
166
165
  )
167
166
  self.mlp = LlamaMLP(
168
167
  hidden_size=self.hidden_size,
169
168
  intermediate_size=config.intermediate_size,
170
169
  hidden_act=config.hidden_act,
171
- linear_method=linear_method,
170
+ quant_config=quant_config,
172
171
  )
173
172
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
174
173
  self.post_attention_layernorm = RMSNorm(
@@ -204,7 +203,7 @@ class LlamaModel(nn.Module):
204
203
  def __init__(
205
204
  self,
206
205
  config: LlamaConfig,
207
- linear_method: Optional[LinearMethodBase] = None,
206
+ quant_config: Optional[QuantizationConfig] = None,
208
207
  ) -> None:
209
208
  super().__init__()
210
209
  self.config = config
@@ -216,7 +215,7 @@ class LlamaModel(nn.Module):
216
215
  )
217
216
  self.layers = nn.ModuleList(
218
217
  [
219
- LlamaDecoderLayer(config, i, linear_method)
218
+ LlamaDecoderLayer(config, i, quant_config=quant_config)
220
219
  for i in range(config.num_hidden_layers)
221
220
  ]
222
221
  )
@@ -250,12 +249,12 @@ class LlamaForCausalLM(nn.Module):
250
249
  def __init__(
251
250
  self,
252
251
  config: LlamaConfig,
253
- linear_method: Optional[LinearMethodBase] = None,
252
+ quant_config: Optional[QuantizationConfig] = None,
254
253
  ) -> None:
255
254
  super().__init__()
256
255
  self.config = config
257
- self.linear_method = linear_method
258
- self.model = LlamaModel(config, linear_method)
256
+ self.quant_config = quant_config
257
+ self.model = LlamaModel(config, quant_config=quant_config)
259
258
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
260
259
  self.logits_processor = LogitsProcessor(config)
261
260
 
@@ -4,6 +4,11 @@ from typing import List, Optional
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
+ from torch import nn
8
+ from transformers import CLIPVisionModel, LlavaConfig
9
+ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
10
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
11
+
7
12
  from sglang.srt.managers.router.infer_batch import ForwardMode
8
13
  from sglang.srt.managers.router.model_runner import InputMetadata
9
14
  from sglang.srt.mm_utils import (
@@ -12,21 +17,14 @@ from sglang.srt.mm_utils import (
12
17
  unpad_image_shape,
13
18
  )
14
19
  from sglang.srt.models.llama2 import LlamaForCausalLM
15
- from torch import nn
16
- from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
17
- from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
18
- from vllm.model_executor.layers.linear import LinearMethodBase
19
- from vllm.model_executor.weight_utils import (
20
- default_weight_loader,
21
- hf_model_weights_iterator,
22
- )
20
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
23
21
 
24
22
 
25
23
  class LlavaLlamaForCausalLM(nn.Module):
26
24
  def __init__(
27
25
  self,
28
26
  config: LlavaConfig,
29
- linear_method: Optional[LinearMethodBase] = None,
27
+ quant_config: Optional[QuantizationConfig] = None,
30
28
  ) -> None:
31
29
  super().__init__()
32
30
  self.config = config
@@ -34,7 +32,7 @@ class LlavaLlamaForCausalLM(nn.Module):
34
32
  self.config.vision_config.hidden_size = config.mm_hidden_size
35
33
  self.config.text_config.hidden_size = config.hidden_size
36
34
  self.multi_modal_projector = LlavaMultiModalProjector(config)
37
- self.language_model = LlamaForCausalLM(config, linear_method)
35
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
38
36
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
39
37
  self.language_model.model.image_newline = nn.Parameter(
40
38
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
@@ -0,0 +1,307 @@
1
+ """Inference-only LLaVa video model compatible with HuggingFace weights."""
2
+
3
+ import os
4
+ from typing import List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
10
+ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
11
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
12
+
13
+ from sglang.srt.managers.router.infer_batch import ForwardMode
14
+ from sglang.srt.managers.router.model_runner import InputMetadata
15
+ from sglang.srt.mm_utils import (
16
+ get_anyres_image_grid_shape,
17
+ unpad_image,
18
+ unpad_image_shape,
19
+ )
20
+ from sglang.srt.models.llama2 import LlamaForCausalLM
21
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
22
+
23
+
24
+ class LlavaVidForCausalLM(nn.Module):
25
+ def __init__(
26
+ self,
27
+ config: LlavaConfig,
28
+ quant_config: Optional[QuantizationConfig] = None,
29
+ ) -> None:
30
+ super().__init__()
31
+ self.config = config
32
+ self.vision_tower = None
33
+ self.config.vision_config.hidden_size = config.mm_hidden_size
34
+ self.config.text_config.hidden_size = config.hidden_size
35
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
36
+ self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
37
+ self.resampler = nn.AvgPool2d(
38
+ kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
39
+ )
40
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
41
+ self.num_frames = getattr(self.config, "num_frames", 16)
42
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
43
+ self.language_model.model.image_newline = nn.Parameter(
44
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
45
+ )
46
+
47
+ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
48
+ new_image_feature_len = self.image_feature_len
49
+ # now only support spatial_unpad + anyres
50
+ # if self.mm_patch_merge_type.startswith("spatial"):
51
+ # height = width = self.num_patches_per_side
52
+ # if pt_shape[0] > 1:
53
+ # if self.image_aspect_ratio == "anyres":
54
+ # num_patch_width, num_patch_height = get_anyres_image_grid_shape(
55
+ # image_size,
56
+ # self.image_grid_pinpoints,
57
+ # self.vision_tower.config.image_size,
58
+ # )
59
+ # if "unpad" in self.mm_patch_merge_type:
60
+ # h = num_patch_height * height
61
+ # w = num_patch_width * width
62
+ # new_h, new_w = unpad_image_shape(h, w, image_size)
63
+ # new_image_feature_len += new_h * (new_w + 1)
64
+
65
+ pad_ids = pad_value * (
66
+ (new_image_feature_len + len(pad_value)) // len(pad_value)
67
+ )
68
+ # print(input_ids)
69
+ offset = input_ids.index(self.config.image_token_index)
70
+ # old_len + pad_len - 1, because we need to remove image_token_id
71
+ new_input_ids = (
72
+ input_ids[:offset]
73
+ + pad_ids[:new_image_feature_len]
74
+ + input_ids[offset + 1 :]
75
+ )
76
+ return new_input_ids, offset
77
+
78
+ def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
79
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
80
+ # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
81
+
82
+ selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
83
+ if self.vision_feature_select_strategy in ["default", "patch"]:
84
+ selected_image_feature = selected_image_feature[:, 1:]
85
+ elif self.vision_feature_select_strategy == "full":
86
+ selected_image_feature = selected_image_feature
87
+ else:
88
+ raise ValueError(
89
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
90
+ )
91
+
92
+ height = width = self.num_patches_per_side
93
+ num_of_frames = selected_image_feature.shape[0]
94
+ selected_image_feature = selected_image_feature.view(
95
+ num_of_frames, height, width, -1
96
+ )
97
+ selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
98
+ selected_image_feature = (
99
+ self.resampler(selected_image_feature)
100
+ .flatten(2)
101
+ .transpose(1, 2)
102
+ .contiguous()
103
+ )
104
+
105
+ image_features = self.multi_modal_projector(selected_image_feature)
106
+
107
+ return image_features
108
+
109
+ def forward(
110
+ self,
111
+ input_ids: torch.LongTensor,
112
+ positions: torch.Tensor,
113
+ input_metadata: InputMetadata,
114
+ pixel_values: Optional[List[Optional[np.array]]] = None,
115
+ image_sizes: Optional[List[List[int]]] = None,
116
+ image_offsets: Optional[List[int]] = None,
117
+ ) -> torch.Tensor:
118
+ if input_metadata.forward_mode == ForwardMode.EXTEND:
119
+ bs = input_metadata.batch_size
120
+
121
+ # Embed text input
122
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
123
+
124
+ # Embed vision input
125
+ need_vision = (
126
+ (positions[input_metadata.extend_start_loc] < self.image_feature_len)
127
+ .cpu()
128
+ .numpy()
129
+ )
130
+ # FIXME: We need to substract the length of the system prompt
131
+ has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
132
+ need_vision = need_vision & has_pixel
133
+
134
+ if need_vision.any():
135
+ pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
136
+ image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
137
+
138
+ ########## Encode Image ########
139
+
140
+ if pixel_values[0].ndim == 4:
141
+ # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
142
+ np.concatenate(pixel_values, axis=0)
143
+ # ndim=4
144
+ concat_images = torch.tensor(
145
+ np.concatenate(pixel_values, axis=0),
146
+ device=self.vision_tower.device,
147
+ )
148
+ # image_features = self.encode_images(concat_images)
149
+ # split_sizes = [image.shape[0] for image in pixel_values]
150
+ # image_features = torch.split(image_features, split_sizes, dim=0)
151
+ image_features = self.encode_images(
152
+ concat_images
153
+ ) # , prompts)#, image_counts, long_video=long_video)
154
+ split_sizes = [image.shape[0] for image in pixel_values]
155
+ image_features = torch.split(image_features, split_sizes, dim=0)
156
+
157
+ # hd image_features: BS, num_patch, 576, 4096
158
+ else:
159
+ # normal pixel: BS, C=3, H=336, W=336
160
+ pixel_values = torch.tensor(
161
+ np.array(pixel_values), device=self.vision_tower.device
162
+ )
163
+ image_features = self.encode_images(pixel_values)
164
+ # image_features: BS, 576, 4096
165
+
166
+ new_image_features = []
167
+ for image_idx, image_feature in enumerate(image_features):
168
+ new_image_features.append(image_feature.flatten(0, 1))
169
+ image_features = new_image_features
170
+
171
+ extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
172
+ pt = 0
173
+ for i in range(bs):
174
+ if not need_vision[i]:
175
+ continue
176
+
177
+ start_idx = extend_start_loc_cpu[i]
178
+ pad_len, pad_dim = image_features[pt].shape # 576, 4096
179
+ dim = input_embeds.shape[1]
180
+ assert (
181
+ pad_dim == dim
182
+ ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
183
+ # Fill in the placeholder for the image
184
+ try:
185
+ input_embeds[
186
+ start_idx
187
+ + image_offsets[i] : start_idx
188
+ + image_offsets[i]
189
+ + pad_len
190
+ ] = image_features[pt]
191
+ except RuntimeError as e:
192
+ print(f"RuntimeError in llava image encoding: {e}")
193
+ print(input_embeds.shape)
194
+ print(start_idx, image_offsets[i])
195
+ pt += 1
196
+
197
+ return self.language_model(
198
+ input_ids, positions, input_metadata, input_embeds=input_embeds
199
+ )
200
+ elif input_metadata.forward_mode == ForwardMode.DECODE:
201
+ return self.language_model(input_ids, positions, input_metadata)
202
+
203
+ def load_weights(
204
+ self,
205
+ model_name_or_path: str,
206
+ cache_dir: Optional[str] = None,
207
+ load_format: str = "auto",
208
+ revision: Optional[str] = None,
209
+ ):
210
+ # load clip vision model by cfg['mm_vision_tower']:
211
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
212
+ vision_path = self.config.mm_vision_tower
213
+ self.vision_tower = CLIPVisionModel.from_pretrained(
214
+ vision_path, torch_dtype=torch.float16
215
+ ).cuda()
216
+ self.vision_tower.eval()
217
+
218
+ self.vision_feature_layer = self.config.mm_vision_select_layer
219
+ self.vision_feature_select_strategy = self.config.mm_vision_select_feature
220
+ self.image_size = self.vision_tower.config.image_size
221
+ self.patch_size = self.vision_tower.config.patch_size
222
+
223
+ self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
224
+ self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
225
+ self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
226
+
227
+ print(f"target_frames: {self.num_frames}")
228
+ self.image_feature_len = self.num_frames * int(
229
+ (self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
230
+ )
231
+ if self.vision_feature_select_strategy == "patch":
232
+ pass
233
+ elif self.vision_feature_select_strategy == "cls_patch":
234
+ self.image_feature_len += 1
235
+ else:
236
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
237
+
238
+ # load mm_projector
239
+ projector_weights = {
240
+ "model.mm_projector.0": "multi_modal_projector.linear_1",
241
+ "model.mm_projector.2": "multi_modal_projector.linear_2",
242
+ "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
243
+ "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
244
+ "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
245
+ }
246
+ params_dict = dict(self.named_parameters())
247
+ for name, loaded_weight in hf_model_weights_iterator(
248
+ model_name_or_path, cache_dir, load_format, revision
249
+ ):
250
+ # FIXME: why projector weights read two times?
251
+ if "projector" in name or "vision_tower" in name:
252
+ for weight_name, param_name in projector_weights.items():
253
+ if weight_name in name:
254
+ name = name.replace(weight_name, param_name)
255
+ if name in params_dict:
256
+ param = params_dict[name]
257
+ else:
258
+ print(f"Warning: {name} not found in the model")
259
+ continue
260
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
261
+ weight_loader(param, loaded_weight)
262
+
263
+ # load language model
264
+ self.language_model.load_weights(
265
+ model_name_or_path, cache_dir, load_format, revision
266
+ )
267
+
268
+ monkey_path_clip_vision_embed_forward()
269
+
270
+ @property
271
+ def num_patches_per_side(self):
272
+ return self.image_size // self.patch_size
273
+
274
+
275
+ first_call = True
276
+
277
+
278
+ def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
279
+ batch_size = pixel_values.shape[0]
280
+
281
+ # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
282
+ global first_call
283
+ if first_call:
284
+ self.patch_embedding.cpu().float()
285
+ first_call = False
286
+ pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
287
+ patch_embeds = self.patch_embedding(pixel_values).cuda().half()
288
+
289
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
290
+
291
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
292
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
293
+ embeddings = embeddings + self.position_embedding(self.position_ids)
294
+ return embeddings
295
+
296
+
297
+ def monkey_path_clip_vision_embed_forward():
298
+ import transformers
299
+
300
+ setattr(
301
+ transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
302
+ "forward",
303
+ clip_vision_embed_forward,
304
+ )
305
+
306
+
307
+ EntryClass = LlavaVidForCausalLM