sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,232 @@
|
|
1
|
+
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
from decord import VideoReader, cpu
|
6
|
+
from numpy.distutils.cpuinfo import cpu
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
10
|
+
BaseMultimodalProcessor,
|
11
|
+
MultimodalSpecialTokens,
|
12
|
+
)
|
13
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
14
|
+
from sglang.srt.models.internvl import InternVLChatModel
|
15
|
+
|
16
|
+
|
17
|
+
class InternVLImageProcessor(BaseMultimodalProcessor):
|
18
|
+
models = [InternVLChatModel]
|
19
|
+
|
20
|
+
def __init__(self, hf_config, server_args, _image_processor):
|
21
|
+
super().__init__(hf_config, server_args, _image_processor)
|
22
|
+
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
|
23
|
+
patch_size = hf_config.vision_config.patch_size
|
24
|
+
|
25
|
+
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
26
|
+
self.IMG_START_TOKEN = "<img>"
|
27
|
+
self.IMG_END_TOKEN = "</img>"
|
28
|
+
self.IMG_TOKEN = "<image>"
|
29
|
+
self.num_image_token = int(
|
30
|
+
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
|
31
|
+
)
|
32
|
+
|
33
|
+
tokenizer = self._processor
|
34
|
+
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
|
35
|
+
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
|
36
|
+
self.img_context_token_id = tokenizer.convert_tokens_to_ids(
|
37
|
+
self.IMG_CONTEXT_TOKEN
|
38
|
+
)
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def build_transform(input_size):
|
42
|
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
43
|
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
44
|
+
|
45
|
+
def resize_image(img, size):
|
46
|
+
return img.resize((size, size), Image.Resampling.BICUBIC)
|
47
|
+
|
48
|
+
def to_tensor(img):
|
49
|
+
# Convert PIL Image to numpy array
|
50
|
+
img_array = np.array(img).astype(np.float32) / 255.0
|
51
|
+
# Convert HWC to CHW format
|
52
|
+
img_array = img_array.transpose(2, 0, 1)
|
53
|
+
return torch.from_numpy(img_array)
|
54
|
+
|
55
|
+
def normalize(tensor, mean, std):
|
56
|
+
mean = torch.tensor(mean).view(-1, 1, 1)
|
57
|
+
std = torch.tensor(std).view(-1, 1, 1)
|
58
|
+
return (tensor - mean) / std
|
59
|
+
|
60
|
+
def transform(img):
|
61
|
+
img = img.convert("RGB") if img.mode != "RGB" else img
|
62
|
+
img = resize_image(img, input_size)
|
63
|
+
tensor = to_tensor(img)
|
64
|
+
tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
|
65
|
+
return tensor
|
66
|
+
|
67
|
+
return transform
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def dynamic_preprocess(
|
71
|
+
image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
|
72
|
+
):
|
73
|
+
|
74
|
+
def find_closest_aspect_ratio(
|
75
|
+
aspect_ratio, target_ratios, width, height, image_size
|
76
|
+
):
|
77
|
+
best_ratio_diff = float("inf")
|
78
|
+
best_ratio = (1, 1)
|
79
|
+
area = width * height
|
80
|
+
for ratio in target_ratios:
|
81
|
+
target_aspect_ratio = ratio[0] / ratio[1]
|
82
|
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
83
|
+
if ratio_diff < best_ratio_diff:
|
84
|
+
best_ratio_diff = ratio_diff
|
85
|
+
best_ratio = ratio
|
86
|
+
elif ratio_diff == best_ratio_diff:
|
87
|
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
88
|
+
best_ratio = ratio
|
89
|
+
return best_ratio
|
90
|
+
|
91
|
+
orig_width, orig_height = image.size
|
92
|
+
aspect_ratio = orig_width / orig_height
|
93
|
+
|
94
|
+
# calculate the existing image aspect ratio
|
95
|
+
target_ratios = set(
|
96
|
+
(i, j)
|
97
|
+
for n in range(min_num, max_num + 1)
|
98
|
+
for i in range(1, n + 1)
|
99
|
+
for j in range(1, n + 1)
|
100
|
+
if i * j <= max_num and i * j >= min_num
|
101
|
+
)
|
102
|
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
103
|
+
|
104
|
+
# find the closest aspect ratio to the target
|
105
|
+
target_aspect_ratio = find_closest_aspect_ratio(
|
106
|
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
107
|
+
)
|
108
|
+
|
109
|
+
# calculate the target width and height
|
110
|
+
target_width = image_size * target_aspect_ratio[0]
|
111
|
+
target_height = image_size * target_aspect_ratio[1]
|
112
|
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
113
|
+
|
114
|
+
# resize the image
|
115
|
+
resized_img = image.resize((target_width, target_height))
|
116
|
+
processed_images = []
|
117
|
+
for i in range(blocks):
|
118
|
+
box = (
|
119
|
+
(i % (target_width // image_size)) * image_size,
|
120
|
+
(i // (target_width // image_size)) * image_size,
|
121
|
+
((i % (target_width // image_size)) + 1) * image_size,
|
122
|
+
((i // (target_width // image_size)) + 1) * image_size,
|
123
|
+
)
|
124
|
+
# split the image
|
125
|
+
split_img = resized_img.crop(box)
|
126
|
+
processed_images.append(split_img)
|
127
|
+
assert len(processed_images) == blocks
|
128
|
+
if use_thumbnail and len(processed_images) != 1:
|
129
|
+
thumbnail_img = image.resize((image_size, image_size))
|
130
|
+
processed_images.append(thumbnail_img)
|
131
|
+
return processed_images
|
132
|
+
|
133
|
+
@staticmethod
|
134
|
+
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
135
|
+
if bound:
|
136
|
+
start, end = bound[0], bound[1]
|
137
|
+
else:
|
138
|
+
start, end = -100000, 100000
|
139
|
+
start_idx = max(first_idx, round(start * fps))
|
140
|
+
end_idx = min(round(end * fps), max_frame)
|
141
|
+
seg_size = float(end_idx - start_idx) / num_segments
|
142
|
+
frame_indices = np.array(
|
143
|
+
[
|
144
|
+
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
145
|
+
for idx in range(num_segments)
|
146
|
+
]
|
147
|
+
)
|
148
|
+
return frame_indices
|
149
|
+
|
150
|
+
@staticmethod
|
151
|
+
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
152
|
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
153
|
+
max_frame = len(vr) - 1
|
154
|
+
fps = float(vr.get_avg_fps())
|
155
|
+
|
156
|
+
pixel_values_list, num_patches_list = [], []
|
157
|
+
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
158
|
+
frame_indices = InternVLImageProcessor.get_index(
|
159
|
+
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
160
|
+
)
|
161
|
+
for frame_index in frame_indices:
|
162
|
+
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
|
163
|
+
img = InternVLImageProcessor.dynamic_preprocess(
|
164
|
+
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
165
|
+
)
|
166
|
+
pixel_values = [transform(tile) for tile in img]
|
167
|
+
pixel_values = torch.stack(pixel_values)
|
168
|
+
num_patches_list.append(pixel_values.shape[0])
|
169
|
+
pixel_values_list.append(pixel_values)
|
170
|
+
pixel_values = torch.cat(pixel_values_list)
|
171
|
+
return pixel_values, num_patches_list
|
172
|
+
|
173
|
+
async def process_mm_data_async(
|
174
|
+
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
|
175
|
+
):
|
176
|
+
if not image_data:
|
177
|
+
return None
|
178
|
+
|
179
|
+
base_output = self.load_mm_data(
|
180
|
+
prompt=input_text,
|
181
|
+
image_data=image_data,
|
182
|
+
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN),
|
183
|
+
max_req_input_len=max_req_input_len,
|
184
|
+
discard_alpha_channel=True,
|
185
|
+
)
|
186
|
+
|
187
|
+
def process_image_internvl(image, input_size=448, max_num=12):
|
188
|
+
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
189
|
+
images = InternVLImageProcessor.dynamic_preprocess(
|
190
|
+
image, image_size=input_size, use_thumbnail=True, max_num=max_num
|
191
|
+
)
|
192
|
+
pixel_values = [transform(image) for image in images]
|
193
|
+
pixel_values = torch.stack(pixel_values)
|
194
|
+
return pixel_values
|
195
|
+
|
196
|
+
num_patches_list = []
|
197
|
+
pixel_values = []
|
198
|
+
# Process each input with allocated frames
|
199
|
+
for image_index, (image) in enumerate(base_output.images):
|
200
|
+
try:
|
201
|
+
# TODO: video input
|
202
|
+
raw_image = process_image_internvl(image)
|
203
|
+
pixel_value = [raw_image.to(torch.bfloat16).cuda()]
|
204
|
+
pixel_values += pixel_value
|
205
|
+
num_patches = raw_image.shape[0]
|
206
|
+
num_patches_list += [num_patches]
|
207
|
+
|
208
|
+
except FileNotFoundError as e:
|
209
|
+
print(e)
|
210
|
+
return None
|
211
|
+
|
212
|
+
pixel_values = torch.cat(pixel_values, dim=0)
|
213
|
+
items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
|
214
|
+
|
215
|
+
for idx, num_patches in enumerate(num_patches_list):
|
216
|
+
image_tokens = (
|
217
|
+
self.IMG_START_TOKEN
|
218
|
+
+ self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
|
219
|
+
+ self.IMG_END_TOKEN
|
220
|
+
)
|
221
|
+
input_text = input_text.replace("<image>", image_tokens, 1)
|
222
|
+
|
223
|
+
tokenizer = self._processor
|
224
|
+
return {
|
225
|
+
"input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
|
226
|
+
.flatten()
|
227
|
+
.tolist(),
|
228
|
+
"mm_items": items,
|
229
|
+
"im_start_id": self.img_start_token_id,
|
230
|
+
"im_end_id": self.img_end_token_id,
|
231
|
+
"im_token_id": self.img_context_token_id,
|
232
|
+
}
|
@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
745
745
|
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
746
746
|
output_ids: torch.Tensor = None # shape: [b], int64
|
747
747
|
|
748
|
+
# For multimodal inputs
|
749
|
+
multimodal_inputs: Optional[List] = None
|
750
|
+
|
748
751
|
# The sum of all sequence lengths
|
749
752
|
seq_lens_sum: int = None
|
750
753
|
|
@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1050
1053
|
# Copy prefix and do some basic check
|
1051
1054
|
input_embeds = []
|
1052
1055
|
extend_input_logprob_token_ids = []
|
1056
|
+
multimodal_inputs = []
|
1053
1057
|
|
1054
1058
|
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
|
1055
1059
|
req.req_pool_idx = req_pool_indices[i]
|
@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1065
1069
|
# If req.input_embeds is already a list, append its content directly
|
1066
1070
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
1067
1071
|
|
1072
|
+
multimodal_inputs.append(req.multimodal_inputs)
|
1073
|
+
|
1068
1074
|
req.cached_tokens += pre_len - req.already_computed
|
1069
1075
|
req.already_computed = seq_len
|
1070
1076
|
req.is_retracted = False
|
@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1147
1153
|
if input_embeds
|
1148
1154
|
else None
|
1149
1155
|
)
|
1156
|
+
for mm_input in multimodal_inputs:
|
1157
|
+
if mm_input is None:
|
1158
|
+
continue
|
1159
|
+
for mm_item in mm_input.mm_items:
|
1160
|
+
pixel_values = getattr(mm_item, "pixel_values", None)
|
1161
|
+
if isinstance(pixel_values, torch.Tensor):
|
1162
|
+
mm_item.pixel_values = pixel_values.to(
|
1163
|
+
self.device, non_blocking=True
|
1164
|
+
)
|
1165
|
+
self.multimodal_inputs = multimodal_inputs
|
1150
1166
|
self.seq_lens_sum = sum(seq_lens)
|
1151
1167
|
|
1152
1168
|
if self.return_logprob:
|
@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1452
1468
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
1453
1469
|
|
1454
1470
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
1471
|
+
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1455
1472
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1456
1473
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1457
1474
|
self.out_cache_loc = None
|
@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1500
1517
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
1501
1518
|
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
1502
1519
|
self.reqs.extend(other.reqs)
|
1520
|
+
self.multimodal_inputs.extend(other.multimodal_inputs)
|
1503
1521
|
|
1504
1522
|
self.return_logprob |= other.return_logprob
|
1505
1523
|
self.has_stream |= other.has_stream
|
@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1558
1576
|
extend_seq_lens=extend_seq_lens,
|
1559
1577
|
extend_prefix_lens=extend_prefix_lens,
|
1560
1578
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
1561
|
-
multimodal_inputs=
|
1579
|
+
multimodal_inputs=self.multimodal_inputs,
|
1562
1580
|
encoder_cached=self.encoder_cached,
|
1563
1581
|
encoder_lens=self.encoder_lens,
|
1564
1582
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
@@ -455,7 +455,10 @@ class PrefillAdder:
|
|
455
455
|
total_tokens = req.extend_input_len + min(
|
456
456
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
457
457
|
)
|
458
|
-
input_tokens =
|
458
|
+
input_tokens = (
|
459
|
+
-(-req.extend_input_len // self.tree_cache.page_size)
|
460
|
+
* self.tree_cache.page_size
|
461
|
+
)
|
459
462
|
prefix_len = len(req.prefix_indices)
|
460
463
|
|
461
464
|
if total_tokens >= self.rem_total_tokens:
|
@@ -477,7 +480,10 @@ class PrefillAdder:
|
|
477
480
|
req.last_node_global, req.prefix_indices
|
478
481
|
)
|
479
482
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
480
|
-
input_tokens =
|
483
|
+
input_tokens = (
|
484
|
+
-(-req.extend_input_len // self.tree_cache.page_size)
|
485
|
+
* self.tree_cache.page_size
|
486
|
+
)
|
481
487
|
prefix_len = len(req.prefix_indices)
|
482
488
|
|
483
489
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
@@ -493,12 +499,12 @@ class PrefillAdder:
|
|
493
499
|
),
|
494
500
|
)
|
495
501
|
else:
|
496
|
-
|
502
|
+
# Make sure at least one page is available
|
503
|
+
trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
|
504
|
+
if trunc_len <= 0:
|
497
505
|
return AddReqResult.OTHER
|
498
506
|
|
499
507
|
# Chunked prefill
|
500
|
-
trunc_len = self.rem_chunk_tokens
|
501
|
-
|
502
508
|
req.extend_input_len = trunc_len
|
503
509
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
504
510
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -52,7 +52,11 @@ from sglang.srt.disaggregation.utils import (
|
|
52
52
|
TransferBackend,
|
53
53
|
)
|
54
54
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
55
|
-
from sglang.srt.hf_transformers_utils import
|
55
|
+
from sglang.srt.hf_transformers_utils import (
|
56
|
+
get_processor,
|
57
|
+
get_tokenizer,
|
58
|
+
get_tokenizer_from_processor,
|
59
|
+
)
|
56
60
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
57
61
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
58
62
|
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
@@ -83,6 +87,8 @@ from sglang.srt.managers.io_struct import (
|
|
83
87
|
RpcReqOutput,
|
84
88
|
SetInternalStateReq,
|
85
89
|
SetInternalStateReqOutput,
|
90
|
+
SlowDownReqInput,
|
91
|
+
SlowDownReqOutput,
|
86
92
|
TokenizedEmbeddingReqInput,
|
87
93
|
TokenizedGenerateReqInput,
|
88
94
|
UpdateWeightFromDiskReqInput,
|
@@ -413,6 +419,8 @@ class Scheduler(
|
|
413
419
|
self.profiler_id: Optional[str] = None
|
414
420
|
self.profiler_target_forward_ct: Optional[int] = None
|
415
421
|
|
422
|
+
self.forward_sleep_time = None
|
423
|
+
|
416
424
|
# Init metrics stats
|
417
425
|
self.init_metrics()
|
418
426
|
|
@@ -435,6 +443,7 @@ class Scheduler(
|
|
435
443
|
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
436
444
|
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
437
445
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
446
|
+
(SlowDownReqInput, self.slow_down),
|
438
447
|
(ProfileReq, self.profile),
|
439
448
|
(GetInternalStateReq, self.get_internal_state),
|
440
449
|
(SetInternalStateReq, self.set_internal_state),
|
@@ -451,17 +460,7 @@ class Scheduler(
|
|
451
460
|
def init_tokenizer(self):
|
452
461
|
server_args = self.server_args
|
453
462
|
|
454
|
-
self.model_config = ModelConfig(
|
455
|
-
server_args.model_path,
|
456
|
-
trust_remote_code=server_args.trust_remote_code,
|
457
|
-
revision=server_args.revision,
|
458
|
-
context_length=server_args.context_length,
|
459
|
-
model_override_args=server_args.json_model_override_args,
|
460
|
-
is_embedding=server_args.is_embedding,
|
461
|
-
enable_multimodal=server_args.enable_multimodal,
|
462
|
-
dtype=server_args.dtype,
|
463
|
-
quantization=server_args.quantization,
|
464
|
-
)
|
463
|
+
self.model_config = ModelConfig.from_server_args(server_args)
|
465
464
|
self.is_generation = self.model_config.is_generation
|
466
465
|
|
467
466
|
if server_args.skip_tokenizer_init:
|
@@ -475,7 +474,7 @@ class Scheduler(
|
|
475
474
|
revision=server_args.revision,
|
476
475
|
use_fast=not server_args.disable_fast_image_processor,
|
477
476
|
)
|
478
|
-
self.tokenizer = self.processor
|
477
|
+
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
479
478
|
else:
|
480
479
|
self.tokenizer = get_tokenizer(
|
481
480
|
server_args.tokenizer_path,
|
@@ -498,6 +497,7 @@ class Scheduler(
|
|
498
497
|
self.tree_cache = ChunkCache(
|
499
498
|
req_to_token_pool=self.req_to_token_pool,
|
500
499
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
500
|
+
page_size=self.page_size,
|
501
501
|
)
|
502
502
|
else:
|
503
503
|
if self.enable_hierarchical_cache:
|
@@ -920,6 +920,10 @@ class Scheduler(
|
|
920
920
|
)
|
921
921
|
custom_logit_processor = None
|
922
922
|
|
923
|
+
if recv_req.bootstrap_port is None:
|
924
|
+
# Use default bootstrap port
|
925
|
+
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
|
926
|
+
|
923
927
|
req = Req(
|
924
928
|
recv_req.rid,
|
925
929
|
recv_req.input_text,
|
@@ -1527,6 +1531,10 @@ class Scheduler(
|
|
1527
1531
|
):
|
1528
1532
|
self.stop_profile()
|
1529
1533
|
|
1534
|
+
if self.forward_sleep_time is not None:
|
1535
|
+
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
|
1536
|
+
time.sleep(self.forward_sleep_time)
|
1537
|
+
|
1530
1538
|
# Run forward
|
1531
1539
|
if self.is_generation:
|
1532
1540
|
if self.spec_algorithm.is_none():
|
@@ -2002,6 +2010,13 @@ class Scheduler(
|
|
2002
2010
|
del self.stashed_model_static_state
|
2003
2011
|
return ResumeMemoryOccupationReqOutput()
|
2004
2012
|
|
2013
|
+
def slow_down(self, recv_req: SlowDownReqInput):
|
2014
|
+
t = recv_req.forward_sleep_time
|
2015
|
+
if t is not None and t <= 0:
|
2016
|
+
t = None
|
2017
|
+
self.forward_sleep_time = t
|
2018
|
+
return SlowDownReqOutput()
|
2019
|
+
|
2005
2020
|
def profile(self, recv_req: ProfileReq):
|
2006
2021
|
if recv_req.type == ProfileReqType.START_PROFILE:
|
2007
2022
|
return self.start_profile(
|
@@ -54,7 +54,11 @@ from sglang.srt.disaggregation.utils import (
|
|
54
54
|
TransferBackend,
|
55
55
|
get_kv_class,
|
56
56
|
)
|
57
|
-
from sglang.srt.hf_transformers_utils import
|
57
|
+
from sglang.srt.hf_transformers_utils import (
|
58
|
+
get_processor,
|
59
|
+
get_tokenizer,
|
60
|
+
get_tokenizer_from_processor,
|
61
|
+
)
|
58
62
|
from sglang.srt.managers.io_struct import (
|
59
63
|
AbortReq,
|
60
64
|
BatchEmbeddingOut,
|
@@ -86,6 +90,8 @@ from sglang.srt.managers.io_struct import (
|
|
86
90
|
ResumeMemoryOccupationReqInput,
|
87
91
|
ResumeMemoryOccupationReqOutput,
|
88
92
|
SessionParams,
|
93
|
+
SlowDownReqInput,
|
94
|
+
SlowDownReqOutput,
|
89
95
|
TokenizedEmbeddingReqInput,
|
90
96
|
TokenizedGenerateReqInput,
|
91
97
|
UpdateWeightFromDiskReqInput,
|
@@ -161,17 +167,7 @@ class TokenizerManager:
|
|
161
167
|
# Read model args
|
162
168
|
self.model_path = server_args.model_path
|
163
169
|
self.served_model_name = server_args.served_model_name
|
164
|
-
self.model_config = ModelConfig(
|
165
|
-
server_args.model_path,
|
166
|
-
trust_remote_code=server_args.trust_remote_code,
|
167
|
-
revision=server_args.revision,
|
168
|
-
context_length=server_args.context_length,
|
169
|
-
model_override_args=server_args.json_model_override_args,
|
170
|
-
is_embedding=server_args.is_embedding,
|
171
|
-
enable_multimodal=server_args.enable_multimodal,
|
172
|
-
dtype=server_args.dtype,
|
173
|
-
quantization=server_args.quantization,
|
174
|
-
)
|
170
|
+
self.model_config = ModelConfig.from_server_args(server_args)
|
175
171
|
|
176
172
|
self.is_generation = self.model_config.is_generation
|
177
173
|
self.is_image_gen = self.model_config.is_image_gen
|
@@ -199,7 +195,7 @@ class TokenizerManager:
|
|
199
195
|
self.tokenizer = self.processor = None
|
200
196
|
else:
|
201
197
|
self.processor = _processor
|
202
|
-
self.tokenizer = self.processor
|
198
|
+
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
203
199
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
204
200
|
else:
|
205
201
|
self.mm_processor = get_dummy_processor()
|
@@ -265,6 +261,9 @@ class TokenizerManager:
|
|
265
261
|
self.resume_memory_occupation_communicator = _Communicator(
|
266
262
|
self.send_to_scheduler, server_args.dp_size
|
267
263
|
)
|
264
|
+
self.slow_down_communicator = _Communicator(
|
265
|
+
self.send_to_scheduler, server_args.dp_size
|
266
|
+
)
|
268
267
|
self.flush_cache_communicator = _Communicator(
|
269
268
|
self.send_to_scheduler, server_args.dp_size
|
270
269
|
)
|
@@ -318,6 +317,10 @@ class TokenizerManager:
|
|
318
317
|
ResumeMemoryOccupationReqOutput,
|
319
318
|
self.resume_memory_occupation_communicator.handle_recv,
|
320
319
|
),
|
320
|
+
(
|
321
|
+
SlowDownReqOutput,
|
322
|
+
self.slow_down_communicator.handle_recv,
|
323
|
+
),
|
321
324
|
(
|
322
325
|
FlushCacheReqOutput,
|
323
326
|
self.flush_cache_communicator.handle_recv,
|
@@ -876,6 +879,14 @@ class TokenizerManager:
|
|
876
879
|
self.auto_create_handle_loop()
|
877
880
|
await self.resume_memory_occupation_communicator(obj)
|
878
881
|
|
882
|
+
async def slow_down(
|
883
|
+
self,
|
884
|
+
obj: SlowDownReqInput,
|
885
|
+
request: Optional[fastapi.Request] = None,
|
886
|
+
):
|
887
|
+
self.auto_create_handle_loop()
|
888
|
+
await self.slow_down_communicator(obj)
|
889
|
+
|
879
890
|
async def open_session(
|
880
891
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
881
892
|
):
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -21,7 +21,11 @@ import torch
|
|
21
21
|
|
22
22
|
from sglang.srt.configs.model_config import ModelConfig
|
23
23
|
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
|
24
|
-
from sglang.srt.hf_transformers_utils import
|
24
|
+
from sglang.srt.hf_transformers_utils import (
|
25
|
+
get_processor,
|
26
|
+
get_tokenizer,
|
27
|
+
get_tokenizer_from_processor,
|
28
|
+
)
|
25
29
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
26
30
|
from sglang.srt.managers.io_struct import (
|
27
31
|
GetWeightsByNameReqInput,
|
@@ -61,20 +65,13 @@ class TpModelWorker:
|
|
61
65
|
self.pp_rank = pp_rank
|
62
66
|
|
63
67
|
# Init model and tokenizer
|
64
|
-
self.model_config = ModelConfig(
|
65
|
-
|
68
|
+
self.model_config = ModelConfig.from_server_args(
|
69
|
+
server_args,
|
70
|
+
model_path=(
|
66
71
|
server_args.model_path
|
67
72
|
if not is_draft_worker
|
68
73
|
else server_args.speculative_draft_model_path
|
69
74
|
),
|
70
|
-
trust_remote_code=server_args.trust_remote_code,
|
71
|
-
revision=server_args.revision,
|
72
|
-
context_length=server_args.context_length,
|
73
|
-
model_override_args=server_args.json_model_override_args,
|
74
|
-
is_embedding=server_args.is_embedding,
|
75
|
-
enable_multimodal=server_args.enable_multimodal,
|
76
|
-
dtype=server_args.dtype,
|
77
|
-
quantization=server_args.quantization,
|
78
75
|
is_draft_model=is_draft_worker,
|
79
76
|
)
|
80
77
|
|
@@ -102,7 +99,7 @@ class TpModelWorker:
|
|
102
99
|
trust_remote_code=server_args.trust_remote_code,
|
103
100
|
revision=server_args.revision,
|
104
101
|
)
|
105
|
-
self.tokenizer = self.processor
|
102
|
+
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
106
103
|
else:
|
107
104
|
self.tokenizer = get_tokenizer(
|
108
105
|
server_args.tokenizer_path,
|
@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache):
|
|
24
24
|
self,
|
25
25
|
req_to_token_pool: ReqToTokenPool,
|
26
26
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
27
|
+
page_size: int,
|
27
28
|
):
|
28
29
|
self.req_to_token_pool = req_to_token_pool
|
29
30
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
31
|
+
self.page_size = page_size
|
30
32
|
|
31
33
|
def reset(self):
|
32
34
|
pass
|
@@ -374,9 +374,9 @@ class MHATokenToKVPool(KVCache):
|
|
374
374
|
# Overlap the copy of K and V cache for small batch size
|
375
375
|
current_stream = self.device_module.current_stream()
|
376
376
|
self.alt_stream.wait_stream(current_stream)
|
377
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
377
378
|
with self.device_module.stream(self.alt_stream):
|
378
|
-
self.
|
379
|
-
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
379
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
380
380
|
current_stream.wait_stream(self.alt_stream)
|
381
381
|
else:
|
382
382
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|