sglang 0.4.1.post7__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  12. sglang/srt/layers/attention/vision.py +243 -40
  13. sglang/srt/layers/dp_attention.py +3 -1
  14. sglang/srt/layers/layernorm.py +5 -5
  15. sglang/srt/layers/linear.py +24 -9
  16. sglang/srt/layers/logits_processor.py +1 -1
  17. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  18. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  22. sglang/srt/layers/parameter.py +16 -7
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/fp8.py +11 -1
  33. sglang/srt/layers/rotary_embedding.py +34 -13
  34. sglang/srt/layers/sampler.py +33 -10
  35. sglang/srt/layers/torchao_utils.py +12 -6
  36. sglang/srt/managers/detokenizer_manager.py +1 -0
  37. sglang/srt/managers/image_processor.py +77 -38
  38. sglang/srt/managers/io_struct.py +36 -5
  39. sglang/srt/managers/schedule_batch.py +31 -25
  40. sglang/srt/managers/scheduler.py +78 -38
  41. sglang/srt/managers/tokenizer_manager.py +4 -0
  42. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  43. sglang/srt/mem_cache/chunk_cache.py +3 -0
  44. sglang/srt/mem_cache/radix_cache.py +30 -1
  45. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  46. sglang/srt/model_executor/forward_batch_info.py +5 -7
  47. sglang/srt/model_executor/model_runner.py +7 -4
  48. sglang/srt/model_loader/loader.py +75 -0
  49. sglang/srt/model_loader/weight_utils.py +91 -5
  50. sglang/srt/models/commandr.py +14 -2
  51. sglang/srt/models/dbrx.py +9 -1
  52. sglang/srt/models/deepseek_v2.py +3 -3
  53. sglang/srt/models/gemma2.py +9 -1
  54. sglang/srt/models/grok.py +1 -0
  55. sglang/srt/models/minicpm3.py +3 -3
  56. sglang/srt/models/minicpmv.py +129 -76
  57. sglang/srt/models/mllama.py +16 -56
  58. sglang/srt/models/qwen2.py +4 -1
  59. sglang/srt/models/qwen2_vl.py +18 -8
  60. sglang/srt/models/torch_native_llama.py +17 -4
  61. sglang/srt/openai_api/adapter.py +139 -37
  62. sglang/srt/openai_api/protocol.py +5 -4
  63. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  64. sglang/srt/sampling/sampling_batch_info.py +4 -14
  65. sglang/srt/server.py +2 -2
  66. sglang/srt/server_args.py +26 -1
  67. sglang/srt/speculative/eagle_utils.py +37 -15
  68. sglang/srt/speculative/eagle_worker.py +11 -13
  69. sglang/srt/utils.py +62 -67
  70. sglang/test/test_programs.py +1 -0
  71. sglang/test/test_utils.py +81 -22
  72. sglang/utils.py +42 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +8 -8
  75. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +78 -67
  76. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
6
 
7
7
  import torch
8
8
  import torch.nn as nn
9
+ from vllm import _custom_ops as ops
9
10
  from vllm.model_executor.custom_op import CustomOp
10
11
 
11
12
  from sglang.srt.layers.custom_op_util import register_custom_op
13
+ from sglang.srt.utils import is_cuda_available
14
+
15
+ _is_cuda_available = is_cuda_available()
16
+ if _is_cuda_available:
17
+ from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
12
18
 
13
19
 
14
20
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
75
81
  self.dtype = dtype
76
82
 
77
83
  cache = self._compute_cos_sin_cache()
78
- cache = cache.to(dtype)
84
+ # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
85
+ if not _is_cuda_available:
86
+ cache = cache.to(dtype)
79
87
  self.cos_sin_cache: torch.Tensor
80
88
  self.register_buffer("cos_sin_cache", cache, persistent=False)
81
89
 
@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
141
149
  key: torch.Tensor,
142
150
  offsets: Optional[torch.Tensor] = None,
143
151
  ) -> Tuple[torch.Tensor, torch.Tensor]:
