sglang 0.4.2.post4__py3-none-any.whl → 0.4.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/global_config.py +2 -0
  2. sglang/lang/backend/openai.py +5 -0
  3. sglang/lang/chat_template.py +22 -7
  4. sglang/lang/ir.py +1 -0
  5. sglang/srt/configs/__init__.py +6 -3
  6. sglang/srt/configs/model_config.py +2 -0
  7. sglang/srt/configs/qwen2_5_vl_config.py +1003 -0
  8. sglang/srt/entrypoints/engine.py +18 -3
  9. sglang/srt/hf_transformers_utils.py +2 -3
  10. sglang/srt/layers/attention/flashinfer_backend.py +235 -110
  11. sglang/srt/layers/attention/triton_backend.py +358 -72
  12. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  13. sglang/srt/layers/linear.py +12 -5
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
  23. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -2
  24. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
  25. sglang/srt/layers/moe/topk.py +1 -1
  26. sglang/srt/layers/quantization/__init__.py +51 -5
  27. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  29. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  39. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
  41. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  43. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
  45. sglang/srt/layers/quantization/fp8_kernel.py +123 -17
  46. sglang/srt/layers/quantization/fp8_utils.py +33 -4
  47. sglang/srt/managers/detokenizer_manager.py +1 -0
  48. sglang/srt/managers/image_processor.py +217 -122
  49. sglang/srt/managers/io_struct.py +4 -0
  50. sglang/srt/managers/schedule_batch.py +16 -3
  51. sglang/srt/managers/scheduler.py +29 -0
  52. sglang/srt/managers/tokenizer_manager.py +6 -0
  53. sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
  54. sglang/srt/model_executor/cuda_graph_runner.py +12 -1
  55. sglang/srt/model_executor/forward_batch_info.py +4 -1
  56. sglang/srt/model_executor/model_runner.py +12 -2
  57. sglang/srt/models/deepseek_nextn.py +295 -0
  58. sglang/srt/models/deepseek_v2.py +21 -8
  59. sglang/srt/models/llava.py +2 -1
  60. sglang/srt/models/qwen2_5_vl.py +722 -0
  61. sglang/srt/models/qwen2_vl.py +2 -1
  62. sglang/srt/openai_api/adapter.py +17 -3
  63. sglang/srt/server_args.py +26 -4
  64. sglang/srt/speculative/eagle_worker.py +35 -10
  65. sglang/srt/speculative/spec_info.py +11 -1
  66. sglang/srt/utils.py +7 -0
  67. sglang/utils.py +99 -19
  68. sglang/version.py +1 -1
  69. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/METADATA +5 -4
  70. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/RECORD +73 -55
  71. sglang/srt/configs/qwen2vl.py +0 -130
  72. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  # TODO: also move pad_input_ids into this module
2
2
  import asyncio
3
3
  import concurrent.futures
4
+ import dataclasses
4
5
  import logging
5
6
  import multiprocessing as mp
6
7
  import os
@@ -8,6 +9,7 @@ from abc import ABC, abstractmethod
8
9
  from typing import List, Optional, Union
9
10
 
10
11
  import numpy as np
12
+ import PIL
11
13
  import transformers
12
14
  from decord import VideoReader, cpu
13
15
  from PIL import Image
@@ -34,11 +36,22 @@ def init_global_processor(server_args: ServerArgs):
34
36
  )
35
37
 
36
38
 
39
+ @dataclasses.dataclass
40
+ class BaseImageProcessorOutput:
41
+ image_hashes: list[int]
42
+ image_sizes: list[int]
43
+ all_frames: [PIL.Image]
44
+ # input_text, with each frame of video/image represented with a image_token
45
+ input_text: str
46
+
47
+
37
48
  class BaseImageProcessor(ABC):
38
49
  def __init__(self, hf_config, server_args, _processor):
39
50
  self.hf_config = hf_config
40
51
  self._processor = _processor
41
52
  self.server_args = server_args
53
+ # FIXME: not accurate, model and image specific
54
+ self.NUM_TOKEN_PER_FRAME = 330
42
55
 
43
56
  self.executor = concurrent.futures.ProcessPoolExecutor(
44
57
  initializer=init_global_processor,
@@ -48,9 +61,128 @@ class BaseImageProcessor(ABC):
48
61
  )
