sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
44
  from sglang.srt.layers.torchao_utils import apply_torchao_config_
45
45
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
 
48
48
 
49
49
  class LlamaMLP(nn.Module):
@@ -162,12 +162,12 @@ class LlamaAttention(nn.Module):
162
162
  self,
163
163
  positions: torch.Tensor,
164
164
  hidden_states: torch.Tensor,
165
- input_metadata: InputMetadata,
165
+ forward_batch: ForwardBatch,
166
166
  ) -> torch.Tensor:
167
167
  qkv, _ = self.qkv_proj(hidden_states)
168
168
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
169
169
  q, k = self.rotary_emb(positions, q, k)
170
- attn_output = self.attn(q, k, v, input_metadata)
170
+ attn_output = self.attn(q, k, v, forward_batch)
171
171
  output, _ = self.o_proj(attn_output)
172
172
  return output
173
173
 
@@ -221,7 +221,7 @@ class LlamaDecoderLayer(nn.Module):
221
221
  self,
222
222
  positions: torch.Tensor,
223
223
  hidden_states: torch.Tensor,
224
- input_metadata: InputMetadata,
224
+ forward_batch: ForwardBatch,
225
225
  residual: Optional[torch.Tensor],
226
226
  ) -> Tuple[torch.Tensor, torch.Tensor]:
227
227
  # Self Attention
@@ -233,7 +233,7 @@ class LlamaDecoderLayer(nn.Module):
233
233
  hidden_states = self.self_attn(
234
234
  positions=positions,
235
235
  hidden_states=hidden_states,
236
- input_metadata=input_metadata,
236
+ forward_batch=forward_batch,
237
237
  )
238
238
 
239
239
  # Fully Connected
@@ -270,7 +270,7 @@ class LlamaModel(nn.Module):
270
270
  self,
271
271
  input_ids: torch.Tensor,
272
272
  positions: torch.Tensor,
273
- input_metadata: InputMetadata,
273
+ forward_batch: ForwardBatch,
274
274
  input_embeds: torch.Tensor = None,
275
275
  ) -> torch.Tensor:
276
276
  if input_embeds is None:
@@ -283,7 +283,7 @@ class LlamaModel(nn.Module):
283
283
  hidden_states, residual = layer(
284
284
  positions,
285
285
  hidden_states,
286
- input_metadata,
286
+ forward_batch,
287
287
  residual,
288
288
  )
289
289
  hidden_states, _ = self.norm(hidden_states, residual)
@@ -310,15 +310,16 @@ class LlamaForCausalLM(nn.Module):
310
310
  self,
311
311
  input_ids: torch.Tensor,
312
312
  positions: torch.Tensor,
313
- input_metadata: InputMetadata,
313
+ forward_batch: ForwardBatch,
314
314
  input_embeds: torch.Tensor = None,
315
315
  ) -> LogitsProcessorOutput:
316
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
316
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
317
317
  return self.logits_processor(
318
- input_ids, hidden_states, self.lm_head.weight, input_metadata
318
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
319
319
  )
320
320
 
321
321
  def get_hidden_dim(self, module_name):
322
+ # return input_dim, output_dim
322
323
  if module_name in ["q_proj", "o_proj", "qkv_proj"]:
323
324
  return self.config.hidden_size, self.config.hidden_size
324
325
  elif module_name in ["kv_proj"]:
@@ -399,6 +400,9 @@ class LlamaForCausalLM(nn.Module):
399
400
  # Skip loading extra bias for GPTQ models.
400
401
  if name.endswith(".bias") and name not in params_dict:
401
402
  continue
403
+ # Skip loading kv_scale from ckpts towards new design.
404
+ if name.endswith(".kv_scale") and name not in params_dict:
405
+ continue
402
406
  param = params_dict[name]
403
407
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
404
408
  weight_loader(param, loaded_weight)
@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
23
 
24
24
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
25
25
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
26
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
27
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
28
28
 
29
29
 
@@ -50,18 +50,18 @@ class LlamaForClassification(nn.Module):
50
50
  self,
51
51
  input_ids: torch.Tensor,
52
52
  positions: torch.Tensor,
53
- input_metadata: InputMetadata,
53
+ forward_batch: ForwardBatch,
54
54
  input_embeds: torch.Tensor = None,
55
55
  ) -> torch.Tensor:
56
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
56
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
57
57
  is_eos_token = input_ids == self.eos_token_id
58
58
  hidden_states = hidden_states[is_eos_token]
59
59
  scores = self.classification_head(hidden_states)
60
60
 
