sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,8 @@ from typing import List, Optional, Union
9
9
 
10
10
  import numpy as np
11
11
  import transformers
12
+ from decord import VideoReader, cpu
13
+ from PIL import Image
12
14
 
13
15
  from sglang.srt.hf_transformers_utils import get_processor
14
16
  from sglang.srt.mm_utils import expand2square, process_anyres_image
@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
36
38
  def __init__(self, hf_config, server_args, _processor):
37
39
  self.hf_config = hf_config
38
40
  self._processor = _processor
41
+ self.server_args = server_args
39
42
 
40
43
  self.executor = concurrent.futures.ProcessPoolExecutor(
41
44
  initializer=init_global_processor,
@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor):
126
129
  )
127
130
 
128
131
  async def process_images_async(
129
- self, image_data: List[Union[str, bytes]], input_text, request_obj
132
+ self,
133
+ image_data: List[Union[str, bytes]],
134
+ input_text,
135
+ request_obj,
136
+ *args,
137
+ **kwargs,
130
138
  ):
131
139
  if not image_data:
132
140
  return None
@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor):
229
237
  return image_inputs
230
238
 
231
239
 
240
+ class MiniCPMVImageProcessor(BaseImageProcessor):
241
+ def __init__(self, hf_config, server_args, _processor):
242
+ super().__init__(hf_config, server_args, _processor)
243
+
244
+ @staticmethod
245
+ def _process_images_task(images, input_text):
246
+ result = global_processor.__call__(
247
+ text=input_text, images=images, return_tensors="pt"
248
+ )
249
+ return {
250
+ "input_ids": result["input_ids"],
251
+ "pixel_values": result["pixel_values"],
252
+ "tgt_sizes": result["tgt_sizes"],
253
+ }
254
+
255
+ async def _process_images(self, images, input_text):
256
+ if self.executor is not None:
257
+ loop = asyncio.get_event_loop()
258
+ image_inputs = await loop.run_in_executor(
259
+ self.executor,
260
+ MiniCPMVImageProcessor._process_images_task,
261
+ images,
262
+ input_text,
263
+ )
264
+ else:
265
+ image_inputs = self._processor(
266
+ images=images, text=input_text, return_tensors="pt"
267
+ )
268
+
269
+ return image_inputs
270
+
271
+ async def process_images_async(
272
+ self,
273
+ image_data: List[Union[str, bytes]],
274
+ input_text,
275
+ request_obj,
276
+ max_req_input_len,
277
+ ):
278
+ if not image_data:
279
+ return None
280
+
281
+ if not isinstance(image_data, list):
282
+ image_data = [image_data]
283
+
284
+ image_hashes, image_sizes = [], []
285
+ raw_images = []
286
+ IMAGE_TOKEN = "(<image>./</image>)"
287
+
288
+ # roughly calculate the max number of frames
289
+ # TODO: the process should be applied to all the visual inputs
290
+ def calculate_max_num_frames() -> int:
291
+ # Model-specific
292
+ NUM_TOKEN_PER_FRAME = 330
293
+
294
+ ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
295
+ return min(ret, 100)
296
+
297
+ # if cuda OOM set a smaller number
298
+ MAX_NUM_FRAMES = calculate_max_num_frames()
299
+ print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
300
+
301
+ def encode_video(video_path):
302
+ if not os.path.exists(video_path):
303
+ logger.error(f"Video {video_path} does not exist")
304
+ return []
305
+
306
+ if MAX_NUM_FRAMES == 0:
307
+ return []
308
+
309
+ def uniform_sample(l, n):
310
+ gap = len(l) / n
311
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
312
+ return [l[i] for i in idxs]
313
+
314
+ vr = VideoReader(video_path, ctx=cpu(0))
315
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
316
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
317
+ if len(frame_idx) > MAX_NUM_FRAMES:
318
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
319
+ frames = vr.get_batch(frame_idx).asnumpy()
320
+ frames = [Image.fromarray(v.astype("uint8")) for v in frames]
321
+ return frames
322
+
323
+ if isinstance(input_text, list):
324
+ assert len(input_text) and isinstance(input_text[0], int)
325
+ input_text = self._processor.tokenizer.decode(input_text)
326
+
327
+ # MiniCPMV requires each frame of video as a single image token
328
+ text_parts = input_text.split(IMAGE_TOKEN)
329
+ new_text_parts = []
330
+
331
+ for image_index, image in enumerate(image_data):
332
+ try:
333
+ if isinstance(image, str) and image.startswith("video:"):
334
+ path = image[len("video:") :]
335
+ frames = encode_video(path)
336
+ else:
337
+ raw_image, size = load_image(image)
338
+ frames = [raw_image]
339
+ if len(frames) == 0:
340
+ continue
341
+ except FileNotFoundError as e:
342
+ print(e)
343
+ return None
344
+
345
+ image_sizes += frames[0].size * len(frames)
346
+ image_hashes += [hash(image)] * len(frames)
347
+ raw_images += frames
348
+ new_text_parts.append(text_parts[image_index])
349
+ new_text_parts.append(IMAGE_TOKEN * len(frames))
350
+
351
+ new_text_parts.append(text_parts[-1])
352
+ input_text = "".join(new_text_parts)
353
+ if len(raw_images) == 0:
354
+ return None
355
+ res = await self._process_images(images=raw_images, input_text=input_text)
356
+ pixel_values = res["pixel_values"]
357
+ tgt_sizes = res["tgt_sizes"]
358
+ input_ids = res["input_ids"]
359
+
360
+ # Collect special token ids
361
+ tokenizer = self._processor.tokenizer
362
+ im_start_id = [tokenizer.im_start_id]
363
+ im_end_id = [tokenizer.im_end_id]
364
+ if tokenizer.slice_start_id:
365
+ slice_start_id = [tokenizer.slice_start_id]
366
+ slice_end_id = [tokenizer.slice_end_id]
367
+
368
+ return {
369
+ "input_ids": input_ids.flatten().tolist(),
370
+ "pixel_values": pixel_values,
371
+ "tgt_sizes": tgt_sizes,
372
+ "image_hashes": image_hashes,
373
+ "modalities": request_obj.modalities or ["image"],
374
+ "im_start_id": im_start_id,
375
+ "im_end_id": im_end_id,
376
+ "slice_start_id": slice_start_id,
377
+ "slice_end_id": slice_end_id,
378
+ }
379
+
380
+
232
381
  class Qwen2VLImageProcessor(BaseImageProcessor):
