sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -59,6 +59,7 @@ class SeparatorStyle(IntEnum):
59
59
  METAMATH = auto()
60
60
  DeepSeekVL2 = auto()
61
61
  QWEN2_VL_EMBED = auto()
62
+ QWEN2_AUDIO = auto()
62
63
  GEMMA3 = auto()
63
64
  MPT = auto()
64
65
 
@@ -350,6 +351,23 @@ class Conversation:
350
351
  else:
351
352
  ret += role
352
353
  return ret
354
+ elif self.sep_style == SeparatorStyle.QWEN2_AUDIO:
355
+ ret = "" if system_prompt == "" else system_prompt + self.sep
356
+
357
+ counter = 1
358
+ for role, message in self.messages:
359
+ if message:
360
+ while self.audio_token in message:
361
+ message = message.replace(
362
+ self.audio_token, self.audio_token.format(idx=counter), 1
363
+ )
364
+ counter += 1
365
+
366
+ ret += role + "\n" + message + self.sep
367
+ else:
368
+ ret += role + "\n"
369
+
370
+ return ret
353
371
  else:
354
372
  raise ValueError(f"Invalid style: {self.sep_style}")
355
373
 
@@ -903,6 +921,46 @@ register_conv_template(
903
921
  )
904
922
  )
905
923
 
924
+ register_conv_template(
925
+ Conversation(
926
+ name="mimo-vl",
927
+ system_message="You are MiMo, an AI assistant developed by Xiaomi.",
928
+ system_template="<|im_start|>system\n{system_message}",
929
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
930
+ sep="<|im_end|>\n",
931
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
932
+ stop_str=["<|im_end|>"],
933
+ image_token="<|vision_start|><|image_pad|><|vision_end|>",
934
+ )
935
+ )
936
+
937
+
938
+ register_conv_template(
939
+ Conversation(
940
+ name="qwen2-audio",
941
+ system_template="<|im_start|>system\n{system_message}",
942
+ system_message="You are a helpful assistant.",
943
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
944
+ sep="<|im_end|>\n",
945
+ sep_style=SeparatorStyle.QWEN2_AUDIO,
946
+ stop_str=["<|im_end|>"],
947
+ audio_token="Audio {idx}: <|audio_bos|><|AUDIO|><|audio_eos|>\n",
948
+ )
949
+ )
950
+
951
+ register_conv_template(
952
+ Conversation(
953
+ name="llama_4_vision",
954
+ system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
955
+ system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
956
+ roles=("user", "assistant"),
957
+ sep_style=SeparatorStyle.LLAMA4,
958
+ sep="",
959
+ stop_str="<|eot|>",
960
+ image_token="<|image|>",
961
+ )
962
+ )
963
+
906
964
 
907
965
  @register_conv_template_matching_function
908
966
  def match_internvl(model_path: str):
@@ -911,9 +969,11 @@ def match_internvl(model_path: str):
911
969
 
912
970
 
913
971
  @register_conv_template_matching_function
914
- def match_llama_3_vision(model_path: str):
972
+ def match_llama_vision(model_path: str):
915
973
  if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
916
974
  return "llama_3_vision"
975
+ if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
976
+ return "llama_4_vision"
917
977
 
918
978
 
919
979
  @register_conv_template_matching_function
@@ -956,6 +1016,8 @@ def match_qwen_chat_ml(model_path: str):
956
1016
  return "gme-qwen2-vl"
957
1017
  if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
958
1018
  return "qwen2-vl"
1019
+ if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
1020
+ return "qwen2-audio"
959
1021
  if re.search(
960
1022
  r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
961
1023
  model_path,
@@ -1000,3 +1062,9 @@ def match_phi_4_mm(model_path: str):
1000
1062
  def match_vila(model_path: str):
1001
1063
  if re.search(r"vila", model_path, re.IGNORECASE):
1002
1064
  return "chatml"
1065
+
1066
+
1067
+ @register_conv_template_matching_function
1068
+ def match_mimo_vl(model_path: str):
1069
+ if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
1070
+ return "mimo-vl"
@@ -416,6 +416,12 @@ class DecodePreallocQueue:
416
416
 
417
417
  return preallocated_reqs
418
418
 
419
+ @property
420
+ def num_tokens_pre_allocated(self):
421
+ return sum(
422
+ len(decode_req.req.fill_ids) for decode_req in self.transfer_queue.queue
423
+ )
424
+
419
425
  def _allocatable_tokens(
420
426
  self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
421
427
  ) -> int:
@@ -433,9 +439,7 @@ class DecodePreallocQueue:
433
439
  else 0
434
440
  )