61
- if scores.shape[0] != input_metadata.batch_size:
61
+ if scores.shape[0] != forward_batch.batch_size:
62
62
  print("Warning: the EOS tokens are missing in some sentences.")
63
63
  scores = torch.ones(
64
- (input_metadata.batch_size, self.config.classification_out_size)
64
+ (forward_batch.batch_size, self.config.classification_out_size)
65
65
  ).to(input_ids.device)
66
66
 
67
67
  logits_output = LogitsProcessorOutput(
@@ -6,7 +6,7 @@ from transformers import LlamaConfig
6
6
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7
7
 
8
8
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
9
- from sglang.srt.model_executor.model_runner import InputMetadata
9
+ from sglang.srt.model_executor.model_runner import ForwardBatch
10
10
  from sglang.srt.models.llama import LlamaModel
11
11
 
12
12
 
@@ -26,15 +26,15 @@ class LlamaEmbeddingModel(nn.Module):
26
26
  self,
27
27
  input_ids: torch.Tensor,
28
28
  positions: torch.Tensor,
29
- input_metadata: InputMetadata,
29
+ forward_batch: ForwardBatch,
30
30
  input_embeds: torch.Tensor = None,
31
31
  get_embedding: bool = True,
32
32
  ) -> EmbeddingPoolerOutput:
33
33
  assert (
34
34
  get_embedding
35
35
  ), "LlamaEmbeddingModel / MistralModel is only used for embedding"
36
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
37
- return self.pooler(hidden_states, input_metadata)
36
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
37
+ return self.pooler(hidden_states, forward_batch)
38
38
 
39
39
  def load_weights(
40
40
  self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
@@ -0,0 +1,142 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ from typing import Iterable, Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+ from transformers import LlamaConfig
21
+ from vllm.config import CacheConfig
22
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
+
24
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
25
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
26
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
29
+
30
+
31
+ class LlamaForSequenceClassification(nn.Module):
32
+ def __init__(
33
+ self,
34
+ config: LlamaConfig,
35
+ quant_config: Optional[QuantizationConfig] = None,
36
+ cache_config: Optional[CacheConfig] = None,
37
+ ) -> None:
38
+ super().__init__()
39
+ self.config = config
40
+ self.torchao_config = None
41
+ self.quant_config = quant_config
42
+ self.num_labels = config.num_labels
43
+ self.model = LlamaModel(config, quant_config=quant_config)
44
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
45
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
46
+
47
+ self.eos_token_id = config.eos_token_id
48
+
49
+ @torch.no_grad()
50
+ def forward(
51
+ self,
52
+ input_ids: torch.Tensor,
53
+ positions: torch.Tensor,
54
+ forward_batch: ForwardBatch,
55
+ input_embeds: torch.Tensor = None,
56
+ ) -> EmbeddingPoolerOutput:
57
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
58
+ scores = self.score(hidden_states)
59
+
60
+ return self.pooler(scores, forward_batch)
61
+
62
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
63
+ params_dict = dict(self.named_parameters())
64
+
65
+ for name, loaded_weight in weights:
66
+ if "classification_head" in name:
67
+ param = params_dict[name]
68
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
69
+ weight_loader(param, loaded_weight)
70
+ elif "lm_head" in name:
71
+ continue
72
+ else:
73
+ LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
74
+
75
+
76
+ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
77
+ class Weights(torch.nn.Module):
78
+ def __init__(self, hidden_size, num_label):
79
+ super().__init__()
80
+ self.fc = torch.nn.Sequential(
81
+ torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
82
+ torch.nn.SELU(),
83
+ torch.nn.Linear(hidden_size, hidden_size, dtype=torch.float16),
84
+ torch.nn.SELU(),
85
+ torch.nn.Linear(hidden_size, num_label // 2, dtype=torch.float16),
86
+ )
87
+
88
+ def forward(self, x):
89
+ return self.fc(x.to(torch.float16))
90
+
91
+ def __init__(
92
+ self,
93
+ config: LlamaConfig,
94
+ quant_config: Optional[QuantizationConfig] = None,
95
+ cache_config: Optional[CacheConfig] = None,
96
+ ) -> None:
97
+ super().__init__(config, quant_config, cache_config)
98
+ self.weights = self.Weights(config.hidden_size, self.num_labels)
99
+
100
+ @torch.no_grad()
101
+ def forward(
102
+ self,
103
+ input_ids: torch.Tensor,
104
+ positions: torch.Tensor,
105
+ forward_batch: ForwardBatch,
106
+ input_embeds: torch.Tensor = None,
107
+ get_embedding: bool = True,
108
+ ) -> EmbeddingPoolerOutput:
109
+ assert (
110
+ get_embedding
111
+ ), "LlamaForSequenceClassification is only used for embedding"
112
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
113
+ logits = self.score(hidden_states)
114
+ weights = self.weights(hidden_states)
115
+
116
+ pooled_logits = self.pooler(logits, forward_batch).embeddings
117
+ pooled_weights = self.pooler(weights, forward_batch).embeddings
118
+
119
+ rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
120
+ -1, self.num_labels // 2
121
+ )
122
+ scores = (rews * pooled_weights).sum(dim=-1).view(-1, 1)
123
+ return EmbeddingPoolerOutput(scores)
124
+
125
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
126
+ params_dict = dict(self.named_parameters())
127
+
128
+ for name, loaded_weight in weights:
129
+ if "classification_head" in name:
130
+ param = params_dict[name]
131
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
132
+ weight_loader(param, loaded_weight)
133
+ elif "lm_head" in name:
134
+ continue
135
+ else:
136
+ LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
137
+
138
+
139
+ EntryClass = [
140
+ LlamaForSequenceClassification,
141
+ LlamaForSequenceClassificationWithNormal_Weights,
142
+ ]
@@ -35,25 +35,22 @@ from vllm.config import CacheConfig
35
35
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
36
 
37
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.managers.schedule_batch import ImageInputs
38
39
  from sglang.srt.mm_utils import (
39
40
  get_anyres_image_grid_shape,
40
41
  unpad_image,
41
42
  unpad_image_shape,
42
43
  )
43
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
44
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
45
  from sglang.srt.models.llama import LlamaForCausalLM
45
46
  from sglang.srt.models.mistral import MistralForCausalLM
46
47
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
47
48
 
48
49
 
49
50
  class LlavaBaseForCausalLM(nn.Module):
50
- def pad_input_ids(
51
- self,
52
- input_ids: List[int],
53
- pad_value: List[int],
54
- pixel_values: List,
55
- image_sizes: List[List[int]],
56
- ):
51
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
52
+ image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
53
+
57
54
  # hardcode for spatial_unpad + anyres
58
55
  image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
59
56
  offset_list = []
@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module):
92
89
  new_w = int(new_w // times)
93
90
  new_image_feature_len += new_h * (new_w + 1)
94
91
 
95
- pad_ids = pad_value * (
96
- (new_image_feature_len + len(pad_value)) // len(pad_value)
92
+ pad_ids = pad_values * (
93
+ (new_image_feature_len + len(pad_values)) // len(pad_values)
97
94
  )
98
95
  # print("calculated new_image_feature_len: ", new_image_feature_len)
99
96
  try:
@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module):
107
104
  + input_ids[offset + 1 :]
108
105
  )
