sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,127 @@
|
|
1
|
+
import asyncio
|
2
|
+
import math
|
3
|
+
from typing import List, Union
|
4
|
+
|
5
|
+
from transformers.models.pixtral.image_processing_pixtral import (
|
6
|
+
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
7
|
+
)
|
8
|
+
|
9
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
10
|
+
BaseMultimodalProcessor,
|
11
|
+
MultimodalSpecialTokens,
|
12
|
+
)
|
13
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
14
|
+
from sglang.srt.models.pixtral import PixtralVisionModel
|
15
|
+
|
16
|
+
|
17
|
+
class PixtralProcessor(BaseMultimodalProcessor):
|
18
|
+
models = [PixtralVisionModel]
|
19
|
+
|
20
|
+
PAD_TOKEN = "<pad>"
|
21
|
+
IMG_BREAK_TOKEN_ID = 12
|
22
|
+
IMG_END_TOKEN_ID = 13
|
23
|
+
|
24
|
+
def get_patch_grid_size(
|
25
|
+
self,
|
26
|
+
*,
|
27
|
+
image_width: int,
|
28
|
+
image_height: int,
|
29
|
+
) -> tuple[int, int]:
|
30
|
+
max_width = max_height = self.image_size
|
31
|
+
patch_width = patch_height = self.patch_size
|
32
|
+
|
33
|
+
ratio = max(image_width / max_width, image_height / max_height)
|
34
|
+
|
35
|
+
if ratio > 1:
|
36
|
+
image_width = int(math.floor(image_width / ratio))
|
37
|
+
image_height = int(math.floor(image_height / ratio))
|
38
|
+
|
39
|
+
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
40
|
+
(image_height, image_width),
|
41
|
+
(patch_height, patch_width),
|
42
|
+
)
|
43
|
+
|
44
|
+
return ncols, nrows
|
45
|
+
|
46
|
+
def __init__(self, hf_config, server_args, _processor):
|
47
|
+
super().__init__(hf_config, server_args, _processor)
|
48
|
+
self.image_token_id = getattr(
|
49
|
+
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
50
|
+
)
|
51
|
+
# Instantiate the patcher logic helper using the class defined above
|
52
|
+
|
53
|
+
self.vision_config = hf_config.vision_config
|
54
|
+
self.image_size = self.vision_config.image_size
|
55
|
+
self.patch_size = self.vision_config.patch_size
|
56
|
+
self.multimodal_tokens = MultimodalSpecialTokens(
|
57
|
+
image_token=_processor.image_token
|
58
|
+
)
|
59
|
+
_processor.tokenizer.add_special_tokens(
|
60
|
+
{
|
61
|
+
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
|
62
|
+
}
|
63
|
+
)
|
64
|
+
|
65
|
+
async def _resize(self, image):
|
66
|
+
num_w_tokens, num_h_tokens = self.get_patch_grid_size(
|
67
|
+
image_width=image.size[0],
|
68
|
+
image_height=image.size[1],
|
69
|
+
)
|
70
|
+
new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
|
71
|
+
return image.resize(new_size)
|
72
|
+
|
73
|
+
async def process_mm_data_async(
|
74
|
+
self,
|
75
|
+
image_data: List[Union[str, bytes]],
|
76
|
+
input_text,
|
77
|
+
request_obj,
|
78
|
+
*args,
|
79
|
+
**kwargs,
|
80
|
+
):
|
81
|
+
if not image_data:
|
82
|
+
return None
|
83
|
+
|
84
|
+
if isinstance(image_data, str):
|
85
|
+
image_data = [image_data]
|
86
|
+
|
87
|
+
mm_data = self.load_mm_data(
|
88
|
+
prompt=input_text,
|
89
|
+
multimodal_tokens=self.multimodal_tokens,
|
90
|
+
max_req_input_len=kwargs.get("max_req_input_len", 4096),
|
91
|
+
image_data=image_data,
|
92
|
+
return_text=True,
|
93
|
+
)
|
94
|
+
|
95
|
+
if mm_data.images:
|
96
|
+
resize_tasks = [self._resize(image) for image in mm_data.images]
|
97
|
+
mm_data.images = await asyncio.gather(*resize_tasks)
|
98
|
+
|
99
|
+
processor_output = self.process_mm_data(
|
100
|
+
input_text=mm_data.input_text,
|
101
|
+
images=mm_data.images,
|
102
|
+
)
|
103
|
+
|
104
|
+
if "pixel_values" in processor_output:
|
105
|
+
input_ids = processor_output["input_ids"].view(-1)
|
106
|
+
image_offsets = self.get_mm_items_offset(
|
107
|
+
input_ids=input_ids,
|
108
|
+
mm_token_id=self.image_token_id,
|
109
|
+
)
|
110
|
+
mm_items = [
|
111
|
+
MultimodalDataItem(
|
112
|
+
pixel_values=processor_output["pixel_values"],
|
113
|
+
image_sizes=processor_output["image_sizes"],
|
114
|
+
modality=Modality.IMAGE,
|
115
|
+
image_offsets=image_offsets,
|
116
|
+
)
|
117
|
+
]
|
118
|
+
|
119
|
+
input_ids = input_ids.tolist()
|
120
|
+
processor_output.update(
|
121
|
+
input_ids=input_ids,
|
122
|
+
mm_items=mm_items,
|
123
|
+
# there's no im_start_id for pixtral, only im_token and im_end_token
|
124
|
+
im_end_id=self.IMG_END_TOKEN_ID,
|
125
|
+
im_token_id=self.image_token_id,
|
126
|
+
)
|
127
|
+
return processor_output
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import math
|
3
|
-
|
3
|
+
import re
|
4
|
+
from typing import Dict, List, Union
|
4
5
|
|
5
6
|
import torch
|
6
7
|
from PIL import Image
|
@@ -23,7 +24,12 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
23
24
|
|
24
25
|
def __init__(self, hf_config, server_args, _processor):
|
25
26
|
super().__init__(hf_config, server_args, _processor)
|
27
|
+
# The single, pre-expanded image token.
|
26
28
|
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
29
|
+
# The regex that matches expanded image tokens.
|
30
|
+
self.IMAGE_TOKEN_REGEX = re.compile(
|
31
|
+
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
32
|
+
)
|
27
33
|
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
28
34
|
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
29
35
|
self.image_token_id = hf_config.image_token_id
|
@@ -38,7 +44,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
38
44
|
|
39
45
|
async def process_mm_data_async(
|
40
46
|
self,
|
41
|
-
image_data: List[Union[str, bytes]],
|
47
|
+
image_data: List[Union[str, bytes, Dict]],
|
42
48
|
input_text,
|
43
49
|
request_obj,
|
44
50
|
max_req_input_len,
|
@@ -48,11 +54,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
48
54
|
if isinstance(image_data, str):
|
49
55
|
image_data = [image_data]
|
50
56
|
|
51
|
-
image_token = self.IMAGE_TOKEN
|
52
57
|
base_output = self.load_mm_data(
|
53
58
|
prompt=input_text,
|
54
59
|
image_data=image_data,
|
55
|
-
multimodal_tokens=MultimodalSpecialTokens(
|
60
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
61
|
+
image_token=self.IMAGE_TOKEN,
|
62
|
+
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
63
|
+
),
|
56
64
|
max_req_input_len=max_req_input_len,
|
57
65
|
)
|
58
66
|
|
@@ -117,26 +125,60 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
117
125
|
async def resize_image_async(image):
|
118
126
|
return resize_image(image)
|
119
127
|
|
120
|
-
|
128
|
+
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
129
|
+
if base_output.images and not images_are_preprocessed:
|
121
130
|
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
122
131
|
base_output.images = await asyncio.gather(*resize_tasks)
|
123
132
|
|
124
133
|
ret = self.process_mm_data(
|
125
134
|
input_text=base_output.input_text,
|
126
|
-
images=base_output.images,
|
135
|
+
images=None if images_are_preprocessed else base_output.images,
|
127
136
|
)
|
128
|
-
|
137
|
+
input_ids = ret["input_ids"].flatten().tolist()
|
138
|
+
image_offsets = self.get_mm_items_offset(
|
139
|
+
input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
|
140
|
+
)
|
141
|
+
image_grid_thw = None
|
142
|
+
video_grid_thw = None # TODO
|
129
143
|
items = []
|
130
144
|
|
131
|
-
|
132
|
-
|
145
|
+
if base_output.images:
|
146
|
+
if images_are_preprocessed:
|
147
|
+
image_grid_thw = torch.concat(
|
148
|
+
[
|
149
|
+
torch.as_tensor(item.image_grid_thws)
|
150
|
+
for item in base_output.images
|
151
|
+
]
|
152
|
+
)
|
153
|
+
all_pixel_values = [
|
154
|
+
item.pixel_values
|
155
|
+
for item in base_output.images
|
156
|
+
if item.pixel_values is not None
|
157
|
+
]
|
158
|
+
all_precomputed_features = [
|
159
|
+
item.precomputed_features
|
160
|
+
for item in base_output.images
|
161
|
+
if item.precomputed_features is not None
|
162
|
+
]
|
163
|
+
pixel_values = (
|
164
|
+
torch.concat(all_pixel_values) if all_pixel_values else None
|
165
|
+
)
|
166
|
+
precomputed_features = (
|
167
|
+
torch.concat(all_precomputed_features)
|
168
|
+
if all_precomputed_features
|
169
|
+
else None
|
170
|
+
)
|
171
|
+
else:
|
172
|
+
image_grid_thw = ret["image_grid_thw"]
|
173
|
+
pixel_values = ret["pixel_values"]
|
174
|
+
precomputed_features = None
|
133
175
|
items += [
|
134
176
|
MultimodalDataItem(
|
135
|
-
pixel_values=
|
136
|
-
image_grid_thws=
|
137
|
-
|
138
|
-
|
139
|
-
|
177
|
+
pixel_values=pixel_values,
|
178
|
+
image_grid_thws=image_grid_thw,
|
179
|
+
video_grid_thws=video_grid_thw,
|
180
|
+
precomputed_features=precomputed_features,
|
181
|
+
image_offsets=image_offsets,
|
140
182
|
modality=Modality.IMAGE,
|
141
183
|
)
|
142
184
|
]
|
@@ -151,8 +193,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
151
193
|
self.hf_config.vision_config, "tokens_per_second", None
|
152
194
|
),
|
153
195
|
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
154
|
-
image_grid_thw=
|
155
|
-
video_grid_thw=
|
196
|
+
image_grid_thw=image_grid_thw,
|
197
|
+
video_grid_thw=video_grid_thw,
|
156
198
|
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
157
199
|
)
|
158
200
|
mrope_positions = mrope_positions.squeeze(1)
|
@@ -1,8 +1,5 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import hashlib
|
4
|
-
from enum import Enum, auto
|
5
|
-
|
6
3
|
# Copyright 2023-2024 SGLang Team
|
7
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
5
|
# you may not use this file except in compliance with the License.
|
@@ -30,12 +27,16 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
30
27
|
It will be transformed from CPU scheduler to GPU model runner.
|
31
28
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
32
29
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
30
|
+
|
31
|
+
TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
|
33
32
|
"""
|
34
33
|
|
35
34
|
import copy
|
36
35
|
import dataclasses
|
36
|
+
import hashlib
|
37
37
|
import logging
|
38
38
|
import threading
|
39
|
+
from enum import Enum, auto
|
39
40
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
40
41
|
|
41
42
|
import numpy as np
|
@@ -47,10 +48,14 @@ from sglang.global_config import global_config
|
|
47
48
|
from sglang.srt.configs.model_config import ModelConfig
|
48
49
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
49
50
|
from sglang.srt.disaggregation.base import BaseKVSender
|
50
|
-
from sglang.srt.disaggregation.
|
51
|
+
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
52
|
+
ScheduleBatchDisaggregationDecodeMixin,
|
53
|
+
)
|
54
|
+
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
51
55
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
52
56
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
53
57
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
58
|
+
from sglang.srt.metrics.collector import TimeStats
|
54
59
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
55
60
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
56
61
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -73,17 +78,21 @@ global_server_args_dict = {
|
|
73
78
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
74
79
|
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
75
80
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
81
|
+
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
76
82
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
83
|
+
"deepep_config": ServerArgs.deepep_config,
|
77
84
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
78
85
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
79
86
|
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
80
87
|
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
88
|
+
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
|
81
89
|
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
82
90
|
"sampling_backend": ServerArgs.sampling_backend,
|
83
91
|
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
84
92
|
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
85
93
|
"torchao_config": ServerArgs.torchao_config,
|
86
94
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
95
|
+
"ep_num_redundant_experts": ServerArgs.ep_num_redundant_experts,
|
87
96
|
}
|
88
97
|
|
89
98
|
logger = logging.getLogger(__name__)
|
@@ -134,9 +143,9 @@ class FINISH_LENGTH(BaseFinishReason):
|
|
134
143
|
|
135
144
|
|
136
145
|
class FINISH_ABORT(BaseFinishReason):
|
137
|
-
def __init__(self, message=
|
146
|
+
def __init__(self, message=None, status_code=None, err_type=None):
|
138
147
|
super().__init__(is_error=True)
|
139
|
-
self.message = message
|
148
|
+
self.message = message or "Aborted"
|
140
149
|
self.status_code = status_code
|
141
150
|
self.err_type = err_type
|
142
151
|
|
@@ -174,10 +183,10 @@ class MultimodalDataItem:
|
|
174
183
|
image_offsets: Optional[list] = None
|
175
184
|
|
176
185
|
# the real data, pixel_values or audio_features
|
177
|
-
# data: Union[List[torch.Tensor], List[np.
|
178
|
-
pixel_values: Union[torch.Tensor, np.
|
179
|
-
image_grid_thws: Union[torch.Tensor, np.
|
180
|
-
video_grid_thws: Union[torch.Tensor, np.
|
186
|
+
# data: Union[List[torch.Tensor], List[np.ndarray]]
|
187
|
+
pixel_values: Union[torch.Tensor, np.ndarray] = None
|
188
|
+
image_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
189
|
+
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
|
181
190
|
|
182
191
|
image_emb_mask: Optional[torch.Tensor] = None
|
183
192
|
image_spatial_crop: Optional[torch.Tensor] = None
|
@@ -186,8 +195,11 @@ class MultimodalDataItem:
|
|
186
195
|
# [num_images, (n, w, h)]
|
187
196
|
tgt_size: Tuple[int, int] = None
|
188
197
|
|
189
|
-
audio_features: Union[torch.Tensor, np.
|
198
|
+
audio_features: Union[torch.Tensor, np.ndarray] = None
|
190
199
|
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
200
|
+
audio_offsets: Optional[List[Tuple[int, int]]] = None
|
201
|
+
|
202
|
+
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
|
191
203
|
|
192
204
|
@staticmethod
|
193
205
|
def is_empty_list(l):
|
@@ -216,7 +228,8 @@ class MultimodalDataItem:
|
|
216
228
|
for x in tensor_list
|
217
229
|
]
|
218
230
|
tensor = torch.concat(tensor_list)
|
219
|
-
|
231
|
+
if tensor.is_cuda:
|
232
|
+
return gpu_tensor_hash(tensor)
|
220
233
|
tensor = tensor.detach().contiguous()
|
221
234
|
|
222
235
|
if tensor.dtype == torch.bfloat16:
|
@@ -246,7 +259,9 @@ class MultimodalDataItem:
|
|
246
259
|
return tensor_hash([f])
|
247
260
|
return data_hash(f)
|
248
261
|
|
249
|
-
if self.
|
262
|
+
if self.precomputed_features is not None:
|
263
|
+
self.hash = hash_feature(self.precomputed_features)
|
264
|
+
elif self.is_audio():
|
250
265
|
self.hash = hash_feature(self.audio_features)
|
251
266
|
else:
|
252
267
|
self.hash = hash_feature(self.pixel_values)
|
@@ -255,19 +270,24 @@ class MultimodalDataItem:
|
|
255
270
|
self.pad_value = self.hash % (1 << 30)
|
256
271
|
|
257
272
|
def is_audio(self):
|
258
|
-
return (
|
259
|
-
self.
|
260
|
-
|
273
|
+
return (self.modality == Modality.AUDIO) and (
|
274
|
+
self.precomputed_features is not None
|
275
|
+
or not MultimodalDataItem.is_empty_list(self.audio_features)
|
276
|
+
)
|
261
277
|
|
262
278
|
def is_image(self):
|
263
279
|
return (
|
264
280
|
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
265
|
-
) and
|
281
|
+
) and (
|
282
|
+
self.precomputed_features is not None
|
283
|
+
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
284
|
+
)
|
266
285
|
|
267
286
|
def is_video(self):
|
268
|
-
return (
|
269
|
-
self.
|
270
|
-
|
287
|
+
return (self.modality == Modality.VIDEO) and (
|
288
|
+
self.precomputed_features is not None
|
289
|
+
or not MultimodalDataItem.is_empty_list(self.pixel_values)
|
290
|
+
)
|
271
291
|
|
272
292
|
def is_valid(self) -> bool:
|
273
293
|
return self.is_image() or self.is_video() or self.is_audio()
|
@@ -276,6 +296,16 @@ class MultimodalDataItem:
|
|
276
296
|
...
|
277
297
|
# TODO
|
278
298
|
|
299
|
+
@staticmethod
|
300
|
+
def from_dict(obj: dict):
|
301
|
+
kwargs = dict(obj)
|
302
|
+
modality = kwargs.pop("modality")
|
303
|
+
if isinstance(modality, str):
|
304
|
+
modality = Modality[modality]
|
305
|
+
ret = MultimodalDataItem(modality=modality, **kwargs)
|
306
|
+
ret.validate()
|
307
|
+
return ret
|
308
|
+
|
279
309
|
|
280
310
|
@dataclasses.dataclass
|
281
311
|
class MultimodalInputs:
|
@@ -301,8 +331,9 @@ class MultimodalInputs:
|
|
301
331
|
video_token_id: Optional[int] = None
|
302
332
|
|
303
333
|
# audio
|
304
|
-
|
305
|
-
|
334
|
+
audio_token_id: Optional[int] = None
|
335
|
+
audio_start_id: Optional[int] = None
|
336
|
+
audio_end_id: Optional[int] = None
|
306
337
|
|
307
338
|
@staticmethod
|
308
339
|
def from_dict(obj: dict):
|
@@ -326,6 +357,7 @@ class MultimodalInputs:
|
|
326
357
|
"slice_end_id",
|
327
358
|
"audio_start_id",
|
328
359
|
"audio_end_id",
|
360
|
+
"audio_token_id",
|
329
361
|
]
|
330
362
|
for arg in optional_args:
|
331
363
|
if arg in obj:
|
@@ -434,6 +466,7 @@ class Req:
|
|
434
466
|
self.sampling_params = sampling_params
|
435
467
|
self.custom_logit_processor = custom_logit_processor
|
436
468
|
self.return_hidden_states = return_hidden_states
|
469
|
+
self.lora_path = lora_path
|
437
470
|
|
438
471
|
# Memory pool info
|
439
472
|
self.req_pool_idx: Optional[int] = None
|
@@ -441,11 +474,13 @@ class Req:
|
|
441
474
|
# Check finish
|
442
475
|
self.tokenizer = None
|
443
476
|
self.finished_reason = None
|
477
|
+
# Whether this request has finished output
|
478
|
+
self.finished_output = None
|
444
479
|
# If we want to abort the request in the middle of the event loop, set this to true
|
445
480
|
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
446
481
|
self.to_abort = False
|
447
482
|
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
448
|
-
self.to_abort_message: str =
|
483
|
+
self.to_abort_message: str = None
|
449
484
|
self.stream = stream
|
450
485
|
self.eos_token_ids = eos_token_ids
|
451
486
|
|
@@ -483,6 +518,13 @@ class Req:
|
|
483
518
|
# For retraction
|
484
519
|
self.is_retracted = False
|
485
520
|
|
521
|
+
# Incremental streamining
|
522
|
+
self.send_token_offset: int = 0
|
523
|
+
self.send_decode_id_offset: int = 0
|
524
|
+
# TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
|
525
|
+
# because the decode server does not have the first output token logprobs
|
526
|
+
self.send_output_token_logprobs_offset: int = 0
|
527
|
+
|
486
528
|
# Logprobs (arguments)
|
487
529
|
self.return_logprob = return_logprob
|
488
530
|
# Start index to compute logprob from.
|
@@ -492,11 +534,9 @@ class Req:
|
|
492
534
|
self.temp_scaled_logprobs = False
|
493
535
|
self.top_p_normalized_logprobs = False
|
494
536
|
|
495
|
-
# Latency Breakdown
|
496
|
-
self.queue_time_start = None
|
497
|
-
self.queue_time_end = None
|
498
|
-
|
499
537
|
# Logprobs (return values)
|
538
|
+
# True means the input logprob has been already sent to detokenizer.
|
539
|
+
self.input_logprob_sent: bool = False
|
500
540
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
501
541
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
502
542
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
@@ -511,8 +551,10 @@ class Req:
|
|
511
551
|
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
512
552
|
|
513
553
|
if return_logprob:
|
554
|
+
# shape: (bs, 1)
|
514
555
|
self.output_token_logprobs_val = []
|
515
556
|
self.output_token_logprobs_idx = []
|
557
|
+
# shape: (bs, k)
|
516
558
|
self.output_top_logprobs_val = []
|
517
559
|
self.output_top_logprobs_idx = []
|
518
560
|
self.output_token_ids_logprobs_val = []
|
@@ -530,6 +572,7 @@ class Req:
|
|
530
572
|
|
531
573
|
# Constrained decoding
|
532
574
|
self.grammar: Optional[BaseGrammarObject] = None
|
575
|
+
self.grammar_wait_ct = 0
|
533
576
|
|
534
577
|
# The number of cached tokens that were already cached in the KV cache
|
535
578
|
self.cached_tokens = 0
|
@@ -538,7 +581,12 @@ class Req:
|
|
538
581
|
# The number of verification forward passes in the speculative decoding.
|
539
582
|
# This is used to compute the average acceptance length per request.
|
540
583
|
self.spec_verify_ct = 0
|
541
|
-
|
584
|
+
|
585
|
+
# For metrics
|
586
|
+
self.time_stats: TimeStats = TimeStats()
|
587
|
+
self.has_log_time_stats: bool = False
|
588
|
+
self.queue_time_start = None
|
589
|
+
self.queue_time_end = None
|
542
590
|
|
543
591
|
# For disaggregation
|
544
592
|
self.bootstrap_host: str = bootstrap_host
|
@@ -546,8 +594,6 @@ class Req:
|
|
546
594
|
self.bootstrap_room: Optional[int] = bootstrap_room
|
547
595
|
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
548
596
|
|
549
|
-
# used for warmup because we don't have a pair yet when init
|
550
|
-
self.skip_kv_transfer: bool = False
|
551
597
|
# the start index of the sent kv cache
|
552
598
|
# We want to send it chunk by chunk for chunked prefill.
|
553
599
|
# After every chunk forward, we do the following:
|
@@ -555,14 +601,11 @@ class Req:
|
|
555
601
|
# start_send_idx = len(req.fill_ids)
|
556
602
|
self.start_send_idx: int = 0
|
557
603
|
|
558
|
-
self.metadata_buffer_index: int = -1
|
559
|
-
# The first output_id transferred from prefill instance.
|
560
|
-
self.transferred_output_id: Optional[int] = None
|
561
|
-
|
562
604
|
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
563
605
|
# This is because kv is not ready in `process_prefill_chunk`.
|
564
606
|
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
565
607
|
self.tmp_end_idx: int = -1
|
608
|
+
self.metadata_buffer_index: int = -1
|
566
609
|
|
567
610
|
@property
|
568
611
|
def seqlen(self):
|
@@ -653,6 +696,11 @@ class Req:
|
|
653
696
|
)
|
654
697
|
return
|
655
698
|
|
699
|
+
if self.grammar is not None:
|
700
|
+
if self.grammar.is_terminated():
|
701
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
|
702
|
+
return
|
703
|
+
|
656
704
|
last_token_id = self.output_ids[-1]
|
657
705
|
|
658
706
|
if not self.sampling_params.ignore_eos:
|
@@ -697,13 +745,41 @@ class Req:
|
|
697
745
|
self.req_pool_idx = None
|
698
746
|
self.already_computed = 0
|
699
747
|
|
748
|
+
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
749
|
+
token_indices = req_to_token_pool.req_to_token[
|
750
|
+
self.req_pool_idx, : self.seqlen - 1
|
751
|
+
]
|
752
|
+
self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
|
753
|
+
|
754
|
+
def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
|
755
|
+
token_indices = req_to_token_pool.req_to_token[
|
756
|
+
self.req_pool_idx, : self.seqlen - 1
|
757
|
+
]
|
758
|
+
token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
|
759
|
+
del self.kv_cache_cpu
|
760
|
+
|
761
|
+
def log_time_stats(self):
|
762
|
+
# If overlap schedule, we schedule one decode batch ahead so this gets called twice.
|
763
|
+
if self.has_log_time_stats is True:
|
764
|
+
return
|
765
|
+
|
766
|
+
if self.bootstrap_room is not None:
|
767
|
+
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
768
|
+
else:
|
769
|
+
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
|
770
|
+
logger.info(f"{prefix}: {self.time_stats}")
|
771
|
+
self.has_log_time_stats = True
|
772
|
+
|
700
773
|
def __repr__(self):
|
701
774
|
return (
|
702
775
|
f"Req(rid={self.rid}, "
|
703
|
-
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}
|
776
|
+
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
|
777
|
+
f"{self.grammar=}, "
|
778
|
+
f"{self.sampling_params=})"
|
704
779
|
)
|
705
780
|
|
706
781
|
|
782
|
+
# Batch id
|
707
783
|
bid = 0
|
708
784
|
|
709
785
|
|
@@ -862,7 +938,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
862
938
|
error_msg = (
|
863
939
|
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
864
940
|
f"Try to allocate {num_tokens} tokens.\n"
|
865
|
-
f"
|
941
|
+
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
866
942
|
)
|
867
943
|
logger.error(error_msg)
|
868
944
|
if self.tree_cache is not None:
|
@@ -903,7 +979,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
903
979
|
error_msg = (
|
904
980
|
f"Prefill out of memory. Try to lower your batch size.\n"
|
905
981
|
f"Try to allocate {extend_num_tokens} tokens.\n"
|
906
|
-
f"
|
982
|
+
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
907
983
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
908
984
|
f"{self.tree_cache.evictable_size()=}\n"
|
909
985
|
)
|
@@ -938,7 +1014,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
938
1014
|
error_msg = (
|
939
1015
|
f"Decode out of memory. Try to lower your batch size.\n"
|
940
1016
|
f"Try to allocate {len(seq_lens)} tokens.\n"
|
941
|
-
f"
|
1017
|
+
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
942
1018
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
943
1019
|
f"{self.tree_cache.evictable_size()=}\n"
|
944
1020
|
)
|
@@ -1019,7 +1095,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1019
1095
|
else:
|
1020
1096
|
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
|
1021
1097
|
|
1022
|
-
assert
|
1098
|
+
assert (
|
1099
|
+
len(self.out_cache_loc) == self.extend_num_tokens
|
1100
|
+
), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
|
1023
1101
|
|
1024
1102
|
def prepare_for_extend(self):
|
1025
1103
|
self.forward_mode = ForwardMode.EXTEND
|
@@ -1447,7 +1525,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1447
1525
|
i
|
1448
1526
|
for i in range(len(self.reqs))
|
1449
1527
|
if not self.reqs[i].finished()
|
1450
|
-
and
|
1528
|
+
and self.reqs[i] not in chunked_req_to_exclude
|
1451
1529
|
]
|
1452
1530
|
|
1453
1531
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -1468,7 +1546,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1468
1546
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1469
1547
|
|
1470
1548
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1471
|
-
|
1549
|
+
if self.multimodal_inputs is not None:
|
1550
|
+
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1472
1551
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1473
1552
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1474
1553
|
self.out_cache_loc = None
|
@@ -1517,7 +1596,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1517
1596
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1518
1597
|
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
1519
1598
|
self.reqs.extend(other.reqs)
|
1520
|
-
self.multimodal_inputs
|
1599
|
+
if self.multimodal_inputs is not None:
|
1600
|
+
self.multimodal_inputs.extend(other.multimodal_inputs)
|
1521
1601
|
|
1522
1602
|
self.return_logprob |= other.return_logprob
|
1523
1603
|
self.has_stream |= other.has_stream
|
@@ -22,11 +22,7 @@ from typing import Dict, List, Optional, Set, Union
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
-
from sglang.srt.managers.schedule_batch import
|
26
|
-
Req,
|
27
|
-
ScheduleBatch,
|
28
|
-
global_server_args_dict,
|
29
|
-
)
|
25
|
+
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
30
26
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
31
27
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
32
28
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|