233
382
  def __init__(self, hf_config, server_args, _image_processor):
234
383
  self.hf_config = hf_config
@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
289
438
  return self._process_single_image_task(image_data)
290
439
 
291
440
  async def process_images_async(
292
- self, image_data: List[Union[str, bytes]], input_text, request_obj
441
+ self,
442
+ image_data: List[Union[str, bytes]],
443
+ input_text,
444
+ request_obj,
445
+ *args,
446
+ **kwargs,
293
447
  ):
294
448
  if not image_data:
295
449
  return None
@@ -350,6 +504,8 @@ def get_image_processor(
350
504
  return MllamaImageProcessor(hf_config, server_args, processor)
351
505
  elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
352
506
  return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
507
+ elif "MiniCPMV" in hf_config.architectures:
508
+ return MiniCPMVImageProcessor(hf_config, server_args, processor)
353
509
  else:
354
510
  return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
355
511
 
@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
19
19
  import uuid
20
20
  from dataclasses import dataclass
21
21
  from enum import Enum
22
- from typing import Dict, List, Optional, Tuple, Union
23
-
24
- import torch
22
+ from typing import Dict, List, Optional, Union
25
23
 
26
24
  from sglang.srt.managers.schedule_batch import BaseFinishReason
27
25
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -61,6 +59,9 @@ class GenerateReqInput:
61
59
  return_text_in_logprobs: bool = False
62
60
  # Whether to stream output.
63
61
  stream: bool = False
62
+ # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
63
+ log_metrics: bool = True
64
+
64
65
  # The modalities of the image data [image, multi-images, video]
65
66
  modalities: Optional[List[str]] = None
66
67
  # LoRA related
@@ -68,6 +69,8 @@ class GenerateReqInput:
68
69
 
69
70
  # Session info for continual prompting
70
71
  session_params: Optional[Union[List[Dict], Dict]] = None
72
+ # Custom logit processor (serialized function)
73
+ custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
71
74
 
72
75
  def normalize_batch_and_arguments(self):
73
76
  if (
@@ -182,6 +185,13 @@ class GenerateReqInput:
182
185
  else:
183
186
  assert self.parallel_sample_num == 1
184
187
 
188
+ if self.custom_logit_processor is None:
189
+ self.custom_logit_processor = [None] * num
190
+ elif not isinstance(self.custom_logit_processor, list):
191
+ self.custom_logit_processor = [self.custom_logit_processor] * num
192
+ else:
193
+ assert self.parallel_sample_num == 1
194
+
185
195
  def regenerate_rid(self):
186
196
  self.rid = uuid.uuid4().hex
187
197
  return self.rid
@@ -198,8 +208,14 @@ class GenerateReqInput:
198
208
  top_logprobs_num=self.top_logprobs_num[i],
199
209
  return_text_in_logprobs=self.return_text_in_logprobs,
200
210
  stream=self.stream,
211
+ log_metrics=self.log_metrics,
201
212
  modalities=self.modalities[i] if self.modalities else None,
202
213
  lora_path=self.lora_path[i] if self.lora_path is not None else None,
214
+ custom_logit_processor=(
215
+ self.custom_logit_processor[i]
216
+ if self.custom_logit_processor is not None
217
+ else None
218
+ ),
203
219
  )
204
220
 
205
221
 
@@ -232,6 +248,10 @@ class TokenizedGenerateReqInput:
232
248
  # Session info for continual prompting
233
249
  session_params: Optional[SessionParams] = None
234
250
 
251
+ # Custom logit processor (serialized function)
252
+ # TODO (hpguo): Add an example and update doc string here
253
+ custom_logit_processor: Optional[str] = None
254
+
235
255
 
236
256
  @dataclass
237
257
  class EmbeddingReqInput:
@@ -245,6 +265,8 @@ class EmbeddingReqInput:
245
265
  sampling_params: Union[List[Dict], Dict] = None
246
266
  # Dummy input embeds for compatibility
247
267
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
268
+ # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
269
+ log_metrics: bool = True
248
270
 
249
271
  def normalize_batch_and_arguments(self):
250
272
  if (self.text is None and self.input_ids is None) or (
@@ -323,9 +345,7 @@ class BatchTokenIDOut:
323
345
  decoded_texts: List[str]
324
346
  decode_ids: List[int]
325
347
  read_offsets: List[int]
326
- # Only used when --return-token-ids` is set
327
- origin_input_ids: Optional[List[int]]
328
- # Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
348
+ # Only used when `--skip-tokenizer-init` is on
329
349
  output_ids: Optional[List[int]]
330
350
  # Detokenization configs
331
351
  skip_special_tokens: List[bool]
@@ -344,7 +364,6 @@ class BatchTokenIDOut:
344
364
  input_top_logprobs_idx: List[List]
345
365
  output_top_logprobs_val: List[List]
346
366
  output_top_logprobs_idx: List[List]
347
- normalized_prompt_logprob: List[float]
348
367
 
349
368
 
350
369
  @dataclass
@@ -356,14 +375,7 @@ class BatchStrOut:
356
375
  # The output decoded strings
357
376
  output_strs: List[str]
358
377
 
359
- # The token ids
360
- origin_input_ids: Optional[List[int]]
361
- output_ids: Optional[List[int]]
362
-
363
378
  # Token counts
364
- # real input and output tokens can be get from
365
- # origin_input_ids and output_ids by enabling --return_token_ids
366
- # TODO (Shuai): Rename this to clarify the meaning.
367
379
  prompt_tokens: List[int]
368
380
  completion_tokens: List[int]
369
381
  cached_tokens: List[int]
@@ -377,7 +389,6 @@ class BatchStrOut:
377
389
  input_top_logprobs_idx: List[List]
378
390
  output_top_logprobs_val: List[List]
379
391
  output_top_logprobs_idx: List[List]
380
- normalized_prompt_logprob: List[float]
381
392
 
382
393
 
383
394
  @dataclass
@@ -468,6 +479,26 @@ class GetWeightsByNameReqOutput:
468
479
  parameter: list
469
480
 
470
481
 
482
+ @dataclass
483
+ class ReleaseMemoryOccupationReqInput:
484
+ pass
485
+
486
+
487
+ @dataclass
488
+ class ReleaseMemoryOccupationReqOutput:
489
+ pass
490
+
491
+
492
+ @dataclass
493
+ class ResumeMemoryOccupationReqInput:
494
+ pass
495
+
496
+
497
+ @dataclass
498
+ class ResumeMemoryOccupationReqOutput:
499
+ pass
500
+
501
+
471
502
  @dataclass
472
503
  class AbortReq:
473
504
  # The request id
@@ -479,6 +510,14 @@ class ProfileReq(Enum):
479
510
  STOP_PROFILE = 2
480
511
 
481
512
 
513
+ @dataclass
514
+ class ConfigureLoggingReq:
515
+ log_requests: Optional[bool] = None
516
+ log_requests_level: Optional[int] = None
517
+ dump_requests_folder: Optional[str] = None
518
+ dump_requests_threshold: Optional[int] = None
519
+
520
+
482
521
  @dataclass
483
522
  class OpenSessionReqInput:
484
523
  capacity_of_str_len: int
@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
52
52
  if TYPE_CHECKING:
53
53
  from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
54
54
 
55
-
56
55
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
57
56
 
58
57
  # Put some global args for easy access
@@ -65,9 +64,9 @@ global_server_args_dict = {
65
64
  "enable_nan_detection": ServerArgs.enable_nan_detection,
66
65
  "enable_dp_attention": ServerArgs.enable_dp_attention,
67
66
  "enable_ep_moe": ServerArgs.enable_ep_moe,
67
+ "device": ServerArgs.device,
68
68
  }
69
69
 
70
-
71
70
  logger = logging.getLogger(__name__)
72
71
 
73
72
 
@@ -116,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason):
116
115
 
117
116
 
118
117
  class FINISH_ABORT(BaseFinishReason):
119
- def __init__(self, message="Unknown error"):
118
+ def __init__(self, message="Unknown error", status_code=None, err_type=None):
120
119
  super().__init__(is_error=True)
121
120
  self.message = message
121
+ self.status_code = status_code
122
+ self.err_type = err_type
122
123
 
123
124
  def to_json(self):
124
125
  return {
125
126
  "type": "abort",
126
127
  "message": self.message,
128
+ "status_code": self.status_code,
129
+ "err_type": self.err_type,
127
130
  }
128
131
 
129
132
 
@@ -148,6 +151,15 @@ class ImageInputs:
148
151
  image_grid_thws: List[Tuple[int, int, int]] = None
149
152
  mrope_position_delta: Optional[torch.Tensor] = None
150
153
 
154
+ # MiniCPMV related
155
+ # All the images in the batch should share the same special image
156
+ # bound token ids.
157
+ im_start_id: Optional[torch.Tensor] = None
158
+ im_end_id: Optional[torch.Tensor] = None
159
+ slice_start_id: Optional[torch.Tensor] = None
160
+ slice_end_id: Optional[torch.Tensor] = None
161
+ tgt_sizes: Optional[list] = None
162
+
151
163
  @staticmethod
152
164
  def from_dict(obj: dict):
153
165
  ret = ImageInputs(
@@ -167,6 +179,11 @@ class ImageInputs:
167
179
  "aspect_ratio_ids",
168
180
  "aspect_ratio_mask",
169
181
  "image_grid_thws",
182
+ "im_start_id",
183
+ "im_end_id",
184
+ "slice_start_id",
185
+ "slice_end_id",
186
+ "tgt_sizes",
170
187
  ]
171
188
  for arg in optional_args:
172
189
  if arg in obj:
@@ -215,6 +232,7 @@ class Req:
215
232
  lora_path: Optional[str] = None,
216
233
  input_embeds: Optional[List[List[float]]] = None,
217
234
  session_id: Optional[str] = None,
235
+ custom_logit_processor: Optional[str] = None,
218
236
  eos_token_ids: Optional[Set[int]] = None,
219
237
  ):