109
106
  offset_list.append(offset)
110
- return input_ids, offset_list
107
+
108
+ image_inputs.image_offsets = offset_list
109
+ return input_ids
111
110
 
112
111
  def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
113
112
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -131,33 +130,40 @@ class LlavaBaseForCausalLM(nn.Module):
131
130
  self,
132
131
  input_ids: torch.LongTensor,
133
132
  positions: torch.Tensor,
134
- input_metadata: InputMetadata,
135
- pixel_values: Optional[List[Optional[np.array]]] = None,
136
- image_sizes: Optional[List[List[int]]] = None,
137
- image_offsets: Optional[List[int]] = None,
133
+ forward_batch: ForwardBatch,
138
134
  ) -> torch.Tensor:
139
- if input_metadata.forward_mode.is_extend():
140
- bs = input_metadata.batch_size
135
+ image_inputs = forward_batch.image_inputs
136
+
137
+ if forward_batch.forward_mode.is_extend():
138
+ bs = forward_batch.batch_size
141
139
  # Got List[List[str]] extend it to List[str]
142
140
  # The length of the List should be equal to batch size
143
141
  modalities_list = []
144
- for modalities in input_metadata.modalities:
145
- if modalities is not None:
146
- modalities_list.extend(modalities)
142
+ max_image_offset = []
143
+ for im in image_inputs:
144
+ if im and im.modalities is not None:
145
+ modalities_list.extend(im.modalities)
146
+ if im and im.image_offsets is not None:
147
+ max_image_offset.append(max(im.image_offsets))
148
+ else:
149
+ max_image_offset.append(-1)
147
150
 
148
151
  # Embed text inputs
149
152
  input_embeds = self.language_model.model.embed_tokens(input_ids)
150
153
 
