sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
|
|
19
19
|
|
20
20
|
class SafeDeserializer(Deserializer):
|
21
21
|
|
22
|
-
def __init__(self
|
23
|
-
|
22
|
+
def __init__(self):
|
23
|
+
# TODO: dtype options
|
24
|
+
super().__init__(torch.float32)
|
24
25
|
|
25
26
|
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
26
|
-
return load(bytes(b))["tensor_bytes"]
|
27
|
+
return load(bytes(b))["tensor_bytes"]
|
27
28
|
|
28
29
|
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
29
30
|
return self.from_bytes_normal(b)
|
sglang/srt/conversation.py
CHANGED
@@ -26,6 +26,8 @@ Key components:
|
|
26
26
|
# Adapted from
|
27
27
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
28
28
|
import dataclasses
|
29
|
+
import json
|
30
|
+
import os
|
29
31
|
import re
|
30
32
|
from enum import IntEnum, auto
|
31
33
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
@@ -959,16 +961,42 @@ register_conv_template(
|
|
959
961
|
)
|
960
962
|
|
961
963
|
|
964
|
+
MODEL_TYPE_TO_TEMPLATE = {
|
965
|
+
"internvl_chat": "internvl-2-5",
|
966
|
+
"deepseek_vl_v2": "deepseek-vl2",
|
967
|
+
"multi_modality": "janus-pro",
|
968
|
+
"phi4mm": "phi-4-mm",
|
969
|
+
"minicpmv": "minicpmv",
|
970
|
+
"minicpmo": "minicpmo",
|
971
|
+
}
|
972
|
+
|
973
|
+
|
974
|
+
def get_model_type(model_path: str) -> Optional[str]:
|
975
|
+
config_path = os.path.join(model_path, "config.json")
|
976
|
+
if not os.path.exists(config_path):
|
977
|
+
return None
|
978
|
+
try:
|
979
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
980
|
+
config = json.load(f)
|
981
|
+
return config.get("model_type")
|
982
|
+
except (IOError, json.JSONDecodeError):
|
983
|
+
return None
|
984
|
+
|
985
|
+
|
962
986
|
@register_conv_template_matching_function
|
963
987
|
def match_internvl(model_path: str):
|
964
988
|
if re.search(r"internvl", model_path, re.IGNORECASE):
|
965
989
|
return "internvl-2-5"
|
990
|
+
model_type = get_model_type(model_path)
|
991
|
+
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
966
992
|
|
967
993
|
|
968
994
|
@register_conv_template_matching_function
|
969
995
|
def match_deepseek_janus_pro(model_path: str):
|
970
996
|
if re.search(r"janus", model_path, re.IGNORECASE):
|
971
997
|
return "janus-pro"
|
998
|
+
model_type = get_model_type(model_path)
|
999
|
+
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
972
1000
|
|
973
1001
|
|
974
1002
|
@register_conv_template_matching_function
|
@@ -981,6 +1009,8 @@ def match_vicuna(model_path: str):
|
|
981
1009
|
def match_deepseek_vl(model_path: str):
|
982
1010
|
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
|
983
1011
|
return "deepseek-vl2"
|
1012
|
+
model_type = get_model_type(model_path)
|
1013
|
+
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
984
1014
|
|
985
1015
|
|
986
1016
|
@register_conv_template_matching_function
|
@@ -994,14 +1024,17 @@ def match_qwen_chat_ml(model_path: str):
|
|
994
1024
|
|
995
1025
|
|
996
1026
|
@register_conv_template_matching_function
|
997
|
-
def
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1027
|
+
def match_minicpm(model_path: str):
|
1028
|
+
match = re.search(r"minicpm-(v|o)", model_path, re.IGNORECASE)
|
1029
|
+
if match:
|
1030
|
+
return f"minicpm{match.group(1).lower()}"
|
1031
|
+
model_type = get_model_type(model_path)
|
1032
|
+
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
1002
1033
|
|
1003
1034
|
|
1004
1035
|
@register_conv_template_matching_function
|
1005
1036
|
def match_phi_4_mm(model_path: str):
|
1006
1037
|
if "phi-4-multimodal" in model_path.lower():
|
1007
1038
|
return "phi-4-mm"
|
1039
|
+
model_type = get_model_type(model_path)
|
1040
|
+
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
@@ -1,6 +1,12 @@
|
|
1
|
+
import concurrent.futures
|
1
2
|
import logging
|
3
|
+
from typing import List, Tuple
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import numpy.typing as npt
|
2
7
|
|
3
8
|
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
9
|
+
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
4
10
|
from sglang.srt.disaggregation.mooncake.conn import (
|
5
11
|
MooncakeKVBootstrapServer,
|
6
12
|
MooncakeKVManager,
|
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
|
|
29
35
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
30
36
|
)
|
31
37
|
|
38
|
+
def send_kvcache(
|
39
|
+
self,
|
40
|
+
mooncake_session_id: str,
|
41
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
42
|
+
dst_kv_ptrs: list[int],
|
43
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
44
|
+
executor: concurrent.futures.ThreadPoolExecutor,
|
45
|
+
):
|
46
|
+
# Group by indices
|
47
|
+
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
48
|
+
prefill_kv_indices, dst_kv_indices
|
49
|
+
)
|
50
|
+
|
51
|
+
num_layers = len(self.kv_args.kv_data_ptrs)
|
52
|
+
layers_params = [
|
53
|
+
(
|
54
|
+
self.kv_args.kv_data_ptrs[layer_id],
|
55
|
+
dst_kv_ptrs[layer_id],
|
56
|
+
self.kv_args.kv_item_lens[layer_id],
|
57
|
+
)
|
58
|
+
for layer_id in range(num_layers)
|
59
|
+
]
|
60
|
+
|
61
|
+
def set_transfer_blocks(
|
62
|
+
src_ptr: int, dst_ptr: int, item_len: int
|
63
|
+
) -> List[Tuple[int, int, int]]:
|
64
|
+
transfer_blocks = []
|
65
|
+
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
66
|
+
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
67
|
+
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
68
|
+
length = item_len * len(prefill_index)
|
69
|
+
transfer_blocks.append((src_addr, dst_addr, length))
|
70
|
+
return transfer_blocks
|
71
|
+
|
72
|
+
# Worker function for processing a single layer
|
73
|
+
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
74
|
+
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
75
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
76
|
+
|
77
|
+
# Worker function for processing all layers in a batch
|
78
|
+
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
79
|
+
transfer_blocks = []
|
80
|
+
for src_ptr, dst_ptr, item_len in layers_params:
|
81
|
+
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
82
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
83
|
+
|
84
|
+
if self.enable_custom_mem_pool:
|
85
|
+
futures = [
|
86
|
+
executor.submit(
|
87
|
+
process_layer,
|
88
|
+
src_ptr,
|
89
|
+
dst_ptr,
|
90
|
+
item_len,
|
91
|
+
)
|
92
|
+
for (src_ptr, dst_ptr, item_len) in layers_params
|
93
|
+
]
|
94
|
+
for future in concurrent.futures.as_completed(futures):
|
95
|
+
status = future.result()
|
96
|
+
if status != 0:
|
97
|
+
for f in futures:
|
98
|
+
f.cancel()
|
99
|
+
return status
|
100
|
+
else:
|
101
|
+
# Combining all layers' params in one batch transfer is more efficient
|
102
|
+
# compared to using multiple threads
|
103
|
+
return process_layers(layers_params)
|
104
|
+
|
105
|
+
return 0
|
106
|
+
|
32
107
|
|
33
108
|
class AscendKVSender(MooncakeKVSender):
|
34
109
|
pass
|
@@ -6,7 +6,6 @@ from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
|
6
6
|
|
7
7
|
@dataclasses.dataclass
|
8
8
|
class LBArgs:
|
9
|
-
rust_lb: bool = False
|
10
9
|
host: str = "0.0.0.0"
|
11
10
|
port: int = 8000
|
12
11
|
policy: str = "random"
|
@@ -17,11 +16,6 @@ class LBArgs:
|
|
17
16
|
|
18
17
|
@staticmethod
|
19
18
|
def add_cli_args(parser: argparse.ArgumentParser):
|
20
|
-
parser.add_argument(
|
21
|
-
"--rust-lb",
|
22
|
-
action="store_true",
|
23
|
-
help="Deprecated, please use SGLang Router instead, this argument will have no effect.",
|
24
|
-
)
|
25
19
|
parser.add_argument(
|
26
20
|
"--host",
|
27
21
|
type=str,
|
@@ -92,7 +86,6 @@ class LBArgs:
|
|
92
86
|
]
|
93
87
|
|
94
88
|
return cls(
|
95
|
-
rust_lb=args.rust_lb,
|
96
89
|
host=args.host,
|
97
90
|
port=args.port,
|
98
91
|
policy=args.policy,
|
@@ -102,12 +95,6 @@ class LBArgs:
|
|
102
95
|
timeout=args.timeout,
|
103
96
|
)
|
104
97
|
|
105
|
-
def __post_init__(self):
|
106
|
-
if not self.rust_lb:
|
107
|
-
assert (
|
108
|
-
self.policy == "random"
|
109
|
-
), "Only random policy is supported for Python load balancer"
|
110
|
-
|
111
98
|
|
112
99
|
def main():
|
113
100
|
parser = argparse.ArgumentParser(
|
@@ -7,6 +7,7 @@ import dataclasses
|
|
7
7
|
import logging
|
8
8
|
import random
|
9
9
|
import urllib
|
10
|
+
from http import HTTPStatus
|
10
11
|
from itertools import chain
|
11
12
|
from typing import List, Optional
|
12
13
|
|
@@ -262,14 +263,38 @@ async def get_server_info():
|
|
262
263
|
|
263
264
|
@app.get("/get_model_info")
|
264
265
|
async def get_model_info():
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
266
|
+
global load_balancer
|
267
|
+
|
268
|
+
if not load_balancer or not load_balancer.prefill_servers:
|
269
|
+
raise HTTPException(
|
270
|
+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
271
|
+
detail="There is no server registered",
|
272
|
+
)
|
273
|
+
|
274
|
+
target_server_url = load_balancer.prefill_servers[0]
|
275
|
+
endpoint_url = f"{target_server_url}/get_model_info"
|
276
|
+
|
277
|
+
async with aiohttp.ClientSession() as session:
|
278
|
+
try:
|
279
|
+
async with session.get(endpoint_url) as response:
|
280
|
+
if response.status != 200:
|
281
|
+
error_text = await response.text()
|
282
|
+
raise HTTPException(
|
283
|
+
status_code=HTTPStatus.BAD_GATEWAY,
|
284
|
+
detail=(
|
285
|
+
f"Failed to get model info from {target_server_url}"
|
286
|
+
f"Status: {response.status}, Response: {error_text}"
|
287
|
+
),
|
288
|
+
)
|
289
|
+
|
290
|
+
model_info_json = await response.json()
|
291
|
+
return ORJSONResponse(content=model_info_json)
|
292
|
+
|
293
|
+
except aiohttp.ClientError as e:
|
294
|
+
raise HTTPException(
|
295
|
+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
296
|
+
detail=f"Failed to get model info from backend",
|
297
|
+
)
|
273
298
|
|
274
299
|
|
275
300
|
@app.post("/generate")
|
@@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
567
567
|
# Move the chunked request out of the batch so that we can merge
|
568
568
|
# only finished requests to running_batch.
|
569
569
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
570
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
570
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
571
571
|
if self.enable_overlap:
|
572
572
|
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
573
573
|
self.chunked_req.tmp_end_idx = min(
|
@@ -52,6 +52,8 @@ from sglang.srt.utils import (
|
|
52
52
|
|
53
53
|
_is_npu = is_npu()
|
54
54
|
|
55
|
+
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
56
|
+
|
55
57
|
|
56
58
|
@dataclass
|
57
59
|
class GraphCaptureContext:
|
@@ -223,10 +225,12 @@ class GroupCoordinator:
|
|
223
225
|
use_message_queue_broadcaster: bool = False,
|
224
226
|
group_name: Optional[str] = None,
|
225
227
|
):
|
228
|
+
# Set group info
|
226
229
|
group_name = group_name or "anonymous"
|
227
230
|
self.unique_name = _get_unique_name(group_name)
|
228
231
|
_register_group(self)
|
229
232
|
|
233
|
+
# Set rank info
|
230
234
|
self.rank = torch.distributed.get_rank()
|
231
235
|
self.local_rank = local_rank
|
232
236
|
self.device_group = None
|
@@ -250,14 +254,16 @@ class GroupCoordinator:
|
|
250
254
|
assert self.cpu_group is not None
|
251
255
|
assert self.device_group is not None
|
252
256
|
|
257
|
+
device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
|
253
258
|
if is_cuda_alike():
|
254
|
-
self.device = torch.device(f"cuda:{
|
259
|
+
self.device = torch.device(f"cuda:{device_id}")
|
255
260
|
elif _is_npu:
|
256
|
-
self.device = torch.device(f"npu:{
|
261
|
+
self.device = torch.device(f"npu:{device_id}")
|
257
262
|
else:
|
258
263
|
self.device = torch.device("cpu")
|
259
264
|
self.device_module = torch.get_device_module(self.device)
|
260
265
|
|
266
|
+
# Import communicators
|
261
267
|
self.use_pynccl = use_pynccl
|
262
268
|
self.use_pymscclpp = use_pymscclpp
|
263
269
|
self.use_custom_allreduce = use_custom_allreduce
|
@@ -270,6 +276,9 @@ class GroupCoordinator:
|
|
270
276
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
271
277
|
CustomAllreduce,
|
272
278
|
)
|
279
|
+
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
280
|
+
PyMscclppCommunicator,
|
281
|
+
)
|
273
282
|
from sglang.srt.distributed.device_communicators.pynccl import (
|
274
283
|
PyNcclCommunicator,
|
275
284
|
)
|
@@ -287,10 +296,6 @@ class GroupCoordinator:
|
|
287
296
|
device=self.device,
|
288
297
|
)
|
289
298
|
|
290
|
-
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
291
|
-
PyMscclppCommunicator,
|
292
|
-
)
|
293
|
-
|
294
299
|
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
295
300
|
if use_pymscclpp and self.world_size > 1:
|
296
301
|
self.pymscclpp_comm = PyMscclppCommunicator(
|
@@ -325,30 +330,30 @@ class GroupCoordinator:
|
|
325
330
|
except Exception as e:
|
326
331
|
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
327
332
|
|
333
|
+
# Create communicator for other hardware backends
|
328
334
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
329
335
|
HpuCommunicator,
|
330
336
|
)
|
337
|
+
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
338
|
+
NpuCommunicator,
|
339
|
+
)
|
340
|
+
from sglang.srt.distributed.device_communicators.xpu_communicator import (
|
341
|
+
XpuCommunicator,
|
342
|
+
)
|
331
343
|
|
332
344
|
self.hpu_communicator: Optional[HpuCommunicator] = None
|
333
345
|
if use_hpu_communicator and self.world_size > 1:
|
334
346
|
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
335
347
|
|
336
|
-
from sglang.srt.distributed.device_communicators.xpu_communicator import (
|
337
|
-
XpuCommunicator,
|
338
|
-
)
|
339
|
-
|
340
348
|
self.xpu_communicator: Optional[XpuCommunicator] = None
|
341
349
|
if use_xpu_communicator and self.world_size > 1:
|
342
350
|
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
343
351
|
|
344
|
-
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
345
|
-
NpuCommunicator,
|
346
|
-
)
|
347
|
-
|
348
352
|
self.npu_communicator: Optional[NpuCommunicator] = None
|
349
353
|
if use_npu_communicator and self.world_size > 1:
|
350
354
|
self.npu_communicator = NpuCommunicator(group=self.device_group)
|
351
355
|
|
356
|
+
# Create message queue
|
352
357
|
from sglang.srt.distributed.device_communicators.shm_broadcast import (
|
353
358
|
MessageQueue,
|
354
359
|
)
|
@@ -848,6 +853,11 @@ class GroupCoordinator:
|
|
848
853
|
)
|
849
854
|
return obj_list
|
850
855
|
|
856
|
+
def all_gather_object(self, obj: Any) -> List[Any]:
|
857
|
+
objs = [None] * self.world_size
|
858
|
+
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
|
859
|
+
return objs
|
860
|
+
|
851
861
|
def send_object(self, obj: Any, dst: int) -> None:
|
852
862
|
"""Send the input object list to the destination rank."""
|
853
863
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
|
|
60
60
|
UpdateWeightsFromDistributedReqInput,
|
61
61
|
UpdateWeightsFromTensorReqInput,
|
62
62
|
)
|
63
|
+
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
|
63
64
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
64
65
|
from sglang.srt.managers.template_manager import TemplateManager
|
65
66
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
@@ -672,7 +673,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
672
673
|
if server_args.attention_backend == "flashinfer":
|
673
674
|
assert_pkg_version(
|
674
675
|
"flashinfer_python",
|
675
|
-
"0.
|
676
|
+
"0.3.0",
|
676
677
|
"Please uninstall the old version and "
|
677
678
|
"reinstall the latest version by following the instructions "
|
678
679
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -680,7 +681,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
680
681
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
681
682
|
assert_pkg_version(
|
682
683
|
"sgl-kernel",
|
683
|
-
"0.3.
|
684
|
+
"0.3.7.post1",
|
684
685
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
685
686
|
)
|
686
687
|
|
@@ -814,18 +815,24 @@ def _launch_subprocesses(
|
|
814
815
|
),
|
815
816
|
)
|
816
817
|
detoken_proc.start()
|
818
|
+
if server_args.tokenizer_worker_num > 1:
|
819
|
+
# Launch multi-tokenizer router
|
820
|
+
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
|
817
821
|
|
818
|
-
|
819
|
-
|
822
|
+
# Initialize templates
|
823
|
+
template_manager = None
|
824
|
+
else:
|
825
|
+
# Launch tokenizer process
|
826
|
+
tokenizer_manager = TokenizerManager(server_args, port_args)
|
820
827
|
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
828
|
+
# Initialize templates
|
829
|
+
template_manager = TemplateManager()
|
830
|
+
template_manager.initialize_templates(
|
831
|
+
tokenizer_manager=tokenizer_manager,
|
832
|
+
model_path=server_args.model_path,
|
833
|
+
chat_template=server_args.chat_template,
|
834
|
+
completion_template=server_args.completion_template,
|
835
|
+
)
|
829
836
|
|
830
837
|
# Wait for the model to finish loading
|
831
838
|
scheduler_infos = []
|