220
238
  # Input and output info
@@ -226,14 +244,16 @@ class Req:
226
244
  else origin_input_ids # Before image padding
227
245
  )
228
246
  self.origin_input_ids = origin_input_ids
229
- self.output_ids = [] # Each decode stage's output ids
230
- self.fill_ids = None # fill_ids = origin_input_ids + output_ids
247
+ # Each decode stage's output ids
248
+ self.output_ids = []
249
+ # fill_ids = origin_input_ids + output_ids. Updated if chunked.
231
250
  self.session_id = session_id
232
251
  self.input_embeds = input_embeds
233
252
 
234
253
  # Sampling info
235
254
  self.sampling_params = sampling_params
236
255
  self.lora_path = lora_path
256
+ self.custom_logit_processor = custom_logit_processor
237
257
 
238
258
  # Memory pool info
239
259
  self.req_pool_idx = None
@@ -265,6 +285,7 @@ class Req:
265
285
  # Prefix info
266
286
  self.prefix_indices = []
267
287
  # Tokens to run prefill. input_tokens - shared_prefix_tokens.
288
+ # Updated if chunked.
268
289
  self.extend_input_len = 0
269
290
  self.last_node = None
270
291
 
@@ -280,11 +301,10 @@ class Req:
280
301
  self.top_logprobs_num = top_logprobs_num