151
- # Whether the requests need vision inputs
152
- max_image_offset = np.array(
153
- [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
154
- )
155
- start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
156
- need_vision = start_positions <= max_image_offset
154
+ start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
155
+ need_vision = start_positions <= np.array(max_image_offset)
157
156
 
158
157
  if need_vision.any():
159
- pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
160
- image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
158
+ pixel_values = [
159
+ image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
160
+ ]
161
+ image_sizes = [
162
+ image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
163
+ ]
164
+ image_offsets = [
165
+ image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
166
+ ]
161
167
 
162
168
  ########## Encode Image ########
163
169
 
@@ -342,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module):
342
348
  image_features = new_image_features
343
349
 
344
350
  # Fill in the placeholder for the image
345
- extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
346
- prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
351
+ extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
352
+ prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
347
353
  pt = 0
348
354
  for i in range(bs):
349
355
  if not need_vision[i]:
@@ -373,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module):
373
379
  pt += 1
374
380
 
375
381
  return self.language_model(
376
- input_ids, positions, input_metadata, input_embeds=input_embeds
382
+ input_ids, positions, forward_batch, input_embeds=input_embeds
377
383
  )
378
- elif input_metadata.forward_mode.is_decode():
379
- return self.language_model(input_ids, positions, input_metadata)
384
+ elif forward_batch.forward_mode.is_decode():
385
+ return self.language_model(input_ids, positions, forward_batch)
380
386
 
381
387
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
382
388
  # Load clip vision model by cfg['mm_vision_tower']:
@@ -26,7 +26,8 @@ from vllm.config import CacheConfig
26
26
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
27
 
28
28
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
29
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
29
+ from sglang.srt.managers.schedule_batch import ImageInputs
30
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
31
  from sglang.srt.models.llama import LlamaForCausalLM
31
32
 
32
33
 
@@ -54,17 +55,12 @@ class LlavaVidForCausalLM(nn.Module):
54
55
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
55
56
  )
56
57
 
57
- def pad_input_ids(
58
- self,
59
- input_ids: List[int],
60
- pad_value: List[int],
61
- pixel_values: List,
62
- image_sizes: List[List[int]],
63
- ):
58
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
59
+ pad_values = image_inputs.pad_values
64
60
  new_image_feature_len = self.image_feature_len
65
61
 
66
- pad_ids = pad_value * (
67
- (new_image_feature_len + len(pad_value)) // len(pad_value)
62
+ pad_ids = pad_values * (
63
+ (new_image_feature_len + len(pad_values)) // len(pad_values)
68
64
  )
69
65
  offset = input_ids.index(self.config.image_token_index)
70
66
  # old_len + pad_len - 1, because we need to remove image_token_id
@@ -73,7 +69,8 @@ class LlavaVidForCausalLM(nn.Module):
73
69
  + pad_ids[:new_image_feature_len]
74
70
  + input_ids[offset + 1 :]
75
71
  )
76
- return new_input_ids, [offset]
72
+ image_inputs.image_offsets = [offset]
73
+ return new_input_ids
77
74
 
78
75
  def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
79
76
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -111,26 +108,32 @@ class LlavaVidForCausalLM(nn.Module):
111
108
  self,
112
109
  input_ids: torch.LongTensor,
113
110
  positions: torch.Tensor,
114
- input_metadata: InputMetadata,
115
- pixel_values: Optional[List[Optional[np.array]]] = None,
116
- image_sizes: Optional[List[List[int]]] = None,
117
- image_offsets: Optional[List[int]] = None,
111
+ forward_batch: ForwardBatch,
118
112
  ) -> torch.Tensor:
119
- if input_metadata.forward_mode.is_extend():
120
- bs = input_metadata.batch_size
113
+ image_inputs = forward_batch.image_inputs
114
+ if forward_batch.forward_mode.is_extend():
115
+ bs = forward_batch.batch_size
121
116
 
122
117
  # Embed text inputs
123
118
  input_embeds = self.language_model.model.embed_tokens(input_ids)
124
119
 
125
120
  # Whether the requests need vision inputs
