sglang 0.2.14.post1__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +2 -0
- sglang/bench_latency.py +39 -28
- sglang/lang/interpreter.py +3 -0
- sglang/lang/ir.py +5 -0
- sglang/launch_server_llavavid.py +26 -0
- sglang/srt/configs/__init__.py +5 -0
- sglang/srt/configs/exaone.py +195 -0
- sglang/srt/constrained/fsm_cache.py +1 -1
- sglang/srt/conversation.py +24 -2
- sglang/srt/hf_transformers_utils.py +11 -160
- sglang/srt/layers/activation.py +10 -4
- sglang/srt/layers/extend_attention.py +13 -8
- sglang/srt/layers/layernorm.py +47 -1
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +69 -16
- sglang/srt/managers/controller_multi.py +5 -5
- sglang/srt/managers/controller_single.py +5 -5
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/schedule_batch.py +25 -13
- sglang/srt/managers/tokenizer_manager.py +76 -63
- sglang/srt/managers/tp_worker.py +47 -36
- sglang/srt/model_config.py +3 -3
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +78 -43
- sglang/srt/model_executor/model_runner.py +29 -18
- sglang/srt/models/chatglm.py +5 -13
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +57 -25
- sglang/srt/models/exaone.py +399 -0
- sglang/srt/models/gemma.py +7 -3
- sglang/srt/models/gemma2.py +6 -52
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +14 -4
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +10 -7
- sglang/srt/models/llama_classification.py +2 -6
- sglang/srt/models/llama_embedding.py +3 -4
- sglang/srt/models/llava.py +69 -91
- sglang/srt/models/llavavid.py +40 -86
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +6 -2
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_moe.py +12 -33
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/models/yivl.py +2 -7
- sglang/srt/openai_api/adapter.py +16 -1
- sglang/srt/openai_api/protocol.py +5 -5
- sglang/srt/sampling/sampling_batch_info.py +79 -6
- sglang/srt/server.py +9 -9
- sglang/srt/utils.py +18 -36
- sglang/test/runners.py +2 -2
- sglang/test/test_layernorm.py +53 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/METADATA +8 -8
- sglang-0.2.15.dist-info/RECORD +118 -0
- sglang-0.2.14.post1.dist-info/RECORD +0 -114
- {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
- {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
sglang/srt/models/gemma2.py
CHANGED
@@ -22,11 +22,6 @@ from torch import nn
|
|
22
22
|
from transformers import PretrainedConfig
|
23
23
|
from vllm.config import CacheConfig, LoRAConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
|
26
|
-
# FIXME: temporary solution, remove after next vllm release
|
27
|
-
from vllm.model_executor.custom_op import CustomOp
|
28
|
-
|
29
|
-
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
30
25
|
from vllm.model_executor.layers.linear import (
|
31
26
|
MergedColumnParallelLinear,
|
32
27
|
QKVParallelLinear,
|
@@ -39,8 +34,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
|
|
39
34
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
35
|
|
41
36
|
from sglang.srt.layers.activation import GeluAndMul
|
37
|
+
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
42
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
+
from sglang.srt.layers.sampler import Sampler
|
44
41
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
42
|
|
46
43
|
|
@@ -50,52 +47,6 @@ def get_attention_sliding_window_size(config):
|
|
50
47
|
return config.sliding_window - 1
|
51
48
|
|
52
49
|
|
53
|
-
class GemmaRMSNorm(CustomOp):
|
54
|
-
"""RMS normalization for Gemma.
|
55
|
-
|
56
|
-
Two differences from the above RMSNorm:
|
57
|
-
1. x * (1 + w) instead of x * w.
|
58
|
-
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
59
|
-
"""
|
60
|
-
|
61
|
-
def __init__(
|
62
|
-
self,
|
63
|
-
hidden_size: int,
|
64
|
-
eps: float = 1e-6,
|
65
|
-
) -> None:
|
66
|
-
super().__init__()
|
67
|
-
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
68
|
-
self.variance_epsilon = eps
|
69
|
-
|
70
|
-
def forward_native(
|
71
|
-
self,
|
72
|
-
x: torch.Tensor,
|
73
|
-
residual: Optional[torch.Tensor] = None,
|
74
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
75
|
-
"""PyTorch-native implementation equivalent to forward()."""
|
76
|
-
orig_dtype = x.dtype
|
77
|
-
if residual is not None:
|
78
|
-
x = x + residual
|
79
|
-
residual = x
|
80
|
-
|
81
|
-
x = x.float()
|
82
|
-
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
83
|
-
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
84
|
-
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
85
|
-
# See https://github.com/huggingface/transformers/pull/29402
|
86
|
-
x = x * (1.0 + self.weight.float())
|
87
|
-
x = x.to(orig_dtype)
|
88
|
-
return x if residual is None else (x, residual)
|
89
|
-
|
90
|
-
def forward_cuda(
|
91
|
-
self,
|
92
|
-
x: torch.Tensor,
|
93
|
-
residual: Optional[torch.Tensor] = None,
|
94
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
95
|
-
# from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
|
96
|
-
return self.forward_native(x, residual)
|
97
|
-
|
98
|
-
|
99
50
|
# FIXME: temporary solution, remove after next vllm release
|
100
51
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
101
52
|
|
@@ -396,6 +347,7 @@ class Gemma2ForCausalLM(nn.Module):
|
|
396
347
|
self.quant_config = quant_config
|
397
348
|
self.model = Gemma2Model(config, cache_config, quant_config)
|
398
349
|
self.logits_processor = LogitsProcessor(config)
|
350
|
+
self.sampler = Sampler()
|
399
351
|
|
400
352
|
@torch.no_grad()
|
401
353
|
def forward(
|
@@ -406,9 +358,11 @@ class Gemma2ForCausalLM(nn.Module):
|
|
406
358
|
input_embeds: torch.Tensor = None,
|
407
359
|
) -> torch.Tensor:
|
408
360
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
409
|
-
|
361
|
+
logits_output = self.logits_processor(
|
410
362
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
411
363
|
)
|
364
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
365
|
+
return sample_output, logits_output
|
412
366
|
|
413
367
|
def get_attention_sliding_window_size(self):
|
414
368
|
return get_attention_sliding_window_size(self.config)
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
35
35
|
from sglang.srt.layers.activation import get_act_fn
|
36
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.sampler import Sampler
|
38
39
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
39
40
|
|
40
41
|
|
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
261
262
|
if lora_config:
|
262
263
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
263
264
|
self.logits_processor = LogitsProcessor(config)
|
265
|
+
self.sampler = Sampler()
|
264
266
|
|
265
267
|
@torch.no_grad()
|
266
268
|
def forward(
|
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
270
272
|
input_metadata: InputMetadata,
|
271
273
|
) -> torch.Tensor:
|
272
274
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
273
|
-
|
275
|
+
logits_output = self.logits_processor(
|
274
276
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
275
277
|
)
|
278
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
279
|
+
return sample_output, logits_output
|
276
280
|
|
277
281
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
278
282
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/grok.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
+
from sglang.srt.layers.sampler import Sampler
|
49
50
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
51
|
|
51
52
|
|
@@ -273,9 +274,9 @@ class Grok1Model(nn.Module):
|
|
273
274
|
) -> torch.Tensor:
|
274
275
|
if input_embeds is None:
|
275
276
|
hidden_states = self.embed_tokens(input_ids)
|
277
|
+
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
276
278
|
else:
|
277
279
|
hidden_states = input_embeds
|
278
|
-
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
279
280
|
|
280
281
|
for i in range(len(self.layers)):
|
281
282
|
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
@@ -284,7 +285,7 @@ class Grok1Model(nn.Module):
|
|
284
285
|
return hidden_states
|
285
286
|
|
286
287
|
|
287
|
-
class
|
288
|
+
class Grok1ForCausalLM(nn.Module):
|
288
289
|
def __init__(
|
289
290
|
self,
|
290
291
|
config: PretrainedConfig,
|
@@ -297,6 +298,7 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
297
298
|
self.model = Grok1Model(config, quant_config=quant_config)
|
298
299
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
299
300
|
self.logits_processor = LogitsProcessor(config)
|
301
|
+
self.sampler = Sampler()
|
300
302
|
|
301
303
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
302
304
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
@@ -313,9 +315,11 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
313
315
|
input_embeds: torch.Tensor = None,
|
314
316
|
) -> torch.Tensor:
|
315
317
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
316
|
-
|
318
|
+
logits_output = self.logits_processor(
|
317
319
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
318
320
|
)
|
321
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
322
|
+
return sample_output, logits_output
|
319
323
|
|
320
324
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
321
325
|
stacked_params_mapping = [
|
@@ -415,4 +419,10 @@ def _prepare_presharded_weights(
|
|
415
419
|
return hf_folder, hf_weights_files, use_safetensors
|
416
420
|
|
417
421
|
|
418
|
-
|
422
|
+
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
423
|
+
"""An alias for backward-compatbility."""
|
424
|
+
|
425
|
+
pass
|
426
|
+
|
427
|
+
|
428
|
+
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]
|
sglang/srt/models/internlm2.py
CHANGED
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
40
40
|
from sglang.srt.layers.layernorm import RMSNorm
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
+
from sglang.srt.layers.sampler import Sampler
|
43
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
45
|
|
45
46
|
|
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
|
|
262
263
|
self.model = InternLM2Model(config, quant_config)
|
263
264
|
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
264
265
|
self.logits_processor = LogitsProcessor(config)
|
266
|
+
self.sampler = Sampler()
|
265
267
|
|
266
268
|
@torch.no_grad()
|
267
269
|
def forward(
|
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
|
|
272
274
|
input_embeds: torch.Tensor = None,
|
273
275
|
) -> torch.Tensor:
|
274
276
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
275
|
-
|
277
|
+
logits_output = self.logits_processor(
|
276
278
|
input_ids, hidden_states, self.output.weight, input_metadata
|
277
279
|
)
|
280
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
281
|
+
return sample_output, logits_output
|
278
282
|
|
279
283
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
280
284
|
stacked_params_mapping = [
|
sglang/srt/models/llama2.py
CHANGED
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.activation import SiluAndMul
|
41
41
|
from sglang.srt.layers.layernorm import RMSNorm
|
42
|
-
from sglang.srt.layers.logits_processor import
|
42
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
+
from sglang.srt.layers.sampler import Sampler
|
44
45
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
46
|
|
46
47
|
|
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
|
|
302
303
|
self.model = LlamaModel(config, quant_config=quant_config)
|
303
304
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
304
305
|
self.logits_processor = LogitsProcessor(config)
|
306
|
+
self.sampler = Sampler()
|
305
307
|
|
306
308
|
@torch.no_grad()
|
307
309
|
def forward(
|
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
|
|
310
312
|
positions: torch.Tensor,
|
311
313
|
input_metadata: InputMetadata,
|
312
314
|
input_embeds: torch.Tensor = None,
|
313
|
-
) ->
|
315
|
+
) -> LogitsProcessorOutput:
|
314
316
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
315
|
-
|
317
|
+
logits_output = self.logits_processor(
|
316
318
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
317
319
|
)
|
320
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
321
|
+
return sample_output, logits_output
|
318
322
|
|
319
323
|
def get_module_name(self, name):
|
320
324
|
stacked_params_mapping = [
|
@@ -357,6 +361,9 @@ class LlamaForCausalLM(nn.Module):
|
|
357
361
|
# Models trained using ColossalAI may include these tensors in
|
358
362
|
# the checkpoint. Skip them.
|
359
363
|
return
|
364
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
365
|
+
return
|
366
|
+
|
360
367
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
361
368
|
if weight_name not in name:
|
362
369
|
continue
|
@@ -364,8 +371,6 @@ class LlamaForCausalLM(nn.Module):
|
|
364
371
|
# Skip loading extra bias for GPTQ models.
|
365
372
|
if name.endswith(".bias") and name not in params_dict:
|
366
373
|
continue
|
367
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
368
|
-
continue
|
369
374
|
param = params_dict[name]
|
370
375
|
weight_loader = param.weight_loader
|
371
376
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -374,8 +379,6 @@ class LlamaForCausalLM(nn.Module):
|
|
374
379
|
# Skip loading extra bias for GPTQ models.
|
375
380
|
if name.endswith(".bias") and name not in params_dict:
|
376
381
|
return
|
377
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
378
|
-
return
|
379
382
|
param = params_dict[name]
|
380
383
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
381
384
|
weight_loader(param, loaded_weight)
|
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|
24
24
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
25
25
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
26
|
|
27
|
-
from sglang.srt.layers.logits_processor import
|
27
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
28
28
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
29
29
|
from sglang.srt.models.llama2 import LlamaModel
|
30
30
|
|
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
|
|
65
65
|
(input_metadata.batch_size, self.config.classification_out_size)
|
66
66
|
).to(input_ids.device)
|
67
67
|
|
68
|
-
return
|
68
|
+
return LogitsProcessorOutput(
|
69
69
|
next_token_logits=scores,
|
70
70
|
next_token_logprobs=scores,
|
71
71
|
normalized_prompt_logprobs=scores,
|
@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
|
|
103
103
|
# Skip loading extra bias for GPTQ models.
|
104
104
|
if name.endswith(".bias") and name not in params_dict:
|
105
105
|
continue
|
106
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
107
|
-
continue
|
108
106
|
param = params_dict[name]
|
109
107
|
weight_loader = param.weight_loader
|
110
108
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
|
|
113
111
|
# Skip loading extra bias for GPTQ models.
|
114
112
|
if name.endswith(".bias") and name not in params_dict:
|
115
113
|
continue
|
116
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
117
|
-
continue
|
118
114
|
param = params_dict[name]
|
119
115
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
120
116
|
weight_loader(param, loaded_weight)
|
@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
|
|
57
57
|
# Models trained using ColossalAI may include these tensors in
|
58
58
|
# the checkpoint. Skip them.
|
59
59
|
return
|
60
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
61
|
+
return
|
62
|
+
|
60
63
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
61
64
|
if weight_name not in name:
|
62
65
|
continue
|
@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|
64
67
|
# Skip loading extra bias for GPTQ models.
|
65
68
|
if name.endswith(".bias") and name not in params_dict:
|
66
69
|
continue
|
67
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
68
|
-
continue
|
69
70
|
param = params_dict[name]
|
70
71
|
weight_loader = param.weight_loader
|
71
72
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|
74
75
|
# Skip loading extra bias for GPTQ models.
|
75
76
|
if name.endswith(".bias") and name not in params_dict:
|
76
77
|
return
|
77
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
78
|
-
return
|
79
78
|
param = params_dict[name]
|
80
79
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
81
80
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/llava.py
CHANGED
@@ -28,7 +28,6 @@ from transformers import (
|
|
28
28
|
LlavaConfig,
|
29
29
|
MistralConfig,
|
30
30
|
Qwen2Config,
|
31
|
-
SiglipVisionConfig,
|
32
31
|
SiglipVisionModel,
|
33
32
|
)
|
34
33
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
@@ -47,32 +46,19 @@ from sglang.srt.models.mistral import MistralForCausalLM
|
|
47
46
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
48
47
|
|
49
48
|
|
50
|
-
class
|
51
|
-
def
|
49
|
+
class LlavaBaseForCausalLM(nn.Module):
|
50
|
+
def pad_input_ids(
|
52
51
|
self,
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
self.config = config
|
59
|
-
self.vision_tower = None
|
60
|
-
self.config.vision_config.hidden_size = config.mm_hidden_size
|
61
|
-
self.config.text_config.hidden_size = config.hidden_size
|
62
|
-
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
63
|
-
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
64
|
-
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
65
|
-
self.language_model.model.image_newline = nn.Parameter(
|
66
|
-
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
67
|
-
)
|
68
|
-
|
69
|
-
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
70
|
-
|
52
|
+
input_ids: List[int],
|
53
|
+
pad_value: List[int],
|
54
|
+
pixel_values: List,
|
55
|
+
image_sizes: List[List[int]],
|
56
|
+
):
|
71
57
|
# hardcode for spatial_unpad + anyres
|
72
|
-
image_aspect_ratio = "anyres" if len(
|
58
|
+
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
|
73
59
|
offset_list = []
|
74
|
-
for image_s in
|
75
|
-
if len(
|
60
|
+
for image_s in image_sizes:
|
61
|
+
if len(image_sizes) > 16:
|
76
62
|
# 2x2 pooling with stride 2
|
77
63
|
new_image_feature_len = (
|
78
64
|
math.ceil(self.image_size / self.patch_size / 2) ** 2
|
@@ -153,17 +139,15 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
153
139
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
154
140
|
bs = input_metadata.batch_size
|
155
141
|
|
156
|
-
# Embed text
|
142
|
+
# Embed text inputs
|
157
143
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
.numpy()
|
144
|
+
|
145
|
+
# Whether the requests need vision inputs
|
146
|
+
max_image_offset = np.array(
|
147
|
+
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
163
148
|
)
|
164
|
-
|
165
|
-
|
166
|
-
need_vision = need_vision & has_pixel
|
149
|
+
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
150
|
+
need_vision = start_positions <= max_image_offset
|
167
151
|
|
168
152
|
if need_vision.any():
|
169
153
|
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
@@ -332,31 +316,35 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
332
316
|
new_image_features.append(image_feature)
|
333
317
|
image_features = new_image_features
|
334
318
|
|
319
|
+
# Fill in the placeholder for the image
|
335
320
|
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
321
|
+
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
336
322
|
pt = 0
|
337
323
|
for i in range(bs):
|
338
324
|
if not need_vision[i]:
|
339
325
|
continue
|
340
326
|
|
341
327
|
start_idx = extend_start_loc_cpu[i]
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
328
|
+
prefix_len = prefix_lens_cpu[i]
|
329
|
+
|
330
|
+
# Multiple images
|
331
|
+
for j, image_offset in enumerate(image_offsets[i]):
|
332
|
+
if image_offset < prefix_len:
|
333
|
+
continue
|
334
|
+
|
335
|
+
tmp_image_feature = image_features[pt][j]
|
336
|
+
pad_len = tmp_image_feature.shape[0]
|
337
|
+
|
338
|
+
left_idx = start_idx + (image_offset - prefix_len)
|
339
|
+
right_idx = start_idx + (image_offset - prefix_len) + pad_len
|
340
|
+
try:
|
341
|
+
input_embeds[left_idx:right_idx] = tmp_image_feature
|
342
|
+
except RuntimeError as e:
|
343
|
+
print(f"RuntimeError in image encoding: {e}")
|
344
|
+
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
|
345
|
+
print(
|
346
|
+
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
|
347
|
+
)
|
360
348
|
pt += 1
|
361
349
|
|
362
350
|
return self.language_model(
|
@@ -366,8 +354,9 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
366
354
|
return self.language_model(input_ids, positions, input_metadata)
|
367
355
|
|
368
356
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
369
|
-
#
|
370
|
-
#
|
357
|
+
# Load clip vision model by cfg['mm_vision_tower']:
|
358
|
+
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
359
|
+
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
|
371
360
|
vision_path = self.config.mm_vision_tower
|
372
361
|
if "clip" in vision_path:
|
373
362
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
@@ -422,21 +411,41 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|
422
411
|
# load language model
|
423
412
|
self.language_model.load_weights(weights)
|
424
413
|
|
425
|
-
monkey_path_clip_vision_embed_forward()
|
426
|
-
|
427
414
|
@property
|
428
415
|
def num_patches_per_side(self):
|
429
416
|
return self.image_size // self.patch_size
|
430
417
|
|
431
418
|
|
432
|
-
class
|
419
|
+
class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
433
420
|
def __init__(
|
434
421
|
self,
|
435
422
|
config: LlavaConfig,
|
436
423
|
quant_config: Optional[QuantizationConfig] = None,
|
437
424
|
cache_config: Optional[CacheConfig] = None,
|
438
425
|
) -> None:
|
439
|
-
super().__init__(
|
426
|
+
super().__init__()
|
427
|
+
|
428
|
+
self.config = config
|
429
|
+
self.vision_tower = None
|
430
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
431
|
+
self.config.text_config.hidden_size = config.hidden_size
|
432
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
433
|
+
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
434
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
435
|
+
self.language_model.model.image_newline = nn.Parameter(
|
436
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
437
|
+
)
|
438
|
+
|
439
|
+
|
440
|
+
class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
441
|
+
def __init__(
|
442
|
+
self,
|
443
|
+
config: LlavaConfig,
|
444
|
+
quant_config: Optional[QuantizationConfig] = None,
|
445
|
+
cache_config: Optional[CacheConfig] = None,
|
446
|
+
) -> None:
|
447
|
+
super().__init__()
|
448
|
+
|
440
449
|
self.config = config
|
441
450
|
self.vision_tower = None
|
442
451
|
if getattr(self.config, "vision_config", None) is None:
|
@@ -462,14 +471,15 @@ class LlavaQwenForCausalLM(LlavaLlamaForCausalLM):
|
|
462
471
|
)
|
463
472
|
|
464
473
|
|
465
|
-
class LlavaMistralForCausalLM(
|
474
|
+
class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
466
475
|
def __init__(
|
467
476
|
self,
|
468
477
|
config: LlavaConfig,
|
469
478
|
quant_config: Optional[QuantizationConfig] = None,
|
470
479
|
cache_config: Optional[CacheConfig] = None,
|
471
480
|
) -> None:
|
472
|
-
super().__init__(
|
481
|
+
super().__init__()
|
482
|
+
|
473
483
|
self.config = config
|
474
484
|
self.vision_tower = None
|
475
485
|
if getattr(self.config, "vision_config", None) is None:
|
@@ -495,36 +505,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
|
|
495
505
|
)
|
496
506
|
|
497
507
|
|
498
|
-
first_call = True
|
499
|
-
|
500
|
-
|
501
|
-
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
502
|
-
batch_size = pixel_values.shape[0]
|
503
|
-
|
504
|
-
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
505
|
-
global first_call
|
506
|
-
if first_call:
|
507
|
-
self.patch_embedding.cpu().float()
|
508
|
-
first_call = False
|
509
|
-
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
510
|
-
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
511
|
-
|
512
|
-
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
513
|
-
|
514
|
-
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
515
|
-
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
516
|
-
embeddings = embeddings + self.position_embedding(self.position_ids)
|
517
|
-
return embeddings
|
518
|
-
|
519
|
-
|
520
|
-
def monkey_path_clip_vision_embed_forward():
|
521
|
-
import transformers
|
522
|
-
|
523
|
-
setattr(
|
524
|
-
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
525
|
-
"forward",
|
526
|
-
clip_vision_embed_forward,
|
527
|
-
)
|
528
|
-
|
529
|
-
|
530
508
|
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|