sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/srt/layers/sampler.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import List
|
2
|
+
from typing import Dict, List
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch import nn
|
6
6
|
|
7
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
8
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
+
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
9
10
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
11
|
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
11
12
|
|
@@ -35,6 +36,10 @@ class Sampler(nn.Module):
|
|
35
36
|
):
|
36
37
|
logits = logits_output.next_token_logits
|
37
38
|
|
39
|
+
# Apply the custom logit processors if registered in the sampling info.
|
40
|
+
if sampling_info.has_custom_logit_processor:
|
41
|
+
self._apply_custom_logit_processor(logits, sampling_info)
|
42
|
+
|
38
43
|
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
39
44
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
40
45
|
logits = torch.where(
|
@@ -121,6 +126,39 @@ class Sampler(nn.Module):
|
|
121
126
|
|
122
127
|
return batch_next_token_ids
|
123
128
|
|
129
|
+
def _apply_custom_logit_processor(
|
130
|
+
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
131
|
+
):
|
132
|
+
"""Apply custom logit processors to the logits.
|
133
|
+
This function will modify the logits in-place."""
|
134
|
+
|
135
|
+
assert logits.shape[0] == len(sampling_batch_info), (
|
136
|
+
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
137
|
+
f"sampling_batch_info ({len(sampling_batch_info)})"
|
138
|
+
)
|
139
|
+
|
140
|
+
for _, (
|
141
|
+
processor,
|
142
|
+
batch_mask,
|
143
|
+
) in sampling_batch_info.custom_logit_processor.items():
|
144
|
+
# Get the batch indices that need to be processed
|
145
|
+
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
146
|
+
|
147
|
+
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
148
|
+
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
149
|
+
f"sampling_batch_info ({len(sampling_batch_info)})"
|
150
|
+
)
|
151
|
+
|
152
|
+
# Apply the processor to the logits
|
153
|
+
logits[batch_mask] = processor(
|
154
|
+
logits[batch_mask],
|
155
|
+
[sampling_batch_info.custom_params[i] for i in batch_indices],
|
156
|
+
)
|
157
|
+
|
158
|
+
logger.debug(
|
159
|
+
f"Custom logit processor {processor.__class__.__name__} is applied."
|
160
|
+
)
|
161
|
+
|
124
162
|
|
125
163
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
126
164
|
probs: torch.Tensor,
|
@@ -6,13 +6,13 @@ from typing import List, Optional, Sequence, Tuple
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
9
|
-
|
9
|
+
|
10
|
+
from sglang.srt.distributed import (
|
10
11
|
divide,
|
11
12
|
get_tensor_model_parallel_rank,
|
12
13
|
get_tensor_model_parallel_world_size,
|
13
14
|
tensor_model_parallel_all_reduce,
|
14
15
|
)
|
15
|
-
|
16
16
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
17
17
|
from sglang.srt.layers.quantization.base_config import (
|
18
18
|
QuantizationConfig,
|
@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
220
220
|
quant_config: Optional[QuantizationConfig] = None,
|
221
221
|
prefix: str = "",
|
222
222
|
enable_tp: bool = True,
|
223
|
+
use_presharded_weights: bool = False,
|
223
224
|
):
|
224
225
|
super().__init__()
|
225
226
|
self.quant_config = quant_config
|
@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
236
237
|
self.padding_size = padding_size
|
237
238
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
238
239
|
num_added_embeddings = num_embeddings - self.org_vocab_size
|
240
|
+
self.use_presharded_weights = use_presharded_weights
|
241
|
+
if use_presharded_weights:
|
242
|
+
assert (
|
243
|
+
num_added_embeddings == 0
|
244
|
+
), "Lora is not supported with presharded weights."
|
245
|
+
|
239
246
|
self.org_vocab_size_padded = pad_vocab_size(
|
240
247
|
self.org_vocab_size, self.padding_size
|
241
248
|
)
|
@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
447
454
|
start_idx = start_idx // packed_factor
|
448
455
|
shard_size = shard_size // packed_factor
|
449
456
|
else:
|
450
|
-
assert loaded_weight.shape[output_dim] ==
|
457
|
+
assert loaded_weight.shape[output_dim] == (
|
458
|
+
self.org_vocab_size
|
459
|
+
// (self.tp_size if self.use_presharded_weights else 1)
|
460
|
+
)
|
451
461
|
|
452
462
|
# Copy the data.
|
453
|
-
|
463
|
+
if not self.use_presharded_weights:
|
464
|
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
454
465
|
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
|
455
466
|
param[loaded_weight.shape[0] :].data.fill_(0)
|
456
467
|
|
@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
514
525
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
515
526
|
quant_config: Optional[QuantizationConfig] = None,
|
516
527
|
prefix: str = "",
|
528
|
+
use_presharded_weights: bool = False,
|
517
529
|
):
|
518
530
|
super().__init__(
|
519
531
|
num_embeddings,
|
@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
523
535
|
padding_size,
|
524
536
|
quant_config,
|
525
537
|
prefix,
|
538
|
+
use_presharded_weights=use_presharded_weights,
|
526
539
|
)
|
527
540
|
self.quant_config = quant_config
|
528
541
|
if bias:
|
sglang/srt/lora/lora.py
CHANGED
@@ -19,18 +19,11 @@
|
|
19
19
|
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
|
20
20
|
|
21
21
|
|
22
|
-
import json
|
23
|
-
import os
|
24
22
|
import re
|
25
|
-
from typing import Any, Dict, List, Optional, Tuple
|
26
23
|
|
27
|
-
import safetensors.torch
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import
|
31
|
-
ParallelLMHead,
|
32
|
-
VocabParallelEmbedding,
|
33
|
-
)
|
26
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
34
27
|
|
35
28
|
from sglang.srt.layers.linear import (
|
36
29
|
ColumnParallelLinear,
|
@@ -38,7 +31,6 @@ from sglang.srt.layers.linear import (
|
|
38
31
|
QKVParallelLinear,
|
39
32
|
RowParallelLinear,
|
40
33
|
)
|
41
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
42
34
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
43
35
|
|
44
36
|
|
@@ -0,0 +1,46 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2025 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
Configure the logging settings of a server.
|
18
|
+
|
19
|
+
Usage:
|
20
|
+
python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
|
21
|
+
"""
|
22
|
+
|
23
|
+
import argparse
|
24
|
+
|
25
|
+
import requests
|
26
|
+
|
27
|
+
if __name__ == "__main__":
|
28
|
+
parser = argparse.ArgumentParser()
|
29
|
+
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
|
+
parser.add_argument("--log-requests", action="store_true")
|
31
|
+
parser.add_argument(
|
32
|
+
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
33
|
+
)
|
34
|
+
parser.add_argument("--dump-requests-threshold", type=int, default=1000)
|
35
|
+
args = parser.parse_args()
|
36
|
+
|
37
|
+
response = requests.post(
|
38
|
+
args.url + "/configure_logging",
|
39
|
+
json={
|
40
|
+
"log_requests": args.log_requests,
|
41
|
+
"log_requests_level": 1, # Log full requests
|
42
|
+
"dump_requests_folder": args.dump_requests_folder,
|
43
|
+
"dump_requests_threshold": args.dump_requests_threshold,
|
44
|
+
},
|
45
|
+
)
|
46
|
+
assert response.status_code == 200
|
@@ -23,6 +23,7 @@ import psutil
|
|
23
23
|
import setproctitle
|
24
24
|
import zmq
|
25
25
|
|
26
|
+
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
26
27
|
from sglang.srt.managers.io_struct import (
|
27
28
|
TokenizedEmbeddingReqInput,
|
28
29
|
TokenizedGenerateReqInput,
|
@@ -55,6 +56,7 @@ class DataParallelController:
|
|
55
56
|
|
56
57
|
def __init__(self, server_args, port_args) -> None:
|
57
58
|
# Parse args
|
59
|
+
self.max_total_num_tokens = None
|
58
60
|
self.server_args = server_args
|
59
61
|
self.port_args = port_args
|
60
62
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
@@ -63,9 +65,10 @@ class DataParallelController:
|
|
63
65
|
|
64
66
|
# Init inter-process communication
|
65
67
|
self.context = zmq.Context(1 + server_args.dp_size)
|
66
|
-
|
67
|
-
self.
|
68
|
-
|
68
|
+
if server_args.node_rank == 0:
|
69
|
+
self.recv_from_tokenizer = get_zmq_socket(
|
70
|
+
self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
71
|
+
)
|
69
72
|
|
70
73
|
# Dispatch method
|
71
74
|
self.round_robin_counter = 0
|
@@ -75,33 +78,50 @@ class DataParallelController:
|
|
75
78
|
}
|
76
79
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
77
80
|
|
78
|
-
#
|
79
|
-
|
81
|
+
# Launch data parallel workers
|
82
|
+
self.scheduler_procs = []
|
80
83
|
self.workers = [None] * server_args.dp_size
|
81
84
|
|
85
|
+
if not server_args.enable_dp_attention:
|
86
|
+
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
87
|
+
else:
|
88
|
+
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
89
|
+
|
90
|
+
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
91
|
+
if server_args.node_rank == 0:
|
92
|
+
for dp_rank in range(server_args.dp_size):
|
93
|
+
self.workers[dp_rank] = get_zmq_socket(
|
94
|
+
self.context,
|
95
|
+
zmq.PUSH,
|
96
|
+
dp_port_args[dp_rank].scheduler_input_ipc_name,
|
97
|
+
True,
|
98
|
+
)
|
99
|
+
|
100
|
+
self.max_req_input_len = None
|
101
|
+
|
102
|
+
def launch_dp_schedulers(self, server_args, port_args):
|
103
|
+
base_gpu_id = 0
|
104
|
+
|
82
105
|
threads = []
|
83
106
|
sockets = []
|
107
|
+
dp_port_args = []
|
84
108
|
for dp_rank in range(server_args.dp_size):
|
85
109
|
tmp_port_args = PortArgs.init_new(server_args)
|
86
110
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
87
111
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
112
|
+
dp_port_args.append(tmp_port_args)
|
88
113
|
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
tmp_port_args.nccl_port = port_args.nccl_port
|
93
|
-
else:
|
94
|
-
# This port is checked free in PortArgs.init_new.
|
95
|
-
# We hold it first so that the next dp worker gets a different port
|
96
|
-
sockets.append(bind_port(tmp_port_args.nccl_port))
|
114
|
+
# This port is checked free in PortArgs.init_new.
|
115
|
+
# We hold it first so that the next dp worker gets a different port
|
116
|
+
sockets.append(bind_port(tmp_port_args.nccl_port))
|
97
117
|
|
98
118
|
# Create a thread for each worker
|
99
119
|
thread = threading.Thread(
|
100
|
-
target=self.
|
120
|
+
target=self.launch_tensor_parallel_group,
|
101
121
|
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
102
122
|
)
|
103
123
|
threads.append(thread)
|
104
|
-
base_gpu_id +=
|
124
|
+
base_gpu_id += server_args.tp_size
|
105
125
|
|
106
126
|
# Free all sockets before starting the threads to launch TP workers
|
107
127
|
for sock in sockets:
|
@@ -113,26 +133,14 @@ class DataParallelController:
|
|
113
133
|
for thread in threads:
|
114
134
|
thread.join()
|
115
135
|
|
116
|
-
|
117
|
-
self,
|
118
|
-
server_args: ServerArgs,
|
119
|
-
port_args: PortArgs,
|
120
|
-
base_gpu_id: int,
|
121
|
-
dp_rank: int,
|
122
|
-
):
|
123
|
-
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
136
|
+
return dp_port_args
|
124
137
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
server_args,
|
132
|
-
port_args,
|
133
|
-
base_gpu_id,
|
134
|
-
dp_rank,
|
135
|
-
)
|
138
|
+
def launch_dp_attention_schedulers(self, server_args, port_args):
|
139
|
+
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
140
|
+
dp_port_args = []
|
141
|
+
for dp_rank in range(server_args.dp_size):
|
142
|
+
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
|
143
|
+
return dp_port_args
|
136
144
|
|
137
145
|
def launch_tensor_parallel_group(
|
138
146
|
self,
|
@@ -141,8 +149,10 @@ class DataParallelController:
|
|
141
149
|
base_gpu_id: int,
|
142
150
|
dp_rank: int,
|
143
151
|
):
|
152
|
+
if not server_args.enable_dp_attention:
|
153
|
+
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
154
|
+
|
144
155
|
# Launch tensor parallel scheduler processes
|
145
|
-
scheduler_procs = []
|
146
156
|
scheduler_pipe_readers = []
|
147
157
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
148
158
|
tp_rank_range = range(
|
@@ -150,52 +160,39 @@ class DataParallelController:
|
|
150
160
|
tp_size_per_node * (server_args.node_rank + 1),
|
151
161
|
)
|
152
162
|
for tp_rank in tp_rank_range:
|
163
|
+
rank_port_args = port_args
|
164
|
+
|
165
|
+
if server_args.enable_dp_attention:
|
166
|
+
# dp attention has different sharding logic
|
167
|
+
_, _, dp_rank = compute_dp_attention_world_info(
|
168
|
+
server_args.enable_dp_attention,
|
169
|
+
tp_rank,
|
170
|
+
server_args.tp_size,
|
171
|
+
server_args.dp_size,
|
172
|
+
)
|
173
|
+
# compute zmq ports for this dp rank
|
174
|
+
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
175
|
+
# Data parallelism resues the tensor parallelism group,
|
176
|
+
# so all dp ranks should use the same nccl port.
|
177
|
+
rank_port_args.nccl_port = port_args.nccl_port
|
178
|
+
|
153
179
|
reader, writer = mp.Pipe(duplex=False)
|
154
180
|
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
155
181
|
proc = mp.Process(
|
156
182
|
target=run_scheduler_process,
|
157
|
-
args=(server_args,
|
183
|
+
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
158
184
|
)
|
159
185
|
proc.start()
|
160
|
-
scheduler_procs.append(proc)
|
186
|
+
self.scheduler_procs.append(proc)
|
161
187
|
scheduler_pipe_readers.append(reader)
|
162
188
|
|
163
|
-
|
164
|
-
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
165
|
-
)
|
166
|
-
|
167
|
-
# Wait for model to finish loading and get max token nums
|
189
|
+
# Wait for model to finish loading
|
168
190
|
scheduler_info = []
|
169
191
|
for i in range(len(scheduler_pipe_readers)):
|
170
192
|
scheduler_info.append(scheduler_pipe_readers[i].recv())
|
171
193
|
|
172
194
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
173
|
-
|
174
|
-
return send_to
|
175
|
-
|
176
|
-
def launch_tensor_parallel_process(
|
177
|
-
self,
|
178
|
-
server_args: ServerArgs,
|
179
|
-
port_args: PortArgs,
|
180
|
-
base_gpu_id: int,
|
181
|
-
dp_rank: int,
|
182
|
-
):
|
183
|
-
reader, writer = mp.Pipe(duplex=False)
|
184
|
-
gpu_id = base_gpu_id
|
185
|
-
tp_rank = dp_rank
|
186
|
-
proc = mp.Process(
|
187
|
-
target=run_scheduler_process,
|
188
|
-
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
189
|
-
)
|
190
|
-
proc.start()
|
191
|
-
send_to = get_zmq_socket(
|
192
|
-
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
193
|
-
)
|
194
|
-
|
195
|
-
scheduler_info = reader.recv()
|
196
|
-
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
|
197
|
-
|
198
|
-
return send_to
|
195
|
+
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
199
196
|
|
200
197
|
def round_robin_scheduler(self, req):
|
201
198
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
@@ -221,8 +218,8 @@ class DataParallelController:
|
|
221
218
|
):
|
222
219
|
self.dispatching(recv_req)
|
223
220
|
else:
|
224
|
-
# Send other control messages to
|
225
|
-
for worker in self.workers:
|
221
|
+
# Send other control messages to first worker of tp group
|
222
|
+
for worker in self.workers[:: self.server_args.tp_size]:
|
226
223
|
worker.send_pyobj(recv_req)
|
227
224
|
|
228
225
|
|
@@ -238,9 +235,19 @@ def run_data_parallel_controller_process(
|
|
238
235
|
try:
|
239
236
|
controller = DataParallelController(server_args, port_args)
|
240
237
|
pipe_writer.send(
|
241
|
-
{
|
238
|
+
{
|
239
|
+
"status": "ready",
|
240
|
+
"max_total_num_tokens": controller.max_total_num_tokens,
|
241
|
+
"max_req_input_len": controller.max_req_input_len,
|
242
|
+
}
|
242
243
|
)
|
243
|
-
|
244
|
+
if server_args.node_rank == 0:
|
245
|
+
controller.event_loop()
|
246
|
+
for proc in controller.scheduler_procs:
|
247
|
+
proc.join()
|
248
|
+
logger.error(
|
249
|
+
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
250
|
+
)
|
244
251
|
except Exception:
|
245
252
|
traceback = get_exception_traceback()
|
246
253
|
logger.error(f"DataParallelController hit an exception: {traceback}")
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import dataclasses
|
17
17
|
import logging
|
18
|
+
import os
|
18
19
|
import signal
|
19
20
|
from collections import OrderedDict
|
20
21
|
from typing import Dict, List, Union
|
@@ -35,6 +36,12 @@ from sglang.utils import find_printable_text, get_exception_traceback
|
|
35
36
|
|
36
37
|
logger = logging.getLogger(__name__)
|
37
38
|
|
39
|
+
# Maximum number of request states that detokenizer can hold. When exceeded,
|
40
|
+
# oldest request states will be evicted. Default: 65536 (1<<16).
|
41
|
+
# For more details, see: https://github.com/sgl-project/sglang/issues/2812
|
42
|
+
# Use power of 2 values for better memory allocation.
|
43
|
+
DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16))
|
44
|
+
|
38
45
|
|
39
46
|
@dataclasses.dataclass
|
40
47
|
class DecodeStatus:
|
@@ -58,10 +65,10 @@ class DetokenizerManager:
|
|
58
65
|
# Init inter-process communication
|
59
66
|
context = zmq.Context(2)
|
60
67
|
self.recv_from_scheduler = get_zmq_socket(
|
61
|
-
context, zmq.PULL, port_args.detokenizer_ipc_name
|
68
|
+
context, zmq.PULL, port_args.detokenizer_ipc_name, True
|
62
69
|
)
|
63
70
|
self.send_to_tokenizer = get_zmq_socket(
|
64
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
71
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
65
72
|
)
|
66
73
|
|
67
74
|
if server_args.skip_tokenizer_init:
|
@@ -71,9 +78,10 @@ class DetokenizerManager:
|
|
71
78
|
server_args.tokenizer_path,
|
72
79
|
tokenizer_mode=server_args.tokenizer_mode,
|
73
80
|
trust_remote_code=server_args.trust_remote_code,
|
81
|
+
revision=server_args.revision,
|
74
82
|
)
|
75
83
|
|
76
|
-
self.decode_status = LimitedCapacityDict()
|
84
|
+
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
77
85
|
|
78
86
|
def trim_matched_stop(
|
79
87
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
@@ -155,7 +163,17 @@ class DetokenizerManager:
|
|
155
163
|
# Incremental decoding
|
156
164
|
output_strs = []
|
157
165
|
for i in range(bs):
|
158
|
-
|
166
|
+
try:
|
167
|
+
s = self.decode_status[recv_obj.rids[i]]
|
168
|
+
except KeyError:
|
169
|
+
raise RuntimeError(
|
170
|
+
f"Decode status not found for request {recv_obj.rids[i]}. "
|
171
|
+
"It may be due to the request being evicted from the decode status due to memory pressure. "
|
172
|
+
"Please increase the maximum number of requests by setting "
|
173
|
+
"the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. "
|
174
|
+
f"The current value is {DETOKENIZER_MAX_STATES}. "
|
175
|
+
"For more details, see: https://github.com/sgl-project/sglang/issues/2812"
|
176
|
+
)
|
159
177
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
160
178
|
if recv_obj.finished_reasons[i] is None:
|
161
179
|
# Streaming chunk: update the decode status
|
@@ -181,8 +199,6 @@ class DetokenizerManager:
|
|
181
199
|
finished_reasons=recv_obj.finished_reasons,
|
182
200
|
output_strs=output_strs,
|
183
201
|
prompt_tokens=recv_obj.prompt_tokens,
|
184
|
-
origin_input_ids=recv_obj.origin_input_ids,
|
185
|
-
output_ids=recv_obj.output_ids,
|
186
202
|
completion_tokens=recv_obj.completion_tokens,
|
187
203
|
cached_tokens=recv_obj.cached_tokens,
|
188
204
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
@@ -193,13 +209,12 @@ class DetokenizerManager:
|
|
193
209
|
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
194
210
|
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
195
211
|
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
196
|
-
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
197
212
|
)
|
198
213
|
)
|
199
214
|
|
200
215
|
|
201
216
|
class LimitedCapacityDict(OrderedDict):
|
202
|
-
def __init__(self, capacity
|
217
|
+
def __init__(self, capacity: int, *args, **kwargs):
|
203
218
|
super().__init__(*args, **kwargs)
|
204
219
|
self.capacity = capacity
|
205
220
|
|