sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,191 @@
|
|
1
|
+
"""
|
2
|
+
A controller that manages multiple data parallel workers.
|
3
|
+
Each data parallel worker can manage multiple tensor parallel workers.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import asyncio
|
7
|
+
import logging
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from enum import Enum, auto
|
10
|
+
from typing import Dict
|
11
|
+
|
12
|
+
import zmq
|
13
|
+
import zmq.asyncio
|
14
|
+
|
15
|
+
from sglang.global_config import global_config
|
16
|
+
from sglang.srt.managers.io_struct import (
|
17
|
+
AbortReq,
|
18
|
+
FlushCacheReq,
|
19
|
+
TokenizedGenerateReqInput,
|
20
|
+
)
|
21
|
+
from sglang.srt.managers.controller.dp_worker import (
|
22
|
+
DataParallelWorkerThread,
|
23
|
+
start_data_parallel_worker,
|
24
|
+
)
|
25
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
26
|
+
from sglang.utils import get_exception_traceback
|
27
|
+
|
28
|
+
logger = logging.getLogger("srt.controller")
|
29
|
+
|
30
|
+
|
31
|
+
class LoadBalanceMethod(Enum):
|
32
|
+
ROUND_ROBIN = auto()
|
33
|
+
SHORTEST_QUEUE = auto()
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def from_str(cls, method: str):
|
37
|
+
method = method.upper()
|
38
|
+
try:
|
39
|
+
return cls[method]
|
40
|
+
except KeyError as exc:
|
41
|
+
raise ValueError(f"Invalid load balance method: {method}") from exc
|
42
|
+
|
43
|
+
|
44
|
+
class Controller:
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
load_balance_method: str,
|
48
|
+
server_args: ServerArgs,
|
49
|
+
port_args: PortArgs,
|
50
|
+
model_overide_args,
|
51
|
+
):
|
52
|
+
self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
|
53
|
+
self.server_args = server_args
|
54
|
+
self.port_args = port_args
|
55
|
+
|
56
|
+
if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
|
57
|
+
self.round_robin_counter = 0
|
58
|
+
|
59
|
+
self.dispatch_lookup = {
|
60
|
+
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
61
|
+
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
62
|
+
}
|
63
|
+
self.dispatching = self.dispatch_lookup[self.load_balance_method]
|
64
|
+
|
65
|
+
# Init communication
|
66
|
+
context = zmq.asyncio.Context()
|
67
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
68
|
+
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
69
|
+
|
70
|
+
# Init status
|
71
|
+
self.recv_reqs = []
|
72
|
+
|
73
|
+
# Start data parallel workers
|
74
|
+
self.workers: Dict[int, DataParallelWorkerThread] = {}
|
75
|
+
tp_size = server_args.tp_size
|
76
|
+
|
77
|
+
def start_dp_worker(i):
|
78
|
+
try:
|
79
|
+
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
|
80
|
+
worker_thread = start_data_parallel_worker(
|
81
|
+
server_args, port_args, model_overide_args, gpu_ids, i
|
82
|
+
)
|
83
|
+
self.workers[i] = worker_thread
|
84
|
+
except Exception:
|
85
|
+
logger.error(
|
86
|
+
f"Failed to start local worker {i}\n{get_exception_traceback()}"
|
87
|
+
)
|
88
|
+
|
89
|
+
for i in range(server_args.dp_size):
|
90
|
+
start_dp_worker(i)
|
91
|
+
|
92
|
+
# Parallel launch is slower, probably due to the disk bandwidth limitations.
|
93
|
+
# with ThreadPoolExecutor(server_args.dp_size) as executor:
|
94
|
+
# executor.map(start_dp_worker, range(server_args.dp_size))
|
95
|
+
|
96
|
+
def have_any_live_worker(self):
|
97
|
+
return any(worker_thread.liveness for worker_thread in self.workers.values())
|
98
|
+
|
99
|
+
def put_req_to_worker(self, worker_id, req):
|
100
|
+
self.workers[worker_id].request_queue.put(req)
|
101
|
+
|
102
|
+
async def round_robin_scheduler(self, input_requests):
|
103
|
+
available_workers = list(self.workers.keys())
|
104
|
+
for r in input_requests:
|
105
|
+
self.put_req_to_worker(available_workers[self.round_robin_counter], r)
|
106
|
+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
107
|
+
available_workers
|
108
|
+
)
|
109
|
+
return
|
110
|
+
|
111
|
+
async def shortest_queue_scheduler(self, input_requests):
|
112
|
+
for r in input_requests:
|
113
|
+
worker = min(
|
114
|
+
self.workers, key=lambda w: self.workers[w].request_queue.qsize()
|
115
|
+
)
|
116
|
+
self.put_req_to_worker(worker, r)
|
117
|
+
return
|
118
|
+
|
119
|
+
async def remove_dead_workers(self):
|
120
|
+
for i in list(self.workers.keys()):
|
121
|
+
worker_thread = self.workers[i]
|
122
|
+
if not worker_thread.liveness:
|
123
|
+
worker_thread.join()
|
124
|
+
# move unsuccessful requests back to the queue
|
125
|
+
while not worker_thread.request_queue.empty():
|
126
|
+
self.recv_reqs.append(worker_thread.request_queue.get())
|
127
|
+
del self.workers[i]
|
128
|
+
logger.info(f"Stale worker {i} removed")
|
129
|
+
|
130
|
+
async def loop_for_forward(self):
|
131
|
+
while True:
|
132
|
+
await self.remove_dead_workers()
|
133
|
+
|
134
|
+
if self.have_any_live_worker():
|
135
|
+
next_step_input = list(self.recv_reqs)
|
136
|
+
self.recv_reqs = []
|
137
|
+
if next_step_input:
|
138
|
+
await self.dispatching(next_step_input)
|
139
|
+
#else:
|
140
|
+
# logger.error("There is no live worker.")
|
141
|
+
|
142
|
+
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
143
|
+
|
144
|
+
async def loop_for_recv_requests(self):
|
145
|
+
while True:
|
146
|
+
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
147
|
+
if isinstance(recv_req, FlushCacheReq):
|
148
|
+
# TODO(lsyin): apply more specific flushCacheReq
|
149
|
+
for worker_thread in self.workers.values():
|
150
|
+
worker_thread.request_queue.put(recv_req)
|
151
|
+
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
152
|
+
self.recv_reqs.append(recv_req)
|
153
|
+
elif isinstance(recv_req, AbortReq):
|
154
|
+
in_queue = False
|
155
|
+
for i, req in enumerate(self.recv_reqs):
|
156
|
+
if req.rid == recv_req.rid:
|
157
|
+
self.recv_reqs[i] = recv_req
|
158
|
+
in_queue = True
|
159
|
+
break
|
160
|
+
if not in_queue:
|
161
|
+
# Send abort req to all TP groups
|
162
|
+
for worker in list(self.workers.keys()):
|
163
|
+
self.put_req_to_worker(worker, recv_req)
|
164
|
+
else:
|
165
|
+
logger.error(f"Invalid object: {recv_req}")
|
166
|
+
|
167
|
+
|
168
|
+
def start_controller_process(
|
169
|
+
server_args: ServerArgs,
|
170
|
+
port_args: PortArgs,
|
171
|
+
pipe_writer,
|
172
|
+
model_overide_args=None,
|
173
|
+
):
|
174
|
+
logging.basicConfig(
|
175
|
+
level=getattr(logging, server_args.log_level.upper()),
|
176
|
+
format="%(message)s",
|
177
|
+
)
|
178
|
+
|
179
|
+
try:
|
180
|
+
controller = Controller(
|
181
|
+
server_args.load_balance_method, server_args, port_args, model_overide_args
|
182
|
+
)
|
183
|
+
except Exception:
|
184
|
+
pipe_writer.send(get_exception_traceback())
|
185
|
+
raise
|
186
|
+
|
187
|
+
pipe_writer.send("init ok")
|
188
|
+
loop = asyncio.get_event_loop()
|
189
|
+
asyncio.set_event_loop(loop)
|
190
|
+
loop.create_task(controller.loop_for_recv_requests())
|
191
|
+
loop.run_until_complete(controller.loop_for_forward())
|
@@ -0,0 +1,97 @@
|
|
1
|
+
"""A controller that manages a group of tensor parallel workers."""
|
2
|
+
import asyncio
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
|
6
|
+
import uvloop
|
7
|
+
import zmq
|
8
|
+
import zmq.asyncio
|
9
|
+
|
10
|
+
from sglang.global_config import global_config
|
11
|
+
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
12
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
13
|
+
from sglang.srt.utils import kill_parent_process
|
14
|
+
from sglang.utils import get_exception_traceback
|
15
|
+
|
16
|
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
17
|
+
|
18
|
+
logger = logging.getLogger("srt.controller")
|
19
|
+
|
20
|
+
|
21
|
+
class ControllerSingle:
|
22
|
+
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
|
23
|
+
# Init communication
|
24
|
+
context = zmq.asyncio.Context(2)
|
25
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
26
|
+
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
27
|
+
|
28
|
+
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
29
|
+
self.send_to_detokenizer.connect(
|
30
|
+
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
31
|
+
)
|
32
|
+
|
33
|
+
# Init status
|
34
|
+
self.model_client = model_client
|
35
|
+
self.recv_reqs = []
|
36
|
+
|
37
|
+
# Init some configs
|
38
|
+
self.request_dependency_delay = global_config.request_dependency_delay
|
39
|
+
|
40
|
+
async def loop_for_forward(self):
|
41
|
+
while True:
|
42
|
+
next_step_input = list(self.recv_reqs)
|
43
|
+
self.recv_reqs = []
|
44
|
+
out_pyobjs = await self.model_client.step(next_step_input)
|
45
|
+
|
46
|
+
for obj in out_pyobjs:
|
47
|
+
self.send_to_detokenizer.send_pyobj(obj)
|
48
|
+
|
49
|
+
# async sleep for receiving the subsequent request and avoiding cache miss
|
50
|
+
slept = False
|
51
|
+
if len(out_pyobjs) != 0:
|
52
|
+
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
|
53
|
+
if has_finished:
|
54
|
+
if self.request_dependency_delay > 0:
|
55
|
+
slept = True
|
56
|
+
await asyncio.sleep(self.request_dependency_delay)
|
57
|
+
|
58
|
+
if not slept:
|
59
|
+
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
60
|
+
|
61
|
+
async def loop_for_recv_requests(self):
|
62
|
+
while True:
|
63
|
+
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
64
|
+
self.recv_reqs.append(recv_req)
|
65
|
+
|
66
|
+
|
67
|
+
def start_controller_process(
|
68
|
+
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
69
|
+
):
|
70
|
+
logging.basicConfig(
|
71
|
+
level=getattr(logging, server_args.log_level.upper()),
|
72
|
+
format="%(message)s",
|
73
|
+
)
|
74
|
+
|
75
|
+
try:
|
76
|
+
model_client = ModelTpClient(
|
77
|
+
list(range(server_args.tp_size)),
|
78
|
+
server_args,
|
79
|
+
port_args.model_port_args[0],
|
80
|
+
model_overide_args,
|
81
|
+
)
|
82
|
+
controller = ControllerSingle(model_client, port_args)
|
83
|
+
except Exception:
|
84
|
+
pipe_writer.send(get_exception_traceback())
|
85
|
+
raise
|
86
|
+
|
87
|
+
pipe_writer.send("init ok")
|
88
|
+
|
89
|
+
loop = asyncio.new_event_loop()
|
90
|
+
asyncio.set_event_loop(loop)
|
91
|
+
loop.create_task(controller.loop_for_recv_requests())
|
92
|
+
try:
|
93
|
+
loop.run_until_complete(controller.loop_for_forward())
|
94
|
+
except Exception:
|
95
|
+
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
96
|
+
finally:
|
97
|
+
kill_parent_process()
|