sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,34 +1,36 @@
1
1
  # Adapted from
2
- # https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
3
3
  """Inference-only LLaMA model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
6
 
6
7
  import torch
7
- from sglang.srt.layers.logits_processor import LogitsProcessor
8
- from sglang.srt.layers.radix_attention import RadixAttention
9
- from sglang.srt.managers.router.model_runner import InputMetadata
8
+ import tqdm
10
9
  from torch import nn
11
10
  from transformers import LlamaConfig
11
+ from vllm.config import CacheConfig
12
+ from vllm.distributed import (
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ )
12
16
  from vllm.model_executor.layers.activation import SiluAndMul
13
17
  from vllm.model_executor.layers.layernorm import RMSNorm
14
18
  from vllm.model_executor.layers.linear import (
15
- LinearMethodBase,
16
19
  MergedColumnParallelLinear,
17
20
  QKVParallelLinear,
18
21
  RowParallelLinear,
19
22
  )
23
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
20
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
21
25
  from vllm.model_executor.layers.vocab_parallel_embedding import (
22
26
  ParallelLMHead,
23
27
  VocabParallelEmbedding,
24
28
  )
25
- from vllm.model_executor.parallel_utils.parallel_state import (
26
- get_tensor_model_parallel_world_size,
27
- )
28
- from vllm.model_executor.weight_utils import (
29
- default_weight_loader,
30
- hf_model_weights_iterator,
31
- )
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
+
31
+ from sglang.srt.layers.logits_processor import LogitsProcessor
32
+ from sglang.srt.layers.radix_attention import RadixAttention
33
+ from sglang.srt.managers.controller.model_runner import InputMetadata
32
34
 
33
35
 
34
36
  class LlamaMLP(nn.Module):
@@ -37,17 +39,20 @@ class LlamaMLP(nn.Module):
37
39
  hidden_size: int,
38
40
  intermediate_size: int,
39
41
  hidden_act: str,
40
- linear_method: Optional[LinearMethodBase] = None,
42
+ quant_config: Optional[QuantizationConfig] = None,
41
43
  ) -> None:
42
44
  super().__init__()
43
45
  self.gate_up_proj = MergedColumnParallelLinear(
44
46
  hidden_size,
45
47
  [intermediate_size] * 2,
46
48
  bias=False,
47
- linear_method=linear_method,
49
+ quant_config=quant_config,
48
50
  )
49
51
  self.down_proj = RowParallelLinear(
50
- intermediate_size, hidden_size, bias=False, linear_method=linear_method
52
+ intermediate_size,
53
+ hidden_size,
54
+ bias=False,
55
+ quant_config=quant_config,
51
56
  )
52
57
  if hidden_act != "silu":
53
58
  raise ValueError(
@@ -72,8 +77,9 @@ class LlamaAttention(nn.Module):
72
77
  layer_id: int = 0,
73
78
  rope_theta: float = 10000,
74
79
  rope_scaling: Optional[Dict[str, Any]] = None,
80
+ rope_is_neox_style: bool = True,
75
81
  max_position_embeddings: int = 8192,
76
- linear_method: Optional[LinearMethodBase] = None,
82
+ quant_config: Optional[QuantizationConfig] = None,
77
83
  ) -> None:
78
84
  super().__init__()
79
85
  self.hidden_size = hidden_size
@@ -104,13 +110,13 @@ class LlamaAttention(nn.Module):
104
110
  self.total_num_heads,
105
111
  self.total_num_kv_heads,
106
112
  bias=False,
107
- linear_method=linear_method,
113
+ quant_config=quant_config,
108
114
  )
109
115
  self.o_proj = RowParallelLinear(
110
116
  self.total_num_heads * self.head_dim,
111
117
  hidden_size,
112
118
  bias=False,
113
- linear_method=linear_method,
119
+ quant_config=quant_config,
114
120
  )
115
121
 
116
122
  self.rotary_emb = get_rope(
@@ -119,6 +125,7 @@ class LlamaAttention(nn.Module):
119
125
  max_position=max_position_embeddings,
120
126
  base=rope_theta,
121
127
  rope_scaling=rope_scaling,
128
+ is_neox_style=rope_is_neox_style,
122
129
  )
123
130
  self.attn = RadixAttention(
124
131
  self.num_heads,
@@ -147,12 +154,19 @@ class LlamaDecoderLayer(nn.Module):
147
154
  self,
148
155
  config: LlamaConfig,
149
156
  layer_id: int = 0,
150
- linear_method: Optional[LinearMethodBase] = None,
157
+ quant_config: Optional[QuantizationConfig] = None,
151
158
  ) -> None:
152
159
  super().__init__()
153
160
  self.hidden_size = config.hidden_size
154
161
  rope_theta = getattr(config, "rope_theta", 10000)
155
162
  rope_scaling = getattr(config, "rope_scaling", None)
163
+ if rope_scaling is not None and getattr(
164
+ config, "original_max_position_embeddings", None
165
+ ):
166
+ rope_scaling[
167
+ "original_max_position_embeddings"
168
+ ] = config.original_max_position_embeddings
169
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
156
170
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
157
171
  self.self_attn = LlamaAttention(
158
172
  hidden_size=self.hidden_size,
@@ -161,14 +175,15 @@ class LlamaDecoderLayer(nn.Module):
161
175
  layer_id=layer_id,
162
176
  rope_theta=rope_theta,
163
177
  rope_scaling=rope_scaling,
178
+ rope_is_neox_style=rope_is_neox_style,
164
179
  max_position_embeddings=max_position_embeddings,
165
- linear_method=linear_method,
180
+ quant_config=quant_config,
166
181
  )
167
182
  self.mlp = LlamaMLP(
168
183
  hidden_size=self.hidden_size,
169
184
  intermediate_size=config.intermediate_size,
170
185
  hidden_act=config.hidden_act,
171
- linear_method=linear_method,
186
+ quant_config=quant_config,
172
187
  )
173
188
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
174
189
  self.post_attention_layernorm = RMSNorm(
@@ -204,7 +219,7 @@ class LlamaModel(nn.Module):
204
219
  def __init__(
205
220
  self,
206
221
  config: LlamaConfig,
207
- linear_method: Optional[LinearMethodBase] = None,
222
+ quant_config: Optional[QuantizationConfig] = None,
208
223
  ) -> None:
209
224
  super().__init__()
210
225
  self.config = config
@@ -216,7 +231,7 @@ class LlamaModel(nn.Module):
216
231
  )
217
232
  self.layers = nn.ModuleList(
218
233
  [
219
- LlamaDecoderLayer(config, i, linear_method)
234
+ LlamaDecoderLayer(config, i, quant_config=quant_config)
220
235
  for i in range(config.num_hidden_layers)
221
236
  ]
222
237
  )
@@ -250,12 +265,13 @@ class LlamaForCausalLM(nn.Module):
250
265
  def __init__(
251
266
  self,
252
267
  config: LlamaConfig,
253
- linear_method: Optional[LinearMethodBase] = None,
268
+ quant_config: Optional[QuantizationConfig] = None,
269
+ cache_config: Optional[CacheConfig] = None,
254
270
  ) -> None:
255
271
  super().__init__()
256
272
  self.config = config
257
- self.linear_method = linear_method
258
- self.model = LlamaModel(config, linear_method)
273
+ self.quant_config = quant_config
274
+ self.model = LlamaModel(config, quant_config=quant_config)
259
275
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
260
276
  self.logits_processor = LogitsProcessor(config)
261
277
 
@@ -271,13 +287,7 @@ class LlamaForCausalLM(nn.Module):
271
287
  input_ids, hidden_states, self.lm_head.weight, input_metadata
272
288
  )
273
289
 
274
- def load_weights(
275
- self,
276
- model_name_or_path: str,
277
- cache_dir: Optional[str] = None,
278
- load_format: str = "auto",
279
- revision: Optional[str] = None,
280
- ):
290
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
281
291
  stacked_params_mapping = [
282
292
  # (param_name, shard_name, shard_id)
283
293
  ("qkv_proj", "q_proj", "q"),
@@ -287,9 +297,9 @@ class LlamaForCausalLM(nn.Module):
287
297
  ("gate_up_proj", "up_proj", 1),
288
298
  ]
289
299
  params_dict = dict(self.named_parameters())
290
- for name, loaded_weight in hf_model_weights_iterator(
291
- model_name_or_path, cache_dir, load_format, revision
292
- ):
300
+ if get_tensor_model_parallel_rank() == 0:
301
+ weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
302
+ for name, loaded_weight in weights:
293
303
  if "rotary_emb.inv_freq" in name or "projector" in name:
294
304
  continue
295
305
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -0,0 +1,107 @@
1
+ from typing import Iterable, Optional, Tuple
2
+
3
+ import torch
4
+ import tqdm
5
+ from torch import nn
6
+ from transformers import LlamaConfig
7
+ from vllm.config import CacheConfig
8
+ from vllm.distributed import get_tensor_model_parallel_rank
9
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
10
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
11
+
12
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
13
+ from sglang.srt.managers.controller.model_runner import InputMetadata
14
+ from sglang.srt.models.llama2 import LlamaModel
15
+
16
+
17
+ class LlamaForClassification(nn.Module):
18
+ def __init__(
19
+ self,
20
+ config: LlamaConfig,
21
+ quant_config: Optional[QuantizationConfig] = None,
22
+ cache_config: Optional[CacheConfig] = None,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.config = config
26
+ self.quant_config = quant_config
27
+ self.model = LlamaModel(config, quant_config=quant_config)
28
+
29
+ self.classification_head = nn.Linear(
30
+ config.hidden_size, config.classification_out_size
31
+ )
32
+ self.eos_token_id = config.eos_token_id
33
+
34
+ def forward(
35
+ self,
36
+ input_ids: torch.Tensor,
37
+ positions: torch.Tensor,
38
+ input_metadata: InputMetadata,
39
+ input_embeds: torch.Tensor = None,
40
+ ) -> torch.Tensor:
41
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
42
+ is_eos_token = input_ids == self.eos_token_id
43
+ hidden_states = hidden_states[is_eos_token]
44
+ scores = self.classification_head(hidden_states)
45
+
46
+ if scores.shape[0] != input_metadata.batch_size:
47
+ print("Warning: the EOS tokens are missing in some sentences.")
48
+ scores = torch.ones(
49
+ (input_metadata.batch_size, self.config.classification_out_size)
50
+ ).to(input_ids.device)
51
+
52
+ return LogitProcessorOutput(
53
+ next_token_logits=scores,
54
+ next_token_logprobs=scores,
55
+ normalized_prompt_logprobs=scores,
56
+ prefill_token_logprobs=torch.ones_like(input_ids),
57
+ prefill_top_logprobs=None,
58
+ decode_top_logprobs=None,
59
+ )
60
+
61
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
62
+ stacked_params_mapping = [
63
+ # (param_name, shard_name, shard_id)
64
+ ("qkv_proj", "q_proj", "q"),
65
+ ("qkv_proj", "k_proj", "k"),
66
+ ("qkv_proj", "v_proj", "v"),
67
+ ("gate_up_proj", "gate_proj", 0),
68
+ ("gate_up_proj", "up_proj", 1),
69
+ ]
70
+ params_dict = dict(self.named_parameters())
71
+ if get_tensor_model_parallel_rank() == 0:
72
+ weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
73
+ for name, loaded_weight in weights:
74
+ if "rotary_emb.inv_freq" in name or "projector" in name:
75
+ continue
76
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
77
+ # Models trained using ColossalAI may include these tensors in
78
+ # the checkpoint. Skip them.
79
+ continue
80
+ if "lm_head" in name:
81
+ continue
82
+
83
+ for param_name, weight_name, shard_id in stacked_params_mapping:
84
+ if weight_name not in name:
85
+ continue
86
+ name = name.replace(weight_name, param_name)
87
+ # Skip loading extra bias for GPTQ models.
88
+ if name.endswith(".bias") and name not in params_dict:
89
+ continue
90
+ if name.startswith("model.vision_tower") and name not in params_dict:
91
+ continue
92
+ param = params_dict[name]
93
+ weight_loader = param.weight_loader
94
+ weight_loader(param, loaded_weight, shard_id)
95
+ break
96
+ else:
97
+ # Skip loading extra bias for GPTQ models.
98
+ if name.endswith(".bias") and name not in params_dict:
99
+ continue
100
+ if name.startswith("model.vision_tower") and name not in params_dict:
101
+ continue
102
+ param = params_dict[name]
103
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
104
+ weight_loader(param, loaded_weight)
105
+
106
+
107
+ EntryClass = LlamaForClassification
@@ -1,32 +1,40 @@
1
1
  """Inference-only LLaVa model compatible with HuggingFace weights."""
2
2
 
3
- from typing import List, Optional
3
+ from typing import Iterable, List, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
- from sglang.srt.managers.router.infer_batch import ForwardMode
8
- from sglang.srt.managers.router.model_runner import InputMetadata
7
+ from torch import nn
8
+ from transformers import (
9
+ CLIPVisionConfig,
10
+ CLIPVisionModel,
11
+ LlavaConfig,
12
+ MistralConfig,
13
+ Qwen2Config,
14
+ )
15
+ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
16
+ from vllm.config import CacheConfig
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
18
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
19
+
20
+ from sglang.srt.managers.controller.infer_batch import ForwardMode
21
+ from sglang.srt.managers.controller.model_runner import InputMetadata
9
22
  from sglang.srt.mm_utils import (
10
23
  get_anyres_image_grid_shape,
11
24
  unpad_image,
12
25
  unpad_image_shape,
13
26
  )
14
27
  from sglang.srt.models.llama2 import LlamaForCausalLM
15
- from torch import nn
16
- from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
17
- from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
18
- from vllm.model_executor.layers.linear import LinearMethodBase
19
- from vllm.model_executor.weight_utils import (
20
- default_weight_loader,
21
- hf_model_weights_iterator,
22
- )
28
+ from sglang.srt.models.mistral import MistralForCausalLM
29
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
23
30
 
24
31
 
25
32
  class LlavaLlamaForCausalLM(nn.Module):
26
33
  def __init__(
27
34
  self,
28
35
  config: LlavaConfig,
29
- linear_method: Optional[LinearMethodBase] = None,
36
+ quant_config: Optional[QuantizationConfig] = None,
37
+ cache_config: Optional[CacheConfig] = None,
30
38
  ) -> None:
31
39
  super().__init__()
32
40
  self.config = config
@@ -34,7 +42,7 @@ class LlavaLlamaForCausalLM(nn.Module):
34
42
  self.config.vision_config.hidden_size = config.mm_hidden_size
35
43
  self.config.text_config.hidden_size = config.hidden_size
36
44
  self.multi_modal_projector = LlavaMultiModalProjector(config)
37
- self.language_model = LlamaForCausalLM(config, linear_method)
45
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
38
46
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
39
47
  self.language_model.model.image_newline = nn.Parameter(
40
48
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
@@ -235,13 +243,7 @@ class LlavaLlamaForCausalLM(nn.Module):
235
243
  elif input_metadata.forward_mode == ForwardMode.DECODE:
236
244
  return self.language_model(input_ids, positions, input_metadata)
237
245
 
238
- def load_weights(
239
- self,
240
- model_name_or_path: str,
241
- cache_dir: Optional[str] = None,
242
- load_format: str = "auto",
243
- revision: Optional[str] = None,
244
- ):
246
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
245
247
  # load clip vision model by cfg['mm_vision_tower']:
246
248
  # huggingface_name or path_of_clip_relative_to_llava_model_dir
247
249
  vision_path = self.config.mm_vision_tower
@@ -274,9 +276,8 @@ class LlavaLlamaForCausalLM(nn.Module):
274
276
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
275
277
  }
276
278
  params_dict = dict(self.named_parameters())
277
- for name, loaded_weight in hf_model_weights_iterator(
278
- model_name_or_path, cache_dir, load_format, revision
279
- ):
279
+ weights = list(weights)
280
+ for name, loaded_weight in weights:
280
281
  # FIXME: why projector weights read two times?
281
282
  if "projector" in name or "vision_tower" in name:
282
283
  for weight_name, param_name in projector_weights.items():
@@ -287,9 +288,7 @@ class LlavaLlamaForCausalLM(nn.Module):
287
288
  weight_loader(param, loaded_weight)
288
289
 
289
290
  # load language model
290
- self.language_model.load_weights(
291
- model_name_or_path, cache_dir, load_format, revision
292
- )
291
+ self.language_model.load_weights(weights)
293
292
 
294
293
  monkey_path_clip_vision_embed_forward()
295
294
 
@@ -298,6 +297,72 @@ class LlavaLlamaForCausalLM(nn.Module):
298
297
  return self.image_size // self.patch_size
299
298
 
300
299
 
300
+ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
301
+ def __init__(
302
+ self,
303
+ config: LlavaConfig,
304
+ quant_config: Optional[QuantizationConfig] = None,
305
+ cache_config: Optional[CacheConfig] = None,
306
+ ) -> None:
307
+ super().__init__(config, quant_config=quant_config, cache_config=cache_config)
308
+ self.config = config
309
+ self.vision_tower = None
310
+ if getattr(self.config, "vision_config", None) is None:
311
+ self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
312
+
313
+ if getattr(self.config, "text_config", None) is None:
314
+ self.config.text_config = Qwen2Config(self.config._name_or_path)
315
+
316
+ self.config.vision_config.hidden_size = config.mm_hidden_size
317
+ self.config.text_config.hidden_size = config.hidden_size
318
+
319
+ if getattr(self.config, "projector_hidden_act", None) is None:
320
+ self.config.projector_hidden_act = "gelu"
321
+
322
+ if getattr(self.config, "image_token_index", None) is None:
323
+ self.config.image_token_index = 151646
324
+
325
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
326
+ self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
327
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
328
+ self.language_model.model.image_newline = nn.Parameter(
329
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
330
+ )
331
+
332
+
333
+ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
334
+ def __init__(
335
+ self,
336
+ config: LlavaConfig,
337
+ quant_config: Optional[QuantizationConfig] = None,
338
+ cache_config: Optional[CacheConfig] = None,
339
+ ) -> None:
340
+ super().__init__(config, quant_config=quant_config, cache_config=cache_config)
341
+ self.config = config
342
+ self.vision_tower = None
343
+ if getattr(self.config, "vision_config", None) is None:
344
+ self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
345
+
346
+ if getattr(self.config, "text_config", None) is None:
347
+ self.config.text_config = MistralConfig(self.config._name_or_path)
348
+
349
+ self.config.vision_config.hidden_size = config.mm_hidden_size
350
+ self.config.text_config.hidden_size = config.hidden_size
351
+
352
+ if getattr(self.config, "projector_hidden_act", None) is None:
353
+ self.config.projector_hidden_act = "gelu"
354
+
355
+ if getattr(self.config, "image_token_index", None) is None:
356
+ self.config.image_token_index = 32000
357
+
358
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
359
+ self.language_model = MistralForCausalLM(config, quant_config=quant_config)
360
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
361
+ self.language_model.model.image_newline = nn.Parameter(
362
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
363
+ )
364
+
365
+
301
366
  first_call = True
302
367
 
303
368
 
@@ -330,4 +395,4 @@ def monkey_path_clip_vision_embed_forward():
330
395
  )
331
396
 
332
397
 
333
- EntryClass = LlavaLlamaForCausalLM
398
+ EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]