281
302
 
282
303
  # Logprobs (return value)
283
- self.normalized_prompt_logprob = None
284
- self.input_token_logprobs_val = None
285
- self.input_token_logprobs_idx = None
286
- self.input_top_logprobs_val = None
287
- self.input_top_logprobs_idx = None
304
+ self.input_token_logprobs_val: Optional[List[float]] = None
305
+ self.input_token_logprobs_idx: Optional[List[int]] = None
306
+ self.input_top_logprobs_val: Optional[List[float]] = None
307
+ self.input_top_logprobs_idx: Optional[List[int]] = None
288
308
 
289
309
  if return_logprob:
290
310
  self.output_token_logprobs_val = []
@@ -344,9 +364,6 @@ class Req:
344
364
  max_prefix_len = min(max_prefix_len, input_len - 1)
345
365
 
346
366
  if self.return_logprob:
347
- if self.normalized_prompt_logprob is None:
348
- # Need at least two tokens to compute normalized logprob
349
- max_prefix_len = min(max_prefix_len, input_len - 2)
350
367
  max_prefix_len = min(max_prefix_len, self.logprob_start_len)
351
368
 
352
369
  max_prefix_len = max(max_prefix_len, 0)
@@ -578,6 +595,9 @@ class ScheduleBatch:
578
595
  spec_algorithm: SpeculativeAlgorithm = None
