sglang 0.4.5.post3__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -9
- sglang/compile_deep_gemm.py +45 -4
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +9 -3
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +59 -11
- sglang/srt/disaggregation/mini_lb.py +45 -8
- sglang/srt/disaggregation/mooncake/conn.py +198 -31
- sglang/srt/disaggregation/prefill.py +24 -9
- sglang/srt/entrypoints/http_server.py +8 -2
- sglang/srt/function_call_parser.py +77 -5
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +28 -10
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/layernorm.py +38 -16
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/deep_gemm.py +17 -10
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +170 -126
- sglang/srt/managers/data_parallel_controller.py +10 -3
- sglang/srt/managers/io_struct.py +7 -0
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +29 -12
- sglang/srt/managers/scheduler.py +31 -20
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +87 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -3
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +11 -24
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +144 -70
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpmo.py +5 -1
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +18 -8
- sglang/srt/server_args.py +15 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +2 -1
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +36 -18
- sglang/version.py +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/METADATA +4 -5
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/RECORD +70 -68
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post3.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17
17
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18
18
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
19
|
+
from typing import List, Union
|
19
20
|
|
20
21
|
import torch
|
21
22
|
|
@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|
35
36
|
self.IMAGE_TOKEN = "<image>"
|
36
37
|
|
37
38
|
async def process_mm_data_async(
|
38
|
-
self,
|
39
|
+
self,
|
40
|
+
image_data: List[Union[str, bytes]],
|
41
|
+
input_text,
|
42
|
+
request_obj,
|
43
|
+
max_req_input_len,
|
44
|
+
*args,
|
45
|
+
**kwargs
|
39
46
|
):
|
40
47
|
if not image_data:
|
41
48
|
return None
|
@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
|
45
52
|
|
46
53
|
image_token = self.IMAGE_TOKEN
|
47
54
|
base_output = self.load_mm_data(
|
48
|
-
|
55
|
+
input_text,
|
49
56
|
image_data=image_data,
|
50
57
|
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
51
58
|
max_req_input_len=max_req_input_len,
|
@@ -1,7 +1,5 @@
|
|
1
1
|
from typing import List, Union
|
2
2
|
|
3
|
-
from transformers.utils import logging
|
4
|
-
|
5
3
|
from sglang.srt.managers.multimodal_processor import (
|
6
4
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
7
5
|
)
|
@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
|
13
11
|
|
14
12
|
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
|
15
13
|
# will be removed in the future
|
16
|
-
logger = logging.get_logger(__name__)
|
17
14
|
|
18
15
|
|
19
16
|
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|
28
25
|
async def process_mm_data_async(
|
29
26
|
self,
|
30
27
|
image_data: List[Union[str, bytes]],
|
31
|
-
|
28
|
+
input_text,
|
32
29
|
request_obj,
|
33
30
|
max_req_input_len,
|
34
31
|
*args,
|
@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
|
41
38
|
|
42
39
|
image_token = self.IMAGE_TOKEN
|
43
40
|
base_output = self.load_mm_data(
|
44
|
-
prompt=
|
41
|
+
prompt=input_text,
|
45
42
|
image_data=image_data,
|
46
43
|
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
47
44
|
max_req_input_len=max_req_input_len,
|
@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|
17
17
|
async def process_mm_data_async(
|
18
18
|
self,
|
19
19
|
image_data: List[Union[str, bytes]],
|
20
|
-
|
20
|
+
input_text,
|
21
21
|
request_obj,
|
22
22
|
max_req_input_len,
|
23
23
|
**kwargs,
|
@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
|
31
31
|
processor = self._processor
|
32
32
|
|
33
33
|
base_out = self.load_mm_data(
|
34
|
-
prompt=
|
34
|
+
prompt=input_text,
|
35
35
|
image_data=image_data,
|
36
36
|
multimodal_tokens=MultimodalSpecialTokens(
|
37
37
|
image_token=processor.image_token
|
@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|
51
51
|
async def process_mm_data_async(
|
52
52
|
self,
|
53
53
|
image_data: List[Union[str, bytes]],
|
54
|
-
|
54
|
+
input_text,
|
55
55
|
request_obj,
|
56
56
|
max_req_input_len,
|
57
|
+
**kwargs,
|
57
58
|
):
|
58
59
|
audio_data = request_obj.audio_data
|
59
60
|
if not image_data and not audio_data:
|
@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|
64
65
|
audio_data = [audio_data]
|
65
66
|
|
66
67
|
base_output = self.load_mm_data(
|
67
|
-
prompt=
|
68
|
+
prompt=input_text,
|
68
69
|
max_req_input_len=max_req_input_len,
|
69
70
|
audio_data=audio_data,
|
70
71
|
image_data=image_data,
|
@@ -96,7 +97,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|
96
97
|
audio_start_id = tokenizer.audio_start_id
|
97
98
|
audio_end_id = tokenizer.audio_end_id
|
98
99
|
|
99
|
-
im_token_id = tokenizer.
|
100
|
+
im_token_id = tokenizer.unk_id
|
100
101
|
pixel_values = res["pixel_values"]
|
101
102
|
tgt_sizes = res["tgt_sizes"]
|
102
103
|
|
@@ -5,6 +5,7 @@ from typing import List, Union
|
|
5
5
|
import torch
|
6
6
|
from PIL import Image
|
7
7
|
|
8
|
+
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
8
9
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
9
10
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
10
11
|
)
|
@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
27
28
|
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
28
29
|
self.image_token_id = hf_config.image_token_id
|
29
30
|
self.video_token_id = hf_config.video_token_id
|
31
|
+
self.vision_start_token_id = hf_config.vision_start_token_id
|
32
|
+
self.vision_end_token_id = hf_config.vision_end_token_id
|
30
33
|
self.NUM_TOKEN_PER_FRAME = 770
|
31
34
|
self.IMAGE_FACTOR = 28
|
32
35
|
self.MIN_PIXELS = 4 * 28 * 28
|
@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
36
39
|
async def process_mm_data_async(
|
37
40
|
self,
|
38
41
|
image_data: List[Union[str, bytes]],
|
39
|
-
|
42
|
+
input_text,
|
40
43
|
request_obj,
|
41
44
|
max_req_input_len,
|
42
45
|
*args,
|
43
46
|
**kwargs,
|
44
47
|
):
|
45
|
-
if not image_data:
|
46
|
-
return None
|
47
48
|
if isinstance(image_data, str):
|
48
49
|
image_data = [image_data]
|
49
50
|
|
50
51
|
image_token = self.IMAGE_TOKEN
|
51
52
|
base_output = self.load_mm_data(
|
52
|
-
prompt=
|
53
|
+
prompt=input_text,
|
53
54
|
image_data=image_data,
|
54
55
|
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
|
55
56
|
max_req_input_len=max_req_input_len,
|
@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|
116
117
|
async def resize_image_async(image):
|
117
118
|
return resize_image(image)
|
118
119
|
|
119
|
-
|
120
|
-
|
120
|
+
if base_output.images:
|
121
|
+
resize_tasks = [resize_image_async(image) for image in base_output.images]
|
122
|
+
base_output.images = await asyncio.gather(*resize_tasks)
|
121
123
|
|
122
124
|
ret = self.process_mm_data(
|
123
125
|
input_text=base_output.input_text,
|
124
|
-
images=
|
126
|
+
images=base_output.images,
|
125
127
|
)
|
126
128
|
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
129
|
+
items = []
|
130
|
+
|
131
|
+
input_ids = ret["input_ids"].flatten().tolist()
|
132
|
+
if "pixel_values" in ret:
|
133
|
+
items += [
|
131
134
|
MultimodalDataItem(
|
132
135
|
pixel_values=ret["pixel_values"],
|
133
|
-
image_grid_thws=
|
136
|
+
image_grid_thws=torch.concat([ret["image_grid_thw"]]),
|
134
137
|
# TODO
|
135
138
|
video_grid_thws=None,
|
136
139
|
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
137
140
|
modality=Modality.IMAGE,
|
138
141
|
)
|
139
|
-
]
|
142
|
+
]
|
143
|
+
|
144
|
+
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
145
|
+
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
146
|
+
image_token_id=self.image_token_id,
|
147
|
+
video_token_id=self.video_token_id,
|
148
|
+
vision_start_token_id=self.vision_start_token_id,
|
149
|
+
model_type=self.hf_config.model_type,
|
150
|
+
tokens_per_second=getattr(
|
151
|
+
self.hf_config.vision_config, "tokens_per_second", None
|
152
|
+
),
|
153
|
+
input_ids=torch.tensor(input_ids).unsqueeze(0),
|
154
|
+
image_grid_thw=ret.get("image_grid_thw", None),
|
155
|
+
video_grid_thw=ret.get("video_grid_thw", None),
|
156
|
+
second_per_grid_ts=ret.get("second_per_grid_ts", None),
|
157
|
+
)
|
158
|
+
mrope_positions = mrope_positions.squeeze(1)
|
159
|
+
|
160
|
+
return {
|
161
|
+
"input_ids": input_ids,
|
162
|
+
"mm_items": items,
|
140
163
|
"im_start_id": self.IM_START_TOKEN_ID,
|
141
164
|
"im_end_id": self.IM_END_TOKEN_ID,
|
142
165
|
"im_token_id": self.image_token_id,
|
143
166
|
"video_token_id": self.video_token_id,
|
167
|
+
"mrope_positions": mrope_positions,
|
168
|
+
"mrope_position_delta": mrope_position_delta,
|
144
169
|
}
|
@@ -285,6 +285,7 @@ class MultimodalInputs:
|
|
285
285
|
num_image_tokens: Optional[int] = None
|
286
286
|
|
287
287
|
# QWen2-VL related
|
288
|
+
mrope_positions: Optional[torch.Tensor] = None
|
288
289
|
mrope_position_delta: Optional[torch.Tensor] = None
|
289
290
|
|
290
291
|
# image
|
@@ -310,16 +311,12 @@ class MultimodalInputs:
|
|
310
311
|
assert isinstance(ret.mm_items, list)
|
311
312
|
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
312
313
|
|
313
|
-
assert len(ret.mm_items) != 0
|
314
|
-
|
315
|
-
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
316
|
-
# Please note that if the `input_ids` is later used in the model forward,
|
317
|
-
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
318
|
-
# errors in cuda kernels. See also llava.py for example.
|
319
314
|
for item in ret.mm_items:
|
320
315
|
item.set_pad_value()
|
321
316
|
|
322
317
|
optional_args = [
|
318
|
+
"mrope_positions",
|
319
|
+
"mrope_position_delta",
|
323
320
|
"im_token_id",
|
324
321
|
"im_start_id",
|
325
322
|
"im_end_id",
|
@@ -350,11 +347,6 @@ class MultimodalInputs:
|
|
350
347
|
merge image inputs when requests are being merged
|
351
348
|
"""
|
352
349
|
|
353
|
-
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
354
|
-
# Please note that if the `input_ids` is later used in the model forward,
|
355
|
-
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
|
356
|
-
# errors in cuda kernels. See also llava.py for example.
|
357
|
-
|
358
350
|
# args needed to be merged
|
359
351
|
optional_args = [
|
360
352
|
"mm_items",
|
@@ -364,6 +356,30 @@ class MultimodalInputs:
|
|
364
356
|
self_arg = getattr(self, arg, None)
|
365
357
|
if self_arg is not None:
|
366
358
|
setattr(self, arg, self_arg + getattr(other, arg))
|
359
|
+
|
360
|
+
mrope_positions = self.mrope_positions
|
361
|
+
if mrope_positions is not None:
|
362
|
+
if other.mrope_positions is None:
|
363
|
+
self.mrope_positions = mrope_positions
|
364
|
+
else:
|
365
|
+
self.mrope_positions = torch.cat(
|
366
|
+
[self.mrope_positions, other.mrope_positions], dim=1
|
367
|
+
)
|
368
|
+
|
369
|
+
mrope_position_delta = self.mrope_position_delta
|
370
|
+
if mrope_position_delta is not None:
|
371
|
+
if other.mrope_position_delta is None:
|
372
|
+
self.mrope_position_delta = mrope_position_delta
|
373
|
+
else:
|
374
|
+
self.mrope_position_delta = torch.cat(
|
375
|
+
[self.mrope_position_delta, other.mrope_position_delta], dim=0
|
376
|
+
)
|
377
|
+
|
378
|
+
for key, val in other.__dict__.items():
|
379
|
+
if "_id" in key:
|
380
|
+
# set token_ids
|
381
|
+
if getattr(self, key, None) is None:
|
382
|
+
setattr(self, key, getattr(other, key, None))
|
367
383
|
# other args would be kept intact
|
368
384
|
|
369
385
|
|
@@ -388,6 +404,7 @@ class Req:
|
|
388
404
|
return_hidden_states: bool = False,
|
389
405
|
eos_token_ids: Optional[Set[int]] = None,
|
390
406
|
bootstrap_host: Optional[str] = None,
|
407
|
+
bootstrap_port: Optional[int] = None,
|
391
408
|
bootstrap_room: Optional[int] = None,
|
392
409
|
):
|
393
410
|
# Input and output info
|
@@ -523,6 +540,7 @@ class Req:
|
|
523
540
|
|
524
541
|
# For disaggregation
|
525
542
|
self.bootstrap_host: str = bootstrap_host
|
543
|
+
self.bootstrap_port: Optional[int] = bootstrap_port
|
526
544
|
self.bootstrap_room: Optional[int] = bootstrap_room
|
527
545
|
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
528
546
|
|
@@ -1450,7 +1468,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1450
1468
|
if self.model_config.is_encoder_decoder:
|
1451
1469
|
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
1452
1470
|
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
1453
|
-
|
1454
1471
|
self.req_pool_indices = torch.cat(
|
1455
1472
|
[self.req_pool_indices, other.req_pool_indices]
|
1456
1473
|
)
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -578,6 +578,10 @@ class Scheduler(
|
|
578
578
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
579
579
|
transfer_backend=self.transfer_backend,
|
580
580
|
)
|
581
|
+
|
582
|
+
# Metric for pre-allocation
|
583
|
+
self.num_tokens_pre_allocated = 0
|
584
|
+
|
581
585
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
582
586
|
# *2 for the headroom.
|
583
587
|
buffer_size = self.max_running_requests * 2
|
@@ -593,7 +597,7 @@ class Scheduler(
|
|
593
597
|
)
|
594
598
|
metadata_buffers = [output_id_buffer]
|
595
599
|
|
596
|
-
self.
|
600
|
+
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
597
601
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
598
602
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
599
603
|
metadata_buffers=metadata_buffers,
|
@@ -787,6 +791,7 @@ class Scheduler(
|
|
787
791
|
return_hidden_states=recv_req.return_hidden_states,
|
788
792
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
789
793
|
bootstrap_host=recv_req.bootstrap_host,
|
794
|
+
bootstrap_port=recv_req.bootstrap_port,
|
790
795
|
bootstrap_room=recv_req.bootstrap_room,
|
791
796
|
)
|
792
797
|
req.tokenizer = self.tokenizer
|
@@ -901,7 +906,7 @@ class Scheduler(
|
|
901
906
|
def _add_request_to_queue(self, req: Req):
|
902
907
|
req.queue_time_start = time.time()
|
903
908
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
904
|
-
self.
|
909
|
+
self.disagg_prefill_bootstrap_queue.add(req)
|
905
910
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
906
911
|
self.disagg_decode_prealloc_queue.add(req)
|
907
912
|
else:
|
@@ -991,8 +996,15 @@ class Scheduler(
|
|
991
996
|
f"#cached-token: {adder.log_hit_tokens}, "
|
992
997
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
993
998
|
f"#running-req: {running_bs}, "
|
994
|
-
f"#queue-req: {len(self.waiting_queue)}, "
|
995
999
|
)
|
1000
|
+
|
1001
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1002
|
+
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
1003
|
+
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1004
|
+
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
|
1005
|
+
else:
|
1006
|
+
f += f"#queue-req: {len(self.waiting_queue)}"
|
1007
|
+
|
996
1008
|
logger.info(f)
|
997
1009
|
|
998
1010
|
if self.enable_metrics:
|
@@ -1028,15 +1040,14 @@ class Scheduler(
|
|
1028
1040
|
gap_latency / self.server_args.decode_log_interval
|
1029
1041
|
)
|
1030
1042
|
|
1043
|
+
msg = (
|
1044
|
+
f"Decode batch. "
|
1045
|
+
f"#running-req: {num_running_reqs}, "
|
1046
|
+
f"#token: {num_used}, "
|
1047
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
1048
|
+
)
|
1049
|
+
|
1031
1050
|
if self.spec_algorithm.is_none():
|
1032
|
-
msg = (
|
1033
|
-
f"Decode batch. "
|
1034
|
-
f"#running-req: {num_running_reqs}, "
|
1035
|
-
f"#token: {num_used}, "
|
1036
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
1037
|
-
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1038
|
-
f"#queue-req: {len(self.waiting_queue)}, "
|
1039
|
-
)
|
1040
1051
|
spec_accept_length = 0
|
1041
1052
|
else:
|
1042
1053
|
spec_accept_length = (
|
@@ -1045,15 +1056,15 @@ class Scheduler(
|
|
1045
1056
|
self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
|
1046
1057
|
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
1047
1058
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
1048
|
-
msg
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1059
|
+
msg += f"accept len: {spec_accept_length:.2f}, "
|
1060
|
+
|
1061
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1062
|
+
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
1063
|
+
|
1064
|
+
msg += (
|
1065
|
+
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1066
|
+
f"#queue-req: {len(self.waiting_queue)}"
|
1067
|
+
)
|
1057
1068
|
|
1058
1069
|
logger.info(msg)
|
1059
1070
|
if self.enable_metrics:
|
@@ -419,7 +419,10 @@ class TokenizerManager:
|
|
419
419
|
input_ids = self.tokenizer.encode(input_text)
|
420
420
|
|
421
421
|
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
422
|
-
obj.image_data,
|
422
|
+
image_data=obj.image_data,
|
423
|
+
input_text=input_text or input_ids,
|
424
|
+
request_obj=obj,
|
425
|
+
max_req_input_len=self.max_req_input_len,
|
423
426
|
)
|
424
427
|
if image_inputs and "input_ids" in image_inputs:
|
425
428
|
input_ids = image_inputs["input_ids"]
|
@@ -495,6 +498,7 @@ class TokenizerManager:
|
|
495
498
|
token_ids_logprob,
|
496
499
|
obj.stream,
|
497
500
|
bootstrap_host=obj.bootstrap_host,
|
501
|
+
bootstrap_port=obj.bootstrap_port,
|
498
502
|
bootstrap_room=obj.bootstrap_room,
|
499
503
|
lora_path=obj.lora_path,
|
500
504
|
input_embeds=input_embeds,
|
@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
|
|
34
34
|
import numpy as np
|
35
35
|
import psutil
|
36
36
|
import torch
|
37
|
+
import triton
|
38
|
+
import triton.language as tl
|
37
39
|
|
38
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
41
|
from sglang.srt.utils import debug_timing, get_compiler_backend
|
@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|
405
407
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
406
408
|
|
407
409
|
|
410
|
+
@triton.jit
|
411
|
+
def set_mla_kv_buffer_kernel(
|
412
|
+
kv_buffer_ptr,
|
413
|
+
cache_k_nope_ptr,
|
414
|
+
cache_k_rope_ptr,
|
415
|
+
loc_ptr,
|
416
|
+
buffer_stride: tl.constexpr,
|
417
|
+
nope_stride: tl.constexpr,
|
418
|
+
rope_stride: tl.constexpr,
|
419
|
+
nope_dim: tl.constexpr,
|
420
|
+
rope_dim: tl.constexpr,
|
421
|
+
BLOCK: tl.constexpr,
|
422
|
+
):
|
423
|
+
pid_loc = tl.program_id(0)
|
424
|
+
pid_blk = tl.program_id(1)
|
425
|
+
|
426
|
+
base = pid_blk * BLOCK
|
427
|
+
offs = base + tl.arange(0, BLOCK)
|
428
|
+
total_dim = nope_dim + rope_dim
|
429
|
+
mask = offs < total_dim
|
430
|
+
|
431
|
+
loc = tl.load(loc_ptr + pid_loc)
|
432
|
+
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
|
433
|
+
|
434
|
+
if base + BLOCK <= nope_dim:
|
435
|
+
src = tl.load(
|
436
|
+
cache_k_nope_ptr + pid_loc * nope_stride + offs,
|
437
|
+
mask=mask,
|
438
|
+
)
|
439
|
+
else:
|
440
|
+
offs_rope = offs - nope_dim
|
441
|
+
src = tl.load(
|
442
|
+
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
|
443
|
+
mask=mask,
|
444
|
+
)
|
445
|
+
|
446
|
+
tl.store(dst_ptr, src, mask=mask)
|
447
|
+
|
448
|
+
|
449
|
+
def set_mla_kv_buffer_triton(
|
450
|
+
kv_buffer: torch.Tensor,
|
451
|
+
loc: torch.Tensor,
|
452
|
+
cache_k_nope: torch.Tensor,
|
453
|
+
cache_k_rope: torch.Tensor,
|
454
|
+
):
|
455
|
+
nope_dim = cache_k_nope.shape[-1]
|
456
|
+
rope_dim = cache_k_rope.shape[-1]
|
457
|
+
total_dim = nope_dim + rope_dim
|
458
|
+
BLOCK = 128
|
459
|
+
n_loc = loc.numel()
|
460
|
+
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
|
461
|
+
|
462
|
+
set_mla_kv_buffer_kernel[grid](
|
463
|
+
kv_buffer,
|
464
|
+
cache_k_nope,
|
465
|
+
cache_k_rope,
|
466
|
+
loc,
|
467
|
+
kv_buffer.stride(0),
|
468
|
+
cache_k_nope.stride(0),
|
469
|
+
cache_k_rope.stride(0),
|
470
|
+
nope_dim,
|
471
|
+
rope_dim,
|
472
|
+
BLOCK=BLOCK,
|
473
|
+
)
|
474
|
+
|
475
|
+
|
408
476
|
class MLATokenToKVPool(KVCache):
|
409
477
|
def __init__(
|
410
478
|
self,
|
@@ -504,6 +572,25 @@ class MLATokenToKVPool(KVCache):
|
|
504
572
|
else:
|
505
573
|
self.kv_buffer[layer_id][loc] = cache_k
|
506
574
|
|
575
|
+
def set_mla_kv_buffer(
|
576
|
+
self,
|
577
|
+
layer: RadixAttention,
|
578
|
+
loc: torch.Tensor,
|
579
|
+
cache_k_nope: torch.Tensor,
|
580
|
+
cache_k_rope: torch.Tensor,
|
581
|
+
):
|
582
|
+
layer_id = layer.layer_id
|
583
|
+
if cache_k_nope.dtype != self.dtype:
|
584
|
+
cache_k_nope = cache_k_nope.to(self.dtype)
|
585
|
+
cache_k_rope = cache_k_rope.to(self.dtype)
|
586
|
+
if self.store_dtype != self.dtype:
|
587
|
+
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
588
|
+
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
589
|
+
|
590
|
+
set_mla_kv_buffer_triton(
|
591
|
+
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
592
|
+
)
|
593
|
+
|
507
594
|
def get_flat_data(self, indices):
|
508
595
|
# prepare a large chunk of contiguous data for efficient transfer
|
509
596
|
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
@@ -134,7 +134,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
134
134
|
)
|
135
135
|
|
136
136
|
gpu_mem = get_device_memory_capacity()
|
137
|
-
|
137
|
+
# Batch size of each rank will not become so large when DP is on
|
138
|
+
if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
|
138
139
|
capture_bs += list(range(160, 257, 8))
|
139
140
|
|
140
141
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
@@ -278,9 +279,9 @@ class CudaGraphRunner:
|
|
278
279
|
f"Capture cuda graph failed: {e}\n"
|
279
280
|
"Possible solutions:\n"
|
280
281
|
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
281
|
-
"2. set --cuda-graph-max-bs to a smaller value (e.g.,
|
282
|
+
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
282
283
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
283
|
-
"4. disable cuda graph by --disable-cuda-graph\n"
|
284
|
+
"4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)\n"
|
284
285
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
285
286
|
)
|
286
287
|
|