sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,8 @@ from typing import List, Optional, Union
9
9
 
10
10
  import numpy as np
11
11
  import transformers
12
+ from decord import VideoReader, cpu
13
+ from PIL import Image
12
14
 
13
15
  from sglang.srt.hf_transformers_utils import get_processor
14
16
  from sglang.srt.mm_utils import expand2square, process_anyres_image
@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
36
38
  def __init__(self, hf_config, server_args, _processor):
37
39
  self.hf_config = hf_config
38
40
  self._processor = _processor
41
+ self.server_args = server_args
39
42
 
40
43
  self.executor = concurrent.futures.ProcessPoolExecutor(
41
44
  initializer=init_global_processor,
@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor):
126
129
  )
127
130
 
128
131
  async def process_images_async(
129
- self, image_data: List[Union[str, bytes]], input_text, request_obj
132
+ self,
133
+ image_data: List[Union[str, bytes]],
134
+ input_text,
135
+ request_obj,
136
+ *args,
137
+ **kwargs,
130
138
  ):
131
139
  if not image_data:
132
140
  return None
@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor):
229
237
  return image_inputs
230
238
 
231
239
 
240
+ class MiniCPMVImageProcessor(BaseImageProcessor):
241
+ def __init__(self, hf_config, server_args, _processor):
242
+ super().__init__(hf_config, server_args, _processor)
243
+
244
+ @staticmethod
245
+ def _process_images_task(images, input_text):
246
+ result = global_processor.__call__(
247
+ text=input_text, images=images, return_tensors="pt"
248
+ )
249
+ return {
250
+ "input_ids": result["input_ids"],
251
+ "pixel_values": result["pixel_values"],
252
+ "tgt_sizes": result["tgt_sizes"],
253
+ }
254
+
255
+ async def _process_images(self, images, input_text):
256
+ if self.executor is not None:
257
+ loop = asyncio.get_event_loop()
258
+ image_inputs = await loop.run_in_executor(
259
+ self.executor,
260
+ MiniCPMVImageProcessor._process_images_task,
261
+ images,
262
+ input_text,
263
+ )
264
+ else:
265
+ image_inputs = self._processor(
266
+ images=images, text=input_text, return_tensors="pt"
267
+ )
268
+
269
+ return image_inputs
270
+
271
+ async def process_images_async(
272
+ self,
273
+ image_data: List[Union[str, bytes]],
274
+ input_text,
275
+ request_obj,
276
+ max_req_input_len,
277
+ ):
278
+ if not image_data:
279
+ return None
280
+
281
+ if not isinstance(image_data, list):
282
+ image_data = [image_data]
283
+
284
+ image_hashes, image_sizes = [], []
285
+ raw_images = []
286
+ IMAGE_TOKEN = "(<image>./</image>)"
287
+
288
+ # roughly calculate the max number of frames
289
+ # TODO: the process should be applied to all the visual inputs
290
+ def calculate_max_num_frames() -> int:
291
+ # Model-specific
292
+ NUM_TOKEN_PER_FRAME = 330
293
+
294
+ ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
295
+ return min(ret, 100)
296
+
297
+ # if cuda OOM set a smaller number
298
+ MAX_NUM_FRAMES = calculate_max_num_frames()
299
+ print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
300
+
301
+ def encode_video(video_path):
302
+ if not os.path.exists(video_path):
303
+ logger.error(f"Video {video_path} does not exist")
304
+ return []
305
+
306
+ if MAX_NUM_FRAMES == 0:
307
+ return []
308
+
309
+ def uniform_sample(l, n):
310
+ gap = len(l) / n
311
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
312
+ return [l[i] for i in idxs]
313
+
314
+ vr = VideoReader(video_path, ctx=cpu(0))
315
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
316
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
317
+ if len(frame_idx) > MAX_NUM_FRAMES:
318
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
319
+ frames = vr.get_batch(frame_idx).asnumpy()
320
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
321
+ return frames
322
+
323
+ if isinstance(input_text, list):
324
+ assert len(input_text) and isinstance(input_text[0], int)
325
+ input_text = self._processor.tokenizer.decode(input_text)
326
+
327
+ # MiniCPMV requires each frame of video as a single image token
328
+ text_parts = input_text.split(IMAGE_TOKEN)
329
+ new_text_parts = []
330
+
331
+ for image_index, image in enumerate(image_data):
332
+ try:
333
+ if isinstance(image, str) and image.startswith("video:"):
334
+ path = image[len("video:") :]
335
+ frames = encode_video(path)
336
+ else:
337
+ raw_image, size = load_image(image)
338
+ frames = [raw_image]
339
+ if len(frames) == 0:
340
+ continue
341
+ except FileNotFoundError as e:
342
+ print(e)
343
+ return None
344
+
345
+ image_sizes += frames[0].size * len(frames)
346
+ image_hashes += [hash(image)] * len(frames)
347
+ raw_images += frames
348
+ new_text_parts.append(text_parts[image_index])
349
+ new_text_parts.append(IMAGE_TOKEN * len(frames))
350
+
351
+ new_text_parts.append(text_parts[-1])
352
+ input_text = "".join(new_text_parts)
353
+ if len(raw_images) == 0:
354
+ return None
355
+ res = await self._process_images(images=raw_images, input_text=input_text)
356
+ pixel_values = res["pixel_values"]
357
+ tgt_sizes = res["tgt_sizes"]
358
+ input_ids = res["input_ids"]
359
+
360
+ # Collect special token ids
361
+ tokenizer = self._processor.tokenizer
362
+ im_start_id = [tokenizer.im_start_id]
363
+ im_end_id = [tokenizer.im_end_id]
364
+ if tokenizer.slice_start_id:
365
+ slice_start_id = [tokenizer.slice_start_id]
366
+ slice_end_id = [tokenizer.slice_end_id]
367
+
368
+ return {
369
+ "input_ids": input_ids.flatten().tolist(),
370
+ "pixel_values": pixel_values,
371
+ "tgt_sizes": tgt_sizes,
372
+ "image_hashes": image_hashes,
373
+ "modalities": request_obj.modalities or ["image"],
374
+ "im_start_id": im_start_id,
375
+ "im_end_id": im_end_id,
376
+ "slice_start_id": slice_start_id,
377
+ "slice_end_id": slice_end_id,
378
+ }
379
+
380
+
232
381
  class Qwen2VLImageProcessor(BaseImageProcessor):
