sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
|
|
1
|
+
"""
|
2
|
+
A controller that manages multiple data parallel workers.
|
3
|
+
Each data parallel worker can manage multiple tensor parallel workers.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import asyncio
|
7
|
+
import logging
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from enum import Enum, auto
|
10
|
+
from typing import Dict
|
11
|
+
|
12
|
+
import zmq
|
13
|
+
import zmq.asyncio
|
14
|
+
|
15
|
+
from sglang.global_config import global_config
|
16
|
+
from sglang.srt.managers.controller.dp_worker import (
|
17
|
+
DataParallelWorkerThread,
|
18
|
+
start_data_parallel_worker,
|
19
|
+
)
|
20
|
+
from sglang.srt.managers.io_struct import (
|
21
|
+
AbortReq,
|
22
|
+
FlushCacheReq,
|
23
|
+
TokenizedGenerateReqInput,
|
24
|
+
)
|
25
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
26
|
+
from sglang.utils import get_exception_traceback
|
27
|
+
|
28
|
+
logger = logging.getLogger("srt.controller")
|
29
|
+
|
30
|
+
|
31
|
+
class LoadBalanceMethod(Enum):
|
32
|
+
ROUND_ROBIN = auto()
|
33
|
+
SHORTEST_QUEUE = auto()
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def from_str(cls, method: str):
|
37
|
+
method = method.upper()
|
38
|
+
try:
|
39
|
+
return cls[method]
|
40
|
+
except KeyError as exc:
|
41
|
+
raise ValueError(f"Invalid load balance method: {method}") from exc
|
42
|
+
|
43
|
+
|
44
|
+
class Controller:
|
45
|
+
"""A controller that manages multiple data parallel workers."""
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
load_balance_method: str,
|
50
|
+
server_args: ServerArgs,
|
51
|
+
port_args: PortArgs,
|
52
|
+
model_overide_args,
|
53
|
+
):
|
54
|
+
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
|
55
|
+
self.server_args = server_args
|
56
|
+
self.port_args = port_args
|
57
|
+
|
58
|
+
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
|
59
|
+
self.round_robin_counter = 0
|
60
|
+
|
61
|
+
self.dispatch_lookup = {
|
62
|
+
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
63
|
+
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
64
|
+
}
|
65
|
+
self.dispatching = self.dispatch_lookup[self.load_balance_method]
|
66
|
+
|
67
|
+
# Init communication
|
68
|
+
context = zmq.asyncio.Context()
|
69
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
70
|
+
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
71
|
+
|
72
|
+
# Init status
|
73
|
+
self.recv_reqs = []
|
74
|
+
|
75
|
+
# Start data parallel workers
|
76
|
+
self.workers: Dict[int, DataParallelWorkerThread] = {}
|
77
|
+
tp_size = server_args.tp_size
|
78
|
+
|
79
|
+
def start_dp_worker(i):
|
80
|
+
try:
|
81
|
+
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
|
82
|
+
worker_thread = start_data_parallel_worker(
|
83
|
+
server_args, port_args, model_overide_args, gpu_ids, i
|
84
|
+
)
|
85
|
+
self.workers[i] = worker_thread
|
86
|
+
except Exception:
|
87
|
+
logger.error(
|
88
|
+
f"Failed to start local worker {i}\n{get_exception_traceback()}"
|
89
|
+
)
|
90
|
+
|
91
|
+
for i in range(server_args.dp_size):
|
92
|
+
start_dp_worker(i)
|
93
|
+
|
94
|
+
# Parallel launch is slower, probably due to the disk bandwidth limitations.
|
95
|
+
# with ThreadPoolExecutor(server_args.dp_size) as executor:
|
96
|
+
# executor.map(start_dp_worker, range(server_args.dp_size))
|
97
|
+
|
98
|
+
def have_any_live_worker(self):
|
99
|
+
return any(worker_thread.liveness for worker_thread in self.workers.values())
|
100
|
+
|
101
|
+
def put_req_to_worker(self, worker_id, req):
|
102
|
+
self.workers[worker_id].request_queue.put(req)
|
103
|
+
|
104
|
+
async def round_robin_scheduler(self, input_requests):
|
105
|
+
available_workers = list(self.workers.keys())
|
106
|
+
for r in input_requests:
|
107
|
+
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
|
108
|
+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
109
|
+
available_workers
|
110
|
+
)
|
111
|
+
return
|
112
|
+
|
113
|
+
async def shortest_queue_scheduler(self, input_requests):
|
114
|
+
for r in input_requests:
|
115
|
+
worker = min(
|
116
|
+
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
|
117
|
+
)
|
118
|
+
self.put_req_to_worker(worker, r)
|
119
|
+
return
|
120
|
+
|
121
|
+
async def remove_dead_workers(self):
|
122
|
+
for i in list(self.workers.keys()):
|
123
|
+
worker_thread = self.workers[i]
|
124
|
+
if not worker_thread.liveness:
|
125
|
+
worker_thread.join()
|
126
|
+
# move unsuccessful requests back to the queue
|
127
|
+
while not worker_thread.request_queue.empty():
|
128
|
+
self.recv_reqs.append(worker_thread.request_queue.get())
|
129
|
+
del self.workers[i]
|
130
|
+
logger.info(f"Stale worker {i} removed")
|
131
|
+
|
132
|
+
async def loop_for_forward(self):
|
133
|
+
while True:
|
134
|
+
await self.remove_dead_workers()
|
135
|
+
|
136
|
+
if self.have_any_live_worker():
|
137
|
+
next_step_input = list(self.recv_reqs)
|
138
|
+
self.recv_reqs = []
|
139
|
+
if next_step_input:
|
140
|
+
await self.dispatching(next_step_input)
|
141
|
+
# else:
|
142
|
+
# logger.error("There is no live worker.")
|
143
|
+
|
144
|
+
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
145
|
+
|
146
|
+
async def loop_for_recv_requests(self):
|
147
|
+
while True:
|
148
|
+
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
149
|
+
if isinstance(recv_req, FlushCacheReq):
|
150
|
+
# TODO(lsyin): apply more specific flushCacheReq
|
151
|
+
for worker_thread in self.workers.values():
|
152
|
+
worker_thread.request_queue.put(recv_req)
|
153
|
+
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
154
|
+
self.recv_reqs.append(recv_req)
|
155
|
+
elif isinstance(recv_req, AbortReq):
|
156
|
+
in_queue = False
|
157
|
+
for i, req in enumerate(self.recv_reqs):
|
158
|
+
if req.rid == recv_req.rid:
|
159
|
+
self.recv_reqs[i] = recv_req
|
160
|
+
in_queue = True
|
161
|
+
break
|
162
|
+
if not in_queue:
|
163
|
+
# Send abort req to all TP groups
|
164
|
+
for worker in list(self.workers.keys()):
|
165
|
+
self.put_req_to_worker(worker, recv_req)
|
166
|
+
else:
|
167
|
+
logger.error(f"Invalid object: {recv_req}")
|
168
|
+
|
169
|
+
|
170
|
+
def start_controller_process(
|
171
|
+
server_args: ServerArgs,
|
172
|
+
port_args: PortArgs,
|
173
|
+
pipe_writer,
|
174
|
+
model_overide_args=None,
|
175
|
+
):
|
176
|
+
logging.basicConfig(
|
177
|
+
level=getattr(logging, server_args.log_level.upper()),
|
178
|
+
format="%(message)s",
|
179
|
+
)
|
180
|
+
|
181
|
+
try:
|
182
|
+
controller = Controller(
|
183
|
+
server_args.load_balance_method, server_args, port_args, model_overide_args
|
184
|
+
)
|
185
|
+
except Exception:
|
186
|
+
pipe_writer.send(get_exception_traceback())
|
187
|
+
raise
|
188
|
+
pipe_writer.send("init ok")
|
189
|
+
|
190
|
+
loop = asyncio.new_event_loop()
|
191
|
+
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
192
|
+
|
193
|
+
asyncio.set_event_loop(loop)
|
194
|
+
loop.create_task(controller.loop_for_recv_requests())
|
195
|
+
loop.run_until_complete(controller.loop_for_forward())
|
@@ -0,0 +1,177 @@
|
|
1
|
+
"""A controller that manages a group of tensor parallel workers."""
|
2
|
+
|
3
|
+
import multiprocessing
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import pickle
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import torch.distributed as dist
|
10
|
+
import zmq
|
11
|
+
import zmq.asyncio
|
12
|
+
|
13
|
+
from sglang.srt.managers.controller.tp_worker import ModelTpServer
|
14
|
+
from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
|
15
|
+
from sglang.srt.utils import kill_parent_process
|
16
|
+
from sglang.utils import get_exception_traceback
|
17
|
+
|
18
|
+
|
19
|
+
logger = logging.getLogger("srt.controller")
|
20
|
+
|
21
|
+
|
22
|
+
def run_tp_server(
|
23
|
+
gpu_id: int,
|
24
|
+
tp_rank: int,
|
25
|
+
server_args: ServerArgs,
|
26
|
+
model_port_args: ModelPortArgs,
|
27
|
+
model_overide_args: dict,
|
28
|
+
):
|
29
|
+
"""Run a tp server."""
|
30
|
+
try:
|
31
|
+
model_server = ModelTpServer(
|
32
|
+
gpu_id,
|
33
|
+
tp_rank,
|
34
|
+
server_args,
|
35
|
+
model_port_args,
|
36
|
+
model_overide_args,
|
37
|
+
)
|
38
|
+
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
39
|
+
|
40
|
+
while True:
|
41
|
+
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
42
|
+
model_server.exposed_step(recv_reqs)
|
43
|
+
except Exception:
|
44
|
+
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
45
|
+
raise
|
46
|
+
|
47
|
+
|
48
|
+
def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
49
|
+
model_port_args, model_overide_args):
|
50
|
+
"""Launch multiple tp servers."""
|
51
|
+
procs = []
|
52
|
+
for i in tp_rank_range:
|
53
|
+
proc = multiprocessing.Process(target=run_tp_server, args=(
|
54
|
+
gpu_ids[i], i, server_args, model_port_args, model_overide_args
|
55
|
+
))
|
56
|
+
proc.start()
|
57
|
+
procs.append(proc)
|
58
|
+
|
59
|
+
return procs
|
60
|
+
|
61
|
+
|
62
|
+
def broadcast_recv_input(data, rank, dist_group):
|
63
|
+
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
64
|
+
|
65
|
+
if rank == 0:
|
66
|
+
if len(data) == 0:
|
67
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
68
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
69
|
+
else:
|
70
|
+
serialized_data = pickle.dumps(data)
|
71
|
+
size = len(serialized_data)
|
72
|
+
tensor_data = torch.ByteTensor(list(serialized_data))
|
73
|
+
tensor_size = torch.tensor([size], dtype=torch.long)
|
74
|
+
|
75
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
76
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
77
|
+
else:
|
78
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
79
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
80
|
+
size = tensor_size.item()
|
81
|
+
|
82
|
+
if size == 0:
|
83
|
+
return []
|
84
|
+
|
85
|
+
tensor_data = torch.empty(size, dtype=torch.uint8)
|
86
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
87
|
+
|
88
|
+
serialized_data = bytes(tensor_data.tolist())
|
89
|
+
data = pickle.loads(serialized_data)
|
90
|
+
return data
|
91
|
+
|
92
|
+
|
93
|
+
class ControllerSingle:
|
94
|
+
"""A controller that manages a group of tensor parallel workers."""
|
95
|
+
|
96
|
+
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
|
97
|
+
# Parse args
|
98
|
+
self.server_args = server_args
|
99
|
+
|
100
|
+
# Init communication
|
101
|
+
context = zmq.Context(2)
|
102
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
103
|
+
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
104
|
+
|
105
|
+
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
106
|
+
self.send_to_detokenizer.connect(
|
107
|
+
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
108
|
+
)
|
109
|
+
|
110
|
+
# Init model server
|
111
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
112
|
+
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
113
|
+
|
114
|
+
# Launch other tp ranks
|
115
|
+
if tp_size_local > 1:
|
116
|
+
tp_rank_range = range(1, tp_size_local)
|
117
|
+
self.tp_procs = launch_tp_servers(
|
118
|
+
gpu_ids, tp_rank_range, server_args,
|
119
|
+
port_args.model_port_args[0], model_overide_args)
|
120
|
+
|
121
|
+
# Launch tp rank 0
|
122
|
+
self.tp_server = ModelTpServer(
|
123
|
+
gpu_ids[0],
|
124
|
+
0,
|
125
|
+
server_args,
|
126
|
+
port_args.model_port_args[0],
|
127
|
+
model_overide_args,
|
128
|
+
)
|
129
|
+
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
130
|
+
|
131
|
+
def loop_for_forward(self):
|
132
|
+
while True:
|
133
|
+
recv_reqs = self.recv_requests()
|
134
|
+
|
135
|
+
if self.server_args.tp_size > 1:
|
136
|
+
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
137
|
+
|
138
|
+
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
139
|
+
|
140
|
+
for obj in out_pyobjs:
|
141
|
+
self.send_to_detokenizer.send_pyobj(obj)
|
142
|
+
|
143
|
+
def recv_requests(self):
|
144
|
+
recv_reqs = []
|
145
|
+
while True:
|
146
|
+
try:
|
147
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
148
|
+
recv_reqs.append(recv_req)
|
149
|
+
except zmq.ZMQError:
|
150
|
+
break
|
151
|
+
return recv_reqs
|
152
|
+
|
153
|
+
|
154
|
+
def start_controller_process(
|
155
|
+
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
|
156
|
+
):
|
157
|
+
logging.basicConfig(
|
158
|
+
level=getattr(logging, server_args.log_level.upper()),
|
159
|
+
format="%(message)s",
|
160
|
+
)
|
161
|
+
|
162
|
+
try:
|
163
|
+
controller = ControllerSingle(server_args, port_args, model_overide_args)
|
164
|
+
except Exception:
|
165
|
+
pipe_writer.send(get_exception_traceback())
|
166
|
+
raise
|
167
|
+
|
168
|
+
pipe_writer.send("init ok")
|
169
|
+
|
170
|
+
try:
|
171
|
+
controller.loop_for_forward()
|
172
|
+
except Exception:
|
173
|
+
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
174
|
+
finally:
|
175
|
+
for t in controller.tp_procs:
|
176
|
+
os.kill(t.pid, 9)
|
177
|
+
kill_parent_process()
|