49
62
 
50
63
  @abstractmethod
51
- async def process_images_async(self, image_data, input_text, **kwargs):
64
+ async def process_images_async(
65
+ self, image_data, input_text, max_req_input_len, **kwargs
66
+ ):
52
67
  pass
53
68
 
69
+ def get_estimated_frames_list(self, image_data):
70
+ """
71
+ estimate the total frame count from all visual input
72
+ """
73
+ # Before processing inputs
74
+ estimated_frames_list = []
75
+ for image in image_data:
76
+ if isinstance(image, str) and image.startswith("video:"):
77
+ path = image[len("video:") :]
78
+ # Estimate frames for the video
79
+ vr = VideoReader(path, ctx=cpu(0))
80
+ num_frames = len(vr)
81
+ else:
82
+ # For images, each contributes one frame
83
+ num_frames = 1
84
+ estimated_frames_list.append(num_frames)
85
+
86
+ return estimated_frames_list
87
+
88
+ def encode_video(self, video_path, frame_count_limit=None):
89
+ if not os.path.exists(video_path):
90
+ logger.error(f"Video {video_path} does not exist")
91
+ return []
92
+
93
+ if frame_count_limit == 0:
94
+ return []
95
+
96
+ def uniform_sample(l, n):
97
+ gap = len(l) / n
98
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
99
+ return [l[i] for i in idxs]
100
+
101
+ vr = VideoReader(video_path, ctx=cpu(0))
102
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
103
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
104
+ if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
105
+ frame_idx = uniform_sample(frame_idx, frame_count_limit)
106
+ frames = vr.get_batch(frame_idx).asnumpy()
107
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
108
+ return frames
109
+
110
+ def load_images(
111
+ self,
112
+ max_req_input_len: int,
113
+ input_ids: list,
114
+ image_data,
115
+ image_token: str,
116
+ ) -> BaseImageProcessorOutput:
117
+ """
118
+ Each frame of video/image will be replaced by a single image token
119
+ """
120
+ image_hashes, image_sizes = [], []
121
+ all_frames = []
122
+ new_text_parts = []
123
+
124
+ if isinstance(input_ids, list):
125
+ assert len(input_ids) and isinstance(input_ids[0], int)
126
+ input_text = self._processor.tokenizer.decode(input_ids)
127
+ else:
128
+ input_text = input_ids
129
+
130
+ text_parts = input_text.split(image_token)
131
+
132
+ # roughly calculate the max number of frames under the max_req_input_len limit
133
+ def calculate_max_num_frames() -> int:
134
+ ret = (max_req_input_len - len(input_ids)) // self.NUM_TOKEN_PER_FRAME
135
+ return min(ret, 100)
136
+
137
+ MAX_NUM_FRAMES = calculate_max_num_frames()
138
+ estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
139
+ total_frame_count = sum(estimated_frames_list)
140
+ # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
141
+ # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
142
+ scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
143
+
144
+ # Process each input with allocated frames
145
+ for image_index, (image, estimated_frames) in enumerate(
146
+ zip(image_data, estimated_frames_list)
147
+ ):
148
+ if len(all_frames) >= MAX_NUM_FRAMES:
149
+ frames_to_process = 0
150
+ else:
151
+ frames_to_process = max(1, int(estimated_frames * scaling_factor))
152
+
153
+ if frames_to_process == 0:
154
+ frames = []
155
+ else:
156
+ try:
157
+ if isinstance(image, str) and image.startswith("video:"):
158
+ path = image[len("video:") :]
159
+ frames = self.encode_video(
160
+ path, frame_count_limit=frames_to_process
161
+ )
162
+ else:
163
+ raw_image, _size = load_image(image)
164
+ frames = [raw_image]
165
+ if len(frames) == 0:
166
+ continue
167
+ except FileNotFoundError as e:
168
+ print(e)
169
+ return None
170
+ image_sizes += frames[0].size * len(frames)
171
+ image_hashes += [hash(image)] * len(frames)
172
+ all_frames += frames
173
+
174
+ new_text_parts.append(text_parts[image_index])
175
+ if frames_to_process != 0:
176
+ new_text_parts.append(image_token * len(frames))
177
+ assert frames_to_process == len(frames)
178
+
179
+ new_text_parts.append(text_parts[-1])
180
+
181
+ input_text = "".join(new_text_parts)
182
+ return BaseImageProcessorOutput(
183
+ image_hashes, image_sizes, all_frames, input_text
184
+ )
185
+
54
186
 