435
441
 
436
- available_size = self.token_to_kv_pool_allocator.available_size()
437
-
438
- allocatable_tokens = available_size - max(
442
+ allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
439
443
  # preserve some space for future decode
440
444
  self.num_reserved_decode_tokens
441
445
  * (
@@ -606,9 +610,21 @@ class DecodeTransferQueue:
606
610
  : decode_req.req.top_logprobs_num
607
611
  ].tolist()
608
612
  )
613
+
609
614
  if hasattr(decode_req.kv_receiver, "clear"):
610
615
  decode_req.kv_receiver.clear()
611
- transferred_reqs.append(decode_req.req)
616
+
617
+ # special handling for sampling_params.max_new_tokens == 1
618
+ if decode_req.req.sampling_params.max_new_tokens == 1:
619
+ # finish immediately
620
+ decode_req.req.check_finished()
621
+ self.scheduler.stream_output(
622
+ [decode_req.req], decode_req.req.return_logprob
623
+ )
624
+ self.tree_cache.cache_finished_req(decode_req.req)
625
+ else:
626
+ transferred_reqs.append(decode_req.req)
627
+
612
628
  indices_to_remove.add(i)
613
629
  elif poll in [
614
630
  KVPoll.Bootstrapping,
@@ -756,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin:
756
772
  self.last_batch_in_queue = last_batch_in_queue
757
773
 
758
774
  def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
759
- batch, _ = self.prepare_mlp_sync_batch(batch)
775
+ batch = self.prepare_mlp_sync_batch(batch)
760
776
  result = None
761
777
  if batch:
762
778
  result = self.run_batch(batch)
@@ -185,9 +185,11 @@ class MooncakeKVManager(BaseKVManager):
185
185
  threading.Thread(
186
186
  target=self.transfer_worker, args=(queue, executor), daemon=True
187
187
  ).start()
188
-
189
- self.bootstrap_time_out = get_int_env_var(
190
- "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
188
+ # If a timeout happens on the prefill side, it means prefill instances
189
+ # fail to receive the KV indices from the decode instance of this request.
190
+ # These timeout requests should be aborted to release the tree cache.
191
+ self.bootstrap_timeout = get_int_env_var(
192
+ "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
191
193
  )
192
194
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
193
195
  self.heartbeat_failures = {}
@@ -209,6 +211,12 @@ class MooncakeKVManager(BaseKVManager):
209
211
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
210
212
  self.prefill_tp_size_table: Dict[str, int] = {}
211
213
  self.prefill_dp_size_table: Dict[str, int] = {}
214
+ # If a timeout happens on the decode side, it means decode instances
215
+ # fail to receive the KV Cache transfer done signal after bootstrapping.
216
+ # These timeout requests should be aborted to release the tree cache.
217
+ self.waiting_timeout = get_int_env_var(
218
+ "SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
219
+ )
212
220
  else:
213
221
  raise ValueError(
214
222
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
@@ -938,7 +946,12 @@ class MooncakeKVSender(BaseKVSender):
938
946
  if self.init_time is not None:
939
947
  now = time.time()
940
948
  elapsed = now - self.init_time
941
- if elapsed >= self.kv_mgr.bootstrap_time_out:
949
+ if elapsed >= self.kv_mgr.bootstrap_timeout:
950
+ logger.warning_once(
951
+ "Some requests timed out when bootstrapping, "
952
+ "which means prefill instances fail to receive the KV indices from the decode instance of this request. "
953
+ "If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
954
+ )
942
955
  self.kv_mgr.record_failure(
943
956
  self.bootstrap_room,
944
957
  f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.Bootstrapping",
@@ -987,6 +1000,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
987
1000
  self.session_id = self.kv_mgr.get_session_id()
988
1001
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
989
1002
  self.conclude_state = None
1003
+ self.init_time = None
990
1004
  self.data_parallel_rank = data_parallel_rank
991
1005
 
992
1006
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
@@ -1222,14 +1236,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
1222
1236
  str(self.required_dst_info_num).encode("ascii"),
1223
1237
  ]
1224
1238
  )
1239
+ self.init_time = time.time()
1225
1240
 
1226
1241
  def poll(self) -> KVPoll:
1227
1242
  if self.conclude_state is None:
1228
1243
  status = self.kv_mgr.check_status(self.bootstrap_room)
1229
1244
  if status in (KVPoll.Success, KVPoll.Failed):
1230
1245
  self.conclude_state = status
1246
+ elif status == KVPoll.WaitingForInput:
1247
+ if self.init_time is not None:
1248
+ now = time.time()
1249
+ elapsed = now - self.init_time
1250
+ if elapsed >= self.kv_mgr.waiting_timeout:
1251
+ logger.warning_once(
1252
+ "Some requests fail to receive KV Cache transfer done signal after bootstrapping. "
1253
+ "If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
1254
+ )
1255
+ self.kv_mgr.record_failure(
1256
+ self.bootstrap_room,
1257
+ f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput",
1258
+ )
1259
+ self.conclude_state = KVPoll.Failed
1260
+ return KVPoll.Failed
1231
1261
 
1232
1262
  return status
1263
+
1233
1264
  else:
1234
1265
  return self.conclude_state
1235
1266
 
@@ -159,7 +159,7 @@ class NixlKVManager(CommonKVManager):
159
159
  self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
160
160
  ):
161
161
  kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
162
- self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
162
+ self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
163
163
  logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
164
164
  if not self.kv_descs:
165
165
  raise Exception("NIXL memory registration failed for kv tensors")
@@ -168,7 +168,7 @@ class NixlKVManager(CommonKVManager):
168
168
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
169
169
  ):
