sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/constrained/fsm_cache.py +11 -2
  11. sglang/srt/constrained/jump_forward.py +1 -0
  12. sglang/srt/conversation.py +50 -1
  13. sglang/srt/hf_transformers_utils.py +22 -23
  14. sglang/srt/layers/activation.py +100 -1
  15. sglang/srt/layers/decode_attention.py +338 -50
  16. sglang/srt/layers/fused_moe/layer.py +2 -2
  17. sglang/srt/layers/logits_processor.py +56 -19
  18. sglang/srt/layers/radix_attention.py +3 -4
  19. sglang/srt/layers/sampler.py +101 -0
  20. sglang/srt/managers/controller_multi.py +2 -8
  21. sglang/srt/managers/controller_single.py +7 -10
  22. sglang/srt/managers/detokenizer_manager.py +20 -9
  23. sglang/srt/managers/io_struct.py +44 -11
  24. sglang/srt/managers/policy_scheduler.py +5 -2
  25. sglang/srt/managers/schedule_batch.py +46 -166
  26. sglang/srt/managers/tokenizer_manager.py +192 -83
  27. sglang/srt/managers/tp_worker.py +118 -24
  28. sglang/srt/mem_cache/memory_pool.py +82 -8
  29. sglang/srt/mm_utils.py +79 -7
  30. sglang/srt/model_executor/cuda_graph_runner.py +32 -8
  31. sglang/srt/model_executor/forward_batch_info.py +51 -26
  32. sglang/srt/model_executor/model_runner.py +201 -58
  33. sglang/srt/models/gemma2.py +10 -6
  34. sglang/srt/models/gpt_bigcode.py +1 -1
  35. sglang/srt/models/grok.py +11 -1
  36. sglang/srt/models/llama_embedding.py +4 -0
  37. sglang/srt/models/llava.py +176 -59
  38. sglang/srt/models/qwen2.py +9 -3
  39. sglang/srt/openai_api/adapter.py +200 -39
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +136 -0
  42. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
  43. sglang/srt/server.py +92 -57
  44. sglang/srt/server_args.py +43 -15
  45. sglang/srt/utils.py +26 -16
  46. sglang/test/runners.py +22 -30
  47. sglang/test/simple_eval_common.py +9 -10
  48. sglang/test/simple_eval_gpqa.py +2 -1
  49. sglang/test/simple_eval_humaneval.py +2 -2
  50. sglang/test/simple_eval_math.py +2 -1
  51. sglang/test/simple_eval_mmlu.py +2 -1
  52. sglang/test/test_activation.py +55 -0
  53. sglang/test/test_utils.py +36 -53
  54. sglang/version.py +1 -1
  55. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
  56. sglang-0.2.14.post1.dist-info/RECORD +114 -0
  57. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  58. sglang/launch_server_llavavid.py +0 -29
  59. sglang-0.2.13.dist-info/RECORD +0 -112
  60. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  61. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@ limitations under the License.
15
15
 
16
16
  """Inference-only LLaVa model compatible with HuggingFace weights."""
17
17
 
18
+ import math
19
+ import re
18
20
  from typing import Iterable, List, Optional, Tuple
19
21
 
20
22
  import numpy as np
@@ -26,6 +28,8 @@ from transformers import (
26
28
  LlavaConfig,
27
29
  MistralConfig,
28
30
  Qwen2Config,
31
+ SiglipVisionConfig,
32
+ SiglipVisionModel,
29
33
  )
30
34
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
31
35
  from vllm.config import CacheConfig
@@ -63,34 +67,61 @@ class LlavaLlamaForCausalLM(nn.Module):
63
67
  )
64
68
 
65
69
  def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
66
- new_image_feature_len = self.image_feature_len
67
- # now only support spatial_unpad + anyres
68
- if self.mm_patch_merge_type.startswith("spatial"):
70
+
71
+ # hardcode for spatial_unpad + anyres
72
+ image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
73
+ offset_list = []
74
+ for image_s in image_size:
75
+ if len(image_size) > 16:
76
+ # 2x2 pooling with stride 2
77
+ new_image_feature_len = (
78
+ math.ceil(self.image_size / self.patch_size / 2) ** 2
79
+ )
80
+ else:
81
+ new_image_feature_len = self.image_feature_len # multiimage
82
+
69
83
  height = width = self.num_patches_per_side