55
187
  class DummyImageProcessor(BaseImageProcessor):
56
188
  def __init__(self):
@@ -248,9 +380,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
248
380
  text=input_text, images=images, return_tensors="pt"
249
381
  )
250
382
  return {
251
- "input_ids": result["input_ids"],
252
- "pixel_values": result["pixel_values"],
253
- "tgt_sizes": result["tgt_sizes"],
383
+ "input_ids": result.input_ids,
384
+ "pixel_values": result.pixel_values,
385
+ "tgt_sizes": result.tgt_sizes,
254
386
  }
255
387
 
256
388
  async def _process_images(self, images, input_text):
@@ -278,124 +410,20 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
278
410
  ):
279
411
  if not image_data:
280
412
  return None
281
-
282
413
  if not isinstance(image_data, list):
283
414
  image_data = [image_data]
284
415
 
285
- image_hashes, image_sizes = [], []
286
- all_frames = []
287
-
288
- # roughly calculate the max number of frames under the max_req_input_len limit
289
- def calculate_max_num_frames() -> int:
290
- # Model-specific
291
- NUM_TOKEN_PER_FRAME = 330
292
-
293
- ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
294
- return min(ret, 100)
295
-
296
- MAX_NUM_FRAMES = calculate_max_num_frames()
297
-
298
- # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
299
-
300
- def get_estimated_frames_list():
301
- """
302
- estimate the total frame count from all visual input
303
- """
304
- # Before processing inputs
305
- estimated_frames_list = []
306
- for image in image_data:
307
- if isinstance(image, str) and image.startswith("video:"):
308
- path = image[len("video:") :]
309
- # Estimate frames for the video
310
- vr = VideoReader(path, ctx=cpu(0))
311
- num_frames = len(vr)
312
- else:
313
- # For images, each contributes one frame
314
- num_frames = 1
315
- estimated_frames_list.append(num_frames)
316
-
317
- return estimated_frames_list
318
-
319
- estimated_frames_list = get_estimated_frames_list()
320
- total_frame_count = sum(estimated_frames_list)
321
- scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
322
-
323
- def encode_video(video_path, frame_count_limit=None):
324
- if not os.path.exists(video_path):
325
- logger.error(f"Video {video_path} does not exist")
326
- return []
327
-
328
- if frame_count_limit == 0:
329
- return []
330
-
331
- def uniform_sample(l, n):
332
- gap = len(l) / n
333
- idxs = [int(i * gap + gap / 2) for i in range(n)]
334
- return [l[i] for i in idxs]
335
-
336
- vr = VideoReader(video_path, ctx=cpu(0))
337
- sample_fps = round(vr.get_avg_fps() / 1) # FPS
338
- frame_idx = [i for i in range(0, len(vr), sample_fps)]
339
- if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
340
- frame_idx = uniform_sample(frame_idx, frame_count_limit)
341
- frames = vr.get_batch(frame_idx).asnumpy()
342
- frames = [Image.fromarray(v.astype("uint8")) for v in frames]
343
- return frames
344
-
345
- if isinstance(input_ids, list):
346
- assert len(input_ids) and isinstance(input_ids[0], int)
347
- input_text = self._processor.tokenizer.decode(input_ids)
348
- else:
349
- input_text = input_ids
350
- # MiniCPMV requires each frame of video as a single image token
351
- text_parts = input_text.split(self.IMAGE_TOKEN)
352
- new_text_parts = []
353
-
354
- # Process each input with allocated frames
355
- for image_index, (image, estimated_frames) in enumerate(
356
- zip(image_data, estimated_frames_list)
357
- ):
358
- if len(all_frames) >= MAX_NUM_FRAMES:
359
- frames_to_process = 0
360
- else:
361
- frames_to_process = max(1, int(estimated_frames * scaling_factor))
362
-
363
- if frames_to_process == 0:
364
- frames = []
365
- else:
366
- try:
367
- if isinstance(image, str) and image.startswith("video:"):
368
- path = image[len("video:") :]
369
- frames = encode_video(path, frame_count_limit=frames_to_process)
370
- else:
371
- raw_image, _size = load_image(image)
372
- frames = [raw_image]
373
- if len(frames) == 0:
374
- continue
375
- except FileNotFoundError as e:
376
- print(e)
377
- return None
378
- image_sizes += frames[0].size * len(frames)
379
- image_hashes += [hash(image)] * len(frames)
380
- all_frames += frames
381
-
382
- assert frames_to_process == len(frames)
383
-
384
- new_text_parts.append(text_parts[image_index])
385
-
386
- if frames_to_process != 0:
387
- new_text_parts.append(self.IMAGE_TOKEN * len(frames))
388
-
389
- new_text_parts.append(text_parts[-1])
390
-
391
- input_text = "".join(new_text_parts)
416
+ base_output = self.load_images(
417
+ max_req_input_len, input_ids, image_data, self.IMAGE_TOKEN
418
+ )
419
+ if base_output is None:
420
+ return None
392
421
 
