sglang 0.2.14__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. sglang/launch_server_llavavid.py +26 -0
  2. sglang/srt/constrained/fsm_cache.py +11 -2
  3. sglang/srt/constrained/jump_forward.py +1 -0
  4. sglang/srt/hf_transformers_utils.py +0 -149
  5. sglang/srt/layers/activation.py +93 -11
  6. sglang/srt/layers/layernorm.py +47 -4
  7. sglang/srt/layers/logits_processor.py +4 -4
  8. sglang/srt/layers/sampler.py +15 -68
  9. sglang/srt/managers/io_struct.py +5 -4
  10. sglang/srt/managers/schedule_batch.py +20 -25
  11. sglang/srt/managers/tokenizer_manager.py +74 -61
  12. sglang/srt/managers/tp_worker.py +49 -43
  13. sglang/srt/model_executor/cuda_graph_runner.py +17 -31
  14. sglang/srt/model_executor/forward_batch_info.py +9 -26
  15. sglang/srt/model_executor/model_runner.py +20 -17
  16. sglang/srt/models/chatglm.py +13 -5
  17. sglang/srt/models/commandr.py +1 -5
  18. sglang/srt/models/dbrx.py +1 -5
  19. sglang/srt/models/deepseek.py +1 -5
  20. sglang/srt/models/deepseek_v2.py +1 -5
  21. sglang/srt/models/gemma.py +3 -7
  22. sglang/srt/models/gemma2.py +2 -56
  23. sglang/srt/models/gpt_bigcode.py +2 -6
  24. sglang/srt/models/grok.py +10 -8
  25. sglang/srt/models/internlm2.py +1 -5
  26. sglang/srt/models/llama2.py +6 -11
  27. sglang/srt/models/llama_classification.py +2 -6
  28. sglang/srt/models/llama_embedding.py +3 -4
  29. sglang/srt/models/llava.py +69 -91
  30. sglang/srt/models/llavavid.py +40 -86
  31. sglang/srt/models/minicpm.py +1 -5
  32. sglang/srt/models/mixtral.py +1 -5
  33. sglang/srt/models/mixtral_quant.py +1 -5
  34. sglang/srt/models/qwen.py +2 -5
  35. sglang/srt/models/qwen2.py +5 -10
  36. sglang/srt/models/qwen2_moe.py +21 -24
  37. sglang/srt/models/stablelm.py +1 -5
  38. sglang/srt/models/yivl.py +2 -7
  39. sglang/srt/openai_api/adapter.py +85 -4
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +1 -74
  42. sglang/srt/sampling/sampling_params.py +4 -0
  43. sglang/srt/server.py +11 -4
  44. sglang/srt/utils.py +18 -33
  45. sglang/test/runners.py +2 -2
  46. sglang/test/test_layernorm.py +53 -1
  47. sglang/version.py +1 -1
  48. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +11 -5
  49. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +52 -51
  50. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +1 -1
  51. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import get_act_fn
