sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +12 -6
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +43 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +36 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +75 -68
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +2 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +33 -7
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +26 -10
- sglang/srt/models/gpt_oss.py +0 -14
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +14 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +83 -78
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -61,6 +61,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
|
61
61
|
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
62
62
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
63
63
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
64
|
+
from sglang.srt.layers.moe import initialize_moe_config
|
64
65
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
65
66
|
from sglang.srt.managers.scheduler import Scheduler
|
66
67
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -509,6 +510,8 @@ def latency_test(
|
|
509
510
|
bench_args,
|
510
511
|
tp_rank,
|
511
512
|
):
|
513
|
+
initialize_moe_config(server_args)
|
514
|
+
|
512
515
|
# Set CPU affinity
|
513
516
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
514
517
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank)
|
sglang/srt/configs/__init__.py
CHANGED
@@ -5,6 +5,7 @@ from sglang.srt.configs.exaone import ExaoneConfig
|
|
5
5
|
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
6
6
|
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
7
7
|
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
8
|
+
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
8
9
|
from sglang.srt.configs.step3_vl import (
|
9
10
|
Step3TextConfig,
|
10
11
|
Step3VisionEncoderConfig,
|
@@ -16,6 +17,7 @@ __all__ = [
|
|
16
17
|
"ChatGLMConfig",
|
17
18
|
"DbrxConfig",
|
18
19
|
"DeepseekVL2Config",
|
20
|
+
"LongcatFlashConfig",
|
19
21
|
"MultiModalityConfig",
|
20
22
|
"KimiVLConfig",
|
21
23
|
"MoonViTConfig",
|
@@ -0,0 +1,104 @@
|
|
1
|
+
from transformers.configuration_utils import PretrainedConfig
|
2
|
+
from transformers.utils import logging
|
3
|
+
|
4
|
+
logger = logging.get_logger(__name__)
|
5
|
+
|
6
|
+
FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
7
|
+
|
8
|
+
|
9
|
+
class LongcatFlashConfig(PretrainedConfig):
|
10
|
+
model_type = "longcat_flash"
|
11
|
+
keys_to_ignore_at_inference = ["past_key_values"]
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
vocab_size=131072,
|
16
|
+
hidden_size=6144,
|
17
|
+
intermediate_size=None,
|
18
|
+
ffn_hidden_size=12288,
|
19
|
+
expert_ffn_hidden_size=2048,
|
20
|
+
num_layers=28,
|
21
|
+
num_hidden_layers=None,
|
22
|
+
num_attention_heads=64,
|
23
|
+
ep_size=1,
|
24
|
+
kv_lora_rank=512,
|
25
|
+
q_lora_rank=1536,
|
26
|
+
qk_rope_head_dim=128,
|
27
|
+
qk_nope_head_dim=128,
|
28
|
+
v_head_dim=128,
|
29
|
+
n_routed_experts=512,
|
30
|
+
moe_topk=12,
|
31
|
+
norm_topk_prob=False,
|
32
|
+
max_position_embeddings=131072,
|
33
|
+
rms_norm_eps=1e-05,
|
34
|
+
use_cache=True,
|
35
|
+
pad_token_id=None,
|
36
|
+
bos_token_id=1,
|
37
|
+
eos_token_id=2,
|
38
|
+
pretraining_tp=1,
|
39
|
+
tie_word_embeddings=False,
|
40
|
+
rope_theta=10000000.0,
|
41
|
+
rope_scaling=None,
|
42
|
+
attention_bias=False,
|
43
|
+
attention_dropout=0.0,
|
44
|
+
mla_scale_q_lora=True,
|
45
|
+
mla_scale_kv_lora=True,
|
46
|
+
torch_dtype="bfloat16",
|
47
|
+
params_dtype="bfloat16",
|
48
|
+
rounter_params_dtype="float32",
|
49
|
+
router_bias=False,
|
50
|
+
topk_method=None,
|
51
|
+
routed_scaling_factor=6.0,
|
52
|
+
zero_expert_num=256,
|
53
|
+
zero_expert_type="identity",
|
54
|
+
nextn_use_scmoe=False,
|
55
|
+
num_nextn_predict_layers=1,
|
56
|
+
**kwargs,
|
57
|
+
):
|
58
|
+
super().__init__(
|
59
|
+
pad_token_id=pad_token_id,
|
60
|
+
bos_token_id=bos_token_id,
|
61
|
+
eos_token_id=eos_token_id,
|
62
|
+
tie_word_embeddings=tie_word_embeddings,
|
63
|
+
torch_dtype=torch_dtype,
|
64
|
+
params_dtype=params_dtype,
|
65
|
+
rounter_params_dtype=rounter_params_dtype,
|
66
|
+
topk_method=topk_method,
|
67
|
+
router_bias=router_bias,
|
68
|
+
nextn_use_scmoe=nextn_use_scmoe,
|
69
|
+
num_nextn_predict_layers=num_nextn_predict_layers,
|
70
|
+
**kwargs,
|
71
|
+
)
|
72
|
+
self.vocab_size = vocab_size
|
73
|
+
self.max_position_embeddings = max_position_embeddings
|
74
|
+
self.hidden_size = hidden_size
|
75
|
+
self.num_hidden_layers = (
|
76
|
+
num_hidden_layers if num_hidden_layers is not None else num_layers
|
77
|
+
)
|
78
|
+
self.intermediate_size = (
|
79
|
+
intermediate_size if intermediate_size is not None else ffn_hidden_size
|
80
|
+
)
|
81
|
+
self.moe_intermediate_size = expert_ffn_hidden_size
|
82
|
+
self.num_attention_heads = num_attention_heads
|
83
|
+
self.ep_size = ep_size
|
84
|
+
self.kv_lora_rank = kv_lora_rank
|
85
|
+
self.q_lora_rank = q_lora_rank
|
86
|
+
self.qk_rope_head_dim = qk_rope_head_dim
|
87
|
+
self.v_head_dim = v_head_dim
|
88
|
+
self.qk_nope_head_dim = qk_nope_head_dim
|
89
|
+
self.n_routed_experts = n_routed_experts
|
90
|
+
self.moe_topk = moe_topk
|
91
|
+
self.norm_topk_prob = norm_topk_prob
|
92
|
+
self.rms_norm_eps = rms_norm_eps
|
93
|
+
self.pretraining_tp = pretraining_tp
|
94
|
+
self.use_cache = use_cache
|
95
|
+
self.rope_theta = rope_theta
|
96
|
+
self.rope_scaling = rope_scaling
|
97
|
+
self.attention_bias = attention_bias
|
98
|
+
self.attention_dropout = attention_dropout
|
99
|
+
self.mla_scale_q_lora = mla_scale_q_lora
|
100
|
+
self.mla_scale_kv_lora = mla_scale_kv_lora
|
101
|
+
self.zero_expert_num = zero_expert_num
|
102
|
+
self.zero_expert_type = zero_expert_type
|
103
|
+
self.routed_scaling_factor = routed_scaling_factor
|
104
|
+
self.hidden_act = "silu"
|
@@ -132,6 +132,13 @@ class ModelConfig:
|
|
132
132
|
if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
133
133
|
self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
|
134
134
|
|
135
|
+
if (
|
136
|
+
is_draft_model
|
137
|
+
and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
|
138
|
+
):
|
139
|
+
self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
|
140
|
+
self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers
|
141
|
+
|
135
142
|
if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
|
136
143
|
self.hf_config.architectures[0] = "MiMoMTP"
|
137
144
|
if (
|
@@ -199,6 +206,8 @@ class ModelConfig:
|
|
199
206
|
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
200
207
|
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
201
208
|
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
|
209
|
+
or "LongcatFlashForCausalLM" in self.hf_config.architectures
|
210
|
+
or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
|
202
211
|
):
|
203
212
|
self.head_dim = 256
|
204
213
|
self.attention_arch = AttentionArch.MLA
|
@@ -270,6 +279,9 @@ class ModelConfig:
|
|
270
279
|
self.num_key_value_heads = self.num_attention_heads
|
271
280
|
self.hidden_size = self.hf_text_config.hidden_size
|
272
281
|
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
282
|
+
self.num_attention_layers = self.num_hidden_layers
|
283
|
+
if "LongcatFlashForCausalLM" in self.hf_config.architectures:
|
284
|
+
self.num_attention_layers = self.num_hidden_layers * 2
|
273
285
|
self.num_nextn_predict_layers = getattr(
|
274
286
|
self.hf_text_config, "num_nextn_predict_layers", None
|
275
287
|
)
|
sglang/srt/connector/__init__.py
CHANGED
@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum):
|
|
20
20
|
KV = "KV"
|
21
21
|
|
22
22
|
|
23
|
-
def create_remote_connector(url,
|
23
|
+
def create_remote_connector(url, **kwargs) -> BaseConnector:
|
24
24
|
connector_type = parse_connector_type(url)
|
25
25
|
if connector_type == "redis":
|
26
26
|
return RedisConnector(url)
|
@@ -20,9 +20,8 @@ class BaseConnector(ABC):
|
|
20
20
|
<connector_type://<host>:<port>/<model_name>/files/<filename>
|
21
21
|
"""
|
22
22
|
|
23
|
-
def __init__(self, url: str
|
23
|
+
def __init__(self, url: str):
|
24
24
|
self.url = url
|
25
|
-
self.device = device
|
26
25
|
self.closed = False
|
27
26
|
self.local_dir = tempfile.mkdtemp()
|
28
27
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
sglang/srt/connector/redis.py
CHANGED
@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
|
|
15
15
|
|
16
16
|
class RedisConnector(BaseKVConnector):
|
17
17
|
|
18
|
-
def __init__(self, url: str
|
18
|
+
def __init__(self, url: str):
|
19
19
|
import redis
|
20
20
|
|
21
|
-
super().__init__(url
|
21
|
+
super().__init__(url)
|
22
22
|
parsed_url = urlparse(url)
|
23
23
|
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
24
24
|
self.model_name = parsed_url.path.lstrip("/")
|
@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
|
|
19
19
|
|
20
20
|
class SafeDeserializer(Deserializer):
|
21
21
|
|
22
|
-
def __init__(self
|
23
|
-
|
22
|
+
def __init__(self):
|
23
|
+
# TODO: dtype options
|
24
|
+
super().__init__(torch.float32)
|
24
25
|
|
25
26
|
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
26
|
-
return load(bytes(b))["tensor_bytes"]
|
27
|
+
return load(bytes(b))["tensor_bytes"]
|
27
28
|
|
28
29
|
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
29
30
|
return self.from_bytes_normal(b)
|
@@ -1,6 +1,12 @@
|
|
1
|
+
import concurrent.futures
|
1
2
|
import logging
|
3
|
+
from typing import List, Tuple
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import numpy.typing as npt
|
2
7
|
|
3
8
|
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
9
|
+
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
4
10
|
from sglang.srt.disaggregation.mooncake.conn import (
|
5
11
|
MooncakeKVBootstrapServer,
|
6
12
|
MooncakeKVManager,
|
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
|
|
29
35
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
30
36
|
)
|
31
37
|
|
38
|
+
def send_kvcache(
|
39
|
+
self,
|
40
|
+
mooncake_session_id: str,
|
41
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
42
|
+
dst_kv_ptrs: list[int],
|
43
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
44
|
+
executor: concurrent.futures.ThreadPoolExecutor,
|
45
|
+
):
|
46
|
+
# Group by indices
|
47
|
+
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
48
|
+
prefill_kv_indices, dst_kv_indices
|
49
|
+
)
|
50
|
+
|
51
|
+
num_layers = len(self.kv_args.kv_data_ptrs)
|
52
|
+
layers_params = [
|
53
|
+
(
|
54
|
+
self.kv_args.kv_data_ptrs[layer_id],
|
55
|
+
dst_kv_ptrs[layer_id],
|
56
|
+
self.kv_args.kv_item_lens[layer_id],
|
57
|
+
)
|
58
|
+
for layer_id in range(num_layers)
|
59
|
+
]
|
60
|
+
|
61
|
+
def set_transfer_blocks(
|
62
|
+
src_ptr: int, dst_ptr: int, item_len: int
|
63
|
+
) -> List[Tuple[int, int, int]]:
|
64
|
+
transfer_blocks = []
|
65
|
+
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
66
|
+
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
67
|
+
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
68
|
+
length = item_len * len(prefill_index)
|
69
|
+
transfer_blocks.append((src_addr, dst_addr, length))
|
70
|
+
return transfer_blocks
|
71
|
+
|
72
|
+
# Worker function for processing a single layer
|
73
|
+
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
74
|
+
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
75
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
76
|
+
|
77
|
+
# Worker function for processing all layers in a batch
|
78
|
+
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
79
|
+
transfer_blocks = []
|
80
|
+
for src_ptr, dst_ptr, item_len in layers_params:
|
81
|
+
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
82
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
83
|
+
|
84
|
+
if self.enable_custom_mem_pool:
|
85
|
+
futures = [
|
86
|
+
executor.submit(
|
87
|
+
process_layer,
|
88
|
+
src_ptr,
|
89
|
+
dst_ptr,
|
90
|
+
item_len,
|
91
|
+
)
|
92
|
+
for (src_ptr, dst_ptr, item_len) in layers_params
|
93
|
+
]
|
94
|
+
for future in concurrent.futures.as_completed(futures):
|
95
|
+
status = future.result()
|
96
|
+
if status != 0:
|
97
|
+
for f in futures:
|
98
|
+
f.cancel()
|
99
|
+
return status
|
100
|
+
else:
|
101
|
+
# Combining all layers' params in one batch transfer is more efficient
|
102
|
+
# compared to using multiple threads
|
103
|
+
return process_layers(layers_params)
|
104
|
+
|
105
|
+
return 0
|
106
|
+
|
32
107
|
|
33
108
|
class AscendKVSender(MooncakeKVSender):
|
34
109
|
pass
|
@@ -6,7 +6,6 @@ from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
|
6
6
|
|
7
7
|
@dataclasses.dataclass
|
8
8
|
class LBArgs:
|
9
|
-
rust_lb: bool = False
|
10
9
|
host: str = "0.0.0.0"
|
11
10
|
port: int = 8000
|
12
11
|
policy: str = "random"
|
@@ -17,11 +16,6 @@ class LBArgs:
|
|
17
16
|
|
18
17
|
@staticmethod
|
19
18
|
def add_cli_args(parser: argparse.ArgumentParser):
|
20
|
-
parser.add_argument(
|
21
|
-
"--rust-lb",
|
22
|
-
action="store_true",
|
23
|
-
help="Deprecated, please use SGLang Router instead, this argument will have no effect.",
|
24
|
-
)
|
25
19
|
parser.add_argument(
|
26
20
|
"--host",
|
27
21
|
type=str,
|
@@ -92,7 +86,6 @@ class LBArgs:
|
|
92
86
|
]
|
93
87
|
|
94
88
|
return cls(
|
95
|
-
rust_lb=args.rust_lb,
|
96
89
|
host=args.host,
|
97
90
|
port=args.port,
|
98
91
|
policy=args.policy,
|
@@ -102,12 +95,6 @@ class LBArgs:
|
|
102
95
|
timeout=args.timeout,
|
103
96
|
)
|
104
97
|
|
105
|
-
def __post_init__(self):
|
106
|
-
if not self.rust_lb:
|
107
|
-
assert (
|
108
|
-
self.policy == "random"
|
109
|
-
), "Only random policy is supported for Python load balancer"
|
110
|
-
|
111
98
|
|
112
99
|
def main():
|
113
100
|
parser = argparse.ArgumentParser(
|
@@ -7,6 +7,7 @@ import dataclasses
|
|
7
7
|
import logging
|
8
8
|
import random
|
9
9
|
import urllib
|
10
|
+
from http import HTTPStatus
|
10
11
|
from itertools import chain
|
11
12
|
from typing import List, Optional
|
12
13
|
|
@@ -262,14 +263,38 @@ async def get_server_info():
|
|
262
263
|
|
263
264
|
@app.get("/get_model_info")
|
264
265
|
async def get_model_info():
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
266
|
+
global load_balancer
|
267
|
+
|
268
|
+
if not load_balancer or not load_balancer.prefill_servers:
|
269
|
+
raise HTTPException(
|
270
|
+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
271
|
+
detail="There is no server registered",
|
272
|
+
)
|
273
|
+
|
274
|
+
target_server_url = load_balancer.prefill_servers[0]
|
275
|
+
endpoint_url = f"{target_server_url}/get_model_info"
|
276
|
+
|
277
|
+
async with aiohttp.ClientSession() as session:
|
278
|
+
try:
|
279
|
+
async with session.get(endpoint_url) as response:
|
280
|
+
if response.status != 200:
|
281
|
+
error_text = await response.text()
|
282
|
+
raise HTTPException(
|
283
|
+
status_code=HTTPStatus.BAD_GATEWAY,
|
284
|
+
detail=(
|
285
|
+
f"Failed to get model info from {target_server_url}"
|
286
|
+
f"Status: {response.status}, Response: {error_text}"
|
287
|
+
),
|
288
|
+
)
|
289
|
+
|
290
|
+
model_info_json = await response.json()
|
291
|
+
return ORJSONResponse(content=model_info_json)
|
292
|
+
|
293
|
+
except aiohttp.ClientError as e:
|
294
|
+
raise HTTPException(
|
295
|
+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
296
|
+
detail=f"Failed to get model info from backend",
|
297
|
+
)
|
273
298
|
|
274
299
|
|
275
300
|
@app.post("/generate")
|
@@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
567
567
|
# Move the chunked request out of the batch so that we can merge
|
568
568
|
# only finished requests to running_batch.
|
569
569
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
570
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
570
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
571
571
|
if self.enable_overlap:
|
572
572
|
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
573
573
|
self.chunked_req.tmp_end_idx = min(
|
@@ -52,6 +52,8 @@ from sglang.srt.utils import (
|
|
52
52
|
|
53
53
|
_is_npu = is_npu()
|
54
54
|
|
55
|
+
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
56
|
+
|
55
57
|
|
56
58
|
@dataclass
|
57
59
|
class GraphCaptureContext:
|
@@ -223,10 +225,12 @@ class GroupCoordinator:
|
|
223
225
|
use_message_queue_broadcaster: bool = False,
|
224
226
|
group_name: Optional[str] = None,
|
225
227
|
):
|
228
|
+
# Set group info
|
226
229
|
group_name = group_name or "anonymous"
|
227
230
|
self.unique_name = _get_unique_name(group_name)
|
228
231
|
_register_group(self)
|
229
232
|
|
233
|
+
# Set rank info
|
230
234
|
self.rank = torch.distributed.get_rank()
|
231
235
|
self.local_rank = local_rank
|
232
236
|
self.device_group = None
|
@@ -250,14 +254,16 @@ class GroupCoordinator:
|
|
250
254
|
assert self.cpu_group is not None
|
251
255
|
assert self.device_group is not None
|
252
256
|
|
257
|
+
device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
|
253
258
|
if is_cuda_alike():
|
254
|
-
self.device = torch.device(f"cuda:{
|
259
|
+
self.device = torch.device(f"cuda:{device_id}")
|
255
260
|
elif _is_npu:
|
256
|
-
self.device = torch.device(f"npu:{
|
261
|
+
self.device = torch.device(f"npu:{device_id}")
|
257
262
|
else:
|
258
263
|
self.device = torch.device("cpu")
|
259
264
|
self.device_module = torch.get_device_module(self.device)
|
260
265
|
|
266
|
+
# Import communicators
|
261
267
|
self.use_pynccl = use_pynccl
|
262
268
|
self.use_pymscclpp = use_pymscclpp
|
263
269
|
self.use_custom_allreduce = use_custom_allreduce
|
@@ -270,6 +276,9 @@ class GroupCoordinator:
|
|
270
276
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
271
277
|
CustomAllreduce,
|
272
278
|
)
|
279
|
+
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
280
|
+
PyMscclppCommunicator,
|
281
|
+
)
|
273
282
|
from sglang.srt.distributed.device_communicators.pynccl import (
|
274
283
|
PyNcclCommunicator,
|
275
284
|
)
|
@@ -287,10 +296,6 @@ class GroupCoordinator:
|
|
287
296
|
device=self.device,
|
288
297
|
)
|
289
298
|
|
290
|
-
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
291
|
-
PyMscclppCommunicator,
|
292
|
-
)
|
293
|
-
|
294
299
|
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
295
300
|
if use_pymscclpp and self.world_size > 1:
|
296
301
|
self.pymscclpp_comm = PyMscclppCommunicator(
|
@@ -325,30 +330,30 @@ class GroupCoordinator:
|
|
325
330
|
except Exception as e:
|
326
331
|
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
327
332
|
|
333
|
+
# Create communicator for other hardware backends
|
328
334
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
329
335
|
HpuCommunicator,
|
330
336
|
)
|
337
|
+
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
338
|
+
NpuCommunicator,
|
339
|
+
)
|
340
|
+
from sglang.srt.distributed.device_communicators.xpu_communicator import (
|
341
|
+
XpuCommunicator,
|
342
|
+
)
|
331
343
|
|
332
344
|
self.hpu_communicator: Optional[HpuCommunicator] = None
|
333
345
|
if use_hpu_communicator and self.world_size > 1:
|
334
346
|
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
335
347
|
|
336
|
-
from sglang.srt.distributed.device_communicators.xpu_communicator import (
|
337
|
-
XpuCommunicator,
|
338
|
-
)
|
339
|
-
|
340
348
|
self.xpu_communicator: Optional[XpuCommunicator] = None
|
341
349
|
if use_xpu_communicator and self.world_size > 1:
|
342
350
|
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
343
351
|
|
344
|
-
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
345
|
-
NpuCommunicator,
|
346
|
-
)
|
347
|
-
|
348
352
|
self.npu_communicator: Optional[NpuCommunicator] = None
|
349
353
|
if use_npu_communicator and self.world_size > 1:
|
350
354
|
self.npu_communicator = NpuCommunicator(group=self.device_group)
|
351
355
|
|
356
|
+
# Create message queue
|
352
357
|
from sglang.srt.distributed.device_communicators.shm_broadcast import (
|
353
358
|
MessageQueue,
|
354
359
|
)
|
@@ -848,6 +853,11 @@ class GroupCoordinator:
|
|
848
853
|
)
|
849
854
|
return obj_list
|
850
855
|
|
856
|
+
def all_gather_object(self, obj: Any) -> List[Any]:
|
857
|
+
objs = [None] * self.world_size
|
858
|
+
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
|
859
|
+
return objs
|
860
|
+
|
851
861
|
def send_object(self, obj: Any, dst: int) -> None:
|
852
862
|
"""Send the input object list to the destination rank."""
|
853
863
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
|
|
60
60
|
UpdateWeightsFromDistributedReqInput,
|
61
61
|
UpdateWeightsFromTensorReqInput,
|
62
62
|
)
|
63
|
+
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
|
63
64
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
64
65
|
from sglang.srt.managers.template_manager import TemplateManager
|
65
66
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
@@ -672,7 +673,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
672
673
|
if server_args.attention_backend == "flashinfer":
|
673
674
|
assert_pkg_version(
|
674
675
|
"flashinfer_python",
|
675
|
-
"0.
|
676
|
+
"0.3.0",
|
676
677
|
"Please uninstall the old version and "
|
677
678
|
"reinstall the latest version by following the instructions "
|
678
679
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -680,7 +681,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
680
681
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
681
682
|
assert_pkg_version(
|
682
683
|
"sgl-kernel",
|
683
|
-
"0.3.7",
|
684
|
+
"0.3.7.post1",
|
684
685
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
685
686
|
)
|
686
687
|
|
@@ -814,18 +815,24 @@ def _launch_subprocesses(
|
|
814
815
|
),
|
815
816
|
)
|
816
817
|
detoken_proc.start()
|
818
|
+
if server_args.tokenizer_worker_num > 1:
|
819
|
+
# Launch multi-tokenizer router
|
820
|
+
tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
|
817
821
|
|
818
|
-
|
819
|
-
|
822
|
+
# Initialize templates
|
823
|
+
template_manager = None
|
824
|
+
else:
|
825
|
+
# Launch tokenizer process
|
826
|
+
tokenizer_manager = TokenizerManager(server_args, port_args)
|
820
827
|
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
828
|
+
# Initialize templates
|
829
|
+
template_manager = TemplateManager()
|
830
|
+
template_manager.initialize_templates(
|
831
|
+
tokenizer_manager=tokenizer_manager,
|
832
|
+
model_path=server_args.model_path,
|
833
|
+
chat_template=server_args.chat_template,
|
834
|
+
completion_template=server_args.completion_template,
|
835
|
+
)
|
829
836
|
|
830
837
|
# Wait for the model to finish loading
|
831
838
|
scheduler_infos = []
|