393
- if len(all_frames) == 0:
422
+ if len(base_output.all_frames) == 0:
394
423
  return None
395
- res = await self._process_images(images=all_frames, input_text=input_text)
396
- pixel_values = res["pixel_values"]
397
- tgt_sizes = res["tgt_sizes"]
398
- input_ids = res["input_ids"]
424
+ res = await self._process_images(
425
+ images=base_output.all_frames, input_text=base_output.input_text
426
+ )
399
427
 
400
428
  # Collect special token ids
401
429
  tokenizer = self._processor.tokenizer
@@ -405,10 +433,10 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
405
433
  slice_start_id = [tokenizer.slice_start_id]
406
434
  slice_end_id = [tokenizer.slice_end_id]
407
435
  return {
408
- "input_ids": input_ids.flatten().tolist(),
409
- "pixel_values": pixel_values,
410
- "tgt_sizes": tgt_sizes,
411
- "image_hashes": image_hashes,
436
+ "input_ids": res["input_ids"].flatten().tolist(),
437
+ "pixel_values": res["pixel_values"],
438
+ "tgt_sizes": res["tgt_sizes"],
439
+ "image_hashes": base_output.image_hashes,
412
440
  "modalities": request_obj.modalities or ["image"],
413
441
  "im_start_id": im_start_id,
414
442
  "im_end_id": im_end_id,
@@ -536,13 +564,80 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
536
564
  }
537
565
 
538
566
 
567
+ class Qwen2_5VLImageProcessor(BaseImageProcessor):
568
+ def __init__(self, hf_config, server_args, _processor):
569
+ super().__init__(hf_config, server_args, _processor)
570
+ self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
571
+ self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
572
+ self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
573
+ self.NUM_TOKEN_PER_FRAME = 770
574
+
575
+ @staticmethod
576
+ def _process_images_task(images, input_text):
577
+ result = global_processor.__call__(
578
+ text=input_text, images=images, return_tensors="pt"
579
+ )
580
+ return {
581
+ "input_ids": result.input_ids,
582
+ "pixel_values": result.pixel_values,
583
+ "image_grid_thws": result.image_grid_thw,
584
+ }
585
+
586
+ async def _process_images(self, images, input_text) -> dict:
587
+ if self.executor is not None:
588
+ loop = asyncio.get_event_loop()
589
+ return await loop.run_in_executor(
590
+ self.executor,
591
+ Qwen2_5VLImageProcessor._process_images_task,
592
+ images,
593
+ input_text,
594
+ )
595
+ else:
596
+ return self._process_images_task(images, input_text)
597
+
598
+ async def process_images_async(
599
+ self,
600
+ image_data: List[Union[str, bytes]],
601
+ input_ids,
602
+ request_obj,
603
+ max_req_input_len,
604
+ *args,
605
+ **kwargs,
606
+ ):
607
+ if not image_data:
608
+ return None
609
+ if isinstance(image_data, str):
610
+ image_data = [image_data]
611
+
612
+ image_token = self.IMAGE_TOKEN
613
+ base_output = self.load_images(
614
+ max_req_input_len, input_ids, image_data, image_token
615
+ )
616
+
617
+ ret = await self._process_images(base_output.all_frames, base_output.input_text)
618
+
619
+ return {
620
+ "input_ids": ret["input_ids"].flatten().tolist(),
621
+ "pixel_values": ret["pixel_values"],
622
+ "image_hashes": base_output.image_hashes,
623
+ "modalities": request_obj.modalities or ["image"],
624
+ "image_grid_thws": ret["image_grid_thws"],
625
+ "im_start_id": self.IM_START_TOKEN_ID,
626
+ "im_end_id": self.IM_END_TOKEN_ID,
627
+ }
628
+
629
+
539
630
  def get_image_processor(
540
631
  hf_config, server_args: ServerArgs, processor
541
632
  ) -> BaseImageProcessor:
542
633
  if "MllamaForConditionalGeneration" in hf_config.architectures:
543
634
  return MllamaImageProcessor(hf_config, server_args, processor)
544
635
  elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
545
- return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
636
+
637
+ return Qwen2VLImageProcessor(hf_config, server_args, processor)
638
+ elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures:
639
+ return Qwen2_5VLImageProcessor(hf_config, server_args, processor)
640
+
546
641
  elif "MiniCPMV" in hf_config.architectures:
547
642
  return MiniCPMVImageProcessor(hf_config, server_args, processor)
548
643
  else:
@@ -371,6 +371,8 @@ class BatchTokenIDOut:
371
371
  output_top_logprobs_val: List[List]
372
372
  output_top_logprobs_idx: List[List]
373
373
 
374
+ output_hidden_states: List[List[float]]
375
+
374
376
 
375
377
  @dataclass
376
378
  class BatchStrOut:
@@ -397,6 +399,8 @@ class BatchStrOut:
397
399
  output_top_logprobs_val: List[List]
398
400
  output_top_logprobs_idx: List[List]
399
401
 
402
+ output_hidden_states: List[List[float]]
403
+
400
404
 
401
405
  @dataclass
402
406
  class BatchEmbeddingOut:
@@ -65,6 +65,7 @@ global_server_args_dict = {
65
65
  "enable_dp_attention": ServerArgs.enable_dp_attention,
66
66
  "enable_ep_moe": ServerArgs.enable_ep_moe,
67
67
  "device": ServerArgs.device,
68
+ "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
68
69
  }
69
70
 
70
71
  logger = logging.getLogger(__name__)
@@ -315,6 +316,7 @@ class Req:
315
316
  self.output_token_logprobs_val = self.output_token_logprobs_idx = (
316
317
  self.output_top_logprobs_val
317
318
  ) = self.output_top_logprobs_idx = None
319
+ self.hidden_states = []
318
320
 
319
321
  # Logprobs (internal values)
320
322
  # The tokens is prefilled but need to be considered as decode tokens
@@ -604,6 +606,9 @@ class ScheduleBatch:
604
606
  # Enable custom logit processor
605
607
  enable_custom_logit_processor: bool = False
606
608
 
609
+ # Return hidden states
610
+ return_hidden_states: bool = False
611
+
607
612
  @classmethod
608
613
  def init_new(
609
614
  cls,
@@ -615,6 +620,7 @@ class ScheduleBatch:
615
620
  enable_overlap: bool,
616
621
  spec_algorithm: SpeculativeAlgorithm,
617
622
  enable_custom_logit_processor: bool,
623
+ return_hidden_states: bool = False,
618
624
  ):
619
625
  return cls(
620
626
  reqs=reqs,
@@ -629,6 +635,7 @@ class ScheduleBatch:
629
635
  device=req_to_token_pool.device,
630
636
  spec_algorithm=spec_algorithm,
631
637
  enable_custom_logit_processor=enable_custom_logit_processor,
638
+ return_hidden_states=return_hidden_states,
632
639
  )
633
640
 
634
641
  def batch_size(self):
@@ -1196,9 +1203,15 @@ class ScheduleBatch:
1196
1203
  spec_algorithm=self.spec_algorithm,
1197
1204
  spec_info=self.spec_info,
