sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
24
24
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
27
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
28
28
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
29
  from sglang.srt.models.llama2 import LlamaModel
30
30
 
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
65
65
  (input_metadata.batch_size, self.config.classification_out_size)
66
66
  ).to(input_ids.device)
67
67
 
68
- return LogitProcessorOutput(
68
+ return LogitsProcessorOutput(
69
69
  next_token_logits=scores,
70
70
  next_token_logprobs=scores,
71
71
  normalized_prompt_logprobs=scores,
@@ -29,7 +29,11 @@ class LlamaEmbeddingModel(nn.Module):
29
29
  positions: torch.Tensor,
30
30
  input_metadata: InputMetadata,
31
31
  input_embeds: torch.Tensor = None,
32
+ get_embedding: bool = True,
32
33
  ) -> EmbeddingPoolerOutput:
34
+ assert (
35
+ get_embedding
36
+ ), "LlamaEmbeddingModel / MistralModel is only used for embedding"
33
37
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
34
38
  return self.pooler(hidden_states, input_metadata)
35
39
 
@@ -15,6 +15,8 @@ limitations under the License.
15
15
 
16
16
  """Inference-only LLaVa model compatible with HuggingFace weights."""
17
17
 
18
+ import math
19
+ import re
18
20
  from typing import Iterable, List, Optional, Tuple
19
21
 
20
22
  import numpy as np
@@ -26,6 +28,8 @@ from transformers import (
26
28
  LlavaConfig,
27
29
  MistralConfig,
28
30
  Qwen2Config,
31
+ SiglipVisionConfig,
32
+ SiglipVisionModel,
29
33
  )
30
34
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
31
35
  from vllm.config import CacheConfig
@@ -63,34 +67,61 @@ class LlavaLlamaForCausalLM(nn.Module):
63
67
  )
64
68
 
65
69
  def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
66
- new_image_feature_len = self.image_feature_len
67
- # now only support spatial_unpad + anyres
68
- if self.mm_patch_merge_type.startswith("spatial"):
70
+
71
+ # hardcode for spatial_unpad + anyres
72
+ image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
73
+ offset_list = []
74
+ for image_s in image_size:
75
+ if len(image_size) > 16:
76
+ # 2x2 pooling with stride 2
77
+ new_image_feature_len = (
78
+ math.ceil(self.image_size / self.patch_size / 2) ** 2
79
+ )
80
+ else:
81
+ new_image_feature_len = self.image_feature_len # multiimage
82
+
69
83
  height = width = self.num_patches_per_side
70
- if pt_shape[0] > 1:
71
- if self.image_aspect_ratio == "anyres":
72
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(
73
- image_size,
74
- self.image_grid_pinpoints,
75
- self.vision_tower.config.image_size,
84
+ if "anyres" in image_aspect_ratio:
85
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
86
+ image_s,
87
+ self.image_grid_pinpoints,
88
+ self.vision_tower.config.image_size,
89
+ )
90
+ h = num_patch_height * height
91
+ w = num_patch_width * width
92
+ new_h, new_w = unpad_image_shape(h, w, image_s)
93
+
94
+ if "anyres_max" in self.config.image_aspect_ratio:
95
+ matched_anyres_max_num_patches = re.match(
96
+ r"anyres_max_(\d+)", self.config.image_aspect_ratio
97
+ )
98
+ if matched_anyres_max_num_patches:
99
+ max_num_patches = int(matched_anyres_max_num_patches.group(1))
100
+ # times = math.sqrt(h * w / (max_num_patches * unit**2))
101
+ times = math.sqrt(
102
+ new_h * new_w / (max_num_patches * self.image_feature_len)
76
103
  )
