sglang 0.4.3__py3-none-any.whl → 0.4.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -115,6 +115,9 @@ class Engine:
115
115
  sampling_params: Optional[Union[List[Dict], Dict]] = None,
116
116
  # The token ids for text; one can either specify text or input_ids.
117
117
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
118
+ # The image input. It can be a file name, a url, or base64 encoded string.
119
+ # See also python/sglang/srt/utils.py:load_image.
120
+ image_data: Optional[Union[List[str], str]] = None,
118
121
  return_logprob: Optional[Union[List[bool], bool]] = False,
119
122
  logprob_start_len: Optional[Union[List[int], int]] = None,
120
123
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -126,14 +129,20 @@ class Engine:
126
129
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
127
130
  Please refer to `GenerateReqInput` for the documentation.
128
131
  """
132
+ modalities_list = []
133
+ if image_data is not None:
134
+ modalities_list.append("image")
135
+
129
136
  obj = GenerateReqInput(
130
137
  text=prompt,
131
138
  input_ids=input_ids,
132
139
  sampling_params=sampling_params,
140
+ image_data=image_data,
133
141
  return_logprob=return_logprob,
134
142
  logprob_start_len=logprob_start_len,
135
143
  top_logprobs_num=top_logprobs_num,
136
144
  lora_path=lora_path,
145
+ modalities=modalities_list,
137
146
  custom_logit_processor=custom_logit_processor,
138
147
  stream=stream,
139
148
  )
@@ -162,6 +171,9 @@ class Engine:
162
171
  sampling_params: Optional[Union[List[Dict], Dict]] = None,
163
172
  # The token ids for text; one can either specify text or input_ids.
164
173
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
174
+ # The image input. It can be a file name, a url, or base64 encoded string.
175
+ # See also python/sglang/srt/utils.py:load_image.
176
+ image_data: Optional[Union[List[str], str]] = None,
165
177
  return_logprob: Optional[Union[List[bool], bool]] = False,
166
178
  logprob_start_len: Optional[Union[List[int], int]] = None,
167
179
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -177,6 +189,7 @@ class Engine:
177
189
  text=prompt,
178
190
  input_ids=input_ids,
179
191
  sampling_params=sampling_params,
192
+ image_data=image_data,
180
193
  return_logprob=return_logprob,
181
194
  logprob_start_len=logprob_start_len,
182
195
  top_logprobs_num=top_logprobs_num,
@@ -425,7 +438,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
425
438
  # Launch tokenizer process
426
439
  tokenizer_manager = TokenizerManager(server_args, port_args)
427
440
  if server_args.chat_template:
428
- load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
441
+ load_chat_template_for_openai_api(
442
+ tokenizer_manager, server_args.chat_template, server_args.model_path
443
+ )
429
444
 
430
445
  # Wait for the model to finish loading
431
446
  scheduler_infos = []
@@ -30,16 +30,15 @@ from transformers import (
30
30
  )
31
31
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
32
32
 
33
- from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2VLConfig
33
+ from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2_5_VLConfig
34
34
 
35
35
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
36
36
  ChatGLMConfig.model_type: ChatGLMConfig,
37
37
  DbrxConfig.model_type: DbrxConfig,
38
38
  ExaoneConfig.model_type: ExaoneConfig,
39
- Qwen2VLConfig.model_type: Qwen2VLConfig,
39
+ Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
40
40
  }
41
41
 
42
-
43
42
  for name, cls in _CONFIG_REGISTRY.items():
44
43
  with contextlib.suppress(ValueError):
45
44
  AutoConfig.register(name, cls)
@@ -1,6 +1,7 @@
1
1
  # TODO: also move pad_input_ids into this module
2
2
  import asyncio
3
3
  import concurrent.futures
4
+ import dataclasses
4
5
  import logging
5
6
  import multiprocessing as mp
6
7
  import os
@@ -8,6 +9,7 @@ from abc import ABC, abstractmethod
8
9
  from typing import List, Optional, Union
9
10
 
10
11
  import numpy as np
12
+ import PIL
11
13
  import transformers
12
14
  from decord import VideoReader, cpu
13
15
  from PIL import Image
@@ -34,11 +36,22 @@ def init_global_processor(server_args: ServerArgs):
34
36
  )
35
37
 
36
38
 
39
+ @dataclasses.dataclass
40
+ class BaseImageProcessorOutput:
41
+ image_hashes: list[int]
42
+ image_sizes: list[int]
43
+ all_frames: [PIL.Image]
44
+ # input_text, with each frame of video/image represented with a image_token
45
+ input_text: str
46
+
47
+
37
48
  class BaseImageProcessor(ABC):
38
49
  def __init__(self, hf_config, server_args, _processor):
39
50
  self.hf_config = hf_config
40
51
  self._processor = _processor
41
52
  self.server_args = server_args
53
+ # FIXME: not accurate, model and image specific
54
+ self.NUM_TOKEN_PER_FRAME = 330
42
55
 
43
56
  self.executor = concurrent.futures.ProcessPoolExecutor(
44
57
  initializer=init_global_processor,
@@ -48,9 +61,128 @@ class BaseImageProcessor(ABC):
48
61
  )
49
62
 
50
63
  @abstractmethod
51
- async def process_images_async(self, image_data, input_text, **kwargs):
64
+ async def process_images_async(
65
+ self, image_data, input_text, max_req_input_len, **kwargs
66
+ ):
52
67
  pass
53
68
 
69
+ def get_estimated_frames_list(self, image_data):
70
+ """
71
+ estimate the total frame count from all visual input
72
+ """
73
+ # Before processing inputs
74
+ estimated_frames_list = []
75
+ for image in image_data:
76
+ if isinstance(image, str) and image.startswith("video:"):
77
+ path = image[len("video:") :]
78
+ # Estimate frames for the video
79
+ vr = VideoReader(path, ctx=cpu(0))
80
+ num_frames = len(vr)
81
+ else:
82
+ # For images, each contributes one frame
83
+ num_frames = 1
84
+ estimated_frames_list.append(num_frames)
85
+
86
+ return estimated_frames_list
87
+
88
+ def encode_video(self, video_path, frame_count_limit=None):
89
+ if not os.path.exists(video_path):
90
+ logger.error(f"Video {video_path} does not exist")
91
+ return []
92
+
93
+ if frame_count_limit == 0:
94
+ return []
95
+
96
+ def uniform_sample(l, n):
97
+ gap = len(l) / n
98
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
99
+ return [l[i] for i in idxs]
100
+
101
+ vr = VideoReader(video_path, ctx=cpu(0))
102
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
103
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
104
+ if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
105
+ frame_idx = uniform_sample(frame_idx, frame_count_limit)
106
+ frames = vr.get_batch(frame_idx).asnumpy()
107
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
108
+ return frames
109
+
110
+ def load_images(
111
+ self,
112
+ max_req_input_len: int,
113
+ input_ids: list,
114
+ image_data,
115
+ image_token: str,
116
+ ) -> BaseImageProcessorOutput:
117
+ """
118
+ Each frame of video/image will be replaced by a single image token
119
+ """
120
+ image_hashes, image_sizes = [], []
121
+ all_frames = []
122
+ new_text_parts = []
123
+
124
+ if isinstance(input_ids, list):
125
+ assert len(input_ids) and isinstance(input_ids[0], int)
126
+ input_text = self._processor.tokenizer.decode(input_ids)
127
+ else:
128
+ input_text = input_ids
129
+
130
+ text_parts = input_text.split(image_token)
131
+
132
+ # roughly calculate the max number of frames under the max_req_input_len limit
133
+ def calculate_max_num_frames() -> int:
134
+ ret = (max_req_input_len - len(input_ids)) // self.NUM_TOKEN_PER_FRAME
135
+ return min(ret, 100)
136
+
137
+ MAX_NUM_FRAMES = calculate_max_num_frames()
138
+ estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
139
+ total_frame_count = sum(estimated_frames_list)
140
+ # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
141
+ # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
142
+ scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
143
+
144
+ # Process each input with allocated frames
145
+ for image_index, (image, estimated_frames) in enumerate(
146
+ zip(image_data, estimated_frames_list)
147
+ ):
148
+ if len(all_frames) >= MAX_NUM_FRAMES:
149
+ frames_to_process = 0
150
+ else:
151
+ frames_to_process = max(1, int(estimated_frames * scaling_factor))
152
+
153
+ if frames_to_process == 0:
154
+ frames = []
155
+ else:
156
+ try:
157
+ if isinstance(image, str) and image.startswith("video:"):
158
+ path = image[len("video:") :]
159
+ frames = self.encode_video(
160
+ path, frame_count_limit=frames_to_process
161
+ )
162
+ else:
163
+ raw_image, _size = load_image(image)
164
+ frames = [raw_image]
165
+ if len(frames) == 0:
166
+ continue
167
+ except FileNotFoundError as e:
168
+ print(e)
169
+ return None
170
+ image_sizes += frames[0].size * len(frames)
171
+ image_hashes += [hash(image)] * len(frames)
172
+ all_frames += frames
173
+
174
+ new_text_parts.append(text_parts[image_index])
175
+ if frames_to_process != 0:
176
+ new_text_parts.append(image_token * len(frames))
177
+ assert frames_to_process == len(frames)
178
+
179
+ new_text_parts.append(text_parts[-1])
180
+
181
+ input_text = "".join(new_text_parts)
182
+ return BaseImageProcessorOutput(
183
+ image_hashes, image_sizes, all_frames, input_text
184
+ )
185
+
54
186
 
55
187
  class DummyImageProcessor(BaseImageProcessor):
56
188
  def __init__(self):
@@ -248,9 +380,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
248
380
  text=input_text, images=images, return_tensors="pt"
249
381
  )
250
382
  return {
251
- "input_ids": result["input_ids"],
252
- "pixel_values": result["pixel_values"],
253
- "tgt_sizes": result["tgt_sizes"],
383
+ "input_ids": result.input_ids,
384
+ "pixel_values": result.pixel_values,
385
+ "tgt_sizes": result.tgt_sizes,
254
386
  }
255
387
 
256
388
  async def _process_images(self, images, input_text):
@@ -278,124 +410,20 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
278
410
  ):
279
411
  if not image_data:
280
412
  return None
281
-
282
413
  if not isinstance(image_data, list):
283
414
  image_data = [image_data]
284
415
 
285
- image_hashes, image_sizes = [], []
286
- all_frames = []
287
-
288
- # roughly calculate the max number of frames under the max_req_input_len limit
289
- def calculate_max_num_frames() -> int:
290
- # Model-specific
291
- NUM_TOKEN_PER_FRAME = 330
292
-
293
- ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
294
- return min(ret, 100)
295
-
296
- MAX_NUM_FRAMES = calculate_max_num_frames()
297
-
298
- # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
299
-
300
- def get_estimated_frames_list():
301
- """
302
- estimate the total frame count from all visual input
303
- """
304
- # Before processing inputs
305
- estimated_frames_list = []
306
- for image in image_data:
307
- if isinstance(image, str) and image.startswith("video:"):
308
- path = image[len("video:") :]
309
- # Estimate frames for the video
310
- vr = VideoReader(path, ctx=cpu(0))
311
- num_frames = len(vr)
312
- else:
313
- # For images, each contributes one frame
314
- num_frames = 1
315
- estimated_frames_list.append(num_frames)
316
-
317
- return estimated_frames_list
318
-
319
- estimated_frames_list = get_estimated_frames_list()
320
- total_frame_count = sum(estimated_frames_list)
321
- scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
322
-
323
- def encode_video(video_path, frame_count_limit=None):
324
- if not os.path.exists(video_path):
325
- logger.error(f"Video {video_path} does not exist")
326
- return []
327
-
328
- if frame_count_limit == 0:
329
- return []
330
-
331
- def uniform_sample(l, n):
332
- gap = len(l) / n
333
- idxs = [int(i * gap + gap / 2) for i in range(n)]
334
- return [l[i] for i in idxs]
335
-
336
- vr = VideoReader(video_path, ctx=cpu(0))
337
- sample_fps = round(vr.get_avg_fps() / 1) # FPS
338
- frame_idx = [i for i in range(0, len(vr), sample_fps)]
339
- if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
340
- frame_idx = uniform_sample(frame_idx, frame_count_limit)
341
- frames = vr.get_batch(frame_idx).asnumpy()
342
- frames = [Image.fromarray(v.astype("uint8")) for v in frames]
343
- return frames
344
-
345
- if isinstance(input_ids, list):
346
- assert len(input_ids) and isinstance(input_ids[0], int)
347
- input_text = self._processor.tokenizer.decode(input_ids)
348
- else:
349
- input_text = input_ids
350
- # MiniCPMV requires each frame of video as a single image token
351
- text_parts = input_text.split(self.IMAGE_TOKEN)
352
- new_text_parts = []
353
-
354
- # Process each input with allocated frames
355
- for image_index, (image, estimated_frames) in enumerate(
356
- zip(image_data, estimated_frames_list)
357
- ):
358
- if len(all_frames) >= MAX_NUM_FRAMES:
359
- frames_to_process = 0
360
- else:
361
- frames_to_process = max(1, int(estimated_frames * scaling_factor))
362
-
363
- if frames_to_process == 0:
364
- frames = []
365
- else:
366
- try:
367
- if isinstance(image, str) and image.startswith("video:"):
368
- path = image[len("video:") :]
369
- frames = encode_video(path, frame_count_limit=frames_to_process)
370
- else:
371
- raw_image, _size = load_image(image)
372
- frames = [raw_image]
373
- if len(frames) == 0:
374
- continue
375
- except FileNotFoundError as e:
376
- print(e)
377
- return None
378
- image_sizes += frames[0].size * len(frames)
379
- image_hashes += [hash(image)] * len(frames)
380
- all_frames += frames
381
-
382
- assert frames_to_process == len(frames)
383
-
384
- new_text_parts.append(text_parts[image_index])
385
-
386
- if frames_to_process != 0:
387
- new_text_parts.append(self.IMAGE_TOKEN * len(frames))
388
-
389
- new_text_parts.append(text_parts[-1])
390
-
391
- input_text = "".join(new_text_parts)
416
+ base_output = self.load_images(
417
+ max_req_input_len, input_ids, image_data, self.IMAGE_TOKEN
418
+ )
419
+ if base_output is None:
420
+ return None
392
421
 
393
- if len(all_frames) == 0:
422
+ if len(base_output.all_frames) == 0:
394
423
  return None
395
- res = await self._process_images(images=all_frames, input_text=input_text)
396
- pixel_values = res["pixel_values"]
397
- tgt_sizes = res["tgt_sizes"]
398
- input_ids = res["input_ids"]
424
+ res = await self._process_images(
425
+ images=base_output.all_frames, input_text=base_output.input_text
426
+ )
399
427
 
400
428
  # Collect special token ids
401
429
  tokenizer = self._processor.tokenizer
@@ -405,10 +433,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
405
433
  slice_start_id = [tokenizer.slice_start_id]
406
434
  slice_end_id = [tokenizer.slice_end_id]
407
435
  return {
408
- "input_ids": input_ids.flatten().tolist(),
409
- "pixel_values": pixel_values,
410
- "tgt_sizes": tgt_sizes,
411
- "image_hashes": image_hashes,
436
+ "input_ids": res["input_ids"].flatten().tolist(),
437
+ "pixel_values": res["pixel_values"],
438
+ "tgt_sizes": res["tgt_sizes"],
439
+ "image_hashes": base_output.image_hashes,
412
440
  "modalities": request_obj.modalities or ["image"],
413
441
  "im_start_id": im_start_id,
414
442
  "im_end_id": im_end_id,
@@ -536,13 +564,80 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
536
564
  }
537
565
 
538
566
 
567
+ class Qwen2_5VLImageProcessor(BaseImageProcessor):
568
+ def __init__(self, hf_config, server_args, _processor):
569
+ super().__init__(hf_config, server_args, _processor)
570
+ self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
571
+ self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
572
+ self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
573
+ self.NUM_TOKEN_PER_FRAME = 770
574
+
575
+ @staticmethod
576
+ def _process_images_task(images, input_text):
577
+ result = global_processor.__call__(
578
+ text=input_text, images=images, return_tensors="pt"
579
+ )
580
+ return {
581
+ "input_ids": result.input_ids,
582
+ "pixel_values": result.pixel_values,
583
+ "image_grid_thws": result.image_grid_thw,
584
+ }
585
+
586
+ async def _process_images(self, images, input_text) -> dict:
587
+ if self.executor is not None:
588
+ loop = asyncio.get_event_loop()
589
+ return await loop.run_in_executor(
590
+ self.executor,
591
+ Qwen2_5VLImageProcessor._process_images_task,
592
+ images,
593
+ input_text,
594
+ )
595
+ else:
596
+ return self._process_images_task(images, input_text)
597
+
598
+ async def process_images_async(
599
+ self,
600
+ image_data: List[Union[str, bytes]],
601
+ input_ids,
602
+ request_obj,
603
+ max_req_input_len,
604
+ *args,
605
+ **kwargs,
606
+ ):
607
+ if not image_data:
608
+ return None
609
+ if isinstance(image_data, str):
610
+ image_data = [image_data]
611
+
612
+ image_token = self.IMAGE_TOKEN
613
+ base_output = self.load_images(
614
+ max_req_input_len, input_ids, image_data, image_token
615
+ )
616
+
617
+ ret = await self._process_images(base_output.all_frames, base_output.input_text)
618
+
619
+ return {
620
+ "input_ids": ret["input_ids"].flatten().tolist(),
621
+ "pixel_values": ret["pixel_values"],
622
+ "image_hashes": base_output.image_hashes,
623
+ "modalities": request_obj.modalities or ["image"],
624
+ "image_grid_thws": ret["image_grid_thws"],
625
+ "im_start_id": self.IM_START_TOKEN_ID,
626
+ "im_end_id": self.IM_END_TOKEN_ID,
627
+ }
628
+
629
+
539
630
  def get_image_processor(
540
631
  hf_config, server_args: ServerArgs, processor
541
632
  ) -> BaseImageProcessor:
542
633
  if "MllamaForConditionalGeneration" in hf_config.architectures:
543
634
  return MllamaImageProcessor(hf_config, server_args, processor)
544
635
  elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
545
- return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
636
+
637
+ return Qwen2VLImageProcessor(hf_config, server_args, processor)
638
+ elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures:
639
+ return Qwen2_5VLImageProcessor(hf_config, server_args, processor)
640
+
546
641
  elif "MiniCPMV" in hf_config.architectures:
547
642
  return MiniCPMVImageProcessor(hf_config, server_args, processor)
548
643
  else:
@@ -263,7 +263,10 @@ class ForwardBatch:
263
263
  ret.extend_prefix_lens = torch.tensor(
264
264
  batch.extend_prefix_lens, dtype=torch.int32
265
265
  ).to(device, non_blocking=True)
266
- if model_runner.server_args.attention_backend != "torch_native":
266
+ if (
267
+ model_runner.server_args.attention_backend != "torch_native"
268
+ and model_runner.server_args.speculative_algorithm != "NEXTN"
269
+ ):
267
270
  ret.extend_num_tokens = batch.extend_num_tokens
268
271
  positions, ret.extend_start_loc = compute_position_triton(
269
272
  ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens