sglang 0.4.1.post7__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +17 -11
- sglang/bench_one_batch.py +14 -6
- sglang/bench_serving.py +47 -44
- sglang/lang/chat_template.py +31 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
- sglang/srt/entrypoints/engine.py +5 -2
- sglang/srt/entrypoints/http_server.py +24 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
- sglang/srt/layers/attention/vision.py +243 -40
- sglang/srt/layers/dp_attention.py +3 -1
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +24 -9
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -12
- sglang/srt/layers/moe/fused_moe_native.py +17 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
- sglang/srt/layers/parameter.py +16 -7
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +11 -1
- sglang/srt/layers/rotary_embedding.py +34 -13
- sglang/srt/layers/sampler.py +33 -10
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/image_processor.py +77 -38
- sglang/srt/managers/io_struct.py +36 -5
- sglang/srt/managers/schedule_batch.py +31 -25
- sglang/srt/managers/scheduler.py +78 -38
- sglang/srt/managers/tokenizer_manager.py +4 -0
- sglang/srt/mem_cache/base_prefix_cache.py +4 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +30 -1
- sglang/srt/model_executor/cuda_graph_runner.py +23 -25
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +7 -4
- sglang/srt/model_loader/loader.py +75 -0
- sglang/srt/model_loader/weight_utils.py +91 -5
- sglang/srt/models/commandr.py +14 -2
- sglang/srt/models/dbrx.py +9 -1
- sglang/srt/models/deepseek_v2.py +3 -3
- sglang/srt/models/gemma2.py +9 -1
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/minicpm3.py +3 -3
- sglang/srt/models/minicpmv.py +129 -76
- sglang/srt/models/mllama.py +16 -56
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_vl.py +18 -8
- sglang/srt/models/torch_native_llama.py +17 -4
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +5 -4
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +4 -14
- sglang/srt/server.py +2 -2
- sglang/srt/server_args.py +26 -1
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +62 -67
- sglang/test/test_programs.py +1 -0
- sglang/test/test_utils.py +81 -22
- sglang/utils.py +42 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +8 -8
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +78 -67
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
import torch.nn as nn
|
9
|
+
from vllm import _custom_ops as ops
|
9
10
|
from vllm.model_executor.custom_op import CustomOp
|
10
11
|
|
11
12
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
13
|
+
from sglang.srt.utils import is_cuda_available
|
14
|
+
|
15
|
+
_is_cuda_available = is_cuda_available()
|
16
|
+
if _is_cuda_available:
|
17
|
+
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
12
18
|
|
13
19
|
|
14
20
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
|
|
75
81
|
self.dtype = dtype
|
76
82
|
|
77
83
|
cache = self._compute_cos_sin_cache()
|
78
|
-
cache
|
84
|
+
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
85
|
+
if not _is_cuda_available:
|
86
|
+
cache = cache.to(dtype)
|
79
87
|
self.cos_sin_cache: torch.Tensor
|
80
88
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
81
89
|
|
@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
|
|
141
149
|
key: torch.Tensor,
|
142
150
|
offsets: Optional[torch.Tensor] = None,
|
143
151
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
152
|
+
if _is_cuda_available:
|
153
|
+
apply_rope_with_cos_sin_cache_inplace(
|
154
|
+
positions=positions,
|
155
|
+
query=query,
|
156
|
+
key=key,
|
157
|
+
head_size=self.head_size,
|
158
|
+
cos_sin_cache=self.cos_sin_cache,
|
159
|
+
is_neox=self.is_neox_style,
|
160
|
+
)
|
161
|
+
else:
|
162
|
+
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
163
|
+
ops.rotary_embedding(
|
164
|
+
positions,
|
165
|
+
query,
|
166
|
+
key,
|
167
|
+
self.head_size,
|
168
|
+
self.cos_sin_cache,
|
169
|
+
self.is_neox_style,
|
170
|
+
)
|
155
171
|
return query, key
|
156
172
|
|
157
173
|
def forward_xpu(
|
@@ -1018,7 +1034,12 @@ def get_rope(
|
|
1018
1034
|
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
1019
1035
|
)
|
1020
1036
|
else:
|
1021
|
-
|
1037
|
+
if "rope_type" in rope_scaling:
|
1038
|
+
scaling_type = rope_scaling["rope_type"]
|
1039
|
+
elif "type" in rope_scaling:
|
1040
|
+
scaling_type = rope_scaling["type"]
|
1041
|
+
else:
|
1042
|
+
raise ValueError("Unknown RoPE scaling type")
|
1022
1043
|
|
1023
1044
|
if scaling_type == "llama3":
|
1024
1045
|
scaling_factor = rope_scaling["factor"]
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,17 +1,19 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import
|
2
|
+
from typing import List
|
3
3
|
|
4
4
|
import torch
|
5
|
+
import torch.distributed as dist
|
5
6
|
from torch import nn
|
6
7
|
|
8
|
+
from sglang.srt.distributed import get_tensor_model_parallel_group
|
9
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
7
10
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
8
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
-
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
10
12
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
11
|
-
from sglang.srt.utils import crash_on_warnings,
|
13
|
+
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
|
12
14
|
|
13
|
-
if
|
14
|
-
from
|
15
|
+
if is_cuda_available():
|
16
|
+
from sgl_kernel import (
|
15
17
|
min_p_sampling_from_probs,
|
16
18
|
top_k_renorm_prob,
|
17
19
|
top_k_top_p_sampling_from_probs,
|
@@ -21,11 +23,17 @@ if is_flashinfer_available():
|
|
21
23
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
26
|
+
SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
|
27
|
+
|
24
28
|
|
25
29
|
class Sampler(nn.Module):
|
26
30
|
def __init__(self):
|
27
31
|
super().__init__()
|
28
32
|
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
|
33
|
+
self.tp_sync_group = get_tensor_model_parallel_group().device_group
|
34
|
+
|
35
|
+
if global_server_args_dict["enable_dp_attention"]:
|
36
|
+
self.tp_sync_group = get_attention_tp_group().device_group
|
29
37
|
|
30
38
|
def forward(
|
31
39
|
self,
|
@@ -64,9 +72,11 @@ class Sampler(nn.Module):
|
|
64
72
|
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
65
73
|
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
66
74
|
# so we use the torch implementation.
|
75
|
+
|
76
|
+
# clamp to avoid -inf
|
67
77
|
logprobs = torch.log(
|
68
78
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
69
|
-
)
|
79
|
+
).clamp(min=torch.finfo(probs.dtype).min)
|
70
80
|
|
71
81
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
72
82
|
uniform_samples = torch.rand(
|
@@ -101,16 +111,15 @@ class Sampler(nn.Module):
|
|
101
111
|
sampling_info.need_min_p_sampling,
|
102
112
|
)
|
103
113
|
if return_logprob:
|
114
|
+
# clamp to avoid -inf
|
104
115
|
logprobs = torch.log(
|
105
116
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
106
|
-
)
|
117
|
+
).clamp(min=torch.finfo(probs.dtype).min)
|
107
118
|
else:
|
108
119
|
raise ValueError(
|
109
120
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
110
121
|
)
|
111
122
|
|
112
|
-
batch_next_token_ids = batch_next_token_ids.to(torch.int32)
|
113
|
-
|
114
123
|
# Attach logprobs to logits_output (in-place modification)
|
115
124
|
if return_logprob:
|
116
125
|
if any(x > 0 for x in top_logprobs_nums):
|
@@ -124,7 +133,21 @@ class Sampler(nn.Module):
|
|
124
133
|
batch_next_token_ids,
|
125
134
|
]
|
126
135
|
|
127
|
-
|
136
|
+
if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
|
137
|
+
# For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
|
138
|
+
# This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
|
139
|
+
# the last all-reduce, the last lm_head matmul, and all sampling kernels.
|
140
|
+
# These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
|
141
|
+
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
|
142
|
+
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
|
143
|
+
|
144
|
+
torch.distributed.all_reduce(
|
145
|
+
batch_next_token_ids,
|
146
|
+
op=dist.ReduceOp.MIN,
|
147
|
+
group=self.tp_sync_group,
|
148
|
+
)
|
149
|
+
|
150
|
+
return batch_next_token_ids.to(torch.int32)
|
128
151
|
|
129
152
|
def _apply_custom_logit_processor(
|
130
153
|
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
@@ -5,6 +5,7 @@ Common utilities for torchao.
|
|
5
5
|
import logging
|
6
6
|
import os
|
7
7
|
import pwd
|
8
|
+
from typing import Callable, Optional
|
8
9
|
|
9
10
|
import torch
|
10
11
|
|
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
|
|
27
28
|
return True
|
28
29
|
|
29
30
|
|
31
|
+
def proj_filter(
|
32
|
+
module: torch.nn.Module,
|
33
|
+
fqn: str,
|
34
|
+
):
|
35
|
+
"""Filter function for quantizing projection layers."""
|
36
|
+
return "proj" in fqn
|
37
|
+
|
38
|
+
|
30
39
|
def apply_torchao_config_to_model(
|
31
|
-
model: torch.nn.Module,
|
40
|
+
model: torch.nn.Module,
|
41
|
+
torchao_config: str,
|
42
|
+
filter_fn: Optional[Callable] = proj_filter,
|
32
43
|
):
|
33
44
|
"""Quantize a modelwith torchao quantization specified by torchao_config
|
34
45
|
|
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
|
|
49
60
|
)
|
50
61
|
from torchao.quantization.observer import PerRow, PerTensor
|
51
62
|
|
52
|
-
if filter_fn is None:
|
53
|
-
|
54
|
-
def filter_fn(module, fqn):
|
55
|
-
return "proj" in fqn
|
56
|
-
|
57
63
|
if torchao_config == "" or torchao_config is None:
|
58
64
|
return model
|
59
65
|
elif "int8wo" in torchao_config:
|
@@ -201,6 +201,7 @@ class DetokenizerManager:
|
|
201
201
|
prompt_tokens=recv_obj.prompt_tokens,
|
202
202
|
completion_tokens=recv_obj.completion_tokens,
|
203
203
|
cached_tokens=recv_obj.cached_tokens,
|
204
|
+
spec_verify_ct=recv_obj.spec_verify_ct,
|
204
205
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
205
206
|
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
206
207
|
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
|
240
240
|
class MiniCPMVImageProcessor(BaseImageProcessor):
|
241
241
|
def __init__(self, hf_config, server_args, _processor):
|
242
242
|
super().__init__(hf_config, server_args, _processor)
|
243
|
+
self.IMAGE_TOKEN = "(<image>./</image>)"
|
243
244
|
|
244
245
|
@staticmethod
|
245
246
|
def _process_images_task(images, input_text):
|
@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
271
272
|
async def process_images_async(
|
272
273
|
self,
|
273
274
|
image_data: List[Union[str, bytes]],
|
274
|
-
|
275
|
+
input_ids,
|
275
276
|
request_obj,
|
276
277
|
max_req_input_len,
|
277
278
|
):
|
@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
282
283
|
image_data = [image_data]
|
283
284
|
|
284
285
|
image_hashes, image_sizes = [], []
|
285
|
-
|
286
|
-
IMAGE_TOKEN = "(<image>./</image>)"
|
286
|
+
all_frames = []
|
287
287
|
|
288
|
-
# roughly calculate the max number of frames
|
289
|
-
# TODO: the process should be applied to all the visual inputs
|
288
|
+
# roughly calculate the max number of frames under the max_req_input_len limit
|
290
289
|
def calculate_max_num_frames() -> int:
|
291
290
|
# Model-specific
|
292
291
|
NUM_TOKEN_PER_FRAME = 330
|
293
292
|
|
294
|
-
ret = (max_req_input_len - len(
|
293
|
+
ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
|
295
294
|
return min(ret, 100)
|
296
295
|
|
297
|
-
# if cuda OOM set a smaller number
|
298
296
|
MAX_NUM_FRAMES = calculate_max_num_frames()
|
299
|
-
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
300
297
|
|
301
|
-
|
298
|
+
# print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
299
|
+
|
300
|
+
def get_estimated_frames_list():
|
301
|
+
"""
|
302
|
+
estimate the total frame count from all visual input
|
303
|
+
"""
|
304
|
+
# Before processing inputs
|
305
|
+
estimated_frames_list = []
|
306
|
+
for image in image_data:
|
307
|
+
if isinstance(image, str) and image.startswith("video:"):
|
308
|
+
path = image[len("video:") :]
|
309
|
+
# Estimate frames for the video
|
310
|
+
vr = VideoReader(path, ctx=cpu(0))
|
311
|
+
num_frames = len(vr)
|
312
|
+
else:
|
313
|
+
# For images, each contributes one frame
|
314
|
+
num_frames = 1
|
315
|
+
estimated_frames_list.append(num_frames)
|
316
|
+
|
317
|
+
return estimated_frames_list
|
318
|
+
|
319
|
+
estimated_frames_list = get_estimated_frames_list()
|
320
|
+
total_frame_count = sum(estimated_frames_list)
|
321
|
+
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
322
|
+
|
323
|
+
def encode_video(video_path, frame_count_limit=None):
|
302
324
|
if not os.path.exists(video_path):
|
303
325
|
logger.error(f"Video {video_path} does not exist")
|
304
326
|
return []
|
305
327
|
|
306
|
-
if
|
328
|
+
if frame_count_limit == 0:
|
307
329
|
return []
|
308
330
|
|
309
331
|
def uniform_sample(l, n):
|
@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
314
336
|
vr = VideoReader(video_path, ctx=cpu(0))
|
315
337
|
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
316
338
|
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
317
|
-
if len(frame_idx) >
|
318
|
-
frame_idx = uniform_sample(frame_idx,
|
339
|
+
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
|
340
|
+
frame_idx = uniform_sample(frame_idx, frame_count_limit)
|
319
341
|
frames = vr.get_batch(frame_idx).asnumpy()
|
320
342
|
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
321
343
|
return frames
|
322
344
|
|
323
|
-
if isinstance(
|
324
|
-
assert len(
|
325
|
-
input_text = self._processor.tokenizer.decode(
|
326
|
-
|
345
|
+
if isinstance(input_ids, list):
|
346
|
+
assert len(input_ids) and isinstance(input_ids[0], int)
|
347
|
+
input_text = self._processor.tokenizer.decode(input_ids)
|
348
|
+
else:
|
349
|
+
input_text = input_ids
|
327
350
|
# MiniCPMV requires each frame of video as a single image token
|
328
|
-
text_parts = input_text.split(IMAGE_TOKEN)
|
351
|
+
text_parts = input_text.split(self.IMAGE_TOKEN)
|
329
352
|
new_text_parts = []
|
330
353
|
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
354
|
+
# Process each input with allocated frames
|
355
|
+
for image_index, (image, estimated_frames) in enumerate(
|
356
|
+
zip(image_data, estimated_frames_list)
|
357
|
+
):
|
358
|
+
if len(all_frames) >= MAX_NUM_FRAMES:
|
359
|
+
frames_to_process = 0
|
360
|
+
else:
|
361
|
+
frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
362
|
+
|
363
|
+
if frames_to_process == 0:
|
364
|
+
frames = []
|
365
|
+
else:
|
366
|
+
try:
|
367
|
+
if isinstance(image, str) and image.startswith("video:"):
|
368
|
+
path = image[len("video:") :]
|
369
|
+
frames = encode_video(path, frame_count_limit=frames_to_process)
|
370
|
+
else:
|
371
|
+
raw_image, _size = load_image(image)
|
372
|
+
frames = [raw_image]
|
373
|
+
if len(frames) == 0:
|
374
|
+
continue
|
375
|
+
except FileNotFoundError as e:
|
376
|
+
print(e)
|
377
|
+
return None
|
378
|
+
image_sizes += frames[0].size * len(frames)
|
379
|
+
image_hashes += [hash(image)] * len(frames)
|
380
|
+
all_frames += frames
|
381
|
+
|
382
|
+
assert frames_to_process == len(frames)
|
383
|
+
|
348
384
|
new_text_parts.append(text_parts[image_index])
|
349
|
-
|
385
|
+
|
386
|
+
if frames_to_process != 0:
|
387
|
+
new_text_parts.append(self.IMAGE_TOKEN * len(frames))
|
350
388
|
|
351
389
|
new_text_parts.append(text_parts[-1])
|
390
|
+
|
352
391
|
input_text = "".join(new_text_parts)
|
353
|
-
|
392
|
+
|
393
|
+
if len(all_frames) == 0:
|
354
394
|
return None
|
355
|
-
res = await self._process_images(images=
|
395
|
+
res = await self._process_images(images=all_frames, input_text=input_text)
|
356
396
|
pixel_values = res["pixel_values"]
|
357
397
|
tgt_sizes = res["tgt_sizes"]
|
358
398
|
input_ids = res["input_ids"]
|
@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
364
404
|
if tokenizer.slice_start_id:
|
365
405
|
slice_start_id = [tokenizer.slice_start_id]
|
366
406
|
slice_end_id = [tokenizer.slice_end_id]
|
367
|
-
|
368
407
|
return {
|
369
408
|
"input_ids": input_ids.flatten().tolist(),
|
370
409
|
"pixel_values": pixel_values,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
17
17
|
"""
|
18
18
|
|
19
19
|
import uuid
|
20
|
-
from dataclasses import dataclass
|
20
|
+
from dataclasses import dataclass, field
|
21
21
|
from enum import Enum
|
22
22
|
from typing import Dict, List, Optional, Union
|
23
23
|
|
@@ -69,8 +69,10 @@ class GenerateReqInput:
|
|
69
69
|
|
70
70
|
# Session info for continual prompting
|
71
71
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
72
|
-
# Custom logit processor
|
73
|
-
|
72
|
+
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
73
|
+
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
74
|
+
# Use the processor's `to_str()` method to generate the serialized string.
|
75
|
+
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
74
76
|
|
75
77
|
def normalize_batch_and_arguments(self):
|
76
78
|
if (
|
@@ -248,8 +250,9 @@ class TokenizedGenerateReqInput:
|
|
248
250
|
# Session info for continual prompting
|
249
251
|
session_params: Optional[SessionParams] = None
|
250
252
|
|
251
|
-
# Custom logit processor
|
252
|
-
#
|
253
|
+
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
254
|
+
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
255
|
+
# Use the processor's `to_str()` method to generate the serialized string.
|
253
256
|
custom_logit_processor: Optional[str] = None
|
254
257
|
|
255
258
|
|
@@ -351,10 +354,13 @@ class BatchTokenIDOut:
|
|
351
354
|
skip_special_tokens: List[bool]
|
352
355
|
spaces_between_special_tokens: List[bool]
|
353
356
|
no_stop_trim: List[bool]
|
357
|
+
|
354
358
|
# Token counts
|
355
359
|
prompt_tokens: List[int]
|
356
360
|
completion_tokens: List[int]
|
357
361
|
cached_tokens: List[int]
|
362
|
+
spec_verify_ct: List[int]
|
363
|
+
|
358
364
|
# Logprobs
|
359
365
|
input_token_logprobs_val: List[float]
|
360
366
|
input_token_logprobs_idx: List[int]
|
@@ -379,6 +385,7 @@ class BatchStrOut:
|
|
379
385
|
prompt_tokens: List[int]
|
380
386
|
completion_tokens: List[int]
|
381
387
|
cached_tokens: List[int]
|
388
|
+
spec_verify_ct: List[int]
|
382
389
|
|
383
390
|
# Logprobs
|
384
391
|
input_token_logprobs_val: List[float]
|
@@ -533,3 +540,27 @@ class CloseSessionReqInput:
|
|
533
540
|
class OpenSessionReqOutput:
|
534
541
|
session_id: Optional[str]
|
535
542
|
success: bool
|
543
|
+
|
544
|
+
|
545
|
+
@dataclass
|
546
|
+
class Function:
|
547
|
+
description: Optional[str] = None
|
548
|
+
name: Optional[str] = None
|
549
|
+
parameters: Optional[object] = None
|
550
|
+
|
551
|
+
|
552
|
+
@dataclass
|
553
|
+
class Tool:
|
554
|
+
function: Function
|
555
|
+
type: Optional[str] = "function"
|
556
|
+
|
557
|
+
|
558
|
+
@dataclass
|
559
|
+
class FunctionCallReqInput:
|
560
|
+
text: str # The text to parse.
|
561
|
+
tools: List[Tool] = field(
|
562
|
+
default_factory=list
|
563
|
+
) # A list of available function tools (name, parameters, etc.).
|
564
|
+
tool_call_parser: Optional[str] = (
|
565
|
+
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
566
|
+
)
|
@@ -247,12 +247,12 @@ class Req:
|
|
247
247
|
# Each decode stage's output ids
|
248
248
|
self.output_ids = []
|
249
249
|
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
250
|
+
self.fill_ids = None
|
250
251
|
self.session_id = session_id
|
251
252
|
self.input_embeds = input_embeds
|
252
253
|
|
253
254
|
# Sampling info
|
254
255
|
self.sampling_params = sampling_params
|
255
|
-
self.lora_path = lora_path
|
256
256
|
self.custom_logit_processor = custom_logit_processor
|
257
257
|
|
258
258
|
# Memory pool info
|
@@ -300,7 +300,7 @@ class Req:
|
|
300
300
|
self.logprob_start_len = 0
|
301
301
|
self.top_logprobs_num = top_logprobs_num
|
302
302
|
|
303
|
-
# Logprobs (return
|
303
|
+
# Logprobs (return values)
|
304
304
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
305
305
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
306
306
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
@@ -329,8 +329,14 @@ class Req:
|
|
329
329
|
# Constrained decoding
|
330
330
|
self.grammar: Optional[BaseGrammarObject] = None
|
331
331
|
|
332
|
-
# The number of cached tokens
|
332
|
+
# The number of cached tokens that were already cached in the KV cache
|
333
333
|
self.cached_tokens = 0
|
334
|
+
self.already_computed = 0
|
335
|
+
|
336
|
+
# The number of verification forward passes in the speculative decoding.
|
337
|
+
# This is used to compute the average acceptance length per request.
|
338
|
+
self.spec_verify_ct = 0
|
339
|
+
self.lora_path = lora_path
|
334
340
|
|
335
341
|
def extend_image_inputs(self, image_inputs):
|
336
342
|
if self.image_inputs is None:
|
@@ -550,13 +556,13 @@ class ScheduleBatch:
|
|
550
556
|
next_batch_sampling_info: SamplingBatchInfo = None
|
551
557
|
|
552
558
|
# Batched arguments to model runner
|
553
|
-
input_ids: torch.Tensor = None
|
554
|
-
input_embeds: torch.Tensor = None
|
555
|
-
req_pool_indices: torch.Tensor = None
|
556
|
-
seq_lens: torch.Tensor = None
|
559
|
+
input_ids: torch.Tensor = None # shape: [b], int32
|
560
|
+
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
561
|
+
req_pool_indices: torch.Tensor = None # shape: [b], int32
|
562
|
+
seq_lens: torch.Tensor = None # shape: [b], int64
|
557
563
|
# The output locations of the KV cache
|
558
|
-
out_cache_loc: torch.Tensor = None
|
559
|
-
output_ids: torch.Tensor = None
|
564
|
+
out_cache_loc: torch.Tensor = None # shape: [b], int32
|
565
|
+
output_ids: torch.Tensor = None # shape: [b], int32
|
560
566
|
|
561
567
|
# The sum of all sequence lengths
|
562
568
|
seq_lens_sum: int = None
|
@@ -750,13 +756,6 @@ class ScheduleBatch:
|
|
750
756
|
|
751
757
|
pt = 0
|
752
758
|
for i, req in enumerate(reqs):
|
753
|
-
already_computed = (
|
754
|
-
req.extend_logprob_start_len + 1 + req.cached_tokens
|
755
|
-
if req.extend_logprob_start_len > 0
|
756
|
-
else 0
|
757
|
-
)
|
758
|
-
req.cached_tokens += len(req.prefix_indices) - already_computed
|
759
|
-
|
760
759
|
req.req_pool_idx = req_pool_indices[i]
|
761
760
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
762
761
|
seq_lens.append(seq_len)
|
@@ -772,15 +771,20 @@ class ScheduleBatch:
|
|
772
771
|
# If req.input_embeds is already a list, append its content directly
|
773
772
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
774
773
|
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
774
|
+
if req.return_logprob:
|
775
|
+
# Compute the relative logprob_start_len in an extend batch
|
776
|
+
if req.logprob_start_len >= pre_len:
|
777
|
+
extend_logprob_start_len = min(
|
778
|
+
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
779
|
+
)
|
780
|
+
else:
|
781
|
+
raise RuntimeError(
|
782
|
+
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
783
|
+
)
|
784
|
+
req.extend_logprob_start_len = extend_logprob_start_len
|
782
785
|
|
783
|
-
req.
|
786
|
+
req.cached_tokens += pre_len - req.already_computed
|
787
|
+
req.already_computed = seq_len
|
784
788
|
req.is_retracted = False
|
785
789
|
pre_lens.append(pre_len)
|
786
790
|
|
@@ -1026,7 +1030,7 @@ class ScheduleBatch:
|
|
1026
1030
|
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
1027
1031
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1028
1032
|
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
1029
|
-
self.req_pool_indices = torch.empty(0, dtype=torch.
|
1033
|
+
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
1030
1034
|
self.seq_lens_sum = 0
|
1031
1035
|
self.extend_num_tokens = 0
|
1032
1036
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
@@ -1112,6 +1116,8 @@ class ScheduleBatch:
|
|
1112
1116
|
self.has_grammar = any(req.grammar for req in self.reqs)
|
1113
1117
|
|
1114
1118
|
self.sampling_info.filter_batch(keep_indices, new_indices)
|
1119
|
+
if self.spec_info:
|
1120
|
+
self.spec_info.filter_batch(new_indices)
|
1115
1121
|
|
1116
1122
|
def merge_batch(self, other: "ScheduleBatch"):
|
1117
1123
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|