144
- from vllm import _custom_ops as ops
145
-
146
- self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
147
- ops.rotary_embedding(
148
- positions,
149
- query,
150
- key,
151
- self.head_size,
152
- self.cos_sin_cache,
153
- self.is_neox_style,
154
- )
152
+ if _is_cuda_available:
153
+ apply_rope_with_cos_sin_cache_inplace(
154
+ positions=positions,
155
+ query=query,
156
+ key=key,
157
+ head_size=self.head_size,
158
+ cos_sin_cache=self.cos_sin_cache,
159
+ is_neox=self.is_neox_style,
160
+ )
161
+ else:
162
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
163
+ ops.rotary_embedding(
164
+ positions,
165
+ query,
166
+ key,
167
+ self.head_size,
168
+ self.cos_sin_cache,
169
+ self.is_neox_style,
170
+ )
155
171
  return query, key
156
172
 
157
173
  def forward_xpu(
@@ -1018,7 +1034,12 @@ def get_rope(
1018
1034
  head_size, rotary_dim, max_position, base, is_neox_style, dtype
1019
1035
  )
1020
1036
  else:
1021
- scaling_type = rope_scaling["rope_type"]
1037
+ if "rope_type" in rope_scaling:
1038
+ scaling_type = rope_scaling["rope_type"]
1039
+ elif "type" in rope_scaling:
1040
+ scaling_type = rope_scaling["type"]
1041
+ else:
1042
+ raise ValueError("Unknown RoPE scaling type")
1022
1043
 
1023
1044
  if scaling_type == "llama3":
1024
1045
  scaling_factor = rope_scaling["factor"]
@@ -1,17 +1,19 @@
1
1
  import logging
2
- from typing import Dict, List
2
+ from typing import List
3
3
 
4
4
  import torch
5
+ import torch.distributed as dist
5
6
  from torch import nn
6
7
 
8
+ from sglang.srt.distributed import get_tensor_model_parallel_group
9
+ from sglang.srt.layers.dp_attention import get_attention_tp_group
7
10
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
8
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
- from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
10
12
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
11
- from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
13
+ from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
12
14
 
13
- if is_flashinfer_available():
14
- from flashinfer.sampling import (
15
+ if is_cuda_available():
16
+ from sgl_kernel import (
15
17
  min_p_sampling_from_probs,
16
18
  top_k_renorm_prob,
17
19
  top_k_top_p_sampling_from_probs,
@@ -21,11 +23,17 @@ if is_flashinfer_available():
21
23
 
22
24
  logger = logging.getLogger(__name__)
23
25
 
26
+ SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
27
+
24
28
 
25
29
  class Sampler(nn.Module):
26
30
  def __init__(self):
27
31
  super().__init__()
28
32
  self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
33
+ self.tp_sync_group = get_tensor_model_parallel_group().device_group
34
+
35
+ if global_server_args_dict["enable_dp_attention"]:
36
+ self.tp_sync_group = get_attention_tp_group().device_group
29
37
 
30
38
  def forward(
31
39
  self,
@@ -64,9 +72,11 @@ class Sampler(nn.Module):
64
72
  # NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
65
73
  # https://github.com/flashinfer-ai/flashinfer/issues/708
66
74
  # so we use the torch implementation.
75
+
76
+ # clamp to avoid -inf
67
77
  logprobs = torch.log(
68
78
  top_p_normalize_probs_torch(probs, sampling_info.top_ps)
69
- )
79
+ ).clamp(min=torch.finfo(probs.dtype).min)
70
80
 
71
81
  max_top_k_round, batch_size = 32, probs.shape[0]
72
82
  uniform_samples = torch.rand(
@@ -101,16 +111,15 @@ class Sampler(nn.Module):
101
111
  sampling_info.need_min_p_sampling,
102
112
  )
103
113
  if return_logprob:
114
+ # clamp to avoid -inf
104
115
  logprobs = torch.log(
105
116
  top_p_normalize_probs_torch(probs, sampling_info.top_ps)
106
- )
117
+ ).clamp(min=torch.finfo(probs.dtype).min)
107
118
  else:
108
119
  raise ValueError(
109
120
  f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
110
121
  )
111
122
 
112
- batch_next_token_ids = batch_next_token_ids.to(torch.int32)
113
-
114
123
  # Attach logprobs to logits_output (in-place modification)
115
124
  if return_logprob:
116
125
  if any(x > 0 for x in top_logprobs_nums):
@@ -124,7 +133,21 @@ class Sampler(nn.Module):
124
133
  batch_next_token_ids,
125
134
  ]
126
135
 
127
- return batch_next_token_ids
136
+ if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
137
+ # For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
138
+ # This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
139
+ # the last all-reduce, the last lm_head matmul, and all sampling kernels.
140
+ # These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
141
+ # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
142
+ # When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
143
+
144
+ torch.distributed.all_reduce(
145
+ batch_next_token_ids,
146
+ op=dist.ReduceOp.MIN,
147
+ group=self.tp_sync_group,
148
+ )
149
+
150
+ return batch_next_token_ids.to(torch.int32)
128
151
 
129
152
  def _apply_custom_logit_processor(
130
153
  self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
@@ -5,6 +5,7 @@ Common utilities for torchao.
5
5
  import logging
6
6
  import os
7
7
  import pwd
8
+ from typing import Callable, Optional
8
9
 
9
10
  import torch
10
11
 
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
27
28
  return True
28
29
 
29
30
 
31
+ def proj_filter(
32
+ module: torch.nn.Module,
33
+ fqn: str,
34
+ ):
35
+ """Filter function for quantizing projection layers."""
36
+ return "proj" in fqn
37
+
38
+
30
39
  def apply_torchao_config_to_model(
31
- model: torch.nn.Module, torchao_config: str, filter_fn=None
40
+ model: torch.nn.Module,
41
+ torchao_config: str,
42
+ filter_fn: Optional[Callable] = proj_filter,
32
43
  ):
33
44
  """Quantize a modelwith torchao quantization specified by torchao_config
34
45
 
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
49
60
  )
50
61
  from torchao.quantization.observer import PerRow, PerTensor
51
62
 
52
- if filter_fn is None:
53
-
54
- def filter_fn(module, fqn):
55
- return "proj" in fqn
56
-
57
63
  if torchao_config == "" or torchao_config is None:
58
64
  return model
59
65
  elif "int8wo" in torchao_config:
@@ -201,6 +201,7 @@ class DetokenizerManager:
201
201
  prompt_tokens=recv_obj.prompt_tokens,
202
202
  completion_tokens=recv_obj.completion_tokens,
203
203
  cached_tokens=recv_obj.cached_tokens,
204
+ spec_verify_ct=recv_obj.spec_verify_ct,
204
205
  input_token_logprobs_val=recv_obj.input_token_logprobs_val,
205
206
  input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
206
207
  output_token_logprobs_val=recv_obj.output_token_logprobs_val,
@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
240
240
  class MiniCPMVImageProcessor(BaseImageProcessor):
241
241
  def __init__(self, hf_config, server_args, _processor):
242
242
  super().__init__(hf_config, server_args, _processor)
243
+ self.IMAGE_TOKEN = "(<image>./</image>)"
243
244
 
244
245
  @staticmethod
245
246
  def _process_images_task(images, input_text):
@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
271
272
  async def process_images_async(
272
273
  self,
273
274
  image_data: List[Union[str, bytes]],
274
- input_text,
275
+ input_ids,
275
276
  request_obj,
276
277
  max_req_input_len,
277
278
  ):
@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
282
283
  image_data = [image_data]
283
284
 
284
285
  image_hashes, image_sizes = [], []
285
- raw_images = []
286
- IMAGE_TOKEN = "(<image>./</image>)"
286
+ all_frames = []
287
287
 
288
- # roughly calculate the max number of frames
289
- # TODO: the process should be applied to all the visual inputs
288
+ # roughly calculate the max number of frames under the max_req_input_len limit
290
289
  def calculate_max_num_frames() -> int:
291
290
  # Model-specific
292
291
  NUM_TOKEN_PER_FRAME = 330
293
292
 
294
- ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
293
+ ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
295
294
  return min(ret, 100)
296
295
 
297
- # if cuda OOM set a smaller number
298
296
  MAX_NUM_FRAMES = calculate_max_num_frames()
299
- print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
300
297
 
301
- def encode_video(video_path):
298
+ # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
299
+
300
+ def get_estimated_frames_list():
301
+ """
302
+ estimate the total frame count from all visual input
303
+ """
304
+ # Before processing inputs
305
+ estimated_frames_list = []
306
+ for image in image_data:
307
+ if isinstance(image, str) and image.startswith("video:"):
308
+ path = image[len("video:") :]
309
+ # Estimate frames for the video
310
+ vr = VideoReader(path, ctx=cpu(0))
311
+ num_frames = len(vr)
312
+ else:
313
+ # For images, each contributes one frame
314
+ num_frames = 1
315
+ estimated_frames_list.append(num_frames)
316
+
317
+ return estimated_frames_list
318
+
319
+ estimated_frames_list = get_estimated_frames_list()
320
+ total_frame_count = sum(estimated_frames_list)
321
+ scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
322
+
323
+ def encode_video(video_path, frame_count_limit=None):
302
324
  if not os.path.exists(video_path):
303
325
  logger.error(f"Video {video_path} does not exist")
304
326
  return []
305
327
 
306
- if MAX_NUM_FRAMES == 0:
328
+ if frame_count_limit == 0:
307
329
  return []
308
330
 
309
331
  def uniform_sample(l, n):
@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
314
336
  vr = VideoReader(video_path, ctx=cpu(0))
315
337
  sample_fps = round(vr.get_avg_fps() / 1) # FPS
316
338
  frame_idx = [i for i in range(0, len(vr), sample_fps)]
317
- if len(frame_idx) > MAX_NUM_FRAMES:
318
- frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
339
+ if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
340
+ frame_idx = uniform_sample(frame_idx, frame_count_limit)
319
341
  frames = vr.get_batch(frame_idx).asnumpy()
320
342
  frames = [Image.fromarray(v.astype("uint8")) for v in frames]
321
343
  return frames
322
344
 
323
- if isinstance(input_text, list):
324
- assert len(input_text) and isinstance(input_text[0], int)
325
- input_text = self._processor.tokenizer.decode(input_text)
326
-
345
+ if isinstance(input_ids, list):
346
+ assert len(input_ids) and isinstance(input_ids[0], int)
347
+ input_text = self._processor.tokenizer.decode(input_ids)
348
+ else:
349
+ input_text = input_ids
327
350
  # MiniCPMV requires each frame of video as a single image token
328
- text_parts = input_text.split(IMAGE_TOKEN)
351
+ text_parts = input_text.split(self.IMAGE_TOKEN)
329
352
  new_text_parts = []
330
353
 
331
- for image_index, image in enumerate(image_data):
332
- try:
333
- if isinstance(image, str) and image.startswith("video:"):
334
- path = image[len("video:") :]
335
- frames = encode_video(path)
336
- else:
337
- raw_image, size = load_image(image)
338
- frames = [raw_image]
339
- if len(frames) == 0:
340
- continue
341
- except FileNotFoundError as e:
342
- print(e)
343
- return None
344
-
345
- image_sizes += frames[0].size * len(frames)
346
- image_hashes += [hash(image)] * len(frames)
347
- raw_images += frames
354
+ # Process each input with allocated frames
355
+ for image_index, (image, estimated_frames) in enumerate(
356
+ zip(image_data, estimated_frames_list)
357
+ ):
358
+ if len(all_frames) >= MAX_NUM_FRAMES:
359
+ frames_to_process = 0
360
+ else:
361
+ frames_to_process = max(1, int(estimated_frames * scaling_factor))
362
+
363
+ if frames_to_process == 0:
364
+ frames = []
365
+ else:
366
+ try:
367
+ if isinstance(image, str) and image.startswith("video:"):
368
+ path = image[len("video:") :]
369
+ frames = encode_video(path, frame_count_limit=frames_to_process)
370
+ else:
371
+ raw_image, _size = load_image(image)
372
+ frames = [raw_image]
373
+ if len(frames) == 0:
374
+ continue
375
+ except FileNotFoundError as e:
376
+ print(e)
377
+ return None
378
+ image_sizes += frames[0].size * len(frames)
379
+ image_hashes += [hash(image)] * len(frames)
380
+ all_frames += frames
381
+
382
+ assert frames_to_process == len(frames)
383
+
348
384
  new_text_parts.append(text_parts[image_index])
349
- new_text_parts.append(IMAGE_TOKEN * len(frames))
385
+
386
+ if frames_to_process != 0:
387
+ new_text_parts.append(self.IMAGE_TOKEN * len(frames))
350
388
 
351
389
  new_text_parts.append(text_parts[-1])
390
+
352
391
  input_text = "".join(new_text_parts)
353
- if len(raw_images) == 0:
392
+
393
+ if len(all_frames) == 0:
354
394
  return None
355
- res = await self._process_images(images=raw_images, input_text=input_text)
395
+ res = await self._process_images(images=all_frames, input_text=input_text)
356
396
  pixel_values = res["pixel_values"]
357
397
  tgt_sizes = res["tgt_sizes"]
358
398
  input_ids = res["input_ids"]
@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
364
404
  if tokenizer.slice_start_id:
365
405
  slice_start_id = [tokenizer.slice_start_id]
366
406
  slice_end_id = [tokenizer.slice_end_id]
367
-
368
407
  return {
369
408
  "input_ids": input_ids.flatten().tolist(),
370
409
  "pixel_values": pixel_values,
@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
17
17
  """
18
18
 
19
19
  import uuid
20
- from dataclasses import dataclass
20
+ from dataclasses import dataclass, field
21
21
  from enum import Enum
22
22
  from typing import Dict, List, Optional, Union
23
23
 
@@ -69,8 +69,10 @@ class GenerateReqInput:
69
69
 
70
70
  # Session info for continual prompting
71
71
  session_params: Optional[Union[List[Dict], Dict]] = None
72
- # Custom logit processor (serialized function)
73
- custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
72
+ # Custom logit processor for advanced sampling control. Must be a serialized instance
73
+ # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
74
+ # Use the processor's `to_str()` method to generate the serialized string.
75
+ custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
74
76
 
75
77
  def normalize_batch_and_arguments(self):
76
78
  if (
@@ -248,8 +250,9 @@ class TokenizedGenerateReqInput:
248
250
  # Session info for continual prompting
249
251
  session_params: Optional[SessionParams] = None
250
252
 
251
- # Custom logit processor (serialized function)
252
- # TODO (hpguo): Add an example and update doc string here
253
+ # Custom logit processor for advanced sampling control. Must be a serialized instance
254
+ # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
255
+ # Use the processor's `to_str()` method to generate the serialized string.
253
256
  custom_logit_processor: Optional[str] = None
254
257
 
255
258
 
@@ -351,10 +354,13 @@ class BatchTokenIDOut:
351
354
  skip_special_tokens: List[bool]
352
355
  spaces_between_special_tokens: List[bool]
353
356
  no_stop_trim: List[bool]
357
+
354
358
  # Token counts
355
359
  prompt_tokens: List[int]
356
360
  completion_tokens: List[int]
357
361
  cached_tokens: List[int]
362
+ spec_verify_ct: List[int]
363
+
358
364
  # Logprobs
359
365
  input_token_logprobs_val: List[float]
360
366
  input_token_logprobs_idx: List[int]
@@ -379,6 +385,7 @@ class BatchStrOut:
379
385
  prompt_tokens: List[int]
380
386
  completion_tokens: List[int]
381
387
  cached_tokens: List[int]
388
+ spec_verify_ct: List[int]
382
389
 
383
390
  # Logprobs
384
391
  input_token_logprobs_val: List[float]
@@ -533,3 +540,27 @@ class CloseSessionReqInput:
533
540
  class OpenSessionReqOutput:
534
541
  session_id: Optional[str]
535
542
  success: bool
543
+
544
+
545
+ @dataclass
546
+ class Function:
547
+ description: Optional[str] = None
548
+ name: Optional[str] = None
549
+ parameters: Optional[object] = None
550
+
551
+
552
+ @dataclass
553
+ class Tool:
554
+ function: Function
555
+ type: Optional[str] = "function"
556
+
557
+
558
+ @dataclass
559
+ class FunctionCallReqInput:
560
+ text: str # The text to parse.
561
+ tools: List[Tool] = field(
562
+ default_factory=list
563
+ ) # A list of available function tools (name, parameters, etc.).
564
+ tool_call_parser: Optional[str] = (
565
+ None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
566
+ )
@@ -247,12 +247,12 @@ class Req:
247
247
  # Each decode stage's output ids
248
248
  self.output_ids = []
249
249
  # fill_ids = origin_input_ids + output_ids. Updated if chunked.
250
+ self.fill_ids = None
250
251
  self.session_id = session_id
251
252
  self.input_embeds = input_embeds
252
253
 
253
254
  # Sampling info
254
255
  self.sampling_params = sampling_params
255
- self.lora_path = lora_path
256
256
  self.custom_logit_processor = custom_logit_processor
257
257
 
258
258
  # Memory pool info
@@ -300,7 +300,7 @@ class Req:
300
300
  self.logprob_start_len = 0
301
301
  self.top_logprobs_num = top_logprobs_num
302
302
 
303
- # Logprobs (return value)
303
+ # Logprobs (return values)
304
304
  self.input_token_logprobs_val: Optional[List[float]] = None
305
305
  self.input_token_logprobs_idx: Optional[List[int]] = None
306
306
  self.input_top_logprobs_val: Optional[List[float]] = None
@@ -329,8 +329,14 @@ class Req:
329
329
  # Constrained decoding
330
330
  self.grammar: Optional[BaseGrammarObject] = None
331
331
 
332
- # The number of cached tokens, that were already cached in the KV cache
332
+ # The number of cached tokens that were already cached in the KV cache
333
333
  self.cached_tokens = 0
334
+ self.already_computed = 0
335
+
336
+ # The number of verification forward passes in the speculative decoding.
337
+ # This is used to compute the average acceptance length per request.
338
+ self.spec_verify_ct = 0
339
+ self.lora_path = lora_path
334
340
 
335
341
  def extend_image_inputs(self, image_inputs):
336
342
  if self.image_inputs is None:
@@ -550,13 +556,13 @@ class ScheduleBatch:
550
556
  next_batch_sampling_info: SamplingBatchInfo = None
551
557
 
552
558
  # Batched arguments to model runner
553
- input_ids: torch.Tensor = None
554
- input_embeds: torch.Tensor = None
555
- req_pool_indices: torch.Tensor = None
556
- seq_lens: torch.Tensor = None
559
+ input_ids: torch.Tensor = None # shape: [b], int32
560
+ input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
561
+ req_pool_indices: torch.Tensor = None # shape: [b], int32
562
+ seq_lens: torch.Tensor = None # shape: [b], int64
557
563
  # The output locations of the KV cache
558
- out_cache_loc: torch.Tensor = None
559
- output_ids: torch.Tensor = None
564
+ out_cache_loc: torch.Tensor = None # shape: [b], int32
565
+ output_ids: torch.Tensor = None # shape: [b], int32
560
566
 
561
567
  # The sum of all sequence lengths
562
568
  seq_lens_sum: int = None
@@ -750,13 +756,6 @@ class ScheduleBatch:
750
756
 
751
757
  pt = 0
752
758
  for i, req in enumerate(reqs):
753
- already_computed = (
754
- req.extend_logprob_start_len + 1 + req.cached_tokens
755
- if req.extend_logprob_start_len > 0
756
- else 0
757
- )
758
- req.cached_tokens += len(req.prefix_indices) - already_computed
759
-
760
759
  req.req_pool_idx = req_pool_indices[i]
761
760
  pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
762
761
  seq_lens.append(seq_len)
@@ -772,15 +771,20 @@ class ScheduleBatch:
772
771
  # If req.input_embeds is already a list, append its content directly
773
772
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
774
773
 
775
- # Compute the relative logprob_start_len in an extend batch
776
- if req.logprob_start_len >= pre_len:
777
- extend_logprob_start_len = min(
778
- req.logprob_start_len - pre_len, req.extend_input_len - 1
779
- )
780
- else:
781
- extend_logprob_start_len = req.extend_input_len - 1
774
+ if req.return_logprob:
775
+ # Compute the relative logprob_start_len in an extend batch
776
+ if req.logprob_start_len >= pre_len:
777
+ extend_logprob_start_len = min(
778
+ req.logprob_start_len - pre_len, req.extend_input_len - 1
779
+ )
780
+ else:
781
+ raise RuntimeError(
782
+ f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
783
+ )
784
+ req.extend_logprob_start_len = extend_logprob_start_len
782
785
 
783
- req.extend_logprob_start_len = extend_logprob_start_len
786
+ req.cached_tokens += pre_len - req.already_computed
787
+ req.already_computed = seq_len
784
788
  req.is_retracted = False
785
789
  pre_lens.append(pre_len)
786
790
 
@@ -1026,7 +1030,7 @@ class ScheduleBatch:
1026
1030
  self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
1027
1031
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1028
1032
  self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
1029
- self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
1033
+ self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1030
1034
  self.seq_lens_sum = 0
1031
1035
  self.extend_num_tokens = 0
1032
1036
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
@@ -1112,6 +1116,8 @@ class ScheduleBatch:
1112
1116
  self.has_grammar = any(req.grammar for req in self.reqs)
1113
1117
 
1114
1118
  self.sampling_info.filter_batch(keep_indices, new_indices)
1119
+ if self.spec_info:
1120
+ self.spec_info.filter_batch(new_indices)
1115
1121
 
1116
1122
  def merge_batch(self, other: "ScheduleBatch"):
1117
1123
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because