sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -814,9 +814,9 @@ def sample_mmmu_requests(
814
814
  List of tuples (prompt, prompt_token_len, output_token_len).
815
815
  """
816
816
  try:
817
- import base64
818
817
  import io
819
818
 
819
+ import pybase64
820
820
  from datasets import load_dataset
821
821
  except ImportError:
822
822
  raise ImportError("Please install datasets: pip install datasets")
@@ -867,7 +867,7 @@ def sample_mmmu_requests(
867
867
  # Encode image to base64
868
868
  buffered = io.BytesIO()
869
869
  image.save(buffered, format="JPEG")
870
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
870
+ img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8")
871
871
  image_data = f"data:image/jpeg;base64,{img_str}"
872
872
  else:
873
873
  continue
@@ -25,6 +25,7 @@ from transformers import PretrainedConfig
25
25
  from sglang.srt.hf_transformers_utils import (
26
26
  get_config,
27
27
  get_context_length,
28
+ get_generation_config,
28
29
  get_hf_text_config,
29
30
  )
30
31
  from sglang.srt.layers.quantization import QUANTIZATION_METHODS
@@ -83,6 +84,13 @@ class ModelConfig:
83
84
  **kwargs,
84
85
  )
85
86
 
87
+ self.hf_generation_config = get_generation_config(
88
+ self.model_path,
89
+ trust_remote_code=trust_remote_code,
90
+ revision=revision,
91
+ **kwargs,
92
+ )
93
+
86
94
  self.hf_text_config = get_hf_text_config(self.hf_config)
87
95
  self.attention_chunk_size = getattr(
88
96
  self.hf_text_config, "attention_chunk_size", None
@@ -359,7 +367,17 @@ class ModelConfig:
359
367
  if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
360
368
  quant_cfg = modelopt_quant_config
361
369
  elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
362
- quant_cfg = modelopt_quant_config
370
+ quant_config_file = os.path.join(
371
+ self.model_path, "hf_quant_config.json"
372
+ )
373
+ with open(quant_config_file) as f:
374
+ quant_config_dict = json.load(f)
375
+ json_quant_configs = quant_config_dict["quantization"]
376
+ quant_algo = json_quant_configs.get("quant_algo", None)
377
+ if quant_algo == "MIXED_PRECISION":
378
+ quant_cfg = {"quant_method": "w4afp8"}
379
+ else:
380
+ quant_cfg = modelopt_quant_config
363
381
  return quant_cfg
364
382
 
365
383
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
@@ -389,6 +407,7 @@ class ModelConfig:
389
407
  "w8a8_fp8",
390
408
  "moe_wna16",
391
409
  "qoq",
410
+ "w4afp8",
392
411
  ]
393
412
  compatible_quantization_methods = {
394
413
  "modelopt_fp4": ["modelopt"],
@@ -402,7 +421,9 @@ class ModelConfig:
402
421
  quant_cfg = self._parse_quant_hf_config()
403
422
 
404
423
  if quant_cfg is not None:
405
- quant_method = quant_cfg.get("quant_method", "").lower()
424
+ quant_method = quant_cfg.get(
425
+ "quant_method", "" if not self.quantization else self.quantization
426
+ ).lower()
406
427
 
407
428
  # Detect which checkpoint is it
408
429
  for _, method in QUANTIZATION_METHODS.items():
@@ -454,6 +475,19 @@ class ModelConfig:
454
475
  if eos_ids:
455
476
  # it can be either int or list of int
456
477
  eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
478
+ if eos_ids is None:
479
+ eos_ids = set()
480
+ if self.hf_generation_config:
481
+ generation_eos_ids = getattr(
482
+ self.hf_generation_config, "eos_token_id", None
483
+ )
484
+ if generation_eos_ids:
485
+ generation_eos_ids = (
486
+ {generation_eos_ids}
487
+ if isinstance(generation_eos_ids, int)
488
+ else set(generation_eos_ids)
489
+ )
490
+ eos_ids = eos_ids | generation_eos_ids
457
491
  return eos_ids
458
492
 
459
493
  def maybe_pull_model_tokenizer_from_remote(self) -> None:
@@ -88,9 +88,11 @@ class Conversation:
88
88
  stop_str: Union[str, List[str]] = None
89
89
  # The string that represents an image token in the prompt
90
90
  image_token: str = "<image>"
91
+ video_token: str = "<video>"
91
92
  audio_token: str = "<audio>"
92
93
 
93
94
  image_data: Optional[List[str]] = None
95
+ video_data: Optional[List[str]] = None
94
96
  modalities: Optional[List[str]] = None
95
97
  stop_token_ids: Optional[int] = None
96
98
 
@@ -380,11 +382,15 @@ class Conversation:
380
382
  self.messages.append([role, message])
381
383
 
382
384
  def append_image(self, image: str):
383
- """Append a new message."""
385
+ """Append a new image."""
384
386
  self.image_data.append(image)
385
387
 
388
+ def append_video(self, video: str):
389
+ """Append a new video."""
390
+ self.video_data.append(video)
391
+
386
392
  def append_audio(self, audio: str):
387
- """Append a new message."""
393
+ """Append a new audio."""
388
394
  self.audio_data.append(audio)
389
395
 
390
396
  def update_last_message(self, message: str):
@@ -433,6 +439,7 @@ class Conversation:
433
439
  sep2=self.sep2,
434
440
  stop_str=self.stop_str,
435
441
  image_token=self.image_token,
442
+ video_token=self.video_token,
436
443
  audio_token=self.audio_token,
437
444
  )
438
445
 
@@ -495,8 +502,12 @@ def generate_embedding_convs(
495
502
  sep2=conv_template.sep2,
496
503
  stop_str=conv_template.stop_str,
497
504
  image_data=[],
505
+ video_data=[],
506
+ audio_data=[],
498
507
  modalities=[],
499
508
  image_token=conv_template.image_token,
509
+ video_token=conv_template.video_token,
510
+ audio_token=conv_template.audio_token,
500
511
  )
501
512
  real_content = ""
502
513
 
@@ -557,10 +568,12 @@ def generate_chat_conv(
557
568
  sep2=conv.sep2,
558
569
  stop_str=conv.stop_str,
559
570
  image_data=[],
571
+ video_data=[],
560
572
  audio_data=[],
561
573
  modalities=[],
562
574
  image_token=conv.image_token,
563
575
  audio_token=conv.audio_token,
576
+ video_token=conv.video_token,
564
577
  )
565
578
 
566
579
  if isinstance(request.messages, str):
@@ -602,6 +615,7 @@ def generate_chat_conv(
602
615
  image_token = ""
603
616
 
604
617
  audio_token = conv.audio_token
618
+ video_token = conv.video_token
605
619
  for content in message.content:
606
620
  if content.type == "text":
607
621
  if num_image_url > 16:
@@ -614,6 +628,9 @@ def generate_chat_conv(
614
628
  else:
615
629
  real_content += image_token
616
630
  conv.append_image(content.image_url.url)
631
+ elif content.type == "video_url":
632
+ real_content += video_token
633
+ conv.append_video(content.video_url.url)
617
634
  elif content.type == "audio_url":
618
635
  real_content += audio_token
619
636
  conv.append_audio(content.audio_url.url)
@@ -810,6 +827,7 @@ register_conv_template(
810
827
  sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
811
828
  stop_str=["<|im_end|>"],
812
829
  image_token="<|vision_start|><|image_pad|><|vision_end|>",
830
+ video_token="<|vision_start|><|video_pad|><|vision_end|>",
813
831
  )
814
832
  )
815
833
 
@@ -870,6 +888,7 @@ register_conv_template(
870
888
  sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
871
889
  stop_str=("<|im_end|>", "<|endoftext|>"),
872
890
  image_token="(<image>./</image>)",
891
+ video_token="(<video>./</video>)",
873
892
  )
874
893
  )
875
894
 
@@ -921,6 +940,19 @@ register_conv_template(
921
940
  )
922
941
  )
923
942
 
943
+ register_conv_template(
944
+ Conversation(
945
+ name="mimo-vl",
946
+ system_message="You are MiMo, an AI assistant developed by Xiaomi.",
947
+ system_template="<|im_start|>system\n{system_message}",
948
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
949
+ sep="<|im_end|>\n",
950
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
951
+ stop_str=["<|im_end|>"],
952
+ image_token="<|vision_start|><|image_pad|><|vision_end|>",
953
+ )
954
+ )
955
+
924
956
 
925
957
  register_conv_template(
926
958
  Conversation(
@@ -935,6 +967,19 @@ register_conv_template(
935
967
  )
936
968
  )
937
969
 
970
+ register_conv_template(
971
+ Conversation(
972
+ name="llama_4_vision",
973
+ system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
974
+ system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
975
+ roles=("user", "assistant"),
976
+ sep_style=SeparatorStyle.LLAMA4,
977
+ sep="",
978
+ stop_str="<|eot|>",
979
+ image_token="<|image|>",
980
+ )
981
+ )
982
+
938
983
 
939
984
  @register_conv_template_matching_function
940
985
  def match_internvl(model_path: str):
@@ -943,9 +988,11 @@ def match_internvl(model_path: str):
943
988
 
944
989
 
945
990
  @register_conv_template_matching_function
946
- def match_llama_3_vision(model_path: str):
991
+ def match_llama_vision(model_path: str):
947
992
  if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
948
993
  return "llama_3_vision"
994
+ if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
995
+ return "llama_4_vision"
949
996
 
950
997
 
951
998
  @register_conv_template_matching_function
@@ -1034,3 +1081,9 @@ def match_phi_4_mm(model_path: str):
1034
1081
  def match_vila(model_path: str):
1035
1082
  if re.search(r"vila", model_path, re.IGNORECASE):
1036
1083
  return "chatml"
1084
+
1085
+
1086
+ @register_conv_template_matching_function
1087
+ def match_mimo_vl(model_path: str):
1088
+ if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
1089
+ return "mimo-vl"
@@ -0,0 +1,6 @@
1
+ from sglang.srt.disaggregation.ascend.conn import (
2
+ AscendKVBootstrapServer,
3
+ AscendKVManager,
4
+ AscendKVReceiver,
5
+ AscendKVSender,
6
+ )
@@ -0,0 +1,44 @@
1
+ import logging
2
+
3
+ from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
4
+ from sglang.srt.disaggregation.mooncake.conn import (
5
+ MooncakeKVBootstrapServer,
6
+ MooncakeKVManager,
7
+ MooncakeKVReceiver,
8
+ MooncakeKVSender,
9
+ )
10
+ from sglang.srt.utils import get_local_ip_by_remote
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class AscendKVManager(MooncakeKVManager):
16
+ def init_engine(self):
17
+ # TransferEngine initialized on ascend.
18
+ local_ip = get_local_ip_by_remote()
19
+ self.engine = AscendTransferEngine(
20
+ hostname=local_ip,
21
+ npu_id=self.kv_args.gpu_id,
22
+ disaggregation_mode=self.disaggregation_mode,
23
+ )
24
+
25
+ def register_buffer_to_engine(self):
26
+ self.engine.register(
27
+ self.kv_args.kv_data_ptrs[0], sum(self.kv_args.kv_data_lens)
28
+ )
29
+ # The Ascend backend optimize batch registration for small memory blocks.
30
+ self.engine.batch_register(
31
+ self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
32
+ )
33
+
34
+
35
+ class AscendKVSender(MooncakeKVSender):
36
+ pass
37
+
38
+
39
+ class AscendKVReceiver(MooncakeKVReceiver):
40
+ pass
41
+
42
+
43
+ class AscendKVBootstrapServer(MooncakeKVBootstrapServer):
44
+ pass
@@ -0,0 +1,58 @@
1
+ import logging
2
+ import os
3
+ from typing import List, Optional
4
+
5
+ from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
6
+ from sglang.srt.disaggregation.utils import DisaggregationMode
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class AscendTransferEngine(MooncakeTransferEngine):
12
+
13
+ def __init__(
14
+ self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
15
+ ):
16
+ try:
17
+ from mf_adapter import TransferEngine
18
+ except ImportError as e:
19
+ raise ImportError(
20
+ "Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
21
+ ) from e
22
+
23
+ self.engine = TransferEngine()
24
+ self.hostname = hostname
25
+ self.npu_id = npu_id
26
+
27
+ # Centralized storage address of the AscendTransferEngine
28
+ self.store_url = os.getenv("ASCEND_MF_STORE_URL")
29
+ if disaggregation_mode == DisaggregationMode.PREFILL:
30
+ self.role = "Prefill"
31
+ elif disaggregation_mode == DisaggregationMode.DECODE:
32
+ self.role = "Decode"
33
+ else:
34
+ logger.error(f"Unsupported DisaggregationMode: {disaggregation_mode}")
35
+ raise ValueError(f"Unsupported DisaggregationMode: {disaggregation_mode}")
36
+ self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
37
+ self.initialize()
38
+
39
+ def initialize(self) -> None:
40
+ """Initialize the ascend transfer instance."""
41
+ ret_value = self.engine.initialize(
42
+ self.store_url,
43
+ self.session_id,
44
+ self.role,
45
+ self.npu_id,
46
+ )
47
+ if ret_value != 0:
48
+ logger.error("Ascend Transfer Engine initialization failed.")
49
+ raise RuntimeError("Ascend Transfer Engine initialization failed.")
50
+
51
+ def batch_register(self, ptrs: List[int], lengths: List[int]):
52
+ try:
53
+ ret_value = self.engine.batch_register_memory(ptrs, lengths)
54
+ except Exception:
55
+ # Mark register as failed
56
+ ret_value = -1
57
+ if ret_value != 0:
58
+ logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
@@ -132,13 +132,9 @@ class MooncakeKVManager(BaseKVManager):
132
132
  ):
133
133
  self.kv_args = args
134
134
  self.local_ip = get_local_ip_auto()
135
- self.engine = MooncakeTransferEngine(
136
- hostname=self.local_ip,
137
- gpu_id=self.kv_args.gpu_id,
138
- ib_device=self.kv_args.ib_device,
139
- )
140
135
  self.is_mla_backend = is_mla_backend
141
136
  self.disaggregation_mode = disaggregation_mode
137
+ self.init_engine()
142
138
  # for p/d multi node infer
143
139
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
144
140
  self.dist_init_addr = server_args.dist_init_addr
@@ -185,9 +181,11 @@ class MooncakeKVManager(BaseKVManager):
185
181
  threading.Thread(
186
182
  target=self.transfer_worker, args=(queue, executor), daemon=True
187
183
  ).start()
188
-
189
- self.bootstrap_time_out = get_int_env_var(
190
- "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
184
+ # If a timeout happens on the prefill side, it means prefill instances
185
+ # fail to receive the KV indices from the decode instance of this request.
186
+ # These timeout requests should be aborted to release the tree cache.
187
+ self.bootstrap_timeout = get_int_env_var(
188
+ "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
191
189
  )
192
190
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
193
191
  self.heartbeat_failures = {}
@@ -209,6 +207,12 @@ class MooncakeKVManager(BaseKVManager):
209
207
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
210
208
  self.prefill_tp_size_table: Dict[str, int] = {}
211
209
  self.prefill_dp_size_table: Dict[str, int] = {}
210
+ # If a timeout happens on the decode side, it means decode instances
211
+ # fail to receive the KV Cache transfer done signal after bootstrapping.
212
+ # These timeout requests should be aborted to release the tree cache.
213
+ self.waiting_timeout = get_int_env_var(
214
+ "SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
215
+ )
212
216
  else:
213
217
  raise ValueError(
214
218
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
@@ -217,6 +221,13 @@ class MooncakeKVManager(BaseKVManager):
217
221
  self.failure_records: Dict[int, str] = {}
218
222
  self.failure_lock = threading.Lock()
219
223
 
224
+ def init_engine(self):
225
+ self.engine = MooncakeTransferEngine(
226
+ hostname=self.local_ip,
227
+ gpu_id=self.kv_args.gpu_id,
228
+ ib_device=self.kv_args.ib_device,
229
+ )
230
+
220
231
  def register_buffer_to_engine(self):
221
232
  for kv_data_ptr, kv_data_len in zip(
222
233
  self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
@@ -259,19 +270,17 @@ class MooncakeKVManager(BaseKVManager):
259
270
 
260
271
  # Worker function for processing a single layer
261
272
  def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
262
- src_addr_list = []
263
- dst_addr_list = []
264
- length_list = []
265
273
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
266
274
  src_addr = src_ptr + int(prefill_index[0]) * item_len
267
275
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
268
276
  length = item_len * len(prefill_index)
269
- src_addr_list.append(src_addr)
270
- dst_addr_list.append(dst_addr)
271
- length_list.append(length)
272
- return self.engine.batch_transfer_sync(
273
- mooncake_session_id, src_addr_list, dst_addr_list, length_list
274
- )
277
+
278
+ status = self.engine.transfer_sync(
279
+ mooncake_session_id, src_addr, dst_addr, length
280
+ )
281
+ if status != 0:
282
+ return status
283
+ return 0
275
284
 
276
285
  futures = [
277
286
  executor.submit(
@@ -938,7 +947,12 @@ class MooncakeKVSender(BaseKVSender):
938
947
  if self.init_time is not None:
939
948
  now = time.time()
940
949
  elapsed = now - self.init_time
941
- if elapsed >= self.kv_mgr.bootstrap_time_out:
950
+ if elapsed >= self.kv_mgr.bootstrap_timeout:
951
+ logger.warning_once(
952
+ "Some requests timed out when bootstrapping, "
953
+ "which means prefill instances fail to receive the KV indices from the decode instance of this request. "
954
+ "If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
955
+ )
942
956
  self.kv_mgr.record_failure(
943
957
  self.bootstrap_room,
944
958
  f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.Bootstrapping",
@@ -987,6 +1001,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
987
1001
  self.session_id = self.kv_mgr.get_session_id()
988
1002
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
989
1003
  self.conclude_state = None
1004
+ self.init_time = None
990
1005
  self.data_parallel_rank = data_parallel_rank
991
1006
 
992
1007
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
@@ -1222,14 +1237,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
1222
1237
  str(self.required_dst_info_num).encode("ascii"),
1223
1238
  ]
1224
1239
  )
1240
+ self.init_time = time.time()
1225
1241
 
1226
1242
  def poll(self) -> KVPoll:
1227
1243
  if self.conclude_state is None:
1228
1244
  status = self.kv_mgr.check_status(self.bootstrap_room)
1229
1245
  if status in (KVPoll.Success, KVPoll.Failed):
1230
1246
  self.conclude_state = status
1247
+ elif status == KVPoll.WaitingForInput:
1248
+ if self.init_time is not None:
1249
+ now = time.time()
1250
+ elapsed = now - self.init_time
1251
+ if elapsed >= self.kv_mgr.waiting_timeout:
1252
+ logger.warning_once(
1253
+ "Some requests fail to receive KV Cache transfer done signal after bootstrapping. "
1254
+ "If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
1255
+ )
1256
+ self.kv_mgr.record_failure(
1257
+ self.bootstrap_room,
1258
+ f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput",
1259
+ )
1260
+ self.conclude_state = KVPoll.Failed
1261
+ return KVPoll.Failed
1231
1262
 
1232
1263
  return status
1264
+
1233
1265
  else:
1234
1266
  return self.conclude_state
1235
1267
 
@@ -1,8 +1,8 @@
1
- import json
2
1
  import logging
3
- from dataclasses import dataclass
4
2
  from typing import List, Optional
5
3
 
4
+ from sglang.srt.utils import get_bool_env_var, get_free_port
5
+
6
6
  logger = logging.getLogger(__name__)
7
7
 
8
8
 
@@ -55,12 +55,21 @@ class MooncakeTransferEngine:
55
55
  device_name: Optional[str],
56
56
  ) -> None:
57
57
  """Initialize the mooncake instance."""
58
- ret_value = self.engine.initialize(
59
- hostname,
60
- "P2PHANDSHAKE",
61
- "rdma",
62
- device_name if device_name is not None else "",
63
- )
58
+ if get_bool_env_var("ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE", "false"):
59
+ hostname += f":{get_free_port()}:npu_{self.gpu_id}"
60
+ ret_value = self.engine.initialize(
61
+ hostname,
62
+ "P2PHANDSHAKE",
63
+ "ascend",
64
+ device_name if device_name is not None else "",
65
+ )
66
+ else:
67
+ ret_value = self.engine.initialize(
68
+ hostname,
69
+ "P2PHANDSHAKE",
70
+ "rdma",
71
+ device_name if device_name is not None else "",
72
+ )
64
73
  if ret_value != 0:
65
74
  logger.error("Mooncake Transfer Engine initialization failed.")
66
75
  raise RuntimeError("Mooncake Transfer Engine initialization failed.")
@@ -15,7 +15,7 @@ import requests
15
15
  import torch
16
16
  import torch.distributed as dist
17
17
 
18
- from sglang.srt.utils import get_ip
18
+ from sglang.srt.utils import get_ip, is_npu
19
19
 
20
20
  if TYPE_CHECKING:
21
21
  from sglang.srt.managers.schedule_batch import Req
@@ -94,8 +94,12 @@ class MetadataBuffers:
94
94
  custom_mem_pool: torch.cuda.MemPool = None,
95
95
  ):
96
96
  self.custom_mem_pool = custom_mem_pool
97
- device = "cuda" if self.custom_mem_pool else "cpu"
98
-
97
+ device = "cpu"
98
+ if is_npu():
99
+ # For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
100
+ device = "npu"
101
+ elif self.custom_mem_pool:
102
+ device = "cuda"
99
103
  with (
100
104
  torch.cuda.use_mem_pool(self.custom_mem_pool)
101
105
  if self.custom_mem_pool
@@ -200,6 +204,7 @@ class MetadataBuffers:
200
204
  class TransferBackend(Enum):
201
205
  MOONCAKE = "mooncake"
202
206
  NIXL = "nixl"
207
+ ASCEND = "ascend"
203
208
  FAKE = "fake"
204
209
 
205
210
 
@@ -231,6 +236,23 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
231
236
  KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
232
237
  }
233
238
  return class_mapping.get(class_type)
239
+ elif transfer_backend == TransferBackend.ASCEND:
240
+ from sglang.srt.disaggregation.ascend import (
241
+ AscendKVBootstrapServer,
242
+ AscendKVManager,
243
+ AscendKVReceiver,
244
+ AscendKVSender,
245
+ )
246
+ from sglang.srt.disaggregation.base import KVArgs
247
+
248
+ class_mapping = {
249
+ KVClassType.KVARGS: KVArgs,
250
+ KVClassType.MANAGER: AscendKVManager,
251
+ KVClassType.SENDER: AscendKVSender,
252
+ KVClassType.RECEIVER: (AscendKVReceiver),
253
+ KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer,
254
+ }
255
+ return class_mapping.get(class_type)
234
256
  elif transfer_backend == TransferBackend.NIXL:
235
257
  from sglang.srt.disaggregation.base import KVArgs
236
258
  from sglang.srt.disaggregation.nixl import (
@@ -650,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
650
650
  if _is_cuda:
651
651
  assert_pkg_version(
652
652
  "sgl-kernel",
653
- "0.2.4",
653
+ "0.2.5",
654
654
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
655
655
  )
656
656
 
@@ -418,6 +418,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
418
418
 
419
419
  await _global_state.tokenizer_manager.start_profile(
420
420
  output_dir=obj.output_dir,
421
+ start_step=obj.start_step,
421
422
  num_steps=obj.num_steps,
422
423
  activities=obj.activities,
423
424
  with_stack=obj.with_stack,
@@ -1,4 +1,3 @@
1
- import base64
2
1
  import copy
3
2
  import dataclasses
4
3
  import multiprocessing
@@ -7,6 +6,7 @@ import threading
7
6
  import time
8
7
  from typing import Any, Dict, List, Optional, Tuple, Union
9
8
 
9
+ import pybase64
10
10
  import requests
11
11
  import torch
12
12
  import torch.distributed as dist
@@ -267,6 +267,10 @@ class ChatCompletionMessageContentImageURL(BaseModel):
267
267
  detail: Optional[Literal["auto", "low", "high"]] = "auto"
268
268
 
269
269
 
270
+ class ChatCompletionMessageContentVideoURL(BaseModel):
271
+ url: str
272
+
273
+
270
274
  class ChatCompletionMessageContentAudioURL(BaseModel):
271
275
  url: str
272
276
 
@@ -277,6 +281,11 @@ class ChatCompletionMessageContentImagePart(BaseModel):
277
281
  modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
278
282
 
279
283
 
284
+ class ChatCompletionMessageContentVideoPart(BaseModel):
285
+ type: Literal["video_url"]
286
+ video_url: ChatCompletionMessageContentVideoURL
287
+
288
+
280
289
  class ChatCompletionMessageContentAudioPart(BaseModel):
281
290
  type: Literal["audio_url"]
282
291
  audio_url: ChatCompletionMessageContentAudioURL
@@ -285,6 +294,7 @@ class ChatCompletionMessageContentAudioPart(BaseModel):
285
294
  ChatCompletionMessageContentPart = Union[
286
295
  ChatCompletionMessageContentTextPart,
287
296
  ChatCompletionMessageContentImagePart,
297
+ ChatCompletionMessageContentVideoPart,
288
298
  ChatCompletionMessageContentAudioPart,
289
299
  ]
290
300
 
@@ -629,6 +639,7 @@ class MessageProcessingResult:
629
639
  prompt_ids: Union[str, List[int]]
630
640
  image_data: Optional[Any]
631
641
  audio_data: Optional[Any]
642
+ video_data: Optional[Any]
632
643
  modalities: List[str]
633
644
  stop: List[str]
634
645
  tool_call_constraint: Optional[Any] = None