sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,8 @@ from typing import List, Optional, Union
|
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import transformers
|
12
|
+
from decord import VideoReader, cpu
|
13
|
+
from PIL import Image
|
12
14
|
|
13
15
|
from sglang.srt.hf_transformers_utils import get_processor
|
14
16
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
|
|
36
38
|
def __init__(self, hf_config, server_args, _processor):
|
37
39
|
self.hf_config = hf_config
|
38
40
|
self._processor = _processor
|
41
|
+
self.server_args = server_args
|
39
42
|
|
40
43
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
41
44
|
initializer=init_global_processor,
|
@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor):
|
|
126
129
|
)
|
127
130
|
|
128
131
|
async def process_images_async(
|
129
|
-
self,
|
132
|
+
self,
|
133
|
+
image_data: List[Union[str, bytes]],
|
134
|
+
input_text,
|
135
|
+
request_obj,
|
136
|
+
*args,
|
137
|
+
**kwargs,
|
130
138
|
):
|
131
139
|
if not image_data:
|
132
140
|
return None
|
@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor):
|
|
229
237
|
return image_inputs
|
230
238
|
|
231
239
|
|
240
|
+
class MiniCPMVImageProcessor(BaseImageProcessor):
|
241
|
+
def __init__(self, hf_config, server_args, _processor):
|
242
|
+
super().__init__(hf_config, server_args, _processor)
|
243
|
+
|
244
|
+
@staticmethod
|
245
|
+
def _process_images_task(images, input_text):
|
246
|
+
result = global_processor.__call__(
|
247
|
+
text=input_text, images=images, return_tensors="pt"
|
248
|
+
)
|
249
|
+
return {
|
250
|
+
"input_ids": result["input_ids"],
|
251
|
+
"pixel_values": result["pixel_values"],
|
252
|
+
"tgt_sizes": result["tgt_sizes"],
|
253
|
+
}
|
254
|
+
|
255
|
+
async def _process_images(self, images, input_text):
|
256
|
+
if self.executor is not None:
|
257
|
+
loop = asyncio.get_event_loop()
|
258
|
+
image_inputs = await loop.run_in_executor(
|
259
|
+
self.executor,
|
260
|
+
MiniCPMVImageProcessor._process_images_task,
|
261
|
+
images,
|
262
|
+
input_text,
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
image_inputs = self._processor(
|
266
|
+
images=images, text=input_text, return_tensors="pt"
|
267
|
+
)
|
268
|
+
|
269
|
+
return image_inputs
|
270
|
+
|
271
|
+
async def process_images_async(
|
272
|
+
self,
|
273
|
+
image_data: List[Union[str, bytes]],
|
274
|
+
input_text,
|
275
|
+
request_obj,
|
276
|
+
max_req_input_len,
|
277
|
+
):
|
278
|
+
if not image_data:
|
279
|
+
return None
|
280
|
+
|
281
|
+
if not isinstance(image_data, list):
|
282
|
+
image_data = [image_data]
|
283
|
+
|
284
|
+
image_hashes, image_sizes = [], []
|
285
|
+
raw_images = []
|
286
|
+
IMAGE_TOKEN = "(<image>./</image>)"
|
287
|
+
|
288
|
+
# roughly calculate the max number of frames
|
289
|
+
# TODO: the process should be applied to all the visual inputs
|
290
|
+
def calculate_max_num_frames() -> int:
|
291
|
+
# Model-specific
|
292
|
+
NUM_TOKEN_PER_FRAME = 330
|
293
|
+
|
294
|
+
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
|
295
|
+
return min(ret, 100)
|
296
|
+
|
297
|
+
# if cuda OOM set a smaller number
|
298
|
+
MAX_NUM_FRAMES = calculate_max_num_frames()
|
299
|
+
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
300
|
+
|
301
|
+
def encode_video(video_path):
|
302
|
+
if not os.path.exists(video_path):
|
303
|
+
logger.error(f"Video {video_path} does not exist")
|
304
|
+
return []
|
305
|
+
|
306
|
+
if MAX_NUM_FRAMES == 0:
|
307
|
+
return []
|
308
|
+
|
309
|
+
def uniform_sample(l, n):
|
310
|
+
gap = len(l) / n
|
311
|
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
312
|
+
return [l[i] for i in idxs]
|
313
|
+
|
314
|
+
vr = VideoReader(video_path, ctx=cpu(0))
|
315
|
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
316
|
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
317
|
+
if len(frame_idx) > MAX_NUM_FRAMES:
|
318
|
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
319
|
+
frames = vr.get_batch(frame_idx).asnumpy()
|
320
|
+
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
321
|
+
return frames
|
322
|
+
|
323
|
+
if isinstance(input_text, list):
|
324
|
+
assert len(input_text) and isinstance(input_text[0], int)
|
325
|
+
input_text = self._processor.tokenizer.decode(input_text)
|
326
|
+
|
327
|
+
# MiniCPMV requires each frame of video as a single image token
|
328
|
+
text_parts = input_text.split(IMAGE_TOKEN)
|
329
|
+
new_text_parts = []
|
330
|
+
|
331
|
+
for image_index, image in enumerate(image_data):
|
332
|
+
try:
|
333
|
+
if isinstance(image, str) and image.startswith("video:"):
|
334
|
+
path = image[len("video:") :]
|
335
|
+
frames = encode_video(path)
|
336
|
+
else:
|
337
|
+
raw_image, size = load_image(image)
|
338
|
+
frames = [raw_image]
|
339
|
+
if len(frames) == 0:
|
340
|
+
continue
|
341
|
+
except FileNotFoundError as e:
|
342
|
+
print(e)
|
343
|
+
return None
|
344
|
+
|
345
|
+
image_sizes += frames[0].size * len(frames)
|
346
|
+
image_hashes += [hash(image)] * len(frames)
|
347
|
+
raw_images += frames
|
348
|
+
new_text_parts.append(text_parts[image_index])
|
349
|
+
new_text_parts.append(IMAGE_TOKEN * len(frames))
|
350
|
+
|
351
|
+
new_text_parts.append(text_parts[-1])
|
352
|
+
input_text = "".join(new_text_parts)
|
353
|
+
if len(raw_images) == 0:
|
354
|
+
return None
|
355
|
+
res = await self._process_images(images=raw_images, input_text=input_text)
|
356
|
+
pixel_values = res["pixel_values"]
|
357
|
+
tgt_sizes = res["tgt_sizes"]
|
358
|
+
input_ids = res["input_ids"]
|
359
|
+
|
360
|
+
# Collect special token ids
|
361
|
+
tokenizer = self._processor.tokenizer
|
362
|
+
im_start_id = [tokenizer.im_start_id]
|
363
|
+
im_end_id = [tokenizer.im_end_id]
|
364
|
+
if tokenizer.slice_start_id:
|
365
|
+
slice_start_id = [tokenizer.slice_start_id]
|
366
|
+
slice_end_id = [tokenizer.slice_end_id]
|
367
|
+
|
368
|
+
return {
|
369
|
+
"input_ids": input_ids.flatten().tolist(),
|
370
|
+
"pixel_values": pixel_values,
|
371
|
+
"tgt_sizes": tgt_sizes,
|
372
|
+
"image_hashes": image_hashes,
|
373
|
+
"modalities": request_obj.modalities or ["image"],
|
374
|
+
"im_start_id": im_start_id,
|
375
|
+
"im_end_id": im_end_id,
|
376
|
+
"slice_start_id": slice_start_id,
|
377
|
+
"slice_end_id": slice_end_id,
|
378
|
+
}
|
379
|
+
|
380
|
+
|
232
381
|
class Qwen2VLImageProcessor(BaseImageProcessor):
|
233
382
|
def __init__(self, hf_config, server_args, _image_processor):
|
234
383
|
self.hf_config = hf_config
|
@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|
289
438
|
return self._process_single_image_task(image_data)
|
290
439
|
|
291
440
|
async def process_images_async(
|
292
|
-
self,
|
441
|
+
self,
|
442
|
+
image_data: List[Union[str, bytes]],
|
443
|
+
input_text,
|
444
|
+
request_obj,
|
445
|
+
*args,
|
446
|
+
**kwargs,
|
293
447
|
):
|
294
448
|
if not image_data:
|
295
449
|
return None
|
@@ -350,6 +504,8 @@ def get_image_processor(
|
|
350
504
|
return MllamaImageProcessor(hf_config, server_args, processor)
|
351
505
|
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
352
506
|
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
|
507
|
+
elif "MiniCPMV" in hf_config.architectures:
|
508
|
+
return MiniCPMVImageProcessor(hf_config, server_args, processor)
|
353
509
|
else:
|
354
510
|
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
|
355
511
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
19
19
|
import uuid
|
20
20
|
from dataclasses import dataclass
|
21
21
|
from enum import Enum
|
22
|
-
from typing import Dict, List, Optional,
|
23
|
-
|
24
|
-
import torch
|
22
|
+
from typing import Dict, List, Optional, Union
|
25
23
|
|
26
24
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
27
25
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -61,6 +59,9 @@ class GenerateReqInput:
|
|
61
59
|
return_text_in_logprobs: bool = False
|
62
60
|
# Whether to stream output.
|
63
61
|
stream: bool = False
|
62
|
+
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
63
|
+
log_metrics: bool = True
|
64
|
+
|
64
65
|
# The modalities of the image data [image, multi-images, video]
|
65
66
|
modalities: Optional[List[str]] = None
|
66
67
|
# LoRA related
|
@@ -68,6 +69,8 @@ class GenerateReqInput:
|
|
68
69
|
|
69
70
|
# Session info for continual prompting
|
70
71
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
72
|
+
# Custom logit processor (serialized function)
|
73
|
+
custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
71
74
|
|
72
75
|
def normalize_batch_and_arguments(self):
|
73
76
|
if (
|
@@ -182,6 +185,13 @@ class GenerateReqInput:
|
|
182
185
|
else:
|
183
186
|
assert self.parallel_sample_num == 1
|
184
187
|
|
188
|
+
if self.custom_logit_processor is None:
|
189
|
+
self.custom_logit_processor = [None] * num
|
190
|
+
elif not isinstance(self.custom_logit_processor, list):
|
191
|
+
self.custom_logit_processor = [self.custom_logit_processor] * num
|
192
|
+
else:
|
193
|
+
assert self.parallel_sample_num == 1
|
194
|
+
|
185
195
|
def regenerate_rid(self):
|
186
196
|
self.rid = uuid.uuid4().hex
|
187
197
|
return self.rid
|
@@ -198,8 +208,14 @@ class GenerateReqInput:
|
|
198
208
|
top_logprobs_num=self.top_logprobs_num[i],
|
199
209
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
200
210
|
stream=self.stream,
|
211
|
+
log_metrics=self.log_metrics,
|
201
212
|
modalities=self.modalities[i] if self.modalities else None,
|
202
213
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
214
|
+
custom_logit_processor=(
|
215
|
+
self.custom_logit_processor[i]
|
216
|
+
if self.custom_logit_processor is not None
|
217
|
+
else None
|
218
|
+
),
|
203
219
|
)
|
204
220
|
|
205
221
|
|
@@ -232,6 +248,10 @@ class TokenizedGenerateReqInput:
|
|
232
248
|
# Session info for continual prompting
|
233
249
|
session_params: Optional[SessionParams] = None
|
234
250
|
|
251
|
+
# Custom logit processor (serialized function)
|
252
|
+
# TODO (hpguo): Add an example and update doc string here
|
253
|
+
custom_logit_processor: Optional[str] = None
|
254
|
+
|
235
255
|
|
236
256
|
@dataclass
|
237
257
|
class EmbeddingReqInput:
|
@@ -245,6 +265,8 @@ class EmbeddingReqInput:
|
|
245
265
|
sampling_params: Union[List[Dict], Dict] = None
|
246
266
|
# Dummy input embeds for compatibility
|
247
267
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
268
|
+
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
269
|
+
log_metrics: bool = True
|
248
270
|
|
249
271
|
def normalize_batch_and_arguments(self):
|
250
272
|
if (self.text is None and self.input_ids is None) or (
|
@@ -323,9 +345,7 @@ class BatchTokenIDOut:
|
|
323
345
|
decoded_texts: List[str]
|
324
346
|
decode_ids: List[int]
|
325
347
|
read_offsets: List[int]
|
326
|
-
# Only used when
|
327
|
-
origin_input_ids: Optional[List[int]]
|
328
|
-
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
|
348
|
+
# Only used when `--skip-tokenizer-init` is on
|
329
349
|
output_ids: Optional[List[int]]
|
330
350
|
# Detokenization configs
|
331
351
|
skip_special_tokens: List[bool]
|
@@ -344,7 +364,6 @@ class BatchTokenIDOut:
|
|
344
364
|
input_top_logprobs_idx: List[List]
|
345
365
|
output_top_logprobs_val: List[List]
|
346
366
|
output_top_logprobs_idx: List[List]
|
347
|
-
normalized_prompt_logprob: List[float]
|
348
367
|
|
349
368
|
|
350
369
|
@dataclass
|
@@ -356,14 +375,7 @@ class BatchStrOut:
|
|
356
375
|
# The output decoded strings
|
357
376
|
output_strs: List[str]
|
358
377
|
|
359
|
-
# The token ids
|
360
|
-
origin_input_ids: Optional[List[int]]
|
361
|
-
output_ids: Optional[List[int]]
|
362
|
-
|
363
378
|
# Token counts
|
364
|
-
# real input and output tokens can be get from
|
365
|
-
# origin_input_ids and output_ids by enabling --return_token_ids
|
366
|
-
# TODO (Shuai): Rename this to clarify the meaning.
|
367
379
|
prompt_tokens: List[int]
|
368
380
|
completion_tokens: List[int]
|
369
381
|
cached_tokens: List[int]
|
@@ -377,7 +389,6 @@ class BatchStrOut:
|
|
377
389
|
input_top_logprobs_idx: List[List]
|
378
390
|
output_top_logprobs_val: List[List]
|
379
391
|
output_top_logprobs_idx: List[List]
|
380
|
-
normalized_prompt_logprob: List[float]
|
381
392
|
|
382
393
|
|
383
394
|
@dataclass
|
@@ -468,6 +479,26 @@ class GetWeightsByNameReqOutput:
|
|
468
479
|
parameter: list
|
469
480
|
|
470
481
|
|
482
|
+
@dataclass
|
483
|
+
class ReleaseMemoryOccupationReqInput:
|
484
|
+
pass
|
485
|
+
|
486
|
+
|
487
|
+
@dataclass
|
488
|
+
class ReleaseMemoryOccupationReqOutput:
|
489
|
+
pass
|
490
|
+
|
491
|
+
|
492
|
+
@dataclass
|
493
|
+
class ResumeMemoryOccupationReqInput:
|
494
|
+
pass
|
495
|
+
|
496
|
+
|
497
|
+
@dataclass
|
498
|
+
class ResumeMemoryOccupationReqOutput:
|
499
|
+
pass
|
500
|
+
|
501
|
+
|
471
502
|
@dataclass
|
472
503
|
class AbortReq:
|
473
504
|
# The request id
|
@@ -479,6 +510,14 @@ class ProfileReq(Enum):
|
|
479
510
|
STOP_PROFILE = 2
|
480
511
|
|
481
512
|
|
513
|
+
@dataclass
|
514
|
+
class ConfigureLoggingReq:
|
515
|
+
log_requests: Optional[bool] = None
|
516
|
+
log_requests_level: Optional[int] = None
|
517
|
+
dump_requests_folder: Optional[str] = None
|
518
|
+
dump_requests_threshold: Optional[int] = None
|
519
|
+
|
520
|
+
|
482
521
|
@dataclass
|
483
522
|
class OpenSessionReqInput:
|
484
523
|
capacity_of_str_len: int
|
@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
|
|
52
52
|
if TYPE_CHECKING:
|
53
53
|
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
54
54
|
|
55
|
-
|
56
55
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
57
56
|
|
58
57
|
# Put some global args for easy access
|
@@ -65,9 +64,9 @@ global_server_args_dict = {
|
|
65
64
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
66
65
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
67
66
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
67
|
+
"device": ServerArgs.device,
|
68
68
|
}
|
69
69
|
|
70
|
-
|
71
70
|
logger = logging.getLogger(__name__)
|
72
71
|
|
73
72
|
|
@@ -116,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason):
|
|
116
115
|
|
117
116
|
|
118
117
|
class FINISH_ABORT(BaseFinishReason):
|
119
|
-
def __init__(self, message="Unknown error"):
|
118
|
+
def __init__(self, message="Unknown error", status_code=None, err_type=None):
|
120
119
|
super().__init__(is_error=True)
|
121
120
|
self.message = message
|
121
|
+
self.status_code = status_code
|
122
|
+
self.err_type = err_type
|
122
123
|
|
123
124
|
def to_json(self):
|
124
125
|
return {
|
125
126
|
"type": "abort",
|
126
127
|
"message": self.message,
|
128
|
+
"status_code": self.status_code,
|
129
|
+
"err_type": self.err_type,
|
127
130
|
}
|
128
131
|
|
129
132
|
|
@@ -148,6 +151,15 @@ class ImageInputs:
|
|
148
151
|
image_grid_thws: List[Tuple[int, int, int]] = None
|
149
152
|
mrope_position_delta: Optional[torch.Tensor] = None
|
150
153
|
|
154
|
+
# MiniCPMV related
|
155
|
+
# All the images in the batch should share the same special image
|
156
|
+
# bound token ids.
|
157
|
+
im_start_id: Optional[torch.Tensor] = None
|
158
|
+
im_end_id: Optional[torch.Tensor] = None
|
159
|
+
slice_start_id: Optional[torch.Tensor] = None
|
160
|
+
slice_end_id: Optional[torch.Tensor] = None
|
161
|
+
tgt_sizes: Optional[list] = None
|
162
|
+
|
151
163
|
@staticmethod
|
152
164
|
def from_dict(obj: dict):
|
153
165
|
ret = ImageInputs(
|
@@ -167,6 +179,11 @@ class ImageInputs:
|
|
167
179
|
"aspect_ratio_ids",
|
168
180
|
"aspect_ratio_mask",
|
169
181
|
"image_grid_thws",
|
182
|
+
"im_start_id",
|
183
|
+
"im_end_id",
|
184
|
+
"slice_start_id",
|
185
|
+
"slice_end_id",
|
186
|
+
"tgt_sizes",
|
170
187
|
]
|
171
188
|
for arg in optional_args:
|
172
189
|
if arg in obj:
|
@@ -215,6 +232,7 @@ class Req:
|
|
215
232
|
lora_path: Optional[str] = None,
|
216
233
|
input_embeds: Optional[List[List[float]]] = None,
|
217
234
|
session_id: Optional[str] = None,
|
235
|
+
custom_logit_processor: Optional[str] = None,
|
218
236
|
eos_token_ids: Optional[Set[int]] = None,
|
219
237
|
):
|
220
238
|
# Input and output info
|
@@ -226,14 +244,16 @@ class Req:
|
|
226
244
|
else origin_input_ids # Before image padding
|
227
245
|
)
|
228
246
|
self.origin_input_ids = origin_input_ids
|
229
|
-
|
230
|
-
self.
|
247
|
+
# Each decode stage's output ids
|
248
|
+
self.output_ids = []
|
249
|
+
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
231
250
|
self.session_id = session_id
|
232
251
|
self.input_embeds = input_embeds
|
233
252
|
|
234
253
|
# Sampling info
|
235
254
|
self.sampling_params = sampling_params
|
236
255
|
self.lora_path = lora_path
|
256
|
+
self.custom_logit_processor = custom_logit_processor
|
237
257
|
|
238
258
|
# Memory pool info
|
239
259
|
self.req_pool_idx = None
|
@@ -265,6 +285,7 @@ class Req:
|
|
265
285
|
# Prefix info
|
266
286
|
self.prefix_indices = []
|
267
287
|
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
288
|
+
# Updated if chunked.
|
268
289
|
self.extend_input_len = 0
|
269
290
|
self.last_node = None
|
270
291
|
|
@@ -280,11 +301,10 @@ class Req:
|
|
280
301
|
self.top_logprobs_num = top_logprobs_num
|
281
302
|
|
282
303
|
# Logprobs (return value)
|
283
|
-
self.
|
284
|
-
self.
|
285
|
-
self.
|
286
|
-
self.
|
287
|
-
self.input_top_logprobs_idx = None
|
304
|
+
self.input_token_logprobs_val: Optional[List[float]] = None
|
305
|
+
self.input_token_logprobs_idx: Optional[List[int]] = None
|
306
|
+
self.input_top_logprobs_val: Optional[List[float]] = None
|
307
|
+
self.input_top_logprobs_idx: Optional[List[int]] = None
|
288
308
|
|
289
309
|
if return_logprob:
|
290
310
|
self.output_token_logprobs_val = []
|
@@ -344,9 +364,6 @@ class Req:
|
|
344
364
|
max_prefix_len = min(max_prefix_len, input_len - 1)
|
345
365
|
|
346
366
|
if self.return_logprob:
|
347
|
-
if self.normalized_prompt_logprob is None:
|
348
|
-
# Need at least two tokens to compute normalized logprob
|
349
|
-
max_prefix_len = min(max_prefix_len, input_len - 2)
|
350
367
|
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
351
368
|
|
352
369
|
max_prefix_len = max(max_prefix_len, 0)
|
@@ -578,6 +595,9 @@ class ScheduleBatch:
|
|
578
595
|
spec_algorithm: SpeculativeAlgorithm = None
|
579
596
|
spec_info: Optional[SpecInfo] = None
|
580
597
|
|
598
|
+
# Enable custom logit processor
|
599
|
+
enable_custom_logit_processor: bool = False
|
600
|
+
|
581
601
|
@classmethod
|
582
602
|
def init_new(
|
583
603
|
cls,
|
@@ -588,6 +608,7 @@ class ScheduleBatch:
|
|
588
608
|
model_config: ModelConfig,
|
589
609
|
enable_overlap: bool,
|
590
610
|
spec_algorithm: SpeculativeAlgorithm,
|
611
|
+
enable_custom_logit_processor: bool,
|
591
612
|
):
|
592
613
|
return cls(
|
593
614
|
reqs=reqs,
|
@@ -601,6 +622,7 @@ class ScheduleBatch:
|
|
601
622
|
has_grammar=any(req.grammar for req in reqs),
|
602
623
|
device=req_to_token_pool.device,
|
603
624
|
spec_algorithm=spec_algorithm,
|
625
|
+
enable_custom_logit_processor=enable_custom_logit_processor,
|
604
626
|
)
|
605
627
|
|
606
628
|
def batch_size(self):
|
@@ -656,7 +678,7 @@ class ScheduleBatch:
|
|
656
678
|
or len(req.prefix_indices) >= im.num_image_tokens
|
657
679
|
)
|
658
680
|
|
659
|
-
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.
|
681
|
+
self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
|
660
682
|
self.device, non_blocking=True
|
661
683
|
)
|
662
684
|
|
@@ -690,7 +712,7 @@ class ScheduleBatch:
|
|
690
712
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
691
713
|
self.device, non_blocking=True
|
692
714
|
)
|
693
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.
|
715
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
694
716
|
self.device, non_blocking=True
|
695
717
|
)
|
696
718
|
|
@@ -766,10 +788,10 @@ class ScheduleBatch:
|
|
766
788
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
767
789
|
self.device, non_blocking=True
|
768
790
|
)
|
769
|
-
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.
|
791
|
+
self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
770
792
|
self.device, non_blocking=True
|
771
793
|
)
|
772
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.
|
794
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
773
795
|
self.device, non_blocking=True
|
774
796
|
)
|
775
797
|
self.input_embeds = (
|
@@ -1002,11 +1024,16 @@ class ScheduleBatch:
|
|
1002
1024
|
def prepare_for_idle(self):
|
1003
1025
|
self.forward_mode = ForwardMode.IDLE
|
1004
1026
|
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
|
1005
|
-
self.seq_lens = torch.empty(0, dtype=torch.
|
1027
|
+
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1006
1028
|
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
|
1007
|
-
self.req_pool_indices = torch.empty(0, dtype=torch.
|
1029
|
+
self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
|
1008
1030
|
self.seq_lens_sum = 0
|
1009
1031
|
self.extend_num_tokens = 0
|
1032
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
1033
|
+
self,
|
1034
|
+
self.model_config.vocab_size,
|
1035
|
+
enable_overlap_schedule=self.enable_overlap,
|
1036
|
+
)
|
1010
1037
|
|
1011
1038
|
def prepare_for_decode(self):
|
1012
1039
|
self.forward_mode = ForwardMode.DECODE
|
@@ -1067,7 +1094,7 @@ class ScheduleBatch:
|
|
1067
1094
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1068
1095
|
|
1069
1096
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1070
|
-
new_indices = torch.tensor(keep_indices, dtype=torch.
|
1097
|
+
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
|
1071
1098
|
self.device, non_blocking=True
|
1072
1099
|
)
|
1073
1100
|
self.req_pool_indices = self.req_pool_indices[new_indices]
|
@@ -1121,7 +1148,7 @@ class ScheduleBatch:
|
|
1121
1148
|
self.spec_info.merge_batch(other.spec_info)
|
1122
1149
|
|
1123
1150
|
def get_model_worker_batch(self):
|
1124
|
-
if self.forward_mode.
|
1151
|
+
if self.forward_mode.is_decode_or_idle():
|
1125
1152
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
1126
1153
|
else:
|
1127
1154
|
extend_seq_lens = self.extend_lens
|
@@ -1136,7 +1163,6 @@ class ScheduleBatch:
|
|
1136
1163
|
|
1137
1164
|
global bid
|
1138
1165
|
bid += 1
|
1139
|
-
|
1140
1166
|
return ModelWorkerBatch(
|
1141
1167
|
bid=bid,
|
1142
1168
|
forward_mode=self.forward_mode,
|
@@ -1180,6 +1206,7 @@ class ScheduleBatch:
|
|
1180
1206
|
return_logprob=self.return_logprob,
|
1181
1207
|
decoding_reqs=self.decoding_reqs,
|
1182
1208
|
spec_algorithm=self.spec_algorithm,
|
1209
|
+
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
1183
1210
|
)
|
1184
1211
|
|
1185
1212
|
def __str__(self):
|
@@ -24,6 +24,7 @@ import torch
|
|
24
24
|
|
25
25
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
26
26
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
27
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
|
27
28
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
28
29
|
|
29
30
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
@@ -250,23 +251,24 @@ class PrefillAdder:
|
|
250
251
|
def __init__(
|
251
252
|
self,
|
252
253
|
tree_cache: BasePrefixCache,
|
254
|
+
token_to_kv_pool: BaseTokenToKVPool,
|
253
255
|
running_batch: ScheduleBatch,
|
254
256
|
new_token_ratio: float,
|
255
|
-
rem_total_tokens: int,
|
256
257
|
rem_input_tokens: int,
|
257
258
|
rem_chunk_tokens: Optional[int],
|
258
259
|
mixed_with_decode_tokens: int = 0,
|
259
260
|
):
|
260
261
|
self.tree_cache = tree_cache
|
262
|
+
self.token_to_kv_pool = token_to_kv_pool
|
261
263
|
self.running_batch = running_batch
|
262
264
|
self.new_token_ratio = new_token_ratio
|
263
|
-
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
264
265
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
265
266
|
self.rem_chunk_tokens = rem_chunk_tokens
|
266
267
|
if self.rem_chunk_tokens is not None:
|
267
268
|
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
268
269
|
|
269
|
-
self.
|
270
|
+
self.rem_total_token_offset = mixed_with_decode_tokens
|
271
|
+
self.cur_rem_token_offset = mixed_with_decode_tokens
|
270
272
|
|
271
273
|
self.req_states = None
|
272
274
|
self.can_run_list = []
|
@@ -275,8 +277,7 @@ class PrefillAdder:
|
|
275
277
|
self.log_input_tokens = 0
|
276
278
|
|
277
279
|
if running_batch is not None:
|
278
|
-
|
279
|
-
self.rem_total_tokens -= sum(
|
280
|
+
self.rem_total_token_offset += sum(
|
280
281
|
[
|
281
282
|
min(
|
282
283
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
@@ -287,6 +288,22 @@ class PrefillAdder:
|
|
287
288
|
]
|
288
289
|
)
|
289
290
|
|
291
|
+
@property
|
292
|
+
def rem_total_tokens(self):
|
293
|
+
return (
|
294
|
+
self.token_to_kv_pool.available_size()
|
295
|
+
+ self.tree_cache.evictable_size()
|
296
|
+
- self.rem_total_token_offset
|
297
|
+
)
|
298
|
+
|
299
|
+
@property
|
300
|
+
def cur_rem_tokens(self):
|
301
|
+
return (
|
302
|
+
self.token_to_kv_pool.available_size()
|
303
|
+
+ self.tree_cache.evictable_size()
|
304
|
+
- self.cur_rem_token_offset
|
305
|
+
)
|
306
|
+
|
290
307
|
def budget_state(self):
|
291
308
|
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
292
309
|
return AddReqResult.NO_TOKEN
|
@@ -301,8 +318,8 @@ class PrefillAdder:
|
|
301
318
|
def _prefill_one_req(
|
302
319
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
303
320
|
):
|
304
|
-
self.
|
305
|
-
self.
|
321
|
+
self.rem_total_token_offset += extend_input_len + max_new_tokens
|
322
|
+
self.cur_rem_token_offset += extend_input_len
|
306
323
|
self.rem_input_tokens -= extend_input_len
|
307
324
|
if self.rem_chunk_tokens is not None:
|
308
325
|
self.rem_chunk_tokens -= extend_input_len
|
@@ -332,12 +349,10 @@ class PrefillAdder:
|
|
332
349
|
@contextmanager
|
333
350
|
def _lock_node(self, last_node: TreeNode):
|
334
351
|
try:
|
335
|
-
|
336
|
-
self.rem_total_tokens += delta
|
352
|
+
self.tree_cache.inc_lock_ref(last_node)
|
337
353
|
yield None
|
338
354
|
finally:
|
339
|
-
|
340
|
-
self.rem_total_tokens += delta
|
355
|
+
self.tree_cache.dec_lock_ref(last_node)
|
341
356
|
|
342
357
|
def add_one_req_ignore_eos(self, req: Req):
|
343
358
|
def add_req_state(r, insert_sort=False):
|
@@ -433,7 +448,6 @@ class PrefillAdder:
|
|
433
448
|
or input_tokens <= self.rem_chunk_tokens
|
434
449
|
or (
|
435
450
|
req.return_logprob
|
436
|
-
and req.normalized_prompt_logprob is None
|
437
451
|
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
438
452
|
)
|
439
453
|
):
|