126
- max_image_offset = np.array(
127
- [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
128
- )
129
- start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
130
- need_vision = start_positions <= max_image_offset
121
+ max_image_offset = []
122
+ for im in image_inputs:
123
+ if im and im.image_offsets:
124
+ max_image_offset.append(max(im.image_offsets))
125
+ else:
126
+ max_image_offset.append(-1)
127
+ start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
128
+ need_vision = start_positions <= np.array(max_image_offset)
131
129
 
132
130
  if need_vision.any():
133
- pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
131
+ pixel_values = [
132
+ image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
133
+ ]
134
+ image_offsets = [
135
+ image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
136
+ ]
134
137
 
135
138
  ########## Encode Image ########
136
139
 
@@ -166,8 +169,8 @@ class LlavaVidForCausalLM(nn.Module):
166
169
  image_features = new_image_features
167
170
 
168
171
  # Fill in the placeholder for the image
169
- extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
170
- prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
172
+ extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
173
+ prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
171
174
  pt = 0
172
175
  for i in range(bs):
173
176
  if not need_vision[i]:
@@ -197,10 +200,10 @@ class LlavaVidForCausalLM(nn.Module):
197
200
  pt += 1
198
201
 
199
202
  return self.language_model(
200
- input_ids, positions, input_metadata, input_embeds=input_embeds
203
+ input_ids, positions, forward_batch, input_embeds=input_embeds
201
204
  )
202
- elif input_metadata.forward_mode.is_decode():
203
- return self.language_model(input_ids, positions, input_metadata)
205
+ elif forward_batch.forward_mode.is_decode():
206
+ return self.language_model(input_ids, positions, forward_batch)
204
207
 
205
208
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
206
209
  # Load clip vision model by cfg['mm_vision_tower']:
@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
39
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
 
44
44
 
45
45
  class MiniCPMMLP(nn.Module):
@@ -148,7 +148,7 @@ class MiniCPMAttention(nn.Module):
148
148
  self,
149
149
  positions: torch.Tensor,
150
150
  hidden_states: torch.Tensor,
151
- input_metadata: InputMetadata,
151
+ forward_batch: ForwardBatch,
152
152
  ) -> torch.Tensor:
153
153
  qkv, _ = self.qkv_proj(hidden_states)
154
154
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -156,7 +156,7 @@ class MiniCPMAttention(nn.Module):
156
156
  q, k = q.float(), k.float()
157
157
  q, k = self.rotary_emb(positions, q, k)
158
158
  q, k = q.to(orig_dtype), k.to(orig_dtype)
159
- attn_output = self.attn(q, k, v, input_metadata)
159
+ attn_output = self.attn(q, k, v, forward_batch)
160
160
  output, _ = self.o_proj(attn_output)
161
161
  return output
162
162
 
@@ -199,7 +199,7 @@ class MiniCPMDecoderLayer(nn.Module):
199
199
  self,
200
200
  positions: torch.Tensor,
201
201
  hidden_states: torch.Tensor,
202
- input_metadata: InputMetadata,
202
+ forward_batch: ForwardBatch,
203
203
  residual: Optional[torch.Tensor],
204
204
  ) -> Tuple[torch.Tensor, torch.Tensor]:
205
205
  # Self Attention
@@ -208,7 +208,7 @@ class MiniCPMDecoderLayer(nn.Module):
208
208
  hidden_states = self.self_attn(
209
209
  positions=positions,
210
210
  hidden_states=hidden_states,
211
- input_metadata=input_metadata,
211
+ forward_batch=forward_batch,
212
212
  )
213
213
  hidden_states = residual + hidden_states * (
214
214
  self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
@@ -252,7 +252,7 @@ class MiniCPMModel(nn.Module):
252
252
  self,
253
253
  input_ids: torch.Tensor,
254
254
  positions: torch.Tensor,
255
- input_metadata: InputMetadata,
255
+ forward_batch: ForwardBatch,
256
256
  input_embeds: torch.Tensor = None,
257
257
  ) -> torch.Tensor:
258
258
  if input_embeds is None:
@@ -266,7 +266,7 @@ class MiniCPMModel(nn.Module):
266
266
  hidden_states, residual = layer(
267
267
  positions,
268
268
  hidden_states,
269
- input_metadata,
269
+ forward_batch,
270
270
  residual,
271
271
  )
272
272
  hidden_states = self.norm(hidden_states)
@@ -303,19 +303,19 @@ class MiniCPMForCausalLM(nn.Module):
303
303
  self,
304
304
  input_ids: torch.Tensor,
305
305
  positions: torch.Tensor,
306
- input_metadata: InputMetadata,
306
+ forward_batch: ForwardBatch,
307
307
  input_embeds: torch.Tensor = None,
308
308
  ) -> torch.Tensor:
309
309
  if input_embeds is not None:
310
310
  input_embeds = input_embeds * self.config.scale_emb
311
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
311
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
312
312
  hidden_states = hidden_states / self.scale_width
313
313
  if self.config.tie_word_embeddings:
314
314
  lm_head_weight = self.model.embed_tokens.weight
315
315
  else:
316
316
  lm_head_weight = self.lm_head.weight
317
317
  return self.logits_processor(
318
- input_ids, hidden_states, lm_head_weight, input_metadata
318
+ input_ids, hidden_states, lm_head_weight, forward_batch
319
319
  )
320
320
 
321
321
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):