579
596
  spec_info: Optional[SpecInfo] = None
580
597
 
598
+ # Enable custom logit processor
599
+ enable_custom_logit_processor: bool = False
600
+
581
601
  @classmethod
582
602
  def init_new(
583
603
  cls,
@@ -588,6 +608,7 @@ class ScheduleBatch:
588
608
  model_config: ModelConfig,
589
609
  enable_overlap: bool,
590
610
  spec_algorithm: SpeculativeAlgorithm,
611
+ enable_custom_logit_processor: bool,
591
612
  ):
592
613
  return cls(
593
614
  reqs=reqs,
@@ -601,6 +622,7 @@ class ScheduleBatch:
601
622
  has_grammar=any(req.grammar for req in reqs),
602
623
  device=req_to_token_pool.device,
603
624
  spec_algorithm=spec_algorithm,
625
+ enable_custom_logit_processor=enable_custom_logit_processor,
604
626
  )
605
627
 
606
628
  def batch_size(self):
@@ -656,7 +678,7 @@ class ScheduleBatch:
656
678
  or len(req.prefix_indices) >= im.num_image_tokens
657
679
  )
658
680
 
659
- self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
681
+ self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
660
682
  self.device, non_blocking=True
661
683
  )
662
684
 
@@ -690,7 +712,7 @@ class ScheduleBatch:
690
712
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
691
713
  self.device, non_blocking=True