1198
1205
  capture_hidden_mode=(
1199
- getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
1200
- if self.spec_info
1201
- else CaptureHiddenMode.NULL
1206
+ CaptureHiddenMode.FULL
1207
+ if self.return_hidden_states
1208
+ else (
1209
+ getattr(
1210
+ self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
1211
+ )
1212
+ if self.spec_info
1213
+ else CaptureHiddenMode.NULL
1214
+ )
1202
1215
  ),
1203
1216
  )
1204
1217
 
@@ -997,6 +997,7 @@ class Scheduler:
997
997
  self.enable_overlap,
998
998
  self.spec_algorithm,
999
999
  self.server_args.enable_custom_logit_processor,
1000
+ self.server_args.return_hidden_states,
1000
1001
  )
1001
1002
  new_batch.prepare_for_extend()
1002
1003
 
@@ -1156,6 +1157,8 @@ class Scheduler:
1156
1157
  logits_output.input_token_logprobs.tolist()
1157
1158
  )
1158
1159
 
1160
+ hidden_state_offset = 0
1161
+
1159
1162
  # Check finish conditions
1160
1163
  logprob_pt = 0
1161
1164
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
@@ -1182,6 +1185,21 @@ class Scheduler:
1182
1185
  i, req, logprob_pt, next_token_ids, logits_output
1183
1186
  )
1184
1187
 
1188
+ if (
1189
+ self.server_args.return_hidden_states
1190
+ and logits_output.hidden_states is not None
1191
+ ):
1192
+ req.hidden_states.append(
1193
+ logits_output.hidden_states[
1194
+ hidden_state_offset : (
1195
+ hidden_state_offset := hidden_state_offset
1196
+ + len(req.origin_input_ids)
1197
+ )
1198
+ ]
1199
+ .cpu()
1200
+ .clone()
1201
+ )
1202
+
1185
1203
  if req.grammar is not None:
1186
1204
  req.grammar.accept_token(next_token_id)
1187
1205
  req.grammar.finished = req.finished()
@@ -1275,6 +1293,12 @@ class Scheduler:
1275
1293
  logits_output.next_token_top_logprobs_idx[i]
1276
1294
  )
1277
1295
 
1296
+ if (
1297
+ self.server_args.return_hidden_states
1298
+ and logits_output.hidden_states is not None
1299
+ ):
1300
+ req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
1301
+
1278
1302
  if req.grammar is not None:
1279
1303
  req.grammar.accept_token(next_token_id)
1280
1304
  req.grammar.finished = req.finished()
@@ -1398,6 +1422,7 @@ class Scheduler:
1398
1422
  completion_tokens = []
1399
1423
  cached_tokens = []
1400
1424
  spec_verify_ct = []
1425
+ hidden_states = []
1401
1426
 
1402
1427
  if return_logprob:
1403
1428
  input_token_logprobs_val = []
@@ -1464,6 +1489,8 @@ class Scheduler:
1464
1489
  output_top_logprobs_val.append(req.output_top_logprobs_val)
1465
1490
  output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1466
1491
 
1492
+ hidden_states.append(req.hidden_states)
1493
+
1467
1494
  # Send to detokenizer
1468
1495
  if rids:
1469
1496
  self.send_to_detokenizer.send_pyobj(
@@ -1490,6 +1517,7 @@ class Scheduler:
1490
1517
  input_top_logprobs_idx,
1491
1518
  output_top_logprobs_val,
1492
1519
  output_top_logprobs_idx,
1520
+ hidden_states,
1493
1521
  )
1494
1522
  )
1495
1523
  else: # embedding or reward model
@@ -1553,6 +1581,7 @@ class Scheduler:
1553
1581
  self.enable_overlap,
1554
1582
  self.spec_algorithm,
1555
1583
  self.server_args.enable_custom_logit_processor,
1584
+ self.server_args.return_hidden_states,
1556
1585
  )
1557
1586
  idle_batch.prepare_for_idle()
1558
1587
  return idle_batch
@@ -796,6 +796,12 @@ class TokenizerManager:
796
796
  }
797
797
  )
798
798
 
799
+ if (
800
+ hasattr(recv_obj, "output_hidden_states")
801
+ and len(recv_obj.output_hidden_states[i]) > 0
802
+ ):
803
+ meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
804
+
799
805
  if isinstance(recv_obj, BatchStrOut):
