sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +36 -2
- sglang/srt/conversation.py +56 -3
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +50 -18
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +20 -5
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
- sglang/srt/layers/moe/ep_moe/layer.py +141 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +35 -3
- sglang/srt/managers/mm_utils.py +59 -96
- sglang/srt/managers/schedule_batch.py +17 -6
- sglang/srt/managers/scheduler.py +38 -6
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +176 -101
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +78 -19
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +372 -82
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +63 -61
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +26 -4
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +191 -48
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -814,9 +814,9 @@ def sample_mmmu_requests(
|
|
814
814
|
List of tuples (prompt, prompt_token_len, output_token_len).
|
815
815
|
"""
|
816
816
|
try:
|
817
|
-
import base64
|
818
817
|
import io
|
819
818
|
|
819
|
+
import pybase64
|
820
820
|
from datasets import load_dataset
|
821
821
|
except ImportError:
|
822
822
|
raise ImportError("Please install datasets: pip install datasets")
|
@@ -867,7 +867,7 @@ def sample_mmmu_requests(
|
|
867
867
|
# Encode image to base64
|
868
868
|
buffered = io.BytesIO()
|
869
869
|
image.save(buffered, format="JPEG")
|
870
|
-
img_str =
|
870
|
+
img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8")
|
871
871
|
image_data = f"data:image/jpeg;base64,{img_str}"
|
872
872
|
else:
|
873
873
|
continue
|
@@ -25,6 +25,7 @@ from transformers import PretrainedConfig
|
|
25
25
|
from sglang.srt.hf_transformers_utils import (
|
26
26
|
get_config,
|
27
27
|
get_context_length,
|
28
|
+
get_generation_config,
|
28
29
|
get_hf_text_config,
|
29
30
|
)
|
30
31
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
@@ -83,6 +84,13 @@ class ModelConfig:
|
|
83
84
|
**kwargs,
|
84
85
|
)
|
85
86
|
|
87
|
+
self.hf_generation_config = get_generation_config(
|
88
|
+
self.model_path,
|
89
|
+
trust_remote_code=trust_remote_code,
|
90
|
+
revision=revision,
|
91
|
+
**kwargs,
|
92
|
+
)
|
93
|
+
|
86
94
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
87
95
|
self.attention_chunk_size = getattr(
|
88
96
|
self.hf_text_config, "attention_chunk_size", None
|
@@ -359,7 +367,17 @@ class ModelConfig:
|
|
359
367
|
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
360
368
|
quant_cfg = modelopt_quant_config
|
361
369
|
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
362
|
-
|
370
|
+
quant_config_file = os.path.join(
|
371
|
+
self.model_path, "hf_quant_config.json"
|
372
|
+
)
|
373
|
+
with open(quant_config_file) as f:
|
374
|
+
quant_config_dict = json.load(f)
|
375
|
+
json_quant_configs = quant_config_dict["quantization"]
|
376
|
+
quant_algo = json_quant_configs.get("quant_algo", None)
|
377
|
+
if quant_algo == "MIXED_PRECISION":
|
378
|
+
quant_cfg = {"quant_method": "w4afp8"}
|
379
|
+
else:
|
380
|
+
quant_cfg = modelopt_quant_config
|
363
381
|
return quant_cfg
|
364
382
|
|
365
383
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
@@ -389,6 +407,7 @@ class ModelConfig:
|
|
389
407
|
"w8a8_fp8",
|
390
408
|
"moe_wna16",
|
391
409
|
"qoq",
|
410
|
+
"w4afp8",
|
392
411
|
]
|
393
412
|
compatible_quantization_methods = {
|
394
413
|
"modelopt_fp4": ["modelopt"],
|
@@ -402,7 +421,9 @@ class ModelConfig:
|
|
402
421
|
quant_cfg = self._parse_quant_hf_config()
|
403
422
|
|
404
423
|
if quant_cfg is not None:
|
405
|
-
quant_method = quant_cfg.get(
|
424
|
+
quant_method = quant_cfg.get(
|
425
|
+
"quant_method", "" if not self.quantization else self.quantization
|
426
|
+
).lower()
|
406
427
|
|
407
428
|
# Detect which checkpoint is it
|
408
429
|
for _, method in QUANTIZATION_METHODS.items():
|
@@ -454,6 +475,19 @@ class ModelConfig:
|
|
454
475
|
if eos_ids:
|
455
476
|
# it can be either int or list of int
|
456
477
|
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
478
|
+
if eos_ids is None:
|
479
|
+
eos_ids = set()
|
480
|
+
if self.hf_generation_config:
|
481
|
+
generation_eos_ids = getattr(
|
482
|
+
self.hf_generation_config, "eos_token_id", None
|
483
|
+
)
|
484
|
+
if generation_eos_ids:
|
485
|
+
generation_eos_ids = (
|
486
|
+
{generation_eos_ids}
|
487
|
+
if isinstance(generation_eos_ids, int)
|
488
|
+
else set(generation_eos_ids)
|
489
|
+
)
|
490
|
+
eos_ids = eos_ids | generation_eos_ids
|
457
491
|
return eos_ids
|
458
492
|
|
459
493
|
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
sglang/srt/conversation.py
CHANGED
@@ -88,9 +88,11 @@ class Conversation:
|
|
88
88
|
stop_str: Union[str, List[str]] = None
|
89
89
|
# The string that represents an image token in the prompt
|
90
90
|
image_token: str = "<image>"
|
91
|
+
video_token: str = "<video>"
|
91
92
|
audio_token: str = "<audio>"
|
92
93
|
|
93
94
|
image_data: Optional[List[str]] = None
|
95
|
+
video_data: Optional[List[str]] = None
|
94
96
|
modalities: Optional[List[str]] = None
|
95
97
|
stop_token_ids: Optional[int] = None
|
96
98
|
|
@@ -380,11 +382,15 @@ class Conversation:
|
|
380
382
|
self.messages.append([role, message])
|
381
383
|
|
382
384
|
def append_image(self, image: str):
|
383
|
-
"""Append a new
|
385
|
+
"""Append a new image."""
|
384
386
|
self.image_data.append(image)
|
385
387
|
|
388
|
+
def append_video(self, video: str):
|
389
|
+
"""Append a new video."""
|
390
|
+
self.video_data.append(video)
|
391
|
+
|
386
392
|
def append_audio(self, audio: str):
|
387
|
-
"""Append a new
|
393
|
+
"""Append a new audio."""
|
388
394
|
self.audio_data.append(audio)
|
389
395
|
|
390
396
|
def update_last_message(self, message: str):
|
@@ -433,6 +439,7 @@ class Conversation:
|
|
433
439
|
sep2=self.sep2,
|
434
440
|
stop_str=self.stop_str,
|
435
441
|
image_token=self.image_token,
|
442
|
+
video_token=self.video_token,
|
436
443
|
audio_token=self.audio_token,
|
437
444
|
)
|
438
445
|
|
@@ -495,8 +502,12 @@ def generate_embedding_convs(
|
|
495
502
|
sep2=conv_template.sep2,
|
496
503
|
stop_str=conv_template.stop_str,
|
497
504
|
image_data=[],
|
505
|
+
video_data=[],
|
506
|
+
audio_data=[],
|
498
507
|
modalities=[],
|
499
508
|
image_token=conv_template.image_token,
|
509
|
+
video_token=conv_template.video_token,
|
510
|
+
audio_token=conv_template.audio_token,
|
500
511
|
)
|
501
512
|
real_content = ""
|
502
513
|
|
@@ -557,10 +568,12 @@ def generate_chat_conv(
|
|
557
568
|
sep2=conv.sep2,
|
558
569
|
stop_str=conv.stop_str,
|
559
570
|
image_data=[],
|
571
|
+
video_data=[],
|
560
572
|
audio_data=[],
|
561
573
|
modalities=[],
|
562
574
|
image_token=conv.image_token,
|
563
575
|
audio_token=conv.audio_token,
|
576
|
+
video_token=conv.video_token,
|
564
577
|
)
|
565
578
|
|
566
579
|
if isinstance(request.messages, str):
|
@@ -602,6 +615,7 @@ def generate_chat_conv(
|
|
602
615
|
image_token = ""
|
603
616
|
|
604
617
|
audio_token = conv.audio_token
|
618
|
+
video_token = conv.video_token
|
605
619
|
for content in message.content:
|
606
620
|
if content.type == "text":
|
607
621
|
if num_image_url > 16:
|
@@ -614,6 +628,9 @@ def generate_chat_conv(
|
|
614
628
|
else:
|
615
629
|
real_content += image_token
|
616
630
|
conv.append_image(content.image_url.url)
|
631
|
+
elif content.type == "video_url":
|
632
|
+
real_content += video_token
|
633
|
+
conv.append_video(content.video_url.url)
|
617
634
|
elif content.type == "audio_url":
|
618
635
|
real_content += audio_token
|
619
636
|
conv.append_audio(content.audio_url.url)
|
@@ -810,6 +827,7 @@ register_conv_template(
|
|
810
827
|
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
811
828
|
stop_str=["<|im_end|>"],
|
812
829
|
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
830
|
+
video_token="<|vision_start|><|video_pad|><|vision_end|>",
|
813
831
|
)
|
814
832
|
)
|
815
833
|
|
@@ -870,6 +888,7 @@ register_conv_template(
|
|
870
888
|
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
871
889
|
stop_str=("<|im_end|>", "<|endoftext|>"),
|
872
890
|
image_token="(<image>./</image>)",
|
891
|
+
video_token="(<video>./</video>)",
|
873
892
|
)
|
874
893
|
)
|
875
894
|
|
@@ -921,6 +940,19 @@ register_conv_template(
|
|
921
940
|
)
|
922
941
|
)
|
923
942
|
|
943
|
+
register_conv_template(
|
944
|
+
Conversation(
|
945
|
+
name="mimo-vl",
|
946
|
+
system_message="You are MiMo, an AI assistant developed by Xiaomi.",
|
947
|
+
system_template="<|im_start|>system\n{system_message}",
|
948
|
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
949
|
+
sep="<|im_end|>\n",
|
950
|
+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
951
|
+
stop_str=["<|im_end|>"],
|
952
|
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
953
|
+
)
|
954
|
+
)
|
955
|
+
|
924
956
|
|
925
957
|
register_conv_template(
|
926
958
|
Conversation(
|
@@ -935,6 +967,19 @@ register_conv_template(
|
|
935
967
|
)
|
936
968
|
)
|
937
969
|
|
970
|
+
register_conv_template(
|
971
|
+
Conversation(
|
972
|
+
name="llama_4_vision",
|
973
|
+
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
974
|
+
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
975
|
+
roles=("user", "assistant"),
|
976
|
+
sep_style=SeparatorStyle.LLAMA4,
|
977
|
+
sep="",
|
978
|
+
stop_str="<|eot|>",
|
979
|
+
image_token="<|image|>",
|
980
|
+
)
|
981
|
+
)
|
982
|
+
|
938
983
|
|
939
984
|
@register_conv_template_matching_function
|
940
985
|
def match_internvl(model_path: str):
|
@@ -943,9 +988,11 @@ def match_internvl(model_path: str):
|
|
943
988
|
|
944
989
|
|
945
990
|
@register_conv_template_matching_function
|
946
|
-
def
|
991
|
+
def match_llama_vision(model_path: str):
|
947
992
|
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
|
948
993
|
return "llama_3_vision"
|
994
|
+
if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
|
995
|
+
return "llama_4_vision"
|
949
996
|
|
950
997
|
|
951
998
|
@register_conv_template_matching_function
|
@@ -1034,3 +1081,9 @@ def match_phi_4_mm(model_path: str):
|
|
1034
1081
|
def match_vila(model_path: str):
|
1035
1082
|
if re.search(r"vila", model_path, re.IGNORECASE):
|
1036
1083
|
return "chatml"
|
1084
|
+
|
1085
|
+
|
1086
|
+
@register_conv_template_matching_function
|
1087
|
+
def match_mimo_vl(model_path: str):
|
1088
|
+
if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
|
1089
|
+
return "mimo-vl"
|
@@ -0,0 +1,44 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
4
|
+
from sglang.srt.disaggregation.mooncake.conn import (
|
5
|
+
MooncakeKVBootstrapServer,
|
6
|
+
MooncakeKVManager,
|
7
|
+
MooncakeKVReceiver,
|
8
|
+
MooncakeKVSender,
|
9
|
+
)
|
10
|
+
from sglang.srt.utils import get_local_ip_by_remote
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
class AscendKVManager(MooncakeKVManager):
|
16
|
+
def init_engine(self):
|
17
|
+
# TransferEngine initialized on ascend.
|
18
|
+
local_ip = get_local_ip_by_remote()
|
19
|
+
self.engine = AscendTransferEngine(
|
20
|
+
hostname=local_ip,
|
21
|
+
npu_id=self.kv_args.gpu_id,
|
22
|
+
disaggregation_mode=self.disaggregation_mode,
|
23
|
+
)
|
24
|
+
|
25
|
+
def register_buffer_to_engine(self):
|
26
|
+
self.engine.register(
|
27
|
+
self.kv_args.kv_data_ptrs[0], sum(self.kv_args.kv_data_lens)
|
28
|
+
)
|
29
|
+
# The Ascend backend optimize batch registration for small memory blocks.
|
30
|
+
self.engine.batch_register(
|
31
|
+
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
class AscendKVSender(MooncakeKVSender):
|
36
|
+
pass
|
37
|
+
|
38
|
+
|
39
|
+
class AscendKVReceiver(MooncakeKVReceiver):
|
40
|
+
pass
|
41
|
+
|
42
|
+
|
43
|
+
class AscendKVBootstrapServer(MooncakeKVBootstrapServer):
|
44
|
+
pass
|
@@ -0,0 +1,58 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from typing import List, Optional
|
4
|
+
|
5
|
+
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
6
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
class AscendTransferEngine(MooncakeTransferEngine):
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
|
15
|
+
):
|
16
|
+
try:
|
17
|
+
from mf_adapter import TransferEngine
|
18
|
+
except ImportError as e:
|
19
|
+
raise ImportError(
|
20
|
+
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
|
21
|
+
) from e
|
22
|
+
|
23
|
+
self.engine = TransferEngine()
|
24
|
+
self.hostname = hostname
|
25
|
+
self.npu_id = npu_id
|
26
|
+
|
27
|
+
# Centralized storage address of the AscendTransferEngine
|
28
|
+
self.store_url = os.getenv("ASCEND_MF_STORE_URL")
|
29
|
+
if disaggregation_mode == DisaggregationMode.PREFILL:
|
30
|
+
self.role = "Prefill"
|
31
|
+
elif disaggregation_mode == DisaggregationMode.DECODE:
|
32
|
+
self.role = "Decode"
|
33
|
+
else:
|
34
|
+
logger.error(f"Unsupported DisaggregationMode: {disaggregation_mode}")
|
35
|
+
raise ValueError(f"Unsupported DisaggregationMode: {disaggregation_mode}")
|
36
|
+
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
37
|
+
self.initialize()
|
38
|
+
|
39
|
+
def initialize(self) -> None:
|
40
|
+
"""Initialize the ascend transfer instance."""
|
41
|
+
ret_value = self.engine.initialize(
|
42
|
+
self.store_url,
|
43
|
+
self.session_id,
|
44
|
+
self.role,
|
45
|
+
self.npu_id,
|
46
|
+
)
|
47
|
+
if ret_value != 0:
|
48
|
+
logger.error("Ascend Transfer Engine initialization failed.")
|
49
|
+
raise RuntimeError("Ascend Transfer Engine initialization failed.")
|
50
|
+
|
51
|
+
def batch_register(self, ptrs: List[int], lengths: List[int]):
|
52
|
+
try:
|
53
|
+
ret_value = self.engine.batch_register_memory(ptrs, lengths)
|
54
|
+
except Exception:
|
55
|
+
# Mark register as failed
|
56
|
+
ret_value = -1
|
57
|
+
if ret_value != 0:
|
58
|
+
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
|
@@ -132,13 +132,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
132
132
|
):
|
133
133
|
self.kv_args = args
|
134
134
|
self.local_ip = get_local_ip_auto()
|
135
|
-
self.engine = MooncakeTransferEngine(
|
136
|
-
hostname=self.local_ip,
|
137
|
-
gpu_id=self.kv_args.gpu_id,
|
138
|
-
ib_device=self.kv_args.ib_device,
|
139
|
-
)
|
140
135
|
self.is_mla_backend = is_mla_backend
|
141
136
|
self.disaggregation_mode = disaggregation_mode
|
137
|
+
self.init_engine()
|
142
138
|
# for p/d multi node infer
|
143
139
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
144
140
|
self.dist_init_addr = server_args.dist_init_addr
|
@@ -185,9 +181,11 @@ class MooncakeKVManager(BaseKVManager):
|
|
185
181
|
threading.Thread(
|
186
182
|
target=self.transfer_worker, args=(queue, executor), daemon=True
|
187
183
|
).start()
|
188
|
-
|
189
|
-
|
190
|
-
|
184
|
+
# If a timeout happens on the prefill side, it means prefill instances
|
185
|
+
# fail to receive the KV indices from the decode instance of this request.
|
186
|
+
# These timeout requests should be aborted to release the tree cache.
|
187
|
+
self.bootstrap_timeout = get_int_env_var(
|
188
|
+
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
|
191
189
|
)
|
192
190
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
193
191
|
self.heartbeat_failures = {}
|
@@ -209,6 +207,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
209
207
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
210
208
|
self.prefill_tp_size_table: Dict[str, int] = {}
|
211
209
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
210
|
+
# If a timeout happens on the decode side, it means decode instances
|
211
|
+
# fail to receive the KV Cache transfer done signal after bootstrapping.
|
212
|
+
# These timeout requests should be aborted to release the tree cache.
|
213
|
+
self.waiting_timeout = get_int_env_var(
|
214
|
+
"SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
|
215
|
+
)
|
212
216
|
else:
|
213
217
|
raise ValueError(
|
214
218
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
@@ -217,6 +221,13 @@ class MooncakeKVManager(BaseKVManager):
|
|
217
221
|
self.failure_records: Dict[int, str] = {}
|
218
222
|
self.failure_lock = threading.Lock()
|
219
223
|
|
224
|
+
def init_engine(self):
|
225
|
+
self.engine = MooncakeTransferEngine(
|
226
|
+
hostname=self.local_ip,
|
227
|
+
gpu_id=self.kv_args.gpu_id,
|
228
|
+
ib_device=self.kv_args.ib_device,
|
229
|
+
)
|
230
|
+
|
220
231
|
def register_buffer_to_engine(self):
|
221
232
|
for kv_data_ptr, kv_data_len in zip(
|
222
233
|
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
@@ -259,19 +270,17 @@ class MooncakeKVManager(BaseKVManager):
|
|
259
270
|
|
260
271
|
# Worker function for processing a single layer
|
261
272
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
262
|
-
src_addr_list = []
|
263
|
-
dst_addr_list = []
|
264
|
-
length_list = []
|
265
273
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
266
274
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
267
275
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
268
276
|
length = item_len * len(prefill_index)
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
277
|
+
|
278
|
+
status = self.engine.transfer_sync(
|
279
|
+
mooncake_session_id, src_addr, dst_addr, length
|
280
|
+
)
|
281
|
+
if status != 0:
|
282
|
+
return status
|
283
|
+
return 0
|
275
284
|
|
276
285
|
futures = [
|
277
286
|
executor.submit(
|
@@ -938,7 +947,12 @@ class MooncakeKVSender(BaseKVSender):
|
|
938
947
|
if self.init_time is not None:
|
939
948
|
now = time.time()
|
940
949
|
elapsed = now - self.init_time
|
941
|
-
if elapsed >= self.kv_mgr.
|
950
|
+
if elapsed >= self.kv_mgr.bootstrap_timeout:
|
951
|
+
logger.warning_once(
|
952
|
+
"Some requests timed out when bootstrapping, "
|
953
|
+
"which means prefill instances fail to receive the KV indices from the decode instance of this request. "
|
954
|
+
"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
955
|
+
)
|
942
956
|
self.kv_mgr.record_failure(
|
943
957
|
self.bootstrap_room,
|
944
958
|
f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.Bootstrapping",
|
@@ -987,6 +1001,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
987
1001
|
self.session_id = self.kv_mgr.get_session_id()
|
988
1002
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
989
1003
|
self.conclude_state = None
|
1004
|
+
self.init_time = None
|
990
1005
|
self.data_parallel_rank = data_parallel_rank
|
991
1006
|
|
992
1007
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
@@ -1222,14 +1237,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1222
1237
|
str(self.required_dst_info_num).encode("ascii"),
|
1223
1238
|
]
|
1224
1239
|
)
|
1240
|
+
self.init_time = time.time()
|
1225
1241
|
|
1226
1242
|
def poll(self) -> KVPoll:
|
1227
1243
|
if self.conclude_state is None:
|
1228
1244
|
status = self.kv_mgr.check_status(self.bootstrap_room)
|
1229
1245
|
if status in (KVPoll.Success, KVPoll.Failed):
|
1230
1246
|
self.conclude_state = status
|
1247
|
+
elif status == KVPoll.WaitingForInput:
|
1248
|
+
if self.init_time is not None:
|
1249
|
+
now = time.time()
|
1250
|
+
elapsed = now - self.init_time
|
1251
|
+
if elapsed >= self.kv_mgr.waiting_timeout:
|
1252
|
+
logger.warning_once(
|
1253
|
+
"Some requests fail to receive KV Cache transfer done signal after bootstrapping. "
|
1254
|
+
"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
1255
|
+
)
|
1256
|
+
self.kv_mgr.record_failure(
|
1257
|
+
self.bootstrap_room,
|
1258
|
+
f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput",
|
1259
|
+
)
|
1260
|
+
self.conclude_state = KVPoll.Failed
|
1261
|
+
return KVPoll.Failed
|
1231
1262
|
|
1232
1263
|
return status
|
1264
|
+
|
1233
1265
|
else:
|
1234
1266
|
return self.conclude_state
|
1235
1267
|
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import json
|
2
1
|
import logging
|
3
|
-
from dataclasses import dataclass
|
4
2
|
from typing import List, Optional
|
5
3
|
|
4
|
+
from sglang.srt.utils import get_bool_env_var, get_free_port
|
5
|
+
|
6
6
|
logger = logging.getLogger(__name__)
|
7
7
|
|
8
8
|
|
@@ -55,12 +55,21 @@ class MooncakeTransferEngine:
|
|
55
55
|
device_name: Optional[str],
|
56
56
|
) -> None:
|
57
57
|
"""Initialize the mooncake instance."""
|
58
|
-
|
59
|
-
hostname
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
58
|
+
if get_bool_env_var("ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE", "false"):
|
59
|
+
hostname += f":{get_free_port()}:npu_{self.gpu_id}"
|
60
|
+
ret_value = self.engine.initialize(
|
61
|
+
hostname,
|
62
|
+
"P2PHANDSHAKE",
|
63
|
+
"ascend",
|
64
|
+
device_name if device_name is not None else "",
|
65
|
+
)
|
66
|
+
else:
|
67
|
+
ret_value = self.engine.initialize(
|
68
|
+
hostname,
|
69
|
+
"P2PHANDSHAKE",
|
70
|
+
"rdma",
|
71
|
+
device_name if device_name is not None else "",
|
72
|
+
)
|
64
73
|
if ret_value != 0:
|
65
74
|
logger.error("Mooncake Transfer Engine initialization failed.")
|
66
75
|
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
@@ -15,7 +15,7 @@ import requests
|
|
15
15
|
import torch
|
16
16
|
import torch.distributed as dist
|
17
17
|
|
18
|
-
from sglang.srt.utils import get_ip
|
18
|
+
from sglang.srt.utils import get_ip, is_npu
|
19
19
|
|
20
20
|
if TYPE_CHECKING:
|
21
21
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -94,8 +94,12 @@ class MetadataBuffers:
|
|
94
94
|
custom_mem_pool: torch.cuda.MemPool = None,
|
95
95
|
):
|
96
96
|
self.custom_mem_pool = custom_mem_pool
|
97
|
-
device = "
|
98
|
-
|
97
|
+
device = "cpu"
|
98
|
+
if is_npu():
|
99
|
+
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
|
100
|
+
device = "npu"
|
101
|
+
elif self.custom_mem_pool:
|
102
|
+
device = "cuda"
|
99
103
|
with (
|
100
104
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
101
105
|
if self.custom_mem_pool
|
@@ -200,6 +204,7 @@ class MetadataBuffers:
|
|
200
204
|
class TransferBackend(Enum):
|
201
205
|
MOONCAKE = "mooncake"
|
202
206
|
NIXL = "nixl"
|
207
|
+
ASCEND = "ascend"
|
203
208
|
FAKE = "fake"
|
204
209
|
|
205
210
|
|
@@ -231,6 +236,23 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
231
236
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
232
237
|
}
|
233
238
|
return class_mapping.get(class_type)
|
239
|
+
elif transfer_backend == TransferBackend.ASCEND:
|
240
|
+
from sglang.srt.disaggregation.ascend import (
|
241
|
+
AscendKVBootstrapServer,
|
242
|
+
AscendKVManager,
|
243
|
+
AscendKVReceiver,
|
244
|
+
AscendKVSender,
|
245
|
+
)
|
246
|
+
from sglang.srt.disaggregation.base import KVArgs
|
247
|
+
|
248
|
+
class_mapping = {
|
249
|
+
KVClassType.KVARGS: KVArgs,
|
250
|
+
KVClassType.MANAGER: AscendKVManager,
|
251
|
+
KVClassType.SENDER: AscendKVSender,
|
252
|
+
KVClassType.RECEIVER: (AscendKVReceiver),
|
253
|
+
KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer,
|
254
|
+
}
|
255
|
+
return class_mapping.get(class_type)
|
234
256
|
elif transfer_backend == TransferBackend.NIXL:
|
235
257
|
from sglang.srt.disaggregation.base import KVArgs
|
236
258
|
from sglang.srt.disaggregation.nixl import (
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -418,6 +418,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
|
|
418
418
|
|
419
419
|
await _global_state.tokenizer_manager.start_profile(
|
420
420
|
output_dir=obj.output_dir,
|
421
|
+
start_step=obj.start_step,
|
421
422
|
num_steps=obj.num_steps,
|
422
423
|
activities=obj.activities,
|
423
424
|
with_stack=obj.with_stack,
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import base64
|
2
1
|
import copy
|
3
2
|
import dataclasses
|
4
3
|
import multiprocessing
|
@@ -7,6 +6,7 @@ import threading
|
|
7
6
|
import time
|
8
7
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
9
8
|
|
9
|
+
import pybase64
|
10
10
|
import requests
|
11
11
|
import torch
|
12
12
|
import torch.distributed as dist
|
@@ -267,6 +267,10 @@ class ChatCompletionMessageContentImageURL(BaseModel):
|
|
267
267
|
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
268
268
|
|
269
269
|
|
270
|
+
class ChatCompletionMessageContentVideoURL(BaseModel):
|
271
|
+
url: str
|
272
|
+
|
273
|
+
|
270
274
|
class ChatCompletionMessageContentAudioURL(BaseModel):
|
271
275
|
url: str
|
272
276
|
|
@@ -277,6 +281,11 @@ class ChatCompletionMessageContentImagePart(BaseModel):
|
|
277
281
|
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
|
278
282
|
|
279
283
|
|
284
|
+
class ChatCompletionMessageContentVideoPart(BaseModel):
|
285
|
+
type: Literal["video_url"]
|
286
|
+
video_url: ChatCompletionMessageContentVideoURL
|
287
|
+
|
288
|
+
|
280
289
|
class ChatCompletionMessageContentAudioPart(BaseModel):
|
281
290
|
type: Literal["audio_url"]
|
282
291
|
audio_url: ChatCompletionMessageContentAudioURL
|
@@ -285,6 +294,7 @@ class ChatCompletionMessageContentAudioPart(BaseModel):
|
|
285
294
|
ChatCompletionMessageContentPart = Union[
|
286
295
|
ChatCompletionMessageContentTextPart,
|
287
296
|
ChatCompletionMessageContentImagePart,
|
297
|
+
ChatCompletionMessageContentVideoPart,
|
288
298
|
ChatCompletionMessageContentAudioPart,
|
289
299
|
]
|
290
300
|
|
@@ -629,6 +639,7 @@ class MessageProcessingResult:
|
|
629
639
|
prompt_ids: Union[str, List[int]]
|
630
640
|
image_data: Optional[Any]
|
631
641
|
audio_data: Optional[Any]
|
642
|
+
video_data: Optional[Any]
|
632
643
|
modalities: List[str]
|
633
644
|
stop: List[str]
|
634
645
|
tool_call_constraint: Optional[Any] = None
|