sglang 0.2.14.post1__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/interpreter.py +3 -0
  4. sglang/lang/ir.py +5 -0
  5. sglang/launch_server_llavavid.py +26 -0
  6. sglang/srt/configs/__init__.py +5 -0
  7. sglang/srt/configs/exaone.py +195 -0
  8. sglang/srt/constrained/fsm_cache.py +1 -1
  9. sglang/srt/conversation.py +24 -2
  10. sglang/srt/hf_transformers_utils.py +11 -160
  11. sglang/srt/layers/activation.py +10 -4
  12. sglang/srt/layers/extend_attention.py +13 -8
  13. sglang/srt/layers/layernorm.py +47 -1
  14. sglang/srt/layers/logits_processor.py +4 -4
  15. sglang/srt/layers/sampler.py +69 -16
  16. sglang/srt/managers/controller_multi.py +5 -5
  17. sglang/srt/managers/controller_single.py +5 -5
  18. sglang/srt/managers/io_struct.py +11 -5
  19. sglang/srt/managers/schedule_batch.py +25 -13
  20. sglang/srt/managers/tokenizer_manager.py +76 -63
  21. sglang/srt/managers/tp_worker.py +47 -36
  22. sglang/srt/model_config.py +3 -3
  23. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  24. sglang/srt/model_executor/forward_batch_info.py +78 -43
  25. sglang/srt/model_executor/model_runner.py +29 -18
  26. sglang/srt/models/chatglm.py +5 -13
  27. sglang/srt/models/commandr.py +5 -1
  28. sglang/srt/models/dbrx.py +5 -1
  29. sglang/srt/models/deepseek.py +5 -1
  30. sglang/srt/models/deepseek_v2.py +57 -25
  31. sglang/srt/models/exaone.py +399 -0
  32. sglang/srt/models/gemma.py +7 -3
  33. sglang/srt/models/gemma2.py +6 -52
  34. sglang/srt/models/gpt_bigcode.py +5 -1
  35. sglang/srt/models/grok.py +14 -4
  36. sglang/srt/models/internlm2.py +5 -1
  37. sglang/srt/models/llama2.py +10 -7
  38. sglang/srt/models/llama_classification.py +2 -6
  39. sglang/srt/models/llama_embedding.py +3 -4
  40. sglang/srt/models/llava.py +69 -91
  41. sglang/srt/models/llavavid.py +40 -86
  42. sglang/srt/models/minicpm.py +5 -1
  43. sglang/srt/models/mixtral.py +6 -2
  44. sglang/srt/models/mixtral_quant.py +5 -1
  45. sglang/srt/models/qwen.py +5 -2
  46. sglang/srt/models/qwen2.py +9 -6
  47. sglang/srt/models/qwen2_moe.py +12 -33
  48. sglang/srt/models/stablelm.py +5 -1
  49. sglang/srt/models/yivl.py +2 -7
  50. sglang/srt/openai_api/adapter.py +16 -1
  51. sglang/srt/openai_api/protocol.py +5 -5
  52. sglang/srt/sampling/sampling_batch_info.py +79 -6
  53. sglang/srt/server.py +9 -9
  54. sglang/srt/utils.py +18 -36
  55. sglang/test/runners.py +2 -2
  56. sglang/test/test_layernorm.py +53 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/METADATA +8 -8
  59. sglang-0.2.15.dist-info/RECORD +118 -0
  60. sglang-0.2.14.post1.dist-info/RECORD +0 -114
  61. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -22,11 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.config import CacheConfig, LoRAConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