692
714
  )
693
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
715
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
694
716
  self.device, non_blocking=True
695
717
  )
696
718
 
@@ -766,10 +788,10 @@ class ScheduleBatch:
766
788
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
767
789
  self.device, non_blocking=True
768
790
  )
769
- self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
791
+ self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
770
792
  self.device, non_blocking=True
771
793
  )
772
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
794
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
773
795
  self.device, non_blocking=True
774
796
  )
775
797
  self.input_embeds = (
@@ -1002,11 +1024,16 @@ class ScheduleBatch:
1002
1024
  def prepare_for_idle(self):
1003
1025
  self.forward_mode = ForwardMode.IDLE
1004
1026
  self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
1005
- self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1027
+ self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1006
1028
  self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
1007
- self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1029
+ self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
1008
1030
  self.seq_lens_sum = 0
1009
1031
  self.extend_num_tokens = 0
1032
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1033
+ self,
1034
+ self.model_config.vocab_size,
1035
+ enable_overlap_schedule=self.enable_overlap,
1036
+ )
1010
1037
 
1011
1038
  def prepare_for_decode(self):
1012
1039
  self.forward_mode = ForwardMode.DECODE
@@ -1067,7 +1094,7 @@ class ScheduleBatch:
1067
1094
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1068
1095
 
1069
1096
  self.reqs = [self.reqs[i] for i in keep_indices]
1070
- new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
1097
+ new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
1071
1098
  self.device, non_blocking=True
1072
1099
  )
1073
1100
  self.req_pool_indices = self.req_pool_indices[new_indices]
@@ -1121,7 +1148,7 @@ class ScheduleBatch:
1121
1148
  self.spec_info.merge_batch(other.spec_info)
1122
1149
 
1123
1150
  def get_model_worker_batch(self):
1124
- if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1151
+ if self.forward_mode.is_decode_or_idle():
1125
1152
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1126
1153
  else:
1127
1154
  extend_seq_lens = self.extend_lens
@@ -1136,7 +1163,6 @@ class ScheduleBatch:
1136
1163
 
1137
1164
  global bid
1138
1165
  bid += 1
1139
-
1140
1166
  return ModelWorkerBatch(
1141
1167
  bid=bid,
1142
1168
  forward_mode=self.forward_mode,
@@ -1180,6 +1206,7 @@ class ScheduleBatch:
1180
1206
  return_logprob=self.return_logprob,
1181
1207
  decoding_reqs=self.decoding_reqs,
1182
1208
  spec_algorithm=self.spec_algorithm,
1209
+ enable_custom_logit_processor=self.enable_custom_logit_processor,
1183
1210
  )
1184
1211
 
1185
1212
  def __str__(self):
@@ -24,6 +24,7 @@ import torch
24
24
 
25
25
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
26
26
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
27
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
27
28
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
28
29
 
29
30
  # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