170
170
  aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
171
- self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
171
+ self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
172
172
  logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
173
173
  if not self.aux_descs:
174
174
  raise Exception("NIXL memory registration failed for aux tensors")
@@ -215,8 +215,8 @@ class NixlKVManager(CommonKVManager):
215
215
  logger.debug(
216
216
  f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
217
217
  )
218
- src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
219
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
218
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
219
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
220
220
  # Transfer data
221
221
  xfer_handle = self.agent.initialize_xfer(
222
222
  "WRITE",
@@ -248,8 +248,8 @@ class NixlKVManager(CommonKVManager):
248
248
  decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
249
249
  src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
250
250
  dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
251
- src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True)
252
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True)
251
+ src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
252
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
253
253
  # Transfer data
254
254
  xfer_handle = self.agent.initialize_xfer(
255
255
  "WRITE",
@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
276
276
  batch = self.get_new_batch_prefill()
277
277
 
278
278
  if require_mlp_sync(self.server_args):
279
- batch, _ = self.prepare_mlp_sync_batch(batch)
279
+ batch = self.prepare_mlp_sync_batch(batch)
280
280
  self.cur_batch = batch
281
281
 
282
282
  if batch:
@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
310
310
  batch = self.get_new_batch_prefill()
311
311
 
312
312
  if require_mlp_sync(self.server_args):
313
- batch, _ = self.prepare_mlp_sync_batch(batch)
313
+ batch = self.prepare_mlp_sync_batch(batch)
314
314
  self.cur_batch = batch
315
315
  if batch:
316
316
  result = self.run_batch(batch)
@@ -74,7 +74,7 @@ class ReqToMetadataIdxAllocator:
74
74
  def available_size(self):
75
75
  return len(self.free_slots)
76
76
 
77
- def alloc(self) -> List[int]:
77
+ def alloc(self) -> Optional[int]:
78
78
  if len(self.free_slots) == 0:
79
79
  return None
80
80
 
@@ -42,8 +42,10 @@ from torch.distributed import Backend, ProcessGroup
42
42
  from sglang.srt.utils import (
43
43
  direct_register_custom_op,
44
44
  get_bool_env_var,
45
+ get_int_env_var,
45
46
  is_cuda_alike,
46
47
  is_npu,
48
+ is_shm_available,
47
49
  supports_custom_op,
48
50
  )
49
51
 
@@ -222,6 +224,7 @@ class GroupCoordinator:
222
224
  self.local_rank = local_rank
223
225
  self.device_group = None
224
226
  self.cpu_group = None
227
+ self.local_size = get_int_env_var("LOCAL_SIZE", 0)
225
228
 
226
229
  for ranks in group_ranks:
227
230
  device_group = torch.distributed.new_group(
@@ -440,9 +443,12 @@ class GroupCoordinator:
440
443
  return input_
441
444
 
442
445
  if input_.is_cpu:
443
- import intel_extension_for_pytorch as ipex
444
-
445
- ipex.distributed.all_reduce(input_, group=self.device_group)
446
+ if is_shm_available(input_.dtype, self.world_size, self.local_size):
447
+ torch.ops.sgl_kernel.shm_allreduce(
448
+ input_, torch.distributed.ReduceOp.SUM
449
+ )
450
+ else:
451
+ torch.distributed.all_reduce(input_, group=self.device_group)
446
452
  return input_
447
453
 
448
454
  if not supports_custom_op():
@@ -570,6 +576,16 @@ class GroupCoordinator:
570
576
  output_tensor = torch.empty(
571
577
  output_size, dtype=input_.dtype, device=input_.device
572
578
  )
579
+
580
+ if input_.is_cpu:
581
+ if is_shm_available(input_.dtype, self.world_size, self.local_size):
582
+ return torch.ops.sgl_kernel.shm_allgather(input_, dim)
583
+ else:
584
+ torch.distributed.all_gather_into_tensor(
585
+ output_tensor, input_, group=self.device_group
586
+ )
587
+ return output_tensor
588
+
573
589
  # All-gather.
574
590
  self.all_gather_into_tensor(output_tensor, input_)
575
591
  # Reshape
@@ -683,18 +699,25 @@ class GroupCoordinator:
683
699
  )
684
700
 
685
701
  # Serialize object to tensor and get the size as well
686
- object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
702
+ object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
703
+ device=torch.cuda.current_device()
704
+ )
687
705
 
