sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +394 -76
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/rotary_embedding.py +0 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +59 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from enum import Enum, auto
|
4
|
+
|
3
5
|
# Copyright 2023-2024 SGLang Team
|
4
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
7
|
# you may not use this file except in compliance with the License.
|
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
|
51
53
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
52
54
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
53
55
|
from sglang.srt.server_args import ServerArgs
|
54
|
-
from sglang.srt.utils import get_compiler_backend
|
56
|
+
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
55
57
|
|
56
58
|
if TYPE_CHECKING:
|
57
59
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
@@ -70,14 +72,16 @@ global_server_args_dict = {
|
|
70
72
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
71
73
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
72
74
|
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
75
|
+
"deepep_mode": ServerArgs.deepep_mode,
|
73
76
|
"device": ServerArgs.device,
|
74
77
|
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
75
78
|
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
76
|
-
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
77
79
|
"enable_flashmla": ServerArgs.enable_flashmla,
|
78
80
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
79
81
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
80
82
|
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
83
|
+
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
84
|
+
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
|
81
85
|
}
|
82
86
|
|
83
87
|
logger = logging.getLogger(__name__)
|
@@ -143,165 +147,185 @@ class FINISH_ABORT(BaseFinishReason):
|
|
143
147
|
}
|
144
148
|
|
145
149
|
|
150
|
+
class Modality(Enum):
|
151
|
+
IMAGE = auto()
|
152
|
+
MULTI_IMAGES = auto()
|
153
|
+
VIDEO = auto()
|
154
|
+
AUDIO = auto()
|
155
|
+
|
156
|
+
|
146
157
|
@dataclasses.dataclass
|
147
|
-
class
|
148
|
-
"""
|
158
|
+
class MultimodalDataItem:
|
159
|
+
"""
|
160
|
+
A single multimodal data, from a single image/video/audio or other
|
161
|
+
"""
|
162
|
+
|
163
|
+
modality: Modality
|
164
|
+
|
165
|
+
hash: int = None
|
166
|
+
pad_value: int = None
|
149
167
|
|
150
|
-
|
151
|
-
|
152
|
-
|
168
|
+
aspect_ratio_id: Optional[List[torch.Tensor]] = None
|
169
|
+
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
170
|
+
|
171
|
+
image_sizes: Tuple[int, int] = None
|
153
172
|
image_offsets: Optional[list] = None
|
173
|
+
|
174
|
+
# the real data, pixel_values or audio_features
|
175
|
+
# data: Union[List[torch.Tensor], List[np.array]]
|
176
|
+
pixel_values: Union[torch.Tensor, np.array] = None
|
177
|
+
image_grid_thws: Union[torch.Tensor, np.array] = None
|
178
|
+
video_grid_thws: Union[torch.Tensor, np.array] = None
|
179
|
+
|
180
|
+
image_emb_mask: Optional[torch.Tensor] = None
|
181
|
+
image_spatial_crop: Optional[torch.Tensor] = None
|
182
|
+
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
183
|
+
|
184
|
+
# [num_images, (n, w, h)]
|
185
|
+
tgt_size: Tuple[int, int] = None
|
186
|
+
|
187
|
+
audio_features: Union[torch.Tensor, np.array] = None
|
188
|
+
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
189
|
+
|
190
|
+
@staticmethod
|
191
|
+
def is_empty_list(l):
|
192
|
+
if l is None:
|
193
|
+
return True
|
194
|
+
return len([item for item in flatten_nested_list(l) if item is not None]) == 0
|
195
|
+
|
196
|
+
def set_pad_value(self):
|
197
|
+
"""
|
198
|
+
Set the pad value after first hashign the data
|
199
|
+
"""
|
200
|
+
|
201
|
+
def hash_feature(f):
|
202
|
+
if isinstance(f, list):
|
203
|
+
return hash(tuple(flatten_nested_list(f)))
|
204
|
+
elif isinstance(f, np.ndarray):
|
205
|
+
arr = np.ascontiguousarray(f)
|
206
|
+
arr_bytes = arr.tobytes()
|
207
|
+
return hash(arr_bytes)
|
208
|
+
return hash(f)
|
209
|
+
|
210
|
+
if self.is_audio():
|
211
|
+
self.hash = hash_feature(self.audio_features)
|
212
|
+
else:
|
213
|
+
self.hash = hash_feature(self.pixel_values)
|
214
|
+
|
215
|
+
assert self.hash is not None
|
216
|
+
self.pad_value = self.hash % (1 << 30)
|
217
|
+
|
218
|
+
def is_audio(self):
|
219
|
+
return (
|
220
|
+
self.modality == Modality.AUDIO
|
221
|
+
) and not MultimodalDataItem.is_empty_list(self.audio_features)
|
222
|
+
|
223
|
+
def is_image(self):
|
224
|
+
return (
|
225
|
+
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
|
226
|
+
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
227
|
+
|
228
|
+
def is_video(self):
|
229
|
+
return (
|
230
|
+
self.modality == Modality.VIDEO
|
231
|
+
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
232
|
+
|
233
|
+
def validate(self):
|
234
|
+
...
|
235
|
+
# TODO
|
236
|
+
|
237
|
+
|
238
|
+
@dataclasses.dataclass
|
239
|
+
class MultimodalInputs:
|
240
|
+
"""The multimodal data related inputs."""
|
241
|
+
|
242
|
+
# items of data
|
243
|
+
mm_items: List[MultimodalDataItem]
|
154
244
|
image_pad_len: Optional[list] = None
|
155
|
-
pad_values: Optional[list] = None
|
156
|
-
modalities: Optional[list] = None
|
157
245
|
num_image_tokens: Optional[int] = None
|
158
246
|
|
159
|
-
# Llava related
|
160
|
-
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
|
161
|
-
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
162
|
-
|
163
247
|
# QWen2-VL related
|
164
|
-
# [num_of_images, t, h, w]
|
165
|
-
image_grid_thws: torch.Tensor = None
|
166
248
|
mrope_position_delta: Optional[torch.Tensor] = None
|
167
|
-
# Qwen2-VL video related
|
168
|
-
video_token_id: Optional[int] = None
|
169
|
-
video_grid_thws: List[Tuple[int, int, int]] = None
|
170
|
-
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
171
249
|
|
172
|
-
#
|
173
|
-
images_emb_mask: Optional[List[torch.Tensor]] = None
|
174
|
-
image_spatial_crop: Optional[List[torch.Tensor]] = None
|
175
|
-
|
176
|
-
# The id of the single-image placeholder token
|
250
|
+
# image
|
177
251
|
im_token_id: Optional[torch.Tensor] = None
|
178
|
-
|
179
|
-
# All the images in the batch should share the same special image
|
180
|
-
# bound token ids.
|
181
252
|
im_start_id: Optional[int] = None
|
182
253
|
im_end_id: Optional[int] = None
|
183
254
|
slice_start_id: Optional[int] = None
|
184
255
|
slice_end_id: Optional[int] = None
|
185
|
-
|
186
|
-
|
256
|
+
|
257
|
+
# video
|
258
|
+
video_token_id: Optional[int] = None
|
187
259
|
|
188
260
|
# audio
|
189
261
|
audio_start_id: Optional[torch.Tensor] = None
|
190
262
|
audio_end_id: Optional[torch.Tensor] = None
|
191
|
-
audio_features: Optional[List[torch.Tensor]] = None
|
192
|
-
audio_feature_lens: Optional[List[torch.Tensor]] = None
|
193
263
|
|
194
264
|
@staticmethod
|
195
265
|
def from_dict(obj: dict):
|
196
266
|
ret = MultimodalInputs(
|
197
|
-
|
198
|
-
data_hashes=obj["data_hashes"],
|
267
|
+
mm_items=obj["mm_items"],
|
199
268
|
)
|
200
269
|
|
270
|
+
assert isinstance(ret.mm_items, list)
|
271
|
+
ret.mm_items = [
|
272
|
+
item
|
273
|
+
for item in ret.mm_items
|
274
|
+
if item.is_audio() or item.is_image() or item.is_video()
|
275
|
+
]
|
276
|
+
|
277
|
+
assert len(ret.mm_items) != 0
|
278
|
+
|
201
279
|
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
202
280
|
# Please note that if the `input_ids` is later used in the model forward,
|
203
281
|
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
204
282
|
# errors in cuda kernels. See also llava.py for example.
|
205
|
-
|
283
|
+
for item in ret.mm_items:
|
284
|
+
item.set_pad_value()
|
206
285
|
|
207
286
|
optional_args = [
|
208
|
-
"image_sizes",
|
209
287
|
"modalities",
|
210
|
-
"aspect_ratio_ids",
|
211
|
-
"aspect_ratio_mask",
|
212
|
-
"image_grid_thws",
|
213
|
-
"images_emb_mask",
|
214
|
-
"image_spatial_crop",
|
215
288
|
"im_token_id",
|
216
289
|
"im_start_id",
|
217
290
|
"im_end_id",
|
218
291
|
"slice_start_id",
|
219
292
|
"slice_end_id",
|
220
|
-
"tgt_sizes",
|
221
293
|
"audio_start_id",
|
222
294
|
"audio_end_id",
|
223
|
-
"audio_features",
|
224
|
-
"audio_feature_lens",
|
225
295
|
]
|
226
296
|
for arg in optional_args:
|
227
297
|
if arg in obj:
|
228
298
|
setattr(ret, arg, obj[arg])
|
229
299
|
|
230
|
-
# validate
|
231
|
-
assert (
|
232
|
-
isinstance(ret.pixel_values, torch.Tensor)
|
233
|
-
or isinstance(ret.pixel_values, np.ndarray)
|
234
|
-
or isinstance(ret.pixel_values, list)
|
235
|
-
)
|
236
|
-
|
237
|
-
assert ret.audio_features is None or isinstance(ret.audio_features, list)
|
238
|
-
|
239
300
|
return ret
|
240
301
|
|
241
302
|
def contains_image_inputs(self) -> bool:
|
242
303
|
""" """
|
243
|
-
return
|
304
|
+
return any(item.is_image() for item in self.mm_items)
|
244
305
|
|
245
306
|
def contains_audio_inputs(self) -> bool:
|
246
307
|
""" """
|
247
|
-
return
|
308
|
+
return any(item.is_audio() for item in self.mm_items)
|
309
|
+
|
310
|
+
def collect_image_inputs(self) -> List[torch.Tensor]:
|
311
|
+
return [item.pixel_values for item in self.mm_items if item.is_image()]
|
248
312
|
|
249
313
|
def merge(self, other: MultimodalInputs):
|
250
314
|
"""
|
251
315
|
merge image inputs when requests are being merged
|
252
316
|
"""
|
253
|
-
if isinstance(self.pixel_values, list):
|
254
|
-
# in some rare cases, pixel values are list of patches with different shapes
|
255
|
-
# e.g. minicpm
|
256
|
-
self.pixel_values += other.pixel_values
|
257
|
-
else:
|
258
|
-
assert (
|
259
|
-
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
260
|
-
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
|
261
|
-
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
262
|
-
|
263
|
-
# args would be stacked along first dim
|
264
|
-
# usually these are already tensors
|
265
|
-
stack_args = [
|
266
|
-
# TODO: merge with image_grid_thws, basically the same thing
|
267
|
-
"tgt_sizes",
|
268
|
-
"image_spatial_crop",
|
269
|
-
]
|
270
|
-
for arg in stack_args:
|
271
|
-
if getattr(self, arg, None) is None:
|
272
|
-
setattr(self, arg, getattr(other, arg, None))
|
273
|
-
elif getattr(other, arg, None) is not None:
|
274
|
-
# self and other both not None
|
275
|
-
setattr(
|
276
|
-
self,
|
277
|
-
arg,
|
278
|
-
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
|
279
|
-
)
|
280
|
-
|
281
|
-
if self.image_grid_thws is None:
|
282
|
-
self.image_grid_thws = other.image_grid_thws
|
283
|
-
elif other.image_grid_thws is not None:
|
284
|
-
self.image_grid_thws = torch.concat(
|
285
|
-
[self.image_grid_thws, other.image_grid_thws]
|
286
|
-
)
|
287
317
|
|
288
318
|
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
289
319
|
# Please note that if the `input_ids` is later used in the model forward,
|
290
320
|
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
291
321
|
# errors in cuda kernels. See also llava.py for example.
|
292
|
-
self.data_hashes += other.data_hashes
|
293
|
-
self.pad_values = [x % (1 << 30) for x in self.data_hashes]
|
294
322
|
|
295
323
|
# args needed to be merged
|
296
324
|
optional_args = [
|
297
|
-
"
|
298
|
-
"image_sizes",
|
325
|
+
"items",
|
299
326
|
"image_offsets",
|
300
327
|
"image_pad_len",
|
301
328
|
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
302
|
-
"aspect_ratio_ids",
|
303
|
-
"aspect_ratio_mask",
|
304
|
-
"images_emb_mask",
|
305
329
|
]
|
306
330
|
for arg in optional_args:
|
307
331
|
self_arg = getattr(self, arg, None)
|
@@ -599,6 +623,7 @@ class Req:
|
|
599
623
|
self.extend_logprob_start_len = 0
|
600
624
|
self.is_chunked = 0
|
601
625
|
self.req_pool_idx = None
|
626
|
+
self.already_computed = 0
|
602
627
|
|
603
628
|
def __repr__(self):
|
604
629
|
return (
|
@@ -740,11 +765,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
740
765
|
)
|
741
766
|
return req_pool_indices
|
742
767
|
|
743
|
-
def alloc_token_slots(self, num_tokens: int):
|
768
|
+
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
|
744
769
|
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
745
770
|
if self.tree_cache is not None:
|
746
771
|
self.tree_cache.evict(num_tokens)
|
747
772
|
|
773
|
+
if backup_state:
|
774
|
+
state = self.token_to_kv_pool_allocator.backup_state()
|
775
|
+
|
748
776
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
749
777
|
if out_cache_loc is None:
|
750
778
|
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
@@ -758,7 +786,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
758
786
|
self.tree_cache.pretty_print()
|
759
787
|
raise RuntimeError(error_msg)
|
760
788
|
|
761
|
-
|
789
|
+
if backup_state:
|
790
|
+
return out_cache_loc, state
|
791
|
+
else:
|
792
|
+
return out_cache_loc
|
762
793
|
|
763
794
|
def alloc_paged_token_slots_extend(
|
764
795
|
self,
|
@@ -766,6 +797,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
766
797
|
seq_lens: torch.Tensor,
|
767
798
|
last_loc: torch.Tensor,
|
768
799
|
extend_num_tokens: int,
|
800
|
+
backup_state: bool = False,
|
769
801
|
):
|
770
802
|
if (
|
771
803
|
self.token_to_kv_pool_allocator.available_size()
|
@@ -778,6 +810,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
778
810
|
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
779
811
|
)
|
780
812
|
|
813
|
+
if backup_state:
|
814
|
+
state = self.token_to_kv_pool_allocator.backup_state()
|
815
|
+
|
781
816
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
782
817
|
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
783
818
|
)
|
@@ -791,23 +826,31 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
791
826
|
)
|
792
827
|
logger.error(error_msg)
|
793
828
|
raise RuntimeError(error_msg)
|
794
|
-
|
829
|
+
|
830
|
+
if backup_state:
|
831
|
+
return out_cache_loc, state
|
832
|
+
else:
|
833
|
+
return out_cache_loc
|
795
834
|
|
796
835
|
def alloc_paged_token_slots_decode(
|
797
836
|
self,
|
798
837
|
seq_lens: torch.Tensor,
|
799
838
|
last_loc: torch.Tensor,
|
839
|
+
backup_state: bool = False,
|
800
840
|
):
|
801
|
-
if
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
841
|
+
if self.tree_cache is not None:
|
842
|
+
if (
|
843
|
+
self.token_to_kv_pool_allocator.available_size()
|
844
|
+
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
845
|
+
):
|
806
846
|
self.tree_cache.evict(
|
807
847
|
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
808
848
|
)
|
809
|
-
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
810
849
|
|
850
|
+
if backup_state:
|
851
|
+
state = self.token_to_kv_pool_allocator.backup_state()
|
852
|
+
|
853
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
811
854
|
if out_cache_loc is None:
|
812
855
|
error_msg = (
|
813
856
|
f"Decode out of memory. Try to lower your batch size.\n"
|
@@ -818,7 +861,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
818
861
|
)
|
819
862
|
logger.error(error_msg)
|
820
863
|
raise RuntimeError(error_msg)
|
821
|
-
|
864
|
+
|
865
|
+
if backup_state:
|
866
|
+
return out_cache_loc, state
|
867
|
+
else:
|
868
|
+
return out_cache_loc
|
822
869
|
|
823
870
|
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
824
871
|
self.encoder_lens_cpu = []
|
@@ -938,8 +985,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
938
985
|
# If req.input_embeds is already a list, append its content directly
|
939
986
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
940
987
|
|
941
|
-
if req.is_retracted:
|
942
|
-
req.already_computed = 0
|
943
988
|
req.cached_tokens += pre_len - req.already_computed
|
944
989
|
req.already_computed = seq_len
|
945
990
|
req.is_retracted = False
|
@@ -1095,17 +1140,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1095
1140
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
1096
1141
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
1097
1142
|
|
1098
|
-
def
|
1099
|
-
|
1100
|
-
if
|
1101
|
-
return
|
1143
|
+
def new_page_count_next_decode(self):
|
1144
|
+
page_size = self.token_to_kv_pool_allocator.page_size
|
1145
|
+
if page_size == 1:
|
1146
|
+
return len(self.reqs)
|
1147
|
+
return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
|
1102
1148
|
|
1103
|
-
|
1149
|
+
def check_decode_mem(self, buf_multiplier=1):
|
1150
|
+
tokens_required = (
|
1151
|
+
self.new_page_count_next_decode()
|
1152
|
+
* buf_multiplier
|
1153
|
+
* self.token_to_kv_pool_allocator.page_size
|
1154
|
+
)
|
1104
1155
|
|
1105
|
-
if self.token_to_kv_pool_allocator.available_size() >=
|
1156
|
+
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
|
1106
1157
|
return True
|
1107
1158
|
|
1108
|
-
|
1159
|
+
self.tree_cache.evict(tokens_required)
|
1160
|
+
|
1161
|
+
return self.token_to_kv_pool_allocator.available_size() >= tokens_required
|
1109
1162
|
|
1110
1163
|
def retract_decode(self, server_args: ServerArgs):
|
1111
1164
|
"""Retract the decoding requests when there is not enough memory."""
|
@@ -1167,7 +1220,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1167
1220
|
self.req_to_token_pool.free(req.req_pool_idx)
|
1168
1221
|
else:
|
1169
1222
|
# TODO: apply more fine-grained retraction
|
1170
|
-
last_uncached_pos =
|
1223
|
+
last_uncached_pos = (
|
1224
|
+
len(req.prefix_indices) // server_args.page_size
|
1225
|
+
) * server_args.page_size
|
1171
1226
|
token_indices = self.req_to_token_pool.req_to_token[
|
1172
1227
|
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
1173
1228
|
]
|
@@ -1373,21 +1428,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1373
1428
|
|
1374
1429
|
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
1375
1430
|
if self.forward_mode.is_decode_or_idle():
|
1376
|
-
if (
|
1377
|
-
global_server_args_dict["enable_flashinfer_mla"]
|
1378
|
-
or global_server_args_dict["enable_flashmla"]
|
1379
|
-
or global_server_args_dict["attention_backend"] == "fa3"
|
1380
|
-
):
|
1381
|
-
decode_seq_lens = self.seq_lens.cpu()
|
1382
|
-
else:
|
1383
|
-
decode_seq_lens = None
|
1384
1431
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
1385
1432
|
else:
|
1386
|
-
decode_seq_lens = None
|
1387
1433
|
extend_seq_lens = self.extend_lens
|
1388
1434
|
extend_prefix_lens = self.prefix_lens
|
1389
1435
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
1390
1436
|
|
1437
|
+
# Create seq_lens_cpu when needed
|
1438
|
+
if (
|
1439
|
+
(
|
1440
|
+
global_server_args_dict["use_mla_backend"]
|
1441
|
+
and global_server_args_dict["attention_backend"] == "flashinfer"
|
1442
|
+
)
|
1443
|
+
or global_server_args_dict["enable_flashmla"]
|
1444
|
+
or global_server_args_dict["attention_backend"] == "fa3"
|
1445
|
+
):
|
1446
|
+
seq_lens_cpu = self.seq_lens.cpu()
|
1447
|
+
else:
|
1448
|
+
seq_lens_cpu = None
|
1449
|
+
|
1391
1450
|
if self.sampling_info:
|
1392
1451
|
if self.has_grammar:
|
1393
1452
|
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
@@ -1410,7 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1410
1469
|
global_num_tokens=self.global_num_tokens,
|
1411
1470
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
1412
1471
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
1413
|
-
|
1472
|
+
seq_lens_cpu=seq_lens_cpu,
|
1414
1473
|
extend_num_tokens=self.extend_num_tokens,
|
1415
1474
|
extend_seq_lens=extend_seq_lens,
|
1416
1475
|
extend_prefix_lens=extend_prefix_lens,
|
@@ -1471,6 +1530,7 @@ class ModelWorkerBatch:
|
|
1471
1530
|
req_pool_indices: torch.Tensor
|
1472
1531
|
# The sequence length
|
1473
1532
|
seq_lens: torch.Tensor
|
1533
|
+
seq_lens_cpu: Optional[torch.Tensor]
|
1474
1534
|
# The indices of output tokens in the token_to_kv_pool_allocator
|
1475
1535
|
out_cache_loc: torch.Tensor
|
1476
1536
|
|
@@ -1487,9 +1547,6 @@ class ModelWorkerBatch:
|
|
1487
1547
|
global_num_tokens_for_logprob: Optional[List[int]]
|
1488
1548
|
can_run_dp_cuda_graph: bool
|
1489
1549
|
|
1490
|
-
# For decode
|
1491
|
-
decode_seq_lens: Optional[torch.Tensor]
|
1492
|
-
|
1493
1550
|
# For extend
|
1494
1551
|
extend_num_tokens: Optional[int]
|
1495
1552
|
extend_seq_lens: Optional[List[int]]
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
112
112
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
113
113
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
114
114
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
115
|
-
from sglang.srt.model_executor.forward_batch_info import
|
115
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
116
116
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
117
117
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
118
118
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -1110,7 +1110,7 @@ class Scheduler(
|
|
1110
1110
|
)
|
1111
1111
|
if memory_leak:
|
1112
1112
|
msg = (
|
1113
|
-
"
|
1113
|
+
"token_to_kv_pool_allocator memory leak detected! "
|
1114
1114
|
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
1115
1115
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
1116
1116
|
f"{self.tree_cache.evictable_size()=}\n"
|
@@ -1121,7 +1121,7 @@ class Scheduler(
|
|
1121
1121
|
|
1122
1122
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
1123
1123
|
msg = (
|
1124
|
-
"
|
1124
|
+
"req_to_token_pool memory leak detected!"
|
1125
1125
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
1126
1126
|
f"total_size={self.req_to_token_pool.size}\n"
|
1127
1127
|
)
|
@@ -1282,7 +1282,7 @@ class Scheduler(
|
|
1282
1282
|
]
|
1283
1283
|
|
1284
1284
|
if self.enable_hierarchical_cache:
|
1285
|
-
self.tree_cache.
|
1285
|
+
self.tree_cache.ready_to_load_cache()
|
1286
1286
|
|
1287
1287
|
if adder.new_chunked_req is not None:
|
1288
1288
|
assert self.chunked_req is None
|
@@ -736,7 +736,7 @@ class TokenizerManager:
|
|
736
736
|
self.auto_create_handle_loop()
|
737
737
|
assert (
|
738
738
|
self.server_args.dp_size == 1
|
739
|
-
), "dp_size must be for update weights from distributed"
|
739
|
+
), "dp_size must be 1 for update weights from distributed"
|
740
740
|
|
741
741
|
# This means that weight sync
|
742
742
|
# cannot run while requests are in progress.
|
sglang/srt/managers/utils.py
CHANGED
@@ -1,11 +1,6 @@
|
|
1
|
-
import json
|
2
1
|
import logging
|
3
|
-
import time
|
4
|
-
from collections import defaultdict
|
5
2
|
from http import HTTPStatus
|
6
|
-
from typing import
|
7
|
-
|
8
|
-
import torch
|
3
|
+
from typing import Optional
|
9
4
|
|
10
5
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
11
6
|
|