@@ -250,23 +251,24 @@ class PrefillAdder:
250
251
  def __init__(
251
252
  self,
252
253
  tree_cache: BasePrefixCache,
254
+ token_to_kv_pool: BaseTokenToKVPool,
253
255
  running_batch: ScheduleBatch,
254
256
  new_token_ratio: float,
255
- rem_total_tokens: int,
256
257
  rem_input_tokens: int,
257
258
  rem_chunk_tokens: Optional[int],
258
259
  mixed_with_decode_tokens: int = 0,
259
260
  ):
260
261
  self.tree_cache = tree_cache
262
+ self.token_to_kv_pool = token_to_kv_pool
261
263
  self.running_batch = running_batch
262
264
  self.new_token_ratio = new_token_ratio
263
- self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
264
265
  self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
265
266
  self.rem_chunk_tokens = rem_chunk_tokens
266
267
  if self.rem_chunk_tokens is not None:
267
268
  self.rem_chunk_tokens -= mixed_with_decode_tokens
268
269
 
269
- self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
270
+ self.rem_total_token_offset = mixed_with_decode_tokens
271
+ self.cur_rem_token_offset = mixed_with_decode_tokens
270
272
 
271
273
  self.req_states = None
272
274
  self.can_run_list = []
@@ -275,8 +277,7 @@ class PrefillAdder:
275
277
  self.log_input_tokens = 0
276
278
 
277
279
  if running_batch is not None:
278
- # Pre-remove the tokens which will be occupied by the running requests
279
- self.rem_total_tokens -= sum(
280
+ self.rem_total_token_offset += sum(
280
281
  [
281
282
  min(
282
283
  (r.sampling_params.max_new_tokens - len(r.output_ids)),
@@ -287,6 +288,22 @@ class PrefillAdder:
287
288
  ]
288
289
  )
289
290
 
291
+ @property
292
+ def rem_total_tokens(self):
293
+ return (
294
+ self.token_to_kv_pool.available_size()
295
+ + self.tree_cache.evictable_size()
296
+ - self.rem_total_token_offset
297
+ )
298
+
299
+ @property
300
+ def cur_rem_tokens(self):
301
+ return (
302
+ self.token_to_kv_pool.available_size()
303
+ + self.tree_cache.evictable_size()
304
+ - self.cur_rem_token_offset
305
+ )
306
+
290
307
  def budget_state(self):
291
308
  if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
292
309
  return AddReqResult.NO_TOKEN
@@ -301,8 +318,8 @@ class PrefillAdder:
301
318
  def _prefill_one_req(
302
319
  self, prefix_len: int, extend_input_len: int, max_new_tokens: int
303
320
  ):
304
- self.rem_total_tokens -= extend_input_len + max_new_tokens
305
- self.cur_rem_tokens -= extend_input_len
321
+ self.rem_total_token_offset += extend_input_len + max_new_tokens
322
+ self.cur_rem_token_offset += extend_input_len
306
323
  self.rem_input_tokens -= extend_input_len
307
324
  if self.rem_chunk_tokens is not None:
308
325
  self.rem_chunk_tokens -= extend_input_len
@@ -332,12 +349,10 @@ class PrefillAdder:
332
349
  @contextmanager
333
350
  def _lock_node(self, last_node: TreeNode):
334
351
  try:
335
- delta = self.tree_cache.inc_lock_ref(last_node)
336
- self.rem_total_tokens += delta
352
+ self.tree_cache.inc_lock_ref(last_node)
337
353
  yield None
338
354
  finally:
339
- delta = self.tree_cache.dec_lock_ref(last_node)
340
- self.rem_total_tokens += delta
355
+ self.tree_cache.dec_lock_ref(last_node)
341
356
 
342
357
  def add_one_req_ignore_eos(self, req: Req):
343
358
  def add_req_state(r, insert_sort=False):
@@ -433,7 +448,6 @@ class PrefillAdder:
433
448
  or input_tokens <= self.rem_chunk_tokens
434
449
  or (
435
450
  req.return_logprob
436
- and req.normalized_prompt_logprob is None
437
451
  and req.logprob_start_len != len(req.origin_input_ids) - 1
438
452
  )
439
453
  ):