70
- if pt_shape[0] > 1:
71
- if self.image_aspect_ratio == "anyres":
72
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(
73
- image_size,
74
- self.image_grid_pinpoints,
75
- self.vision_tower.config.image_size,
84
+ if "anyres" in image_aspect_ratio:
85
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
86
+ image_s,
87
+ self.image_grid_pinpoints,
88
+ self.vision_tower.config.image_size,
89
+ )
90
+ h = num_patch_height * height
91
+ w = num_patch_width * width
92
+ new_h, new_w = unpad_image_shape(h, w, image_s)
93
+
94
+ if "anyres_max" in self.config.image_aspect_ratio:
95
+ matched_anyres_max_num_patches = re.match(
96
+ r"anyres_max_(\d+)", self.config.image_aspect_ratio
97
+ )
98
+ if matched_anyres_max_num_patches:
99
+ max_num_patches = int(matched_anyres_max_num_patches.group(1))
100
+ # times = math.sqrt(h * w / (max_num_patches * unit**2))
101
+ times = math.sqrt(
102
+ new_h * new_w / (max_num_patches * self.image_feature_len)
76
103
  )
77
- if "unpad" in self.mm_patch_merge_type:
78
- h = num_patch_height * height
79
- w = num_patch_width * width
80
- new_h, new_w = unpad_image_shape(h, w, image_size)
81
- new_image_feature_len += new_h * (new_w + 1)
82
-
83
- pad_ids = pad_value * (
84
- (new_image_feature_len + len(pad_value)) // len(pad_value)
85
- )
86
- offset = input_ids.index(self.config.image_token_index)
87
- # old_len + pad_len - 1, because we need to remove image_token_id
88
- new_input_ids = (
89
- input_ids[:offset]
90
- + pad_ids[:new_image_feature_len]
91
- + input_ids[offset + 1 :]
92
- )
93
- return new_input_ids, offset
104
+ if times > 1.1:
105
+ new_h = int(new_h // times)
106
+ new_w = int(new_w // times)
107
+ new_image_feature_len += new_h * (new_w + 1)
108
+
109
+ pad_ids = pad_value * (
110
+ (new_image_feature_len + len(pad_value)) // len(pad_value)
111
+ )
112
+ # print("calculated new_image_feature_len: ", new_image_feature_len)
113
+ try:
114
+ offset = input_ids.index(self.config.image_token_index)
115
+ except ValueError:
116
+ offset = 0
117
+ # old_len + pad_len - 1, because we need to remove image_token_id
118
+ input_ids = (
119
+ input_ids[:offset]
120
+ + pad_ids[:new_image_feature_len]
121
+ + input_ids[offset + 1 :]
122
+ )
123
+ offset_list.append(offset)
124
+ return input_ids, offset_list
94
125
 
95
126
  def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
96
127
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -124,7 +155,6 @@ class LlavaLlamaForCausalLM(nn.Module):
124
155
 
125
156
  # Embed text input
126
157
  input_embeds = self.language_model.model.embed_tokens(input_ids)
127
-
128
158
  # Embed vision input
129
159
  need_vision = (
130
160
  (positions[input_metadata.extend_start_loc] < self.image_feature_len)
@@ -163,27 +193,73 @@ class LlavaLlamaForCausalLM(nn.Module):
163
193
 
164
194
  if self.mm_patch_merge_type.startswith("spatial"):
165
195
  new_image_features = []
196
+ height = width = self.num_patches_per_side
166
197
  for image_idx, image_feature in enumerate(image_features):
167
- if image_feature.shape[0] > 1:
198
+ if len(image_sizes[image_idx]) == 1:
199
+ image_aspect_ratio = (
200
+ self.config.image_aspect_ratio
201
+ ) # single image
202
+ else:
203
+ image_aspect_ratio = "pad" # multi image
204
+ # image_aspect_ratio = (
205
+ # "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
206
+ # )
207
+ if (
208
+ image_feature.shape[0] > 1
209
+ and "anyres" in image_aspect_ratio
210
+ ):
168
211
  base_image_feature = image_feature[0]
169
212
  image_feature = image_feature[1:]
170
- height = width = self.num_patches_per_side
171
213
  assert height * width == base_image_feature.shape[0]
172
- if self.image_aspect_ratio == "anyres":
173
- (
174
- num_patch_width,
175
- num_patch_height,
176
- ) = get_anyres_image_grid_shape(
177
- image_sizes[image_idx],
178
- self.image_grid_pinpoints,
179
- self.vision_tower.config.image_size,
214
+
215
+ if "anyres_max" in image_aspect_ratio:
216
+ matched_anyres_max_num_patches = re.match(
217
+ r"anyres_max_(\d+)", image_aspect_ratio
180
218
  )
219
+ if matched_anyres_max_num_patches:
220
+ max_num_patches = int(
221
+ matched_anyres_max_num_patches.group(1)
222
+ )
223
+
224
+ if (
225
+ image_aspect_ratio == "anyres"
226
+ or "anyres_max" in image_aspect_ratio
227
+ ):
228
+ vision_tower_image_size = self.image_size
229
+ try:
230
+ num_patch_width, num_patch_height = (
231
+ get_anyres_image_grid_shape(
232
+ image_sizes[image_idx][0],
233
+ self.config.image_grid_pinpoints,
234
+ vision_tower_image_size,
235
+ )
236
+ )
237
+ except Exception as e:
238
+ print(f"Error: {e}")
239
+ num_patch_width, num_patch_height = 2, 2
181
240
  image_feature = image_feature.view(
182
241
  num_patch_height, num_patch_width, height, width, -1
183
242
  )
184
243
  else:
185
- raise NotImplementedError()
244
+ image_feature = image_feature.view(
245
+ 2, 2, height, width, -1
246
+ )
247
+
248
+ # (
249
+ # num_patch_width,
250
+ # num_patch_height,
251
+ # ) = get_anyres_image_grid_shape(
252
+ # image_sizes[image_idx][0],
253
+ # self.image_grid_pinpoints,
254
+ # self.vision_tower.config.image_size,
255
+ # )
256
+
257
+ # image_feature = image_feature.view(
258
+ # num_patch_height, num_patch_width, height, width, -1
259
+ # )
260
+
186
261
  if "unpad" in self.mm_patch_merge_type:
262
+ unit = image_feature.shape[2]
187
263
  image_feature = image_feature.permute(
188
264
  4, 0, 2, 1, 3
189
265
  ).contiguous()
@@ -191,8 +267,23 @@ class LlavaLlamaForCausalLM(nn.Module):
191
267
  2, 3
192
268
  )
193
269
  image_feature = unpad_image(
194
- image_feature, image_sizes[image_idx]
270
+ image_feature, image_sizes[image_idx][0]
195
271
  )
272
+ if (
273
+ "anyres_max" in image_aspect_ratio
274
+ and matched_anyres_max_num_patches
275
+ ):
276
+ c, h, w = image_feature.shape
277
+ times = math.sqrt(
278
+ h * w / (max_num_patches * unit**2)
279
+ )
280
+ if times > 1.1:
281
+ image_feature = image_feature[None]
282
+ image_feature = nn.functional.interpolate(
283
+ image_feature,
284
+ [int(h // times), int(w // times)],
285
+ mode="bilinear",
286
+ )[0]
196
287
  image_feature = torch.cat(
197
288
  (
198
289
  image_feature,
@@ -213,16 +304,31 @@ class LlavaLlamaForCausalLM(nn.Module):
213
304
  image_feature = torch.cat(
214
305
  (base_image_feature, image_feature), dim=0
215
306
  )
307
+ image_feature = image_feature.unsqueeze(0)
216
308
  else:
217
- image_feature = image_feature[0]
218
- if "unpad" in self.mm_patch_merge_type:
219
- image_feature = torch.cat(
220
- (
221
- image_feature,
222
- self.language_model.model.image_newline[None],
223
- ),
224
- dim=0,
309
+ if image_feature.shape[0] > 16: # video
310
+ # 2x2 pooling
311
+ num_of_frames = image_feature.shape[0]
312
+ image_feature = image_feature.view(
313
+ num_of_frames, height, width, -1
225
314
  )
315
+ image_feature = image_feature.permute(
316
+ 0, 3, 1, 2
317
+ ).contiguous() # N, C, H, W
318
+ height, weight = image_feature.shape[2:]
319
+ scaled_shape = [
320
+ math.ceil(height / 2),
321
+ math.ceil(weight / 2),
322
+ ]
323
+ image_feature = nn.functional.interpolate(
324
+ image_feature, size=scaled_shape, mode="bilinear"
325
+ )
326
+ image_feature = (
327
+ image_feature.flatten(2)
328
+ .transpose(1, 2)
329
+ .contiguous()
330
+ ) # N, C, H*W
331
+
226
332
  new_image_features.append(image_feature)
227
333
  image_features = new_image_features
228
334
 
@@ -233,21 +339,22 @@ class LlavaLlamaForCausalLM(nn.Module):
233
339
  continue
234
340
 
235
341
  start_idx = extend_start_loc_cpu[i]
236
- pad_len, pad_dim = image_features[pt].shape # 576, 4096
342
+ pad_dim = image_features[pt].shape[-1] # 576, 4096
237
343
  dim = input_embeds.shape[1]
238
344
  assert (
239
345
  pad_dim == dim
240
346
  ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
241
347
  # Fill in the placeholder for the image
242
348
  try:
243
- input_embeds[
244
- start_idx
245
- + image_offsets[i] : start_idx
246
- + image_offsets[i]
247
- + pad_len
248
- ] = image_features[pt]
349
+ for j, image_off in enumerate(image_offsets[i]):
350
+ # print("actual image_features length: ", image_features[pt][j].shape[0])
351
+ pad_len = image_features[pt][j].shape[0]
352
+ input_embeds[
353
+ start_idx + image_off : start_idx + image_off + pad_len
354
+ ] = image_features[pt][j]
249
355
  except RuntimeError as e:
250
356
  print(f"RuntimeError in llava image encoding: {e}")
357
+ print(image_features[pt].shape)
251
358
  print(input_embeds.shape)
252
359
  print(start_idx, image_offsets[i])
253
360
  pt += 1
@@ -262,9 +369,16 @@ class LlavaLlamaForCausalLM(nn.Module):
262
369
  # load clip vision model by cfg['mm_vision_tower']:
263
370
  # huggingface_name or path_of_clip_relative_to_llava_model_dir
264
371
  vision_path = self.config.mm_vision_tower
265
- self.vision_tower = CLIPVisionModel.from_pretrained(
266
- vision_path, torch_dtype=torch.float16
267
- ).cuda()
372
+ if "clip" in vision_path:
373
+ self.vision_tower = CLIPVisionModel.from_pretrained(
374
+ vision_path, torch_dtype=torch.float16
375
+ ).cuda()
376
+ elif "siglip" in vision_path:
377
+ self.vision_tower = SiglipVisionModel.from_pretrained(
378
+ vision_path, torch_dtype=torch.float16
379
+ ).cuda()
380
+ # Siglip needs all feature tokens
381
+ self.config.mm_vision_select_feature = "full"
268
382
  self.vision_tower.eval()
269
383
 
270
384
  self.vision_feature_layer = self.config.mm_vision_select_layer
@@ -276,8 +390,11 @@ class LlavaLlamaForCausalLM(nn.Module):
276
390
  self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
277
391
  self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
278
392
 
279
- self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
280
- if self.vision_feature_select_strategy == "patch":
393
+ self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
394
+ if (
395
+ self.vision_feature_select_strategy == "patch"
396
+ or self.vision_feature_select_strategy == "full"
397
+ ):
281
398
  pass
282
399
  elif self.vision_feature_select_strategy == "cls_patch":
283
400
  self.image_feature_len += 1
@@ -38,6 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
38
  from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
41
42
  from sglang.srt.layers.radix_attention import RadixAttention
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
@@ -275,6 +276,7 @@ class Qwen2ForCausalLM(nn.Module):
275
276
  self.model = Qwen2Model(config, quant_config=quant_config)
276
277
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
277
278
  self.logits_processor = LogitsProcessor(config)
279
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
278
280
 
279
281
  @torch.no_grad()
280
282
  def forward(
@@ -283,11 +285,15 @@ class Qwen2ForCausalLM(nn.Module):
283
285
  positions: torch.Tensor,
284
286
  input_metadata: InputMetadata,
285
287
  input_embeds: torch.Tensor = None,
288
+ get_embedding: bool = False,
286
289
  ) -> torch.Tensor:
287
290
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
288
- return self.logits_processor(
289
- input_ids, hidden_states, self.lm_head.weight, input_metadata
290
- )
291
+ if not get_embedding:
292
+ return self.logits_processor(
293
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
294
+ )
295
+ else:
296
+ return self.pooler(hidden_states, input_metadata)
291
297
 
292
298
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
293
299
  stacked_params_mapping = [