sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
|
26
26
|
context_attention_fwd,
|
27
27
|
)
|
28
28
|
|
29
|
-
|
29
|
+
is_cuda_available = torch.cuda.is_available()
|
30
|
+
if is_cuda_available:
|
31
|
+
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
30
32
|
|
31
33
|
|
32
34
|
@triton.jit
|
@@ -286,12 +288,12 @@ def extend_attention_fwd(
|
|
286
288
|
BLOCK_DPE = 0
|
287
289
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
288
290
|
|
289
|
-
if CUDA_CAPABILITY[0] >= 9:
|
291
|
+
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
290
292
|
if Lq <= 256:
|
291
293
|
BLOCK_M, BLOCK_N = (128, 64)
|
292
294
|
else:
|
293
295
|
BLOCK_M, BLOCK_N = (32, 64)
|
294
|
-
elif CUDA_CAPABILITY[0] >= 8:
|
296
|
+
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
295
297
|
if Lq <= 128:
|
296
298
|
BLOCK_M, BLOCK_N = (128, 128)
|
297
299
|
elif Lq <= 256:
|
@@ -24,7 +24,9 @@ import torch
|
|
24
24
|
import triton
|
25
25
|
import triton.language as tl
|
26
26
|
|
27
|
-
|
27
|
+
is_cuda_available = torch.cuda.is_available()
|
28
|
+
if is_cuda_available:
|
29
|
+
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
28
30
|
|
29
31
|
|
30
32
|
@triton.jit
|
@@ -145,7 +147,7 @@ def _fwd_kernel(
|
|
145
147
|
|
146
148
|
|
147
149
|
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
148
|
-
if CUDA_CAPABILITY[0] >= 8:
|
150
|
+
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
149
151
|
BLOCK = 128
|
150
152
|
else:
|
151
153
|
BLOCK = 64
|
sglang/srt/layers/sampler.py
CHANGED
@@ -21,6 +21,10 @@ logger = logging.getLogger(__name__)
|
|
21
21
|
|
22
22
|
|
23
23
|
class Sampler(nn.Module):
|
24
|
+
def __init__(self):
|
25
|
+
super().__init__()
|
26
|
+
self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"]
|
27
|
+
|
24
28
|
def forward(
|
25
29
|
self,
|
26
30
|
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
@@ -36,13 +40,13 @@ class Sampler(nn.Module):
|
|
36
40
|
logits = None
|
37
41
|
del logits
|
38
42
|
|
39
|
-
if torch.any(torch.isnan(probs)):
|
43
|
+
if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
|
40
44
|
logger.warning("Detected errors during sampling! NaN in the probability.")
|
41
45
|
probs = torch.where(
|
42
46
|
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
43
47
|
)
|
44
48
|
|
45
|
-
if sampling_info.
|
49
|
+
if sampling_info.is_all_greedy:
|
46
50
|
# Use torch.argmax if all requests use greedy sampling
|
47
51
|
batch_next_token_ids = torch.argmax(probs, -1)
|
48
52
|
elif global_server_args_dict["sampling_backend"] == "flashinfer":
|
@@ -0,0 +1,177 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""A controller that dispatches requests to multiple data parallel workers."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import multiprocessing as mp
|
20
|
+
from enum import Enum, auto
|
21
|
+
|
22
|
+
import zmq
|
23
|
+
|
24
|
+
from sglang.srt.managers.io_struct import (
|
25
|
+
TokenizedEmbeddingReqInput,
|
26
|
+
TokenizedGenerateReqInput,
|
27
|
+
TokenizedRewardReqInput,
|
28
|
+
)
|
29
|
+
from sglang.srt.managers.scheduler import run_scheduler_process
|
30
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
+
from sglang.srt.utils import (
|
32
|
+
configure_logger,
|
33
|
+
kill_parent_process,
|
34
|
+
suppress_other_loggers,
|
35
|
+
)
|
36
|
+
from sglang.utils import get_exception_traceback
|
37
|
+
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
|
40
|
+
|
41
|
+
class LoadBalanceMethod(Enum):
|
42
|
+
"""Load balance method."""
|
43
|
+
|
44
|
+
ROUND_ROBIN = auto()
|
45
|
+
SHORTEST_QUEUE = auto()
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def from_str(cls, method: str):
|
49
|
+
method = method.upper()
|
50
|
+
try:
|
51
|
+
return cls[method]
|
52
|
+
except KeyError as exc:
|
53
|
+
raise ValueError(f"Invalid load balance method: {method}") from exc
|
54
|
+
|
55
|
+
|
56
|
+
class DataParallelController:
|
57
|
+
"""A controller that dispatches requests to multiple data parallel workers."""
|
58
|
+
|
59
|
+
def __init__(self, server_args, port_args) -> None:
|
60
|
+
# Parse args
|
61
|
+
self.server_args = server_args
|
62
|
+
self.port_args = port_args
|
63
|
+
self.load_balance_method = LoadBalanceMethod.from_str(
|
64
|
+
server_args.load_balance_method
|
65
|
+
)
|
66
|
+
|
67
|
+
# Init inter-process communication
|
68
|
+
self.context = zmq.Context(1 + server_args.dp_size)
|
69
|
+
self.recv_from_tokenizer = self.context.socket(zmq.PULL)
|
70
|
+
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
|
71
|
+
|
72
|
+
# Dispatch method
|
73
|
+
self.round_robin_counter = 0
|
74
|
+
dispatch_lookup = {
|
75
|
+
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
76
|
+
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
77
|
+
}
|
78
|
+
self.dispatching = dispatch_lookup[self.load_balance_method]
|
79
|
+
|
80
|
+
# Start data parallel workers
|
81
|
+
base_gpu_id = 0
|
82
|
+
self.workers = []
|
83
|
+
for dp_rank in range(server_args.dp_size):
|
84
|
+
tmp_port_args = PortArgs.init_new(server_args)
|
85
|
+
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
86
|
+
|
87
|
+
send_to = self.launch_tensor_parallel_group(
|
88
|
+
server_args,
|
89
|
+
tmp_port_args,
|
90
|
+
base_gpu_id,
|
91
|
+
dp_rank,
|
92
|
+
)
|
93
|
+
|
94
|
+
self.workers.append(send_to)
|
95
|
+
base_gpu_id += server_args.tp_size
|
96
|
+
|
97
|
+
def launch_tensor_parallel_group(
|
98
|
+
self,
|
99
|
+
server_args: ServerArgs,
|
100
|
+
port_args: PortArgs,
|
101
|
+
base_gpu_id: int,
|
102
|
+
dp_rank: int,
|
103
|
+
):
|
104
|
+
# Launch tensor parallel scheduler processes
|
105
|
+
scheduler_procs = []
|
106
|
+
scheduler_pipe_readers = []
|
107
|
+
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
108
|
+
tp_rank_range = range(
|
109
|
+
tp_size_per_node * server_args.node_rank,
|
110
|
+
tp_size_per_node * (server_args.node_rank + 1),
|
111
|
+
)
|
112
|
+
for tp_rank in tp_rank_range:
|
113
|
+
reader, writer = mp.Pipe(duplex=False)
|
114
|
+
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
|
115
|
+
proc = mp.Process(
|
116
|
+
target=run_scheduler_process,
|
117
|
+
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
118
|
+
)
|
119
|
+
proc.start()
|
120
|
+
scheduler_procs.append(proc)
|
121
|
+
scheduler_pipe_readers.append(reader)
|
122
|
+
|
123
|
+
send_to = self.context.socket(zmq.PUSH)
|
124
|
+
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
|
125
|
+
|
126
|
+
# Wait for model to finish loading
|
127
|
+
for i in range(len(scheduler_pipe_readers)):
|
128
|
+
scheduler_pipe_readers[i].recv()
|
129
|
+
|
130
|
+
return send_to
|
131
|
+
|
132
|
+
def round_robin_scheduler(self, req):
|
133
|
+
self.workers[self.round_robin_counter].send_pyobj(req)
|
134
|
+
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
135
|
+
|
136
|
+
def shortest_queue_scheduler(self, input_requests):
|
137
|
+
raise NotImplementedError()
|
138
|
+
|
139
|
+
def event_loop(self):
|
140
|
+
while True:
|
141
|
+
while True:
|
142
|
+
try:
|
143
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
144
|
+
except zmq.ZMQError:
|
145
|
+
break
|
146
|
+
|
147
|
+
if isinstance(
|
148
|
+
recv_req,
|
149
|
+
(
|
150
|
+
TokenizedGenerateReqInput,
|
151
|
+
TokenizedEmbeddingReqInput,
|
152
|
+
TokenizedRewardReqInput,
|
153
|
+
),
|
154
|
+
):
|
155
|
+
self.dispatching(recv_req)
|
156
|
+
else:
|
157
|
+
# Send other control messages to all workers
|
158
|
+
for worker in self.workers:
|
159
|
+
worker.queue.put(recv_req)
|
160
|
+
|
161
|
+
|
162
|
+
def run_data_parallel_controller_process(
|
163
|
+
server_args: ServerArgs,
|
164
|
+
port_args: PortArgs,
|
165
|
+
pipe_writer,
|
166
|
+
):
|
167
|
+
configure_logger(server_args)
|
168
|
+
suppress_other_loggers()
|
169
|
+
|
170
|
+
try:
|
171
|
+
controller = DataParallelController(server_args, port_args)
|
172
|
+
pipe_writer.send("ready")
|
173
|
+
controller.event_loop()
|
174
|
+
except Exception:
|
175
|
+
msg = get_exception_traceback()
|
176
|
+
logger.error(msg)
|
177
|
+
kill_parent_process()
|
@@ -18,7 +18,7 @@ limitations under the License.
|
|
18
18
|
import dataclasses
|
19
19
|
import logging
|
20
20
|
from collections import OrderedDict
|
21
|
-
from typing import List
|
21
|
+
from typing import List, Union
|
22
22
|
|
23
23
|
import zmq
|
24
24
|
|
@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
|
|
29
29
|
BatchTokenIDOut,
|
30
30
|
UpdateWeightReqOutput,
|
31
31
|
)
|
32
|
-
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
32
|
+
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
33
33
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
34
34
|
from sglang.srt.utils import configure_logger, kill_parent_process
|
35
35
|
from sglang.utils import find_printable_text, get_exception_traceback
|
@@ -75,6 +75,21 @@ class DetokenizerManager:
|
|
75
75
|
|
76
76
|
self.decode_status = LimitedCapacityDict()
|
77
77
|
|
78
|
+
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
|
79
|
+
if no_stop_trim:
|
80
|
+
return output
|
81
|
+
|
82
|
+
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
83
|
+
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
|
84
|
+
pos = output.find(finished_reason.matched)
|
85
|
+
return output[:pos] if pos != -1 else output
|
86
|
+
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
|
87
|
+
output, list
|
88
|
+
):
|
89
|
+
assert len(output) > 0
|
90
|
+
return output[:-1]
|
91
|
+
return output
|
92
|
+
|
78
93
|
def event_loop(self):
|
79
94
|
"""The event loop that handles requests"""
|
80
95
|
|
@@ -122,7 +137,13 @@ class DetokenizerManager:
|
|
122
137
|
s = self.decode_status[rid]
|
123
138
|
s.decode_ids = recv_obj.decode_ids[i]
|
124
139
|
|
125
|
-
read_ids.append(
|
140
|
+
read_ids.append(
|
141
|
+
self.trim_eos(
|
142
|
+
s.decode_ids[s.surr_offset :],
|
143
|
+
recv_obj.finished_reason[i],
|
144
|
+
recv_obj.no_stop_trim[i],
|
145
|
+
)
|
146
|
+
)
|
126
147
|
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
127
148
|
|
128
149
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
@@ -152,13 +173,13 @@ class DetokenizerManager:
|
|
152
173
|
else:
|
153
174
|
new_text = find_printable_text(new_text)
|
154
175
|
|
155
|
-
output_strs.append(
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
176
|
+
output_strs.append(
|
177
|
+
self.trim_eos(
|
178
|
+
s.decoded_text + new_text,
|
179
|
+
recv_obj.finished_reason[i],
|
180
|
+
recv_obj.no_stop_trim[i],
|
181
|
+
)
|
182
|
+
)
|
162
183
|
|
163
184
|
self.send_to_tokenizer.send_pyobj(
|
164
185
|
BatchStrOut(
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,6 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
20
20
|
|
21
21
|
import uuid
|
22
22
|
from dataclasses import dataclass
|
23
|
+
from enum import Enum
|
23
24
|
from typing import Dict, List, Optional, Union
|
24
25
|
|
25
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
@@ -55,6 +56,9 @@ class GenerateReqInput:
|
|
55
56
|
# LoRA related
|
56
57
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
57
58
|
|
59
|
+
# Whether it is a single request or a batch request
|
60
|
+
is_single: bool = True
|
61
|
+
|
58
62
|
def post_init(self):
|
59
63
|
if (self.text is None and self.input_ids is None) or (
|
60
64
|
self.text is not None and self.input_ids is not None
|
@@ -119,8 +123,7 @@ class GenerateReqInput:
|
|
119
123
|
elif not isinstance(self.image_data, list):
|
120
124
|
self.image_data = [self.image_data] * num
|
121
125
|
elif isinstance(self.image_data, list):
|
122
|
-
|
123
|
-
self.image_data = self.image_data * num
|
126
|
+
pass
|
124
127
|
|
125
128
|
if self.sampling_params is None:
|
126
129
|
self.sampling_params = [{}] * num
|
@@ -295,6 +298,7 @@ class BatchTokenIDOut:
|
|
295
298
|
spaces_between_special_tokens: List[bool]
|
296
299
|
meta_info: List[Dict]
|
297
300
|
finished_reason: List[BaseFinishReason]
|
301
|
+
no_stop_trim: List[bool]
|
298
302
|
|
299
303
|
|
300
304
|
@dataclass
|
@@ -344,3 +348,8 @@ class UpdateWeightReqOutput:
|
|
344
348
|
class AbortReq:
|
345
349
|
# The request id
|
346
350
|
rid: str
|
351
|
+
|
352
|
+
|
353
|
+
class ProfileReq(Enum):
|
354
|
+
START_PROFILE = 1
|
355
|
+
STOP_PROFILE = 2
|