800
806
  out_dict = {
801
807
  "text": recv_obj.output_strs[i],
@@ -156,6 +156,10 @@ class TpModelWorkerClient:
156
156
  logits_output.input_token_logprobs = (
157
157
  logits_output.input_token_logprobs.to("cpu", non_blocking=True)
158
158
  )
159
+ if logits_output.hidden_states is not None:
160
+ logits_output.hidden_states = logits_output.hidden_states.to(
161
+ "cpu", non_blocking=True
162
+ )
159
163
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
160
164
  copy_done.record()
161
165
 
@@ -33,6 +33,9 @@ from sglang.srt.model_executor.forward_batch_info import (
33
33
  ForwardBatch,
34
34
  ForwardMode,
35
35
  )
36
+ from sglang.srt.utils import is_hip
37
+
38
+ is_hip_ = is_hip()
36
39
 
37
40
  if TYPE_CHECKING:
38
41
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -129,6 +132,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
129
132
  if bs <= model_runner.req_to_token_pool.size
130
133
  and bs <= server_args.cuda_graph_max_bs
131
134
  ]
135
+ if is_hip_:
136
+ capture_bs += [i * 8 for i in range(21, 33)]
132
137
  compile_bs = (
133
138
  [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
134
139
  if server_args.enable_torch_compile
@@ -349,7 +354,13 @@ class CudaGraphRunner:
349
354
  spec_algorithm=self.model_runner.spec_algorithm,
350
355
  spec_info=spec_info,
351
356
  capture_hidden_mode=(
352
- spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
357
+ CaptureHiddenMode.FULL
358
+ if self.model_runner.server_args.return_hidden_states
359
+ else (
360
+ spec_info.capture_hidden_mode
361
+ if spec_info
362
+ else CaptureHiddenMode.NULL
363
+ )
353
364
  ),
354
365
  )
355
366
 
@@ -263,7 +263,10 @@ class ForwardBatch:
263
263
  ret.extend_prefix_lens = torch.tensor(
264
264
  batch.extend_prefix_lens, dtype=torch.int32
265
265
  ).to(device, non_blocking=True)
266
- if model_runner.server_args.attention_backend != "torch_native":
266
+ if (
267
+ model_runner.server_args.attention_backend != "torch_native"
268
+ and model_runner.server_args.speculative_algorithm != "NEXTN"
269
+ ):
267
270
  ret.extend_num_tokens = batch.extend_num_tokens
268
271
  positions, ret.extend_start_loc = compute_position_triton(
269
272
  ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
@@ -67,6 +67,7 @@ from sglang.srt.utils import (
67
67
  monkey_patch_p2p_access_check,
68
68
  monkey_patch_vllm_gguf_config,
69
69
  set_cpu_offload_max_bytes,
70
+ set_cuda_arch,
70
71
  )
71
72
 
72
73
  logger = logging.getLogger(__name__)
@@ -110,8 +111,14 @@ class ModelRunner:
110
111
  ):
111
112
  # TODO: add MLA optimization on CPU
112
113
  if self.server_args.device != "cpu":
113
- logger.info("MLA optimization is turned on. Use triton backend.")
114
- self.server_args.attention_backend = "triton"
114
+ if server_args.enable_flashinfer_mla:
115
+ logger.info(
116
+ "FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
117
+ )
118
+ self.server_args.attention_backend = "flashinfer"
119
+ else:
120
+ logger.info("MLA optimization is turned on. Use triton backend.")
121
+ self.server_args.attention_backend = "triton"
115
122
 
116
123
  if self.server_args.enable_double_sparsity:
117
124
  logger.info(
@@ -169,6 +176,7 @@ class ModelRunner:
169
176
  "enable_dp_attention": server_args.enable_dp_attention,
170
177
  "enable_ep_moe": server_args.enable_ep_moe,
171
178
  "device": server_args.device,
179
+ "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
172
180
  }
173
181
  )
174
182
 
@@ -292,6 +300,8 @@ class ModelRunner:
292
300
  if torch.cuda.get_device_capability()[1] < 5:
293
301
  raise RuntimeError("SGLang only supports sm75 and above.")
294
302
 
303
+ set_cuda_arch()
304
+
295
305
  # Prepare the model config
296
306
  self.load_config = LoadConfig(
297
307
  load_format=self.server_args.load_format,