688
706
  size_tensor = torch.tensor(
689
- [object_tensor.numel()], dtype=torch.long, device="cpu"
707
+ [object_tensor.numel()],
708
+ dtype=torch.long,
709
+ device=torch.cuda.current_device(),
690
710
  )
691
711
 
692
712
  # Send object size
693
-
694
- torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
713
+ torch.distributed.send(
714
+ size_tensor, dst=self.ranks[dst], group=self.device_group
715
+ )
695
716
 
696
717
  # Send object
697
- torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
718
+ torch.distributed.send(
719
+ object_tensor, dst=self.ranks[dst], group=self.device_group
720
+ )
698
721
 
699
722
  return None
700
723
 
@@ -708,29 +731,31 @@ class GroupCoordinator:
708
731
  src != self.rank_in_group
709
732
  ), "Invalid source rank. Source rank is the same as the current rank."
710
733
 
711
- size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
734
+ size_tensor = torch.empty(
735
+ 1, dtype=torch.long, device=torch.cuda.current_device()
736
+ )
712
737
 
713
738
  # Receive object size
714
739
  rank_size = torch.distributed.recv(
715
- size_tensor, src=self.ranks[src], group=self.cpu_group
740
+ size_tensor, src=self.ranks[src], group=self.device_group
716
741
  )
717
742
 
718
743
  # Tensor to receive serialized objects into.
719
744
  object_tensor = torch.empty( # type: ignore[call-overload]
720
745
  size_tensor.item(), # type: ignore[arg-type]
721
746
  dtype=torch.uint8,
722
- device="cpu",
747
+ device=torch.cuda.current_device(),
723
748
  )
724
749
 
725
750
  rank_object = torch.distributed.recv(
726
- object_tensor, src=self.ranks[src], group=self.cpu_group
751
+ object_tensor, src=self.ranks[src], group=self.device_group
727
752
  )
728
753
 
729
754
  assert (
730
755
  rank_object == rank_size
731
756
  ), "Received object sender rank does not match the size sender rank."
732
757
 
733
- obj = pickle.loads(object_tensor.numpy().tobytes())
758
+ obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
734
759
 
735
760
  return obj
736
761
 
@@ -841,14 +866,16 @@ class GroupCoordinator:
841
866
  dst = (self.rank_in_group + 1) % self.world_size
842
867
  assert dst < self.world_size, f"Invalid dst rank ({dst})"
843
868
 
844
- metadata_list: List[Tuple[Any, Any]] = []
845
869
  assert isinstance(
846
870
  tensor_dict, dict
847
871
  ), f"Expecting a dictionary, got {type(tensor_dict)}"
848
872
  metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
849
- # `metadata_list` lives in CPU memory.
850
- # `send_object_list` has serialization & deserialization,
851
- # all happening on CPU. Therefore, we can use the CPU group.
873
+ # Note: While switching to Device-to-Device (D2D) would introduce an extra
874
+ # Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
875
+ # show better overall transmission performance with D2D due to:
876
+ # 1. Superior D2D transfer bandwidth
877
+ # 2. Ability to overlap send and recv operations
878
+ # Thus the net performance gain justifies this approach.
852
879
  self.send_object(metadata_list, dst=dst)