27
26
  from vllm.model_executor.layers.linear import (
28
27
  ColumnParallelLinear,
29
28
  QKVParallelLinear,
@@ -33,9 +32,9 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
33
32
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
34
33
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
34
 
35
+ from sglang.srt.layers.activation import get_act_fn
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.layers.sampler import Sampler
39
38
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
40
39
 
41
40
 
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
262
261
  if lora_config:
263
262
  self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
264
263
  self.logits_processor = LogitsProcessor(config)
265
- self.sampler = Sampler()
266
264
 
267
265
  @torch.no_grad()
268
266
  def forward(
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
272
270
  input_metadata: InputMetadata,
273
271
  ) -> torch.Tensor:
274
272
  hidden_states = self.transformer(input_ids, positions, input_metadata)
275
- logits_output = self.logits_processor(
273
+ return self.logits_processor(
276
274
  input_ids, hidden_states, self.lm_head.weight, input_metadata
277
275
  )
278
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
279
- return sample_output, logits_output
280
276
 
281
277
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
282
278
  params_dict = dict(self.named_parameters(remove_duplicate=False))
sglang/srt/models/grok.py CHANGED
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.layers.sampler import Sampler
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
50
 
52
51
 
@@ -274,9 +273,9 @@ class Grok1Model(nn.Module):
274
273
  ) -> torch.Tensor:
275
274
  if input_embeds is None:
276
275
  hidden_states = self.embed_tokens(input_ids)
276
+ hidden_states.mul_(self.config.embedding_multiplier_scale)
277
277
  else:
278
278
  hidden_states = input_embeds
279
- hidden_states.mul_(self.config.embedding_multiplier_scale)
280
279
 
281
280
  for i in range(len(self.layers)):
282
281
  hidden_states = self.layers[i](positions, hidden_states, input_metadata)
@@ -285,7 +284,7 @@ class Grok1Model(nn.Module):
285
284
  return hidden_states
286
285
 
287
286
 
288
- class Grok1ModelForCausalLM(nn.Module):
287
+ class Grok1ForCausalLM(nn.Module):
289
288
  def __init__(
290
289
  self,
291
290
  config: PretrainedConfig,
@@ -298,7 +297,6 @@ class Grok1ModelForCausalLM(nn.Module):
298
297
  self.model = Grok1Model(config, quant_config=quant_config)
299
298
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
300
299
  self.logits_processor = LogitsProcessor(config)
301
- self.sampler = Sampler()
302
300
 
303
301
  # Monkey patch _prepare_weights to load pre-sharded weights
304
302
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -315,11 +313,9 @@ class Grok1ModelForCausalLM(nn.Module):
315
313
  input_embeds: torch.Tensor = None,
316
314
  ) -> torch.Tensor:
317
315
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
318
- logits_output = self.logits_processor(
316
+ return self.logits_processor(
319
317
  input_ids, hidden_states, self.lm_head.weight, input_metadata
320
318
  )
321
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
- return sample_output, logits_output
323
319
 
324
320
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
325
321
  stacked_params_mapping = [
@@ -419,4 +415,10 @@ def _prepare_presharded_weights(
419
415
  return hf_folder, hf_weights_files, use_safetensors
420
416
 
421
417
 
422
- EntryClass = Grok1ModelForCausalLM
418
+ class Grok1ModelForCausalLM(Grok1ForCausalLM):
419
+ """An alias for backward-compatbility."""
420
+
421
+ pass
422
+
423
+
424
+ EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
 
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
263
262
  self.model = InternLM2Model(config, quant_config)
264
263
  self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
265
264
  self.logits_processor = LogitsProcessor(config)
266
- self.sampler = Sampler()
267
265
 
268
266
  @torch.no_grad()
269
267
  def forward(
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
274
272
  input_embeds: torch.Tensor = None,
275
273
  ) -> torch.Tensor:
276
274
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
277
- logits_output = self.logits_processor(
275
+ return self.logits_processor(
278
276
  input_ids, hidden_states, self.output.weight, input_metadata
279
277
  )
280
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
281
- return sample_output, logits_output
282
278
 
283
279
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
284
280
  stacked_params_mapping = [
@@ -39,9 +39,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
42
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.sampler import Sampler
45
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
45
 
47
46
 
@@ -303,7 +302,6 @@ class LlamaForCausalLM(nn.Module):
303
302
  self.model = LlamaModel(config, quant_config=quant_config)
304
303
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
305
304
  self.logits_processor = LogitsProcessor(config)
306
- self.sampler = Sampler()
307
305
 
308
306
  @torch.no_grad()
309
307
  def forward(
@@ -312,13 +310,11 @@ class LlamaForCausalLM(nn.Module):
312
310
  positions: torch.Tensor,
313
311
  input_metadata: InputMetadata,
314
312
  input_embeds: torch.Tensor = None,
315
- ) -> LogitsProcessorOutput:
313
+ ) -> LogitProcessorOutput:
316
314
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
317
- logits_output = self.logits_processor(
315
+ return self.logits_processor(
318
316
  input_ids, hidden_states, self.lm_head.weight, input_metadata
319
317
  )
320
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
321
- return sample_output, logits_output
322
318
 
323
319
  def get_module_name(self, name):
324
320
  stacked_params_mapping = [
@@ -361,6 +357,9 @@ class LlamaForCausalLM(nn.Module):
361
357
  # Models trained using ColossalAI may include these tensors in
362
358
  # the checkpoint. Skip them.
363
359
  return
360
+ if name.startswith("model.vision_tower") and name not in params_dict:
361
+ return
362
+
364
363
  for param_name, weight_name, shard_id in stacked_params_mapping:
365
364
  if weight_name not in name:
366
365
  continue
@@ -368,8 +367,6 @@ class LlamaForCausalLM(nn.Module):
368
367
  # Skip loading extra bias for GPTQ models.
369
368
  if name.endswith(".bias") and name not in params_dict:
370
369
  continue
371
- if name.startswith("model.vision_tower") and name not in params_dict:
372
- continue
373
370
  param = params_dict[name]
374
371
  weight_loader = param.weight_loader
375
372
  weight_loader(param, loaded_weight, shard_id)
@@ -378,8 +375,6 @@ class LlamaForCausalLM(nn.Module):
378
375
  # Skip loading extra bias for GPTQ models.
379
376
  if name.endswith(".bias") and name not in params_dict:
380
377
  return
381
- if name.startswith("model.vision_tower") and name not in params_dict:
382
- return
383
378
  param = params_dict[name]
384
379
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
385
380
  weight_loader(param, loaded_weight)
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
24
24
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
27
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
28
28
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
29
  from sglang.srt.models.llama2 import LlamaModel
30
30
 
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
65
65
  (input_metadata.batch_size, self.config.classification_out_size)
66
66
  ).to(input_ids.device)
67
67
 
68
- return LogitsProcessorOutput(
68
+ return LogitProcessorOutput(
69
69
  next_token_logits=scores,
70
70
  next_token_logprobs=scores,
71
71
  normalized_prompt_logprobs=scores,
@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
103
103
  # Skip loading extra bias for GPTQ models.
104
104
  if name.endswith(".bias") and name not in params_dict:
105
105
  continue
106
- if name.startswith("model.vision_tower") and name not in params_dict:
107
- continue
108
106
  param = params_dict[name]
109
107
  weight_loader = param.weight_loader
110
108
  weight_loader(param, loaded_weight, shard_id)
@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
113
111
  # Skip loading extra bias for GPTQ models.
114
112
  if name.endswith(".bias") and name not in params_dict:
115
113
  continue
116
- if name.startswith("model.vision_tower") and name not in params_dict:
117
- continue
118
114
  param = params_dict[name]
119
115
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
120
116
  weight_loader(param, loaded_weight)
@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
57
57
  # Models trained using ColossalAI may include these tensors in
58
58
  # the checkpoint. Skip them.
59
59
  return
60
+ if name.startswith("model.vision_tower") and name not in params_dict:
61
+ return
62
+
60
63
  for param_name, weight_name, shard_id in stacked_params_mapping:
61
64
  if weight_name not in name:
62
65
  continue
@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
64
67
  # Skip loading extra bias for GPTQ models.
65
68
  if name.endswith(".bias") and name not in params_dict:
66
69
  continue
67
- if name.startswith("model.vision_tower") and name not in params_dict:
68
- continue
69
70
  param = params_dict[name]
70
71
  weight_loader = param.weight_loader
71
72
  weight_loader(param, loaded_weight, shard_id)
@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
74
75
  # Skip loading extra bias for GPTQ models.
75
76
  if name.endswith(".bias") and name not in params_dict:
76
77
  return
77
- if name.startswith("model.vision_tower") and name not in params_dict:
78
- return
79
78
  param = params_dict[name]
80
79
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
81
80
  weight_loader(param, loaded_weight)
@@ -28,7 +28,6 @@ from transformers import (
28
28
  LlavaConfig,
29
29
  MistralConfig,
30
30
  Qwen2Config,
31
- SiglipVisionConfig,
32
31
  SiglipVisionModel,
33
32
  )
34
33
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
@@ -47,32 +46,19 @@ from sglang.srt.models.mistral import MistralForCausalLM
47
46
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
48
47
 
49
48
 
50
- class LlavaLlamaForCausalLM(nn.Module):
51
- def __init__(
49
+ class LlavaBaseForCausalLM(nn.Module):
50
+ def pad_input_ids(
52
51
  self,
53
- config: LlavaConfig,
54
- quant_config: Optional[QuantizationConfig] = None,
55
- cache_config: Optional[CacheConfig] = None,
56
- ) -> None:
57
- super().__init__()
58
- self.config = config
59
- self.vision_tower = None
60
- self.config.vision_config.hidden_size = config.mm_hidden_size
61
- self.config.text_config.hidden_size = config.hidden_size
62
- self.multi_modal_projector = LlavaMultiModalProjector(config)
63
- self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
64
- if "unpad" in getattr(config, "mm_patch_merge_type", ""):
65
- self.language_model.model.image_newline = nn.Parameter(
66
- torch.empty(config.text_config.hidden_size, dtype=torch.float16)
67
- )
68
-
69
- def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
70
-
52
+ input_ids: List[int],
53
+ pad_value: List[int],
54
+ pixel_values: List,
55
+ image_sizes: List[List[int]],
56
+ ):
71
57
  # hardcode for spatial_unpad + anyres
72
- image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
58
+ image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
73
59
  offset_list = []
74
- for image_s in image_size:
75
- if len(image_size) > 16:
60
+ for image_s in image_sizes:
61
+ if len(image_sizes) > 16:
76
62
  # 2x2 pooling with stride 2
77
63
  new_image_feature_len = (
78
64
  math.ceil(self.image_size / self.patch_size / 2) ** 2
@@ -153,17 +139,15 @@ class LlavaLlamaForCausalLM(nn.Module):
153
139
  if input_metadata.forward_mode == ForwardMode.EXTEND:
154
140
  bs = input_metadata.batch_size
155
141
 
156
- # Embed text input
142
+ # Embed text inputs
157
143
  input_embeds = self.language_model.model.embed_tokens(input_ids)
158
- # Embed vision input
159
- need_vision = (
160
- (positions[input_metadata.extend_start_loc] < self.image_feature_len)
161
- .cpu()
162
- .numpy()
144
+
145
+ # Whether the requests need vision inputs
146
+ max_image_offset = np.array(
147
+ [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
163
148
  )
164
- # FIXME: We need to substract the length of the system prompt
165
- has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
166
- need_vision = need_vision & has_pixel
149
+ start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
150
+ need_vision = start_positions <= max_image_offset
167
151
 
168
152
  if need_vision.any():
169
153
  pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
@@ -332,31 +316,35 @@ class LlavaLlamaForCausalLM(nn.Module):
332
316
  new_image_features.append(image_feature)
333
317
  image_features = new_image_features
334
318
 
319
+ # Fill in the placeholder for the image
335
320
  extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
321
+ prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
336
322
  pt = 0
337
323
  for i in range(bs):
338
324
  if not need_vision[i]:
339
325
  continue
340
326
 
341
327
  start_idx = extend_start_loc_cpu[i]
342
- pad_dim = image_features[pt].shape[-1] # 576, 4096
343
- dim = input_embeds.shape[1]
344
- assert (
345
- pad_dim == dim
346
- ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
347
- # Fill in the placeholder for the image
348
- try:
349
- for j, image_off in enumerate(image_offsets[i]):
350
- # print("actual image_features length: ", image_features[pt][j].shape[0])
351
- pad_len = image_features[pt][j].shape[0]
352
- input_embeds[
353
- start_idx + image_off : start_idx + image_off + pad_len
354
- ] = image_features[pt][j]
355
- except RuntimeError as e:
356
- print(f"RuntimeError in llava image encoding: {e}")
357
- print(image_features[pt].shape)
358
- print(input_embeds.shape)
359
- print(start_idx, image_offsets[i])
328
+ prefix_len = prefix_lens_cpu[i]
329
+
330
+ # Multiple images
331
+ for j, image_offset in enumerate(image_offsets[i]):
332
+ if image_offset < prefix_len:
333
+ continue
334
+
335
+ tmp_image_feature = image_features[pt][j]
336
+ pad_len = tmp_image_feature.shape[0]
337
+
338
+ left_idx = start_idx + (image_offset - prefix_len)
339
+ right_idx = start_idx + (image_offset - prefix_len) + pad_len
340
+ try:
341
+ input_embeds[left_idx:right_idx] = tmp_image_feature
342
+ except RuntimeError as e:
343
+ print(f"RuntimeError in image encoding: {e}")
344
+ print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
345
+ print(
346
+ f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
347
+ )
360
348
  pt += 1
361
349
 
362
350
  return self.language_model(
@@ -366,8 +354,9 @@ class LlavaLlamaForCausalLM(nn.Module):
366
354
  return self.language_model(input_ids, positions, input_metadata)
367
355
 
368
356
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
369
- # load clip vision model by cfg['mm_vision_tower']:
370
- # huggingface_name or path_of_clip_relative_to_llava_model_dir
357
+ # Load clip vision model by cfg['mm_vision_tower']:
358
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
359
+ # We put the initialization here instead of __init__ to allow it being reused by other subclasses.
371
360
  vision_path = self.config.mm_vision_tower
372
361
  if "clip" in vision_path:
373
362
  self.vision_tower = CLIPVisionModel.from_pretrained(
@@ -422,21 +411,41 @@ class LlavaLlamaForCausalLM(nn.Module):
422
411
  # load language model
423
412
  self.language_model.load_weights(weights)
424
413
 
425
- monkey_path_clip_vision_embed_forward()
426
-
427
414
  @property
428
415
  def num_patches_per_side(self):
429
416
  return self.image_size // self.patch_size
430
417
 
431
418
 
432
- class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
419
+ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
433
420
  def __init__(
434
421
  self,
435
422
  config: LlavaConfig,
436
423
  quant_config: Optional[QuantizationConfig] = None,
437
424
  cache_config: Optional[CacheConfig] = None,
438
425
  ) -> None:
439
- super().__init__(config, quant_config=quant_config, cache_config=cache_config)
426
+ super().__init__()
427
+
428
+ self.config = config
429
+ self.vision_tower = None
430
+ self.config.vision_config.hidden_size = config.mm_hidden_size
431
+ self.config.text_config.hidden_size = config.hidden_size
432
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
433
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
434
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
435
+ self.language_model.model.image_newline = nn.Parameter(
436
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
437
+ )
438
+
439
+
440
+ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
441
+ def __init__(
442
+ self,
443
+ config: LlavaConfig,
444
+ quant_config: Optional[QuantizationConfig] = None,
445
+ cache_config: Optional[CacheConfig] = None,
446
+ ) -> None:
447
+ super().__init__()
448
+
440
449
  self.config = config
441
450
  self.vision_tower = None
442
451
  if getattr(self.config, "vision_config", None) is None:
@@ -462,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
462
471
  )
463
472
 
464
473
 
465
- class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
474
+ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
466
475
  def __init__(
467
476
  self,
468
477
  config: LlavaConfig,
469
478
  quant_config: Optional[QuantizationConfig] = None,
470
479
  cache_config: Optional[CacheConfig] = None,
471
480
  ) -> None:
472
- super().__init__(config, quant_config=quant_config, cache_config=cache_config)
481
+ super().__init__()
482
+
473
483
  self.config = config
474
484
  self.vision_tower = None
475
485
  if getattr(self.config, "vision_config", None) is None:
@@ -495,36 +505,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
495
505
  )
496
506
 
497
507
 
498
- first_call = True
499
-
500
-
501
- def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
502
- batch_size = pixel_values.shape[0]
503
-
504
- # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
505
- global first_call
506
- if first_call:
507
- self.patch_embedding.cpu().float()
508
- first_call = False
509
- pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
510
- patch_embeds = self.patch_embedding(pixel_values).cuda().half()
511
-
512
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
513
-
514
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
515
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
516
- embeddings = embeddings + self.position_embedding(self.position_ids)
517
- return embeddings
518
-
519
-
520
- def monkey_path_clip_vision_embed_forward():
521
- import transformers
522
-
523
- setattr(
524
- transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
525
- "forward",
526
- clip_vision_embed_forward,
527
- )
528
-
529
-
530
508
  EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]