sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,8 @@ from typing import List, Optional, Union
|
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import transformers
|
12
|
+
from decord import VideoReader, cpu
|
13
|
+
from PIL import Image
|
12
14
|
|
13
15
|
from sglang.srt.hf_transformers_utils import get_processor
|
14
16
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
|
|
36
38
|
def __init__(self, hf_config, server_args, _processor):
|
37
39
|
self.hf_config = hf_config
|
38
40
|
self._processor = _processor
|
41
|
+
self.server_args = server_args
|
39
42
|
|
40
43
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
41
44
|
initializer=init_global_processor,
|
@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|
126
129
|
)
|
127
130
|
|
128
131
|
async def process_images_async(
|
129
|
-
self,
|
132
|
+
self,
|
133
|
+
image_data: List[Union[str, bytes]],
|
134
|
+
input_text,
|
135
|
+
request_obj,
|
136
|
+
*args,
|
137
|
+
**kwargs,
|
130
138
|
):
|
131
139
|
if not image_data:
|
132
140
|
return None
|
@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor):
|
|
229
237
|
return image_inputs
|
230
238
|
|
231
239
|
|
240
|
+
class MiniCPMVImageProcessor(BaseImageProcessor):
|
241
|
+
def __init__(self, hf_config, server_args, _processor):
|
242
|
+
super().__init__(hf_config, server_args, _processor)
|
243
|
+
|
244
|
+
@staticmethod
|
245
|
+
def _process_images_task(images, input_text):
|
246
|
+
result = global_processor.__call__(
|
247
|
+
text=input_text, images=images, return_tensors="pt"
|
248
|
+
)
|
249
|
+
return {
|
250
|
+
"input_ids": result["input_ids"],
|
251
|
+
"pixel_values": result["pixel_values"],
|
252
|
+
"tgt_sizes": result["tgt_sizes"],
|
253
|
+
}
|
254
|
+
|
255
|
+
async def _process_images(self, images, input_text):
|
256
|
+
if self.executor is not None:
|
257
|
+
loop = asyncio.get_event_loop()
|
258
|
+
image_inputs = await loop.run_in_executor(
|
259
|
+
self.executor,
|
260
|
+
MiniCPMVImageProcessor._process_images_task,
|
261
|
+
images,
|
262
|
+
input_text,
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
image_inputs = self._processor(
|
266
|
+
images=images, text=input_text, return_tensors="pt"
|
267
|
+
)
|
268
|
+
|
269
|
+
return image_inputs
|
270
|
+
|
271
|
+
async def process_images_async(
|
272
|
+
self,
|
273
|
+
image_data: List[Union[str, bytes]],
|
274
|
+
input_text,
|
275
|
+
request_obj,
|
276
|
+
max_req_input_len,
|
277
|
+
):
|
278
|
+
if not image_data:
|
279
|
+
return None
|
280
|
+
|
281
|
+
if not isinstance(image_data, list):
|
282
|
+
image_data = [image_data]
|
283
|
+
|
284
|
+
image_hashes, image_sizes = [], []
|
285
|
+
raw_images = []
|
286
|
+
IMAGE_TOKEN = "(<image>./</image>)"
|
287
|
+
|
288
|
+
# roughly calculate the max number of frames
|
289
|
+
# TODO: the process should be applied to all the visual inputs
|
290
|
+
def calculate_max_num_frames() -> int:
|
291
|
+
# Model-specific
|
292
|
+
NUM_TOKEN_PER_FRAME = 330
|
293
|
+
|
294
|
+
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
|
295
|
+
return min(ret, 100)
|
296
|
+
|
297
|
+
# if cuda OOM set a smaller number
|
298
|
+
MAX_NUM_FRAMES = calculate_max_num_frames()
|
299
|
+
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
300
|
+
|
301
|
+
def encode_video(video_path):
|
302
|
+
if not os.path.exists(video_path):
|
303
|
+
logger.error(f"Video {video_path} does not exist")
|
304
|
+
return []
|
305
|
+
|
306
|
+
if MAX_NUM_FRAMES == 0:
|
307
|
+
return []
|
308
|
+
|
309
|
+
def uniform_sample(l, n):
|
310
|
+
gap = len(l) / n
|
311
|
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
312
|
+
return [l[i] for i in idxs]
|
313
|
+
|
314
|
+
vr = VideoReader(video_path, ctx=cpu(0))
|
315
|
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
316
|
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
317
|
+
if len(frame_idx) > MAX_NUM_FRAMES:
|
318
|
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
319
|
+
frames = vr.get_batch(frame_idx).asnumpy()
|
320
|
+
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
321
|
+
return frames
|
322
|
+
|
323
|
+
if isinstance(input_text, list):
|
324
|
+
assert len(input_text) and isinstance(input_text[0], int)
|
325
|
+
input_text = self._processor.tokenizer.decode(input_text)
|
326
|
+
|
327
|
+
# MiniCPMV requires each frame of video as a single image token
|
328
|
+
text_parts = input_text.split(IMAGE_TOKEN)
|
329
|
+
new_text_parts = []
|
330
|
+
|
331
|
+
for image_index, image in enumerate(image_data):
|
332
|
+
try:
|
333
|
+
if isinstance(image, str) and image.startswith("video:"):
|
334
|
+
path = image[len("video:") :]
|
335
|
+
frames = encode_video(path)
|
336
|
+
else:
|
337
|
+
raw_image, size = load_image(image)
|
338
|
+
frames = [raw_image]
|
339
|
+
if len(frames) == 0:
|
340
|
+
continue
|
341
|
+
except FileNotFoundError as e:
|
342
|
+
print(e)
|
343
|
+
return None
|
344
|
+
|
345
|
+
image_sizes += frames[0].size * len(frames)
|
346
|
+
image_hashes += [hash(image)] * len(frames)
|
347
|
+
raw_images += frames
|
348
|
+
new_text_parts.append(text_parts[image_index])
|
349
|
+
new_text_parts.append(IMAGE_TOKEN * len(frames))
|
350
|
+
|
351
|
+
new_text_parts.append(text_parts[-1])
|
352
|
+
input_text = "".join(new_text_parts)
|
353
|
+
if len(raw_images) == 0:
|
354
|
+
return None
|
355
|
+
res = await self._process_images(images=raw_images, input_text=input_text)
|
356
|
+
pixel_values = res["pixel_values"]
|
357
|
+
tgt_sizes = res["tgt_sizes"]
|
358
|
+
input_ids = res["input_ids"]
|
359
|
+
|
360
|
+
# Collect special token ids
|
361
|
+
tokenizer = self._processor.tokenizer
|
362
|
+
im_start_id = [tokenizer.im_start_id]
|
363
|
+
im_end_id = [tokenizer.im_end_id]
|
364
|
+
if tokenizer.slice_start_id:
|
365
|
+
slice_start_id = [tokenizer.slice_start_id]
|
366
|
+
slice_end_id = [tokenizer.slice_end_id]
|
367
|
+
|
368
|
+
return {
|
369
|
+
"input_ids": input_ids.flatten().tolist(),
|
370
|
+
"pixel_values": pixel_values,
|
371
|
+
"tgt_sizes": tgt_sizes,
|
372
|
+
"image_hashes": image_hashes,
|
373
|
+
"modalities": request_obj.modalities or ["image"],
|
374
|
+
"im_start_id": im_start_id,
|
375
|
+
"im_end_id": im_end_id,
|
376
|
+
"slice_start_id": slice_start_id,
|
377
|
+
"slice_end_id": slice_end_id,
|
378
|
+
}
|
379
|
+
|
380
|
+
|
232
381
|
class Qwen2VLImageProcessor(BaseImageProcessor):
|
233
382
|
def __init__(self, hf_config, server_args, _image_processor):
|
234
383
|
self.hf_config = hf_config
|
@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|
289
438
|
return self._process_single_image_task(image_data)
|
290
439
|
|
291
440
|
async def process_images_async(
|
292
|
-
self,
|
441
|
+
self,
|
442
|
+
image_data: List[Union[str, bytes]],
|
443
|
+
input_text,
|
444
|
+
request_obj,
|
445
|
+
*args,
|
446
|
+
**kwargs,
|
293
447
|
):
|
294
448
|
if not image_data:
|
295
449
|
return None
|
@@ -350,6 +504,8 @@ def get_image_processor(
|
|
350
504
|
return MllamaImageProcessor(hf_config, server_args, processor)
|
351
505
|
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
352
506
|
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
|
507
|
+
elif "MiniCPMV" in hf_config.architectures:
|
508
|
+
return MiniCPMVImageProcessor(hf_config, server_args, processor)
|
353
509
|
else:
|
354
510
|
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
|
355
511
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
17
17
|
"""
|
18
18
|
|
19
19
|
import uuid
|
20
|
-
from dataclasses import dataclass
|
20
|
+
from dataclasses import dataclass, field
|
21
21
|
from enum import Enum
|
22
22
|
from typing import Dict, List, Optional, Union
|
23
23
|
|
@@ -59,6 +59,9 @@ class GenerateReqInput:
|
|
59
59
|
return_text_in_logprobs: bool = False
|
60
60
|
# Whether to stream output.
|
61
61
|
stream: bool = False
|
62
|
+
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
63
|
+
log_metrics: bool = True
|
64
|
+
|
62
65
|
# The modalities of the image data [image, multi-images, video]
|
63
66
|
modalities: Optional[List[str]] = None
|
64
67
|
# LoRA related
|
@@ -66,6 +69,10 @@ class GenerateReqInput:
|
|
66
69
|
|
67
70
|
# Session info for continual prompting
|
68
71
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
72
|
+
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
73
|
+
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
74
|
+
# Use the processor's `to_str()` method to generate the serialized string.
|
75
|
+
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
69
76
|
|
70
77
|
def normalize_batch_and_arguments(self):
|
71
78
|
if (
|
@@ -180,6 +187,13 @@ class GenerateReqInput:
|
|
180
187
|
else:
|
181
188
|
assert self.parallel_sample_num == 1
|
182
189
|
|
190
|
+
if self.custom_logit_processor is None:
|
191
|
+
self.custom_logit_processor = [None] * num
|
192
|
+
elif not isinstance(self.custom_logit_processor, list):
|
193
|
+
self.custom_logit_processor = [self.custom_logit_processor] * num
|
194
|
+
else:
|
195
|
+
assert self.parallel_sample_num == 1
|
196
|
+
|
183
197
|
def regenerate_rid(self):
|
184
198
|
self.rid = uuid.uuid4().hex
|
185
199
|
return self.rid
|
@@ -196,8 +210,14 @@ class GenerateReqInput:
|
|
196
210
|
top_logprobs_num=self.top_logprobs_num[i],
|
197
211
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
198
212
|
stream=self.stream,
|
213
|
+
log_metrics=self.log_metrics,
|
199
214
|
modalities=self.modalities[i] if self.modalities else None,
|
200
215
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
216
|
+
custom_logit_processor=(
|
217
|
+
self.custom_logit_processor[i]
|
218
|
+
if self.custom_logit_processor is not None
|
219
|
+
else None
|
220
|
+
),
|
201
221
|
)
|
202
222
|
|
203
223
|
|
@@ -230,6 +250,11 @@ class TokenizedGenerateReqInput:
|
|
230
250
|
# Session info for continual prompting
|
231
251
|
session_params: Optional[SessionParams] = None
|
232
252
|
|
253
|
+
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
254
|
+
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
255
|
+
# Use the processor's `to_str()` method to generate the serialized string.
|
256
|
+
custom_logit_processor: Optional[str] = None
|
257
|
+
|
233
258
|
|
234
259
|
@dataclass
|
235
260
|
class EmbeddingReqInput:
|
@@ -243,6 +268,8 @@ class EmbeddingReqInput:
|
|
243
268
|
sampling_params: Union[List[Dict], Dict] = None
|
244
269
|
# Dummy input embeds for compatibility
|
245
270
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
271
|
+
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
272
|
+
log_metrics: bool = True
|
246
273
|
|
247
274
|
def normalize_batch_and_arguments(self):
|
248
275
|
if (self.text is None and self.input_ids is None) or (
|
@@ -327,10 +354,13 @@ class BatchTokenIDOut:
|
|
327
354
|
skip_special_tokens: List[bool]
|
328
355
|
spaces_between_special_tokens: List[bool]
|
329
356
|
no_stop_trim: List[bool]
|
357
|
+
|
330
358
|
# Token counts
|
331
359
|
prompt_tokens: List[int]
|
332
360
|
completion_tokens: List[int]
|
333
361
|
cached_tokens: List[int]
|
362
|
+
spec_verify_ct: List[int]
|
363
|
+
|
334
364
|
# Logprobs
|
335
365
|
input_token_logprobs_val: List[float]
|
336
366
|
input_token_logprobs_idx: List[int]
|
@@ -340,7 +370,6 @@ class BatchTokenIDOut:
|
|
340
370
|
input_top_logprobs_idx: List[List]
|
341
371
|
output_top_logprobs_val: List[List]
|
342
372
|
output_top_logprobs_idx: List[List]
|
343
|
-
normalized_prompt_logprob: List[float]
|
344
373
|
|
345
374
|
|
346
375
|
@dataclass
|
@@ -356,6 +385,7 @@ class BatchStrOut:
|
|
356
385
|
prompt_tokens: List[int]
|
357
386
|
completion_tokens: List[int]
|
358
387
|
cached_tokens: List[int]
|
388
|
+
spec_verify_ct: List[int]
|
359
389
|
|
360
390
|
# Logprobs
|
361
391
|
input_token_logprobs_val: List[float]
|
@@ -366,7 +396,6 @@ class BatchStrOut:
|
|
366
396
|
input_top_logprobs_idx: List[List]
|
367
397
|
output_top_logprobs_val: List[List]
|
368
398
|
output_top_logprobs_idx: List[List]
|
369
|
-
normalized_prompt_logprob: List[float]
|
370
399
|
|
371
400
|
|
372
401
|
@dataclass
|
@@ -491,6 +520,7 @@ class ProfileReq(Enum):
|
|
491
520
|
@dataclass
|
492
521
|
class ConfigureLoggingReq:
|
493
522
|
log_requests: Optional[bool] = None
|
523
|
+
log_requests_level: Optional[int] = None
|
494
524
|
dump_requests_folder: Optional[str] = None
|
495
525
|
dump_requests_threshold: Optional[int] = None
|
496
526
|
|
@@ -510,3 +540,27 @@ class CloseSessionReqInput:
|
|
510
540
|
class OpenSessionReqOutput:
|
511
541
|
session_id: Optional[str]
|
512
542
|
success: bool
|
543
|
+
|
544
|
+
|
545
|
+
@dataclass
|
546
|
+
class Function:
|
547
|
+
description: Optional[str] = None
|
548
|
+
name: Optional[str] = None
|
549
|
+
parameters: Optional[object] = None
|
550
|
+
|
551
|
+
|
552
|
+
@dataclass
|
553
|
+
class Tool:
|
554
|
+
function: Function
|
555
|
+
type: Optional[str] = "function"
|
556
|
+
|
557
|
+
|
558
|
+
@dataclass
|
559
|
+
class FunctionCallReqInput:
|
560
|
+
text: str # The text to parse.
|
561
|
+
tools: List[Tool] = field(
|
562
|
+
default_factory=list
|
563
|
+
) # A list of available function tools (name, parameters, etc.).
|
564
|
+
tool_call_parser: Optional[str] = (
|
565
|
+
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
566
|
+
)
|
@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
|
|
52
52
|
if TYPE_CHECKING:
|
53
53
|
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
54
54
|
|
55
|
-
|
56
55
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
57
56
|
|
58
57
|
# Put some global args for easy access
|
@@ -65,9 +64,9 @@ global_server_args_dict = {
|
|
65
64
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
66
65
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
67
66
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
67
|
+
"device": ServerArgs.device,
|
68
68
|
}
|
69
69
|
|
70
|
-
|
71
70
|
logger = logging.getLogger(__name__)
|
72
71
|
|
73
72
|
|
@@ -116,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason):
|
|
116
115
|
|
117
116
|
|
118
117
|
class FINISH_ABORT(BaseFinishReason):
|
119
|
-
def __init__(self, message="Unknown error"):
|
118
|
+
def __init__(self, message="Unknown error", status_code=None, err_type=None):
|
120
119
|
super().__init__(is_error=True)
|
121
120
|
self.message = message
|
121
|
+
self.status_code = status_code
|
122
|
+
self.err_type = err_type
|
122
123
|
|
123
124
|
def to_json(self):
|
124
125
|
return {
|
125
126
|
"type": "abort",
|
126
127
|
"message": self.message,
|
128
|
+
"status_code": self.status_code,
|
129
|
+
"err_type": self.err_type,
|
127
130
|
}
|
128
131
|
|
129
132
|
|
@@ -148,6 +151,15 @@ class ImageInputs:
|
|
148
151
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
149
152
|
mrope_position_delta: Optional[torch.Tensor] = None
|
150
153
|
|
154
|
+
# MiniCPMV related
|
155
|
+
# All the images in the batch should share the same special image
|
156
|
+
# bound token ids.
|
157
|
+
im_start_id: Optional[torch.Tensor] = None
|
158
|
+
im_end_id: Optional[torch.Tensor] = None
|
159
|
+
slice_start_id: Optional[torch.Tensor] = None
|
160
|
+
slice_end_id: Optional[torch.Tensor] = None
|
161
|
+
tgt_sizes: Optional[list] = None
|
162
|
+
|
151
163
|
@staticmethod
|
152
164
|
def from_dict(obj: dict):
|
153
165
|
ret = ImageInputs(
|
@@ -167,6 +179,11 @@ class ImageInputs:
|
|
167
179
|
"aspect_ratio_ids",
|
168
180
|
"aspect_ratio_mask",
|
169
181
|
"image_grid_thws",
|
182
|
+
"im_start_id",
|
183
|
+
"im_end_id",
|
184
|
+
"slice_start_id",
|
185
|
+
"slice_end_id",
|
186
|
+
"tgt_sizes",
|
170
187
|
]
|
171
188
|
for arg in optional_args:
|
172
189
|
if arg in obj:
|
@@ -215,6 +232,7 @@ class Req:
|
|
215
232
|
lora_path: Optional[str] = None,
|
216
233
|
input_embeds: Optional[List[List[float]]] = None,
|
217
234
|
session_id: Optional[str] = None,
|
235
|
+
custom_logit_processor: Optional[str] = None,
|
218
236
|
eos_token_ids: Optional[Set[int]] = None,
|
219
237
|
):
|
220
238
|
# Input and output info
|
@@ -226,14 +244,16 @@ class Req:
|
|
226
244
|
else origin_input_ids # Before image padding
|
227
245
|
)
|
228
246
|
self.origin_input_ids = origin_input_ids
|
229
|
-
|
230
|
-
self.
|
247
|
+
# Each decode stage's output ids
|
248
|
+
self.output_ids = []
|
249
|
+
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
250
|
+
self.fill_ids = None
|
231
251
|
self.session_id = session_id
|
232
252
|
self.input_embeds = input_embeds
|
233
253
|
|
234
254
|
# Sampling info
|
235
255
|
self.sampling_params = sampling_params
|
236
|
-
self.
|
256
|
+
self.custom_logit_processor = custom_logit_processor
|
237
257
|
|
238
258
|
# Memory pool info
|
239
259
|
self.req_pool_idx = None
|
@@ -265,6 +285,7 @@ class Req:
|
|
265
285
|
# Prefix info
|
266
286
|
self.prefix_indices = []
|
267
287
|
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
288
|
+
# Updated if chunked.
|
268
289
|
self.extend_input_len = 0
|
269
290
|
self.last_node = None
|
270
291
|
|
@@ -279,12 +300,11 @@ class Req:
|
|
279
300
|
self.logprob_start_len = 0
|
280
301
|
self.top_logprobs_num = top_logprobs_num
|
281
302
|
|
282
|
-
# Logprobs (return
|
283
|
-
self.
|
284
|
-
self.
|
285
|
-
self.
|
286
|
-
self.
|
287
|
-
self.input_top_logprobs_idx = None
|
303
|
+
# Logprobs (return values)
|
304
|
+
self.input_token_logprobs_val: Optional[List[float]] = None
|
305
|
+
self.input_token_logprobs_idx: Optional[List[int]] = None
|
306
|
+
self.input_top_logprobs_val: Optional[List[float]] = None
|
307
|
+
self.input_top_logprobs_idx: Optional[List[int]] = None
|
288
308
|
|
289
309
|
if return_logprob:
|
290
310
|
self.output_token_logprobs_val = []
|
@@ -309,8 +329,14 @@ class Req:
|
|
309
329
|
# Constrained decoding
|
310
330
|
self.grammar: Optional[BaseGrammarObject] = None
|
311
331
|
|
312
|
-
# The number of cached tokens
|
332
|
+
# The number of cached tokens that were already cached in the KV cache
|
313
333
|
self.cached_tokens = 0
|
334
|
+
self.already_computed = 0
|
335
|
+
|
336
|
+
# The number of verification forward passes in the speculative decoding.
|
337
|
+
# This is used to compute the average acceptance length per request.
|
338
|
+
self.spec_verify_ct = 0
|
339
|
+
self.lora_path = lora_path
|
314
340
|
|
315
341
|
def extend_image_inputs(self, image_inputs):
|
316
342
|
if self.image_inputs is None:
|
@@ -344,9 +370,6 @@ class Req:
|
|
344
370
|
max_prefix_len = min(max_prefix_len, input_len - 1)
|
345
371
|
|
346
372
|
if self.return_logprob:
|
347
|
-
if self.normalized_prompt_logprob is None:
|
348
|
-
# Need at least two tokens to compute normalized logprob
|
349
|
-
max_prefix_len = min(max_prefix_len, input_len - 2)
|
350
373
|
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
351
374
|
|
352
375
|
max_prefix_len = max(max_prefix_len, 0)
|
@@ -533,13 +556,13 @@ class ScheduleBatch:
|
|
533
556
|
next_batch_sampling_info: SamplingBatchInfo = None
|
534
557
|
|
535
558
|
# Batched arguments to model runner
|
536
|
-
input_ids: torch.Tensor = None
|
537
|
-
input_embeds: torch.Tensor = None
|
538
|
-
req_pool_indices: torch.Tensor = None
|
539
|
-
seq_lens: torch.Tensor = None
|
559
|
+
input_ids: torch.Tensor = None # shape: [b], int32
|
560
|
+
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
561
|
+
req_pool_indices: torch.Tensor = None # shape: [b], int32
|
562
|
+
seq_lens: torch.Tensor = None # shape: [b], int64
|
540
563
|
# The output locations of the KV cache
|
541
|
-
out_cache_loc: torch.Tensor = None
|
542
|
-
output_ids: torch.Tensor = None
|
564
|
+
out_cache_loc: torch.Tensor = None # shape: [b], int32
|
565
|
+
output_ids: torch.Tensor = None # shape: [b], int32
|
543
566
|
|
544
567
|
# The sum of all sequence lengths
|
545
568
|
seq_lens_sum: int = None
|
@@ -578,6 +601,9 @@ class ScheduleBatch:
|
|
578
601
|
spec_algorithm: SpeculativeAlgorithm = None
|
579
602
|
spec_info: Optional[SpecInfo] = None
|
580
603
|
|
604
|
+
# Enable custom logit processor
|
605
|
+
enable_custom_logit_processor: bool = False
|
606
|
+
|
581
607
|
@classmethod
|
582
608
|
def init_new(
|
583
609
|
cls,
|
@@ -588,6 +614,7 @@ class ScheduleBatch:
|
|
588
614
|
model_config: ModelConfig,
|
589
615
|
enable_overlap: bool,
|
590
616
|
spec_algorithm: SpeculativeAlgorithm,
|
617
|
+
enable_custom_logit_processor: bool,
|
591
618
|
):
|
592
619
|
return cls(
|
593
620
|
reqs=reqs,
|
@@ -601,6 +628,7 @@ class ScheduleBatch:
|
|
601
628
|
has_grammar=any(req.grammar for req in reqs),
|
602
629
|
device=req_to_token_pool.device,
|
603
630
|
spec_algorithm=spec_algorithm,
|
631
|
+
enable_custom_logit_processor=enable_custom_logit_processor,
|
604
632
|
)
|
605
633
|
|
606
634
|
def batch_size(self):
|
@@ -656,7 +684,7 @@ class ScheduleBatch:
|
|
656
684
|
or len(req.prefix_indices) >= im.num_image_tokens
|
657
685
|
)
|
658
686
|
|
659
|
-
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.
|
687
|
+
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
|
660
688
|
self.device, non_blocking=True
|
661
689
|
)
|
662
690
|
|
@@ -690,7 +718,7 @@ class ScheduleBatch:
|
|
690
718
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
691
719
|
self.device, non_blocking=True
|
692
720
|
)
|
693
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.
|
721
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
694
722
|
self.device, non_blocking=True
|
695
723
|
)
|
696
724
|
|
@@ -728,13 +756,6 @@ class ScheduleBatch:
|
|
728
756
|
|
729
757
|
pt = 0
|
730
758
|
for i, req in enumerate(reqs):
|
731
|
-
already_computed = (
|
732
|
-
req.extend_logprob_start_len + 1 + req.cached_tokens
|
733
|
-
if req.extend_logprob_start_len > 0
|
734
|
-
else 0
|
735
|
-
)
|
736
|
-
req.cached_tokens += len(req.prefix_indices) - already_computed
|
737
|
-
|
738
759
|
req.req_pool_idx = req_pool_indices[i]
|
739
760
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
740
761
|
seq_lens.append(seq_len)
|
@@ -750,15 +771,20 @@ class ScheduleBatch:
|
|
750
771
|
# If req.input_embeds is already a list, append its content directly
|
751
772
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
752
773
|
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
774
|
+
if req.return_logprob:
|
775
|
+
# Compute the relative logprob_start_len in an extend batch
|
776
|
+
if req.logprob_start_len >= pre_len:
|
777
|
+
extend_logprob_start_len = min(
|
778
|
+
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
779
|
+
)
|
780
|
+
else:
|
781
|
+
raise RuntimeError(
|
782
|
+
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
783
|
+
)
|
784
|
+
req.extend_logprob_start_len = extend_logprob_start_len
|
760
785
|
|
761
|
-
req.
|
786
|
+
req.cached_tokens += pre_len - req.already_computed
|
787
|
+
req.already_computed = seq_len
|
762
788
|
req.is_retracted = False
|
763
789
|
pre_lens.append(pre_len)
|
764
790
|
|
@@ -766,10 +792,10 @@ class ScheduleBatch:
|
|
766
792
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
767
793
|
self.device, non_blocking=True
|
768
794
|
)
|
769
|
-
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.
|
795
|
+
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
770
796
|
self.device, non_blocking=True
|
771
797
|
)
|
772
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.
|
798
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
773
799
|
self.device, non_blocking=True
|
774
800
|
)
|
775
801
|
self.input_embeds = (
|
@@ -1002,11 +1028,16 @@ class ScheduleBatch:
|
|
1002
1028
|
def prepare_for_idle(self):
|
1003
1029
|
self.forward_mode = ForwardMode.IDLE
|
1004
1030
|
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
1005
|
-
self.seq_lens = torch.empty(0, dtype=torch.
|
1031
|
+
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1006
1032
|
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
1007
1033
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
1008
1034
|
self.seq_lens_sum = 0
|
1009
1035
|
self.extend_num_tokens = 0
|
1036
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
1037
|
+
self,
|
1038
|
+
self.model_config.vocab_size,
|
1039
|
+
enable_overlap_schedule=self.enable_overlap,
|
1040
|
+
)
|
1010
1041
|
|
1011
1042
|
def prepare_for_decode(self):
|
1012
1043
|
self.forward_mode = ForwardMode.DECODE
|
@@ -1067,7 +1098,7 @@ class ScheduleBatch:
|
|
1067
1098
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1068
1099
|
|
1069
1100
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1070
|
-
new_indices = torch.tensor(keep_indices, dtype=torch.
|
1101
|
+
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
|
1071
1102
|
self.device, non_blocking=True
|
1072
1103
|
)
|
1073
1104
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
@@ -1085,6 +1116,8 @@ class ScheduleBatch:
|
|
1085
1116
|
self.has_grammar = any(req.grammar for req in self.reqs)
|
1086
1117
|
|
1087
1118
|
self.sampling_info.filter_batch(keep_indices, new_indices)
|
1119
|
+
if self.spec_info:
|
1120
|
+
self.spec_info.filter_batch(new_indices)
|
1088
1121
|
|
1089
1122
|
def merge_batch(self, other: "ScheduleBatch"):
|
1090
1123
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
@@ -1121,7 +1154,7 @@ class ScheduleBatch:
|
|
1121
1154
|
self.spec_info.merge_batch(other.spec_info)
|
1122
1155
|
|
1123
1156
|
def get_model_worker_batch(self):
|
1124
|
-
if self.forward_mode.
|
1157
|
+
if self.forward_mode.is_decode_or_idle():
|
1125
1158
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
1126
1159
|
else:
|
1127
1160
|
extend_seq_lens = self.extend_lens
|
@@ -1136,7 +1169,6 @@ class ScheduleBatch:
|
|
1136
1169
|
|
1137
1170
|
global bid
|
1138
1171
|
bid += 1
|
1139
|
-
|
1140
1172
|
return ModelWorkerBatch(
|
1141
1173
|
bid=bid,
|
1142
1174
|
forward_mode=self.forward_mode,
|
@@ -1180,6 +1212,7 @@ class ScheduleBatch:
|
|
1180
1212
|
return_logprob=self.return_logprob,
|
1181
1213
|
decoding_reqs=self.decoding_reqs,
|
1182
1214
|
spec_algorithm=self.spec_algorithm,
|
1215
|
+
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
1183
1216
|
)
|
1184
1217
|
|
1185
1218
|
def __str__(self):
|