sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +13 -6
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +2 -4
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +40 -35
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +110 -74
- sglang/srt/managers/tokenizer_manager.py +24 -15
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +60 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +118 -141
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +6 -8
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +8 -43
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/{llama2.py → llama.py} +48 -26
- sglang/srt/models/llama_classification.py +14 -40
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +38 -16
- sglang/srt/models/llavavid.py +7 -8
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +67 -58
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +70 -0
- sglang/test/test_utils.py +89 -1
- sglang/utils.py +38 -4
- sglang/version.py +1 -1
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.2.15.dist-info/RECORD +0 -118
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
sglang/srt/models/gemma.py
CHANGED
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
|
|
37
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.layers.sampler import Sampler
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
42
41
|
|
43
42
|
|
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
|
|
288
287
|
self.quant_config = quant_config
|
289
288
|
self.model = GemmaModel(config, quant_config=quant_config)
|
290
289
|
self.logits_processor = LogitsProcessor(config)
|
291
|
-
self.sampler = Sampler()
|
292
290
|
|
293
291
|
@torch.no_grad()
|
294
292
|
def forward(
|
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
|
|
299
297
|
input_embeds: torch.Tensor = None,
|
300
298
|
) -> torch.Tensor:
|
301
299
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
302
|
-
|
300
|
+
return self.logits_processor(
|
303
301
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
304
302
|
)
|
305
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
306
|
-
return (sample_output, logits_output)
|
307
303
|
|
308
304
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
309
305
|
stacked_params_mapping = [
|
sglang/srt/models/gemma2.py
CHANGED
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
|
|
37
37
|
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.layers.sampler import Sampler
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
42
41
|
|
43
42
|
|
@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
|
|
347
346
|
self.quant_config = quant_config
|
348
347
|
self.model = Gemma2Model(config, cache_config, quant_config)
|
349
348
|
self.logits_processor = LogitsProcessor(config)
|
350
|
-
self.sampler = Sampler()
|
351
349
|
|
352
350
|
@torch.no_grad()
|
353
351
|
def forward(
|
@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
|
|
358
356
|
input_embeds: torch.Tensor = None,
|
359
357
|
) -> torch.Tensor:
|
360
358
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
361
|
-
|
359
|
+
return self.logits_processor(
|
362
360
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
363
361
|
)
|
364
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
365
|
-
return sample_output, logits_output
|
366
362
|
|
367
363
|
def get_attention_sliding_window_size(self):
|
368
364
|
return get_attention_sliding_window_size(self.config)
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
35
35
|
from sglang.srt.layers.activation import get_act_fn
|
36
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.layers.sampler import Sampler
|
39
38
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
40
39
|
|
41
40
|
|
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
262
261
|
if lora_config:
|
263
262
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
264
263
|
self.logits_processor = LogitsProcessor(config)
|
265
|
-
self.sampler = Sampler()
|
266
264
|
|
267
265
|
@torch.no_grad()
|
268
266
|
def forward(
|
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
272
270
|
input_metadata: InputMetadata,
|
273
271
|
) -> torch.Tensor:
|
274
272
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
275
|
-
|
273
|
+
return self.logits_processor(
|
276
274
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
277
275
|
)
|
278
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
279
|
-
return sample_output, logits_output
|
280
276
|
|
281
277
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
282
278
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/grok.py
CHANGED
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.layers.sampler import Sampler
|
50
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
51
50
|
|
52
51
|
|
@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
|
|
298
297
|
self.model = Grok1Model(config, quant_config=quant_config)
|
299
298
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
300
299
|
self.logits_processor = LogitsProcessor(config)
|
301
|
-
self.sampler = Sampler()
|
302
300
|
|
303
301
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
304
302
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
|
|
315
313
|
input_embeds: torch.Tensor = None,
|
316
314
|
) -> torch.Tensor:
|
317
315
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
318
|
-
|
316
|
+
return self.logits_processor(
|
319
317
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
320
318
|
)
|
321
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
322
|
-
return sample_output, logits_output
|
323
319
|
|
324
320
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
325
321
|
stacked_params_mapping = [
|
sglang/srt/models/internlm2.py
CHANGED
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
40
40
|
from sglang.srt.layers.layernorm import RMSNorm
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.layers.sampler import Sampler
|
44
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
44
|
|
46
45
|
|
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
|
|
263
262
|
self.model = InternLM2Model(config, quant_config)
|
264
263
|
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
265
264
|
self.logits_processor = LogitsProcessor(config)
|
266
|
-
self.sampler = Sampler()
|
267
265
|
|
268
266
|
@torch.no_grad()
|
269
267
|
def forward(
|
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
|
|
274
272
|
input_embeds: torch.Tensor = None,
|
275
273
|
) -> torch.Tensor:
|
276
274
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
277
|
-
|
275
|
+
return self.logits_processor(
|
278
276
|
input_ids, hidden_states, self.output.weight, input_metadata
|
279
277
|
)
|
280
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
281
|
-
return sample_output, logits_output
|
282
278
|
|
283
279
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
284
280
|
stacked_params_mapping = [
|
@@ -41,7 +41,8 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
41
41
|
from sglang.srt.layers.layernorm import RMSNorm
|
42
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.layers.
|
44
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
45
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
46
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
47
|
|
47
48
|
|
@@ -295,15 +296,16 @@ class LlamaForCausalLM(nn.Module):
|
|
295
296
|
config: LlamaConfig,
|
296
297
|
quant_config: Optional[QuantizationConfig] = None,
|
297
298
|
cache_config: Optional[CacheConfig] = None,
|
298
|
-
efficient_weight_load=False,
|
299
299
|
) -> None:
|
300
300
|
super().__init__()
|
301
301
|
self.config = config
|
302
302
|
self.quant_config = quant_config
|
303
|
+
self.torchao_config = global_server_args_dict["torchao_config"]
|
303
304
|
self.model = LlamaModel(config, quant_config=quant_config)
|
304
305
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
305
306
|
self.logits_processor = LogitsProcessor(config)
|
306
|
-
|
307
|
+
|
308
|
+
self.param_dict = dict(self.named_parameters())
|
307
309
|
|
308
310
|
@torch.no_grad()
|
309
311
|
def forward(
|
@@ -314,13 +316,35 @@ class LlamaForCausalLM(nn.Module):
|
|
314
316
|
input_embeds: torch.Tensor = None,
|
315
317
|
) -> LogitsProcessorOutput:
|
316
318
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
317
|
-
|
319
|
+
return self.logits_processor(
|
318
320
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
319
321
|
)
|
320
|
-
|
321
|
-
|
322
|
+
|
323
|
+
def get_hidden_dim(self, module_name):
|
324
|
+
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
325
|
+
return self.config.hidden_size, self.config.hidden_size
|
326
|
+
elif module_name in ["kv_proj"]:
|
327
|
+
return self.config.hidden_size, self.config.hidden_size // (
|
328
|
+
self.config.num_attention_heads // self.config.num_key_value_heads
|
329
|
+
)
|
330
|
+
elif module_name == "gate_up_proj":
|
331
|
+
return self.config.hidden_size, self.config.intermediate_size
|
332
|
+
elif module_name == "down_proj":
|
333
|
+
return self.config.intermediate_size, self.config.hidden_size
|
334
|
+
else:
|
335
|
+
raise NotImplementedError()
|
322
336
|
|
323
337
|
def get_module_name(self, name):
|
338
|
+
params_mapping = {
|
339
|
+
"q_proj": "qkv_proj",
|
340
|
+
"k_proj": "qkv_proj",
|
341
|
+
"v_proj": "qkv_proj",
|
342
|
+
"gate_proj": "gate_up_proj",
|
343
|
+
"up_proj": "gate_up_proj",
|
344
|
+
}
|
345
|
+
return params_mapping.get(name, name)
|
346
|
+
|
347
|
+
def get_module_name_from_weight_name(self, name):
|
324
348
|
stacked_params_mapping = [
|
325
349
|
# (param_name, shard_name, shard_id, num_shard)
|
326
350
|
("qkv_proj", "q_proj", "q", 3),
|
@@ -341,28 +365,26 @@ class LlamaForCausalLM(nn.Module):
|
|
341
365
|
params_dict = dict(self.named_parameters())
|
342
366
|
return len(params_dict)
|
343
367
|
|
344
|
-
def load_weights(
|
345
|
-
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
346
|
-
):
|
368
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
347
369
|
stacked_params_mapping = [
|
348
370
|
# (param_name, shard_name, shard_id)
|
349
|
-
("qkv_proj", "q_proj", "q"),
|
350
|
-
("qkv_proj", "k_proj", "k"),
|
351
|
-
("qkv_proj", "v_proj", "v"),
|
352
|
-
("gate_up_proj", "gate_proj", 0),
|
353
|
-
("gate_up_proj", "up_proj", 1),
|
371
|
+
(".qkv_proj", ".q_proj", "q"),
|
372
|
+
(".qkv_proj", ".k_proj", "k"),
|
373
|
+
(".qkv_proj", ".v_proj", "v"),
|
374
|
+
(".gate_up_proj", ".gate_proj", 0),
|
375
|
+
(".gate_up_proj", ".up_proj", 1),
|
354
376
|
]
|
355
|
-
params_dict =
|
377
|
+
params_dict = self.param_dict
|
356
378
|
|
357
|
-
|
379
|
+
for name, loaded_weight in weights:
|
358
380
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
359
|
-
|
381
|
+
continue
|
360
382
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
361
383
|
# Models trained using ColossalAI may include these tensors in
|
362
384
|
# the checkpoint. Skip them.
|
363
|
-
|
385
|
+
continue
|
364
386
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
365
|
-
|
387
|
+
continue
|
366
388
|
|
367
389
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
368
390
|
if weight_name not in name:
|
@@ -378,16 +400,16 @@ class LlamaForCausalLM(nn.Module):
|
|
378
400
|
else:
|
379
401
|
# Skip loading extra bias for GPTQ models.
|
380
402
|
if name.endswith(".bias") and name not in params_dict:
|
381
|
-
|
403
|
+
continue
|
382
404
|
param = params_dict[name]
|
383
405
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
384
406
|
weight_loader(param, loaded_weight)
|
385
407
|
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
408
|
+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
409
|
+
|
410
|
+
|
411
|
+
class Phi3ForCausalLM(LlamaForCausalLM):
|
412
|
+
pass
|
391
413
|
|
392
414
|
|
393
|
-
EntryClass = LlamaForCausalLM
|
415
|
+
EntryClass = [LlamaForCausalLM, Phi3ForCausalLM]
|
@@ -16,17 +16,15 @@ limitations under the License.
|
|
16
16
|
from typing import Iterable, Optional, Tuple
|
17
17
|
|
18
18
|
import torch
|
19
|
-
import tqdm
|
20
19
|
from torch import nn
|
21
20
|
from transformers import LlamaConfig
|
22
21
|
from vllm.config import CacheConfig
|
23
|
-
from vllm.distributed import get_tensor_model_parallel_rank
|
24
22
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
25
23
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
24
|
|
27
25
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
28
26
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
29
|
-
from sglang.srt.models.
|
27
|
+
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
30
28
|
|
31
29
|
|
32
30
|
class LlamaForClassification(nn.Module):
|
@@ -42,10 +40,12 @@ class LlamaForClassification(nn.Module):
|
|
42
40
|
self.model = LlamaModel(config, quant_config=quant_config)
|
43
41
|
|
44
42
|
self.classification_head = nn.Linear(
|
45
|
-
config.hidden_size, config.classification_out_size
|
43
|
+
config.hidden_size, config.classification_out_size, bias=False
|
46
44
|
)
|
47
45
|
self.eos_token_id = config.eos_token_id
|
48
46
|
|
47
|
+
self.param_dict = dict(self.named_parameters())
|
48
|
+
|
49
49
|
@torch.no_grad()
|
50
50
|
def forward(
|
51
51
|
self,
|
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
|
|
65
65
|
(input_metadata.batch_size, self.config.classification_out_size)
|
66
66
|
).to(input_ids.device)
|
67
67
|
|
68
|
-
|
68
|
+
logits_output = LogitsProcessorOutput(
|
69
69
|
next_token_logits=scores,
|
70
70
|
next_token_logprobs=scores,
|
71
71
|
normalized_prompt_logprobs=scores,
|
@@ -74,46 +74,20 @@ class LlamaForClassification(nn.Module):
|
|
74
74
|
output_top_logprobs=None,
|
75
75
|
)
|
76
76
|
|
77
|
+
return logits_output
|
78
|
+
|
77
79
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
78
|
-
|
79
|
-
# (param_name, shard_name, shard_id)
|
80
|
-
("qkv_proj", "q_proj", "q"),
|
81
|
-
("qkv_proj", "k_proj", "k"),
|
82
|
-
("qkv_proj", "v_proj", "v"),
|
83
|
-
("gate_up_proj", "gate_proj", 0),
|
84
|
-
("gate_up_proj", "up_proj", 1),
|
85
|
-
]
|
86
|
-
params_dict = dict(self.named_parameters())
|
87
|
-
if get_tensor_model_parallel_rank() == 0:
|
88
|
-
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
89
|
-
for name, loaded_weight in weights:
|
90
|
-
if "rotary_emb.inv_freq" in name or "projector" in name:
|
91
|
-
continue
|
92
|
-
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
93
|
-
# Models trained using ColossalAI may include these tensors in
|
94
|
-
# the checkpoint. Skip them.
|
95
|
-
continue
|
96
|
-
if "lm_head" in name:
|
97
|
-
continue
|
80
|
+
params_dict = self.param_dict
|
98
81
|
|
99
|
-
|
100
|
-
|
101
|
-
continue
|
102
|
-
name = name.replace(weight_name, param_name)
|
103
|
-
# Skip loading extra bias for GPTQ models.
|
104
|
-
if name.endswith(".bias") and name not in params_dict:
|
105
|
-
continue
|
106
|
-
param = params_dict[name]
|
107
|
-
weight_loader = param.weight_loader
|
108
|
-
weight_loader(param, loaded_weight, shard_id)
|
109
|
-
break
|
110
|
-
else:
|
111
|
-
# Skip loading extra bias for GPTQ models.
|
112
|
-
if name.endswith(".bias") and name not in params_dict:
|
113
|
-
continue
|
82
|
+
for name, loaded_weight in weights:
|
83
|
+
if "classification_head" in name:
|
114
84
|
param = params_dict[name]
|
115
85
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
116
86
|
weight_loader(param, loaded_weight)
|
87
|
+
elif "lm_head" in name:
|
88
|
+
continue
|
89
|
+
else:
|
90
|
+
LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
|
117
91
|
|
118
92
|
|
119
93
|
EntryClass = LlamaForClassification
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Iterable,
|
1
|
+
from typing import Iterable, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
4
|
from torch import nn
|
@@ -7,7 +7,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
7
7
|
|
8
8
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
9
9
|
from sglang.srt.model_executor.model_runner import InputMetadata
|
10
|
-
from sglang.srt.models.
|
10
|
+
from sglang.srt.models.llama import LlamaModel
|
11
11
|
|
12
12
|
|
13
13
|
class LlamaEmbeddingModel(nn.Module):
|
@@ -16,7 +16,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|
16
16
|
config: LlamaConfig,
|
17
17
|
quant_config=None,
|
18
18
|
cache_config=None,
|
19
|
-
efficient_weight_load=False,
|
20
19
|
) -> None:
|
21
20
|
super().__init__()
|
22
21
|
self.model = LlamaModel(config, quant_config=quant_config)
|
@@ -86,6 +85,8 @@ class LlamaEmbeddingModel(nn.Module):
|
|
86
85
|
load_weights_per_param(name, loaded_weight)
|
87
86
|
|
88
87
|
|
89
|
-
|
90
|
-
|
91
|
-
|
88
|
+
class MistralModel(LlamaEmbeddingModel):
|
89
|
+
pass
|
90
|
+
|
91
|
+
|
92
|
+
EntryClass = [LlamaEmbeddingModel, MistralModel]
|
sglang/srt/models/llava.py
CHANGED
@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
|
|
41
41
|
unpad_image_shape,
|
42
42
|
)
|
43
43
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
44
|
-
from sglang.srt.models.
|
44
|
+
from sglang.srt.models.llama import LlamaForCausalLM
|
45
45
|
from sglang.srt.models.mistral import MistralForCausalLM
|
46
46
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
47
47
|
|
@@ -136,8 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
136
136
|
image_sizes: Optional[List[List[int]]] = None,
|
137
137
|
image_offsets: Optional[List[int]] = None,
|
138
138
|
) -> torch.Tensor:
|
139
|
-
if input_metadata.forward_mode
|
139
|
+
if input_metadata.forward_mode.is_extend():
|
140
140
|
bs = input_metadata.batch_size
|
141
|
+
# Got List[List[str]] extend it to List[str]
|
142
|
+
# The length of the List should be equal to batch size
|
143
|
+
modalities_list = []
|
144
|
+
for modalities in input_metadata.modalities:
|
145
|
+
if modalities is not None:
|
146
|
+
modalities_list.extend(modalities)
|
141
147
|
|
142
148
|
# Embed text inputs
|
143
149
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
@@ -179,11 +185,14 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
179
185
|
new_image_features = []
|
180
186
|
height = width = self.num_patches_per_side
|
181
187
|
for image_idx, image_feature in enumerate(image_features):
|
182
|
-
if
|
188
|
+
if modalities_list[image_idx] == "image":
|
183
189
|
image_aspect_ratio = (
|
184
190
|
self.config.image_aspect_ratio
|
185
191
|
) # single image
|
186
|
-
|
192
|
+
elif (
|
193
|
+
modalities_list[image_idx] == "multi-images"
|
194
|
+
or modalities_list[image_idx] == "video"
|
195
|
+
):
|
187
196
|
image_aspect_ratio = "pad" # multi image
|
188
197
|
# image_aspect_ratio = (
|
189
198
|
# "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
|
@@ -191,6 +200,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
191
200
|
if (
|
192
201
|
image_feature.shape[0] > 1
|
193
202
|
and "anyres" in image_aspect_ratio
|
203
|
+
and modalities_list[image_idx] == "image"
|
194
204
|
):
|
195
205
|
base_image_feature = image_feature[0]
|
196
206
|
image_feature = image_feature[1:]
|
@@ -290,7 +300,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
290
300
|
)
|
291
301
|
image_feature = image_feature.unsqueeze(0)
|
292
302
|
else:
|
293
|
-
if
|
303
|
+
if modalities_list[image_idx] == "video": # video
|
294
304
|
# 2x2 pooling
|
295
305
|
num_of_frames = image_feature.shape[0]
|
296
306
|
image_feature = image_feature.view(
|
@@ -312,6 +322,21 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
312
322
|
.transpose(1, 2)
|
313
323
|
.contiguous()
|
314
324
|
) # N, C, H*W
|
325
|
+
if "unpad" in self.mm_patch_merge_type:
|
326
|
+
image_feature = torch.cat(
|
327
|
+
(
|
328
|
+
image_feature,
|
329
|
+
# Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
|
330
|
+
self.language_model.model.image_newline[
|
331
|
+
None, None
|
332
|
+
].expand(
|
333
|
+
image_feature.shape[0],
|
334
|
+
1,
|
335
|
+
image_feature.shape[-1],
|
336
|
+
),
|
337
|
+
),
|
338
|
+
dim=1,
|
339
|
+
)
|
315
340
|
|
316
341
|
new_image_features.append(image_feature)
|
317
342
|
image_features = new_image_features
|
@@ -350,7 +375,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
350
375
|
return self.language_model(
|
351
376
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
352
377
|
)
|
353
|
-
elif input_metadata.forward_mode
|
378
|
+
elif input_metadata.forward_mode.is_decode():
|
354
379
|
return self.language_model(input_ids, positions, input_metadata)
|
355
380
|
|
356
381
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -395,21 +420,19 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
395
420
|
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
396
421
|
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
397
422
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
423
|
+
"model.image_newline": "language_model.model.image_newline",
|
398
424
|
}
|
399
425
|
params_dict = dict(self.named_parameters())
|
400
|
-
weights = list(weights)
|
401
426
|
for name, loaded_weight in weights:
|
402
|
-
|
403
|
-
if "projector" in name or "vision_tower" in name:
|
427
|
+
if "projector" in name or "vision_tower" in name or "image_newline" in name:
|
404
428
|
for weight_name, param_name in projector_weights.items():
|
405
429
|
if weight_name in name:
|
406
430
|
name = name.replace(weight_name, param_name)
|
407
431
|
param = params_dict[name]
|
408
432
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
409
433
|
weight_loader(param, loaded_weight)
|
410
|
-
|
411
|
-
|
412
|
-
self.language_model.load_weights(weights)
|
434
|
+
else:
|
435
|
+
self.language_model.load_weights([(name, loaded_weight)])
|
413
436
|
|
414
437
|
@property
|
415
438
|
def num_patches_per_side(self):
|
@@ -429,6 +452,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
|
429
452
|
self.vision_tower = None
|
430
453
|
self.config.vision_config.hidden_size = config.mm_hidden_size
|
431
454
|
self.config.text_config.hidden_size = config.hidden_size
|
455
|
+
|
432
456
|
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
433
457
|
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
434
458
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
@@ -448,9 +472,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
|
448
472
|
|
449
473
|
self.config = config
|
450
474
|
self.vision_tower = None
|
475
|
+
|
451
476
|
if getattr(self.config, "vision_config", None) is None:
|
452
477
|
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
453
|
-
|
454
478
|
if getattr(self.config, "text_config", None) is None:
|
455
479
|
self.config.text_config = Qwen2Config(self.config._name_or_path)
|
456
480
|
|
@@ -459,7 +483,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
|
459
483
|
|
460
484
|
if getattr(self.config, "projector_hidden_act", None) is None:
|
461
485
|
self.config.projector_hidden_act = "gelu"
|
462
|
-
|
463
486
|
if getattr(self.config, "image_token_index", None) is None:
|
464
487
|
self.config.image_token_index = 151646
|
465
488
|
|
@@ -482,9 +505,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
482
505
|
|
483
506
|
self.config = config
|
484
507
|
self.vision_tower = None
|
508
|
+
|
485
509
|
if getattr(self.config, "vision_config", None) is None:
|
486
510
|
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
|
487
|
-
|
488
511
|
if getattr(self.config, "text_config", None) is None:
|
489
512
|
self.config.text_config = MistralConfig(self.config._name_or_path)
|
490
513
|
|
@@ -493,7 +516,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
493
516
|
|
494
517
|
if getattr(self.config, "projector_hidden_act", None) is None:
|
495
518
|
self.config.projector_hidden_act = "gelu"
|
496
|
-
|
497
519
|
if getattr(self.config, "image_token_index", None) is None:
|
498
520
|
self.config.image_token_index = 32000
|
499
521
|
|
sglang/srt/models/llavavid.py
CHANGED
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
27
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
28
|
|
29
29
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
30
|
-
from sglang.srt.models.
|
30
|
+
from sglang.srt.models.llama import LlamaForCausalLM
|
31
31
|
|
32
32
|
|
33
33
|
class LlavaVidForCausalLM(nn.Module):
|
@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
116
116
|
image_sizes: Optional[List[List[int]]] = None,
|
117
117
|
image_offsets: Optional[List[int]] = None,
|
118
118
|
) -> torch.Tensor:
|
119
|
-
if input_metadata.forward_mode
|
119
|
+
if input_metadata.forward_mode.is_extend():
|
120
120
|
bs = input_metadata.batch_size
|
121
121
|
|
122
122
|
# Embed text inputs
|
@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
199
199
|
return self.language_model(
|
200
200
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
201
201
|
)
|
202
|
-
elif input_metadata.forward_mode
|
202
|
+
elif input_metadata.forward_mode.is_decode():
|
203
203
|
return self.language_model(input_ids, positions, input_metadata)
|
204
204
|
|
205
205
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -239,12 +239,12 @@ class LlavaVidForCausalLM(nn.Module):
|
|
239
239
|
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
240
240
|
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
241
241
|
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
242
|
+
"model.image_newline": "language_model.model.image_newline",
|
242
243
|
}
|
243
244
|
params_dict = dict(self.named_parameters())
|
244
|
-
weights = list(weights)
|
245
245
|
for name, loaded_weight in weights:
|
246
246
|
# FIXME: why projector weights read two times?
|
247
|
-
if "projector" in name or "vision_tower" in name:
|
247
|
+
if "projector" in name or "vision_tower" in name or "image_newline" in name:
|
248
248
|
for weight_name, param_name in projector_weights.items():
|
249
249
|
if weight_name in name:
|
250
250
|
name = name.replace(weight_name, param_name)
|
@@ -255,9 +255,8 @@ class LlavaVidForCausalLM(nn.Module):
|
|
255
255
|
continue
|
256
256
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
257
257
|
weight_loader(param, loaded_weight)
|
258
|
-
|
259
|
-
|
260
|
-
self.language_model.load_weights(weights)
|
258
|
+
else:
|
259
|
+
self.language_model.load_weights([(name, loaded_weight)])
|
261
260
|
|
262
261
|
@property
|
263
262
|
def num_patches_per_side(self):
|
sglang/srt/models/minicpm.py
CHANGED
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
39
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.layers.sampler import Sampler
|
43
42
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
43
|
|
45
44
|
|
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
|
|
298
297
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
299
298
|
|
300
299
|
self.logits_processor = LogitsProcessor(config)
|
301
|
-
self.sampler = Sampler()
|
302
300
|
|
303
301
|
@torch.no_grad()
|
304
302
|
def forward(
|
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
|
|
316
314
|
lm_head_weight = self.model.embed_tokens.weight
|
317
315
|
else:
|
318
316
|
lm_head_weight = self.lm_head.weight
|
319
|
-
|
317
|
+
return self.logits_processor(
|
320
318
|
input_ids, hidden_states, lm_head_weight, input_metadata
|
321
319
|
)
|
322
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
323
|
-
return sample_output, logits_output
|
324
320
|
|
325
321
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
326
322
|
stacked_params_mapping = [
|