-
26
- # FIXME: temporary solution, remove after next vllm release
27
- from vllm.model_executor.custom_op import CustomOp
28
-
29
- # from vllm.model_executor.layers.layernorm import GemmaRMSNorm
30
25
  from vllm.model_executor.layers.linear import (
31
26
  MergedColumnParallelLinear,
32
27
  QKVParallelLinear,
@@ -39,8 +34,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
39
34
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
35
 
41
36
  from sglang.srt.layers.activation import GeluAndMul
37
+ from sglang.srt.layers.layernorm import GemmaRMSNorm
42
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.sampler import Sampler
44
41
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
42
 
46
43
 
@@ -50,52 +47,6 @@ def get_attention_sliding_window_size(config):
50
47
  return config.sliding_window - 1
51
48
 
52
49
 
53
- class GemmaRMSNorm(CustomOp):
54
- """RMS normalization for Gemma.
55
-
56
- Two differences from the above RMSNorm:
57
- 1. x * (1 + w) instead of x * w.
58
- 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
59
- """
60
-
61
- def __init__(
62
- self,
63
- hidden_size: int,
64
- eps: float = 1e-6,
65
- ) -> None:
66
- super().__init__()
67
- self.weight = nn.Parameter(torch.zeros(hidden_size))
68
- self.variance_epsilon = eps
69
-
70
- def forward_native(
71
- self,
72
- x: torch.Tensor,
73
- residual: Optional[torch.Tensor] = None,
74
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
75
- """PyTorch-native implementation equivalent to forward()."""
76
- orig_dtype = x.dtype
77
- if residual is not None:
78
- x = x + residual
79
- residual = x
80
-
81
- x = x.float()
82
- variance = x.pow(2).mean(dim=-1, keepdim=True)
83
- x = x * torch.rsqrt(variance + self.variance_epsilon)
84
- # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
85
- # See https://github.com/huggingface/transformers/pull/29402
86
- x = x * (1.0 + self.weight.float())
87
- x = x.to(orig_dtype)
88
- return x if residual is None else (x, residual)
89
-
90
- def forward_cuda(
91
- self,
92
- x: torch.Tensor,
93
- residual: Optional[torch.Tensor] = None,
94
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
95
- # from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
96
- return self.forward_native(x, residual)
97
-
98
-
99
50
  # FIXME: temporary solution, remove after next vllm release
100
51
  from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
101
52
 
@@ -396,6 +347,7 @@ class Gemma2ForCausalLM(nn.Module):
396
347
  self.quant_config = quant_config
397
348
  self.model = Gemma2Model(config, cache_config, quant_config)
398
349
  self.logits_processor = LogitsProcessor(config)
350
+ self.sampler = Sampler()
399
351
 
400
352
  @torch.no_grad()
401
353
  def forward(
@@ -406,9 +358,11 @@ class Gemma2ForCausalLM(nn.Module):
406
358
  input_embeds: torch.Tensor = None,
407
359
  ) -> torch.Tensor:
408
360
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
409
- return self.logits_processor(
361
+ logits_output = self.logits_processor(
410
362
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
411
363
  )
364
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
365
+ return sample_output, logits_output
412
366
 
413
367
  def get_attention_sliding_window_size(self):
414
368
  return get_attention_sliding_window_size(self.config)
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
  from sglang.srt.layers.activation import get_act_fn
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.sampler import Sampler
38
39
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
39
40
 
40
41
 
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
261
262
  if lora_config:
262
263
  self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
263
264
  self.logits_processor = LogitsProcessor(config)
265
+ self.sampler = Sampler()
264
266
 
265
267
  @torch.no_grad()
266
268
  def forward(
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
270
272
  input_metadata: InputMetadata,
271
273
  ) -> torch.Tensor:
272
274
  hidden_states = self.transformer(input_ids, positions, input_metadata)
273
- return self.logits_processor(
275
+ logits_output = self.logits_processor(
274
276
  input_ids, hidden_states, self.lm_head.weight, input_metadata
275
277
  )
278
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
279
+ return sample_output, logits_output
276
280
 
277
281
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
278
282
  params_dict = dict(self.named_parameters(remove_duplicate=False))
sglang/srt/models/grok.py CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
51
52
 
@@ -273,9 +274,9 @@ class Grok1Model(nn.Module):
273
274
  ) -> torch.Tensor:
274
275
  if input_embeds is None:
275
276
  hidden_states = self.embed_tokens(input_ids)
277
+ hidden_states.mul_(self.config.embedding_multiplier_scale)
276
278
  else:
277
279
  hidden_states = input_embeds
278
- hidden_states.mul_(self.config.embedding_multiplier_scale)
279
280
 
280
281
  for i in range(len(self.layers)):
281
282
  hidden_states = self.layers[i](positions, hidden_states, input_metadata)
@@ -284,7 +285,7 @@ class Grok1Model(nn.Module):
284
285
  return hidden_states
285
286
 
286
287
 
287
- class Grok1ModelForCausalLM(nn.Module):
288
+ class Grok1ForCausalLM(nn.Module):
288
289
  def __init__(
289
290
  self,
290
291
  config: PretrainedConfig,
@@ -297,6 +298,7 @@ class Grok1ModelForCausalLM(nn.Module):
297
298
  self.model = Grok1Model(config, quant_config=quant_config)
298
299
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
299
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
300
302
 
301
303
  # Monkey patch _prepare_weights to load pre-sharded weights
302
304
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -313,9 +315,11 @@ class Grok1ModelForCausalLM(nn.Module):
313
315
  input_embeds: torch.Tensor = None,
314
316
  ) -> torch.Tensor:
315
317
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
316
- return self.logits_processor(
318
+ logits_output = self.logits_processor(
317
319
  input_ids, hidden_states, self.lm_head.weight, input_metadata
318
320
  )
321
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
+ return sample_output, logits_output
319
323
 
320
324
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
321
325
  stacked_params_mapping = [
@@ -415,4 +419,10 @@ def _prepare_presharded_weights(
415
419
  return hf_folder, hf_weights_files, use_safetensors
416
420
 
417
421
 
418
- EntryClass = Grok1ModelForCausalLM
422
+ class Grok1ModelForCausalLM(Grok1ForCausalLM):
423
+ """An alias for backward-compatbility."""
424
+
425
+ pass
426
+
427
+
428
+ EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
 
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
262
263
  self.model = InternLM2Model(config, quant_config)
263
264
  self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
264
265
  self.logits_processor = LogitsProcessor(config)
266
+ self.sampler = Sampler()
265
267
 
266
268
  @torch.no_grad()
267
269
  def forward(
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
272
274
  input_embeds: torch.Tensor = None,
273
275
  ) -> torch.Tensor:
274
276
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
275
- return self.logits_processor(
277
+ logits_output = self.logits_processor(
276
278
  input_ids, hidden_states, self.output.weight, input_metadata
277
279
  )
280
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
281
+ return sample_output, logits_output
278
282
 
279
283
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
280
284
  stacked_params_mapping = [
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
- from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
302
303
  self.model = LlamaModel(config, quant_config=quant_config)
303
304
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
304
305
  self.logits_processor = LogitsProcessor(config)
306
+ self.sampler = Sampler()
305
307
 
306
308
  @torch.no_grad()
307
309
  def forward(
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
310
312
  positions: torch.Tensor,
311
313
  input_metadata: InputMetadata,
312
314
  input_embeds: torch.Tensor = None,
313
- ) -> LogitProcessorOutput:
315
+ ) -> LogitsProcessorOutput:
314
316
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
315
- return self.logits_processor(
317
+ logits_output = self.logits_processor(
316
318
  input_ids, hidden_states, self.lm_head.weight, input_metadata
317
319
  )
320
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
321
+ return sample_output, logits_output
318
322
 
319
323
  def get_module_name(self, name):
320
324
  stacked_params_mapping = [
@@ -357,6 +361,9 @@ class LlamaForCausalLM(nn.Module):
357
361
  # Models trained using ColossalAI may include these tensors in
358
362
  # the checkpoint. Skip them.
359
363
  return
364
+ if name.startswith("model.vision_tower") and name not in params_dict:
365
+ return
366
+
360
367
  for param_name, weight_name, shard_id in stacked_params_mapping:
361
368
  if weight_name not in name:
362
369
  continue
@@ -364,8 +371,6 @@ class LlamaForCausalLM(nn.Module):
364
371
  # Skip loading extra bias for GPTQ models.
365
372
  if name.endswith(".bias") and name not in params_dict:
366
373
  continue
367
- if name.startswith("model.vision_tower") and name not in params_dict:
368
- continue
369
374
  param = params_dict[name]
370
375
  weight_loader = param.weight_loader
371
376
  weight_loader(param, loaded_weight, shard_id)
@@ -374,8 +379,6 @@ class LlamaForCausalLM(nn.Module):
374
379
  # Skip loading extra bias for GPTQ models.
375
380
  if name.endswith(".bias") and name not in params_dict:
376
381
  return
377
- if name.startswith("model.vision_tower") and name not in params_dict:
378
- return
379
382
  param = params_dict[name]
380
383
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
381
384
  weight_loader(param, loaded_weight)
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
24
24
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
27
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
28
28
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
29
  from sglang.srt.models.llama2 import LlamaModel
30
30
 
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
65
65
  (input_metadata.batch_size, self.config.classification_out_size)
66
66
  ).to(input_ids.device)
67
67
 
68
- return LogitProcessorOutput(
68
+ return LogitsProcessorOutput(
69
69
  next_token_logits=scores,
70
70
  next_token_logprobs=scores,
71
71
  normalized_prompt_logprobs=scores,
@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
103
103
  # Skip loading extra bias for GPTQ models.
104
104
  if name.endswith(".bias") and name not in params_dict:
105
105
  continue
106
- if name.startswith("model.vision_tower") and name not in params_dict:
107
- continue
108
106
  param = params_dict[name]
109
107
  weight_loader = param.weight_loader
110
108
  weight_loader(param, loaded_weight, shard_id)
@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
113
111
  # Skip loading extra bias for GPTQ models.
114
112
  if name.endswith(".bias") and name not in params_dict:
115
113
  continue
116
- if name.startswith("model.vision_tower") and name not in params_dict:
117
- continue
118
114
  param = params_dict[name]
119
115
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
120
116
  weight_loader(param, loaded_weight)
@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
57
57
  # Models trained using ColossalAI may include these tensors in
58
58
  # the checkpoint. Skip them.
59
59
  return
60
+ if name.startswith("model.vision_tower") and name not in params_dict:
61
+ return
62
+
60
63
  for param_name, weight_name, shard_id in stacked_params_mapping:
61
64
  if weight_name not in name:
62
65
  continue
@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
64
67
  # Skip loading extra bias for GPTQ models.
65
68
  if name.endswith(".bias") and name not in params_dict:
66
69
  continue
67
- if name.startswith("model.vision_tower") and name not in params_dict:
68
- continue
69
70
  param = params_dict[name]
70
71
  weight_loader = param.weight_loader
71
72
  weight_loader(param, loaded_weight, shard_id)
@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
74
75
  # Skip loading extra bias for GPTQ models.
75
76
  if name.endswith(".bias") and name not in params_dict:
76
77
  return
77
- if name.startswith("model.vision_tower") and name not in params_dict:
78
- return
79
78
  param = params_dict[name]
80
79
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
81
80
  weight_loader(param, loaded_weight)
@@ -28,7 +28,6 @@ from transformers import (
28
28
  LlavaConfig,
29
29
  MistralConfig,
30
30
  Qwen2Config,
31
- SiglipVisionConfig,
32
31
  SiglipVisionModel,
33
32
  )
34
33
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
@@ -47,32 +46,19 @@ from sglang.srt.models.mistral import MistralForCausalLM
47
46
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
48
47
 
49
48
 
50
- class LlavaLlamaForCausalLM(nn.Module):
51
- def __init__(
49
+ class LlavaBaseForCausalLM(nn.Module):
50
+ def pad_input_ids(
52
51
  self,
53
- config: LlavaConfig,
54
- quant_config: Optional[QuantizationConfig] = None,
55
- cache_config: Optional[CacheConfig] = None,
56
- ) -> None:
57
- super().__init__()
58
- self.config = config
59
- self.vision_tower = None
60
- self.config.vision_config.hidden_size = config.mm_hidden_size
61
- self.config.text_config.hidden_size = config.hidden_size
62
- self.multi_modal_projector = LlavaMultiModalProjector(config)
63
- self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
64
- if "unpad" in getattr(config, "mm_patch_merge_type", ""):
65
- self.language_model.model.image_newline = nn.Parameter(
66
- torch.empty(config.text_config.hidden_size, dtype=torch.float16)
67
- )
68
-
69
- def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
70
-
52
+ input_ids: List[int],
53
+ pad_value: List[int],
54
+ pixel_values: List,
55
+ image_sizes: List[List[int]],
56
+ ):
71
57
  # hardcode for spatial_unpad + anyres
72
- image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
58
+ image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
73
59
  offset_list = []
74
- for image_s in image_size:
75
- if len(image_size) > 16:
60
+ for image_s in image_sizes:
61
+ if len(image_sizes) > 16:
76
62
  # 2x2 pooling with stride 2
77
63
  new_image_feature_len = (
78
64
  math.ceil(self.image_size / self.patch_size / 2) ** 2
@@ -153,17 +139,15 @@ class LlavaLlamaForCausalLM(nn.Module):
153
139
  if input_metadata.forward_mode == ForwardMode.EXTEND:
154
140
  bs = input_metadata.batch_size
155
141
 
156
- # Embed text input
142
+ # Embed text inputs
157
143
  input_embeds = self.language_model.model.embed_tokens(input_ids)
158
- # Embed vision input
159
- need_vision = (
160
- (positions[input_metadata.extend_start_loc] < self.image_feature_len)
161
- .cpu()
162
- .numpy()
144
+
145
+ # Whether the requests need vision inputs
146
+ max_image_offset = np.array(
147
+ [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
163
148
  )
164
- # FIXME: We need to substract the length of the system prompt
165
- has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
166
- need_vision = need_vision & has_pixel
149
+ start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
150
+ need_vision = start_positions <= max_image_offset
167
151
 
168
152
  if need_vision.any():
169
153
  pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
@@ -332,31 +316,35 @@ class LlavaLlamaForCausalLM(nn.Module):
332
316
  new_image_features.append(image_feature)
333
317
  image_features = new_image_features
334
318
 
319
+ # Fill in the placeholder for the image
335
320
  extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
321
+ prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
336
322
  pt = 0
337
323
  for i in range(bs):
338
324
  if not need_vision[i]:
339
325
  continue
340
326
 
341
327
  start_idx = extend_start_loc_cpu[i]
342
- pad_dim = image_features[pt].shape[-1] # 576, 4096
343
- dim = input_embeds.shape[1]
344
- assert (
345
- pad_dim == dim
346
- ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
347
- # Fill in the placeholder for the image
348
- try:
349
- for j, image_off in enumerate(image_offsets[i]):
350
- # print("actual image_features length: ", image_features[pt][j].shape[0])
351
- pad_len = image_features[pt][j].shape[0]
352
- input_embeds[
353
- start_idx + image_off : start_idx + image_off + pad_len
354
- ] = image_features[pt][j]
355
- except RuntimeError as e:
356
- print(f"RuntimeError in llava image encoding: {e}")
357
- print(image_features[pt].shape)
358
- print(input_embeds.shape)
359
- print(start_idx, image_offsets[i])
328
+ prefix_len = prefix_lens_cpu[i]
329
+
330
+ # Multiple images
331
+ for j, image_offset in enumerate(image_offsets[i]):
332
+ if image_offset < prefix_len:
333
+ continue
334
+
335
+ tmp_image_feature = image_features[pt][j]
336
+ pad_len = tmp_image_feature.shape[0]
337
+
338
+ left_idx = start_idx + (image_offset - prefix_len)
339
+ right_idx = start_idx + (image_offset - prefix_len) + pad_len
340
+ try:
341
+ input_embeds[left_idx:right_idx] = tmp_image_feature
342
+ except RuntimeError as e:
343
+ print(f"RuntimeError in image encoding: {e}")
344
+ print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
345
+ print(
346
+ f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
347
+ )
360
348
  pt += 1
361
349
 
362
350
  return self.language_model(
@@ -366,8 +354,9 @@ class LlavaLlamaForCausalLM(nn.Module):
366
354
  return self.language_model(input_ids, positions, input_metadata)
367
355
 
368
356
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
369
- # load clip vision model by cfg['mm_vision_tower']:
370
- # huggingface_name or path_of_clip_relative_to_llava_model_dir
357
+ # Load clip vision model by cfg['mm_vision_tower']:
358
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
359
+ # We put the initialization here instead of __init__ to allow it being reused by other subclasses.
371
360
  vision_path = self.config.mm_vision_tower
372
361
  if "clip" in vision_path:
373
362
  self.vision_tower = CLIPVisionModel.from_pretrained(
@@ -422,21 +411,41 @@ class LlavaLlamaForCausalLM(nn.Module):
422
411
  # load language model
423
412
  self.language_model.load_weights(weights)
424
413
 
425
- monkey_path_clip_vision_embed_forward()
426
-
427
414
  @property
428
415
  def num_patches_per_side(self):
429
416
  return self.image_size // self.patch_size
430
417
 
431
418
 
432
- class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
419
+ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
433
420
  def __init__(
434
421
  self,
435
422
  config: LlavaConfig,
436
423
  quant_config: Optional[QuantizationConfig] = None,
437
424
  cache_config: Optional[CacheConfig] = None,
438
425
  ) -> None:
439
- super().__init__(config, quant_config=quant_config, cache_config=cache_config)
426
+ super().__init__()
427
+
428
+ self.config = config
429
+ self.vision_tower = None
430
+ self.config.vision_config.hidden_size = config.mm_hidden_size
431
+ self.config.text_config.hidden_size = config.hidden_size
432
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
433
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
434
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
435
+ self.language_model.model.image_newline = nn.Parameter(
436
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
437
+ )
438
+
439
+
440
+ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
441
+ def __init__(
442
+ self,
443
+ config: LlavaConfig,
444
+ quant_config: Optional[QuantizationConfig] = None,
445
+ cache_config: Optional[CacheConfig] = None,
446
+ ) -> None:
447
+ super().__init__()
448
+
440
449
  self.config = config
441
450
  self.vision_tower = None
442
451
  if getattr(self.config, "vision_config", None) is None:
@@ -462,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
462
471
  )
463
472
 
464
473
 
465
- class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
474
+ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
466
475
  def __init__(
467
476
  self,
468
477
  config: LlavaConfig,
469
478
  quant_config: Optional[QuantizationConfig] = None,
470
479
  cache_config: Optional[CacheConfig] = None,
471
480
  ) -> None:
472
- super().__init__(config, quant_config=quant_config, cache_config=cache_config)
481
+ super().__init__()
482
+
473
483
  self.config = config
474
484
  self.vision_tower = None
475
485
  if getattr(self.config, "vision_config", None) is None:
@@ -495,36 +505,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
495
505
  )
496
506
 
497
507
 
498
- first_call = True
499
-
500
-
501
- def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
502
- batch_size = pixel_values.shape[0]
503
-
504
- # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
505
- global first_call
506
- if first_call:
507
- self.patch_embedding.cpu().float()
508
- first_call = False
509
- pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
510
- patch_embeds = self.patch_embedding(pixel_values).cuda().half()
511
-
512
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
513
-
514
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
515
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
516
- embeddings = embeddings + self.position_embedding(self.position_ids)
517
- return embeddings
518
-
519
-
520
- def monkey_path_clip_vision_embed_forward():
521
- import transformers
522
-
523
- setattr(
524
- transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
525
- "forward",
526
- clip_vision_embed_forward,
527
- )
528
-
529
-
530
508
  EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]