233
382
  def __init__(self, hf_config, server_args, _image_processor):
234
383
  self.hf_config = hf_config
@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
289
438
  return self._process_single_image_task(image_data)
290
439
 
291
440
  async def process_images_async(
292
- self, image_data: List[Union[str, bytes]], input_text, request_obj
441
+ self,
442
+ image_data: List[Union[str, bytes]],
443
+ input_text,
444
+ request_obj,
445
+ *args,
446
+ **kwargs,
293
447
  ):
294
448
  if not image_data:
295
449
  return None
@@ -350,6 +504,8 @@ def get_image_processor(
350
504
  return MllamaImageProcessor(hf_config, server_args, processor)
351
505
  elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
352
506
  return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
507
+ elif "MiniCPMV" in hf_config.architectures:
508
+ return MiniCPMVImageProcessor(hf_config, server_args, processor)
353
509
  else:
354
510
  return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
355
511
 
@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
17
17
  """
18
18
 
19
19
  import uuid
20
- from dataclasses import dataclass
20
+ from dataclasses import dataclass, field
21
21
  from enum import Enum
22
22
  from typing import Dict, List, Optional, Union
23
23
 
@@ -59,6 +59,9 @@ class GenerateReqInput:
59
59
  return_text_in_logprobs: bool = False
60
60
  # Whether to stream output.
61
61
  stream: bool = False
62
+ # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
63
+ log_metrics: bool = True
64
+
62
65
  # The modalities of the image data [image, multi-images, video]
63
66
  modalities: Optional[List[str]] = None
64
67
  # LoRA related
@@ -66,6 +69,10 @@ class GenerateReqInput:
66
69
 
67
70
  # Session info for continual prompting
68
71
  session_params: Optional[Union[List[Dict], Dict]] = None
72
+ # Custom logit processor for advanced sampling control. Must be a serialized instance
73
+ # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
74
+ # Use the processor's `to_str()` method to generate the serialized string.
75
+ custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
69
76
 
70
77
  def normalize_batch_and_arguments(self):
71
78
  if (
@@ -180,6 +187,13 @@ class GenerateReqInput:
180
187
  else:
181
188
  assert self.parallel_sample_num == 1
182
189
 
190
+ if self.custom_logit_processor is None:
191
+ self.custom_logit_processor = [None] * num
192
+ elif not isinstance(self.custom_logit_processor, list):
193
+ self.custom_logit_processor = [self.custom_logit_processor] * num
194
+ else:
195
+ assert self.parallel_sample_num == 1
196
+
183
197
  def regenerate_rid(self):
184
198
  self.rid = uuid.uuid4().hex
185
199
  return self.rid
@@ -196,8 +210,14 @@ class GenerateReqInput:
196
210
  top_logprobs_num=self.top_logprobs_num[i],
197
211
  return_text_in_logprobs=self.return_text_in_logprobs,
198
212
  stream=self.stream,
213
+ log_metrics=self.log_metrics,
199
214
  modalities=self.modalities[i] if self.modalities else None,
200
215
  lora_path=self.lora_path[i] if self.lora_path is not None else None,
216
+ custom_logit_processor=(
217
+ self.custom_logit_processor[i]
218
+ if self.custom_logit_processor is not None
219
+ else None
220
+ ),
201
221
  )
202
222
 
203
223
 
@@ -230,6 +250,11 @@ class TokenizedGenerateReqInput:
230
250
  # Session info for continual prompting
231
251
  session_params: Optional[SessionParams] = None
232
252
 
253
+ # Custom logit processor for advanced sampling control. Must be a serialized instance
254
+ # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
255
+ # Use the processor's `to_str()` method to generate the serialized string.
256
+ custom_logit_processor: Optional[str] = None
257
+
233
258
 
234
259
  @dataclass
235
260
  class EmbeddingReqInput:
@@ -243,6 +268,8 @@ class EmbeddingReqInput:
243
268
  sampling_params: Union[List[Dict], Dict] = None
244
269
  # Dummy input embeds for compatibility
245
270
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
271
+ # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
272
+ log_metrics: bool = True
246
273
 
247
274
  def normalize_batch_and_arguments(self):
248
275
  if (self.text is None and self.input_ids is None) or (
@@ -327,10 +354,13 @@ class BatchTokenIDOut:
327
354
  skip_special_tokens: List[bool]
328
355
  spaces_between_special_tokens: List[bool]
329
356
  no_stop_trim: List[bool]
357
+
330
358
  # Token counts
331
359
  prompt_tokens: List[int]
332
360
  completion_tokens: List[int]
333
361
  cached_tokens: List[int]
362
+ spec_verify_ct: List[int]
363
+
334
364
  # Logprobs
335
365
  input_token_logprobs_val: List[float]
336
366
  input_token_logprobs_idx: List[int]
@@ -340,7 +370,6 @@ class BatchTokenIDOut:
340
370
  input_top_logprobs_idx: List[List]
341
371
  output_top_logprobs_val: List[List]
342
372
  output_top_logprobs_idx: List[List]
343
- normalized_prompt_logprob: List[float]
344
373
 
345
374
 
346
375
  @dataclass
@@ -356,6 +385,7 @@ class BatchStrOut:
356
385
  prompt_tokens: List[int]
357
386
  completion_tokens: List[int]
358
387
  cached_tokens: List[int]
388
+ spec_verify_ct: List[int]
359
389
 
360
390
  # Logprobs
361
391
  input_token_logprobs_val: List[float]
@@ -366,7 +396,6 @@ class BatchStrOut:
366
396
  input_top_logprobs_idx: List[List]
367
397
  output_top_logprobs_val: List[List]
368
398
  output_top_logprobs_idx: List[List]
369
- normalized_prompt_logprob: List[float]
370
399
 
371
400
 
372
401
  @dataclass
@@ -491,6 +520,7 @@ class ProfileReq(Enum):
491
520
  @dataclass
492
521
  class ConfigureLoggingReq:
493
522
  log_requests: Optional[bool] = None
523
+ log_requests_level: Optional[int] = None
494
524
  dump_requests_folder: Optional[str] = None
495
525
  dump_requests_threshold: Optional[int] = None
496
526
 
@@ -510,3 +540,27 @@ class CloseSessionReqInput:
510
540
  class OpenSessionReqOutput:
511
541
  session_id: Optional[str]
512
542
  success: bool
543
+
544
+
545
+ @dataclass
546
+ class Function:
547
+ description: Optional[str] = None
548
+ name: Optional[str] = None
549
+ parameters: Optional[object] = None
550
+
551
+
552
+ @dataclass
553
+ class Tool:
554
+ function: Function
555
+ type: Optional[str] = "function"
556
+
557
+
558
+ @dataclass
559
+ class FunctionCallReqInput:
560
+ text: str # The text to parse.
561
+ tools: List[Tool] = field(
562
+ default_factory=list
563
+ ) # A list of available function tools (name, parameters, etc.).
564
+ tool_call_parser: Optional[str] = (
565
+ None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
566
+ )
@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
52
52
  if TYPE_CHECKING:
53
53
  from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
54
54
 
55
-
56
55
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
57
56
 
58
57
  # Put some global args for easy access
@@ -65,9 +64,9 @@ global_server_args_dict = {
65
64
  "enable_nan_detection": ServerArgs.enable_nan_detection,
66
65
  "enable_dp_attention": ServerArgs.enable_dp_attention,
67
66
  "enable_ep_moe": ServerArgs.enable_ep_moe,
67
+ "device": ServerArgs.device,
68
68
  }
69
69
 
70
-
71
70
  logger = logging.getLogger(__name__)
72
71
 
73
72
 
@@ -116,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason):
116
115
 
117
116
 
118
117
  class FINISH_ABORT(BaseFinishReason):
119
- def __init__(self, message="Unknown error"):
118
+ def __init__(self, message="Unknown error", status_code=None, err_type=None):
120
119
  super().__init__(is_error=True)
121
120
  self.message = message
121
+ self.status_code = status_code
122
+ self.err_type = err_type
122
123
 
123
124
  def to_json(self):
124
125
  return {
125
126
  "type": "abort",
126
127
  "message": self.message,
128
+ "status_code": self.status_code,
129
+ "err_type": self.err_type,
127
130
  }
128
131
 
129
132
 
@@ -148,6 +151,15 @@ class ImageInputs:
148
151
  image_grid_thws: List[Tuple[int, int, int]] = None
149
152
  mrope_position_delta: Optional[torch.Tensor] = None
150
153
 
154
+ # MiniCPMV related
155
+ # All the images in the batch should share the same special image
156
+ # bound token ids.
157
+ im_start_id: Optional[torch.Tensor] = None
158
+ im_end_id: Optional[torch.Tensor] = None
159
+ slice_start_id: Optional[torch.Tensor] = None
160
+ slice_end_id: Optional[torch.Tensor] = None
161
+ tgt_sizes: Optional[list] = None
162
+
151
163
  @staticmethod
152
164
  def from_dict(obj: dict):
153
165
  ret = ImageInputs(
@@ -167,6 +179,11 @@ class ImageInputs:
167
179
  "aspect_ratio_ids",
168
180
  "aspect_ratio_mask",
169
181
  "image_grid_thws",
182
+ "im_start_id",
183
+ "im_end_id",
184
+ "slice_start_id",
185
+ "slice_end_id",
186
+ "tgt_sizes",
170
187
  ]
171
188
  for arg in optional_args:
172
189
  if arg in obj:
@@ -215,6 +232,7 @@ class Req:
215
232
  lora_path: Optional[str] = None,
216
233
  input_embeds: Optional[List[List[float]]] = None,
217
234
  session_id: Optional[str] = None,
235
+ custom_logit_processor: Optional[str] = None,
218
236
  eos_token_ids: Optional[Set[int]] = None,
219
237
  ):
220
238
  # Input and output info
@@ -226,14 +244,16 @@ class Req:
226
244
  else origin_input_ids # Before image padding
227
245
  )
228
246
  self.origin_input_ids = origin_input_ids
229
- self.output_ids = [] # Each decode stage's output ids
230
- self.fill_ids = None # fill_ids = origin_input_ids + output_ids
247
+ # Each decode stage's output ids
248
+ self.output_ids = []
249
+ # fill_ids = origin_input_ids + output_ids. Updated if chunked.
250
+ self.fill_ids = None
231
251
  self.session_id = session_id
232
252
  self.input_embeds = input_embeds
233
253
 
234
254
  # Sampling info
235
255
  self.sampling_params = sampling_params
236
- self.lora_path = lora_path
256
+ self.custom_logit_processor = custom_logit_processor
237
257
 
238
258
  # Memory pool info
239
259
  self.req_pool_idx = None
@@ -265,6 +285,7 @@ class Req:
265
285
  # Prefix info
266
286
  self.prefix_indices = []
267
287
  # Tokens to run prefill. input_tokens - shared_prefix_tokens.
288
+ # Updated if chunked.
268
289
  self.extend_input_len = 0
269
290
  self.last_node = None
270
291
 
@@ -279,12 +300,11 @@ class Req:
279
300
  self.logprob_start_len = 0
280
301
  self.top_logprobs_num = top_logprobs_num
281
302
 
282
- # Logprobs (return value)
283
- self.normalized_prompt_logprob = None
284
- self.input_token_logprobs_val = None
285
- self.input_token_logprobs_idx = None
286
- self.input_top_logprobs_val = None
287
- self.input_top_logprobs_idx = None
303
+ # Logprobs (return values)
304
+ self.input_token_logprobs_val: Optional[List[float]] = None
305
+ self.input_token_logprobs_idx: Optional[List[int]] = None
306
+ self.input_top_logprobs_val: Optional[List[float]] = None
307
+ self.input_top_logprobs_idx: Optional[List[int]] = None
288
308
 
289
309
  if return_logprob:
290
310
  self.output_token_logprobs_val = []
@@ -309,8 +329,14 @@ class Req:
309
329
  # Constrained decoding
310
330
  self.grammar: Optional[BaseGrammarObject] = None
311
331
 
312
- # The number of cached tokens, that were already cached in the KV cache
332
+ # The number of cached tokens that were already cached in the KV cache
313
333
  self.cached_tokens = 0
334
+ self.already_computed = 0
335
+
336
+ # The number of verification forward passes in the speculative decoding.
337
+ # This is used to compute the average acceptance length per request.
338
+ self.spec_verify_ct = 0
339
+ self.lora_path = lora_path
314
340
 
315
341
  def extend_image_inputs(self, image_inputs):
316
342
  if self.image_inputs is None:
@@ -344,9 +370,6 @@ class Req:
344
370
  max_prefix_len = min(max_prefix_len, input_len - 1)
345
371
 
346
372
  if self.return_logprob:
347
- if self.normalized_prompt_logprob is None:
348
- # Need at least two tokens to compute normalized logprob
349
- max_prefix_len = min(max_prefix_len, input_len - 2)
350
373
  max_prefix_len = min(max_prefix_len, self.logprob_start_len)
351
374
 
352
375
  max_prefix_len = max(max_prefix_len, 0)
@@ -533,13 +556,13 @@ class ScheduleBatch:
533
556
  next_batch_sampling_info: SamplingBatchInfo = None
534
557
 
535
558
  # Batched arguments to model runner
536
- input_ids: torch.Tensor = None
537
- input_embeds: torch.Tensor = None
538
- req_pool_indices: torch.Tensor = None
539
- seq_lens: torch.Tensor = None
559
+ input_ids: torch.Tensor = None # shape: [b], int32
560
+ input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
561
+ req_pool_indices: torch.Tensor = None # shape: [b], int32
562
+ seq_lens: torch.Tensor = None # shape: [b], int64
540
563
  # The output locations of the KV cache
541
- out_cache_loc: torch.Tensor = None
542
- output_ids: torch.Tensor = None
564
+ out_cache_loc: torch.Tensor = None # shape: [b], int32
565
+ output_ids: torch.Tensor = None # shape: [b], int32
543
566
 
544
567
  # The sum of all sequence lengths
545
568
  seq_lens_sum: int = None
@@ -578,6 +601,9 @@ class ScheduleBatch:
578
601
  spec_algorithm: SpeculativeAlgorithm = None
579
602
  spec_info: Optional[SpecInfo] = None
580
603
 
604
+ # Enable custom logit processor
605
+ enable_custom_logit_processor: bool = False
606
+
581
607
  @classmethod
582
608
  def init_new(
583
609
  cls,
@@ -588,6 +614,7 @@ class ScheduleBatch:
588
614
  model_config: ModelConfig,
589
615
  enable_overlap: bool,
590
616
  spec_algorithm: SpeculativeAlgorithm,
617
+ enable_custom_logit_processor: bool,
591
618
  ):
592
619
  return cls(
593
620
  reqs=reqs,
@@ -601,6 +628,7 @@ class ScheduleBatch:
601
628
  has_grammar=any(req.grammar for req in reqs),
602
629
  device=req_to_token_pool.device,
603
630
  spec_algorithm=spec_algorithm,
631
+ enable_custom_logit_processor=enable_custom_logit_processor,
604
632
  )
605
633
 
606
634
  def batch_size(self):
@@ -656,7 +684,7 @@ class ScheduleBatch:
656
684
  or len(req.prefix_indices) >= im.num_image_tokens
657
685
  )
658
686
 
659
- self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
687
+ self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
660
688
  self.device, non_blocking=True
661
689
  )
662
690
 
@@ -690,7 +718,7 @@ class ScheduleBatch:
690
718
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
691
719
  self.device, non_blocking=True
692
720
  )
693
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
721
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
694
722
  self.device, non_blocking=True
695
723
  )
696
724
 
@@ -728,13 +756,6 @@ class ScheduleBatch:
728
756
 
729
757
  pt = 0
730
758
  for i, req in enumerate(reqs):
731
- already_computed = (
732
- req.extend_logprob_start_len + 1 + req.cached_tokens
733
- if req.extend_logprob_start_len > 0
734
- else 0
735
- )
736
- req.cached_tokens += len(req.prefix_indices) - already_computed
737
-
738
759
  req.req_pool_idx = req_pool_indices[i]
739
760
  pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
740
761
  seq_lens.append(seq_len)
@@ -750,15 +771,20 @@ class ScheduleBatch:
750
771
  # If req.input_embeds is already a list, append its content directly
751
772
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
752
773
 
753
- # Compute the relative logprob_start_len in an extend batch
754
- if req.logprob_start_len >= pre_len:
755
- extend_logprob_start_len = min(
756
- req.logprob_start_len - pre_len, req.extend_input_len - 1
757
- )
758
- else:
759
- extend_logprob_start_len = req.extend_input_len - 1
774
+ if req.return_logprob:
775
+ # Compute the relative logprob_start_len in an extend batch
776
+ if req.logprob_start_len >= pre_len:
777
+ extend_logprob_start_len = min(
778
+ req.logprob_start_len - pre_len, req.extend_input_len - 1
779
+ )
780
+ else:
781
+ raise RuntimeError(
782
+ f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
783
+ )
784
+ req.extend_logprob_start_len = extend_logprob_start_len
760
785
 
761
- req.extend_logprob_start_len = extend_logprob_start_len
786
+ req.cached_tokens += pre_len - req.already_computed
787
+ req.already_computed = seq_len
762
788
  req.is_retracted = False
763
789
  pre_lens.append(pre_len)
764
790
 
@@ -766,10 +792,10 @@ class ScheduleBatch:
766
792
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
767
793
  self.device, non_blocking=True
768
794
  )
769
- self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
795
+ self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
770
796
  self.device, non_blocking=True
771
797
  )
772
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
798
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
773
799
  self.device, non_blocking=True
774
800
  )
775
801
  self.input_embeds = (
@@ -1002,11 +1028,16 @@ class ScheduleBatch:
1002
1028
  def prepare_for_idle(self):
1003
1029
  self.forward_mode = ForwardMode.IDLE
1004
1030
  self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
1005
- self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1031
+ self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1006
1032
  self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
1007
1033
  self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1008
1034
  self.seq_lens_sum = 0
1009
1035
  self.extend_num_tokens = 0
1036
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1037
+ self,
1038
+ self.model_config.vocab_size,
1039
+ enable_overlap_schedule=self.enable_overlap,
1040
+ )
1010
1041
 
1011
1042
  def prepare_for_decode(self):
1012
1043
  self.forward_mode = ForwardMode.DECODE
@@ -1067,7 +1098,7 @@ class ScheduleBatch:
1067
1098
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1068
1099
 
1069
1100
  self.reqs = [self.reqs[i] for i in keep_indices]
1070
- new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
1101
+ new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
1071
1102
  self.device, non_blocking=True
1072
1103
  )
1073
1104
  self.req_pool_indices = self.req_pool_indices[new_indices]
@@ -1085,6 +1116,8 @@ class ScheduleBatch:
1085
1116
  self.has_grammar = any(req.grammar for req in self.reqs)
1086
1117
 
1087
1118
  self.sampling_info.filter_batch(keep_indices, new_indices)
1119
+ if self.spec_info:
1120
+ self.spec_info.filter_batch(new_indices)
1088
1121
 
1089
1122
  def merge_batch(self, other: "ScheduleBatch"):
1090
1123
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1121,7 +1154,7 @@ class ScheduleBatch:
1121
1154
  self.spec_info.merge_batch(other.spec_info)
1122
1155
 
1123
1156
  def get_model_worker_batch(self):
1124
- if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1157
+ if self.forward_mode.is_decode_or_idle():
1125
1158
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1126
1159
  else:
1127
1160
  extend_seq_lens = self.extend_lens
@@ -1136,7 +1169,6 @@ class ScheduleBatch:
1136
1169
 
1137
1170
  global bid
1138
1171
  bid += 1
1139
-
1140
1172
  return ModelWorkerBatch(
1141
1173
  bid=bid,
1142
1174
  forward_mode=self.forward_mode,
@@ -1180,6 +1212,7 @@ class ScheduleBatch:
1180
1212
  return_logprob=self.return_logprob,
1181
1213
  decoding_reqs=self.decoding_reqs,
1182
1214
  spec_algorithm=self.spec_algorithm,
1215
+ enable_custom_logit_processor=self.enable_custom_logit_processor,
1183
1216
  )
1184
1217
 
1185
1218
  def __str__(self):