sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.layers.sampler import Sampler
41
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
42
41
 
43
42
 
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
288
287
  self.quant_config = quant_config
289
288
  self.model = GemmaModel(config, quant_config=quant_config)
290
289
  self.logits_processor = LogitsProcessor(config)
291
- self.sampler = Sampler()
292
290
 
293
291
  @torch.no_grad()
294
292
  def forward(
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
299
297
  input_embeds: torch.Tensor = None,
300
298
  ) -> torch.Tensor:
301
299
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
302
- logits_output = self.logits_processor(
300
+ return self.logits_processor(
303
301
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
304
302
  )
305
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
306
- return (sample_output, logits_output)
307
303
 
308
304
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
309
305
  stacked_params_mapping = [
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import GemmaRMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.layers.sampler import Sampler
41
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
42
41
 
43
42
 
@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
347
346
  self.quant_config = quant_config
348
347
  self.model = Gemma2Model(config, cache_config, quant_config)
349
348
  self.logits_processor = LogitsProcessor(config)
350
- self.sampler = Sampler()
351
349
 
352
350
  @torch.no_grad()
353
351
  def forward(
@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
358
356
  input_embeds: torch.Tensor = None,
359
357
  ) -> torch.Tensor:
360
358
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
361
- logits_output = self.logits_processor(
359
+ return self.logits_processor(
362
360
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
363
361
  )
364
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
365
- return sample_output, logits_output
366
362
 
367
363
  def get_attention_sliding_window_size(self):
368
364
  return get_attention_sliding_window_size(self.config)
@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
  from sglang.srt.layers.activation import get_act_fn
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.layers.sampler import Sampler
39
38
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
40
39
 
41
40
 
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
262
261
  if lora_config:
263
262
  self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
264
263
  self.logits_processor = LogitsProcessor(config)
265
- self.sampler = Sampler()
266
264
 
267
265
  @torch.no_grad()
268
266
  def forward(
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
272
270
  input_metadata: InputMetadata,
273
271
  ) -> torch.Tensor:
274
272
  hidden_states = self.transformer(input_ids, positions, input_metadata)
275
- logits_output = self.logits_processor(
273
+ return self.logits_processor(
276
274
  input_ids, hidden_states, self.lm_head.weight, input_metadata
277
275
  )
278
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
279
- return sample_output, logits_output
280
276
 
281
277
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
282
278
  params_dict = dict(self.named_parameters(remove_duplicate=False))
sglang/srt/models/grok.py CHANGED
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.layers.sampler import Sampler
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
50
 
52
51
 
@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
298
297
  self.model = Grok1Model(config, quant_config=quant_config)
299
298
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
300
299
  self.logits_processor = LogitsProcessor(config)
301
- self.sampler = Sampler()
302
300
 
303
301
  # Monkey patch _prepare_weights to load pre-sharded weights
304
302
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
315
313
  input_embeds: torch.Tensor = None,
316
314
  ) -> torch.Tensor:
317
315
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
318
- logits_output = self.logits_processor(
316
+ return self.logits_processor(
319
317
  input_ids, hidden_states, self.lm_head.weight, input_metadata
320
318
  )
321
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
- return sample_output, logits_output
323
319
 
324
320
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
325
321
  stacked_params_mapping = [
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
 
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
263
262
  self.model = InternLM2Model(config, quant_config)
264
263
  self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
265
264
  self.logits_processor = LogitsProcessor(config)
266
- self.sampler = Sampler()
267
265
 
268
266
  @torch.no_grad()
269
267
  def forward(
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
274
272
  input_embeds: torch.Tensor = None,
275
273
  ) -> torch.Tensor:
276
274
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
277
- logits_output = self.logits_processor(
275
+ return self.logits_processor(
278
276
  input_ids, hidden_states, self.output.weight, input_metadata
279
277
  )
280
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
281
- return sample_output, logits_output
282
278
 
283
279
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
284
280
  stacked_params_mapping = [
@@ -41,7 +41,8 @@ from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.sampler import Sampler
44
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
45
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
45
46
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
47
 
47
48
 
@@ -295,15 +296,16 @@ class LlamaForCausalLM(nn.Module):
295
296
  config: LlamaConfig,
296
297
  quant_config: Optional[QuantizationConfig] = None,
297
298
  cache_config: Optional[CacheConfig] = None,
298
- efficient_weight_load=False,
299
299
  ) -> None:
300
300
  super().__init__()
301
301
  self.config = config
302
302
  self.quant_config = quant_config
303
+ self.torchao_config = global_server_args_dict["torchao_config"]
303
304
  self.model = LlamaModel(config, quant_config=quant_config)
304
305
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
305
306
  self.logits_processor = LogitsProcessor(config)
306
- self.sampler = Sampler()
307
+
308
+ self.param_dict = dict(self.named_parameters())
307
309
 
308
310
  @torch.no_grad()
309
311
  def forward(
@@ -314,13 +316,35 @@ class LlamaForCausalLM(nn.Module):
314
316
  input_embeds: torch.Tensor = None,
315
317
  ) -> LogitsProcessorOutput:
316
318
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
317
- logits_output = self.logits_processor(
319
+ return self.logits_processor(
318
320
  input_ids, hidden_states, self.lm_head.weight, input_metadata
319
321
  )
320
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
321
- return sample_output, logits_output
322
+
323
+ def get_hidden_dim(self, module_name):
324
+ if module_name in ["q_proj", "o_proj", "qkv_proj"]:
325
+ return self.config.hidden_size, self.config.hidden_size
326
+ elif module_name in ["kv_proj"]:
327
+ return self.config.hidden_size, self.config.hidden_size // (
328
+ self.config.num_attention_heads // self.config.num_key_value_heads
329
+ )
330
+ elif module_name == "gate_up_proj":
331
+ return self.config.hidden_size, self.config.intermediate_size
332
+ elif module_name == "down_proj":
333
+ return self.config.intermediate_size, self.config.hidden_size
334
+ else:
335
+ raise NotImplementedError()
322
336
 
323
337
  def get_module_name(self, name):
338
+ params_mapping = {
339
+ "q_proj": "qkv_proj",
340
+ "k_proj": "qkv_proj",
341
+ "v_proj": "qkv_proj",
342
+ "gate_proj": "gate_up_proj",
343
+ "up_proj": "gate_up_proj",
344
+ }
345
+ return params_mapping.get(name, name)
346
+
347
+ def get_module_name_from_weight_name(self, name):
324
348
  stacked_params_mapping = [
325
349
  # (param_name, shard_name, shard_id, num_shard)
326
350
  ("qkv_proj", "q_proj", "q", 3),
@@ -341,28 +365,26 @@ class LlamaForCausalLM(nn.Module):
341
365
  params_dict = dict(self.named_parameters())
342
366
  return len(params_dict)
343
367
 
344
- def load_weights(
345
- self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
346
- ):
368
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
347
369
  stacked_params_mapping = [
348
370
  # (param_name, shard_name, shard_id)
349
- ("qkv_proj", "q_proj", "q"),
350
- ("qkv_proj", "k_proj", "k"),
351
- ("qkv_proj", "v_proj", "v"),
352
- ("gate_up_proj", "gate_proj", 0),
353
- ("gate_up_proj", "up_proj", 1),
371
+ (".qkv_proj", ".q_proj", "q"),
372
+ (".qkv_proj", ".k_proj", "k"),
373
+ (".qkv_proj", ".v_proj", "v"),
374
+ (".gate_up_proj", ".gate_proj", 0),
375
+ (".gate_up_proj", ".up_proj", 1),
354
376
  ]
355
- params_dict = dict(self.named_parameters())
377
+ params_dict = self.param_dict
356
378
 
357
- def load_weights_per_param(name, loaded_weight):
379
+ for name, loaded_weight in weights:
358
380
  if "rotary_emb.inv_freq" in name or "projector" in name:
359
- return
381
+ continue
360
382
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
361
383
  # Models trained using ColossalAI may include these tensors in
362
384
  # the checkpoint. Skip them.
363
- return
385
+ continue
364
386
  if name.startswith("model.vision_tower") and name not in params_dict:
365
- return
387
+ continue
366
388
 
367
389
  for param_name, weight_name, shard_id in stacked_params_mapping:
368
390
  if weight_name not in name:
@@ -378,16 +400,16 @@ class LlamaForCausalLM(nn.Module):
378
400
  else:
379
401
  # Skip loading extra bias for GPTQ models.
380
402
  if name.endswith(".bias") and name not in params_dict:
381
- return
403
+ continue
382
404
  param = params_dict[name]
383
405
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
384
406
  weight_loader(param, loaded_weight)
385
407
 
386
- if name is None or loaded_weight is None:
387
- for name, loaded_weight in weights:
388
- load_weights_per_param(name, loaded_weight)
389
- else:
390
- load_weights_per_param(name, loaded_weight)
408
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
409
+
410
+
411
+ class Phi3ForCausalLM(LlamaForCausalLM):
412
+ pass
391
413
 
392
414
 
393
- EntryClass = LlamaForCausalLM
415
+ EntryClass = [LlamaForCausalLM, Phi3ForCausalLM]
@@ -16,17 +16,15 @@ limitations under the License.
16
16
  from typing import Iterable, Optional, Tuple
17
17
 
18
18
  import torch
19
- import tqdm
20
19
  from torch import nn
21
20
  from transformers import LlamaConfig
22
21
  from vllm.config import CacheConfig
23
- from vllm.distributed import get_tensor_model_parallel_rank
24
22
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
23
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
24
 
27
25
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
28
26
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
- from sglang.srt.models.llama2 import LlamaModel
27
+ from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
30
28
 
31
29
 
32
30
  class LlamaForClassification(nn.Module):
@@ -42,10 +40,12 @@ class LlamaForClassification(nn.Module):
42
40
  self.model = LlamaModel(config, quant_config=quant_config)
43
41
 
44
42
  self.classification_head = nn.Linear(
45
- config.hidden_size, config.classification_out_size
43
+ config.hidden_size, config.classification_out_size, bias=False
46
44
  )
47
45
  self.eos_token_id = config.eos_token_id
48
46
 
47
+ self.param_dict = dict(self.named_parameters())
48
+
49
49
  @torch.no_grad()
50
50
  def forward(
51
51
  self,
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
65
65
  (input_metadata.batch_size, self.config.classification_out_size)
66
66
  ).to(input_ids.device)
67
67
 
68
- return LogitsProcessorOutput(
68
+ logits_output = LogitsProcessorOutput(
69
69
  next_token_logits=scores,
70
70
  next_token_logprobs=scores,
71
71
  normalized_prompt_logprobs=scores,
@@ -74,46 +74,20 @@ class LlamaForClassification(nn.Module):
74
74
  output_top_logprobs=None,
75
75
  )
76
76
 
77
+ return logits_output
78
+
77
79
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
78
- stacked_params_mapping = [
79
- # (param_name, shard_name, shard_id)
80
- ("qkv_proj", "q_proj", "q"),
81
- ("qkv_proj", "k_proj", "k"),
82
- ("qkv_proj", "v_proj", "v"),
83
- ("gate_up_proj", "gate_proj", 0),
84
- ("gate_up_proj", "up_proj", 1),
85
- ]
86
- params_dict = dict(self.named_parameters())
87
- if get_tensor_model_parallel_rank() == 0:
88
- weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
89
- for name, loaded_weight in weights:
90
- if "rotary_emb.inv_freq" in name or "projector" in name:
91
- continue
92
- if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
93
- # Models trained using ColossalAI may include these tensors in
94
- # the checkpoint. Skip them.
95
- continue
96
- if "lm_head" in name:
97
- continue
80
+ params_dict = self.param_dict
98
81
 
99
- for param_name, weight_name, shard_id in stacked_params_mapping:
100
- if weight_name not in name:
101
- continue
102
- name = name.replace(weight_name, param_name)
103
- # Skip loading extra bias for GPTQ models.
104
- if name.endswith(".bias") and name not in params_dict:
105
- continue
106
- param = params_dict[name]
107
- weight_loader = param.weight_loader
108
- weight_loader(param, loaded_weight, shard_id)
109
- break
110
- else:
111
- # Skip loading extra bias for GPTQ models.
112
- if name.endswith(".bias") and name not in params_dict:
113
- continue
82
+ for name, loaded_weight in weights:
83
+ if "classification_head" in name:
114
84
  param = params_dict[name]
115
85
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
116
86
  weight_loader(param, loaded_weight)
87
+ elif "lm_head" in name:
88
+ continue
89
+ else:
90
+ LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
117
91
 
118
92
 
119
93
  EntryClass = LlamaForClassification
@@ -1,4 +1,4 @@
1
- from typing import Iterable, Optional, Tuple
1
+ from typing import Iterable, Tuple
2
2
 
3
3
  import torch
4
4
  from torch import nn
@@ -7,7 +7,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7
7
 
8
8
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
9
9
  from sglang.srt.model_executor.model_runner import InputMetadata
10
- from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
10
+ from sglang.srt.models.llama import LlamaModel
11
11
 
12
12
 
13
13
  class LlamaEmbeddingModel(nn.Module):
@@ -16,7 +16,6 @@ class LlamaEmbeddingModel(nn.Module):
16
16
  config: LlamaConfig,
17
17
  quant_config=None,
18
18
  cache_config=None,
19
- efficient_weight_load=False,
20
19
  ) -> None:
21
20
  super().__init__()
22
21
  self.model = LlamaModel(config, quant_config=quant_config)
@@ -86,6 +85,8 @@ class LlamaEmbeddingModel(nn.Module):
86
85
  load_weights_per_param(name, loaded_weight)
87
86
 
88
87
 
89
- EntryClass = LlamaEmbeddingModel
90
- # compat: e5-mistral model.config class == MistralModel
91
- EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
88
+ class MistralModel(LlamaEmbeddingModel):
89
+ pass
90
+
91
+
92
+ EntryClass = [LlamaEmbeddingModel, MistralModel]
@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
41
41
  unpad_image_shape,
42
42
  )
43
43
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
44
- from sglang.srt.models.llama2 import LlamaForCausalLM
44
+ from sglang.srt.models.llama import LlamaForCausalLM
45
45
  from sglang.srt.models.mistral import MistralForCausalLM
46
46
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
47
47
 
@@ -136,8 +136,14 @@ class LlavaBaseForCausalLM(nn.Module):
136
136
  image_sizes: Optional[List[List[int]]] = None,
137
137
  image_offsets: Optional[List[int]] = None,
138
138
  ) -> torch.Tensor:
139
- if input_metadata.forward_mode == ForwardMode.EXTEND:
139
+ if input_metadata.forward_mode.is_extend():
140
140
  bs = input_metadata.batch_size
141
+ # Got List[List[str]] extend it to List[str]
142
+ # The length of the List should be equal to batch size
143
+ modalities_list = []
144
+ for modalities in input_metadata.modalities:
145
+ if modalities is not None:
146
+ modalities_list.extend(modalities)
141
147
 
142
148
  # Embed text inputs
143
149
  input_embeds = self.language_model.model.embed_tokens(input_ids)
@@ -179,11 +185,14 @@ class LlavaBaseForCausalLM(nn.Module):
179
185
  new_image_features = []
180
186
  height = width = self.num_patches_per_side
181
187
  for image_idx, image_feature in enumerate(image_features):
182
- if len(image_sizes[image_idx]) == 1:
188
+ if modalities_list[image_idx] == "image":
183
189
  image_aspect_ratio = (
184
190
  self.config.image_aspect_ratio
185
191
  ) # single image
186
- else:
192
+ elif (
193
+ modalities_list[image_idx] == "multi-images"
194
+ or modalities_list[image_idx] == "video"
195
+ ):
187
196
  image_aspect_ratio = "pad" # multi image
188
197
  # image_aspect_ratio = (
189
198
  # "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
@@ -191,6 +200,7 @@ class LlavaBaseForCausalLM(nn.Module):
191
200
  if (
192
201
  image_feature.shape[0] > 1
193
202
  and "anyres" in image_aspect_ratio
203
+ and modalities_list[image_idx] == "image"
194
204
  ):
195
205
  base_image_feature = image_feature[0]
196
206
  image_feature = image_feature[1:]
@@ -290,7 +300,7 @@ class LlavaBaseForCausalLM(nn.Module):
290
300
  )
291
301
  image_feature = image_feature.unsqueeze(0)
292
302
  else:
293
- if image_feature.shape[0] > 16: # video
303
+ if modalities_list[image_idx] == "video": # video
294
304
  # 2x2 pooling
295
305
  num_of_frames = image_feature.shape[0]
296
306
  image_feature = image_feature.view(
@@ -312,6 +322,21 @@ class LlavaBaseForCausalLM(nn.Module):
312
322
  .transpose(1, 2)
313
323
  .contiguous()
314
324
  ) # N, C, H*W
325
+ if "unpad" in self.mm_patch_merge_type:
326
+ image_feature = torch.cat(
327
+ (
328
+ image_feature,
329
+ # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens
330
+ self.language_model.model.image_newline[
331
+ None, None
332
+ ].expand(
333
+ image_feature.shape[0],
334
+ 1,
335
+ image_feature.shape[-1],
336
+ ),
337
+ ),
338
+ dim=1,
339
+ )
315
340
 
316
341
  new_image_features.append(image_feature)
317
342
  image_features = new_image_features
@@ -350,7 +375,7 @@ class LlavaBaseForCausalLM(nn.Module):
350
375
  return self.language_model(
351
376
  input_ids, positions, input_metadata, input_embeds=input_embeds
352
377
  )
353
- elif input_metadata.forward_mode == ForwardMode.DECODE:
378
+ elif input_metadata.forward_mode.is_decode():
354
379
  return self.language_model(input_ids, positions, input_metadata)
355
380
 
356
381
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -395,21 +420,19 @@ class LlavaBaseForCausalLM(nn.Module):
395
420
  "model.mm_projector.0": "multi_modal_projector.linear_1",
396
421
  "model.mm_projector.2": "multi_modal_projector.linear_2",
397
422
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
423
+ "model.image_newline": "language_model.model.image_newline",
398
424
  }
399
425
  params_dict = dict(self.named_parameters())
400
- weights = list(weights)
401
426
  for name, loaded_weight in weights:
402
- # FIXME: why projector weights read two times?
403
- if "projector" in name or "vision_tower" in name:
427
+ if "projector" in name or "vision_tower" in name or "image_newline" in name:
404
428
  for weight_name, param_name in projector_weights.items():
405
429
  if weight_name in name:
406
430
  name = name.replace(weight_name, param_name)
407
431
  param = params_dict[name]
408
432
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
409
433
  weight_loader(param, loaded_weight)
410
-
411
- # load language model
412
- self.language_model.load_weights(weights)
434
+ else:
435
+ self.language_model.load_weights([(name, loaded_weight)])
413
436
 
414
437
  @property
415
438
  def num_patches_per_side(self):
@@ -429,6 +452,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
429
452
  self.vision_tower = None
430
453
  self.config.vision_config.hidden_size = config.mm_hidden_size
431
454
  self.config.text_config.hidden_size = config.hidden_size
455
+
432
456
  self.multi_modal_projector = LlavaMultiModalProjector(config)
433
457
  self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
434
458
  if "unpad" in getattr(config, "mm_patch_merge_type", ""):
@@ -448,9 +472,9 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
448
472
 
449
473
  self.config = config
450
474
  self.vision_tower = None
475
+
451
476
  if getattr(self.config, "vision_config", None) is None:
452
477
  self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
453
-
454
478
  if getattr(self.config, "text_config", None) is None:
455
479
  self.config.text_config = Qwen2Config(self.config._name_or_path)
456
480
 
@@ -459,7 +483,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
459
483
 
460
484
  if getattr(self.config, "projector_hidden_act", None) is None:
461
485
  self.config.projector_hidden_act = "gelu"
462
-
463
486
  if getattr(self.config, "image_token_index", None) is None:
464
487
  self.config.image_token_index = 151646
465
488
 
@@ -482,9 +505,9 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
482
505
 
483
506
  self.config = config
484
507
  self.vision_tower = None
508
+
485
509
  if getattr(self.config, "vision_config", None) is None:
486
510
  self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
487
-
488
511
  if getattr(self.config, "text_config", None) is None:
489
512
  self.config.text_config = MistralConfig(self.config._name_or_path)
490
513
 
@@ -493,7 +516,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
493
516
 
494
517
  if getattr(self.config, "projector_hidden_act", None) is None:
495
518
  self.config.projector_hidden_act = "gelu"
496
-
497
519
  if getattr(self.config, "image_token_index", None) is None:
498
520
  self.config.image_token_index = 32000
499
521
 
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
27
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
28
 
29
29
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
30
- from sglang.srt.models.llama2 import LlamaForCausalLM
30
+ from sglang.srt.models.llama import LlamaForCausalLM
31
31
 
32
32
 
33
33
  class LlavaVidForCausalLM(nn.Module):
@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
116
116
  image_sizes: Optional[List[List[int]]] = None,
117
117
  image_offsets: Optional[List[int]] = None,
118
118
  ) -> torch.Tensor:
119
- if input_metadata.forward_mode == ForwardMode.EXTEND:
119
+ if input_metadata.forward_mode.is_extend():
120
120
  bs = input_metadata.batch_size
121
121
 
122
122
  # Embed text inputs
@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
199
199
  return self.language_model(
200
200
  input_ids, positions, input_metadata, input_embeds=input_embeds
201
201
  )
202
- elif input_metadata.forward_mode == ForwardMode.DECODE:
202
+ elif input_metadata.forward_mode.is_decode():
203
203
  return self.language_model(input_ids, positions, input_metadata)
204
204
 
205
205
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -239,12 +239,12 @@ class LlavaVidForCausalLM(nn.Module):
239
239
  "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
240
240
  "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
241
241
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
242
+ "model.image_newline": "language_model.model.image_newline",
242
243
  }
243
244
  params_dict = dict(self.named_parameters())
244
- weights = list(weights)
245
245
  for name, loaded_weight in weights:
246
246
  # FIXME: why projector weights read two times?
247
- if "projector" in name or "vision_tower" in name:
247
+ if "projector" in name or "vision_tower" in name or "image_newline" in name:
248
248
  for weight_name, param_name in projector_weights.items():
249
249
  if weight_name in name:
250
250
  name = name.replace(weight_name, param_name)
@@ -255,9 +255,8 @@ class LlavaVidForCausalLM(nn.Module):
255
255
  continue
256
256
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
257
257
  weight_loader(param, loaded_weight)
258
-
259
- # load language model
260
- self.language_model.load_weights(weights)
258
+ else:
259
+ self.language_model.load_weights([(name, loaded_weight)])
261
260
 
262
261
  @property
263
262
  def num_patches_per_side(self):
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.layers.sampler import Sampler
43
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
43
 
45
44
 
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
298
297
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
299
298
 
300
299
  self.logits_processor = LogitsProcessor(config)
301
- self.sampler = Sampler()
302
300
 
303
301
  @torch.no_grad()
304
302
  def forward(
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
316
314
  lm_head_weight = self.model.embed_tokens.weight
317
315
  else:
318
316
  lm_head_weight = self.lm_head.weight
319
- logits_output = self.logits_processor(
317
+ return self.logits_processor(
320
318
  input_ids, hidden_states, lm_head_weight, input_metadata
321
319
  )
322
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
323
- return sample_output, logits_output
324
320
 
325
321
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
326
322
  stacked_params_mapping = [