sglang 0.2.14__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/launch_server_llavavid.py +26 -0
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/hf_transformers_utils.py +0 -149
- sglang/srt/layers/activation.py +93 -11
- sglang/srt/layers/layernorm.py +47 -4
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +15 -68
- sglang/srt/managers/io_struct.py +5 -4
- sglang/srt/managers/schedule_batch.py +20 -25
- sglang/srt/managers/tokenizer_manager.py +74 -61
- sglang/srt/managers/tp_worker.py +49 -43
- sglang/srt/model_executor/cuda_graph_runner.py +17 -31
- sglang/srt/model_executor/forward_batch_info.py +9 -26
- sglang/srt/model_executor/model_runner.py +20 -17
- sglang/srt/models/chatglm.py +13 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/gemma.py +3 -7
- sglang/srt/models/gemma2.py +2 -56
- sglang/srt/models/gpt_bigcode.py +2 -6
- sglang/srt/models/grok.py +10 -8
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama2.py +6 -11
- sglang/srt/models/llama_classification.py +2 -6
- sglang/srt/models/llama_embedding.py +3 -4
- sglang/srt/models/llava.py +69 -91
- sglang/srt/models/llavavid.py +40 -86
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/mixtral.py +1 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +2 -5
- sglang/srt/models/qwen2.py +5 -10
- sglang/srt/models/qwen2_moe.py +21 -24
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/yivl.py +2 -7
- sglang/srt/openai_api/adapter.py +85 -4
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -74
- sglang/srt/sampling/sampling_params.py +4 -0
- sglang/srt/server.py +11 -4
- sglang/srt/utils.py +18 -33
- sglang/test/runners.py +2 -2
- sglang/test/test_layernorm.py +53 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +11 -5
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +52 -51
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -23,7 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import GPTBigCodeConfig
|
24
24
|
from vllm.config import CacheConfig, LoRAConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import get_act_fn
|
27
26
|
from vllm.model_executor.layers.linear import (
|
28
27
|
ColumnParallelLinear,
|
29
28
|
QKVParallelLinear,
|
@@ -33,9 +32,9 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
33
32
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
34
33
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
35
34
|
|
35
|
+
from sglang.srt.layers.activation import get_act_fn
|
36
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.layers.sampler import Sampler
|
39
38
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
40
39
|
|
41
40
|
|
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
262
261
|
if lora_config:
|
263
262
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
264
263
|
self.logits_processor = LogitsProcessor(config)
|
265
|
-
self.sampler = Sampler()
|
266
264
|
|
267
265
|
@torch.no_grad()
|
268
266
|
def forward(
|
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
272
270
|
input_metadata: InputMetadata,
|
273
271
|
) -> torch.Tensor:
|
274
272
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
275
|
-
|
273
|
+
return self.logits_processor(
|
276
274
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
277
275
|
)
|
278
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
279
|
-
return sample_output, logits_output
|
280
276
|
|
281
277
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
282
278
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/grok.py
CHANGED
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.layers.sampler import Sampler
|
50
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
51
50
|
|
52
51
|
|
@@ -274,9 +273,9 @@ class Grok1Model(nn.Module):
|
|
274
273
|
) -> torch.Tensor:
|
275
274
|
if input_embeds is None:
|
276
275
|
hidden_states = self.embed_tokens(input_ids)
|
276
|
+
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
277
277
|
else:
|
278
278
|
hidden_states = input_embeds
|
279
|
-
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
280
279
|
|
281
280
|
for i in range(len(self.layers)):
|
282
281
|
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
@@ -285,7 +284,7 @@ class Grok1Model(nn.Module):
|
|
285
284
|
return hidden_states
|
286
285
|
|
287
286
|
|
288
|
-
class
|
287
|
+
class Grok1ForCausalLM(nn.Module):
|
289
288
|
def __init__(
|
290
289
|
self,
|
291
290
|
config: PretrainedConfig,
|
@@ -298,7 +297,6 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
298
297
|
self.model = Grok1Model(config, quant_config=quant_config)
|
299
298
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
300
299
|
self.logits_processor = LogitsProcessor(config)
|
301
|
-
self.sampler = Sampler()
|
302
300
|
|
303
301
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
304
302
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
@@ -315,11 +313,9 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
315
313
|
input_embeds: torch.Tensor = None,
|
316
314
|
) -> torch.Tensor:
|
317
315
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
318
|
-
|
316
|
+
return self.logits_processor(
|
319
317
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
320
318
|
)
|
321
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
322
|
-
return sample_output, logits_output
|
323
319
|
|
324
320
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
325
321
|
stacked_params_mapping = [
|
@@ -419,4 +415,10 @@ def _prepare_presharded_weights(
|
|
419
415
|
return hf_folder, hf_weights_files, use_safetensors
|
420
416
|
|
421
417
|
|
422
|
-
|
418
|
+
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
419
|
+
"""An alias for backward-compatbility."""
|
420
|
+
|
421
|
+
pass
|
422
|
+
|
423
|
+
|
424
|
+
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]
|
sglang/srt/models/internlm2.py
CHANGED
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
40
40
|
from sglang.srt.layers.layernorm import RMSNorm
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.layers.sampler import Sampler
|
44
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
44
|
|
46
45
|
|
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
|
|
263
262
|
self.model = InternLM2Model(config, quant_config)
|
264
263
|
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
265
264
|
self.logits_processor = LogitsProcessor(config)
|
266
|
-
self.sampler = Sampler()
|
267
265
|
|
268
266
|
@torch.no_grad()
|
269
267
|
def forward(
|
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
|
|
274
272
|
input_embeds: torch.Tensor = None,
|
275
273
|
) -> torch.Tensor:
|
276
274
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
277
|
-
|
275
|
+
return self.logits_processor(
|
278
276
|
input_ids, hidden_states, self.output.weight, input_metadata
|
279
277
|
)
|
280
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
281
|
-
return sample_output, logits_output
|
282
278
|
|
283
279
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
284
280
|
stacked_params_mapping = [
|
sglang/srt/models/llama2.py
CHANGED
@@ -39,9 +39,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.activation import SiluAndMul
|
41
41
|
from sglang.srt.layers.layernorm import RMSNorm
|
42
|
-
from sglang.srt.layers.logits_processor import
|
42
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.layers.sampler import Sampler
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
45
|
|
47
46
|
|
@@ -303,7 +302,6 @@ class LlamaForCausalLM(nn.Module):
|
|
303
302
|
self.model = LlamaModel(config, quant_config=quant_config)
|
304
303
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
305
304
|
self.logits_processor = LogitsProcessor(config)
|
306
|
-
self.sampler = Sampler()
|
307
305
|
|
308
306
|
@torch.no_grad()
|
309
307
|
def forward(
|
@@ -312,13 +310,11 @@ class LlamaForCausalLM(nn.Module):
|
|
312
310
|
positions: torch.Tensor,
|
313
311
|
input_metadata: InputMetadata,
|
314
312
|
input_embeds: torch.Tensor = None,
|
315
|
-
) ->
|
313
|
+
) -> LogitProcessorOutput:
|
316
314
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
317
|
-
|
315
|
+
return self.logits_processor(
|
318
316
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
319
317
|
)
|
320
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
321
|
-
return sample_output, logits_output
|
322
318
|
|
323
319
|
def get_module_name(self, name):
|
324
320
|
stacked_params_mapping = [
|
@@ -361,6 +357,9 @@ class LlamaForCausalLM(nn.Module):
|
|
361
357
|
# Models trained using ColossalAI may include these tensors in
|
362
358
|
# the checkpoint. Skip them.
|
363
359
|
return
|
360
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
361
|
+
return
|
362
|
+
|
364
363
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
365
364
|
if weight_name not in name:
|
366
365
|
continue
|
@@ -368,8 +367,6 @@ class LlamaForCausalLM(nn.Module):
|
|
368
367
|
# Skip loading extra bias for GPTQ models.
|
369
368
|
if name.endswith(".bias") and name not in params_dict:
|
370
369
|
continue
|
371
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
372
|
-
continue
|
373
370
|
param = params_dict[name]
|
374
371
|
weight_loader = param.weight_loader
|
375
372
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -378,8 +375,6 @@ class LlamaForCausalLM(nn.Module):
|
|
378
375
|
# Skip loading extra bias for GPTQ models.
|
379
376
|
if name.endswith(".bias") and name not in params_dict:
|
380
377
|
return
|
381
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
382
|
-
return
|
383
378
|
param = params_dict[name]
|
384
379
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
385
380
|
weight_loader(param, loaded_weight)
|
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|
24
24
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
25
25
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
26
|
|
27
|
-
from sglang.srt.layers.logits_processor import
|
27
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
28
28
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
29
29
|
from sglang.srt.models.llama2 import LlamaModel
|
30
30
|
|
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
|
|
65
65
|
(input_metadata.batch_size, self.config.classification_out_size)
|
66
66
|
).to(input_ids.device)
|
67
67
|
|
68
|
-
return
|
68
|
+
return LogitProcessorOutput(
|
69
69
|
next_token_logits=scores,
|
70
70
|
next_token_logprobs=scores,
|
71
71
|
normalized_prompt_logprobs=scores,
|
@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
|
|
103
103
|
# Skip loading extra bias for GPTQ models.
|
104
104
|
if name.endswith(".bias") and name not in params_dict:
|
105
105
|
continue
|
106
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
107
|
-
continue
|
108
106
|
param = params_dict[name]
|
109
107
|
weight_loader = param.weight_loader
|
110
108
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
|
|
113
111
|
# Skip loading extra bias for GPTQ models.
|
114
112
|
if name.endswith(".bias") and name not in params_dict:
|
115
113
|
continue
|
116
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
117
|
-
continue
|
118
114
|
param = params_dict[name]
|
119
115
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
120
116
|
weight_loader(param, loaded_weight)
|
@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
|
|
57
57
|
# Models trained using ColossalAI may include these tensors in
|
58
58
|
# the checkpoint. Skip them.
|
59
59
|
return
|
60
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
61
|
+
return
|
62
|
+
|
60
63
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
61
64
|
if weight_name not in name:
|
62
65
|
continue
|
@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|
64
67
|
# Skip loading extra bias for GPTQ models.
|
65
68
|
if name.endswith(".bias") and name not in params_dict:
|
66
69
|
continue
|
67
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
68
|
-
continue
|
69
70
|
param = params_dict[name]
|
70
71
|
weight_loader = param.weight_loader
|
71
72
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|
74
75
|
# Skip loading extra bias for GPTQ models.
|
75
76
|
if name.endswith(".bias") and name not in params_dict:
|
76
77
|
return
|
77
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
78
|
-
return
|
79
78
|
param = params_dict[name]
|
80
79
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
81
80
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/llava.py
CHANGED
@@ -28,7 +28,6 @@ from transformers import (
|
|
28
28
|
LlavaConfig,
|
29
29
|
MistralConfig,
|
30
30
|
Qwen2Config,
|
31
|
-
SiglipVisionConfig,
|
32
31
|
SiglipVisionModel,
|
33
32
|
)
|
34
33
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
@@ -47,32 +46,19 @@ from sglang.srt.models.mistral import MistralForCausalLM
|
|
47
46
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
48
47
|
|
49
48
|
|
50
|
-
class
|
51
|
-
def
|
49
|
+
class LlavaBaseForCausalLM(nn.Module):
|
50
|
+
def pad_input_ids(
|
52
51
|
self,
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
self.config = config
|
59
|
-
self.vision_tower = None
|
60
|
-
self.config.vision_config.hidden_size = config.mm_hidden_size
|
61
|
-
self.config.text_config.hidden_size = config.hidden_size
|
62
|
-
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
63
|
-
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
64
|
-
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
65
|
-
self.language_model.model.image_newline = nn.Parameter(
|
66
|
-
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
67
|
-
)
|
68
|
-
|
69
|
-
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
70
|
-
|
52
|
+
input_ids: List[int],
|
53
|
+
pad_value: List[int],
|
54
|
+
pixel_values: List,
|
55
|
+
image_sizes: List[List[int]],
|
56
|
+
):
|
71
57
|
# hardcode for spatial_unpad + anyres
|
72
|
-
image_aspect_ratio = "anyres" if len(
|
58
|
+
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
|
73
59
|
offset_list = []
|
74
|
-
for image_s in
|
75
|
-
if len(
|
60
|
+
for image_s in image_sizes:
|
61
|
+
if len(image_sizes) > 16:
|
76
62
|
# 2x2 pooling with stride 2
|
77
63
|
new_image_feature_len = (
|
78
64
|
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
@@ -153,17 +139,15 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
153
139
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
154
140
|
bs = input_metadata.batch_size
|
155
141
|
|
156
|
-
# Embed text
|
142
|
+
# Embed text inputs
|
157
143
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
.numpy()
|
144
|
+
|
145
|
+
# Whether the requests need vision inputs
|
146
|
+
max_image_offset = np.array(
|
147
|
+
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
163
148
|
)
|
164
|
-
|
165
|
-
|
166
|
-
need_vision = need_vision & has_pixel
|
149
|
+
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
150
|
+
need_vision = start_positions <= max_image_offset
|
167
151
|
|
168
152
|
if need_vision.any():
|
169
153
|
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
@@ -332,31 +316,35 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
332
316
|
new_image_features.append(image_feature)
|
333
317
|
image_features = new_image_features
|
334
318
|
|
319
|
+
# Fill in the placeholder for the image
|
335
320
|
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
321
|
+
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
336
322
|
pt = 0
|
337
323
|
for i in range(bs):
|
338
324
|
if not need_vision[i]:
|
339
325
|
continue
|
340
326
|
|
341
327
|
start_idx = extend_start_loc_cpu[i]
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
328
|
+
prefix_len = prefix_lens_cpu[i]
|
329
|
+
|
330
|
+
# Multiple images
|
331
|
+
for j, image_offset in enumerate(image_offsets[i]):
|
332
|
+
if image_offset < prefix_len:
|
333
|
+
continue
|
334
|
+
|
335
|
+
tmp_image_feature = image_features[pt][j]
|
336
|
+
pad_len = tmp_image_feature.shape[0]
|
337
|
+
|
338
|
+
left_idx = start_idx + (image_offset - prefix_len)
|
339
|
+
right_idx = start_idx + (image_offset - prefix_len) + pad_len
|
340
|
+
try:
|
341
|
+
input_embeds[left_idx:right_idx] = tmp_image_feature
|
342
|
+
except RuntimeError as e:
|
343
|
+
print(f"RuntimeError in image encoding: {e}")
|
344
|
+
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
|
345
|
+
print(
|
346
|
+
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
|
347
|
+
)
|
360
348
|
pt += 1
|
361
349
|
|
362
350
|
return self.language_model(
|
@@ -366,8 +354,9 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
366
354
|
return self.language_model(input_ids, positions, input_metadata)
|
367
355
|
|
368
356
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
369
|
-
#
|
370
|
-
#
|
357
|
+
# Load clip vision model by cfg['mm_vision_tower']:
|
358
|
+
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
359
|
+
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
|
371
360
|
vision_path = self.config.mm_vision_tower
|
372
361
|
if "clip" in vision_path:
|
373
362
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
@@ -422,21 +411,41 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
422
411
|
# load language model
|
423
412
|
self.language_model.load_weights(weights)
|
424
413
|
|
425
|
-
monkey_path_clip_vision_embed_forward()
|
426
|
-
|
427
414
|
@property
|
428
415
|
def num_patches_per_side(self):
|
429
416
|
return self.image_size // self.patch_size
|
430
417
|
|
431
418
|
|
432
|
-
class
|
419
|
+
class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
433
420
|
def __init__(
|
434
421
|
self,
|
435
422
|
config: LlavaConfig,
|
436
423
|
quant_config: Optional[QuantizationConfig] = None,
|
437
424
|
cache_config: Optional[CacheConfig] = None,
|
438
425
|
) -> None:
|
439
|
-
super().__init__(
|
426
|
+
super().__init__()
|
427
|
+
|
428
|
+
self.config = config
|
429
|
+
self.vision_tower = None
|
430
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
431
|
+
self.config.text_config.hidden_size = config.hidden_size
|
432
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
433
|
+
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
434
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
435
|
+
self.language_model.model.image_newline = nn.Parameter(
|
436
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
437
|
+
)
|
438
|
+
|
439
|
+
|
440
|
+
class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
441
|
+
def __init__(
|
442
|
+
self,
|
443
|
+
config: LlavaConfig,
|
444
|
+
quant_config: Optional[QuantizationConfig] = None,
|
445
|
+
cache_config: Optional[CacheConfig] = None,
|
446
|
+
) -> None:
|
447
|
+
super().__init__()
|
448
|
+
|
440
449
|
self.config = config
|
441
450
|
self.vision_tower = None
|
442
451
|
if getattr(self.config, "vision_config", None) is None:
|
@@ -462,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
|
462
471
|
)
|
463
472
|
|
464
473
|
|
465
|
-
class LlavaMistralForCausalLM(
|
474
|
+
class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
466
475
|
def __init__(
|
467
476
|
self,
|
468
477
|
config: LlavaConfig,
|
469
478
|
quant_config: Optional[QuantizationConfig] = None,
|
470
479
|
cache_config: Optional[CacheConfig] = None,
|
471
480
|
) -> None:
|
472
|
-
super().__init__(
|
481
|
+
super().__init__()
|
482
|
+
|
473
483
|
self.config = config
|
474
484
|
self.vision_tower = None
|
475
485
|
if getattr(self.config, "vision_config", None) is None:
|
@@ -495,36 +505,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
|
495
505
|
)
|
496
506
|
|
497
507
|
|
498
|
-
first_call = True
|
499
|
-
|
500
|
-
|
501
|
-
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
502
|
-
batch_size = pixel_values.shape[0]
|
503
|
-
|
504
|
-
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
505
|
-
global first_call
|
506
|
-
if first_call:
|
507
|
-
self.patch_embedding.cpu().float()
|
508
|
-
first_call = False
|
509
|
-
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
510
|
-
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
511
|
-
|
512
|
-
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
513
|
-
|
514
|
-
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
515
|
-
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
516
|
-
embeddings = embeddings + self.position_embedding(self.position_ids)
|
517
|
-
return embeddings
|
518
|
-
|
519
|
-
|
520
|
-
def monkey_path_clip_vision_embed_forward():
|
521
|
-
import transformers
|
522
|
-
|
523
|
-
setattr(
|
524
|
-
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
525
|
-
"forward",
|
526
|
-
clip_vision_embed_forward,
|
527
|
-
)
|
528
|
-
|
529
|
-
|
530
508
|
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|