77
- if "unpad" in self.mm_patch_merge_type:
78
- h = num_patch_height * height
79
- w = num_patch_width * width
80
- new_h, new_w = unpad_image_shape(h, w, image_size)
81
- new_image_feature_len += new_h * (new_w + 1)
82
-
83
- pad_ids = pad_value * (
84
- (new_image_feature_len + len(pad_value)) // len(pad_value)
85
- )
86
- offset = input_ids.index(self.config.image_token_index)
87
- # old_len + pad_len - 1, because we need to remove image_token_id
88
- new_input_ids = (
89
- input_ids[:offset]
90
- + pad_ids[:new_image_feature_len]
91
- + input_ids[offset + 1 :]
92
- )
93
- return new_input_ids, offset
104
+ if times > 1.1:
105
+ new_h = int(new_h // times)
106
+ new_w = int(new_w // times)
107
+ new_image_feature_len += new_h * (new_w + 1)
108
+
109
+ pad_ids = pad_value * (
110
+ (new_image_feature_len + len(pad_value)) // len(pad_value)
111
+ )
112
+ # print("calculated new_image_feature_len: ", new_image_feature_len)
113
+ try:
114
+ offset = input_ids.index(self.config.image_token_index)
115
+ except ValueError:
116
+ offset = 0
117
+ # old_len + pad_len - 1, because we need to remove image_token_id
118
+ input_ids = (
119
+ input_ids[:offset]
120
+ + pad_ids[:new_image_feature_len]
121
+ + input_ids[offset + 1 :]
122
+ )
123
+ offset_list.append(offset)
124
+ return input_ids, offset_list
94
125
 
95
126
  def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
96
127
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -124,7 +155,6 @@ class LlavaLlamaForCausalLM(nn.Module):
124
155
 
125
156
  # Embed text input
126
157
  input_embeds = self.language_model.model.embed_tokens(input_ids)
127
-
128
158
  # Embed vision input
129
159
  need_vision = (
130
160
  (positions[input_metadata.extend_start_loc] < self.image_feature_len)
@@ -163,27 +193,73 @@ class LlavaLlamaForCausalLM(nn.Module):
163
193
 
164
194
  if self.mm_patch_merge_type.startswith("spatial"):
165
195
  new_image_features = []
196
+ height = width = self.num_patches_per_side
166
197
  for image_idx, image_feature in enumerate(image_features):
167
- if image_feature.shape[0] > 1:
198
+ if len(image_sizes[image_idx]) == 1:
199
+ image_aspect_ratio = (
200
+ self.config.image_aspect_ratio
201
+ ) # single image
202
+ else:
203
+ image_aspect_ratio = "pad" # multi image
204
+ # image_aspect_ratio = (
205
+ # "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
206
+ # )
207
+ if (
208
+ image_feature.shape[0] > 1
209
+ and "anyres" in image_aspect_ratio
210
+ ):
168
211
  base_image_feature = image_feature[0]
169
212
  image_feature = image_feature[1:]
170
- height = width = self.num_patches_per_side
171
213
  assert height * width == base_image_feature.shape[0]
172
- if self.image_aspect_ratio == "anyres":
173
- (
174
- num_patch_width,
175
- num_patch_height,
176
- ) = get_anyres_image_grid_shape(
177
- image_sizes[image_idx],
178
- self.image_grid_pinpoints,
179
- self.vision_tower.config.image_size,
214
+
215
+ if "anyres_max" in image_aspect_ratio:
216
+ matched_anyres_max_num_patches = re.match(
217
+ r"anyres_max_(\d+)", image_aspect_ratio
180
218
  )
219
+ if matched_anyres_max_num_patches:
220
+ max_num_patches = int(
221
+ matched_anyres_max_num_patches.group(1)
222
+ )
223
+
224
+ if (
225
+ image_aspect_ratio == "anyres"
226
+ or "anyres_max" in image_aspect_ratio
227
+ ):
228
+ vision_tower_image_size = self.image_size
229
+ try:
230
+ num_patch_width, num_patch_height = (
231
+ get_anyres_image_grid_shape(
232
+ image_sizes[image_idx][0],
233
+ self.config.image_grid_pinpoints,
234
+ vision_tower_image_size,
235
+ )
236
+ )
237
+ except Exception as e:
238
+ print(f"Error: {e}")
239
+ num_patch_width, num_patch_height = 2, 2
181
240
  image_feature = image_feature.view(
182
241
  num_patch_height, num_patch_width, height, width, -1
183
242
  )
184
243
  else:
185
- raise NotImplementedError()
244
+ image_feature = image_feature.view(
245
+ 2, 2, height, width, -1
246
+ )
247
+
248
+ # (
249
+ # num_patch_width,
250
+ # num_patch_height,
251
+ # ) = get_anyres_image_grid_shape(
252
+ # image_sizes[image_idx][0],
253
+ # self.image_grid_pinpoints,
254
+ # self.vision_tower.config.image_size,
255
+ # )
256
+
257
+ # image_feature = image_feature.view(
258
+ # num_patch_height, num_patch_width, height, width, -1
259
+ # )
260
+
186
261
  if "unpad" in self.mm_patch_merge_type:
262
+ unit = image_feature.shape[2]
187
263
  image_feature = image_feature.permute(
188
264
  4, 0, 2, 1, 3
189
265
  ).contiguous()
@@ -191,8 +267,23 @@ class LlavaLlamaForCausalLM(nn.Module):
191
267
  2, 3
192
268
  )
193
269
  image_feature = unpad_image(
194
- image_feature, image_sizes[image_idx]
270
+ image_feature, image_sizes[image_idx][0]
195
271
  )
272
+ if (
273
+ "anyres_max" in image_aspect_ratio
274
+ and matched_anyres_max_num_patches
275
+ ):
276
+ c, h, w = image_feature.shape
277
+ times = math.sqrt(
278
+ h * w / (max_num_patches * unit**2)
279
+ )
280
+ if times > 1.1:
281
+ image_feature = image_feature[None]
282
+ image_feature = nn.functional.interpolate(
283
+ image_feature,
284
+ [int(h // times), int(w // times)],
285
+ mode="bilinear",
286
+ )[0]
196
287
  image_feature = torch.cat(
197
288
  (
198
289
  image_feature,
@@ -213,16 +304,31 @@ class LlavaLlamaForCausalLM(nn.Module):
213
304
  image_feature = torch.cat(
214
305
  (base_image_feature, image_feature), dim=0
215
306
  )
307
+ image_feature = image_feature.unsqueeze(0)
216
308
  else:
217
- image_feature = image_feature[0]
218
- if "unpad" in self.mm_patch_merge_type:
219
- image_feature = torch.cat(
220
- (
221
- image_feature,
222
- self.language_model.model.image_newline[None],
223
- ),
224
- dim=0,
309
+ if image_feature.shape[0] > 16: # video
310
+ # 2x2 pooling
311
+ num_of_frames = image_feature.shape[0]
312
+ image_feature = image_feature.view(
313
+ num_of_frames, height, width, -1
225
314
  )
315
+ image_feature = image_feature.permute(
316
+ 0, 3, 1, 2
317
+ ).contiguous() # N, C, H, W
318
+ height, weight = image_feature.shape[2:]
319
+ scaled_shape = [
320
+ math.ceil(height / 2),
321
+ math.ceil(weight / 2),
322
+ ]
323
+ image_feature = nn.functional.interpolate(
324
+ image_feature, size=scaled_shape, mode="bilinear"
325
+ )
326
+ image_feature = (
327
+ image_feature.flatten(2)
328
+ .transpose(1, 2)
329
+ .contiguous()
330
+ ) # N, C, H*W
331
+
226
332
  new_image_features.append(image_feature)
227
333
  image_features = new_image_features
228
334
 
@@ -233,21 +339,22 @@ class LlavaLlamaForCausalLM(nn.Module):
233
339
  continue
234
340
 
235
341
  start_idx = extend_start_loc_cpu[i]
236
- pad_len, pad_dim = image_features[pt].shape # 576, 4096
342
+ pad_dim = image_features[pt].shape[-1] # 576, 4096
237
343
  dim = input_embeds.shape[1]
238
344
  assert (
239
345
  pad_dim == dim
240
346
  ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
241
347
  # Fill in the placeholder for the image
242
348
  try:
243
- input_embeds[
244
- start_idx
245
- + image_offsets[i] : start_idx
246
- + image_offsets[i]
247
- + pad_len
248
- ] = image_features[pt]
349
+ for j, image_off in enumerate(image_offsets[i]):
350
+ # print("actual image_features length: ", image_features[pt][j].shape[0])
351
+ pad_len = image_features[pt][j].shape[0]
352
+ input_embeds[
353
+ start_idx + image_off : start_idx + image_off + pad_len
354
+ ] = image_features[pt][j]
249
355
  except RuntimeError as e:
250
356
  print(f"RuntimeError in llava image encoding: {e}")
357
+ print(image_features[pt].shape)
251
358
  print(input_embeds.shape)
252
359
  print(start_idx, image_offsets[i])
253
360
  pt += 1
@@ -262,9 +369,16 @@ class LlavaLlamaForCausalLM(nn.Module):
262
369
  # load clip vision model by cfg['mm_vision_tower']:
263
370
  # huggingface_name or path_of_clip_relative_to_llava_model_dir
264
371
  vision_path = self.config.mm_vision_tower
265
- self.vision_tower = CLIPVisionModel.from_pretrained(
266
- vision_path, torch_dtype=torch.float16
267
- ).cuda()
372
+ if "clip" in vision_path:
373
+ self.vision_tower = CLIPVisionModel.from_pretrained(
374
+ vision_path, torch_dtype=torch.float16
375
+ ).cuda()
376
+ elif "siglip" in vision_path:
377
+ self.vision_tower = SiglipVisionModel.from_pretrained(
378
+ vision_path, torch_dtype=torch.float16
379
+ ).cuda()
380
+ # Siglip needs all feature tokens
381
+ self.config.mm_vision_select_feature = "full"
268
382
  self.vision_tower.eval()
269
383
 
270
384
  self.vision_feature_layer = self.config.mm_vision_select_layer
@@ -276,8 +390,11 @@ class LlavaLlamaForCausalLM(nn.Module):
276
390
  self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
277
391
  self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
278
392
 
279
- self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
280
- if self.vision_feature_select_strategy == "patch":
393
+ self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
394
+ if (
395
+ self.vision_feature_select_strategy == "patch"
396
+ or self.vision_feature_select_strategy == "full"
397
+ ):
281
398
  pass
282
399
  elif self.vision_feature_select_strategy == "cls_patch":
283
400
  self.image_feature_len += 1
@@ -22,8 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.config import CacheConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.activation import SiluAndMul
26
- from vllm.model_executor.layers.layernorm import RMSNorm
27
25
  from vllm.model_executor.layers.linear import (
28
26
  MergedColumnParallelLinear,
29
27
  QKVParallelLinear,
@@ -37,8 +35,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
37
35
  )
38
36
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
38
+ from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.sampler import Sampler
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
44
45
 
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
297
298
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
298
299
 
299
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
300
302
 
301
303
  @torch.no_grad()
302
304
  def forward(
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
314
316
  lm_head_weight = self.model.embed_tokens.weight
315
317
  else:
316
318
  lm_head_weight = self.lm_head.weight
317
- return self.logits_processor(
319
+ logits_output = self.logits_processor(
318
320
  input_ids, hidden_states, lm_head_weight, input_metadata
319
321
  )
322
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
323
+ return sample_output, logits_output
320
324
 
321
325
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
322
326
  stacked_params_mapping = [