853
880
  for tensor in tensor_list:
854
881
  if tensor.numel() == 0:
@@ -48,6 +48,14 @@ class EngineBase(ABC):
48
48
  """Update model weights with in-memory tensor data."""
49
49
  pass
50
50
 
51
+ def load_lora_adapter(self, lora_name: str, lora_path: str):
52
+ """Load a new LoRA adapter without re-launching the engine."""
53
+ pass
54
+
55
+ def unload_lora_adapter(self, lora_name: str):
56
+ """Unload a LoRA adapter without re-launching the engine."""
57
+ pass
58
+
51
59
  @abstractmethod
52
60
  def release_memory_occupation(self):
53
61
  """Release GPU memory occupation temporarily."""
@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
48
48
  GetWeightsByNameReqInput,
49
49
  ImageDataItem,
50
50
  InitWeightsUpdateGroupReqInput,
51
+ LoadLoRAAdapterReqInput,
51
52
  ReleaseMemoryOccupationReqInput,
52
53
  ResumeMemoryOccupationReqInput,
53
54
  RpcReqInput,
54
55
  RpcReqOutput,
56
+ UnloadLoRAAdapterReqInput,
55
57
  UpdateWeightFromDiskReqInput,
56
58
  UpdateWeightsFromDistributedReqInput,
57
59
  UpdateWeightsFromTensorReqInput,
@@ -416,12 +418,21 @@ class Engine(EngineBase):
416
418
  self.tokenizer_manager.init_weights_update_group(obj, None)
417
419
  )
418
420
 
419
- def update_weights_from_distributed(self, name: str, dtype, shape):
421
+ def update_weights_from_distributed(
422
+ self,
423
+ names: list[str],
424
+ dtypes: list[str],
425
+ shapes: list[list[int]],
426
+ group_name: str = "weight_update_group",
427
+ flush_cache: bool = True,
428
+ ):
420
429
  """Update weights from distributed source."""
421
430
  obj = UpdateWeightsFromDistributedReqInput(
422
- name=name,
423
- dtype=dtype,
424
- shape=shape,
431
+ names=names,
432
+ dtypes=dtypes,
433
+ shapes=shapes,
434
+ group_name=group_name,
435
+ flush_cache=flush_cache,
425
436
  )
426
437
  loop = asyncio.get_event_loop()
427
438
  return loop.run_until_complete(
@@ -478,6 +489,29 @@ class Engine(EngineBase):
478
489
  self.tokenizer_manager.get_weights_by_name(obj, None)
479
490
  )
480
491
 
492
+ def load_lora_adapter(self, lora_name: str, lora_path: str):
493
+ """Load a new LoRA adapter without re-launching the engine."""
494
+
495
+ obj = LoadLoRAAdapterReqInput(
496
+ lora_name=lora_name,
497
+ lora_path=lora_path,
498
+ )
499
+
500
+ loop = asyncio.get_event_loop()
501
+ return loop.run_until_complete(
502
+ self.tokenizer_manager.load_lora_adapter(obj, None)
503
+ )
504
+
505
+ def unload_lora_adapter(self, lora_name: str):
506
+ """Unload a LoRA adapter without re-launching the engine."""
507
+
508
+ obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
509
+
510
+ loop = asyncio.get_event_loop()
511
+ return loop.run_until_complete(
512
+ self.tokenizer_manager.unload_lora_adapter(obj, None)
513
+ )
514
+
481
515
  def release_memory_occupation(self, tags: Optional[List[str]] = None):
482
516
  obj = ReleaseMemoryOccupationReqInput(tags=tags)
483
517
  loop = asyncio.get_event_loop()
@@ -608,7 +642,7 @@ def _set_envs_and_config(server_args: ServerArgs):
608
642
  if server_args.attention_backend == "flashinfer":
609
643
  assert_pkg_version(
610
644
  "flashinfer_python",
611
- "0.2.6.post1",
645
+ "0.2.7.post1",
612
646
  "Please uninstall the old version and "
613
647
  "reinstall the latest version by following the instructions "
614
648
  "at https://docs.flashinfer.ai/installation.html.",
@@ -616,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
616
650
  if _is_cuda:
617
651
  assert_pkg_version(
618
652
  "sgl-kernel",
619
- "0.1.9",
653
+ "0.2.4",
620
654
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
621
655
  )
622
656