sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/conversation.py
CHANGED
@@ -59,6 +59,7 @@ class SeparatorStyle(IntEnum):
|
|
59
59
|
METAMATH = auto()
|
60
60
|
DeepSeekVL2 = auto()
|
61
61
|
QWEN2_VL_EMBED = auto()
|
62
|
+
QWEN2_AUDIO = auto()
|
62
63
|
GEMMA3 = auto()
|
63
64
|
MPT = auto()
|
64
65
|
|
@@ -350,6 +351,23 @@ class Conversation:
|
|
350
351
|
else:
|
351
352
|
ret += role
|
352
353
|
return ret
|
354
|
+
elif self.sep_style == SeparatorStyle.QWEN2_AUDIO:
|
355
|
+
ret = "" if system_prompt == "" else system_prompt + self.sep
|
356
|
+
|
357
|
+
counter = 1
|
358
|
+
for role, message in self.messages:
|
359
|
+
if message:
|
360
|
+
while self.audio_token in message:
|
361
|
+
message = message.replace(
|
362
|
+
self.audio_token, self.audio_token.format(idx=counter), 1
|
363
|
+
)
|
364
|
+
counter += 1
|
365
|
+
|
366
|
+
ret += role + "\n" + message + self.sep
|
367
|
+
else:
|
368
|
+
ret += role + "\n"
|
369
|
+
|
370
|
+
return ret
|
353
371
|
else:
|
354
372
|
raise ValueError(f"Invalid style: {self.sep_style}")
|
355
373
|
|
@@ -903,6 +921,46 @@ register_conv_template(
|
|
903
921
|
)
|
904
922
|
)
|
905
923
|
|
924
|
+
register_conv_template(
|
925
|
+
Conversation(
|
926
|
+
name="mimo-vl",
|
927
|
+
system_message="You are MiMo, an AI assistant developed by Xiaomi.",
|
928
|
+
system_template="<|im_start|>system\n{system_message}",
|
929
|
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
930
|
+
sep="<|im_end|>\n",
|
931
|
+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
932
|
+
stop_str=["<|im_end|>"],
|
933
|
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
934
|
+
)
|
935
|
+
)
|
936
|
+
|
937
|
+
|
938
|
+
register_conv_template(
|
939
|
+
Conversation(
|
940
|
+
name="qwen2-audio",
|
941
|
+
system_template="<|im_start|>system\n{system_message}",
|
942
|
+
system_message="You are a helpful assistant.",
|
943
|
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
944
|
+
sep="<|im_end|>\n",
|
945
|
+
sep_style=SeparatorStyle.QWEN2_AUDIO,
|
946
|
+
stop_str=["<|im_end|>"],
|
947
|
+
audio_token="Audio {idx}: <|audio_bos|><|AUDIO|><|audio_eos|>\n",
|
948
|
+
)
|
949
|
+
)
|
950
|
+
|
951
|
+
register_conv_template(
|
952
|
+
Conversation(
|
953
|
+
name="llama_4_vision",
|
954
|
+
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
955
|
+
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
956
|
+
roles=("user", "assistant"),
|
957
|
+
sep_style=SeparatorStyle.LLAMA4,
|
958
|
+
sep="",
|
959
|
+
stop_str="<|eot|>",
|
960
|
+
image_token="<|image|>",
|
961
|
+
)
|
962
|
+
)
|
963
|
+
|
906
964
|
|
907
965
|
@register_conv_template_matching_function
|
908
966
|
def match_internvl(model_path: str):
|
@@ -911,9 +969,11 @@ def match_internvl(model_path: str):
|
|
911
969
|
|
912
970
|
|
913
971
|
@register_conv_template_matching_function
|
914
|
-
def
|
972
|
+
def match_llama_vision(model_path: str):
|
915
973
|
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
|
916
974
|
return "llama_3_vision"
|
975
|
+
if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
|
976
|
+
return "llama_4_vision"
|
917
977
|
|
918
978
|
|
919
979
|
@register_conv_template_matching_function
|
@@ -956,6 +1016,8 @@ def match_qwen_chat_ml(model_path: str):
|
|
956
1016
|
return "gme-qwen2-vl"
|
957
1017
|
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
958
1018
|
return "qwen2-vl"
|
1019
|
+
if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
|
1020
|
+
return "qwen2-audio"
|
959
1021
|
if re.search(
|
960
1022
|
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
961
1023
|
model_path,
|
@@ -1000,3 +1062,9 @@ def match_phi_4_mm(model_path: str):
|
|
1000
1062
|
def match_vila(model_path: str):
|
1001
1063
|
if re.search(r"vila", model_path, re.IGNORECASE):
|
1002
1064
|
return "chatml"
|
1065
|
+
|
1066
|
+
|
1067
|
+
@register_conv_template_matching_function
|
1068
|
+
def match_mimo_vl(model_path: str):
|
1069
|
+
if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
|
1070
|
+
return "mimo-vl"
|
@@ -416,6 +416,12 @@ class DecodePreallocQueue:
|
|
416
416
|
|
417
417
|
return preallocated_reqs
|
418
418
|
|
419
|
+
@property
|
420
|
+
def num_tokens_pre_allocated(self):
|
421
|
+
return sum(
|
422
|
+
len(decode_req.req.fill_ids) for decode_req in self.transfer_queue.queue
|
423
|
+
)
|
424
|
+
|
419
425
|
def _allocatable_tokens(
|
420
426
|
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
|
421
427
|
) -> int:
|
@@ -433,9 +439,7 @@ class DecodePreallocQueue:
|
|
433
439
|
else 0
|
434
440
|
)
|
435
441
|
|
436
|
-
|
437
|
-
|
438
|
-
allocatable_tokens = available_size - max(
|
442
|
+
allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
|
439
443
|
# preserve some space for future decode
|
440
444
|
self.num_reserved_decode_tokens
|
441
445
|
* (
|
@@ -606,9 +610,21 @@ class DecodeTransferQueue:
|
|
606
610
|
: decode_req.req.top_logprobs_num
|
607
611
|
].tolist()
|
608
612
|
)
|
613
|
+
|
609
614
|
if hasattr(decode_req.kv_receiver, "clear"):
|
610
615
|
decode_req.kv_receiver.clear()
|
611
|
-
|
616
|
+
|
617
|
+
# special handling for sampling_params.max_new_tokens == 1
|
618
|
+
if decode_req.req.sampling_params.max_new_tokens == 1:
|
619
|
+
# finish immediately
|
620
|
+
decode_req.req.check_finished()
|
621
|
+
self.scheduler.stream_output(
|
622
|
+
[decode_req.req], decode_req.req.return_logprob
|
623
|
+
)
|
624
|
+
self.tree_cache.cache_finished_req(decode_req.req)
|
625
|
+
else:
|
626
|
+
transferred_reqs.append(decode_req.req)
|
627
|
+
|
612
628
|
indices_to_remove.add(i)
|
613
629
|
elif poll in [
|
614
630
|
KVPoll.Bootstrapping,
|
@@ -756,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
756
772
|
self.last_batch_in_queue = last_batch_in_queue
|
757
773
|
|
758
774
|
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
759
|
-
batch
|
775
|
+
batch = self.prepare_mlp_sync_batch(batch)
|
760
776
|
result = None
|
761
777
|
if batch:
|
762
778
|
result = self.run_batch(batch)
|
@@ -185,9 +185,11 @@ class MooncakeKVManager(BaseKVManager):
|
|
185
185
|
threading.Thread(
|
186
186
|
target=self.transfer_worker, args=(queue, executor), daemon=True
|
187
187
|
).start()
|
188
|
-
|
189
|
-
|
190
|
-
|
188
|
+
# If a timeout happens on the prefill side, it means prefill instances
|
189
|
+
# fail to receive the KV indices from the decode instance of this request.
|
190
|
+
# These timeout requests should be aborted to release the tree cache.
|
191
|
+
self.bootstrap_timeout = get_int_env_var(
|
192
|
+
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
|
191
193
|
)
|
192
194
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
193
195
|
self.heartbeat_failures = {}
|
@@ -209,6 +211,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
209
211
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
210
212
|
self.prefill_tp_size_table: Dict[str, int] = {}
|
211
213
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
214
|
+
# If a timeout happens on the decode side, it means decode instances
|
215
|
+
# fail to receive the KV Cache transfer done signal after bootstrapping.
|
216
|
+
# These timeout requests should be aborted to release the tree cache.
|
217
|
+
self.waiting_timeout = get_int_env_var(
|
218
|
+
"SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
|
219
|
+
)
|
212
220
|
else:
|
213
221
|
raise ValueError(
|
214
222
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
@@ -938,7 +946,12 @@ class MooncakeKVSender(BaseKVSender):
|
|
938
946
|
if self.init_time is not None:
|
939
947
|
now = time.time()
|
940
948
|
elapsed = now - self.init_time
|
941
|
-
if elapsed >= self.kv_mgr.
|
949
|
+
if elapsed >= self.kv_mgr.bootstrap_timeout:
|
950
|
+
logger.warning_once(
|
951
|
+
"Some requests timed out when bootstrapping, "
|
952
|
+
"which means prefill instances fail to receive the KV indices from the decode instance of this request. "
|
953
|
+
"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
954
|
+
)
|
942
955
|
self.kv_mgr.record_failure(
|
943
956
|
self.bootstrap_room,
|
944
957
|
f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.Bootstrapping",
|
@@ -987,6 +1000,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
987
1000
|
self.session_id = self.kv_mgr.get_session_id()
|
988
1001
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
989
1002
|
self.conclude_state = None
|
1003
|
+
self.init_time = None
|
990
1004
|
self.data_parallel_rank = data_parallel_rank
|
991
1005
|
|
992
1006
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
@@ -1222,14 +1236,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1222
1236
|
str(self.required_dst_info_num).encode("ascii"),
|
1223
1237
|
]
|
1224
1238
|
)
|
1239
|
+
self.init_time = time.time()
|
1225
1240
|
|
1226
1241
|
def poll(self) -> KVPoll:
|
1227
1242
|
if self.conclude_state is None:
|
1228
1243
|
status = self.kv_mgr.check_status(self.bootstrap_room)
|
1229
1244
|
if status in (KVPoll.Success, KVPoll.Failed):
|
1230
1245
|
self.conclude_state = status
|
1246
|
+
elif status == KVPoll.WaitingForInput:
|
1247
|
+
if self.init_time is not None:
|
1248
|
+
now = time.time()
|
1249
|
+
elapsed = now - self.init_time
|
1250
|
+
if elapsed >= self.kv_mgr.waiting_timeout:
|
1251
|
+
logger.warning_once(
|
1252
|
+
"Some requests fail to receive KV Cache transfer done signal after bootstrapping. "
|
1253
|
+
"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
1254
|
+
)
|
1255
|
+
self.kv_mgr.record_failure(
|
1256
|
+
self.bootstrap_room,
|
1257
|
+
f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput",
|
1258
|
+
)
|
1259
|
+
self.conclude_state = KVPoll.Failed
|
1260
|
+
return KVPoll.Failed
|
1231
1261
|
|
1232
1262
|
return status
|
1263
|
+
|
1233
1264
|
else:
|
1234
1265
|
return self.conclude_state
|
1235
1266
|
|
@@ -159,7 +159,7 @@ class NixlKVManager(CommonKVManager):
|
|
159
159
|
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
160
160
|
):
|
161
161
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
162
|
-
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=
|
162
|
+
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
|
163
163
|
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
164
164
|
if not self.kv_descs:
|
165
165
|
raise Exception("NIXL memory registration failed for kv tensors")
|
@@ -168,7 +168,7 @@ class NixlKVManager(CommonKVManager):
|
|
168
168
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
169
169
|
):
|
170
170
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
171
|
-
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=
|
171
|
+
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
|
172
172
|
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
173
173
|
if not self.aux_descs:
|
174
174
|
raise Exception("NIXL memory registration failed for aux tensors")
|
@@ -215,8 +215,8 @@ class NixlKVManager(CommonKVManager):
|
|
215
215
|
logger.debug(
|
216
216
|
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
217
217
|
)
|
218
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=
|
219
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=
|
218
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
|
219
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
|
220
220
|
# Transfer data
|
221
221
|
xfer_handle = self.agent.initialize_xfer(
|
222
222
|
"WRITE",
|
@@ -248,8 +248,8 @@ class NixlKVManager(CommonKVManager):
|
|
248
248
|
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
249
249
|
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
250
250
|
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
251
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=
|
252
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=
|
251
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
|
252
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
|
253
253
|
# Transfer data
|
254
254
|
xfer_handle = self.agent.initialize_xfer(
|
255
255
|
"WRITE",
|
@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
276
276
|
batch = self.get_new_batch_prefill()
|
277
277
|
|
278
278
|
if require_mlp_sync(self.server_args):
|
279
|
-
batch
|
279
|
+
batch = self.prepare_mlp_sync_batch(batch)
|
280
280
|
self.cur_batch = batch
|
281
281
|
|
282
282
|
if batch:
|
@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
310
310
|
batch = self.get_new_batch_prefill()
|
311
311
|
|
312
312
|
if require_mlp_sync(self.server_args):
|
313
|
-
batch
|
313
|
+
batch = self.prepare_mlp_sync_batch(batch)
|
314
314
|
self.cur_batch = batch
|
315
315
|
if batch:
|
316
316
|
result = self.run_batch(batch)
|
@@ -42,8 +42,10 @@ from torch.distributed import Backend, ProcessGroup
|
|
42
42
|
from sglang.srt.utils import (
|
43
43
|
direct_register_custom_op,
|
44
44
|
get_bool_env_var,
|
45
|
+
get_int_env_var,
|
45
46
|
is_cuda_alike,
|
46
47
|
is_npu,
|
48
|
+
is_shm_available,
|
47
49
|
supports_custom_op,
|
48
50
|
)
|
49
51
|
|
@@ -222,6 +224,7 @@ class GroupCoordinator:
|
|
222
224
|
self.local_rank = local_rank
|
223
225
|
self.device_group = None
|
224
226
|
self.cpu_group = None
|
227
|
+
self.local_size = get_int_env_var("LOCAL_SIZE", 0)
|
225
228
|
|
226
229
|
for ranks in group_ranks:
|
227
230
|
device_group = torch.distributed.new_group(
|
@@ -440,9 +443,12 @@ class GroupCoordinator:
|
|
440
443
|
return input_
|
441
444
|
|
442
445
|
if input_.is_cpu:
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
+
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
447
|
+
torch.ops.sgl_kernel.shm_allreduce(
|
448
|
+
input_, torch.distributed.ReduceOp.SUM
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
torch.distributed.all_reduce(input_, group=self.device_group)
|
446
452
|
return input_
|
447
453
|
|
448
454
|
if not supports_custom_op():
|
@@ -570,6 +576,16 @@ class GroupCoordinator:
|
|
570
576
|
output_tensor = torch.empty(
|
571
577
|
output_size, dtype=input_.dtype, device=input_.device
|
572
578
|
)
|
579
|
+
|
580
|
+
if input_.is_cpu:
|
581
|
+
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
582
|
+
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
583
|
+
else:
|
584
|
+
torch.distributed.all_gather_into_tensor(
|
585
|
+
output_tensor, input_, group=self.device_group
|
586
|
+
)
|
587
|
+
return output_tensor
|
588
|
+
|
573
589
|
# All-gather.
|
574
590
|
self.all_gather_into_tensor(output_tensor, input_)
|
575
591
|
# Reshape
|
@@ -683,18 +699,25 @@ class GroupCoordinator:
|
|
683
699
|
)
|
684
700
|
|
685
701
|
# Serialize object to tensor and get the size as well
|
686
|
-
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
702
|
+
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
|
703
|
+
device=torch.cuda.current_device()
|
704
|
+
)
|
687
705
|
|
688
706
|
size_tensor = torch.tensor(
|
689
|
-
[object_tensor.numel()],
|
707
|
+
[object_tensor.numel()],
|
708
|
+
dtype=torch.long,
|
709
|
+
device=torch.cuda.current_device(),
|
690
710
|
)
|
691
711
|
|
692
712
|
# Send object size
|
693
|
-
|
694
|
-
|
713
|
+
torch.distributed.send(
|
714
|
+
size_tensor, dst=self.ranks[dst], group=self.device_group
|
715
|
+
)
|
695
716
|
|
696
717
|
# Send object
|
697
|
-
torch.distributed.send(
|
718
|
+
torch.distributed.send(
|
719
|
+
object_tensor, dst=self.ranks[dst], group=self.device_group
|
720
|
+
)
|
698
721
|
|
699
722
|
return None
|
700
723
|
|
@@ -708,29 +731,31 @@ class GroupCoordinator:
|
|
708
731
|
src != self.rank_in_group
|
709
732
|
), "Invalid source rank. Source rank is the same as the current rank."
|
710
733
|
|
711
|
-
size_tensor = torch.empty(
|
734
|
+
size_tensor = torch.empty(
|
735
|
+
1, dtype=torch.long, device=torch.cuda.current_device()
|
736
|
+
)
|
712
737
|
|
713
738
|
# Receive object size
|
714
739
|
rank_size = torch.distributed.recv(
|
715
|
-
size_tensor, src=self.ranks[src], group=self.
|
740
|
+
size_tensor, src=self.ranks[src], group=self.device_group
|
716
741
|
)
|
717
742
|
|
718
743
|
# Tensor to receive serialized objects into.
|
719
744
|
object_tensor = torch.empty( # type: ignore[call-overload]
|
720
745
|
size_tensor.item(), # type: ignore[arg-type]
|
721
746
|
dtype=torch.uint8,
|
722
|
-
device=
|
747
|
+
device=torch.cuda.current_device(),
|
723
748
|
)
|
724
749
|
|
725
750
|
rank_object = torch.distributed.recv(
|
726
|
-
object_tensor, src=self.ranks[src], group=self.
|
751
|
+
object_tensor, src=self.ranks[src], group=self.device_group
|
727
752
|
)
|
728
753
|
|
729
754
|
assert (
|
730
755
|
rank_object == rank_size
|
731
756
|
), "Received object sender rank does not match the size sender rank."
|
732
757
|
|
733
|
-
obj = pickle.loads(object_tensor.numpy().tobytes())
|
758
|
+
obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
|
734
759
|
|
735
760
|
return obj
|
736
761
|
|
@@ -841,14 +866,16 @@ class GroupCoordinator:
|
|
841
866
|
dst = (self.rank_in_group + 1) % self.world_size
|
842
867
|
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
843
868
|
|
844
|
-
metadata_list: List[Tuple[Any, Any]] = []
|
845
869
|
assert isinstance(
|
846
870
|
tensor_dict, dict
|
847
871
|
), f"Expecting a dictionary, got {type(tensor_dict)}"
|
848
872
|
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
849
|
-
#
|
850
|
-
#
|
851
|
-
#
|
873
|
+
# Note: While switching to Device-to-Device (D2D) would introduce an extra
|
874
|
+
# Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
|
875
|
+
# show better overall transmission performance with D2D due to:
|
876
|
+
# 1. Superior D2D transfer bandwidth
|
877
|
+
# 2. Ability to overlap send and recv operations
|
878
|
+
# Thus the net performance gain justifies this approach.
|
852
879
|
self.send_object(metadata_list, dst=dst)
|
853
880
|
for tensor in tensor_list:
|
854
881
|
if tensor.numel() == 0:
|
@@ -48,6 +48,14 @@ class EngineBase(ABC):
|
|
48
48
|
"""Update model weights with in-memory tensor data."""
|
49
49
|
pass
|
50
50
|
|
51
|
+
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
52
|
+
"""Load a new LoRA adapter without re-launching the engine."""
|
53
|
+
pass
|
54
|
+
|
55
|
+
def unload_lora_adapter(self, lora_name: str):
|
56
|
+
"""Unload a LoRA adapter without re-launching the engine."""
|
57
|
+
pass
|
58
|
+
|
51
59
|
@abstractmethod
|
52
60
|
def release_memory_occupation(self):
|
53
61
|
"""Release GPU memory occupation temporarily."""
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
|
|
48
48
|
GetWeightsByNameReqInput,
|
49
49
|
ImageDataItem,
|
50
50
|
InitWeightsUpdateGroupReqInput,
|
51
|
+
LoadLoRAAdapterReqInput,
|
51
52
|
ReleaseMemoryOccupationReqInput,
|
52
53
|
ResumeMemoryOccupationReqInput,
|
53
54
|
RpcReqInput,
|
54
55
|
RpcReqOutput,
|
56
|
+
UnloadLoRAAdapterReqInput,
|
55
57
|
UpdateWeightFromDiskReqInput,
|
56
58
|
UpdateWeightsFromDistributedReqInput,
|
57
59
|
UpdateWeightsFromTensorReqInput,
|
@@ -416,12 +418,21 @@ class Engine(EngineBase):
|
|
416
418
|
self.tokenizer_manager.init_weights_update_group(obj, None)
|
417
419
|
)
|
418
420
|
|
419
|
-
def update_weights_from_distributed(
|
421
|
+
def update_weights_from_distributed(
|
422
|
+
self,
|
423
|
+
names: list[str],
|
424
|
+
dtypes: list[str],
|
425
|
+
shapes: list[list[int]],
|
426
|
+
group_name: str = "weight_update_group",
|
427
|
+
flush_cache: bool = True,
|
428
|
+
):
|
420
429
|
"""Update weights from distributed source."""
|
421
430
|
obj = UpdateWeightsFromDistributedReqInput(
|
422
|
-
|
423
|
-
|
424
|
-
|
431
|
+
names=names,
|
432
|
+
dtypes=dtypes,
|
433
|
+
shapes=shapes,
|
434
|
+
group_name=group_name,
|
435
|
+
flush_cache=flush_cache,
|
425
436
|
)
|
426
437
|
loop = asyncio.get_event_loop()
|
427
438
|
return loop.run_until_complete(
|
@@ -478,6 +489,29 @@ class Engine(EngineBase):
|
|
478
489
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
479
490
|
)
|
480
491
|
|
492
|
+
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
493
|
+
"""Load a new LoRA adapter without re-launching the engine."""
|
494
|
+
|
495
|
+
obj = LoadLoRAAdapterReqInput(
|
496
|
+
lora_name=lora_name,
|
497
|
+
lora_path=lora_path,
|
498
|
+
)
|
499
|
+
|
500
|
+
loop = asyncio.get_event_loop()
|
501
|
+
return loop.run_until_complete(
|
502
|
+
self.tokenizer_manager.load_lora_adapter(obj, None)
|
503
|
+
)
|
504
|
+
|
505
|
+
def unload_lora_adapter(self, lora_name: str):
|
506
|
+
"""Unload a LoRA adapter without re-launching the engine."""
|
507
|
+
|
508
|
+
obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
|
509
|
+
|
510
|
+
loop = asyncio.get_event_loop()
|
511
|
+
return loop.run_until_complete(
|
512
|
+
self.tokenizer_manager.unload_lora_adapter(obj, None)
|
513
|
+
)
|
514
|
+
|
481
515
|
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
482
516
|
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
483
517
|
loop = asyncio.get_event_loop()
|
@@ -608,7 +642,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
608
642
|
if server_args.attention_backend == "flashinfer":
|
609
643
|
assert_pkg_version(
|
610
644
|
"flashinfer_python",
|
611
|
-
"0.2.
|
645
|
+
"0.2.7.post1",
|
612
646
|
"Please uninstall the old version and "
|
613
647
|
"reinstall the latest version by following the instructions "
|
614
648
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -616,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
616
650
|
if _is_cuda:
|
617
651
|
assert_pkg_version(
|
618
652
|
"sgl-kernel",
|
619
|
-
"0.
|
653
|
+
"0.2.4",
|
620
654
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
621
655
|
)
|
622
656
|
|